models.generator.HeadGen

models.generator.HeadGen#

class HeadGen(cfg: str | ListGen, box_out: int, cls_out: int, in_channels: int = 2, init_weights=False)[source]#

Bases: ModelGen

Model head generator

The configuration lists for this module look different.

Configuration list example#
cfgs: ListGen = [
    [
        Conv(kernel_size=1),
        Norm(),
        LSTM(),
    ],
    [
        Conv(box_out, 1),
    ],
    [
        Conv(cls_out, 1),
    ],
],

The configuration includes three lists:

  • The first one is for data preparation

  • The second one is for box prediction

  • The third one is for class prediction

Box and class prediction models use the output of the preparation network as input.

Parameters:
  • cfg (str | ListGen) – Lists of layer generators.

  • box_out (int) – The number of channels obtained as a result of the class prediction network.

  • cls_out (int) – The number of channels obtained as a result of the box prediction network.

  • in_channels (int, optional) – Number of input channels. Defaults to 2.

  • init_weights (bool, optional) – If true apply weight initialization function. Defaults to True.

Methods

forward

Network pass for data containing time resolution

forward_impl

State-based network pass

Attributes

training

forward(X: Tensor) Tensor[source]#

Network pass for data containing time resolution

Parameters:

X (torch.Tensor) – Input tensor. Shape is Shape [ts, batch, channel, h, w].

Returns:

The resulting tensor and the list of new states.

Return type:

torch.Tensor,

forward_impl(X: Tensor, state: ListState | None) Tuple[Tensor, ListState][source]#

State-based network pass

Parameters:
  • X (torch.Tensor) – Input tensor. Shape is Shape [batch, channel, h, w].

  • state (ListState | None, optional) – List of block layer states. Defaults to None.

Returns:

The resulting tensor and the list of new states.

Return type:

Tuple[torch.Tensor, ListState]