models.generator.HeadGen#
- class HeadGen(cfg: str | ListGen, box_out: int, cls_out: int, in_channels: int = 2, init_weights=False)[source]#
Bases:
ModelGen
Model head generator
The configuration lists for this module look different.
Configuration list example#cfgs: ListGen = [ [ Conv(kernel_size=1), Norm(), LSTM(), ], [ Conv(box_out, 1), ], [ Conv(cls_out, 1), ], ],
The configuration includes three lists:
The first one is for data preparation
The second one is for box prediction
The third one is for class prediction
Box and class prediction models use the output of the preparation network as input.
- Parameters:
cfg (str | ListGen) – Lists of layer generators.
box_out (int) – The number of channels obtained as a result of the class prediction network.
cls_out (int) – The number of channels obtained as a result of the box prediction network.
in_channels (int, optional) – Number of input channels. Defaults to 2.
init_weights (bool, optional) – If
true
apply weight initialization function. Defaults to True.
Methods
Network pass for data containing time resolution
State-based network pass
Attributes
training
- forward(X: Tensor) Tensor [source]#
Network pass for data containing time resolution
- Parameters:
X (torch.Tensor) – Input tensor. Shape is Shape [ts, batch, channel, h, w].
- Returns:
The resulting tensor and the list of new states.
- Return type:
- forward_impl(X: Tensor, state: ListState | None) Tuple[Tensor, ListState] [source]#
State-based network pass
- Parameters:
X (torch.Tensor) – Input tensor. Shape is Shape [batch, channel, h, w].
state (ListState | None, optional) – List of block layer states. Defaults to None.
- Returns:
The resulting tensor and the list of new states.
- Return type:
Tuple[torch.Tensor, ListState]