grelu.model.trunks.borzoi#
The Borzoi model architecture and its required classes.
Classes#
Convolutional tower for the Borzoi model. |
|
Trunk consisting of conv, transformer and U-net layers for the Borzoi model. |
Module Contents#
- class grelu.model.trunks.borzoi.BorzoiConvTower(stem_channels: int, stem_kernel_size: int, init_channels: int, out_channels: int, kernel_size: int, n_blocks: int, norm_type='batch', norm_kwargs=None, dtype=None, device=None)[source]#
Bases:
torch.nn.Module
Convolutional tower for the Borzoi model.
- Parameters:
stem_channels – Number of channels in the first (stem) convolutional layer
stem_kernel_size – Width of the convolutional kernel in the first (stem) convolutional layer
init_channels – Number of channels in the first convolutional block after the stem
out_channels – Number of channels in the output
kernel_size – Width of the convolutional kernel
n_blocks – Number of convolutional/pooling blocks, including the stem
norm_type – Type of normalization to apply: ‘batch’, ‘syncbatch’, ‘layer’, ‘instance’ or None
norm_kwargs – Additional arguments to be passed to the normalization layer
dtype – Data type for the layers.
device – Device for the layers.
- class grelu.model.trunks.borzoi.BorzoiTrunk(stem_channels: int, stem_kernel_size: int, init_channels: int, n_conv: int, kernel_size: int, channels: int, n_transformers: int, key_len: int, value_len: int, pos_dropout: float, attn_dropout: float, n_heads: int, n_pos_features: int, crop_len: int, flash_attn: bool, norm_type='batch', norm_kwargs=None, dtype=None, device=None)[source]#
Bases:
torch.nn.Module
Trunk consisting of conv, transformer and U-net layers for the Borzoi model.
- Parameters:
stem_channels – Number of channels in the first (stem) convolutional layer
stem_kernel_size – Width of the convolutional kernel in the first (stem) convolutional layer
init_channels – Number of channels in the first convolutional block after the stem
n_conv – Number of convolutional/pooling blocks, including the stem
kernel_size – Width of the convolutional kernel
channels – Number of channels in the output
n_transformers – Number of transformer blocks
key_len – Length of the key
value_len – Length of the value
pos_dropout – Dropout rate for positional embeddings
attn_dropout – Dropout rate for attention
n_heads – Number of attention heads
n_pos_features – Number of positional features
crop_len – Length of the crop
flash_attn – If True, uses Flash Attention with Rotational Position Embeddings. key_len, value_len, pos_dropout and n_pos_features are ignored.
norm_type – Type of normalization to apply: ‘batch’, ‘syncbatch’, ‘layer’, ‘instance’ or None
norm_kwargs – Additional arguments to be passed to the normalization layer
dtype – Data type for the layers.
device – Device for the layers.