Skip to content

Layers

beignet.nn.AlphaFold3

Bases: Module

Main Inference Loop for AlphaFold 3.

This module implements Algorithm 1 exactly, which is the main inference pipeline for AlphaFold 3. It processes input features through multiple stages including feature embedding, MSA processing, template embedding, Pairformer stacks, diffusion sampling, and confidence prediction.

Parameters:

Name Type Description Default
n_cycle int

Number of recycling cycles

4
c_s int

Single representation dimension

384
c_z int

Pair representation dimension

128
c_m int

MSA representation dimension

64
c_template int

Template feature dimension

64
n_blocks_pairformer int

Number of blocks in PairformerStack

48
n_head int

Number of attention heads

16

Examples:

>>> import torch
>>> from beignet.nn import AlphaFold3
>>> batch_size, n_tokens = 2, 64
>>> module = AlphaFold3(n_cycle=2)  # Smaller for example
>>> f_star = {
...     'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
...     'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
...     'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
...     'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
...     'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
...     'token_bonds': torch.randn(batch_size, n_tokens, n_tokens, 32)
... }
>>> outputs = module(f_star)
>>> outputs['x_pred'].shape
torch.Size([2, 64, 3])
References

.. [1] AlphaFold 3 Algorithm 1: Main Inference Loop

Source code in src/beignet/nn/alphafold3/_alphafold3.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
class AlphaFold3(nn.Module):
    r"""
    Main Inference Loop for AlphaFold 3.

    This module implements Algorithm 1 exactly, which is the main inference
    pipeline for AlphaFold 3. It processes input features through multiple
    stages including feature embedding, MSA processing, template embedding,
    Pairformer stacks, diffusion sampling, and confidence prediction.

    Parameters
    ----------
    n_cycle : int, default=4
        Number of recycling cycles
    c_s : int, default=384
        Single representation dimension
    c_z : int, default=128
        Pair representation dimension
    c_m : int, default=64
        MSA representation dimension
    c_template : int, default=64
        Template feature dimension
    n_blocks_pairformer : int, default=48
        Number of blocks in PairformerStack
    n_head : int, default=16
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import AlphaFold3
    >>> batch_size, n_tokens = 2, 64
    >>> module = AlphaFold3(n_cycle=2)  # Smaller for example
    >>> f_star = {
    ...     'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
    ...     'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
    ...     'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
    ...     'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
    ...     'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
    ...     'token_bonds': torch.randn(batch_size, n_tokens, n_tokens, 32)
    ... }
    >>> outputs = module(f_star)
    >>> outputs['x_pred'].shape
    torch.Size([2, 64, 3])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 1: Main Inference Loop
    """

    def __init__(
        self,
        n_cycle: int = 4,
        c_s: int = 384,
        c_z: int = 128,
        c_m: int = 64,
        c_template: int = 64,
        n_blocks_pairformer: int = 48,
        n_head: int = 16,
    ):
        super().__init__()

        self.n_cycle = n_cycle
        self.c_s = c_s
        self.c_z = c_z

        # Step 1: Input Feature Embedder
        self.input_feature_embedder = _InputFeatureEmbedder(
            c_atom=128,
            c_atompair=16,
            c_token=c_s,
        )

        # Step 2-3: Linear projections for initial representations
        self.single_linear = nn.Linear(c_s, c_s, bias=False)  # s_i^init
        self.pair_linear_i = nn.Linear(c_s, c_z, bias=False)  # z_ij^init from s_i
        self.pair_linear_j = nn.Linear(c_s, c_z, bias=False)  # z_ij^init from s_j

        # Step 4: Relative Position Encoding
        self.relative_position_encoding = RelativePositionEncoding(
            c_z=c_z,
        )

        # Step 5: Token bonds projection
        self.token_bonds_linear = nn.Linear(
            32, c_z, bias=False
        )  # Assuming 32 bond features

        # Step 8: Layer norm for pair initialization
        self.pair_layer_norm = nn.LayerNorm(c_z)

        # Step 9: Template Embedder
        self.template_embedder = TemplateEmbedder(
            c_z=c_z,
            c_template=c_template,
        )

        # Step 10: MSA Module
        self.msa_module = MSA(
            c_m=c_m,
            c_z=c_z,
            c_s=c_s,
        )

        # Step 11: Single representation update
        self.single_update_linear = nn.Linear(c_s, c_s, bias=False)
        self.single_layer_norm = nn.LayerNorm(c_s)

        # Step 12: Pairformer Stack
        self.pairformer_stack = PairformerStack(
            n_block=n_blocks_pairformer,
            c_s=c_s,
            c_z=c_z,
            n_head_single=n_head,
            n_head_pair=n_head // 4,  # Typically fewer heads for pair attention
        )

        # Step 15: Sample Diffusion
        self.sample_diffusion = SampleDiffusion()

        # Step 16: Confidence Head
        self.confidence_head = _Confidence(
            c_s=c_s,
            c_z=c_z,
        )

        # Step 17: Distogram Head
        self.distogram_head = _Distogram(
            c_z=c_z,
        )

    def forward(self, f_star: dict[str, Tensor]) -> dict[str, Tensor]:
        r"""
        Forward pass implementing Algorithm 1 exactly.

        Parameters
        ----------
        f_star : dict
            Dictionary containing input features with keys:
            - 'asym_id': asymmetric unit IDs (batch, n_tokens)
            - 'residue_index': residue indices (batch, n_tokens)
            - 'entity_id': entity IDs (batch, n_tokens)
            - 'token_index': token indices (batch, n_tokens)
            - 'sym_id': symmetry IDs (batch, n_tokens)
            - 'token_bonds': token bond features (batch, n_tokens, n_tokens, bond_dim)
            - Optional: 'template_features', 'msa_features', etc.

        Returns
        -------
        outputs : dict
            Dictionary containing:
            - 'x_pred': predicted coordinates (batch, n_tokens, 3)
            - 'p_plddt': pLDDT confidence (batch, n_tokens)
            - 'p_pae': PAE confidence (batch, n_tokens, n_tokens)
            - 'p_pde': PDE confidence (batch, n_tokens, n_tokens)
            - 'p_resolved': resolved confidence (batch, n_tokens)
            - 'p_distogram': distance distributions (batch, n_tokens, n_tokens, n_bins)
        """
        # Step 1: Input Feature Embedder
        embeddings = self.input_feature_embedder(f_star)
        s_inputs = embeddings["single"]  # (batch, n_tokens, c_s)

        # Step 2: Initialize single representation
        s_i_init = self.single_linear(s_inputs)  # (batch, n_tokens, c_s)

        # Step 3: Initialize pair representation
        # z_ij^init = LinearNoBias(s_i^inputs) + LinearNoBias(s_j^inputs)
        pair_i = self.pair_linear_i(s_inputs).unsqueeze(-2)  # (batch, n_tokens, 1, c_z)
        pair_j = self.pair_linear_j(s_inputs).unsqueeze(-3)  # (batch, 1, n_tokens, c_z)
        z_ij_init = pair_i + pair_j  # (batch, n_tokens, n_tokens, c_z)

        # Step 4: Add relative position encoding
        z_ij_init = z_ij_init + self.relative_position_encoding(f_star)

        # Step 5: Add token bonds (if available)
        if "token_bonds" in f_star:
            token_bonds = f_star["token_bonds"]  # (batch, n_tokens, n_tokens, bond_dim)
            z_ij_init = z_ij_init + self.token_bonds_linear(token_bonds)

        # Step 6: Initialize accumulators
        z_ij = torch.zeros_like(z_ij_init)
        s_i = torch.zeros_like(s_i_init)

        # Step 7-14: Main recycling loop
        for _ in range(self.n_cycle):
            # Step 8: Update pair representation
            z_ij = z_ij_init + self.pair_layer_norm(z_ij)

            # Step 9: Template Embedder
            z_ij = z_ij + self.template_embedder(f_star, z_ij)

            # Step 10: MSA Module
            if "msa_features" in f_star:
                z_ij = z_ij + self.msa_module(
                    f_star["msa_features"],
                    f_star.get("has_deletion"),
                    f_star.get("deletion_value"),
                    s_inputs,
                    z_ij,
                )

            # Step 11: Update single representation
            s_i = s_i_init + self.single_update_linear(self.single_layer_norm(s_i))

            # Step 12: Pairformer Stack
            s_i, z_ij = self.pairformer_stack(s_i, z_ij)

            # Step 13: Copy for next iteration (handled by loop)

        # Step 15: Sample Diffusion
        x_pred = self.sample_diffusion(
            f_star, s_inputs, s_i, z_ij, noise_schedule=torch.linspace(1.0, 0.01, 20)
        )

        # Step 16: Confidence Head
        confidence_outputs = self.confidence_head(
            {"token_single_initial_repr": s_inputs}, s_i, z_ij, x_pred
        )

        # Step 17: Distogram Head
        p_distogram = self.distogram_head(z_ij)

        # Step 18: Return all outputs
        return {
            "x_pred": x_pred,
            "p_plddt": confidence_outputs["p_plddt"],
            "p_pae": confidence_outputs["p_pae"],
            "p_pde": confidence_outputs["p_pde"],
            "p_resolved": confidence_outputs["p_resolved"],
            "p_distogram": p_distogram,
        }
forward
forward(f_star)

Forward pass implementing Algorithm 1 exactly.

Parameters:

Name Type Description Default
f_star dict

Dictionary containing input features with keys: - 'asym_id': asymmetric unit IDs (batch, n_tokens) - 'residue_index': residue indices (batch, n_tokens) - 'entity_id': entity IDs (batch, n_tokens) - 'token_index': token indices (batch, n_tokens) - 'sym_id': symmetry IDs (batch, n_tokens) - 'token_bonds': token bond features (batch, n_tokens, n_tokens, bond_dim) - Optional: 'template_features', 'msa_features', etc.

required

Returns:

Name Type Description
outputs dict

Dictionary containing: - 'x_pred': predicted coordinates (batch, n_tokens, 3) - 'p_plddt': pLDDT confidence (batch, n_tokens) - 'p_pae': PAE confidence (batch, n_tokens, n_tokens) - 'p_pde': PDE confidence (batch, n_tokens, n_tokens) - 'p_resolved': resolved confidence (batch, n_tokens) - 'p_distogram': distance distributions (batch, n_tokens, n_tokens, n_bins)

Source code in src/beignet/nn/alphafold3/_alphafold3.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
def forward(self, f_star: dict[str, Tensor]) -> dict[str, Tensor]:
    r"""
    Forward pass implementing Algorithm 1 exactly.

    Parameters
    ----------
    f_star : dict
        Dictionary containing input features with keys:
        - 'asym_id': asymmetric unit IDs (batch, n_tokens)
        - 'residue_index': residue indices (batch, n_tokens)
        - 'entity_id': entity IDs (batch, n_tokens)
        - 'token_index': token indices (batch, n_tokens)
        - 'sym_id': symmetry IDs (batch, n_tokens)
        - 'token_bonds': token bond features (batch, n_tokens, n_tokens, bond_dim)
        - Optional: 'template_features', 'msa_features', etc.

    Returns
    -------
    outputs : dict
        Dictionary containing:
        - 'x_pred': predicted coordinates (batch, n_tokens, 3)
        - 'p_plddt': pLDDT confidence (batch, n_tokens)
        - 'p_pae': PAE confidence (batch, n_tokens, n_tokens)
        - 'p_pde': PDE confidence (batch, n_tokens, n_tokens)
        - 'p_resolved': resolved confidence (batch, n_tokens)
        - 'p_distogram': distance distributions (batch, n_tokens, n_tokens, n_bins)
    """
    # Step 1: Input Feature Embedder
    embeddings = self.input_feature_embedder(f_star)
    s_inputs = embeddings["single"]  # (batch, n_tokens, c_s)

    # Step 2: Initialize single representation
    s_i_init = self.single_linear(s_inputs)  # (batch, n_tokens, c_s)

    # Step 3: Initialize pair representation
    # z_ij^init = LinearNoBias(s_i^inputs) + LinearNoBias(s_j^inputs)
    pair_i = self.pair_linear_i(s_inputs).unsqueeze(-2)  # (batch, n_tokens, 1, c_z)
    pair_j = self.pair_linear_j(s_inputs).unsqueeze(-3)  # (batch, 1, n_tokens, c_z)
    z_ij_init = pair_i + pair_j  # (batch, n_tokens, n_tokens, c_z)

    # Step 4: Add relative position encoding
    z_ij_init = z_ij_init + self.relative_position_encoding(f_star)

    # Step 5: Add token bonds (if available)
    if "token_bonds" in f_star:
        token_bonds = f_star["token_bonds"]  # (batch, n_tokens, n_tokens, bond_dim)
        z_ij_init = z_ij_init + self.token_bonds_linear(token_bonds)

    # Step 6: Initialize accumulators
    z_ij = torch.zeros_like(z_ij_init)
    s_i = torch.zeros_like(s_i_init)

    # Step 7-14: Main recycling loop
    for _ in range(self.n_cycle):
        # Step 8: Update pair representation
        z_ij = z_ij_init + self.pair_layer_norm(z_ij)

        # Step 9: Template Embedder
        z_ij = z_ij + self.template_embedder(f_star, z_ij)

        # Step 10: MSA Module
        if "msa_features" in f_star:
            z_ij = z_ij + self.msa_module(
                f_star["msa_features"],
                f_star.get("has_deletion"),
                f_star.get("deletion_value"),
                s_inputs,
                z_ij,
            )

        # Step 11: Update single representation
        s_i = s_i_init + self.single_update_linear(self.single_layer_norm(s_i))

        # Step 12: Pairformer Stack
        s_i, z_ij = self.pairformer_stack(s_i, z_ij)

        # Step 13: Copy for next iteration (handled by loop)

    # Step 15: Sample Diffusion
    x_pred = self.sample_diffusion(
        f_star, s_inputs, s_i, z_ij, noise_schedule=torch.linspace(1.0, 0.01, 20)
    )

    # Step 16: Confidence Head
    confidence_outputs = self.confidence_head(
        {"token_single_initial_repr": s_inputs}, s_i, z_ij, x_pred
    )

    # Step 17: Distogram Head
    p_distogram = self.distogram_head(z_ij)

    # Step 18: Return all outputs
    return {
        "x_pred": x_pred,
        "p_plddt": confidence_outputs["p_plddt"],
        "p_pae": confidence_outputs["p_pae"],
        "p_pde": confidence_outputs["p_pde"],
        "p_resolved": confidence_outputs["p_resolved"],
        "p_distogram": p_distogram,
    }

beignet.nn.alphafold3.AtomAttentionDecoder

Bases: Module

Atom Attention Decoder for AlphaFold 3.

This module broadcasts per-token activations to per-atom activations, applies cross attention transformer, and maps to position updates. Implements Algorithm 6 exactly.

Parameters:

Name Type Description Default
c_token int

Channel dimension for token representations

768
c_atom int

Channel dimension for atom representations

128
n_block int

Number of transformer blocks

3
n_head int

Number of attention heads

4

Examples:

>>> import torch
>>> from beignet.nn.alphafold3 import _AtomAttentionDecoder
>>> batch_size, n_tokens, n_atoms = 2, 32, 1000
>>> module = _AtomAttentionDecoder()
>>>
>>> a = torch.randn(batch_size, n_tokens, 768)  # Token representations
>>> q_skip = torch.randn(batch_size, n_atoms, 768)  # Query skip
>>> c_skip = torch.randn(batch_size, n_atoms, 128)  # Context skip
>>> p_skip = torch.randn(batch_size, n_atoms, n_atoms, 16)  # Pair skip
>>>
>>> r_update = module(a, q_skip, c_skip, p_skip)
>>> r_update.shape
torch.Size([2, 1000, 3])
References

.. [1] AlphaFold 3 Algorithm 6: Atom attention decoder

Source code in src/beignet/nn/alphafold3/_atom_attention_decoder.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
class AtomAttentionDecoder(nn.Module):
    r"""
    Atom Attention Decoder for AlphaFold 3.

    This module broadcasts per-token activations to per-atom activations,
    applies cross attention transformer, and maps to position updates.
    Implements Algorithm 6 exactly.

    Parameters
    ----------
    c_token : int, default=768
        Channel dimension for token representations
    c_atom : int, default=128
        Channel dimension for atom representations
    n_block : int, default=3
        Number of transformer blocks
    n_head : int, default=4
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn.alphafold3 import _AtomAttentionDecoder
    >>> batch_size, n_tokens, n_atoms = 2, 32, 1000
    >>> module = _AtomAttentionDecoder()
    >>>
    >>> a = torch.randn(batch_size, n_tokens, 768)  # Token representations
    >>> q_skip = torch.randn(batch_size, n_atoms, 768)  # Query skip
    >>> c_skip = torch.randn(batch_size, n_atoms, 128)  # Context skip
    >>> p_skip = torch.randn(batch_size, n_atoms, n_atoms, 16)  # Pair skip
    >>>
    >>> r_update = module(a, q_skip, c_skip, p_skip)
    >>> r_update.shape
    torch.Size([2, 1000, 3])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 6: Atom attention decoder
    """

    def __init__(
        self, c_token: int = 768, c_atom: int = 128, n_block: int = 3, n_head: int = 4
    ):
        super().__init__()

        self.c_token = c_token
        self.c_atom = c_atom
        self.n_block = n_block
        self.n_head = n_head

        # Step 1: Broadcast per-token activations to per-atom activations
        self.token_to_atom_proj = nn.Linear(c_token, c_token, bias=False)

        # Step 2: Cross attention transformer
        self.atom_transformer = AtomTransformer(
            n_block=n_block,
            n_head=n_head,
            c_q=c_token,  # Query dimension
            c_kv=c_atom,  # Key-value dimension
            c_pair=None,  # Will be inferred from p_skip
        )

        # Step 3: Map to position updates
        self.position_proj = nn.Linear(c_token, 3, bias=False)
        self.layer_norm = nn.LayerNorm(c_token)

    def forward(
        self, a: Tensor, q_skip: Tensor, c_skip: Tensor, p_skip: Tensor
    ) -> Tensor:
        r"""
        Forward pass of Atom Attention Decoder.

        Implements Algorithm 6 exactly:
        1. q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip
        2. {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4)
        3. r_l^update = LinearNoBias(LayerNorm(q_l))

        Parameters
        ----------
        a : Tensor, shape=(batch_size, n_tokens, c_token)
            Token-level representations
        q_skip : Tensor, shape=(batch_size, n_atoms, c_token)
            Query skip connection
        c_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
            Context skip connection
        p_skip : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
            Pair skip connection

        Returns
        -------
        r_update : Tensor, shape=(batch_size, n_atoms, 3)
            Position updates for atoms
        """
        batch_size, n_tokens, c_token = a.shape
        n_atoms = q_skip.shape[1]

        # Step 1: Broadcast per-token activations to per-atom activations and add skip connection
        # q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip

        # Create token indices for each atom (simple broadcasting approach)
        # For simplicity, we'll map atoms to tokens cyclically
        token_indices = torch.arange(n_atoms, device=a.device) % n_tokens

        # Get corresponding token activations for each atom
        a_tok_idx = a[:, token_indices]  # (batch_size, n_atoms, c_token)

        # Apply linear projection and add skip connection
        q = self.token_to_atom_proj(a_tok_idx) + q_skip

        # Step 2: Cross attention transformer
        # {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4)
        q = self.atom_transformer(q, c_skip, p_skip)

        # Step 3: Map to positions update
        # r_l^update = LinearNoBias(LayerNorm(q_l))
        r_update = self.position_proj(self.layer_norm(q))

        return r_update
