grelu.model.trunks#

Some general purpose model architectures.

Submodules#

Classes#

ConvTower

A module that consists of multiple convolutional blocks and takes a one-hot encoded

GRUBlock

Stacked bidirectional GRU layers followed by a feed-forward network.

TransformerTower

Multiple stacked transformer encoder layers.

ConvTrunk

A fully convolutional trunk that optionally includes pooling,

DilatedConvTrunk

A model architecture based on dilated convolutional layers with residual connections.

ConvGRUTrunk

A model consisting of a convolutional tower followed by a bidirectional GRU layer and optional pooling.

ConvTransformerTrunk

A model consisting of a convolutional tower followed by a transformer encoder layer and optional pooling.

Package Contents#

class grelu.model.trunks.ConvTower(stem_channels: int, stem_kernel_size: int, n_blocks: int = 2, channel_init: int = 16, channel_mult: float = 1, kernel_size: int = 5, dilation_init: int = 1, dilation_mult: float = 1, act_func: str = 'relu', norm: bool = False, pool_func: str | None = None, pool_size: int | None = None, residual: bool = False, dropout: float = 0.0, order: str = 'CDNRA', crop_len: int | str = 0, dtype=None, device=None)[source]#

Bases: torch.nn.Module

A module that consists of multiple convolutional blocks and takes a one-hot encoded DNA sequence as input.

Parameters:
  • n_blocks – Number of convolutional blocks, including the stem

  • stem_channels – Number of channels in the stem,

  • stem_kernel_size – Kernel width for the stem

  • kernel_size – Convolutional kernel width

  • channel_init – Initial number of channels,

  • channel_mult – Factor by which to multiply the number of channels in each block

  • dilation_init – Initial dilation

  • dilation_mult – Factor by which to multiply the dilation in each block

  • act_func – Name of the activation function

  • pool_func – Name of the pooling function

  • pool_size – Width of the pooling layers

  • dropout – Dropout probability

  • norm – If True, apply batch norm

  • residual – If True, apply residual connection

  • order – A string representing the order in which operations are to be performed on the input. For example, “CDNRA” means that the operations will be performed in the order: convolution, dropout, batch norm, residual addition, activation. Pooling is not included as it is always performed last.

  • crop_len – Number of positions to crop at either end of the output

  • dtype – Data type of the weights

  • device – Device on which to store

