models.generator.BlockGen

models.generator.BlockGen#

class BlockGen(in_channels: int, cfgs: ListGen)[source]#

Bases: Module

Block generator for the network

Takes as input two-dimensional arrays of LayerGen. The inner dimensions are sequential, the outer ones are added together. Lists can include blocks from other two-dimensional lists. They will be considered recursively.

Simple configuration list example#
def vgg_block(out_channels: int, kernel: int = 3):
    return Conv(out_channels, kernel), Norm(), LIF()

cfgs: ListGen = [
    *vgg_block(8), Pool("S"), *vgg_block(32), Pool("S"), *vgg_block(64), Pool("S")
]
Example of a configuration list with residual links#
def conv(out_channels: int, kernel: int = 3, stride: int = 1):
    return (
        Conv(out_channels, stride=stride, kernel_size=kernel),
        Norm(),
        LIF(),
    )

def res_block(out_channels: int, kernel: int = 3):
    return (
        Conv(out_channels, 1),
        # Residual block. The values from all branches are added together
        Residual(
            [
                [*conv(out_channels, kernel)],
                [Conv(out_channels, 1)],
            ]
        ),
        Conv(out_channels, 1),
    )

cfgs: ListGen = [
    *conv(64, 7, 2), *res_block(64, 5), *conv(128, 5, 2), *res_block(128)
]
Parameters:
  • in_channels (int) – Number of input channels.

  • cfgs (ListGen) – Two-dimensional lists of layer generators.

Methods

forward

Direct block pass

Attributes

out_channels

The number of channels that will be after applying this block to a tensor with in_channels channels.

training

forward(X: Tensor, state: ListState | None = None) Tuple[Tensor, ListState][source]#

Direct block pass

Parameters:
  • X (torch.Tensor) – Input tensor. Shape is Shape [batch, channel, h, w].

  • state (ListState | None, optional) – List of block layer states. Defaults to None.

Returns:

The resulting tensor and the list of new states.

Return type:

Tuple[torch.Tensor, ListState]

out_channels: int = 0#

The number of channels that will be after applying this block to a tensor with in_channels channels.