"""
Run xai alongside with inference.
Example:
with FoXaiExplainer(
model=classifier,
explainers=[
ExplainerWithParams(
explainer_name=Explainers.CV_GRADIENT_SHAP_EXPLAINER,
n_samples=100,
stdevs=0.0005,
),
],
target=pred_label_idx,
) as xai_model:
output, xai_explanations = xai_model(img_tensor)
"""
import logging
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, Tuple, Union, cast
import torch
from foxai import explainer
from foxai.explainer import (
DeconvolutionCVExplainer,
DeepLIFTCVExplainer,
DeepLIFTSHAPCVExplainer,
GradientSHAPCVExplainer,
GuidedBackpropCVExplainer,
GuidedGradCAMCVExplainer,
InputXGradientCVExplainer,
IntegratedGradientsCVExplainer,
LayerConductanceCVExplainer,
LayerDeepLIFTCVExplainer,
LayerDeepLIFTSHAPCVExplainer,
LayerGradCAMCVExplainer,
LayerGradCAMObjectDetectionExplainer,
LayerGradientSHAPCVExplainer,
LayerInputXGradientCVExplainer,
LayerIntegratedGradientsCVExplainer,
LayerLRPCVExplainer,
LayerNoiseTunnelCVExplainer,
LRPCVExplainer,
NoiseTunnelCVExplainer,
OcclusionCVExplainer,
SaliencyCVExplainer,
)
from foxai.explainer.base_explainer import CVExplainerT
from foxai.logger import create_logger
_LOGGER: Optional[logging.Logger] = None
[docs]def log() -> logging.Logger:
"""Get or create logger."""
# pylint: disable = global-statement
global _LOGGER
if _LOGGER is None:
_LOGGER = create_logger(__name__)
return _LOGGER
[docs]class CVClassificationExplainers(Enum):
"""Enum of supported computer vision classification explainers types."""
CV_OCCLUSION_EXPLAINER: str = OcclusionCVExplainer.__name__
CV_INTEGRATED_GRADIENTS_EXPLAINER: str = IntegratedGradientsCVExplainer.__name__
CV_NOISE_TUNNEL_EXPLAINER: str = NoiseTunnelCVExplainer.__name__
CV_GRADIENT_SHAP_EXPLAINER: str = GradientSHAPCVExplainer.__name__
CV_LRP_EXPLAINER: str = LRPCVExplainer.__name__
CV_GUIDEDGRADCAM_EXPLAINER: str = GuidedGradCAMCVExplainer.__name__
CV_LAYER_INTEGRATED_GRADIENTS_EXPLAINER: str = (
LayerIntegratedGradientsCVExplainer.__name__
)
CV_LAYER_NOISE_TUNNEL_EXPLAINER: str = LayerNoiseTunnelCVExplainer.__name__
CV_LAYER_GRADIENT_SHAP_EXPLAINER: str = LayerGradientSHAPCVExplainer.__name__
CV_LAYER_LRP_EXPLAINER: str = LayerLRPCVExplainer.__name__
CV_LAYER_GRADCAM_EXPLAINER: str = LayerGradCAMCVExplainer.__name__
CV_INPUT_X_GRADIENT_EXPLAINER: str = InputXGradientCVExplainer.__name__
CV_LAYER_INPUT_X_GRADIENT_EXPLAINER: str = LayerInputXGradientCVExplainer.__name__
CV_DEEPLIFT_EXPLAINER: str = DeepLIFTCVExplainer.__name__
CV_LAYER_DEEPLIFT_EXPLAINER: str = LayerDeepLIFTCVExplainer.__name__
CV_DEEPLIFT_SHAP_EXPLAINER: str = DeepLIFTSHAPCVExplainer.__name__
CV_LAYER_DEEPLIFT_SHAP_EXPLAINER: str = LayerDeepLIFTSHAPCVExplainer.__name__
CV_DECONVOLUTION_EXPLAINER: str = DeconvolutionCVExplainer.__name__
CV_LAYER_CONDUCTANCE_EXPLAINER: str = LayerConductanceCVExplainer.__name__
CV_SALIENCY_EXPLAINER: str = SaliencyCVExplainer.__name__
CV_GUIDED_BACKPOPAGATION_EXPLAINER: str = GuidedBackpropCVExplainer.__name__
[docs]class CVObjectDetectionExplainers(Enum):
"""Enum of supported computer vision object detection explainers types."""
CV_LAYER_GRADCAM_OBJECT_DETECTION_EXPLAINER: str = (
LayerGradCAMObjectDetectionExplainer.__name__
)
[docs]@dataclass
class ExplainerWithParams:
"""Holder for explainer name (class name) and it's params"""
explainer_name: Union[CVClassificationExplainers, CVObjectDetectionExplainers]
kwargs: Dict[str, Any] = field(default_factory=dict)
def __init__(
self,
explainer_name: Union[CVClassificationExplainers, CVObjectDetectionExplainers],
**kwargs
) -> None:
self.explainer_name = explainer_name
if kwargs:
self.kwargs = kwargs
else:
self.kwargs = {}
[docs]@dataclass
class ExplainerClassWithParams(Generic[CVExplainerT]):
"""Holder for explainer class and it's params"""
explainer_class: CVExplainerT
kwargs: Dict[str, Any] = field(default_factory=dict)
def __init__(self, explainer_class: CVExplainerT, **kwargs) -> None:
self.explainer_class = explainer_class
if kwargs:
self.kwargs = kwargs
else:
self.kwargs = {}
[docs]class FoXaiExplainer(Generic[CVExplainerT]):
"""Context menager for FoXAI explanation.
Example:
with FoXaiExplainer(
model=classifier,
explainers=[
ExplainerWithParams(
explainer_name=Explainers.CV_GRADIENT_SHAP_EXPLAINER,
n_samples=100,
stdevs=0.0005,
),
],
target=pred_label_idx,
) as xai_model:
output, xai_explanations = xai_model(img_tensor)
Raises:
ValueError: if no explainer provided
"""
def __init__(
self,
model: torch.nn.Module,
explainers: List[ExplainerWithParams],
target: int = 0,
) -> None:
"""
Args:
model: the torch model to exavluate with CV explainer
explainers: explainers names list, to use for model evaluation.
target: predicted target index. For which class to generate xai.
"""
if not explainers:
raise ValueError("At leas one explainer should be defined.")
self.model: torch.nn.Module = model
self.prev_model_training_state: bool = self.model.training
self.explainer_map: Dict[str, ExplainerClassWithParams] = {
explainer_with_params.explainer_name.name: ExplainerClassWithParams(
explainer_class=getattr(
explainer, explainer_with_params.explainer_name.value
)(),
**explainer_with_params.kwargs,
)
for explainer_with_params in explainers
}
self.target: int = target
def __enter__(self) -> "FoXaiExplainer":
"""Verify if model is in eval() mode.
Raises:
ValueError: if the model is in training mode.
Returns:
the foxai class instance.
"""
self.prev_torch_grad = torch.is_grad_enabled()
if not self.prev_torch_grad:
log_msg: str = (
"Torch model explainer can be called only with enabled "
+ "gradients, as it depends on gradients computations. The model is going "
+ "to be toggled to gradients enabled. For the "
+ "model prediction, the gradient is temporary turned off."
)
log().warning(log_msg)
torch.set_grad_enabled(True)
if self.model.training:
self.model.eval()
log().warning(
"The model should be in the eval model. Toggling it to eval mode right now."
)
return self
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
"""If the torch was not recording gradient, before entering in the
context manager mode, switch it back to no gradient recording mode.
If the torch was recording gradient before entering in the context
manager modes, nothings changes.
Setup model to previous state: `eval` or `training` to match initial
state.
"""
torch.set_grad_enabled(self.prev_torch_grad)
self.model.train(self.prev_model_training_state)
def __call__(self, *args, **kwargs) -> Tuple[Any, Dict[str, torch.Tensor]]:
"""Run model prediction and explain the model with given explainers.
Explainers and model are defined as the class parameter.
Args:
list of arguments for the torch.nn.Module forward method.
Returns:
the model output and explanations for each requested explainer.
"""
with torch.no_grad():
model_output: Any = self.model(*args, **kwargs)
if len(args) != 1:
# TODO: add support in explainer for multiple input models
raise NotImplementedError(
"calculate_features() functions "
+ "in explainers does not support multiple inputs to the model."
)
input_tensor: torch.Tensor = cast(torch.Tensor, args)[0]
# cashe tensor requires grad state
prev_requires_grad: bool = input_tensor.requires_grad
# turn on requires grad for the input tensor
input_tensor.requires_grad = True
explanations: Dict[str, torch.Tensor] = {}
for explainer_name in self.explainer_map:
# zero the previous gradient for the model
self.model.zero_grad()
# run explainer
explainer_kwargs: Dict[str, Any] = self.explainer_map[explainer_name].kwargs
explainer_class: CVExplainerT = self.explainer_map[
explainer_name
].explainer_class
explanations[explainer_name] = (
explainer_class.calculate_features(
model=self.model,
input_data=input_tensor,
pred_label_idx=self.target,
**explainer_kwargs,
)
.detach()
.cpu()
)
input_tensor.grad = None
# restore tensor requires grad state
input_tensor.requires_grad = prev_requires_grad
return model_output, explanations