forward
forward(a, q_skip, c_skip, p_skip)

Forward pass of Atom Attention Decoder.

Implements Algorithm 6 exactly: 1. q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip 2. {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4) 3. r_l^update = LinearNoBias(LayerNorm(q_l))

Parameters:

Name Type Description Default
a Tensor, shape=(batch_size, n_tokens, c_token)

Token-level representations

required
q_skip Tensor, shape=(batch_size, n_atoms, c_token)

Query skip connection

required
c_skip Tensor, shape=(batch_size, n_atoms, c_atom)

Context skip connection

required
p_skip Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)

Pair skip connection

required

Returns:

Name Type Description
r_update Tensor, shape=(batch_size, n_atoms, 3)

Position updates for atoms

Source code in src/beignet/nn/alphafold3/_atom_attention_decoder.py
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
def forward(
    self, a: Tensor, q_skip: Tensor, c_skip: Tensor, p_skip: Tensor
) -> Tensor:
    r"""
    Forward pass of Atom Attention Decoder.

    Implements Algorithm 6 exactly:
    1. q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip
    2. {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4)
    3. r_l^update = LinearNoBias(LayerNorm(q_l))

    Parameters
    ----------
    a : Tensor, shape=(batch_size, n_tokens, c_token)
        Token-level representations
    q_skip : Tensor, shape=(batch_size, n_atoms, c_token)
        Query skip connection
    c_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
        Context skip connection
    p_skip : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
        Pair skip connection

    Returns
    -------
    r_update : Tensor, shape=(batch_size, n_atoms, 3)
        Position updates for atoms
    """
    batch_size, n_tokens, c_token = a.shape
    n_atoms = q_skip.shape[1]

    # Step 1: Broadcast per-token activations to per-atom activations and add skip connection
    # q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip

    # Create token indices for each atom (simple broadcasting approach)
    # For simplicity, we'll map atoms to tokens cyclically
    token_indices = torch.arange(n_atoms, device=a.device) % n_tokens

    # Get corresponding token activations for each atom
    a_tok_idx = a[:, token_indices]  # (batch_size, n_atoms, c_token)

    # Apply linear projection and add skip connection
    q = self.token_to_atom_proj(a_tok_idx) + q_skip

    # Step 2: Cross attention transformer
    # {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4)
    q = self.atom_transformer(q, c_skip, p_skip)

    # Step 3: Map to positions update
    # r_l^update = LinearNoBias(LayerNorm(q_l))
    r_update = self.position_proj(self.layer_norm(q))

    return r_update

beignet.nn.alphafold3.AtomAttentionEncoder

Bases: Module

Atom Attention Encoder for AlphaFold 3.

This module implements Algorithm 5 exactly, creating atom single conditioning, embedding offsets and distances, running cross attention transformer, and aggregating per-atom to per-token representations.

Parameters:

Name Type Description Default
c_atom int

Channel dimension for atom representations

128
c_atompair int

Channel dimension for atom pair representations

16
c_token int

Channel dimension for token representations

384
n_block int

Number of transformer blocks

3
n_head int

Number of attention heads

4

Examples:

>>> import torch
>>> from beignet.nn import AtomAttentionEncoder
>>> batch_size, n_atoms = 2, 1000
>>> module = AtomAttentionEncoder()
>>>
>>> # Feature dictionary with all required atom features
>>> f_star = {
...     'ref_pos': torch.randn(batch_size, n_atoms, 3),
...     'ref_mask': torch.ones(batch_size, n_atoms),
...     'ref_element': torch.randint(0, 118, (batch_size, n_atoms)),
...     'ref_atom_name_chars': torch.randint(0, 26, (batch_size, n_atoms, 4)),
...     'ref_charge': torch.randn(batch_size, n_atoms),
...     'restype': torch.randint(0, 21, (batch_size, n_atoms)),
...     'profile': torch.randn(batch_size, n_atoms, 20),
...     'deletion_mean': torch.randn(batch_size, n_atoms),
...     'ref_space_uid': torch.randint(0, 1000, (batch_size, n_atoms)),
... }
>>> r_t = torch.randn(batch_size, n_atoms, 3)
>>> s_trunk = torch.randn(batch_size, 32, 384)
>>> z_ij = torch.randn(batch_size, n_atoms, n_atoms, 16)
>>>
>>> a, q_skip, c_skip, p_skip = module(f_star, r_t, s_trunk, z_ij)
>>> a.shape
torch.Size([2, 32, 384])
References

.. [1] AlphaFold 3 Algorithm 5: Atom attention encoder

Source code in src/beignet/nn/alphafold3/_atom_attention_encoder.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
class AtomAttentionEncoder(nn.Module):
    r"""
    Atom Attention Encoder for AlphaFold 3.

    This module implements Algorithm 5 exactly, creating atom single conditioning,
    embedding offsets and distances, running cross attention transformer, and
    aggregating per-atom to per-token representations.

    Parameters
    ----------
    c_atom : int, default=128
        Channel dimension for atom representations
    c_atompair : int, default=16
        Channel dimension for atom pair representations
    c_token : int, default=384
        Channel dimension for token representations
    n_block : int, default=3
        Number of transformer blocks
    n_head : int, default=4
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import AtomAttentionEncoder
    >>> batch_size, n_atoms = 2, 1000
    >>> module = AtomAttentionEncoder()
    >>>
    >>> # Feature dictionary with all required atom features
    >>> f_star = {
    ...     'ref_pos': torch.randn(batch_size, n_atoms, 3),
    ...     'ref_mask': torch.ones(batch_size, n_atoms),
    ...     'ref_element': torch.randint(0, 118, (batch_size, n_atoms)),
    ...     'ref_atom_name_chars': torch.randint(0, 26, (batch_size, n_atoms, 4)),
    ...     'ref_charge': torch.randn(batch_size, n_atoms),
    ...     'restype': torch.randint(0, 21, (batch_size, n_atoms)),
    ...     'profile': torch.randn(batch_size, n_atoms, 20),
    ...     'deletion_mean': torch.randn(batch_size, n_atoms),
    ...     'ref_space_uid': torch.randint(0, 1000, (batch_size, n_atoms)),
    ... }
    >>> r_t = torch.randn(batch_size, n_atoms, 3)
    >>> s_trunk = torch.randn(batch_size, 32, 384)
    >>> z_ij = torch.randn(batch_size, n_atoms, n_atoms, 16)
    >>>
    >>> a, q_skip, c_skip, p_skip = module(f_star, r_t, s_trunk, z_ij)
    >>> a.shape
    torch.Size([2, 32, 384])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 5: Atom attention encoder
    """

    def __init__(
        self,
        c_atom: int = 128,
        c_atompair: int = 16,
        c_token: int = 384,
        n_block: int = 3,
        n_head: int = 4,
    ):
        super().__init__()

        self.c_atom = c_atom
        self.c_atompair = c_atompair
        self.c_token = c_token
        self.n_block = n_block
        self.n_head = n_head

        # Step 1: Create atom single conditioning by embedding per-atom meta data
        # We'll concatenate all features and project them
        self.atom_feature_proj = nn.Linear(
            3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000,  # Approximate feature size
            c_atom,
            bias=False,
        )

        # Step 4: Embed pairwise inverse squared distances
        self.dist_proj_1 = nn.Linear(1, c_atompair, bias=False)

        # Step 6: Additional distance embedding
        self.dist_proj_2 = nn.Linear(1, c_atompair, bias=False)

        # Step 9-11: Trunk embedding projections (if provided)
        self.trunk_single_proj = nn.Linear(c_token, c_atom, bias=False)
        self.trunk_single_norm = nn.LayerNorm(c_token)

        self.trunk_pair_proj = nn.Linear(
            128, c_atompair, bias=False
        )  # Assuming z_trunk has 128 dims
        self.trunk_pair_norm = nn.LayerNorm(128)

        # Step 11: Add noisy positions
        self.noisy_pos_proj = nn.Linear(3, c_atom, bias=False)

        # Step 13: Pair representation updates
        self.pair_update_proj_1 = nn.Linear(c_atom, c_atompair, bias=False)
        self.pair_update_proj_2 = nn.Linear(c_atom, c_atompair, bias=False)

        # Step 14: Small MLP on pair activations
        self.pair_mlp = nn.Sequential(
            nn.Linear(c_atompair, c_atompair, bias=False),
            nn.ReLU(),
            nn.Linear(c_atompair, c_atompair, bias=False),
            nn.ReLU(),
            nn.Linear(c_atompair, c_atompair, bias=False),
            nn.ReLU(),
            nn.Linear(c_atompair, c_atompair, bias=False),
        )

        # Step 15: Cross attention transformer
        self.atom_transformer = AtomTransformer(
            n_block=n_block,
            n_head=n_head,
            c_q=c_atom,
            c_kv=c_atom,
            c_pair=c_atompair,
        )

        # Step 16: Aggregation to per-token representation
        self.aggregation_proj = nn.Linear(c_atom, c_token, bias=False)

    def forward(
        self, f_star: dict, r_t: Tensor, s_trunk: Tensor, z_ij: Tensor
    ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
        r"""
        Forward pass of Atom Attention Encoder implementing Algorithm 5.

        Parameters
        ----------
        f_star : dict
            Dictionary containing atom features with keys:
            - 'ref_pos': reference positions (batch, n_atoms, 3)
            - 'ref_mask': mask (batch, n_atoms)
            - 'ref_element': element types (batch, n_atoms)
            - 'ref_atom_name_chars': atom name characters (batch, n_atoms, 4)
            - 'ref_charge': charges (batch, n_atoms)
            - 'restype': residue types (batch, n_atoms)
            - 'profile': sequence profile (batch, n_atoms, 20)
            - 'deletion_mean': deletion statistics (batch, n_atoms)
            - 'ref_space_uid': space UIDs (batch, n_atoms)
        r_t : Tensor, shape=(batch_size, n_atoms, 3)
            Noisy atomic positions at time t
        s_trunk : Tensor, shape=(batch_size, n_tokens, c_token)
            Trunk single representations (optional, can be None)
        z_ij : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
            Atom pair representations

        Returns
        -------
        a : Tensor, shape=(batch_size, n_tokens, c_token)
            Token-level representations
        q_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
            Skip connection for queries
        c_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
            Skip connection for atom features
        p_skip : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
            Skip connection for pair features
        """
        batch_size, n_atoms = r_t.shape[:2]
        device = r_t.device

        # Step 1: Create atom single conditioning by embedding per-atom meta data
        # For simplicity, we'll use basic features that are commonly available
        # In practice, you'd need to handle the full feature set properly

        # Use reference positions from f_star if available, otherwise zeros
        ref_pos = f_star.get("ref_pos", torch.zeros_like(r_t))

        # Create a concatenated feature vector (simplified version)
        # In practice, you'd properly embed each feature type
        atom_features = torch.cat(
            [
                ref_pos,  # (batch, n_atoms, 3)
                torch.ones(
                    batch_size, n_atoms, 1, device=device
                ),  # placeholder for other features
            ],
            dim=-1,
        )

        # Pad or project to expected input size
        if atom_features.shape[-1] < 3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000:
            # Pad with zeros for missing features
            pad_size = (
                3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000 - atom_features.shape[-1]
            )
            atom_features = torch.cat(
                [
                    atom_features,
                    torch.zeros(batch_size, n_atoms, pad_size, device=device),
                ],
                dim=-1,
            )

        c_l = self.atom_feature_proj(
            atom_features[:, :, : 3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000]
        )

        # Steps 2-4: Embed offsets and distances
        d_lm = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(
            -3
        )  # (batch, n_atoms, n_atoms, 3)

        # Step 3: Check for same reference space (simplified)
        same_space = torch.ones(
            batch_size, n_atoms, n_atoms, device=device
        )  # Simplified

        # Step 4: Embed pairwise inverse squared distances
        d_lm_norm = torch.norm(
            d_lm, dim=-1, keepdim=True
        )  # (batch, n_atoms, n_atoms, 1)
        inv_sq_dist = 1.0 / (1.0 + d_lm_norm**2)
        p_lm = self.dist_proj_1(inv_sq_dist) * same_space.unsqueeze(-1)

        # Steps 5-6: Additional distance embeddings
        p_lm = p_lm + self.dist_proj_2(same_space.unsqueeze(-1))

        # Step 7: Initialize atom single representation
        q_l = c_l.clone()

        # Steps 8-12: Add trunk embeddings and noisy positions if provided
        if s_trunk is not None and s_trunk.shape[1] > 0:
            # Step 9: Broadcast single embedding from trunk
            n_tokens = s_trunk.shape[1]
            token_indices = torch.arange(n_atoms, device=device) % n_tokens
            s_trunk_broadcast = s_trunk[:, token_indices]  # (batch, n_atoms, c_token)
            c_l = c_l + self.trunk_single_proj(
                self.trunk_single_norm(s_trunk_broadcast)
            )

            # Step 10: Add trunk pair embedding (simplified)
            if hasattr(self, "z_trunk") and self.z_trunk is not None:
                # In practice, you'd need the actual z_trunk tensor
                pass

        # Step 11: Add noisy positions
        q_l = q_l + self.noisy_pos_proj(r_t)

        # Step 13: Add combined single conditioning to pair representation
        p_lm = (
            p_lm
            + self.pair_update_proj_1(c_l).unsqueeze(-2)
            + self.pair_update_proj_2(c_l).unsqueeze(-3)
        )

        # Step 14: Run small MLP on pair activations
        p_lm = p_lm + self.pair_mlp(p_lm)

        # Step 15: Cross attention transformer
        q_l = self.atom_transformer(q_l, c_l, p_lm)

        # Step 16: Aggregate per-atom to per-token representation
        if s_trunk is not None:
            n_tokens = s_trunk.shape[1]
            # Simple mean aggregation within token groups
            token_assignment = torch.arange(n_atoms, device=device) // (
                n_atoms // n_tokens + 1
            )
            token_assignment = torch.clamp(token_assignment, 0, n_tokens - 1)

            # Aggregate atoms to tokens
            a_i = torch.zeros(
                batch_size, n_tokens, self.c_token, device=device, dtype=q_l.dtype
            )
            q_l_projected = self.aggregation_proj(q_l)

            for i in range(n_tokens):
                mask = token_assignment == i
                if mask.any():
                    a_i[:, i] = q_l_projected[:, mask].mean(dim=1)
        else:
            # If no trunk, create a single token representation
            a_i = self.aggregation_proj(q_l).mean(dim=1, keepdim=True)

        # Step 17: Skip connections
        q_skip = q_l
        c_skip = c_l
        p_skip = p_lm

        return a_i, q_skip, c_skip, p_skip
forward
forward(f_star, r_t, s_trunk, z_ij)

Forward pass of Atom Attention Encoder implementing Algorithm 5.

Parameters:

Name Type Description Default
f_star dict

Dictionary containing atom features with keys: - 'ref_pos': reference positions (batch, n_atoms, 3) - 'ref_mask': mask (batch, n_atoms) - 'ref_element': element types (batch, n_atoms) - 'ref_atom_name_chars': atom name characters (batch, n_atoms, 4) - 'ref_charge': charges (batch, n_atoms) - 'restype': residue types (batch, n_atoms) - 'profile': sequence profile (batch, n_atoms, 20) - 'deletion_mean': deletion statistics (batch, n_atoms) - 'ref_space_uid': space UIDs (batch, n_atoms)

required
r_t Tensor, shape=(batch_size, n_atoms, 3)

Noisy atomic positions at time t

required
s_trunk Tensor, shape=(batch_size, n_tokens, c_token)

Trunk single representations (optional, can be None)

required
z_ij Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)

Atom pair representations

required

Returns:

Name Type Description
a Tensor, shape=(batch_size, n_tokens, c_token)

Token-level representations

q_skip Tensor, shape=(batch_size, n_atoms, c_atom)

Skip connection for queries

c_skip Tensor, shape=(batch_size, n_atoms, c_atom)

Skip connection for atom features

p_skip Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)

Skip connection for pair features

