grelu.model.layers#

Commonly used layers to build deep learning models.

Classes#

Activation

A nonlinear activation layer.

Pool

A pooling layer.

AdaptivePool

An Adaptive Pooling layer. This layer does not have a defined pooling width but

Norm

A batch normalization or layer normalization layer.

ChannelTransform

A convolutional layer to transform the number of channels in the input.

Dropout

Optional dropout layer

Crop

Optional cropping layer.

Attention

FlashAttention

Module Contents#

class grelu.model.layers.Activation(func: str)[source]#

Bases: torch.nn.Module

A nonlinear activation layer.

Parameters:

func – The type of activation function. Supported values are ‘relu’, ‘elu’, ‘softplus’, ‘gelu’, ‘gelu_enformer’ and ‘exp’. If None, will return nn.Identity.

Raises:

NotImplementedError – If ‘func’ is not a supported activation function.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.Pool(func: str | None, pool_size: int | None = None, in_channels: int | None = None, **kwargs)[source]#

Bases: torch.nn.Module

A pooling layer.

Parameters:
  • func – Type of pooling function. Supported values are ‘avg’, ‘max’, or ‘attn’. If None, will return nn.Identity.

  • pool_size – The number of positions to pool together

  • in_channels – Number of channels in the input. Only needeed for attention pooling.

  • **kwargs – Additional arguments to pass to the pooling function.

Raises:

NotImplementedError – If ‘func’ is not a supported pooling function.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.AdaptivePool(func: str | None = None)[source]#

Bases: torch.nn.Module

An Adaptive Pooling layer. This layer does not have a defined pooling width but instead pools together all the values in the last axis.

Parameters:

func – Type of pooling function. Supported values are ‘avg’ or ‘max’. If None, will return nn.Identity.

Raises:

NotImplementedError – If ‘func’ is not a supported pooling function.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.Norm(func: str | None = None, in_dim: int | None = None, **kwargs)[source]#

Bases: torch.nn.Module

A batch normalization or layer normalization layer.

Parameters:
  • func – Type of normalization function. Supported values are ‘batch’, ‘syncbatch’, ‘instance’, or ‘layer’. If None, will return nn.Identity.

  • in_dim – Number of features in the input tensor.

  • **kwargs – Additional arguments to pass to the normalization function.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.ChannelTransform(in_channels: int, out_channels: int = 1, if_equal: bool = False, **kwargs)[source]#

Bases: torch.nn.Module

A convolutional layer to transform the number of channels in the input.

Parameters:
  • in_channels – Number of channels in the input

  • out_channels – Number of channels in the output

  • if_equal – Whether to create layer if input and output channels are equal

  • **kwargs – Additional arguments to pass to the convolutional layer.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.Dropout(p: float = 0.0)[source]#

Bases: torch.nn.Module

Optional dropout layer

Parameters:

p – Dropout probability. If this is set to 0, will return nn.Identity.

layer[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.Crop(crop_len: int = 0, receptive_field: int | None = None)[source]#

Bases: torch.nn.Module

Optional cropping layer.

Parameters:
  • crop_len – Number of positions to crop at each end of the input.

  • receptive_field – Receptive field of the model to calculate crop_len. Only needed if crop_len is None.

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.Attention(in_len: int, key_len: int, value_len: int, n_heads: int, n_pos_features: int, pos_dropout: float = 0, attn_dropout: float = 0, device=None, dtype=None)[source]#

Bases: torch.nn.Module

in_len[source]#
key_len[source]#
value_len[source]#
n_heads[source]#
n_pos_features[source]#
to_q[source]#
to_k[source]#
to_v[source]#
to_out[source]#
positional_embed[source]#
to_pos_k[source]#
rel_content_bias[source]#
rel_pos_bias[source]#
pos_dropout[source]#
attn_dropout[source]#
_get_pos_k(x)[source]#
get_attn_scores(x, return_v=False)[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.layers.FlashAttention(embed_dim: int, n_heads: int, dropout_p=0.0, device=None, dtype=None)[source]#

Bases: torch.nn.Module

embed_dim[source]#
n_heads[source]#
head_dim[source]#
dropout_p[source]#
qkv[source]#
out[source]#
rotary_embed[source]#
flash_attn_qkvpacked_func[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (batch_size, seq_len, embed_dim)

Returns:

Output tensor