"""Callback for Weights and Biases."""
from collections import defaultdict
from typing import Dict, Generator, List, Optional, Tuple, cast
import matplotlib
import numpy as np
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning.loggers import WandbLogger
from torch.utils.data import DataLoader
from foxai.array_utils import convert_standardized_float_to_uint8, standardize_array
from foxai.context_manager import ExplainerWithParams, FoXaiExplainer
from foxai.visualizer import mean_channels_visualization
AttributeMapType = Dict[str, List[np.ndarray]]
CaptionMapType = Dict[str, List[str]]
FigureMapType = Dict[str, List[matplotlib.pyplot.Figure]]
[docs]class WandBCallback(pl.callbacks.Callback):
"""Library callback for Weights and Biases."""
def __init__( # pylint: disable = (too-many-arguments)
self,
wandb_logger: WandbLogger,
explainers: List[ExplainerWithParams],
idx_to_label: Dict[int, str],
max_artifacts: int = 3,
):
"""Initialize Callback class.
Args:
wandb_logger: Pytorch-lightning wandb logger.
idx_to_label: Index to label mapping.
explainers: List of explainer algorithms of type ExplainerWithParams.
idx_to_label: Dictionary with mapping from model index to label.
max_artifacts: Number of maximum number of artifacts to be logged.
Defaults to 3.
"""
super().__init__()
self.explainers = explainers
self.wandb_logger = wandb_logger
self.idx_to_label = idx_to_label
self.max_artifacts = max_artifacts
def _save_idx_mapping(self) -> None:
"""Saving index to label mapping to experiment logs directory."""
self.wandb_logger.log_table(
key="idx2label",
columns=["index", "label"],
data=[[key, val] for key, val in self.idx_to_label.items()],
)
[docs] def iterate_dataloader(
self, dataloader_list: List[DataLoader], max_items: int
) -> Generator[Tuple[torch.Tensor, torch.Tensor], None, None]:
"""Iterate over dataloader list with constraint on max items returned.
Args:
dataloader: Trainer dataloader.
max_items: Max items to return.
Yields:
Tuple containing training sample and corresponding label.
"""
index: int = 0
dataloader: DataLoader
item: torch.Tensor
target_label: torch.Tensor
for dataloader in dataloader_list:
for batch in dataloader:
for item, target_label in zip(*batch):
if index >= max_items:
break
index += 1
yield item, target_label
[docs] def explain( # pylint: disable = (too-many-arguments)
self,
model: pl.LightningModule,
item: torch.Tensor,
target_label: torch.Tensor,
attributes_dict: AttributeMapType,
caption_dict: CaptionMapType,
figures_dict: FigureMapType,
) -> Tuple[AttributeMapType, CaptionMapType, FigureMapType,]:
"""Calculate explainer attributes, creates captions and figures.
Args:
model: Model to explain.
item: Input data sample tensor.
target_label: Sample label.
attributes_dict: List of attributes for every explainer and sample.
caption_dict: List of captions for every explainer and sample.
figures_dict: List of figures for every explainer and sample.
Returns:
Tuple of maps containing attributes, captions and figures for
every explainer and sample.
"""
with FoXaiExplainer(
model=model,
explainers=self.explainers,
target=int(target_label.item()),
) as xai_model:
_, attributes = xai_model(item.to(cast(torch.device, model.device)))
for explainer in self.explainers:
explainer_name: str = explainer.explainer_name.name
explainer_attributes: torch.Tensor = attributes[explainer_name]
caption_dict[explainer_name].append(f"label: {target_label}")
figure = mean_channels_visualization(
attributions=explainer_attributes,
transformed_img=item,
)
figures_dict[explainer_name].append(figure)
standardized_attr = standardize_array(
explainer_attributes.detach().cpu().numpy().astype(float)
)
attributes_dict[explainer_name].append(
convert_standardized_float_to_uint8(standardized_attr),
)
return attributes_dict, caption_dict, figures_dict
[docs] def on_train_start(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule, # pylint: disable = (unused-argument)
) -> None:
"""Save index to labels mapping and validation samples to experiment
at `fit`.
Args:
trainer: Trainer object.
pl_module: Model to explain.
"""
if trainer.val_dataloaders is None:
return
self._save_idx_mapping()
image_matrix: Optional[torch.Tensor] = None
image_labels: List[str] = []
for item, target_label in self.iterate_dataloader(
dataloader_list=trainer.val_dataloaders, max_items=self.max_artifacts
):
if image_matrix is None:
image_matrix = item
else:
image_matrix = torch.cat( # pylint: disable = (no-member)
[image_matrix, item]
)
image_labels.append(f"label: {target_label.item()}")
if image_matrix is None:
return
list_of_images: List[torch.Tensor] = list(torch.split(image_matrix, 1))
self.wandb_logger.log_image(
key="validation_data",
images=list_of_images[: min(len(list_of_images), self.max_artifacts)],
caption=image_labels[: min(len(image_labels), self.max_artifacts)],
)
[docs] def on_validation_epoch_end( # pylint: disable = (too-many-arguments, too-many-locals)
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
) -> None:
"""Export model's state dict in log directory on validation epoch end.
Args:
trainer: Trainer object.
pl_module: Model to explain.
"""
if trainer.val_dataloaders is None:
return
attributes_dict: AttributeMapType = defaultdict(list)
caption_dict: CaptionMapType = defaultdict(list)
figures_dict: FigureMapType = defaultdict(list)
for item, target_label in self.iterate_dataloader(
dataloader_list=trainer.val_dataloaders,
max_items=self.max_artifacts,
):
attributes_dict, caption_dict, figures_dict = self.explain(
model=pl_module,
item=item,
target_label=target_label,
attributes_dict=attributes_dict,
caption_dict=caption_dict,
figures_dict=figures_dict,
)
self.log_explanations(
attributes_dict=attributes_dict,
caption_dict=caption_dict,
figures_dict=figures_dict,
)
[docs] def log_explanations(
self,
attributes_dict: AttributeMapType,
caption_dict: CaptionMapType,
figures_dict: FigureMapType,
) -> None:
"""Log explanation artifacts to W&B experiment.
Args:
attributes_dict: Numpy array attributes for every sample and every explainer.
caption_dict: Caption for every sample and every explainer.
figures_dict: Figure with attributes for every sample and every explainer.
"""
# upload artifacts to the wandb experiment
for explainer in self.explainers:
explainer_name: str = explainer.explainer_name.name
self.wandb_logger.log_image(
key=f"{explainer_name}",
images=attributes_dict[explainer_name],
caption=caption_dict[explainer_name],
)
# matplotlib Figures can not be directly logged via WandbLogger
# we have to use native Run object from wandb which is more powerfull
wandb_image_list: List[wandb.Image] = []
for figure in figures_dict[explainer_name]:
wandb_image_list.append(wandb.Image(figure))
self.wandb_logger.experiment.log(
{f"{explainer_name}_explanations": wandb_image_list}
)