Source code in src/beignet/nn/alphafold3/_atom_attention_encoder.py
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def forward(
    self, f_star: dict, r_t: Tensor, s_trunk: Tensor, z_ij: Tensor
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
    r"""
    Forward pass of Atom Attention Encoder implementing Algorithm 5.

    Parameters
    ----------
    f_star : dict
        Dictionary containing atom features with keys:
        - 'ref_pos': reference positions (batch, n_atoms, 3)
        - 'ref_mask': mask (batch, n_atoms)
        - 'ref_element': element types (batch, n_atoms)
        - 'ref_atom_name_chars': atom name characters (batch, n_atoms, 4)
        - 'ref_charge': charges (batch, n_atoms)
        - 'restype': residue types (batch, n_atoms)
        - 'profile': sequence profile (batch, n_atoms, 20)
        - 'deletion_mean': deletion statistics (batch, n_atoms)
        - 'ref_space_uid': space UIDs (batch, n_atoms)
    r_t : Tensor, shape=(batch_size, n_atoms, 3)
        Noisy atomic positions at time t
    s_trunk : Tensor, shape=(batch_size, n_tokens, c_token)
        Trunk single representations (optional, can be None)
    z_ij : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
        Atom pair representations

    Returns
    -------
    a : Tensor, shape=(batch_size, n_tokens, c_token)
        Token-level representations
    q_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
        Skip connection for queries
    c_skip : Tensor, shape=(batch_size, n_atoms, c_atom)
        Skip connection for atom features
    p_skip : Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
        Skip connection for pair features
    """
    batch_size, n_atoms = r_t.shape[:2]
    device = r_t.device

    # Step 1: Create atom single conditioning by embedding per-atom meta data
    # For simplicity, we'll use basic features that are commonly available
    # In practice, you'd need to handle the full feature set properly

    # Use reference positions from f_star if available, otherwise zeros
    ref_pos = f_star.get("ref_pos", torch.zeros_like(r_t))

    # Create a concatenated feature vector (simplified version)
    # In practice, you'd properly embed each feature type
    atom_features = torch.cat(
        [
            ref_pos,  # (batch, n_atoms, 3)
            torch.ones(
                batch_size, n_atoms, 1, device=device
            ),  # placeholder for other features
        ],
        dim=-1,
    )

    # Pad or project to expected input size
    if atom_features.shape[-1] < 3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000:
        # Pad with zeros for missing features
        pad_size = (
            3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000 - atom_features.shape[-1]
        )
        atom_features = torch.cat(
            [
                atom_features,
                torch.zeros(batch_size, n_atoms, pad_size, device=device),
            ],
            dim=-1,
        )

    c_l = self.atom_feature_proj(
        atom_features[:, :, : 3 + 1 + 118 + 4 * 26 + 1 + 21 + 20 + 1 + 1000]
    )

    # Steps 2-4: Embed offsets and distances
    d_lm = ref_pos.unsqueeze(-2) - ref_pos.unsqueeze(
        -3
    )  # (batch, n_atoms, n_atoms, 3)

    # Step 3: Check for same reference space (simplified)
    same_space = torch.ones(
        batch_size, n_atoms, n_atoms, device=device
    )  # Simplified

    # Step 4: Embed pairwise inverse squared distances
    d_lm_norm = torch.norm(
        d_lm, dim=-1, keepdim=True
    )  # (batch, n_atoms, n_atoms, 1)
    inv_sq_dist = 1.0 / (1.0 + d_lm_norm**2)
    p_lm = self.dist_proj_1(inv_sq_dist) * same_space.unsqueeze(-1)

    # Steps 5-6: Additional distance embeddings
    p_lm = p_lm + self.dist_proj_2(same_space.unsqueeze(-1))

    # Step 7: Initialize atom single representation
    q_l = c_l.clone()

    # Steps 8-12: Add trunk embeddings and noisy positions if provided
    if s_trunk is not None and s_trunk.shape[1] > 0:
        # Step 9: Broadcast single embedding from trunk
        n_tokens = s_trunk.shape[1]
        token_indices = torch.arange(n_atoms, device=device) % n_tokens
        s_trunk_broadcast = s_trunk[:, token_indices]  # (batch, n_atoms, c_token)
        c_l = c_l + self.trunk_single_proj(
            self.trunk_single_norm(s_trunk_broadcast)
        )

        # Step 10: Add trunk pair embedding (simplified)
        if hasattr(self, "z_trunk") and self.z_trunk is not None:
            # In practice, you'd need the actual z_trunk tensor
            pass

    # Step 11: Add noisy positions
    q_l = q_l + self.noisy_pos_proj(r_t)

    # Step 13: Add combined single conditioning to pair representation
    p_lm = (
        p_lm
        + self.pair_update_proj_1(c_l).unsqueeze(-2)
        + self.pair_update_proj_2(c_l).unsqueeze(-3)
    )

    # Step 14: Run small MLP on pair activations
    p_lm = p_lm + self.pair_mlp(p_lm)

    # Step 15: Cross attention transformer
    q_l = self.atom_transformer(q_l, c_l, p_lm)

    # Step 16: Aggregate per-atom to per-token representation
    if s_trunk is not None:
        n_tokens = s_trunk.shape[1]
        # Simple mean aggregation within token groups
        token_assignment = torch.arange(n_atoms, device=device) // (
            n_atoms // n_tokens + 1
        )
        token_assignment = torch.clamp(token_assignment, 0, n_tokens - 1)

        # Aggregate atoms to tokens
        a_i = torch.zeros(
            batch_size, n_tokens, self.c_token, device=device, dtype=q_l.dtype
        )
        q_l_projected = self.aggregation_proj(q_l)

        for i in range(n_tokens):
            mask = token_assignment == i
            if mask.any():
                a_i[:, i] = q_l_projected[:, mask].mean(dim=1)
    else:
        # If no trunk, create a single token representation
        a_i = self.aggregation_proj(q_l).mean(dim=1, keepdim=True)

    # Step 17: Skip connections
    q_skip = q_l
    c_skip = c_l
    p_skip = p_lm

    return a_i, q_skip, c_skip, p_skip

beignet.nn.alphafold3.AtomTransformer

Bases: Module

Atom Transformer for AlphaFold 3.

This module implements sequence-local atom attention using rectangular blocks along the diagonal. It applies the DiffusionTransformer with sequence-local attention masking based on query and key positions.

Parameters:

Name Type Description Default
n_block int

Number of transformer blocks

3
n_head int

Number of attention heads

4
n_queries int

Number of queries per block

32
n_keys int

Number of keys per block

128
subset_centres list

Centers for subset selection

[15.5, 47.5, 79.5, ...]
c_q int

Query dimension (inferred from input if None)

None
c_kv int

Key-value dimension (inferred from input if None)

None
c_pair int

Pair dimension (inferred from input if None)

None

Examples:

>>> import torch
>>> from beignet.nn import AtomTransformer
>>> batch_size, n_atoms = 2, 1000
>>> module = AtomTransformer()
>>> q = torch.randn(batch_size, n_atoms, 128)
>>> c = torch.randn(batch_size, n_atoms, 64)
>>> p = torch.randn(batch_size, n_atoms, n_atoms, 16)
>>> output = module(q, c, p)
>>> output.shape
torch.Size([2, 1000, 128])
References

.. [1] AlphaFold 3 Algorithm 7: Atom Transformer

Source code in src/beignet/nn/alphafold3/_atom_transformer.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
class AtomTransformer(nn.Module):
    r"""
    Atom Transformer for AlphaFold 3.

    This module implements sequence-local atom attention using rectangular blocks
    along the diagonal. It applies the DiffusionTransformer with sequence-local
    attention masking based on query and key positions.

    Parameters
    ----------
    n_block : int, default=3
        Number of transformer blocks
    n_head : int, default=4
        Number of attention heads
    n_queries : int, default=32
        Number of queries per block
    n_keys : int, default=128
        Number of keys per block
    subset_centres : list, default=[15.5, 47.5, 79.5, ...]
        Centers for subset selection
    c_q : int, default=None
        Query dimension (inferred from input if None)
    c_kv : int, default=None
        Key-value dimension (inferred from input if None)
    c_pair : int, default=None
        Pair dimension (inferred from input if None)

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import AtomTransformer
    >>> batch_size, n_atoms = 2, 1000
    >>> module = AtomTransformer()
    >>> q = torch.randn(batch_size, n_atoms, 128)
    >>> c = torch.randn(batch_size, n_atoms, 64)
    >>> p = torch.randn(batch_size, n_atoms, n_atoms, 16)
    >>> output = module(q, c, p)
    >>> output.shape
    torch.Size([2, 1000, 128])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 7: Atom Transformer
    """

    def __init__(
        self,
        n_block: int = 3,
        n_head: int = 4,
        n_queries: int = 32,
        n_keys: int = 128,
        subset_centres: list = None,
        c_q: int = None,
        c_kv: int = None,
        c_pair: int = None,
    ):
        super().__init__()

        self.n_block = n_block
        self.n_head = n_head
        self.n_queries = n_queries
        self.n_keys = n_keys

        if subset_centres is None:
            # Default subset centers as specified in the algorithm
            self.subset_centres = [15.5, 47.5, 79.5]  # Can be extended as needed
        else:
            self.subset_centres = subset_centres

        # Store dimensions (will be inferred from input if not provided)
        self.c_q = c_q
        self.c_kv = c_kv
        self.c_pair = c_pair

        # Will be initialized in first forward pass
        self.diffusion_transformer = None

    def _create_sequence_local_mask(self, q: Tensor, beta_lm: Tensor) -> Tensor:
        """
        Create sequence-local attention mask based on Algorithm 7.

        Parameters
        ----------
        q : Tensor, shape=(batch_size, n_atoms, c_q)
            Query tensor
        beta_lm : Tensor, shape=(batch_size, n_atoms, n_atoms, n_head)
            Base attention bias

        Returns
        -------
        beta_lm : Tensor, shape=(batch_size, n_atoms, n_atoms, n_head)
            Modified attention bias with sequence-local masking
        """
        batch_size, n_atoms = q.shape[:2]
        device = q.device

        # Create position indices
        l_idx = torch.arange(n_atoms, device=device)  # (n_atoms,)
        m_idx = torch.arange(n_atoms, device=device)  # (n_atoms,)

        # Create meshgrid for all pairs
        l_grid, m_grid = torch.meshgrid(
            l_idx, m_idx, indexing="ij"
        )  # (n_atoms, n_atoms)

        # Initialize mask with -10^10 (effectively -inf)
        mask = torch.full_like(beta_lm, -1e10)

        # For each subset center, create rectangular blocks along diagonal
        for c in self.subset_centres:
            # Condition: |l - c| < n_queries/2 ∧ |m - c| < n_keys/2
            l_condition = torch.abs(l_grid - c) < (self.n_queries / 2)
            m_condition = torch.abs(m_grid - c) < (self.n_keys / 2)

            # Combined condition for this subset
            subset_condition = l_condition & m_condition  # (n_atoms, n_atoms)

            # Expand to match beta_lm shape
            subset_condition = subset_condition.unsqueeze(0).unsqueeze(
                -1
            )  # (1, n_atoms, n_atoms, 1)
            subset_condition = subset_condition.expand(batch_size, -1, -1, self.n_head)

            # Set mask to 0 where condition is satisfied
            mask = torch.where(subset_condition, 0.0, mask)

        # Apply mask to beta_lm
        beta_lm = beta_lm + mask

        return beta_lm

    def forward(self, q: Tensor, c: Tensor, p: Tensor) -> Tensor:
        r"""
        Forward pass of Atom Transformer.

        Parameters
        ----------
        q : Tensor, shape=(batch_size, n_atoms, c_q)
            Query representations
        c : Tensor, shape=(batch_size, n_atoms, c_kv)
            Context (single) representations
        p : Tensor, shape=(batch_size, n_atoms, n_atoms, c_pair)
            Pair representations

        Returns
        -------
        q : Tensor, shape=(batch_size, n_atoms, c_q)
            Updated query representations
        """
        # Infer dimensions from input if not provided
        if self.diffusion_transformer is None:
            c_q = q.shape[-1] if self.c_q is None else self.c_q
            c_kv = c.shape[-1] if self.c_kv is None else self.c_kv
            c_pair = p.shape[-1] if self.c_pair is None else self.c_pair

            self.diffusion_transformer = DiffusionTransformer(
                c_a=c_q,  # Use query dimension as token dimension
                c_s=c_kv,  # Use context dimension as single dimension
                c_z=c_pair,  # Use pair dimension
                n_head=self.n_head,
                n_block=self.n_block,
            )

            # Move to same device as input
            self.diffusion_transformer = self.diffusion_transformer.to(q.device)

        # Create initial beta_lm (starts as zeros, will be modified by masking)
        batch_size, n_atoms = q.shape[:2]
        beta_lm = torch.zeros(
            batch_size, n_atoms, n_atoms, self.n_head, device=q.device, dtype=q.dtype
        )

        # Apply sequence-local masking
        beta_lm = self._create_sequence_local_mask(q, beta_lm)

        # Apply DiffusionTransformer
        q = self.diffusion_transformer(q, c, p, beta_lm)

        return q
forward
forward(q, c, p)

Forward pass of Atom Transformer.

Parameters:

Name Type Description Default
q Tensor, shape=(batch_size, n_atoms, c_q)

Query representations

required
c Tensor, shape=(batch_size, n_atoms, c_kv)

Context (single) representations

required
p Tensor, shape=(batch_size, n_atoms, n_atoms, c_pair)

Pair representations

required

Returns:

Name Type Description
q Tensor, shape=(batch_size, n_atoms, c_q)

Updated query representations

Source code in src/beignet/nn/alphafold3/_atom_transformer.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
def forward(self, q: Tensor, c: Tensor, p: Tensor) -> Tensor:
    r"""
    Forward pass of Atom Transformer.

    Parameters
    ----------
    q : Tensor, shape=(batch_size, n_atoms, c_q)
        Query representations
    c : Tensor, shape=(batch_size, n_atoms, c_kv)
        Context (single) representations
    p : Tensor, shape=(batch_size, n_atoms, n_atoms, c_pair)
        Pair representations

    Returns
    -------
    q : Tensor, shape=(batch_size, n_atoms, c_q)
        Updated query representations
    """
    # Infer dimensions from input if not provided
    if self.diffusion_transformer is None:
        c_q = q.shape[-1] if self.c_q is None else self.c_q
        c_kv = c.shape[-1] if self.c_kv is None else self.c_kv
        c_pair = p.shape[-1] if self.c_pair is None else self.c_pair

        self.diffusion_transformer = DiffusionTransformer(
            c_a=c_q,  # Use query dimension as token dimension
            c_s=c_kv,  # Use context dimension as single dimension
            c_z=c_pair,  # Use pair dimension
            n_head=self.n_head,
            n_block=self.n_block,
        )

        # Move to same device as input
        self.diffusion_transformer = self.diffusion_transformer.to(q.device)

    # Create initial beta_lm (starts as zeros, will be modified by masking)
    batch_size, n_atoms = q.shape[:2]
    beta_lm = torch.zeros(
        batch_size, n_atoms, n_atoms, self.n_head, device=q.device, dtype=q.dtype
    )

    # Apply sequence-local masking
    beta_lm = self._create_sequence_local_mask(q, beta_lm)

    # Apply DiffusionTransformer
    q = self.diffusion_transformer(q, c, p, beta_lm)

    return q

beignet.nn.alphafold3.AttentionPairBias

Bases: Module

Attention with pair bias and mask from AlphaFold 3 Algorithm 24.

This implements the AttentionPairBias operation with conditioning signal support for diffusion models. It uses AdaLN when conditioning is provided, or standard LayerNorm when not. Includes proper gating and output projection.

Parameters:

Name Type Description Default
c_a int

Channel dimension for input representation 'a'

required
c_s int

Channel dimension for conditioning signal 's' (can be None if no conditioning)

required
c_z int

Channel dimension for pair representation 'z'

required
n_head int

Number of attention heads

required

Examples:

>>> import torch
>>> from beignet.nn import AttentionPairBias
>>> batch_size, seq_len = 2, 10
>>> c_a, c_s, c_z, n_head = 256, 384, 128, 16
>>> module = AttentionPairBias(c_a=c_a, c_s=c_s, c_z=c_z, n_head=n_head)
>>> a = torch.randn(batch_size, seq_len, c_a)
>>> s = torch.randn(batch_size, seq_len, c_s)
>>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
>>> a_out = module(a, s, z, beta)
>>> a_out.shape
torch.Size([2, 10, 256])
References

.. [1] AlphaFold 3 Algorithm 24: AttentionPairBias with pair bias and mask

