grelu.model.trunks.enformer#

The Enformer model architecture and its required classes

Classes#

EnformerConvTower

param n_blocks:

Number of convolutional/pooling blocks including the stem.

EnformerTransformerBlock

Transformer tower for enformer model

EnformerTransformerTower

Transformer tower for enformer model

EnformerTrunk

Enformer model architecture.

Module Contents#

class grelu.model.trunks.enformer.EnformerConvTower(n_blocks: int, out_channels: int)[source]#

Bases: torch.nn.Module

Parameters:
  • n_blocks – Number of convolutional/pooling blocks including the stem.

  • out_channels – Number of channels in the output

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.enformer.EnformerTransformerBlock(in_len: int, n_heads: int, key_len: int, attn_dropout: float, pos_dropout: float, ff_dropout: float)[source]#

Bases: torch.nn.Module

Transformer tower for enformer model

Parameters:
  • in_len – Length of the input

  • n_blocks – Number of stacked transformer blocks

  • n_heads – Number of attention heads

  • n_pos_features – Number of positional embedding features

  • key_len – Length of the key vectors

  • value_len – Length of the value vectors.

  • pos_dropout – Dropout probability in the positional embeddings

  • attn_dropout – Dropout probability in the output layer

  • ff_droppout – Dropout probability in the linear feed-forward layers

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.enformer.EnformerTransformerTower(in_channels: int, n_blocks: int, n_heads: int, key_len: int, attn_dropout: float, pos_dropout: float, ff_dropout: float)[source]#

Bases: torch.nn.Module

Transformer tower for enformer model

Parameters:
  • in_channels – Number of channels in the input

  • n_blocks – Number of stacked transformer blocks

  • n_heads – Number of attention heads

  • n_pos_features – Number of positional embedding features

  • key_len – Length of the key vectors

  • value_len – Length of the value vectors.

  • pos_dropout – Dropout probability in the positional embeddings

  • attn_dropout – Dropout probability in the output layer

  • ff_droppout – Dropout probability in the linear feed-forward layers

forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.enformer.EnformerTrunk(n_conv: int = 7, channels: int = 1536, n_transformers: int = 11, n_heads: int = 8, key_len: int = 64, attn_dropout: float = 0.05, pos_dropout: float = 0.01, ff_dropout: float = 0.4, crop_len: int = 0)[source]#

Bases: torch.nn.Module

Enformer model architecture.

Parameters:
  • n_conv – Number of convolutional/pooling blocks

  • channels – Number of output channels for the convolutional tower

  • n_transformers – Number of stacked transformer blocks

  • n_heads – Number of attention heads

  • key_len – Length of the key vectors

  • value_len – Length of the value vectors.

  • pos_dropout – Dropout probability in the positional embeddings

  • attn_dropout – Dropout probability in the output layer

  • ff_droppout – Dropout probability in the linear feed-forward layers

  • crop_len – Number of positions to crop at either end of the output

forward(x)[source]#