models.generator.NeckGen#

class NeckGen(cfg: str | ListGen, in_channels: int = 2, init_weights: bool = False)[source]#

Bases: ModelGen

Network neck generator

Returns a list of tensors that were stored in the models.modules.Return layers.

Parameters:
  • cfg (BaseConfig) – Network Configuration Generator.

  • in_channels (int, optional) – Number of input channels. Defaults to 2.

  • init_weights (bool, optional) – If true apply weight initialization function. Defaults to True.

Methods

forward

Network pass for data containing time resolution

forward_impl

State-based network pass

Attributes

out_shape

Stores the format of the output data

training

forward(X: Tensor) Tensor[source]#

Network pass for data containing time resolution

Parameters:

X (torch.Tensor) – Input tensor. Shape is Shape [ts, batch, channel, h, w].

Returns:

The resulting tensor and the list of new states.

Return type:

torch.Tensor,

forward_impl(X: List[Tensor], state: ListState | None) Tuple[Tensor, ListState][source]#

State-based network pass

Parameters:
  • X (torch.Tensor) – Input tensor. Shape is Shape [batch, channel, h, w].

  • state (ListState | None, optional) – List of block layer states. Defaults to None.

Returns:

The resulting tensor and the list of new states.

Return type:

Tuple[torch.Tensor, ListState]

out_shape: List[int]#

Stores the format of the output data

  • The number of elements is equal to the number of tensors in the output list.

  • The numeric value is equal to the number of channels of the corresponding tensor.

This data is required to initialize models.head.Head.