Source code in src/beignet/nn/alphafold3/_attention_pair_bias.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class AttentionPairBias(nn.Module):
    r"""
    Attention with pair bias and mask from AlphaFold 3 Algorithm 24.

    This implements the AttentionPairBias operation with conditioning signal support
    for diffusion models. It uses AdaLN when conditioning is provided, or standard
    LayerNorm when not. Includes proper gating and output projection.

    Parameters
    ----------
    c_a : int
        Channel dimension for input representation 'a'
    c_s : int
        Channel dimension for conditioning signal 's' (can be None if no conditioning)
    c_z : int
        Channel dimension for pair representation 'z'
    n_head : int
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import AttentionPairBias
    >>> batch_size, seq_len = 2, 10
    >>> c_a, c_s, c_z, n_head = 256, 384, 128, 16
    >>> module = AttentionPairBias(c_a=c_a, c_s=c_s, c_z=c_z, n_head=n_head)
    >>> a = torch.randn(batch_size, seq_len, c_a)
    >>> s = torch.randn(batch_size, seq_len, c_s)
    >>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
    >>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
    >>> a_out = module(a, s, z, beta)
    >>> a_out.shape
    torch.Size([2, 10, 256])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 24: AttentionPairBias with pair bias and mask
    """

    def __init__(self, c_a: int, c_s: int, c_z: int, n_head: int):
        super().__init__()

        self.c_a = c_a
        self.c_s = c_s
        self.c_z = c_z
        self.n_head = n_head
        self.head_dim = c_a // n_head

        if c_a % n_head != 0:
            raise ValueError(
                f"Channel dimension {c_a} must be divisible by number of heads {n_head}"
            )

        # Input projections - Algorithm 24 steps 1-4
        # Step 1-2: If {si} ≠ ∅ then ai ← AdaLN(ai, si) else ai ← LayerNorm(ai)
        if c_s is not None:
            self.ada_ln = AdaptiveLayerNorm(c=c_a, c_s=c_s)
        else:
            self.layer_norm = nn.LayerNorm(c_a)

        # Step 6: q_i^h = Linear(ai)
        self.linear_q = nn.Linear(c_a, c_a, bias=True)

        # Step 7: k_i^h, v_i^h = LinearNoBias(ai)
        self.linear_k = nn.Linear(c_a, c_a, bias=False)
        self.linear_v = nn.Linear(c_a, c_a, bias=False)

        # Step 8: b_ij^h ← LinearNoBias(LayerNorm(zij)) + βij
        self.linear_b = nn.Linear(c_z, n_head, bias=False)
        self.layer_norm_z = nn.LayerNorm(c_z)

        # Step 9: g_i^h ← sigmoid(LinearNoBias(ai))
        self.linear_g = nn.Linear(c_a, c_a, bias=False)

        # Step 11: Output projection
        self.output_linear = nn.Linear(c_a, c_a, bias=False)

        # Scale factor for attention (Step 10: 1/√c where c = ca/Nhead)
        self.scale = 1.0 / math.sqrt(self.head_dim)

        # Output projection with adaLN-Zero pattern (Steps 12-13)
        if c_s is not None:
            # Step 13: sigmoid(Linear(si, biasinit=-2.0)) ⊙ ai
            self.linear_s_gate = nn.Linear(c_s, c_a, bias=True)
            # Initialize bias to -2.0 as specified
            with torch.no_grad():
                self.linear_s_gate.bias.fill_(-2.0)

    def forward(
        self, a: Tensor, s: Tensor = None, z: Tensor = None, beta: Tensor = None
    ) -> Tensor:
        r"""
        Forward pass of attention with pair bias.

        Parameters
        ----------
        a : Tensor, shape=(..., seq_len, c_a)
            Input representation
        s : Tensor, shape=(..., seq_len, c_s), optional
            Conditioning signal (if None, uses LayerNorm instead of AdaLN)
        z : Tensor, shape=(..., seq_len, seq_len, c_z), optional
            Pair representation for computing attention bias
        beta : Tensor, shape=(..., seq_len, seq_len, n_head), optional
            Additional bias terms

        Returns
        -------
        a_out : Tensor, shape=(..., seq_len, c_a)
            Updated representation after attention with pair bias
        """
        batch_shape = a.shape[:-2]
        seq_len = a.shape[-2]

        # Algorithm 24 Steps 1-4: Input projections
        # Step 1: if {si} ≠ ∅ then
        if s is not None and hasattr(self, "ada_ln"):
            # Step 2: ai ← AdaLN(ai, si)
            a_normed = self.ada_ln(a, s)
        else:
            # Step 4: ai ← LayerNorm(ai)
            a_normed = self.layer_norm(a)

        # Step 6: q_i^h = Linear(ai)
        q = self.linear_q(a_normed)
        q = q.view(*batch_shape, seq_len, self.n_head, self.head_dim)

        # Step 7: k_i^h, v_i^h = LinearNoBias(ai)
        k = self.linear_k(a_normed)
        k = k.view(*batch_shape, seq_len, self.n_head, self.head_dim)

        v = self.linear_v(a_normed)
        v = v.view(*batch_shape, seq_len, self.n_head, self.head_dim)

        # Step 8: b_ij^h ← LinearNoBias(LayerNorm(zij)) + βij
        if z is not None:
            z_normed = self.layer_norm_z(z)
            b_z = self.linear_b(z_normed)  # Shape: (..., seq_len, seq_len, n_head)
        else:
            b_z = torch.zeros(
                *batch_shape,
                seq_len,
                seq_len,
                self.n_head,
                device=a.device,
                dtype=a.dtype,
            )

        if beta is not None:
            b = b_z + beta
        else:
            b = b_z

        # Step 9: g_i^h ← sigmoid(LinearNoBias(ai))
        g = torch.sigmoid(self.linear_g(a_normed))
        g = g.view(*batch_shape, seq_len, self.n_head, self.head_dim)

        # Step 10: A_ij^h ← softmax_j(1/√c * q_i^h * k_j^h + b_ij^h)
        # Compute attention scores: q @ k.T / √d + bias
        # q: (..., seq_len, n_head, head_dim)
        # k: (..., seq_len, n_head, head_dim)
        # Want: (..., n_head, seq_len, seq_len)
        attn_logits = (
            torch.einsum("...ihd,...jhd->...hij", q, k) * self.scale
        )  # Shape: (..., n_head, seq_len, seq_len)

        # Add bias: b has shape (..., seq_len, seq_len, n_head)
        # We need to transpose to match attention shape
        b_transposed = b.transpose(-3, -1)  # (..., n_head, seq_len, seq_len)
        attn_logits = attn_logits + b_transposed

        # Apply softmax over the last dimension (keys)
        attn_weights = torch.softmax(attn_logits, dim=-1)

        # Step 11: ai ← LinearNoBias(concat_h(g_i^h ⊙ Σ_j A_ij^h v_j^h))
        # Apply attention to values and gate
        # attn_weights: (..., n_head, seq_len, seq_len)
        # v: (..., seq_len, n_head, head_dim)
        attended_v = torch.einsum(
            "...hij,...jhd->...hid", attn_weights, v
        )  # Shape: (..., n_head, seq_len, head_dim)

        # Reshape g to match attended_v shape: (..., seq_len, n_head, head_dim) -> (..., n_head, seq_len, head_dim)
        g_reshaped = g.transpose(-3, -2)  # Shape: (..., n_head, seq_len, head_dim)

        # Apply gating: g ⊙ attended_v
        gated_output = g_reshaped * attended_v  # Element-wise multiplication

        # Concatenate heads: reshape to (..., seq_len, c_a)
        concat_output = (
            gated_output.transpose(-3, -2)
            .contiguous()
            .view(*batch_shape, seq_len, self.c_a)
        )

        # Linear projection
        a_out = self.output_linear(concat_output)

        # Algorithm 24 Steps 12-14: Output projection (from adaLN-Zero)
        # Step 12: if {si} ≠ ∅ then
        if s is not None and hasattr(self, "linear_s_gate"):
            # Step 13: ai ← sigmoid(Linear(si, biasinit=-2.0)) ⊙ ai
            s_gate = torch.sigmoid(self.linear_s_gate(s))
            a_out = s_gate * a_out

        # Step 15: return {ai}
        return a_out
forward
forward(a, s=None, z=None, beta=None)

Forward pass of attention with pair bias.

Parameters:

Name Type Description Default
a Tensor, shape=(..., seq_len, c_a)

Input representation

required
s Tensor, shape=(..., seq_len, c_s)

Conditioning signal (if None, uses LayerNorm instead of AdaLN)

None
z Tensor, shape=(..., seq_len, seq_len, c_z)

Pair representation for computing attention bias

None
beta Tensor, shape=(..., seq_len, seq_len, n_head)

Additional bias terms

None

Returns:

Name Type Description
a_out Tensor, shape=(..., seq_len, c_a)

Updated representation after attention with pair bias

Source code in src/beignet/nn/alphafold3/_attention_pair_bias.py
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
def forward(
    self, a: Tensor, s: Tensor = None, z: Tensor = None, beta: Tensor = None
) -> Tensor:
    r"""
    Forward pass of attention with pair bias.

    Parameters
    ----------
    a : Tensor, shape=(..., seq_len, c_a)
        Input representation
    s : Tensor, shape=(..., seq_len, c_s), optional
        Conditioning signal (if None, uses LayerNorm instead of AdaLN)
    z : Tensor, shape=(..., seq_len, seq_len, c_z), optional
        Pair representation for computing attention bias
    beta : Tensor, shape=(..., seq_len, seq_len, n_head), optional
        Additional bias terms

    Returns
    -------
    a_out : Tensor, shape=(..., seq_len, c_a)
        Updated representation after attention with pair bias
    """
    batch_shape = a.shape[:-2]
    seq_len = a.shape[-2]

    # Algorithm 24 Steps 1-4: Input projections
    # Step 1: if {si} ≠ ∅ then
    if s is not None and hasattr(self, "ada_ln"):
        # Step 2: ai ← AdaLN(ai, si)
        a_normed = self.ada_ln(a, s)
    else:
        # Step 4: ai ← LayerNorm(ai)
        a_normed = self.layer_norm(a)

    # Step 6: q_i^h = Linear(ai)
    q = self.linear_q(a_normed)
    q = q.view(*batch_shape, seq_len, self.n_head, self.head_dim)

    # Step 7: k_i^h, v_i^h = LinearNoBias(ai)
    k = self.linear_k(a_normed)
    k = k.view(*batch_shape, seq_len, self.n_head, self.head_dim)

    v = self.linear_v(a_normed)
    v = v.view(*batch_shape, seq_len, self.n_head, self.head_dim)

    # Step 8: b_ij^h ← LinearNoBias(LayerNorm(zij)) + βij
    if z is not None:
        z_normed = self.layer_norm_z(z)
        b_z = self.linear_b(z_normed)  # Shape: (..., seq_len, seq_len, n_head)
    else:
        b_z = torch.zeros(
            *batch_shape,
            seq_len,
            seq_len,
            self.n_head,
            device=a.device,
            dtype=a.dtype,
        )

    if beta is not None:
        b = b_z + beta
    else:
        b = b_z

    # Step 9: g_i^h ← sigmoid(LinearNoBias(ai))
    g = torch.sigmoid(self.linear_g(a_normed))
    g = g.view(*batch_shape, seq_len, self.n_head, self.head_dim)

    # Step 10: A_ij^h ← softmax_j(1/√c * q_i^h * k_j^h + b_ij^h)
    # Compute attention scores: q @ k.T / √d + bias
    # q: (..., seq_len, n_head, head_dim)
    # k: (..., seq_len, n_head, head_dim)
    # Want: (..., n_head, seq_len, seq_len)
    attn_logits = (
        torch.einsum("...ihd,...jhd->...hij", q, k) * self.scale
    )  # Shape: (..., n_head, seq_len, seq_len)

    # Add bias: b has shape (..., seq_len, seq_len, n_head)
    # We need to transpose to match attention shape
    b_transposed = b.transpose(-3, -1)  # (..., n_head, seq_len, seq_len)
    attn_logits = attn_logits + b_transposed

    # Apply softmax over the last dimension (keys)
    attn_weights = torch.softmax(attn_logits, dim=-1)

    # Step 11: ai ← LinearNoBias(concat_h(g_i^h ⊙ Σ_j A_ij^h v_j^h))
    # Apply attention to values and gate
    # attn_weights: (..., n_head, seq_len, seq_len)
    # v: (..., seq_len, n_head, head_dim)
    attended_v = torch.einsum(
        "...hij,...jhd->...hid", attn_weights, v
    )  # Shape: (..., n_head, seq_len, head_dim)

    # Reshape g to match attended_v shape: (..., seq_len, n_head, head_dim) -> (..., n_head, seq_len, head_dim)
    g_reshaped = g.transpose(-3, -2)  # Shape: (..., n_head, seq_len, head_dim)

    # Apply gating: g ⊙ attended_v
    gated_output = g_reshaped * attended_v  # Element-wise multiplication

    # Concatenate heads: reshape to (..., seq_len, c_a)
    concat_output = (
        gated_output.transpose(-3, -2)
        .contiguous()
        .view(*batch_shape, seq_len, self.c_a)
    )

    # Linear projection
    a_out = self.output_linear(concat_output)

    # Algorithm 24 Steps 12-14: Output projection (from adaLN-Zero)
    # Step 12: if {si} ≠ ∅ then
    if s is not None and hasattr(self, "linear_s_gate"):
        # Step 13: ai ← sigmoid(Linear(si, biasinit=-2.0)) ⊙ ai
        s_gate = torch.sigmoid(self.linear_s_gate(s))
        a_out = s_gate * a_out

    # Step 15: return {ai}
    return a_out

beignet.nn.alphafold3.DiffusionTransformer

Bases: Module

Diffusion Transformer from AlphaFold 3 Algorithm 23.

This implements a transformer block for diffusion models that alternates between AttentionPairBias and ConditionedTransitionBlock operations. The module processes single representations {ai} conditioned on {si}, pair representations {zij}, and bias terms {βij}.

Parameters:

Name Type Description Default
c_a int

Channel dimension for single representation 'a'

required
c_s int

Channel dimension for conditioning signal 's'

required
c_z int

Channel dimension for pair representation 'z'

required
n_head int

Number of attention heads

required
n_block int

Number of transformer blocks

required
n int

Expansion factor for ConditionedTransitionBlock

2

Examples:

>>> import torch
>>> from beignet.nn import DiffusionTransformer
>>> batch_size, seq_len, c_a, c_s, c_z = 2, 32, 256, 384, 128
>>> n_head, n_block = 16, 4
>>> module = DiffusionTransformer(
...     c_a=c_a, c_s=c_s, c_z=c_z,
...     n_head=n_head, n_block=n_block
... )
>>> a = torch.randn(batch_size, seq_len, c_a)
>>> s = torch.randn(batch_size, seq_len, c_s)
>>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
>>> a_out = module(a, s, z, beta)
>>> a_out.shape
torch.Size([2, 32, 256])
References

.. [1] AlphaFold 3 Algorithm 23: Diffusion Transformer

Source code in src/beignet/nn/alphafold3/_diffusion_transformer.py
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class DiffusionTransformer(nn.Module):
    r"""
    Diffusion Transformer from AlphaFold 3 Algorithm 23.

    This implements a transformer block for diffusion models that alternates between
    AttentionPairBias and ConditionedTransitionBlock operations. The module processes
    single representations {ai} conditioned on {si}, pair representations {zij},
    and bias terms {βij}.

    Parameters
    ----------
    c_a : int
        Channel dimension for single representation 'a'
    c_s : int
        Channel dimension for conditioning signal 's'
    c_z : int
        Channel dimension for pair representation 'z'
    n_head : int
        Number of attention heads
    n_block : int
        Number of transformer blocks
    n : int, default=2
        Expansion factor for ConditionedTransitionBlock

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import DiffusionTransformer
    >>> batch_size, seq_len, c_a, c_s, c_z = 2, 32, 256, 384, 128
    >>> n_head, n_block = 16, 4
    >>> module = DiffusionTransformer(
    ...     c_a=c_a, c_s=c_s, c_z=c_z,
    ...     n_head=n_head, n_block=n_block
    ... )
    >>> a = torch.randn(batch_size, seq_len, c_a)
    >>> s = torch.randn(batch_size, seq_len, c_s)
    >>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
    >>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
    >>> a_out = module(a, s, z, beta)
    >>> a_out.shape
    torch.Size([2, 32, 256])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 23: Diffusion Transformer
    """

    def __init__(
        self,
        c_a: int,
        c_s: int,
        c_z: int,
        n_head: int,
        n_block: int,
        n: int = 2,
    ):
        super().__init__()

        self.c_a = c_a
        self.c_s = c_s
        self.c_z = c_z
        self.n_head = n_head
        self.n_block = n_block
        self.n = n

        # Create n_block pairs of (AttentionPairBias, ConditionedTransitionBlock)
        self.blocks = nn.ModuleList()
        for _ in range(n_block):
            # Each block contains:
            # 1. AttentionPairBias for step 2
            # 2. ConditionedTransitionBlock for step 3
            block = nn.ModuleDict(
                {
                    "attention": AttentionPairBias(
                        c_a=c_a, c_s=c_s, c_z=c_z, n_head=n_head
                    ),
                    "transition": _ConditionedTransitionBlock(c=c_a, c_s=c_s, n=n),
                }
            )
            self.blocks.append(block)

    def forward(self, a: Tensor, s: Tensor, z: Tensor, beta: Tensor) -> Tensor:
        r"""
        Forward pass of Diffusion Transformer.

        Parameters
        ----------
        a : Tensor, shape=(batch_size, seq_len, c_a)
            Single representations
        s : Tensor, shape=(batch_size, seq_len, c_s)
            Conditioning signal
        z : Tensor, shape=(batch_size, seq_len, seq_len, c_z)
            Pair representations
        beta : Tensor, shape=(batch_size, seq_len, seq_len, n_head)
            Bias terms for attention

        Returns
        -------
        a_out : Tensor, shape=(batch_size, seq_len, c_a)
            Updated single representations after diffusion transformer
        """
        # Algorithm 23: for all n ∈ [1, ..., N_block] do
        for block in self.blocks:
            # Algorithm 23 Step 2: {bi} = AttentionPairBias({ai}, {si}, {zij}, {βij}, N_head)
            b = block["attention"](a, s, z, beta)

            # Algorithm 23 Step 3: ai ← bi + ConditionedTransitionBlock(ai, si)
            a = b + block["transition"](a, s)

        # Algorithm 23 Step 5: return {ai}
        return a
forward
forward(a, s, z, beta)

Forward pass of Diffusion Transformer.

Parameters:

Name Type Description Default
a Tensor, shape=(batch_size, seq_len, c_a)

Single representations

required
s Tensor, shape=(batch_size, seq_len, c_s)

Conditioning signal

required
z Tensor, shape=(batch_size, seq_len, seq_len, c_z)

Pair representations

required
beta Tensor, shape=(batch_size, seq_len, seq_len, n_head)

Bias terms for attention

required

Returns:

Name Type Description
a_out Tensor, shape=(batch_size, seq_len, c_a)

Updated single representations after diffusion transformer

Source code in src/beignet/nn/alphafold3/_diffusion_transformer.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
def forward(self, a: Tensor, s: Tensor, z: Tensor, beta: Tensor) -> Tensor:
    r"""
    Forward pass of Diffusion Transformer.

    Parameters
    ----------
    a : Tensor, shape=(batch_size, seq_len, c_a)
        Single representations
    s : Tensor, shape=(batch_size, seq_len, c_s)
        Conditioning signal
    z : Tensor, shape=(batch_size, seq_len, seq_len, c_z)
        Pair representations
    beta : Tensor, shape=(batch_size, seq_len, seq_len, n_head)
        Bias terms for attention

    Returns
    -------
    a_out : Tensor, shape=(batch_size, seq_len, c_a)
        Updated single representations after diffusion transformer
    """
    # Algorithm 23: for all n ∈ [1, ..., N_block] do
    for block in self.blocks:
        # Algorithm 23 Step 2: {bi} = AttentionPairBias({ai}, {si}, {zij}, {βij}, N_head)
        b = block["attention"](a, s, z, beta)

        # Algorithm 23 Step 3: ai ← bi + ConditionedTransitionBlock(ai, si)
        a = b + block["transition"](a, s)

    # Algorithm 23 Step 5: return {ai}
    return a

beignet.nn.alphafold3.MSA

Bases: Module

Multiple Sequence Alignment Module from AlphaFold 3.

This implements Algorithm 8 from AlphaFold 3, which is a complete MSA processing module that combines multiple sub-modules in a structured way:

  1. MSA representation initialization and random sampling
  2. Communication block with OuterProductMean
  3. MSA stack with MSAPairWeightedAveraging and Transition
  4. Pair stack with triangle updates and attention

The module processes MSA features, single representations, and pair representations through multiple blocks to capture complex evolutionary and structural patterns.

Parameters:

Name Type Description Default
n_block int

Number of processing blocks

4
c_m int

Channel dimension for MSA representation

64
c_z int

Channel dimension for pair representation

128
c_s int

Channel dimension for single representation

256
n_head_msa int

Number of attention heads for MSA operations

8
n_head_pair int

Number of attention heads for pair operations

4
dropout_rate float

Dropout rate for MSA operations

0.15

Examples:

>>> import torch
>>> from beignet.nn import MSA
>>> batch_size, seq_len, n_seq = 2, 32, 16
>>> c_m, c_z, c_s = 64, 128, 256
>>>
>>> module = MSA(n_block=2, c_m=c_m, c_z=c_z, c_s=c_s)
>>>
>>> # Input features
>>> f_msa = torch.randn(batch_size, seq_len, n_seq, 23)  # MSA features
>>> f_has_deletion = torch.randn(batch_size, seq_len, n_seq, 1)
>>> f_deletion_value = torch.randn(batch_size, seq_len, n_seq, 1)
>>> s_inputs = torch.randn(batch_size, seq_len, c_s)  # Single inputs
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z)  # Pair representation
>>>
>>> # Forward pass
>>> updated_z_ij = module(f_msa, f_has_deletion, f_deletion_value, s_inputs, z_ij)
>>> updated_z_ij.shape
torch.Size([2, 32, 32, 128])
References

.. [1] AlphaFold 3 paper, Algorithm 8: MSA Module

