from typing import List, Tuple
import cv2
import matplotlib
import numpy as np
import torch
import torchvision
from matplotlib.pyplot import Figure
from foxai.array_utils import (
    convert_standardized_float_to_uint8,
    normalize_attributes,
    resize_attributes,
    retain_only_positive,
    standardize_array,
    transpose_array,
)
from foxai.explainer.computer_vision.object_detection.types import ObjectDetectionOutput
def _preprocess_img_and_attributes(
    attributes_matrix: np.ndarray,
    transformed_img_np: np.ndarray,
    only_positive_attr: bool,
) -> Tuple[np.ndarray, np.ndarray]:
    """Pre-process image and attributes matrices.
    Pre-processing consists of:
        - squash color dimensions by mean over all colors in attributes matrix
        - optional retension of only positive attributes
        - resizing attributes heatmap to match the size of an image
        - standardization to value range [0-1]
        - transpoze image matrix from (C x H x W) to (H x W x C)
    Args:
        attributions: Features.
        transformed_img: Image in shape (C x H x W) or (H x W).
        only_positive_attr: Whether to display only positive or all attributes.
            Defaults to True.
    Returns:
        Tuple of pre-processed attributes and image matrices.
    """
    single_channel_attributes: np.ndarray = normalize_attributes(
        attributes=attributes_matrix,
    )
    if only_positive_attr:
        single_channel_attributes = retain_only_positive(
            array=single_channel_attributes
        )
    resized_attributes: np.ndarray = resize_attributes(
        attributes=single_channel_attributes,
        dest_height=transformed_img_np.shape[1],
        dest_width=transformed_img_np.shape[2],
    )
    # standardize attributes to uint8 type and back-scale them to range 0-1
    grayscale_attributes = standardize_array(resized_attributes)
    # standardize image
    standardized_img = standardize_array(transformed_img_np.astype(np.dtype(float)))
    # transpoze image from (C x H x W) shape to (H x W x C) to matplotlib imshow
    normalized_transformed_img = transpose_array(
        convert_standardized_float_to_uint8(standardized_img),
    )
    return grayscale_attributes, normalized_transformed_img
[docs]def mean_channels_visualization(
    attributions: torch.Tensor,
    transformed_img: torch.Tensor,
    title: str = "",
    figsize: Tuple[int, int] = (8, 8),
    alpha: float = 0.5,
    only_positive_attr: bool = True,
) -> Figure:
    """Create image with calculated heatmap.
    Args:
        attributions: Features.
        transformed_img: Image in shape (C x H x W) or (H x W).
        title: Title of the figure. Defaults to "".
        figsize: Tuple with size of figure. Defaults to (8, 8).
        alpha: Opacity level. Defaults to 0.5,
        only_positive_attr: Whether to display only positive or all attributes.
            Defaults to True.
    Returns:
        Heatmap of mean channel values applied on original image.
    """
    attributes_matrix: np.ndarray = attributions.detach().cpu().numpy()
    transformed_img_np: np.ndarray = transformed_img.detach().cpu().numpy()
    grayscale_attributes, normalized_transformed_img = _preprocess_img_and_attributes(
        attributes_matrix=attributes_matrix,
        transformed_img_np=transformed_img_np,
        only_positive_attr=only_positive_attr,
    )
    return generate_figure(
        attributions=grayscale_attributes,
        transformed_img=normalized_transformed_img,
        title=title,
        figsize=figsize,
        alpha=alpha,
    ) 
[docs]def single_channel_visualization(
    attributions: torch.Tensor,
    transformed_img: torch.Tensor,
    selected_channel: int,
    title: str = "",
    figsize: Tuple[int, int] = (8, 8),
    alpha: float = 0.5,
    only_positive_attr: bool = True,
) -> Figure:
    """Create image with calculated heatmap.
    Args:
        attributions: Features.
        transformed_img: Image in shape (C x H x W) or (H x W).
        selected_channel: Single color channel to visualize.
        title: Title of the figure. Defaults to "".
        figsize: Tuple with size of figure. Defaults to (8, 8).
        alpha: Opacity level. Defaults to 0.5,
        only_positive_attr: Whether to display only positive or all attributes.
            Defaults to True.
    Returns:
        Heatmap of single channel applied on original image.
    Raises:
        ValueError: if selected channel is negative number or exceed dimension
            of color channels of attributes.
    """
    condition: bool = 0 <= selected_channel < attributions.shape[0]
    if not condition:
        raise ValueError(
            f"The selected channel exceeds color dimension. Selected channel: {selected_channel}",
        )
    attributes_matrix: np.ndarray = attributions.detach().cpu().numpy()
    transformed_img_np: np.ndarray = transformed_img.detach().cpu().numpy()
    attributes_matrix = attributes_matrix[selected_channel]
    grayscale_attributes, normalized_transformed_img = _preprocess_img_and_attributes(
        attributes_matrix=attributes_matrix,
        transformed_img_np=transformed_img_np,
        only_positive_attr=only_positive_attr,
    )
    return generate_figure(
        attributions=grayscale_attributes,
        transformed_img=normalized_transformed_img,
        title=title,
        figsize=figsize,
        alpha=alpha,
    ) 
