"""File contains functions to modifiy DNN models."""
from typing import List
import torch
[docs]def modify_modules(model: torch.nn.Module) -> torch.nn.Module:
"""Modify modules of given model.
Function iterates over all modules and sets property `inplace`
to `False` for every `torch.nn.ReLU` activation function.
Args:
model: Neural network object to be modified.
Returns:
Modified neural network object.
"""
for module in model.modules(): # pylint: disable = (duplicate-code)
if isinstance(module, torch.nn.ReLU):
module.inplace = False
return model
[docs]def get_last_conv_model_layer(model: torch.nn.Module) -> torch.nn.Module:
"""Get the last convolutional layer from the torch model.
Args:
model: torch.nn.Module
Returns:
The last convolutional layer of the model.
Raises:
ValueError if the model does not contain convolutional layers.
"""
conv_layers: List[torch.nn.Module] = []
for module in model.modules():
if isinstance(module, torch.nn.Conv2d):
conv_layers.append(module)
if not conv_layers:
raise ValueError("The model does not contain convolution layers.")
return conv_layers[-1]