Source code in src/beignet/nn/alphafold3/_msa.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class MSA(nn.Module):
    r"""
    Multiple Sequence Alignment Module from AlphaFold 3.

    This implements Algorithm 8 from AlphaFold 3, which is a complete MSA processing
    module that combines multiple sub-modules in a structured way:

    1. MSA representation initialization and random sampling
    2. Communication block with OuterProductMean
    3. MSA stack with MSAPairWeightedAveraging and Transition
    4. Pair stack with triangle updates and attention

    The module processes MSA features, single representations, and pair representations
    through multiple blocks to capture complex evolutionary and structural patterns.

    Parameters
    ----------
    n_block : int, default=4
        Number of processing blocks
    c_m : int, default=64
        Channel dimension for MSA representation
    c_z : int, default=128
        Channel dimension for pair representation
    c_s : int, default=256
        Channel dimension for single representation
    n_head_msa : int, default=8
        Number of attention heads for MSA operations
    n_head_pair : int, default=4
        Number of attention heads for pair operations
    dropout_rate : float, default=0.15
        Dropout rate for MSA operations

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import MSA
    >>> batch_size, seq_len, n_seq = 2, 32, 16
    >>> c_m, c_z, c_s = 64, 128, 256
    >>>
    >>> module = MSA(n_block=2, c_m=c_m, c_z=c_z, c_s=c_s)
    >>>
    >>> # Input features
    >>> f_msa = torch.randn(batch_size, seq_len, n_seq, 23)  # MSA features
    >>> f_has_deletion = torch.randn(batch_size, seq_len, n_seq, 1)
    >>> f_deletion_value = torch.randn(batch_size, seq_len, n_seq, 1)
    >>> s_inputs = torch.randn(batch_size, seq_len, c_s)  # Single inputs
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z)  # Pair representation
    >>>
    >>> # Forward pass
    >>> updated_z_ij = module(f_msa, f_has_deletion, f_deletion_value, s_inputs, z_ij)
    >>> updated_z_ij.shape
    torch.Size([2, 32, 32, 128])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 8: MSA Module
    """

    def __init__(
        self,
        n_block: int = 4,
        c_m: int = 64,
        c_z: int = 128,
        c_s: int = 256,
        n_head_msa: int = 8,
        n_head_pair: int = 4,
        dropout_rate: float = 0.15,
    ):
        super().__init__()

        self.n_block = n_block
        self.c_m = c_m
        self.c_z = c_z
        self.c_s = c_s
        self.n_head_msa = n_head_msa
        self.n_head_pair = n_head_pair
        self.dropout_rate = dropout_rate

        # Step 3: Initial linear projection for MSA (concatenated features -> c_m)
        # Input features: f_msa (23) + f_has_deletion (1) + f_deletion_value (1) = 25 channels
        self.msa_linear = nn.Linear(25, c_m, bias=False)

        # Step 4: Linear projection for single inputs (s_inputs -> c_m)
        self.single_linear = nn.Linear(c_s, c_m, bias=False)

        # Communication: OuterProductMean (step 6)
        self.outer_product_mean = _OuterProductMean(c=c_m, c_z=c_z)

        # MSA stack components (step 7-8)
        self.msa_pair_weighted_averaging = _MSAPairWeightedAveraging(
            c_m=c_m, c_z=c_z, n_head=n_head_msa
        )
        self.msa_transition = Transition(c=c_m, n=4)
        self.msa_dropout = nn.Dropout(dropout_rate)

        # Pair stack components (step 9-13)
        self.triangle_mult_outgoing = TriangleMultiplicationOutgoing(c=c_z)
        self.triangle_mult_incoming = TriangleMultiplicationIncoming(c=c_z)
        self.triangle_attention_starting = TriangleAttentionStartingNode(
            c=c_z, n_head=n_head_pair
        )
        self.triangle_attention_ending = TriangleAttentionEndingNode(
            c=c_z, n_head=n_head_pair
        )
        self.pair_transition = Transition(c=c_z, n=4)

        # Dropout layers for pair operations
        self.pair_dropout_rowwise = nn.Dropout(0.25)  # For steps 9,10,11
        self.pair_dropout_columnwise = nn.Dropout(0.25)  # For step 12

    def forward(
        self,
        f_msa: Tensor,
        f_has_deletion: Tensor,
        f_deletion_value: Tensor,
        s_inputs: Tensor,
        z_ij: Tensor,
    ) -> Tensor:
        r"""
        Forward pass of the MSA Module.

        Parameters
        ----------
        f_msa : Tensor, shape=(..., s, n_seq, 23)
            MSA features (amino acid profiles)
        f_has_deletion : Tensor, shape=(..., s, n_seq, 1)
            Has deletion features
        f_deletion_value : Tensor, shape=(..., s, n_seq, 1)
            Deletion value features
        s_inputs : Tensor, shape=(..., s, c_s)
            Single representation inputs
        z_ij : Tensor, shape=(..., s, s, c_z)
            Pair representation

        Returns
        -------
        z_ij : Tensor, shape=(..., s, s, c_z)
            Updated pair representation
        """
        # Input validation
        seq_len_msa = f_msa.shape[-3]
        seq_len_single = s_inputs.shape[-2]
        n_seq_msa = f_msa.shape[-2]

        if seq_len_msa != seq_len_single:
            raise ValueError(
                f"Sequence length mismatch: MSA has {seq_len_msa} residues, "
                f"single representation has {seq_len_single} residues"
            )

        # Check MSA feature consistency
        if (
            f_has_deletion.shape[-3] != seq_len_msa
            or f_deletion_value.shape[-3] != seq_len_msa
        ):
            raise ValueError("All MSA features must have the same sequence length")

        if (
            f_has_deletion.shape[-2] != n_seq_msa
            or f_deletion_value.shape[-2] != n_seq_msa
        ):
            raise ValueError("All MSA features must have the same number of sequences")

        # Check pair representation compatibility
        if z_ij.shape[-2] != seq_len_msa or z_ij.shape[-3] != seq_len_msa:
            raise ValueError(
                f"Pair representation shape mismatch: expected ({seq_len_msa}, {seq_len_msa}), "
                f"got ({z_ij.shape[-3]}, {z_ij.shape[-2]})"
            )

        m_si = self.msa_linear(
            torch.concatenate(
                [
                    f_msa,
                    f_has_deletion,
                    f_deletion_value,
                ],
                dim=-1,
            )
        ) + torch.unsqueeze(self.single_linear(s_inputs), -2)

        # Step 5: Process through N_block iterations
        for _ in range(self.n_block):
            # Step 6: Communication - OuterProductMean
            # OuterProductMean now properly handles MSA sequences and computes mean over outer products
            # Pass the full MSA representation to capture coevolutionary information
            z_ij = z_ij + self.outer_product_mean(m_si)  # m_si: (..., s, n_seq, c_m)

            # MSA stack (steps 7-8)
            # Step 7: MSA Pair Weighted Averaging with dropout
            m_si = m_si + self.msa_dropout(self.msa_pair_weighted_averaging(m_si, z_ij))

            # Step 8: MSA Transition
            m_si = m_si + self.msa_transition(m_si)

            # Pair stack (steps 9-13)
            # Step 9: Triangle Multiplication Outgoing with rowwise dropout
            z_ij = z_ij + self.pair_dropout_rowwise(self.triangle_mult_outgoing(z_ij))

            # Step 10: Triangle Multiplication Incoming with rowwise dropout
            z_ij = z_ij + self.pair_dropout_rowwise(self.triangle_mult_incoming(z_ij))

            # Step 11: Triangle Attention Starting Node with rowwise dropout
            z_ij = z_ij + self.pair_dropout_rowwise(
                self.triangle_attention_starting(z_ij)
            )

            # Step 12: Triangle Attention Ending Node with columnwise dropout
            z_ij = z_ij + self.pair_dropout_columnwise(
                self.triangle_attention_ending(z_ij)
            )

            # Step 13: Pair Transition
            z_ij = z_ij + self.pair_transition(z_ij)

        return z_ij
forward
forward(f_msa, f_has_deletion, f_deletion_value, s_inputs, z_ij)

Forward pass of the MSA Module.

Parameters:

Name Type Description Default
f_msa Tensor, shape=(..., s, n_seq, 23)

MSA features (amino acid profiles)

required
f_has_deletion Tensor, shape=(..., s, n_seq, 1)

Has deletion features

required
f_deletion_value Tensor, shape=(..., s, n_seq, 1)

Deletion value features

required
s_inputs Tensor, shape=(..., s, c_s)

Single representation inputs

required
z_ij Tensor, shape=(..., s, s, c_z)

Pair representation

required

Returns:

Name Type Description
z_ij Tensor, shape=(..., s, s, c_z)

Updated pair representation

Source code in src/beignet/nn/alphafold3/_msa.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def forward(
    self,
    f_msa: Tensor,
    f_has_deletion: Tensor,
    f_deletion_value: Tensor,
    s_inputs: Tensor,
    z_ij: Tensor,
) -> Tensor:
    r"""
    Forward pass of the MSA Module.

    Parameters
    ----------
    f_msa : Tensor, shape=(..., s, n_seq, 23)
        MSA features (amino acid profiles)
    f_has_deletion : Tensor, shape=(..., s, n_seq, 1)
        Has deletion features
    f_deletion_value : Tensor, shape=(..., s, n_seq, 1)
        Deletion value features
    s_inputs : Tensor, shape=(..., s, c_s)
        Single representation inputs
    z_ij : Tensor, shape=(..., s, s, c_z)
        Pair representation

    Returns
    -------
    z_ij : Tensor, shape=(..., s, s, c_z)
        Updated pair representation
    """
    # Input validation
    seq_len_msa = f_msa.shape[-3]
    seq_len_single = s_inputs.shape[-2]
    n_seq_msa = f_msa.shape[-2]

    if seq_len_msa != seq_len_single:
        raise ValueError(
            f"Sequence length mismatch: MSA has {seq_len_msa} residues, "
            f"single representation has {seq_len_single} residues"
        )

    # Check MSA feature consistency
    if (
        f_has_deletion.shape[-3] != seq_len_msa
        or f_deletion_value.shape[-3] != seq_len_msa
    ):
        raise ValueError("All MSA features must have the same sequence length")

    if (
        f_has_deletion.shape[-2] != n_seq_msa
        or f_deletion_value.shape[-2] != n_seq_msa
    ):
        raise ValueError("All MSA features must have the same number of sequences")

    # Check pair representation compatibility
    if z_ij.shape[-2] != seq_len_msa or z_ij.shape[-3] != seq_len_msa:
        raise ValueError(
            f"Pair representation shape mismatch: expected ({seq_len_msa}, {seq_len_msa}), "
            f"got ({z_ij.shape[-3]}, {z_ij.shape[-2]})"
        )

    m_si = self.msa_linear(
        torch.concatenate(
            [
                f_msa,
                f_has_deletion,
                f_deletion_value,
            ],
            dim=-1,
        )
    ) + torch.unsqueeze(self.single_linear(s_inputs), -2)

    # Step 5: Process through N_block iterations
    for _ in range(self.n_block):
        # Step 6: Communication - OuterProductMean
        # OuterProductMean now properly handles MSA sequences and computes mean over outer products
        # Pass the full MSA representation to capture coevolutionary information
        z_ij = z_ij + self.outer_product_mean(m_si)  # m_si: (..., s, n_seq, c_m)

        # MSA stack (steps 7-8)
        # Step 7: MSA Pair Weighted Averaging with dropout
        m_si = m_si + self.msa_dropout(self.msa_pair_weighted_averaging(m_si, z_ij))

        # Step 8: MSA Transition
        m_si = m_si + self.msa_transition(m_si)

        # Pair stack (steps 9-13)
        # Step 9: Triangle Multiplication Outgoing with rowwise dropout
        z_ij = z_ij + self.pair_dropout_rowwise(self.triangle_mult_outgoing(z_ij))

        # Step 10: Triangle Multiplication Incoming with rowwise dropout
        z_ij = z_ij + self.pair_dropout_rowwise(self.triangle_mult_incoming(z_ij))

        # Step 11: Triangle Attention Starting Node with rowwise dropout
        z_ij = z_ij + self.pair_dropout_rowwise(
            self.triangle_attention_starting(z_ij)
        )

        # Step 12: Triangle Attention Ending Node with columnwise dropout
        z_ij = z_ij + self.pair_dropout_columnwise(
            self.triangle_attention_ending(z_ij)
        )

        # Step 13: Pair Transition
        z_ij = z_ij + self.pair_transition(z_ij)

    return z_ij

beignet.nn.alphafold3.PairformerStack

Bases: Module

Pairformer stack from AlphaFold 3 Algorithm 17.

This is the exact implementation of the Pairformer stack as specified in Algorithm 17, which processes single and pair representations through N_block iterations of triangle operations and attention mechanisms.

Parameters:

Name Type Description Default
n_block int

Number of Pairformer blocks (N_block in Algorithm 17)

48
c_s int

Channel dimension for single representation

384
c_z int

Channel dimension for pair representation

128
n_head_single int

Number of attention heads for single representation

16
n_head_pair int

Number of attention heads for pair representation

4
dropout_rate float

Dropout rate as specified in Algorithm 17

0.25
transition_n int

Multiplier for transition layer hidden dimension

4

Examples:

>>> import torch
>>> from beignet.nn import PairformerStack
>>> batch_size, seq_len = 2, 10
>>> n_block, c_s, c_z = 4, 384, 128
>>> module = PairformerStack(n_block=n_block, c_s=c_s, c_z=c_z)
>>> s_i = torch.randn(batch_size, seq_len, c_s)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> s_out, z_out = module(s_i, z_ij)
>>> s_out.shape
torch.Size([2, 10, 384])
>>> z_out.shape
torch.Size([2, 10, 10, 128])
References

.. [1] Abramson, J., Adler, J., Dunger, J. et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3. Nature 630, 493–500 (2024). Algorithm 17: Pairformer stack

Source code in src/beignet/nn/alphafold3/_pairformer_stack.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
class PairformerStack(nn.Module):
    r"""
    Pairformer stack from AlphaFold 3 Algorithm 17.

    This is the exact implementation of the Pairformer stack as specified
    in Algorithm 17, which processes single and pair representations through
    N_block iterations of triangle operations and attention mechanisms.

    Parameters
    ----------
    n_block : int, default=48
        Number of Pairformer blocks (N_block in Algorithm 17)
    c_s : int, default=384
        Channel dimension for single representation
    c_z : int, default=128
        Channel dimension for pair representation
    n_head_single : int, default=16
        Number of attention heads for single representation
    n_head_pair : int, default=4
        Number of attention heads for pair representation
    dropout_rate : float, default=0.25
        Dropout rate as specified in Algorithm 17
    transition_n : int, default=4
        Multiplier for transition layer hidden dimension

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import PairformerStack
    >>> batch_size, seq_len = 2, 10
    >>> n_block, c_s, c_z = 4, 384, 128
    >>> module = PairformerStack(n_block=n_block, c_s=c_s, c_z=c_z)
    >>> s_i = torch.randn(batch_size, seq_len, c_s)
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z)
    >>> s_out, z_out = module(s_i, z_ij)
    >>> s_out.shape
    torch.Size([2, 10, 384])
    >>> z_out.shape
    torch.Size([2, 10, 10, 128])

    References
    ----------
    .. [1] Abramson, J., Adler, J., Dunger, J. et al. Accurate structure prediction
           of biomolecular interactions with AlphaFold 3. Nature 630, 493–500 (2024).
           Algorithm 17: Pairformer stack
    """

    def __init__(
        self,
        n_block: int = 48,
        c_s: int = 384,
        c_z: int = 128,
        n_head_single: int = 16,
        n_head_pair: int = 4,
        dropout_rate: float = 0.25,
        transition_n: int = 4,
    ):
        super().__init__()

        self.n_block = n_block
        self.c_s = c_s
        self.c_z = c_z

        # Create n_block Pairformer stack blocks (each with its own parameters)
        self.blocks = nn.ModuleList(
            [
                _PairformerStackBlock(
                    c_s=c_s,
                    c_z=c_z,
                    n_head_single=n_head_single,
                    n_head_pair=n_head_pair,
                    dropout_rate=dropout_rate,
                    transition_n=transition_n,
                )
                for _ in range(n_block)
            ]
        )

    def forward(self, s_i: Tensor, z_ij: Tensor) -> tuple[Tensor, Tensor]:
        r"""
        Forward pass of Pairformer stack.

        Parameters
        ----------
        s_i : Tensor, shape=(..., s, c_s)
            Single representation where s is sequence length
        z_ij : Tensor, shape=(..., s, s, c_z)
            Pair representation

        Returns
        -------
        s_out : Tensor, shape=(..., s, c_s)
            Updated single representation after all blocks
        z_out : Tensor, shape=(..., s, s, c_z)
            Updated pair representation after all blocks
        """
        # Validate input shapes
        if s_i.shape[-2] != z_ij.shape[-2] or s_i.shape[-2] != z_ij.shape[-3]:
            raise ValueError(
                f"Sequence length mismatch: single representation has {s_i.shape[-2]} "
                f"residues but pair representation has shape {z_ij.shape[-3:]}"
            )

        if s_i.shape[-1] != self.c_s:
            raise ValueError(
                f"Single representation has {s_i.shape[-1]} channels, "
                f"expected {self.c_s}"
            )

        if z_ij.shape[-1] != self.c_z:
            raise ValueError(
                f"Pair representation has {z_ij.shape[-1]} channels, "
                f"expected {self.c_z}"
            )

        # Algorithm 17: for all l ∈ [1, ..., N_block] do
        for block in self.blocks:
            s_i, z_ij = block(s_i, z_ij)

        # Algorithm 17 step 10: return {s_i}, {z_ij}
        return s_i, z_ij
forward
forward(s_i, z_ij)

Forward pass of Pairformer stack.

Parameters:

Name Type Description Default
s_i Tensor, shape=(..., s, c_s)

Single representation where s is sequence length

required
z_ij Tensor, shape=(..., s, s, c_z)

Pair representation

required

Returns:

Name Type Description
s_out Tensor, shape=(..., s, c_s)

Updated single representation after all blocks

z_out Tensor, shape=(..., s, s, c_z)

Updated pair representation after all blocks