[docs]def preprocess_object_detection_image(input_image: torch.Tensor) -> np.ndarray:
    """Process input image to display.
    Args:
        input_image: Original image of type float in range [0-1].
    Returns:
        Converted image as np.ndarray in (C x H x W).
    """
    return (
        input_image.squeeze(0)
        .mul(255)
        .clamp_(0, 255)
        .permute(1, 2, 0)
        .detach()
        .cpu()
        .numpy()
    ) 
[docs]def get_heatmap_bbox(
    heatmap: np.ndarray,
    bbox: List[int],
    mask_value: int = 0,
) -> np.ndarray:
    """_summary_
    Code based on https://github.com/pooya-mohammadi/deep_utils/blob/main/deep_utils/utils/box_utils/boxes.py.
    Args:
        heatmap: Heatmap to visualize.
        bbox: Bounding box of detection.
        mask_value: Masking value . Defaults to 0.
    Returns:
        Numpy array with heatmap only present in area of given bounding box.
    """
    # fill the outer area of the selected box
    mask = np.ones_like(heatmap, dtype=np.uint8) * mask_value
    mask[bbox[0] : bbox[2], bbox[1] : bbox[3]] = 1
    masked_heatmap = cv2.multiply(heatmap, mask)
    return masked_heatmap 
[docs]def draw_heatmap_in_bbox(
    bbox: List[int],
    heatmap: torch.Tensor,
    img: np.ndarray,
) -> np.ndarray:
    """Draw heatmap in bounding box on image.
    Args:
        bbox: List of coordinates for bounding box.
        heatmap: Heatmap to display.
        img: Original image.
    Returns:
        Image with displayed heatmap in bounding box area.
    """
    heatmap_np = preprocess_object_detection_image(heatmap).astype(np.uint8)
    heatmap_np = cv2.applyColorMap(heatmap_np, cv2.COLORMAP_JET)
    masked_heatmap = get_heatmap_bbox(heatmap=heatmap_np, bbox=bbox).astype(np.float32)
    img = cv2.add(img, masked_heatmap)
    img = img / img.max()
    img = (img * 255).astype(np.uint8)
    return img 
[docs]def concat_images(images: List[np.ndarray]) -> np.ndarray:
    """Concatenate images into one.
    Args:
        images: List of images to merge.
    Returns:
        Final image.
    """
    w, h = images[0].shape[:2]
    width = w
    height = h * len(images)
    base_img = np.zeros((width, height, 3), dtype=np.uint8)
    for i, img in enumerate(images):
        base_img[:, h * i : h * (i + 1), ...] = img
    return base_img 
[docs]def object_detection_visualization(
    detections: ObjectDetectionOutput,
    input_image: torch.Tensor,
) -> np.ndarray:
    """Create array with detection heatmaps.
    Args:
        detections: Object detection data class.
        input_image: Image in shape (C x H x W) or (H x W).
    Returns:
        Array of series of images with heatmap displayed on detection bounding boxes.
    """
    masks = detections.saliency_maps
    boxes = [pred.bbox for pred in detections.predictions]
    class_names = [pred.class_name for pred in detections.predictions]
    img_to_display = preprocess_object_detection_image(input_image)
    img_to_display = img_to_display[..., ::-1]  # convert to bgr
    images = [img_to_display]
    for i, mask in enumerate(masks):
        res_img = img_to_display.copy()
        bbox, cls_name = boxes[i], class_names[i]
        bbox = [int(val) for val in bbox]
        res_img = draw_heatmap_in_bbox(bbox, mask, res_img)
        # convert to (C x H x W)
        res_img_tensor = torch.tensor(res_img).transpose(0, 2).transpose(1, 2)
        bbox = [bbox[1], bbox[0], bbox[3], bbox[2]]
        res_img_tensor = torchvision.utils.draw_bounding_boxes(
            image=res_img_tensor,
            boxes=torch.tensor([bbox]),
            labels=[cls_name],
        )
        # convert to (H x W x C)
        res_img = res_img_tensor.transpose(1, 2).transpose(0, 2).numpy()
        images.append(res_img)
    final_image = concat_images(images)
    return final_image