blocks[source]#
receptive_field[source]#
pool_factor = 1[source]#
out_channels[source]#
crop[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.GRUBlock(in_channels: int, n_layers: int = 1, dropout: float = 0.0, act_func: str = 'relu', norm: bool = False, dtype=None, device=None)[source]#

Bases: torch.nn.Module

Stacked bidirectional GRU layers followed by a feed-forward network.

Parameters:
  • in_channels – The number of channels in the input

  • n_layers – The number of GRU layers

  • gru_hidden_size – Number of hidden elements in GRU layers

  • dropout – Dropout probability

  • act_func – Name of the activation function for feed-forward network

  • norm – If True, include layer normalization in feed-forward network.

  • dtype – Data type of the weights

  • device – Device on which to store the weights

gru[source]#
ffn[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.TransformerTower(in_channels: int, n_blocks: int = 1, n_heads: int = 1, n_pos_features: int = 32, key_len: int = 64, value_len: int = 64, pos_dropout: float = 0.0, attn_dropout: float = 0.0, ff_dropout: float = 0.0, flash_attn: bool = False, dtype=None, device=None)[source]#

Bases: torch.nn.Module

Multiple stacked transformer encoder layers.

Parameters:
  • in_channels – Number of channels in the input

  • n_blocks – Number of stacked transformer blocks

  • n_heads – Number of attention heads

  • n_pos_features – Number of positional embedding features

  • key_len – Length of the key vectors

  • value_len – Length of the value vectors.

  • pos_dropout – Dropout probability in the positional embeddings

  • attn_dropout – Dropout probability in the output layer

  • ff_droppout – Dropout probability in the linear feed-forward layers

  • flash_attn – If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len, pos_dropout and n_pos_features are ignored.

  • dtype – Data type of the weights

  • device – Device on which to store the weights

blocks[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.ConvTrunk(stem_channels: int = 64, stem_kernel_size: int = 15, n_conv: int = 2, channel_init: int = 64, channel_mult: float = 1, kernel_size: int = 5, dilation_init: int = 1, dilation_mult: float = 1, act_func: str = 'relu', norm: bool = False, pool_func: str | None = None, pool_size: int | None = None, residual: bool = False, dropout: float = 0.0, crop_len: int = 0, **kwargs)[source]#

Bases: torch.nn.Module

A fully convolutional trunk that optionally includes pooling, residual connections, batch normalization, or dilated convolutions.

Parameters:
  • stem_channels – Number of channels in the stem

  • stem_kernel_size – Kernel width for the stem

  • n_blocks – Number of convolutional blocks, not including the stem

  • kernel_size – Convolutional kernel width

  • channel_init – Initial number of channels,

  • channel_mult – Factor by which to multiply the number of channels in each block

  • dilation_init – Initial dilation

  • dilation_mult – Factor by which to multiply the dilation in each block

  • act_func – Name of the activation function

  • pool_func – Name of the pooling function

  • pool_size – Width of the pooling layers

  • dropout – Dropout probability

  • norm – If True, apply batch norm

  • residual – If True, apply residual connection

  • order – A string representing the order in which operations are to be performed on the input. For example, “CDNRA” means that the operations will be performed in the order: convolution, dropout, batch norm, residual addition, activation. Pooling is not included as it is always performed last.

  • crop_len – Number of positions to crop at either end of the output

  • kwargs – Additional keyword arguments for the convolutional blocks

conv_tower[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.DilatedConvTrunk(channels: int = 64, stem_kernel_size: int = 21, kernel_size: int = 3, dilation_mult: float = 2, act_func: str = 'relu', n_conv: int = 8, crop_len: str | int = 'auto', **kwargs)[source]#

Bases: torch.nn.Module

A model architecture based on dilated convolutional layers with residual connections. Inspired by the ChromBPnet model architecture.

Parameters:
  • channels – Number of channels for all convolutional layers

  • stem_kernel_size – Kernel width for the stem

  • n_conv – Number of convolutional blocks, not including the stem

  • kernel_size – Convolutional kernel width

  • dilation_mult – Factor by which to multiply the dilation in each block

  • act_func – Name of the activation function

  • crop_len – Number of positions to crop at either end of the output

  • kwargs – Additional keyword arguments for the dilated-convolutional blocks

conv_tower[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.ConvGRUTrunk(stem_channels: int = 16, stem_kernel_size: int = 15, n_conv: int = 2, channel_init: int = 16, channel_mult: float = 1, kernel_size: int = 5, act_func: str = 'relu', conv_norm: bool = False, pool_func: str | None = None, pool_size: int | None = None, residual: bool = False, crop_len: int = 0, n_gru: int = 1, dropout: float = 0.0, gru_norm: bool = False, dtype=None, device=None)[source]#

Bases: torch.nn.Module

A model consisting of a convolutional tower followed by a bidirectional GRU layer and optional pooling.

Parameters:
  • stem_channels – Number of channels in the stem

  • stem_kernel_size – Kernel width for the stem

  • n_conv – Number of convolutional blocks, not including the stem

  • kernel_size – Convolutional kernel width

  • channel_init – Initial number of channels,

  • channel_mult – Factor by which to multiply the number of channels in each block

  • act_func – Name of the activation function

  • pool_func – Name of the pooling function

  • pool_size – Width of the pooling layers

  • conv_norm – If True, apply batch normalization in the convolutional layers.

  • residual – If True, apply residual connections in the convolutional layers.

  • crop_len – Number of positions to crop at either end of the output

  • n_gru – Number of GRU layers

  • dropout – Dropout for GRU and feed-forward layers

  • gru_norm – If True, include layer normalization in feed-forward network.

  • dtype – Data type for the layers.

  • device – Device for the layers.

conv_tower[source]#
gru_tower[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.ConvTransformerTrunk(stem_channels: int = 16, stem_kernel_size: int = 15, n_conv: int = 2, channel_init: int = 16, channel_mult: float = 1, kernel_size: int = 5, act_func: str = 'relu', norm: bool = False, pool_func: str | None = None, pool_size: int | None = None, residual: bool = False, crop_len: int = 0, n_transformers=1, key_len: int = 8, value_len: int = 8, n_heads: int = 1, n_pos_features: int = 4, pos_dropout: float = 0.0, attn_dropout: float = 0.0, ff_dropout: float = 0.0, dtype=None, device=None)[source]#

Bases: torch.nn.Module

A model consisting of a convolutional tower followed by a transformer encoder layer and optional pooling.

Parameters:
  • stem_channels – Number of channels in the stem

  • stem_kernel_size – Kernel width for the stem

  • n_conv – Number of convolutional blocks, not including the stem

  • kernel_size – Convolutional kernel width

  • channel_init – Initial number of channels,

  • channel_mult – Factor by which to multiply the number of channels in each block

  • act_func – Name of the activation function

  • pool_func – Name of the pooling function

  • pool_size – Width of the pooling layers

  • conv_norm – If True, apply batch normalization in the convolutional layers.

  • residual – If True, apply residual connections in the convolutional layers.

  • crop_len – Number of positions to crop at either end of the output

  • n_transformers – Number of transformer encoder layers

  • n_heads – Number of heads in each multi-head attention layer

  • n_pos_features – Number of positional embedding features

  • key_len – Length of the key vectors

  • value_len – Length of the value vectors.

  • pos_dropout – Dropout probability in the positional embeddings

  • attn_dropout – Dropout probability in the output layer

  • ff_droppout – Dropout probability in the linear feed-forward layers

  • device – Device for the layers.

  • dtype – Data type for the layers.

conv_tower[source]#
transformer_tower[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor