Source code for models.modules

"""
Custom modules and layer generators
"""

import torch
from torch import nn
import torch.nn.functional as F
from typing import Tuple, Optional
import norse.torch as snn

__all__ = (
    "SumPool2d",
    "Storage",
    "LayerGen",
    "Pass",
    "Conv",
    "Norm",
    "LIF",
    "LI",
    "ReLU",
    "SiLU",
    "Tanh",
    "LSTM",
    "Pool",
    "Up",
    "Return",
    "Residual",
    "Dense",
)


#####################################################################
#                          Custom modules                           #
#####################################################################


[docs] class SumPool2d(nn.Module): """Applies a 2D average pooling over an input signal composed of several input planes Summarizes the values of the cells of a kernel. To do this, it calls :external:func:`torch.nn.functional.avg_pool2d` and multiplies the result by the kernel area. """ def __init__(self, kernel_size: int, stride: int = 1, padding: int = 0): """ :param kernel_size: The size of the window. :type kernel_size: int :param stride: The stride of the window. Defaults to 1 :type stride: int, optional :param padding: Implicit zero padding to be added on both sides. Defaults to 0 :type padding: int, optional """ super().__init__() self.kernel_size, self.stride, self.padding = kernel_size, stride, padding
[docs] def forward(self, X: torch.Tensor) -> torch.Tensor: """Direct Module pass :param X: Input tensor. :type X: torch.Tensor :return: Result of summing pool. :rtype: torch.Tensor """ return ( F.avg_pool2d(X, self.kernel_size, self.stride, self.padding) * self.kernel_size * self.kernel_size )
[docs] class Storage(nn.Module): """ Stores the forward pass values It is intended for use in feature pyramids, where you need to get multiple matrices from different places in the network. """ _storage: torch.Tensor
[docs] def forward(self, X: torch.Tensor) -> torch.Tensor: """Store the input tensor and returns it back :param X: Input tensor. :type X: torch.Tensor :return: Input tensor. :rtype: torch.Tensor """ self._storage = X return X
[docs] def get_storage(self) -> torch.Tensor: """Returns the stored tensor :return: Stored tensor. :rtype: torch.Tensor """ temp = self._storage self._storage = None return temp
[docs] class ConvLSTM(nn.Module): """Convolutional LSTM For more details, see https://github.com/ndrplz/ConvLSTM_pytorch/tree/master. """ def __init__( self, in_channels: int, hidden_channels: int, kernel_size: int = 1, bias: bool = False, ): """ :param in_channels: Number of input channels. :type in_channels: int :param hidden_channels: Number of hidden channels. :type hidden_channels: int :param kernel_size: Size of the convolving kernel. Defaults to 1. :type kernel_size: int, optional :param bias: If ``True``, adds a learnable bias to the output. Defaults to False. :type bias: bool, optional """ super().__init__() self.in_channels = in_channels self.hidden_channels = hidden_channels self.conv = nn.Conv2d( in_channels=self.in_channels + self.hidden_channels, out_channels=4 * self.hidden_channels, kernel_size=kernel_size, bias=bias, ) def _init_hidden(self, target: torch.Tensor): batch, _, h, w = target.shape return ( torch.zeros((batch, self.hidden_channels, h, w), device=target.device), torch.zeros((batch, self.hidden_channels, h, w), device=target.device), )
[docs] def forward( self, X: torch.Tensor, state: Optional[Tuple[torch.Tensor, torch.Tensor]] = None ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ :param X: Input tensor. Shape [batch, channel, h, w]. :type X: torch.Tensor :param state: Past state of the cell. Defaults to None. It is a list of the form: (hidden state, cell state). :type state: Optional[Tuple[torch.Tensor, torch.Tensor]], optional :return: List of form: (next hidden state, (next hidden state, next cell state)). :rtype: Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] """ hidden_state, cell_state = self._init_hidden(X) if state is None else state combined = torch.cat([X, hidden_state], dim=1) combined = self.conv(combined) input_gate, forget_gate, out_gate, in_node = torch.split( combined, self.hidden_channels, dim=1 ) I = torch.sigmoid(input_gate) F = torch.sigmoid(forget_gate) O = torch.sigmoid(out_gate) C = torch.tanh(in_node) cell_next = F * cell_state + I * C hidden_next = O * torch.tanh(cell_next) # This form is needed for the model generator to work return hidden_next, (hidden_next, cell_next)
##################################################################### # Layer Generators # #####################################################################
[docs] class LayerGen: """Base class for model layer generators The ``get`` method must initialize the network module and pass it to the generator (See :class:`BlockGen`). .. warning:: This class can only be used as a base class for inheritance. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: """Initializes and returns the network layer :param in_channels: Number of input channels. :type in_channels: int :return: The generated module and the number of channels that will be after applying this layer to a tensor with ``in_channels`` channels. :rtype: Tuple[nn.Module, int] """ raise NotImplementedError
[docs] class Pass(LayerGen): """A placeholder layer generator that does nothing Uses :external:class:`torch.nn.Identity` module. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return nn.Identity(), in_channels
[docs] class Conv(LayerGen): """Generator of standard 2d convolution Uses :external:class:`torch.nn.Conv2d` module. Bias defaults to ``False``, padding is calculated automatically. """ def __init__(self, out_channels: int = None, kernel_size: int = 3, stride: int = 1): """ :param out_channels: Number of channels produced by the convolution. Defaults to None. :type out_channels: int, optional :param kernel_size: Size of the convolving kernel. Defaults to 3. :type kernel_size: int, optional :param stride: Stride of the convolution. Defaults to 1. :type stride: int, optional """ self.out_channels = out_channels self.kernel_size = kernel_size self.stride = stride
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: out = in_channels if self.out_channels is None else self.out_channels return nn.Conv2d( in_channels, out, kernel_size=self.kernel_size, padding=int(self.kernel_size / 2), stride=self.stride, bias=False, ), out
[docs] class Pool(LayerGen): """Pooling layer generator Uses modules :external:class:`torch.nn.AvgPool2d`, :external:class:`torch.nn.MaxPool2d`, :class:`SumPool2d`. """ def __init__(self, type: str, kernel_size: int = 2, stride: Optional[int] = None): """ :param type: Pooling type. - ``A`` - :external:class:`torch.nn.AvgPool2d`. - ``M`` - :external:class:`torch.nn.MaxPool2d`. - ``S`` - :class:`SumPool2d`. :type type: str :param kernel_size: The size of the window. Defaults to 2. :type kernel_size: int, optional :param stride: The stride of the window. Default value is kernel_size. :type stride: Optional[int], optional :raises ValueError: Non-existent pool type. """ self.kernel_size = kernel_size self.stride = stride if (stride is not None) else kernel_size match type: case "A": self.pool = nn.AvgPool2d case "M": self.pool = nn.MaxPool2d case "S": self.pool = SumPool2d case _: raise ValueError(f'[ERROR]: Non-existent pool type "{type}"!')
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return self.pool(self.kernel_size, self.stride), in_channels
[docs] class Up(LayerGen): """Upsample layer generator Uses :external:class:`torch.nn.Upsample` module. """ def __init__(self, scale: int = 2, mode: str = "nearest"): """ :param scale: Multiplier for spatial size. Defaults to 2. :type scale: int, optional :param mode: The upsampling algorithm: one of 'nearest', 'linear', 'bilinear', 'bicubic' and 'trilinear'. Defaults to "nearest". :type mode: str, optional """ self.scale = scale self.mode = mode
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return nn.Upsample(scale_factor=self.scale, mode=self.mode), in_channels
[docs] class Norm(LayerGen): """Batch Normalization layer generator Uses :external:class:`torch.nn.BatchNorm2d` module. """ def __init__(self, bias: bool = False): """ :param bias: If True, adds a learnable bias. Defaults to False. :type bias: bool, optional """ self.bias = bias
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: norm_layer = nn.BatchNorm2d(in_channels) if not self.bias: norm_layer.bias = None return norm_layer, in_channels
[docs] class LIF(LayerGen): """Generator of the layer of LIF neurons Uses :external:class:`norse.torch.module.lif.LIFCell` module. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return snn.LIFCell(), in_channels
[docs] class LI(LayerGen): """Generator of the layer of LI neurons Uses :external:class:`norse.torch.module.leaky_integrator.LICell` module. """
[docs] def get(self, in_channels: int) -> Tuple[snn.LICell, int]: return snn.LICell(), in_channels
[docs] class ReLU(LayerGen): """ReLU layer generator Uses :external:class:`torch.nn.ReLU` module. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return nn.ReLU(), in_channels
[docs] class SiLU(LayerGen): """SiLU layer generator Uses :external:class:`torch.nn.SiLU` module. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return nn.SiLU(), in_channels
[docs] class Tanh(LayerGen): """SiLU layer generator Uses :external:class:`torch.nn.Tanh` module. """
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: return nn.Tanh(), in_channels
[docs] class LSTM(LayerGen): """LSTM layer generator Uses :class:`ConvLSTM` module. """ def __init__(self, hidden_size: Optional[int] = None): """ :param hidden_size: Number of hidden channels. Defaults to None. :type hidden_size: Optional[int], optional """ self.hidden_size = hidden_size
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: h_size = in_channels if self.hidden_size is None else self.hidden_size return ConvLSTM(in_channels, h_size), h_size
[docs] class Return(LayerGen): """Generator of layers for storing forward pass values It is intended for use in feature pyramids, where you need to get multiple matrices from different places in the network. Uses :class:`Storage` module. """ out_channels: int
[docs] def get(self, in_channels: int) -> Tuple[nn.Module, int]: self.out_channels = in_channels return Storage(), in_channels
[docs] class Residual(list): """Class inherited from :external:class:`list` type without changes Needed to mark a network in the configuration as residual. .. code-block:: :caption: Example Residual( [ [*conv(out_channels, kernel)], [Conv(out_channels, 1)], ] ) """ pass
[docs] class Dense(list): """Class inherited from :external:class:`list` type without changes Needed to mark the network in the configuration as densely connected. .. code-block:: :caption: Example Dense( [ [*conv(out_channels, kernel)], [Conv(out_channels, 1)], ] ) """ pass