Source code in src/beignet/nn/alphafold3/_pairformer_stack.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
def forward(self, s_i: Tensor, z_ij: Tensor) -> tuple[Tensor, Tensor]:
    r"""
    Forward pass of Pairformer stack.

    Parameters
    ----------
    s_i : Tensor, shape=(..., s, c_s)
        Single representation where s is sequence length
    z_ij : Tensor, shape=(..., s, s, c_z)
        Pair representation

    Returns
    -------
    s_out : Tensor, shape=(..., s, c_s)
        Updated single representation after all blocks
    z_out : Tensor, shape=(..., s, s, c_z)
        Updated pair representation after all blocks
    """
    # Validate input shapes
    if s_i.shape[-2] != z_ij.shape[-2] or s_i.shape[-2] != z_ij.shape[-3]:
        raise ValueError(
            f"Sequence length mismatch: single representation has {s_i.shape[-2]} "
            f"residues but pair representation has shape {z_ij.shape[-3:]}"
        )

    if s_i.shape[-1] != self.c_s:
        raise ValueError(
            f"Single representation has {s_i.shape[-1]} channels, "
            f"expected {self.c_s}"
        )

    if z_ij.shape[-1] != self.c_z:
        raise ValueError(
            f"Pair representation has {z_ij.shape[-1]} channels, "
            f"expected {self.c_z}"
        )

    # Algorithm 17: for all l ∈ [1, ..., N_block] do
    for block in self.blocks:
        s_i, z_ij = block(s_i, z_ij)

    # Algorithm 17 step 10: return {s_i}, {z_ij}
    return s_i, z_ij

beignet.nn.alphafold3.RelativePositionEncoding

Bases: Module

Relative Position Encoding for AlphaFold 3.

This module implements Algorithm 3 exactly, computing relative position encodings based on asymmetric ID, residue index, entity ID, token index, and chain ID information.

Parameters:

Name Type Description Default
r_max int

Maximum residue separation for clipping

32
s_max int

Maximum chain separation for clipping

2
c_z int

Output channel dimension

128

Examples:

>>> import torch
>>> from beignet.nn import RelativePositionEncoding
>>> batch_size, n_tokens = 2, 100
>>> module = RelativePositionEncoding()
>>> f_star = {
...     'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
...     'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
...     'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
...     'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
...     'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
... }
>>> p_ij = module(f_star)
>>> p_ij.shape
torch.Size([2, 100, 100, 128])
References

.. [1] AlphaFold 3 Algorithm 3: Relative position encoding

Source code in src/beignet/nn/alphafold3/_relative_position_encoding.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class RelativePositionEncoding(nn.Module):
    r"""
    Relative Position Encoding for AlphaFold 3.

    This module implements Algorithm 3 exactly, computing relative position
    encodings based on asymmetric ID, residue index, entity ID, token index,
    and chain ID information.

    Parameters
    ----------
    r_max : int, default=32
        Maximum residue separation for clipping
    s_max : int, default=2
        Maximum chain separation for clipping
    c_z : int, default=128
        Output channel dimension

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import RelativePositionEncoding
    >>> batch_size, n_tokens = 2, 100
    >>> module = RelativePositionEncoding()
    >>> f_star = {
    ...     'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
    ...     'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
    ...     'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
    ...     'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
    ...     'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
    ... }
    >>> p_ij = module(f_star)
    >>> p_ij.shape
    torch.Size([2, 100, 100, 128])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 3: Relative position encoding
    """

    def __init__(self, r_max: int = 32, s_max: int = 2, c_z: int = 128):
        super().__init__()

        self.r_max = r_max
        self.s_max = s_max
        self.c_z = c_z

        # Final linear projection
        # 2 rel distance features (residue, token) * (2*r_max+2) + 1 same_entity + 1 chain feature * (2*s_max+2)
        self.linear = nn.Linear(
            2 * (2 * r_max + 2) + 1 + (2 * s_max + 2), c_z, bias=False
        )

    def forward(self, f_star: dict) -> Tensor:
        r"""
        Forward pass implementing Algorithm 3 exactly.

        Parameters
        ----------
        f_star : dict
            Dictionary containing features with keys:
            - 'asym_id': asymmetric unit IDs (batch, n_tokens)
            - 'residue_index': residue indices (batch, n_tokens)
            - 'entity_id': entity IDs (batch, n_tokens)
            - 'token_index': token indices (batch, n_tokens)
            - 'sym_id': symmetry IDs (batch, n_tokens)

        Returns
        -------
        p_ij : Tensor, shape=(batch, n_tokens, n_tokens, c_z)
            Relative position encodings
        """
        # Extract features
        asym_id_i = f_star["asym_id"]  # (batch, n_tokens)
        residue_index_i = f_star["residue_index"]
        entity_id_i = f_star["entity_id"]
        token_index_i = f_star["token_index"]
        sym_id_i = f_star["sym_id"]

        batch_size, n_tokens = asym_id_i.shape
        device = asym_id_i.device

        # Create pairwise comparisons
        asym_id_j = asym_id_i.unsqueeze(-1)  # (batch, n_tokens, 1)
        asym_id_i = asym_id_i.unsqueeze(-2)  # (batch, 1, n_tokens)

        residue_index_j = residue_index_i.unsqueeze(-1)
        residue_index_i = residue_index_i.unsqueeze(-2)

        entity_id_j = entity_id_i.unsqueeze(-1)
        entity_id_i = entity_id_i.unsqueeze(-2)

        token_index_j = token_index_i.unsqueeze(-1)
        token_index_i = token_index_i.unsqueeze(-2)

        sym_id_j = sym_id_i.unsqueeze(-1)
        sym_id_i = sym_id_i.unsqueeze(-2)

        # Step 1: b_ij^same_chain = (f_i^asym_id == f_j^asym_id)
        b_same_chain = (asym_id_i == asym_id_j).float()

        # Step 2: b_ij^same_residue = (f_i^residue_index == f_j^residue_index)
        b_same_residue = (residue_index_i == residue_index_j).float()

        # Step 3: b_ij^same_entity = (f_i^entity_id == f_j^entity_id)
        b_same_entity = (entity_id_i == entity_id_j).float()

        # Step 4: Relative residue distance
        d_residue = torch.where(
            b_same_chain.bool(),
            torch.clamp(
                residue_index_i - residue_index_j + self.r_max, 0, 2 * self.r_max
            ),
            2 * self.r_max + 1,
        ).long()

        # Step 5: One-hot encode residue distance
        a_rel_pos = one_hot(
            d_residue.float(),
            torch.arange(2 * self.r_max + 2, device=device, dtype=torch.float32),
        )  # (batch, n_tokens, n_tokens, 2*r_max+2)

        # Step 6: Relative token distance
        d_token = torch.where(
            (b_same_chain.bool()) & (b_same_residue.bool()),
            torch.clamp(token_index_i - token_index_j + self.r_max, 0, 2 * self.r_max),
            2 * self.r_max + 1,
        ).long()

        # Step 7: One-hot encode token distance
        a_rel_token = one_hot(
            d_token.float(),
            torch.arange(2 * self.r_max + 2, device=device, dtype=torch.float32),
        )  # (batch, n_tokens, n_tokens, 2*r_max+2)

        # Step 8: Relative chain distance
        d_chain = torch.where(
            ~b_same_chain.bool(),
            torch.clamp(sym_id_i - sym_id_j + self.s_max, 0, 2 * self.s_max),
            2 * self.s_max + 1,
        ).long()

        # Step 9: One-hot encode chain distance
        a_rel_chain = one_hot(
            d_chain.float(),
            torch.arange(2 * self.s_max + 2, device=device, dtype=torch.float32),
        )  # (batch, n_tokens, n_tokens, 2*s_max+2)

        # Step 10: Concatenate all features and apply linear projection
        all_features = torch.cat(
            [
                a_rel_pos,  # (batch, n_tokens, n_tokens, 2*r_max+2)
                a_rel_token,  # (batch, n_tokens, n_tokens, 2*r_max+2)
                b_same_entity.unsqueeze(-1),  # (batch, n_tokens, n_tokens, 1)
                a_rel_chain,  # (batch, n_tokens, n_tokens, 2*s_max+2)
            ],
            dim=-1,
        )

        # Step 11: Linear projection
        p_ij = self.linear(all_features)

        return p_ij
forward
forward(f_star)

Forward pass implementing Algorithm 3 exactly.

Parameters:

Name Type Description Default
f_star dict

Dictionary containing features with keys: - 'asym_id': asymmetric unit IDs (batch, n_tokens) - 'residue_index': residue indices (batch, n_tokens) - 'entity_id': entity IDs (batch, n_tokens) - 'token_index': token indices (batch, n_tokens) - 'sym_id': symmetry IDs (batch, n_tokens)

required

Returns:

Name Type Description
p_ij Tensor, shape=(batch, n_tokens, n_tokens, c_z)

Relative position encodings

Source code in src/beignet/nn/alphafold3/_relative_position_encoding.py
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
def forward(self, f_star: dict) -> Tensor:
    r"""
    Forward pass implementing Algorithm 3 exactly.

    Parameters
    ----------
    f_star : dict
        Dictionary containing features with keys:
        - 'asym_id': asymmetric unit IDs (batch, n_tokens)
        - 'residue_index': residue indices (batch, n_tokens)
        - 'entity_id': entity IDs (batch, n_tokens)
        - 'token_index': token indices (batch, n_tokens)
        - 'sym_id': symmetry IDs (batch, n_tokens)

    Returns
    -------
    p_ij : Tensor, shape=(batch, n_tokens, n_tokens, c_z)
        Relative position encodings
    """
    # Extract features
    asym_id_i = f_star["asym_id"]  # (batch, n_tokens)
    residue_index_i = f_star["residue_index"]
    entity_id_i = f_star["entity_id"]
    token_index_i = f_star["token_index"]
    sym_id_i = f_star["sym_id"]

    batch_size, n_tokens = asym_id_i.shape
    device = asym_id_i.device

    # Create pairwise comparisons
    asym_id_j = asym_id_i.unsqueeze(-1)  # (batch, n_tokens, 1)
    asym_id_i = asym_id_i.unsqueeze(-2)  # (batch, 1, n_tokens)

    residue_index_j = residue_index_i.unsqueeze(-1)
    residue_index_i = residue_index_i.unsqueeze(-2)

    entity_id_j = entity_id_i.unsqueeze(-1)
    entity_id_i = entity_id_i.unsqueeze(-2)

    token_index_j = token_index_i.unsqueeze(-1)
    token_index_i = token_index_i.unsqueeze(-2)

    sym_id_j = sym_id_i.unsqueeze(-1)
    sym_id_i = sym_id_i.unsqueeze(-2)

    # Step 1: b_ij^same_chain = (f_i^asym_id == f_j^asym_id)
    b_same_chain = (asym_id_i == asym_id_j).float()

    # Step 2: b_ij^same_residue = (f_i^residue_index == f_j^residue_index)
    b_same_residue = (residue_index_i == residue_index_j).float()

    # Step 3: b_ij^same_entity = (f_i^entity_id == f_j^entity_id)
    b_same_entity = (entity_id_i == entity_id_j).float()

    # Step 4: Relative residue distance
    d_residue = torch.where(
        b_same_chain.bool(),
        torch.clamp(
            residue_index_i - residue_index_j + self.r_max, 0, 2 * self.r_max
        ),
        2 * self.r_max + 1,
    ).long()

    # Step 5: One-hot encode residue distance
    a_rel_pos = one_hot(
        d_residue.float(),
        torch.arange(2 * self.r_max + 2, device=device, dtype=torch.float32),
    )  # (batch, n_tokens, n_tokens, 2*r_max+2)

    # Step 6: Relative token distance
    d_token = torch.where(
        (b_same_chain.bool()) & (b_same_residue.bool()),
        torch.clamp(token_index_i - token_index_j + self.r_max, 0, 2 * self.r_max),
        2 * self.r_max + 1,
    ).long()

    # Step 7: One-hot encode token distance
    a_rel_token = one_hot(
        d_token.float(),
        torch.arange(2 * self.r_max + 2, device=device, dtype=torch.float32),
    )  # (batch, n_tokens, n_tokens, 2*r_max+2)

    # Step 8: Relative chain distance
    d_chain = torch.where(
        ~b_same_chain.bool(),
        torch.clamp(sym_id_i - sym_id_j + self.s_max, 0, 2 * self.s_max),
        2 * self.s_max + 1,
    ).long()

    # Step 9: One-hot encode chain distance
    a_rel_chain = one_hot(
        d_chain.float(),
        torch.arange(2 * self.s_max + 2, device=device, dtype=torch.float32),
    )  # (batch, n_tokens, n_tokens, 2*s_max+2)

    # Step 10: Concatenate all features and apply linear projection
    all_features = torch.cat(
        [
            a_rel_pos,  # (batch, n_tokens, n_tokens, 2*r_max+2)
            a_rel_token,  # (batch, n_tokens, n_tokens, 2*r_max+2)
            b_same_entity.unsqueeze(-1),  # (batch, n_tokens, n_tokens, 1)
            a_rel_chain,  # (batch, n_tokens, n_tokens, 2*s_max+2)
        ],
        dim=-1,
    )

    # Step 11: Linear projection
    p_ij = self.linear(all_features)

    return p_ij

beignet.nn.alphafold3.SampleDiffusion

Bases: Module

Sample Diffusion for AlphaFold 3.

This module implements Algorithm 18 exactly, performing iterative denoising sampling for structure generation using a diffusion model.

Parameters:

Name Type Description Default
gamma_0 float

Initial gamma parameter for augmentation

0.8
gamma_min float

Minimum gamma threshold

1.0
noise_scale float

Noise scale lambda parameter

1.003
step_scale float

Step scale eta parameter

1.5
s_trans float

Translation scale for augmentation

1.0

Examples:

>>> import torch
>>> from beignet.nn.alphafold3 import SampleDiffusion
>>> batch_size, n_atoms, n_tokens = 2, 1000, 32
>>> module = SampleDiffusion()
>>>
>>> # Input features
>>> f_star = {'ref_pos': torch.randn(batch_size, n_atoms, 3)}
>>> s_inputs = torch.randn(batch_size, n_atoms, 100)
>>> s_trunk = torch.randn(batch_size, n_tokens, 384)
>>> z_trunk = torch.randn(batch_size, n_tokens, n_tokens, 128)
>>> noise_schedule = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
>>>
>>> x_t = module(f_star, s_inputs, s_trunk, z_trunk, noise_schedule)
>>> x_t.shape
torch.Size([2, 1000, 3])
References

.. [1] AlphaFold 3 Algorithm 18: Sample Diffusion

Source code in src/beignet/nn/alphafold3/_sample_diffusion.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class SampleDiffusion(nn.Module):
    r"""
    Sample Diffusion for AlphaFold 3.

    This module implements Algorithm 18 exactly, performing iterative denoising
    sampling for structure generation using a diffusion model.

    Parameters
    ----------
    gamma_0 : float, default=0.8
        Initial gamma parameter for augmentation
    gamma_min : float, default=1.0
        Minimum gamma threshold
    noise_scale : float, default=1.003
        Noise scale lambda parameter
    step_scale : float, default=1.5
        Step scale eta parameter
    s_trans : float, default=1.0
        Translation scale for augmentation

    Examples
    --------
    >>> import torch
    >>> from beignet.nn.alphafold3 import SampleDiffusion
    >>> batch_size, n_atoms, n_tokens = 2, 1000, 32
    >>> module = SampleDiffusion()
    >>>
    >>> # Input features
    >>> f_star = {'ref_pos': torch.randn(batch_size, n_atoms, 3)}
    >>> s_inputs = torch.randn(batch_size, n_atoms, 100)
    >>> s_trunk = torch.randn(batch_size, n_tokens, 384)
    >>> z_trunk = torch.randn(batch_size, n_tokens, n_tokens, 128)
    >>> noise_schedule = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
    >>>
    >>> x_t = module(f_star, s_inputs, s_trunk, z_trunk, noise_schedule)
    >>> x_t.shape
    torch.Size([2, 1000, 3])

    References
    ----------
    .. [1] AlphaFold 3 Algorithm 18: Sample Diffusion
    """

    def __init__(
        self,
        gamma_0: float = 0.8,
        gamma_min: float = 1.0,
        noise_scale: float = 1.003,
        step_scale: float = 1.5,
        s_trans: float = 1.0,
    ):
        super().__init__()

        self.gamma_0 = gamma_0
        self.gamma_min = gamma_min
        self.noise_scale = noise_scale  # lambda
        self.step_scale = step_scale  # eta
        self.s_trans = s_trans

        # Diffusion module for denoising
        self.diffusion_module = _Diffusion()

        # Centre random augmentation
        self.centre_random_augmentation = _CentreRandomAugmentation(s_trans=s_trans)

    def forward(
        self,
        f_star: dict,
        s_inputs: Tensor,
        s_trunk: Tensor,
        z_trunk: Tensor,
        noise_schedule: list[float],
    ) -> Tensor:
        r"""
        Forward pass implementing Algorithm 18 exactly.

        Parameters
        ----------
        f_star : dict
            Reference structure features
        s_inputs : Tensor, shape=(batch_size, n_atoms, c_s_inputs)
            Input single representations
        s_trunk : Tensor, shape=(batch_size, n_tokens, c_s)
            Trunk single representations
        z_trunk : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
            Trunk pair representations
        noise_schedule : list
            Noise schedule [c0, c1, ..., cT]

        Returns
        -------
        x_t : Tensor, shape=(batch_size, n_atoms, 3)
            Final denoised positions
        """
        device = s_inputs.device
        batch_size = s_inputs.shape[0]
        n_atoms = s_inputs.shape[1]

        # Step 1: x̃_t ∼ c_0 · N(0̃, I_3)
        c_0 = noise_schedule[0]
        x_t = c_0 * torch.randn(batch_size, n_atoms, 3, device=device)

        # Step 2: for all c_τ ∈ [c_1, ..., c_T] do
        for tau, c_tau in enumerate(noise_schedule[1:], 1):
            # Step 3: {x̃_t} ← CentreRandomAugmentation({x̃_t})
            x_t = self.centre_random_augmentation(x_t)

            # Step 4: γ = γ_0 if c_τ > γ_min else 0
            gamma = self.gamma_0 if c_tau > self.gamma_min else 0.0

            # Step 5: t̂ = c_{τ-1}(γ + 1)
            c_tau_minus_1 = noise_schedule[tau - 1]
            t_hat = c_tau_minus_1 * (gamma + 1)

            # Step 6: ζ̃_t = λ√(t̂^2 - c^2_{τ-1}) · N(0̃, I_3)
            variance = t_hat**2 - c_tau_minus_1**2
            if variance > 0:
                zeta_t = (
                    self.noise_scale
                    * torch.sqrt(torch.tensor(variance))
                    * torch.randn_like(x_t)
                )
            else:
                zeta_t = torch.zeros_like(x_t)

            # Step 7: x̃_t^noisy = x̃_t + ζ̃_t
            x_t_noisy = x_t + zeta_t

            # Step 8: {x̃_t^denoised} = AlphaFold3Diffusion({x̃_t^noisy}, t̂, {f*}, {s_i^inputs}, {s_i^trunk}, {z_{ij}^trunk})
            # Create timestep tensor
            t_tensor = torch.full(
                (batch_size, 1), t_hat, device=device, dtype=x_t.dtype
            )

            # Get reference positions from f_star
            f_star_pos = f_star.get("ref_pos", x_t * 0)  # Use zeros if not available

            # Create dummy z_atom for the diffusion module
            z_atom = torch.zeros(
                batch_size, n_atoms, n_atoms, 16, device=device, dtype=x_t.dtype
            )

            x_t_denoised = self.diffusion_module(
                x_noisy=x_t_noisy,
                t=t_tensor,
                f_star=f_star_pos,
                s_inputs=s_inputs,
                s_trunk=s_trunk,
                z_trunk=z_trunk,
                z_atom=z_atom,
            )

            # Step 9: δ̃_t = (x̃_t - x̃_t^denoised) / t̂
            delta_t = (x_t - x_t_denoised) / t_hat

            # Step 10: dt = c_τ - t̂
            dt = c_tau - t_hat

            # Step 11: x̃_t ← x̃_t^noisy + η · dt · δ̃_t
            x_t = x_t_noisy + self.step_scale * dt * delta_t

        # Step 12: end for
        # Step 13: return {x̃_t}
        return x_t
