Source code for foxai.explainer.computer_vision.object_detection.base_object_detector

"""File contains abstract base ObjectDetector class."""
from abc import ABC, abstractmethod
from typing import List, Tuple

import numpy as np
import torch
from torch import nn

from foxai.explainer.computer_vision.object_detection.types import PredictionOutput
from foxai.explainer.computer_vision.object_detection.utils import resize_image


[docs]class BaseObjectDetector(nn.Module, ABC): """Base ObjectDetector class which returns predictions with logits to explain. Code based on https://github.com/pooya-mohammadi/yolov5-gradcam. """
[docs] @abstractmethod def forward( self, image: torch.Tensor, ) -> Tuple[List[PredictionOutput], List[torch.Tensor]]: """Forward pass of the network. Args: image: Image to process. Returns: Tuple of 2 values, first is tuple of predictions containing bounding-boxes, class number, class name and confidence; second value is list of tensors with logits per each detection. """
[docs] @staticmethod def preprocessing( img: np.ndarray, new_shape: Tuple[int, int] = (640, 640), change_original_ratio: bool = False, scaleup: bool = True, ) -> torch.Tensor: """Preprocess image before prediction. Preprocessing is a process consisting of steps: * adding batch dimension * resizing images to desired shapes * adjusting image channels to (B x C x H x W) * convertion to float Args: img: Image to preprocess. new_shape: Desired shape of image. Defaults to (640, 640). change_original_ratio: If resized image should have different height to width ratio than original image. Defaults to False. scaleup: If scale up image. Defaults to True. Returns: Tensor containing preprocessed image. """ if len(img.shape) != 4: # add batch dimension img = np.expand_dims(img, axis=0) # resize all images from batch img = np.array( [ resize_image( image=im, new_shape=new_shape, change_original_ratio=change_original_ratio, scaleup=scaleup, ) for im in img ] ) # convert array from (B x H x W x C) to (B x C x H x W) img = img.transpose((0, 3, 1, 2)) img_tensor = torch.from_numpy(img) # convert from uint8 to float img_tensor = img_tensor / 255.0 return img_tensor