grelu.model.heads#

Model head layers to return the final prediction outputs.

Classes#

ConvHead

A 1x1 Conv layer that transforms the the number of channels in the input and then

MLPHead

This block implements the multi-layer perceptron (MLP) module.

Module Contents#

class grelu.model.heads.ConvHead(n_tasks: int, in_channels: int, act_func: str | None = None, pool_func: str | None = None, norm: bool = False, dtype=None, device=None)[source]#

Bases: torch.nn.Module

A 1x1 Conv layer that transforms the the number of channels in the input and then optionally pools along the length axis.

Parameters:
  • n_tasks – Number of tasks (output channels)

  • in_channels – Number of channels in the input

  • norm – If True, batch normalization will be included.

  • act_func – Activation function for the convolutional layer

  • pool_func – Pooling function.

  • dtype – Data type for the layers.

  • device – Device for the layers.

n_tasks[source]#
in_channels[source]#
act_func[source]#
pool_func[source]#
norm[source]#
channel_transform[source]#
pool[source]#
forward(x: torch.Tensor) torch.Tensor[source]#
Parameters:

x – Input data.

class grelu.model.heads.MLPHead(n_tasks: int, in_channels: int, in_len: int, act_func: str | None = None, hidden_size: List[int] = [], norm: bool = False, dropout: float = 0.0, dtype=None, device=None)[source]#

Bases: torch.nn.Module

This block implements the multi-layer perceptron (MLP) module.

Parameters:
  • n_tasks – Number of tasks (output channels)

  • in_channels – Number of channels in the input

  • in_len – Length of the input

  • norm – If True, batch normalization will be included.

  • act_func – Activation function for the linear layers

  • hidden_size – A list of dimensions for each hidden layer of the MLP.

  • dropout – Dropout probability for the linear layers.

  • dtype – Data type for the layers.

  • device – Device for the layers.

n_tasks[source]#
in_channels[source]#
in_len[source]#
act_func[source]#
hidden_size[source]#
norm[source]#
dropout[source]#
blocks[source]#
forward(x: torch.Tensor) torch.Tensor[source]#
Parameters:

x – Input data.