"""File contains abstract base ObjectDetector class."""
from abc import ABC, abstractmethod
from typing import List, Tuple
import numpy as np
import torch
from torch import nn
from foxai.explainer.computer_vision.object_detection.types import PredictionOutput
from foxai.explainer.computer_vision.object_detection.utils import resize_image
[docs]class BaseObjectDetector(nn.Module, ABC):
"""Base ObjectDetector class which returns predictions with logits to explain.
Code based on https://github.com/pooya-mohammadi/yolov5-gradcam.
"""
[docs] @abstractmethod
def forward(
self,
image: torch.Tensor,
) -> Tuple[List[PredictionOutput], List[torch.Tensor]]:
"""Forward pass of the network.
Args:
image: Image to process.
Returns:
Tuple of 2 values, first is tuple of predictions containing bounding-boxes,
class number, class name and confidence; second value is list of tensors
with logits per each detection.
"""
[docs] @staticmethod
def preprocessing(
img: np.ndarray,
new_shape: Tuple[int, int] = (640, 640),
change_original_ratio: bool = False,
scaleup: bool = True,
) -> torch.Tensor:
"""Preprocess image before prediction.
Preprocessing is a process consisting of steps:
* adding batch dimension
* resizing images to desired shapes
* adjusting image channels to (B x C x H x W)
* convertion to float
Args:
img: Image to preprocess.
new_shape: Desired shape of image. Defaults to (640, 640).
change_original_ratio: If resized image should have different height to
width ratio than original image. Defaults to False.
scaleup: If scale up image. Defaults to True.
Returns:
Tensor containing preprocessed image.
"""
if len(img.shape) != 4:
# add batch dimension
img = np.expand_dims(img, axis=0)
# resize all images from batch
img = np.array(
[
resize_image(
image=im,
new_shape=new_shape,
change_original_ratio=change_original_ratio,
scaleup=scaleup,
)
for im in img
]
)
# convert array from (B x H x W x C) to (B x C x H x W)
img = img.transpose((0, 3, 1, 2))
img_tensor = torch.from_numpy(img)
# convert from uint8 to float
img_tensor = img_tensor / 255.0
return img_tensor