grelu.model.trunks.borzoi#

The Borzoi model architecture and its required classes.

Classes#

BorzoiConvTower

Convolutional tower for the Borzoi model.

BorzoiTrunk

Trunk consisting of conv, transformer and U-net layers for the Borzoi model.

Module Contents#

class grelu.model.trunks.borzoi.BorzoiConvTower(stem_channels: int, stem_kernel_size: int, init_channels: int, out_channels: int, kernel_size: int, n_blocks: int, norm_type='batch', norm_kwargs=None, dtype=None, device=None)[source]#

Bases: torch.nn.Module

Convolutional tower for the Borzoi model.

Parameters:
  • stem_channels – Number of channels in the first (stem) convolutional layer

  • stem_kernel_size – Width of the convolutional kernel in the first (stem) convolutional layer

  • init_channels – Number of channels in the first convolutional block after the stem

  • out_channels – Number of channels in the output

  • kernel_size – Width of the convolutional kernel

  • n_blocks – Number of convolutional/pooling blocks, including the stem

  • norm_type – Type of normalization to apply: ‘batch’, ‘syncbatch’, ‘layer’, ‘instance’ or None

  • norm_kwargs – Additional arguments to be passed to the normalization layer

  • dtype – Data type for the layers.

  • device – Device for the layers.

blocks[source]#
filters[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor

class grelu.model.trunks.borzoi.BorzoiTrunk(stem_channels: int, stem_kernel_size: int, init_channels: int, n_conv: int, kernel_size: int, channels: int, n_transformers: int, key_len: int, value_len: int, pos_dropout: float, attn_dropout: float, n_heads: int, n_pos_features: int, crop_len: int, flash_attn: bool, norm_type='batch', norm_kwargs=None, dtype=None, device=None)[source]#

Bases: torch.nn.Module

Trunk consisting of conv, transformer and U-net layers for the Borzoi model.

Parameters:
  • stem_channels – Number of channels in the first (stem) convolutional layer

  • stem_kernel_size – Width of the convolutional kernel in the first (stem) convolutional layer

  • init_channels – Number of channels in the first convolutional block after the stem

  • n_conv – Number of convolutional/pooling blocks, including the stem

  • kernel_size – Width of the convolutional kernel

  • channels – Number of channels in the output

  • n_transformers – Number of transformer blocks

  • key_len – Length of the key

  • value_len – Length of the value

  • pos_dropout – Dropout rate for positional embeddings

  • attn_dropout – Dropout rate for attention

  • n_heads – Number of attention heads

  • n_pos_features – Number of positional features

  • crop_len – Length of the crop

  • flash_attn – If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len, pos_dropout and n_pos_features are ignored.

  • norm_type – Type of normalization to apply: ‘batch’, ‘syncbatch’, ‘layer’, ‘instance’ or None

  • norm_kwargs – Additional arguments to be passed to the normalization layer

  • dtype – Data type for the layers.

  • device – Device for the layers.

conv_tower[source]#
transformer_tower[source]#
unet_tower[source]#
pointwise_conv[source]#
act[source]#
crop[source]#
forward(x: torch.Tensor) torch.Tensor[source]#

Forward pass

Parameters:

x – Input tensor of shape (N, C, L)

Returns:

Output tensor