forward
forward(f_star, s_inputs, s_trunk, z_trunk, noise_schedule)

Forward pass implementing Algorithm 18 exactly.

Parameters:

Name Type Description Default
f_star dict

Reference structure features

required
s_inputs Tensor, shape=(batch_size, n_atoms, c_s_inputs)

Input single representations

required
s_trunk Tensor, shape=(batch_size, n_tokens, c_s)

Trunk single representations

required
z_trunk Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)

Trunk pair representations

required
noise_schedule list

Noise schedule [c0, c1, ..., cT]

required

Returns:

Name Type Description
x_t Tensor, shape=(batch_size, n_atoms, 3)

Final denoised positions

Source code in src/beignet/nn/alphafold3/_sample_diffusion.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
def forward(
    self,
    f_star: dict,
    s_inputs: Tensor,
    s_trunk: Tensor,
    z_trunk: Tensor,
    noise_schedule: list[float],
) -> Tensor:
    r"""
    Forward pass implementing Algorithm 18 exactly.

    Parameters
    ----------
    f_star : dict
        Reference structure features
    s_inputs : Tensor, shape=(batch_size, n_atoms, c_s_inputs)
        Input single representations
    s_trunk : Tensor, shape=(batch_size, n_tokens, c_s)
        Trunk single representations
    z_trunk : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
        Trunk pair representations
    noise_schedule : list
        Noise schedule [c0, c1, ..., cT]

    Returns
    -------
    x_t : Tensor, shape=(batch_size, n_atoms, 3)
        Final denoised positions
    """
    device = s_inputs.device
    batch_size = s_inputs.shape[0]
    n_atoms = s_inputs.shape[1]

    # Step 1: x̃_t ∼ c_0 · N(0̃, I_3)
    c_0 = noise_schedule[0]
    x_t = c_0 * torch.randn(batch_size, n_atoms, 3, device=device)

    # Step 2: for all c_τ ∈ [c_1, ..., c_T] do
    for tau, c_tau in enumerate(noise_schedule[1:], 1):
        # Step 3: {x̃_t} ← CentreRandomAugmentation({x̃_t})
        x_t = self.centre_random_augmentation(x_t)

        # Step 4: γ = γ_0 if c_τ > γ_min else 0
        gamma = self.gamma_0 if c_tau > self.gamma_min else 0.0

        # Step 5: t̂ = c_{τ-1}(γ + 1)
        c_tau_minus_1 = noise_schedule[tau - 1]
        t_hat = c_tau_minus_1 * (gamma + 1)

        # Step 6: ζ̃_t = λ√(t̂^2 - c^2_{τ-1}) · N(0̃, I_3)
        variance = t_hat**2 - c_tau_minus_1**2
        if variance > 0:
            zeta_t = (
                self.noise_scale
                * torch.sqrt(torch.tensor(variance))
                * torch.randn_like(x_t)
            )
        else:
            zeta_t = torch.zeros_like(x_t)

        # Step 7: x̃_t^noisy = x̃_t + ζ̃_t
        x_t_noisy = x_t + zeta_t

        # Step 8: {x̃_t^denoised} = AlphaFold3Diffusion({x̃_t^noisy}, t̂, {f*}, {s_i^inputs}, {s_i^trunk}, {z_{ij}^trunk})
        # Create timestep tensor
        t_tensor = torch.full(
            (batch_size, 1), t_hat, device=device, dtype=x_t.dtype
        )

        # Get reference positions from f_star
        f_star_pos = f_star.get("ref_pos", x_t * 0)  # Use zeros if not available

        # Create dummy z_atom for the diffusion module
        z_atom = torch.zeros(
            batch_size, n_atoms, n_atoms, 16, device=device, dtype=x_t.dtype
        )

        x_t_denoised = self.diffusion_module(
            x_noisy=x_t_noisy,
            t=t_tensor,
            f_star=f_star_pos,
            s_inputs=s_inputs,
            s_trunk=s_trunk,
            z_trunk=z_trunk,
            z_atom=z_atom,
        )

        # Step 9: δ̃_t = (x̃_t - x̃_t^denoised) / t̂
        delta_t = (x_t - x_t_denoised) / t_hat

        # Step 10: dt = c_τ - t̂
        dt = c_tau - t_hat

        # Step 11: x̃_t ← x̃_t^noisy + η · dt · δ̃_t
        x_t = x_t_noisy + self.step_scale * dt * delta_t

    # Step 12: end for
    # Step 13: return {x̃_t}
    return x_t

beignet.nn.alphafold3.TemplateEmbedder

Bases: Module

Template Embedder for AlphaFold 3.

This module processes template structural information and adds it to the pair representation. Templates provide structural constraints from homologous structures that guide the prediction.

Parameters:

Name Type Description Default
c_z int

Pair representation dimension

128
c_template int

Template feature dimension

64
n_head int

Number of attention heads

4

Examples:

>>> import torch
>>> from beignet.nn import TemplateEmbedder
>>> batch_size, n_tokens = 2, 64
>>> module = TemplateEmbedder()
>>> f_star = {'template_features': torch.randn(batch_size, n_tokens, n_tokens, 64)}
>>> z_ij = torch.randn(batch_size, n_tokens, n_tokens, 128)
>>> output = module(f_star, z_ij)
>>> output.shape
torch.Size([2, 64, 64, 128])
Source code in src/beignet/nn/alphafold3/_template_embedder.py
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class TemplateEmbedder(nn.Module):
    r"""
    Template Embedder for AlphaFold 3.

    This module processes template structural information and adds it to the pair
    representation. Templates provide structural constraints from homologous
    structures that guide the prediction.

    Parameters
    ----------
    c_z : int, default=128
        Pair representation dimension
    c_template : int, default=64
        Template feature dimension
    n_head : int, default=4
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import TemplateEmbedder
    >>> batch_size, n_tokens = 2, 64
    >>> module = TemplateEmbedder()
    >>> f_star = {'template_features': torch.randn(batch_size, n_tokens, n_tokens, 64)}
    >>> z_ij = torch.randn(batch_size, n_tokens, n_tokens, 128)
    >>> output = module(f_star, z_ij)
    >>> output.shape
    torch.Size([2, 64, 64, 128])
    """

    def __init__(
        self,
        c_z: int = 128,
        c_template: int = 64,
        n_head: int = 4,
    ):
        super().__init__()

        self.c_z = c_z
        self.c_template = c_template
        self.n_head = n_head

        # Template processing layers
        self.template_proj = nn.Linear(c_template, c_z, bias=False)
        self.layer_norm = nn.LayerNorm(c_z)

        # Attention mechanism for template integration
        self.attention = nn.MultiheadAttention(
            embed_dim=c_z,
            num_heads=n_head,
            batch_first=True,
        )

        # Final projection
        self.output_proj = nn.Linear(c_z, c_z, bias=False)

    def forward(self, f_star: dict, z_ij: Tensor) -> Tensor:
        r"""
        Forward pass of Template Embedder.

        Parameters
        ----------
        f_star : dict
            Dictionary containing template features with key 'template_features'
        z_ij : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
            Current pair representations

        Returns
        -------
        z_ij : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
            Updated pair representations with template information
        """
        # Extract template features (if available)
        if "template_features" not in f_star:
            # No templates available, return unchanged
            return z_ij

        template_features = f_star[
            "template_features"
        ]  # (batch, n_tokens, n_tokens, c_template)
        batch_size, n_tokens, _, c_template = template_features.shape

        # Project template features to pair dimension
        template_proj = self.template_proj(
            template_features
        )  # (batch, n_tokens, n_tokens, c_z)

        # Reshape for attention: (batch, n_tokens^2, c_z)
        template_flat = template_proj.reshape(batch_size, n_tokens * n_tokens, self.c_z)
        z_flat = z_ij.reshape(batch_size, n_tokens * n_tokens, self.c_z)

        # Apply layer norm
        template_flat = self.layer_norm(template_flat)

        # Self-attention to integrate template information
        template_attended, _ = self.attention(
            template_flat, template_flat, template_flat
        )

        # Add template information to pair representations
        z_updated = z_flat + template_attended

        # Final projection and reshape back
        z_updated = self.output_proj(z_updated)
        z_updated = z_updated.reshape(batch_size, n_tokens, n_tokens, self.c_z)

        return z_updated
forward
forward(f_star, z_ij)

Forward pass of Template Embedder.

Parameters:

Name Type Description Default
f_star dict

Dictionary containing template features with key 'template_features'

required
z_ij Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)

Current pair representations

required

Returns:

Name Type Description
z_ij Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)

Updated pair representations with template information

Source code in src/beignet/nn/alphafold3/_template_embedder.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
def forward(self, f_star: dict, z_ij: Tensor) -> Tensor:
    r"""
    Forward pass of Template Embedder.

    Parameters
    ----------
    f_star : dict
        Dictionary containing template features with key 'template_features'
    z_ij : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
        Current pair representations

    Returns
    -------
    z_ij : Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
        Updated pair representations with template information
    """
    # Extract template features (if available)
    if "template_features" not in f_star:
        # No templates available, return unchanged
        return z_ij

    template_features = f_star[
        "template_features"
    ]  # (batch, n_tokens, n_tokens, c_template)
    batch_size, n_tokens, _, c_template = template_features.shape

    # Project template features to pair dimension
    template_proj = self.template_proj(
        template_features
    )  # (batch, n_tokens, n_tokens, c_z)

    # Reshape for attention: (batch, n_tokens^2, c_z)
    template_flat = template_proj.reshape(batch_size, n_tokens * n_tokens, self.c_z)
    z_flat = z_ij.reshape(batch_size, n_tokens * n_tokens, self.c_z)

    # Apply layer norm
    template_flat = self.layer_norm(template_flat)

    # Self-attention to integrate template information
    template_attended, _ = self.attention(
        template_flat, template_flat, template_flat
    )

    # Add template information to pair representations
    z_updated = z_flat + template_attended

    # Final projection and reshape back
    z_updated = self.output_proj(z_updated)
    z_updated = z_updated.reshape(batch_size, n_tokens, n_tokens, self.c_z)

    return z_updated

beignet.nn.alphafold3.Transition

Bases: Module

Transition layer from AlphaFold 3.

This implements Algorithm 11 from AlphaFold 3, which is a simple transition layer with layer normalization, two linear projections, and a SwiGLU activation function for enhanced non-linearity.

Parameters:

Name Type Description Default
c int

Input and output channel dimension

128
n int

Expansion factor for the hidden dimension

4

Examples:

>>> import torch
>>> from beignet.nn import Transition
>>> batch_size, seq_len, c = 2, 10, 128
>>> n = 4
>>> module = Transition(c=c, n=n)
>>> x = torch.randn(batch_size, seq_len, c)
>>> x_out = module(x)
>>> x_out.shape
torch.Size([2, 10, 128])
References

.. [1] AlphaFold 3 paper, Algorithm 11: Transition layer

Source code in src/beignet/nn/alphafold3/_transition.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
class Transition(nn.Module):
    r"""
    Transition layer from AlphaFold 3.

    This implements Algorithm 11 from AlphaFold 3, which is a simple
    transition layer with layer normalization, two linear projections,
    and a SwiGLU activation function for enhanced non-linearity.

    Parameters
    ----------
    c : int, default=128
        Input and output channel dimension
    n : int, default=4
        Expansion factor for the hidden dimension

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import Transition
    >>> batch_size, seq_len, c = 2, 10, 128
    >>> n = 4
    >>> module = Transition(c=c, n=n)
    >>> x = torch.randn(batch_size, seq_len, c)
    >>> x_out = module(x)
    >>> x_out.shape
    torch.Size([2, 10, 128])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 11: Transition layer
    """

    def __init__(self, c: int = 128, n: int = 4):
        super().__init__()

        self.c = c
        self.n = n
        self.hidden_dim = n * c

        # Layer normalization (step 1)
        self.layer_norm = nn.LayerNorm(c)

        # First linear projection (step 2)
        self.linear_1 = nn.Linear(c, self.hidden_dim, bias=False)

        # Second linear projection (step 3)
        self.linear_2 = nn.Linear(c, self.hidden_dim, bias=False)

        # Final output projection (step 4)
        self.output_linear = nn.Linear(self.hidden_dim, c, bias=False)

    def forward(self, x: Tensor) -> Tensor:
        r"""
        Forward pass of transition layer.

        Parameters
        ----------
        x : Tensor, shape=(..., c)
            Input tensor where c is the channel dimension.

        Returns
        -------
        x : Tensor, shape=(..., c)
            Output tensor after transition layer processing.
        """
        x = self.layer_norm(x)

        x = self.output_linear(
            self.linear_1(x) * torch.sigmoid(self.linear_1(x)) * self.linear_2(x),
        )  # (..., c)

        return x
forward
forward(x)

Forward pass of transition layer.

Parameters:

Name Type Description Default
x Tensor, shape=(..., c)

Input tensor where c is the channel dimension.

required

Returns:

Name Type Description
x Tensor, shape=(..., c)

Output tensor after transition layer processing.

Source code in src/beignet/nn/alphafold3/_transition.py
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
def forward(self, x: Tensor) -> Tensor:
    r"""
    Forward pass of transition layer.

    Parameters
    ----------
    x : Tensor, shape=(..., c)
        Input tensor where c is the channel dimension.

    Returns
    -------
    x : Tensor, shape=(..., c)
        Output tensor after transition layer processing.
    """
    x = self.layer_norm(x)

    x = self.output_linear(
        self.linear_1(x) * torch.sigmoid(self.linear_1(x)) * self.linear_2(x),
    )  # (..., c)

    return x

beignet.nn.alphafold3.TriangleAttentionEndingNode

Bases: Module

Triangular gated self-attention around ending node from AlphaFold 3.

This implements Algorithm 15 from AlphaFold 3, which performs triangular gated self-attention where the attention is computed around the ending node of each edge in the triangle. The key differences from Algorithm 14 are in steps 5 and 6 where k^h_kj and v^h_kj are used instead of k^h_jk and v^h_jk.

Parameters:

Name Type Description Default
c int

Channel dimension for the pair representation

32
n_head int

Number of attention heads

4

Examples:

>>> import torch
>>> from beignet.nn import TriangleAttentionEndingNode
>>> batch_size, seq_len, c = 2, 10, 32
>>> n_head = 4
>>> module = TriangleAttentionEndingNode(c=c, n_head=n_head)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 32])
References

.. [1] AlphaFold 3 paper, Algorithm 15: Triangular gated self-attention around ending node

Source code in src/beignet/nn/alphafold3/_triangle_attention_ending_node.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
class TriangleAttentionEndingNode(Module):
    r"""
    Triangular gated self-attention around ending node from AlphaFold 3.

    This implements Algorithm 15 from AlphaFold 3, which performs triangular
    gated self-attention where the attention is computed around the ending
    node of each edge in the triangle. The key differences from Algorithm 14
    are in steps 5 and 6 where k^h_kj and v^h_kj are used instead of k^h_jk and v^h_jk.

    Parameters
    ----------
    c : int, default=32
        Channel dimension for the pair representation
    n_head : int, default=4
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import TriangleAttentionEndingNode
    >>> batch_size, seq_len, c = 2, 10, 32
    >>> n_head = 4
    >>> module = TriangleAttentionEndingNode(c=c, n_head=n_head)
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
    >>> z_tilde_ij = module(z_ij)
    >>> z_tilde_ij.shape
    torch.Size([2, 10, 10, 32])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 15: Triangular gated self-attention around ending node
    """

    def __init__(self, c: int = 32, n_head: int = 4):
        super().__init__()

        self.c = c
        self.n_head = n_head
        self.head_dim = c // n_head

        if c % n_head != 0:
            raise ValueError(
                f"Channel dimension {c} must be divisible by number of heads {n_head}"
            )

        # Layer normalization for input (step 1)
        self.layer_norm = LayerNorm(c)

        # Linear projections for queries, keys, values (step 2)
        self.linear_q = Linear(c, c, bias=False)
        self.linear_k = Linear(c, c, bias=False)
        self.linear_v = Linear(c, c, bias=False)

        # Bias projection (step 3)
        self.linear_b = Linear(c, n_head, bias=False)

        # Gate projection (step 4)
        self.linear_g = Linear(c, c, bias=False)

        # Output projection (step 7)
        self.output_linear = Linear(c, c, bias=False)

        # Scale factor for attention
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, z_ij: Tensor) -> Tensor:
        r"""
        Forward pass of triangular gated self-attention around ending node.

        Parameters
        ----------
        z_ij : Tensor, shape=(..., s, s, c)
            Input pair representation where s is sequence length
            and c is channel dimension.

        Returns
        -------
        z_tilde_ij : Tensor, shape=(..., s, s, c)
            Updated pair representation after triangular attention.
        """
        # Step 1: Layer normalization
        z_ij = self.layer_norm(z_ij)

        return self.output_linear(
            (
                torch.sigmoid(self.linear_g(z_ij)).view(
                    *(z_ij.shape[:-3]),
                    z_ij.shape[-2],
                    z_ij.shape[-2],
                    self.n_head,
                    self.head_dim,
                )
                * torch.einsum(
                    "...ijhk,...kjhd->...ijhd",
                    torch.softmax(
                        (
                            torch.einsum(
                                "...ijhd,...kjhd->...ijhk",
                                self.linear_q(z_ij).view(
                                    *(z_ij.shape[:-3]),
                                    z_ij.shape[-2],
                                    z_ij.shape[-2],
                                    self.n_head,
                                    self.head_dim,
                                ),
                                self.linear_k(z_ij).view(
                                    *(z_ij.shape[:-3]),
                                    z_ij.shape[-2],
                                    z_ij.shape[-2],
                                    self.n_head,
                                    self.head_dim,
                                ),
                            )
                            * self.scale
                            + torch.unsqueeze(self.linear_b(z_ij), -1)
                        ),
                        dim=-1,
                    ),
                    self.linear_v(z_ij).view(
                        *(z_ij.shape[:-3]),
                        z_ij.shape[-2],
                        z_ij.shape[-2],
                        self.n_head,
                        self.head_dim,
                    ),
                )
            ).view(
                *(z_ij.shape[:-3]),
                z_ij.shape[-2],
                z_ij.shape[-2],
                self.c,
            )
        )
