from typing import List, Tuple
import cv2
import matplotlib
import numpy as np
import torch
import torchvision
from matplotlib.pyplot import Figure
from foxai.array_utils import (
convert_standardized_float_to_uint8,
normalize_attributes,
resize_attributes,
retain_only_positive,
standardize_array,
transpose_array,
)
from foxai.explainer.computer_vision.object_detection.types import ObjectDetectionOutput
def _preprocess_img_and_attributes(
attributes_matrix: np.ndarray,
transformed_img_np: np.ndarray,
only_positive_attr: bool,
) -> Tuple[np.ndarray, np.ndarray]:
"""Pre-process image and attributes matrices.
Pre-processing consists of:
- squash color dimensions by mean over all colors in attributes matrix
- optional retension of only positive attributes
- resizing attributes heatmap to match the size of an image
- standardization to value range [0-1]
- transpoze image matrix from (C x H x W) to (H x W x C)
Args:
attributions: Features.
transformed_img: Image in shape (C x H x W) or (H x W).
only_positive_attr: Whether to display only positive or all attributes.
Defaults to True.
Returns:
Tuple of pre-processed attributes and image matrices.
"""
single_channel_attributes: np.ndarray = normalize_attributes(
attributes=attributes_matrix,
)
if only_positive_attr:
single_channel_attributes = retain_only_positive(
array=single_channel_attributes
)
resized_attributes: np.ndarray = resize_attributes(
attributes=single_channel_attributes,
dest_height=transformed_img_np.shape[1],
dest_width=transformed_img_np.shape[2],
)
# standardize attributes to uint8 type and back-scale them to range 0-1
grayscale_attributes = standardize_array(resized_attributes)
# standardize image
standardized_img = standardize_array(transformed_img_np.astype(np.dtype(float)))
# transpoze image from (C x H x W) shape to (H x W x C) to matplotlib imshow
normalized_transformed_img = transpose_array(
convert_standardized_float_to_uint8(standardized_img),
)
return grayscale_attributes, normalized_transformed_img
[docs]def mean_channels_visualization(
attributions: torch.Tensor,
transformed_img: torch.Tensor,
title: str = "",
figsize: Tuple[int, int] = (8, 8),
alpha: float = 0.5,
only_positive_attr: bool = True,
) -> Figure:
"""Create image with calculated heatmap.
Args:
attributions: Features.
transformed_img: Image in shape (C x H x W) or (H x W).
title: Title of the figure. Defaults to "".
figsize: Tuple with size of figure. Defaults to (8, 8).
alpha: Opacity level. Defaults to 0.5,
only_positive_attr: Whether to display only positive or all attributes.
Defaults to True.
Returns:
Heatmap of mean channel values applied on original image.
"""
attributes_matrix: np.ndarray = attributions.detach().cpu().numpy()
transformed_img_np: np.ndarray = transformed_img.detach().cpu().numpy()
grayscale_attributes, normalized_transformed_img = _preprocess_img_and_attributes(
attributes_matrix=attributes_matrix,
transformed_img_np=transformed_img_np,
only_positive_attr=only_positive_attr,
)
return generate_figure(
attributions=grayscale_attributes,
transformed_img=normalized_transformed_img,
title=title,
figsize=figsize,
alpha=alpha,
)
[docs]def single_channel_visualization(
attributions: torch.Tensor,
transformed_img: torch.Tensor,
selected_channel: int,
title: str = "",
figsize: Tuple[int, int] = (8, 8),
alpha: float = 0.5,
only_positive_attr: bool = True,
) -> Figure:
"""Create image with calculated heatmap.
Args:
attributions: Features.
transformed_img: Image in shape (C x H x W) or (H x W).
selected_channel: Single color channel to visualize.
title: Title of the figure. Defaults to "".
figsize: Tuple with size of figure. Defaults to (8, 8).
alpha: Opacity level. Defaults to 0.5,
only_positive_attr: Whether to display only positive or all attributes.
Defaults to True.
Returns:
Heatmap of single channel applied on original image.
Raises:
ValueError: if selected channel is negative number or exceed dimension
of color channels of attributes.
"""
condition: bool = 0 <= selected_channel < attributions.shape[0]
if not condition:
raise ValueError(
f"The selected channel exceeds color dimension. Selected channel: {selected_channel}",
)
attributes_matrix: np.ndarray = attributions.detach().cpu().numpy()
transformed_img_np: np.ndarray = transformed_img.detach().cpu().numpy()
attributes_matrix = attributes_matrix[selected_channel]
grayscale_attributes, normalized_transformed_img = _preprocess_img_and_attributes(
attributes_matrix=attributes_matrix,
transformed_img_np=transformed_img_np,
only_positive_attr=only_positive_attr,
)
return generate_figure(
attributions=grayscale_attributes,
transformed_img=normalized_transformed_img,
title=title,
figsize=figsize,
alpha=alpha,
)
[docs]def preprocess_object_detection_image(input_image: torch.Tensor) -> np.ndarray:
"""Process input image to display.
Args:
input_image: Original image of type float in range [0-1].
Returns:
Converted image as np.ndarray in (C x H x W).
"""
return (
input_image.squeeze(0)
.mul(255)
.clamp_(0, 255)
.permute(1, 2, 0)
.detach()
.cpu()
.numpy()
)
[docs]def get_heatmap_bbox(
heatmap: np.ndarray,
bbox: List[int],
mask_value: int = 0,
) -> np.ndarray:
"""_summary_
Code based on https://github.com/pooya-mohammadi/deep_utils/blob/main/deep_utils/utils/box_utils/boxes.py.
Args:
heatmap: Heatmap to visualize.
bbox: Bounding box of detection.
mask_value: Masking value . Defaults to 0.
Returns:
Numpy array with heatmap only present in area of given bounding box.
"""
# fill the outer area of the selected box
mask = np.ones_like(heatmap, dtype=np.uint8) * mask_value
mask[bbox[0] : bbox[2], bbox[1] : bbox[3]] = 1
masked_heatmap = cv2.multiply(heatmap, mask)
return masked_heatmap
[docs]def draw_heatmap_in_bbox(
bbox: List[int],
heatmap: torch.Tensor,
img: np.ndarray,
) -> np.ndarray:
"""Draw heatmap in bounding box on image.
Args:
bbox: List of coordinates for bounding box.
heatmap: Heatmap to display.
img: Original image.
Returns:
Image with displayed heatmap in bounding box area.
"""
heatmap_np = preprocess_object_detection_image(heatmap).astype(np.uint8)
heatmap_np = cv2.applyColorMap(heatmap_np, cv2.COLORMAP_JET)
masked_heatmap = get_heatmap_bbox(heatmap=heatmap_np, bbox=bbox).astype(np.float32)
img = cv2.add(img, masked_heatmap)
img = img / img.max()
img = (img * 255).astype(np.uint8)
return img
[docs]def concat_images(images: List[np.ndarray]) -> np.ndarray:
"""Concatenate images into one.
Args:
images: List of images to merge.
Returns:
Final image.
"""
w, h = images[0].shape[:2]
width = w
height = h * len(images)
base_img = np.zeros((width, height, 3), dtype=np.uint8)
for i, img in enumerate(images):
base_img[:, h * i : h * (i + 1), ...] = img
return base_img
[docs]def object_detection_visualization(
detections: ObjectDetectionOutput,
input_image: torch.Tensor,
) -> np.ndarray:
"""Create array with detection heatmaps.
Args:
detections: Object detection data class.
input_image: Image in shape (C x H x W) or (H x W).
Returns:
Array of series of images with heatmap displayed on detection bounding boxes.
"""
masks = detections.saliency_maps
boxes = [pred.bbox for pred in detections.predictions]
class_names = [pred.class_name for pred in detections.predictions]
img_to_display = preprocess_object_detection_image(input_image)
img_to_display = img_to_display[..., ::-1] # convert to bgr
images = [img_to_display]
for i, mask in enumerate(masks):
res_img = img_to_display.copy()
bbox, cls_name = boxes[i], class_names[i]
bbox = [int(val) for val in bbox]
res_img = draw_heatmap_in_bbox(bbox, mask, res_img)
# convert to (C x H x W)
res_img_tensor = torch.tensor(res_img).transpose(0, 2).transpose(1, 2)
bbox = [bbox[1], bbox[0], bbox[3], bbox[2]]
res_img_tensor = torchvision.utils.draw_bounding_boxes(
image=res_img_tensor,
boxes=torch.tensor([bbox]),
labels=[cls_name],
)
# convert to (H x W x C)
res_img = res_img_tensor.transpose(1, 2).transpose(0, 2).numpy()
images.append(res_img)
final_image = concat_images(images)
return final_image