models.generator.BlockGen#
- class BlockGen(in_channels: int, cfgs: ListGen)[source]#
Bases:
Module
Block generator for the network
Takes as input two-dimensional arrays of
LayerGen
. The inner dimensions are sequential, the outer ones are added together. Lists can include blocks from other two-dimensional lists. They will be considered recursively.Simple configuration list example#def vgg_block(out_channels: int, kernel: int = 3): return Conv(out_channels, kernel), Norm(), LIF() cfgs: ListGen = [ *vgg_block(8), Pool("S"), *vgg_block(32), Pool("S"), *vgg_block(64), Pool("S") ]
Example of a configuration list with residual links#def conv(out_channels: int, kernel: int = 3, stride: int = 1): return ( Conv(out_channels, stride=stride, kernel_size=kernel), Norm(), LIF(), ) def res_block(out_channels: int, kernel: int = 3): return ( Conv(out_channels, 1), # Residual block. The values from all branches are added together Residual( [ [*conv(out_channels, kernel)], [Conv(out_channels, 1)], ] ), Conv(out_channels, 1), ) cfgs: ListGen = [ *conv(64, 7, 2), *res_block(64, 5), *conv(128, 5, 2), *res_block(128) ]
- Parameters:
in_channels (int) – Number of input channels.
cfgs (ListGen) – Two-dimensional lists of layer generators.
Methods
Direct block pass
Attributes
The number of channels that will be after applying this block to a tensor with
in_channels
channels.training
- forward(X: Tensor, state: ListState | None = None) Tuple[Tensor, ListState] [source]#
Direct block pass
- Parameters:
X (torch.Tensor) – Input tensor. Shape is Shape [batch, channel, h, w].
state (ListState | None, optional) – List of block layer states. Defaults to None.
- Returns:
The resulting tensor and the list of new states.
- Return type:
Tuple[torch.Tensor, ListState]