forward
forward(z_ij)

Forward pass of triangular gated self-attention around ending node.

Parameters:

Name Type Description Default
z_ij Tensor, shape=(..., s, s, c)

Input pair representation where s is sequence length and c is channel dimension.

required

Returns:

Name Type Description
z_tilde_ij Tensor, shape=(..., s, s, c)

Updated pair representation after triangular attention.

Source code in src/beignet/nn/alphafold3/_triangle_attention_ending_node.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
def forward(self, z_ij: Tensor) -> Tensor:
    r"""
    Forward pass of triangular gated self-attention around ending node.

    Parameters
    ----------
    z_ij : Tensor, shape=(..., s, s, c)
        Input pair representation where s is sequence length
        and c is channel dimension.

    Returns
    -------
    z_tilde_ij : Tensor, shape=(..., s, s, c)
        Updated pair representation after triangular attention.
    """
    # Step 1: Layer normalization
    z_ij = self.layer_norm(z_ij)

    return self.output_linear(
        (
            torch.sigmoid(self.linear_g(z_ij)).view(
                *(z_ij.shape[:-3]),
                z_ij.shape[-2],
                z_ij.shape[-2],
                self.n_head,
                self.head_dim,
            )
            * torch.einsum(
                "...ijhk,...kjhd->...ijhd",
                torch.softmax(
                    (
                        torch.einsum(
                            "...ijhd,...kjhd->...ijhk",
                            self.linear_q(z_ij).view(
                                *(z_ij.shape[:-3]),
                                z_ij.shape[-2],
                                z_ij.shape[-2],
                                self.n_head,
                                self.head_dim,
                            ),
                            self.linear_k(z_ij).view(
                                *(z_ij.shape[:-3]),
                                z_ij.shape[-2],
                                z_ij.shape[-2],
                                self.n_head,
                                self.head_dim,
                            ),
                        )
                        * self.scale
                        + torch.unsqueeze(self.linear_b(z_ij), -1)
                    ),
                    dim=-1,
                ),
                self.linear_v(z_ij).view(
                    *(z_ij.shape[:-3]),
                    z_ij.shape[-2],
                    z_ij.shape[-2],
                    self.n_head,
                    self.head_dim,
                ),
            )
        ).view(
            *(z_ij.shape[:-3]),
            z_ij.shape[-2],
            z_ij.shape[-2],
            self.c,
        )
    )

beignet.nn.alphafold3.TriangleAttentionStartingNode

Bases: Module

Triangular gated self-attention around starting node from AlphaFold 3.

This implements Algorithm 14 from AlphaFold 3, which performs triangular gated self-attention where the attention is computed around the starting node of each edge in the triangle.

Parameters:

Name Type Description Default
c int

Channel dimension for the pair representation

32
n_head int

Number of attention heads

4

Examples:

>>> import torch
>>> from beignet.nn import TriangleAttentionStartingNode
>>> batch_size, seq_len, c = 2, 10, 32
>>> n_head = 4
>>> module = TriangleAttentionStartingNode(c=c, n_head=n_head)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 32])
References

.. [1] AlphaFold 3 paper, Algorithm 14: Triangular gated self-attention around starting node

Source code in src/beignet/nn/alphafold3/_triangle_attention_starting_node.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class TriangleAttentionStartingNode(Module):
    r"""
    Triangular gated self-attention around starting node from AlphaFold 3.

    This implements Algorithm 14 from AlphaFold 3, which performs triangular
    gated self-attention where the attention is computed around the starting
    node of each edge in the triangle.

    Parameters
    ----------
    c : int, default=32
        Channel dimension for the pair representation
    n_head : int, default=4
        Number of attention heads

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import TriangleAttentionStartingNode
    >>> batch_size, seq_len, c = 2, 10, 32
    >>> n_head = 4
    >>> module = TriangleAttentionStartingNode(c=c, n_head=n_head)
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
    >>> z_tilde_ij = module(z_ij)
    >>> z_tilde_ij.shape
    torch.Size([2, 10, 10, 32])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 14: Triangular gated self-attention around starting node
    """

    def __init__(self, c: int = 32, n_head: int = 4):
        super().__init__()

        self.c = c
        self.n_head = n_head
        self.head_dim = c // n_head

        if c % n_head != 0:
            raise ValueError(
                f"Channel dimension {c} must be divisible by number of heads {n_head}"
            )

        # Layer normalization for input (step 1)
        self.layer_norm = LayerNorm(c)

        # Linear projections for queries, keys, values (step 2)
        self.linear_q = Linear(c, c, bias=False)
        self.linear_k = Linear(c, c, bias=False)
        self.linear_v = Linear(c, c, bias=False)

        # Bias projection (step 3)
        self.linear_b = Linear(c, n_head, bias=False)

        # Gate projection (step 4)
        self.linear_g = Linear(c, c, bias=False)

        # Output projection (step 7)
        self.output_linear = Linear(c, c, bias=False)

        # Scale factor for attention
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, z_ij: Tensor) -> Tensor:
        r"""
        Forward pass of triangular gated self-attention around starting node.

        Parameters
        ----------
        z_ij : Tensor, shape=(..., s, s, c)
            Input pair representation where s is sequence length
            and c is channel dimension.

        Returns
        -------
        z_tilde_ij : Tensor, shape=(..., s, s, c)
            Updated pair representation after triangular attention.
        """
        z_ij = self.layer_norm(z_ij)

        return self.output_linear(
            (
                torch.sigmoid(self.linear_g(z_ij)).view(
                    *(z_ij.shape[:-3]),
                    z_ij.shape[-2],
                    z_ij.shape[-2],
                    self.n_head,
                    self.head_dim,
                )
                * torch.einsum(
                    "...ijhk,...jkhd->...ijhd",
                    torch.softmax(
                        (
                            torch.einsum(
                                "...ijhd,...jkhd->...ijhk",
                                self.linear_q(z_ij).view(
                                    *(z_ij.shape[:-3]),
                                    z_ij.shape[-2],
                                    z_ij.shape[-2],
                                    self.n_head,
                                    self.head_dim,
                                ),
                                self.linear_k(z_ij).view(
                                    *(z_ij.shape[:-3]),
                                    z_ij.shape[-2],
                                    z_ij.shape[-2],
                                    self.n_head,
                                    self.head_dim,
                                ),
                            )
                            * self.scale
                            + torch.unsqueeze(self.linear_b(z_ij), -1)
                        ),
                        dim=-1,
                    ),
                    self.linear_v(z_ij).view(
                        *(z_ij.shape[:-3]),
                        z_ij.shape[-2],
                        z_ij.shape[-2],
                        self.n_head,
                        self.head_dim,
                    ),
                )
            ).view(
                *(z_ij.shape[:-3]),
                z_ij.shape[-2],
                z_ij.shape[-2],
                self.c,
            )
        )
forward
forward(z_ij)

Forward pass of triangular gated self-attention around starting node.

Parameters:

Name Type Description Default
z_ij Tensor, shape=(..., s, s, c)

Input pair representation where s is sequence length and c is channel dimension.

required

Returns:

Name Type Description
z_tilde_ij Tensor, shape=(..., s, s, c)

Updated pair representation after triangular attention.

Source code in src/beignet/nn/alphafold3/_triangle_attention_starting_node.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def forward(self, z_ij: Tensor) -> Tensor:
    r"""
    Forward pass of triangular gated self-attention around starting node.

    Parameters
    ----------
    z_ij : Tensor, shape=(..., s, s, c)
        Input pair representation where s is sequence length
        and c is channel dimension.

    Returns
    -------
    z_tilde_ij : Tensor, shape=(..., s, s, c)
        Updated pair representation after triangular attention.
    """
    z_ij = self.layer_norm(z_ij)

    return self.output_linear(
        (
            torch.sigmoid(self.linear_g(z_ij)).view(
                *(z_ij.shape[:-3]),
                z_ij.shape[-2],
                z_ij.shape[-2],
                self.n_head,
                self.head_dim,
            )
            * torch.einsum(
                "...ijhk,...jkhd->...ijhd",
                torch.softmax(
                    (
                        torch.einsum(
                            "...ijhd,...jkhd->...ijhk",
                            self.linear_q(z_ij).view(
                                *(z_ij.shape[:-3]),
                                z_ij.shape[-2],
                                z_ij.shape[-2],
                                self.n_head,
                                self.head_dim,
                            ),
                            self.linear_k(z_ij).view(
                                *(z_ij.shape[:-3]),
                                z_ij.shape[-2],
                                z_ij.shape[-2],
                                self.n_head,
                                self.head_dim,
                            ),
                        )
                        * self.scale
                        + torch.unsqueeze(self.linear_b(z_ij), -1)
                    ),
                    dim=-1,
                ),
                self.linear_v(z_ij).view(
                    *(z_ij.shape[:-3]),
                    z_ij.shape[-2],
                    z_ij.shape[-2],
                    self.n_head,
                    self.head_dim,
                ),
            )
        ).view(
            *(z_ij.shape[:-3]),
            z_ij.shape[-2],
            z_ij.shape[-2],
            self.c,
        )
    )

beignet.nn.alphafold3.TriangleMultiplicationIncoming

Bases: Module

Triangular multiplicative update using "incoming" edges from AlphaFold 3.

This implements Algorithm 13 from AlphaFold 3, which performs triangular multiplicative updates on pair representations using incoming edges. The key difference from the outgoing version is in step 4 where a_ki ⊙ b_kj is computed instead of a_ik ⊙ b_jk.

Parameters:

Name Type Description Default
c int

Channel dimension for the pair representation

128

Examples:

>>> import torch
>>> from beignet.nn import TriangleMultiplicationIncoming
>>> batch_size, seq_len, c = 2, 10, 128
>>> module = TriangleMultiplicationIncoming(c=c)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 128])
References

.. [1] AlphaFold 3 paper, Algorithm 13: Triangular multiplicative update using "incoming" edges

Source code in src/beignet/nn/alphafold3/_triangle_multiplication_incoming.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class TriangleMultiplicationIncoming(Module):
    r"""
    Triangular multiplicative update using "incoming" edges from AlphaFold 3.

    This implements Algorithm 13 from AlphaFold 3, which performs triangular
    multiplicative updates on pair representations using incoming edges.
    The key difference from the outgoing version is in step 4 where
    a_ki ⊙ b_kj is computed instead of a_ik ⊙ b_jk.

    Parameters
    ----------
    c : int, default=128
        Channel dimension for the pair representation

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import TriangleMultiplicationIncoming
    >>> batch_size, seq_len, c = 2, 10, 128
    >>> module = TriangleMultiplicationIncoming(c=c)
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
    >>> z_tilde_ij = module(z_ij)
    >>> z_tilde_ij.shape
    torch.Size([2, 10, 10, 128])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 13: Triangular multiplicative update using "incoming" edges
    """

    def __init__(self, c: int = 128):
        super().__init__()

        self.c = c

        # Layer normalization (step 1)
        self.layer_norm = LayerNorm(c)

        # Linear projections without bias for a and b (step 2)
        self.linear_a = Linear(c, c, bias=False)
        self.linear_b = Linear(c, c, bias=False)

        # Linear projection without bias for g (step 3)
        self.linear_g = Linear(c, c, bias=False)

        # Final linear projection without bias with layer norm (step 4)
        self.final_layer_norm = LayerNorm(c)
        self.final_linear = Linear(c, c, bias=False)

    def forward(self, z_ij: Tensor) -> Tensor:
        r"""
        Forward pass of triangular multiplicative update with incoming edges.

        Parameters
        ----------
        z_ij : Tensor, shape=(..., s, s, c)
            Input pair representation where s is sequence length
            and c is channel dimension.

        Returns
        -------
        z_tilde_ij : Tensor, shape=(..., s, s, c)
            Updated pair representation after triangular multiplicative update.
        """
        z_ij = self.layer_norm(z_ij)

        return torch.sigmoid(self.linear_g(z_ij)) * self.final_linear(
            self.final_layer_norm(
                torch.einsum(
                    "...kic,...kjc->...ijc",
                    torch.sigmoid(self.linear_a(z_ij)).transpose(-3, -2),
                    torch.sigmoid(self.linear_b(z_ij)),
                )
            )
        )
forward
forward(z_ij)

Forward pass of triangular multiplicative update with incoming edges.

Parameters:

Name Type Description Default
z_ij Tensor, shape=(..., s, s, c)

Input pair representation where s is sequence length and c is channel dimension.

required

Returns:

Name Type Description
z_tilde_ij Tensor, shape=(..., s, s, c)

Updated pair representation after triangular multiplicative update.

Source code in src/beignet/nn/alphafold3/_triangle_multiplication_incoming.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
def forward(self, z_ij: Tensor) -> Tensor:
    r"""
    Forward pass of triangular multiplicative update with incoming edges.

    Parameters
    ----------
    z_ij : Tensor, shape=(..., s, s, c)
        Input pair representation where s is sequence length
        and c is channel dimension.

    Returns
    -------
    z_tilde_ij : Tensor, shape=(..., s, s, c)
        Updated pair representation after triangular multiplicative update.
    """
    z_ij = self.layer_norm(z_ij)

    return torch.sigmoid(self.linear_g(z_ij)) * self.final_linear(
        self.final_layer_norm(
            torch.einsum(
                "...kic,...kjc->...ijc",
                torch.sigmoid(self.linear_a(z_ij)).transpose(-3, -2),
                torch.sigmoid(self.linear_b(z_ij)),
            )
        )
    )

beignet.nn.alphafold3.TriangleMultiplicationOutgoing

Bases: Module

Triangular multiplicative update using "outgoing" edges from AlphaFold 3.

This implements Algorithm 12 from AlphaFold 3, which performs triangular multiplicative updates on pair representations. The algorithm establishes direct communication between edges that connect 3 nodes in a triangle, allowing the network to detect inconsistencies in spatial relationships.

Parameters:

Name Type Description Default
c int

Channel dimension for the pair representation

128

Examples:

>>> import torch
>>> from beignet.nn import TriangleMultiplicationOutgoing
>>> batch_size, seq_len, c = 2, 10, 128
>>> module = TriangleMultiplicationOutgoing(c=c)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 128])
References

.. [1] AlphaFold 3 paper, Algorithm 12: Triangular multiplicative update using "outgoing" edges

Source code in src/beignet/nn/alphafold3/_triangle_multiplication_outgoing.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
class TriangleMultiplicationOutgoing(Module):
    r"""
    Triangular multiplicative update using "outgoing" edges from AlphaFold 3.

    This implements Algorithm 12 from AlphaFold 3, which performs triangular
    multiplicative updates on pair representations. The algorithm establishes
    direct communication between edges that connect 3 nodes in a triangle,
    allowing the network to detect inconsistencies in spatial relationships.

    Parameters
    ----------
    c : int, default=128
        Channel dimension for the pair representation

    Examples
    --------
    >>> import torch
    >>> from beignet.nn import TriangleMultiplicationOutgoing
    >>> batch_size, seq_len, c = 2, 10, 128
    >>> module = TriangleMultiplicationOutgoing(c=c)
    >>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
    >>> z_tilde_ij = module(z_ij)
    >>> z_tilde_ij.shape
    torch.Size([2, 10, 10, 128])

    References
    ----------
    .. [1] AlphaFold 3 paper, Algorithm 12: Triangular multiplicative update using "outgoing" edges
    """

    def __init__(self, c: int = 128):
        super().__init__()

        self.c = c

        # Layer normalization (step 1)
        self.layer_norm = LayerNorm(c)

        # Linear projections without bias for a and b (step 2)
        self.linear_a = Linear(c, c, bias=False)
        self.linear_b = Linear(c, c, bias=False)

        # Linear projection without bias for g (step 3)
        self.linear_g = Linear(c, c, bias=False)

        # Final linear projection without bias with layer norm (step 4)
        self.final_layer_norm = LayerNorm(c)
        self.final_linear = Linear(c, c, bias=False)

    def forward(self, z_ij: Tensor) -> Tensor:
        r"""
        Forward pass of triangular multiplicative update with outgoing edges.

        Parameters
        ----------
        z_ij : Tensor, shape=(..., s, s, c)
            Input pair representation where s is sequence length
            and c is channel dimension.

        Returns
        -------
        z_tilde_ij : Tensor, shape=(..., s, s, c)
            Updated pair representation after triangular multiplicative update.
        """
        # Step 1: Layer normalization
        z_ij = self.layer_norm(z_ij)

        return torch.sigmoid(self.linear_g(z_ij)) * self.final_linear(
            self.final_layer_norm(
                torch.einsum(
                    "...ikc,...jkc->...ijc",
                    torch.sigmoid(self.linear_a(z_ij)),
                    torch.sigmoid(self.linear_b(z_ij)).transpose(-3, -2),
                )
            )
        )
forward
forward(z_ij)

Forward pass of triangular multiplicative update with outgoing edges.

Parameters:

Name Type Description Default
z_ij Tensor, shape=(..., s, s, c)

Input pair representation where s is sequence length and c is channel dimension.

required

Returns:

Name Type Description
z_tilde_ij Tensor, shape=(..., s, s, c)

Updated pair representation after triangular multiplicative update.

Source code in src/beignet/nn/alphafold3/_triangle_multiplication_outgoing.py
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def forward(self, z_ij: Tensor) -> Tensor:
    r"""
    Forward pass of triangular multiplicative update with outgoing edges.

    Parameters
    ----------
    z_ij : Tensor, shape=(..., s, s, c)
        Input pair representation where s is sequence length
        and c is channel dimension.

    Returns
    -------
    z_tilde_ij : Tensor, shape=(..., s, s, c)
        Updated pair representation after triangular multiplicative update.
    """
    # Step 1: Layer normalization
    z_ij = self.layer_norm(z_ij)

    return torch.sigmoid(self.linear_g(z_ij)) * self.final_linear(
        self.final_layer_norm(
            torch.einsum(
                "...ikc,...jkc->...ijc",
                torch.sigmoid(self.linear_a(z_ij)),
                torch.sigmoid(self.linear_b(z_ij)).transpose(-3, -2),
            )
        )
    )