Layers
beignet.nn.AlphaFold3
Bases: Module
Main Inference Loop for AlphaFold 3.
This module implements Algorithm 1 exactly, which is the main inference pipeline for AlphaFold 3. It processes input features through multiple stages including feature embedding, MSA processing, template embedding, Pairformer stacks, diffusion sampling, and confidence prediction.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_cycle
|
int
|
Number of recycling cycles |
4
|
c_s
|
int
|
Single representation dimension |
384
|
c_z
|
int
|
Pair representation dimension |
128
|
c_m
|
int
|
MSA representation dimension |
64
|
c_template
|
int
|
Template feature dimension |
64
|
n_blocks_pairformer
|
int
|
Number of blocks in PairformerStack |
48
|
n_head
|
int
|
Number of attention heads |
16
|
Examples:
>>> import torch
>>> from beignet.nn import AlphaFold3
>>> batch_size, n_tokens = 2, 64
>>> module = AlphaFold3(n_cycle=2) # Smaller for example
>>> f_star = {
... 'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
... 'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
... 'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
... 'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
... 'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
... 'token_bonds': torch.randn(batch_size, n_tokens, n_tokens, 32)
... }
>>> outputs = module(f_star)
>>> outputs['x_pred'].shape
torch.Size([2, 64, 3])
References
.. [1] AlphaFold 3 Algorithm 1: Main Inference Loop
Source code in src/beignet/nn/alphafold3/_alphafold3.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
|
forward
forward(f_star)
Forward pass implementing Algorithm 1 exactly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_star
|
dict
|
Dictionary containing input features with keys: - 'asym_id': asymmetric unit IDs (batch, n_tokens) - 'residue_index': residue indices (batch, n_tokens) - 'entity_id': entity IDs (batch, n_tokens) - 'token_index': token indices (batch, n_tokens) - 'sym_id': symmetry IDs (batch, n_tokens) - 'token_bonds': token bond features (batch, n_tokens, n_tokens, bond_dim) - Optional: 'template_features', 'msa_features', etc. |
required |
Returns:
Name | Type | Description |
---|---|---|
outputs |
dict
|
Dictionary containing: - 'x_pred': predicted coordinates (batch, n_tokens, 3) - 'p_plddt': pLDDT confidence (batch, n_tokens) - 'p_pae': PAE confidence (batch, n_tokens, n_tokens) - 'p_pde': PDE confidence (batch, n_tokens, n_tokens) - 'p_resolved': resolved confidence (batch, n_tokens) - 'p_distogram': distance distributions (batch, n_tokens, n_tokens, n_bins) |
Source code in src/beignet/nn/alphafold3/_alphafold3.py
144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
|
beignet.nn.alphafold3.AtomAttentionDecoder
Bases: Module
Atom Attention Decoder for AlphaFold 3.
This module broadcasts per-token activations to per-atom activations, applies cross attention transformer, and maps to position updates. Implements Algorithm 6 exactly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c_token
|
int
|
Channel dimension for token representations |
768
|
c_atom
|
int
|
Channel dimension for atom representations |
128
|
n_block
|
int
|
Number of transformer blocks |
3
|
n_head
|
int
|
Number of attention heads |
4
|
Examples:
>>> import torch
>>> from beignet.nn.alphafold3 import _AtomAttentionDecoder
>>> batch_size, n_tokens, n_atoms = 2, 32, 1000
>>> module = _AtomAttentionDecoder()
>>>
>>> a = torch.randn(batch_size, n_tokens, 768) # Token representations
>>> q_skip = torch.randn(batch_size, n_atoms, 768) # Query skip
>>> c_skip = torch.randn(batch_size, n_atoms, 128) # Context skip
>>> p_skip = torch.randn(batch_size, n_atoms, n_atoms, 16) # Pair skip
>>>
>>> r_update = module(a, q_skip, c_skip, p_skip)
>>> r_update.shape
torch.Size([2, 1000, 3])
References
.. [1] AlphaFold 3 Algorithm 6: Atom attention decoder
Source code in src/beignet/nn/alphafold3/_atom_attention_decoder.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
forward
forward(a, q_skip, c_skip, p_skip)
Forward pass of Atom Attention Decoder.
Implements Algorithm 6 exactly: 1. q_l = LinearNoBias(a_tok_idx(l)) + q_l^skip 2. {q_l} = AtomTransformer({q_l}, {c_l^skip}, {p_lm^skip}, N_block=3, N_head=4) 3. r_l^update = LinearNoBias(LayerNorm(q_l))
Parameters:
Name | Type | Description | Default |
---|---|---|---|
a
|
Tensor, shape=(batch_size, n_tokens, c_token)
|
Token-level representations |
required |
q_skip
|
Tensor, shape=(batch_size, n_atoms, c_token)
|
Query skip connection |
required |
c_skip
|
Tensor, shape=(batch_size, n_atoms, c_atom)
|
Context skip connection |
required |
p_skip
|
Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
|
Pair skip connection |
required |
Returns:
Name | Type | Description |
---|---|---|
r_update |
Tensor, shape=(batch_size, n_atoms, 3)
|
Position updates for atoms |
Source code in src/beignet/nn/alphafold3/_atom_attention_decoder.py
74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
|
beignet.nn.alphafold3.AtomAttentionEncoder
Bases: Module
Atom Attention Encoder for AlphaFold 3.
This module implements Algorithm 5 exactly, creating atom single conditioning, embedding offsets and distances, running cross attention transformer, and aggregating per-atom to per-token representations.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c_atom
|
int
|
Channel dimension for atom representations |
128
|
c_atompair
|
int
|
Channel dimension for atom pair representations |
16
|
c_token
|
int
|
Channel dimension for token representations |
384
|
n_block
|
int
|
Number of transformer blocks |
3
|
n_head
|
int
|
Number of attention heads |
4
|
Examples:
>>> import torch
>>> from beignet.nn import AtomAttentionEncoder
>>> batch_size, n_atoms = 2, 1000
>>> module = AtomAttentionEncoder()
>>>
>>> # Feature dictionary with all required atom features
>>> f_star = {
... 'ref_pos': torch.randn(batch_size, n_atoms, 3),
... 'ref_mask': torch.ones(batch_size, n_atoms),
... 'ref_element': torch.randint(0, 118, (batch_size, n_atoms)),
... 'ref_atom_name_chars': torch.randint(0, 26, (batch_size, n_atoms, 4)),
... 'ref_charge': torch.randn(batch_size, n_atoms),
... 'restype': torch.randint(0, 21, (batch_size, n_atoms)),
... 'profile': torch.randn(batch_size, n_atoms, 20),
... 'deletion_mean': torch.randn(batch_size, n_atoms),
... 'ref_space_uid': torch.randint(0, 1000, (batch_size, n_atoms)),
... }
>>> r_t = torch.randn(batch_size, n_atoms, 3)
>>> s_trunk = torch.randn(batch_size, 32, 384)
>>> z_ij = torch.randn(batch_size, n_atoms, n_atoms, 16)
>>>
>>> a, q_skip, c_skip, p_skip = module(f_star, r_t, s_trunk, z_ij)
>>> a.shape
torch.Size([2, 32, 384])
References
.. [1] AlphaFold 3 Algorithm 5: Atom attention encoder
Source code in src/beignet/nn/alphafold3/_atom_attention_encoder.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
|
forward
forward(f_star, r_t, s_trunk, z_ij)
Forward pass of Atom Attention Encoder implementing Algorithm 5.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_star
|
dict
|
Dictionary containing atom features with keys: - 'ref_pos': reference positions (batch, n_atoms, 3) - 'ref_mask': mask (batch, n_atoms) - 'ref_element': element types (batch, n_atoms) - 'ref_atom_name_chars': atom name characters (batch, n_atoms, 4) - 'ref_charge': charges (batch, n_atoms) - 'restype': residue types (batch, n_atoms) - 'profile': sequence profile (batch, n_atoms, 20) - 'deletion_mean': deletion statistics (batch, n_atoms) - 'ref_space_uid': space UIDs (batch, n_atoms) |
required |
r_t
|
Tensor, shape=(batch_size, n_atoms, 3)
|
Noisy atomic positions at time t |
required |
s_trunk
|
Tensor, shape=(batch_size, n_tokens, c_token)
|
Trunk single representations (optional, can be None) |
required |
z_ij
|
Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
|
Atom pair representations |
required |
Returns:
Name | Type | Description |
---|---|---|
a |
Tensor, shape=(batch_size, n_tokens, c_token)
|
Token-level representations |
q_skip |
Tensor, shape=(batch_size, n_atoms, c_atom)
|
Skip connection for queries |
c_skip |
Tensor, shape=(batch_size, n_atoms, c_atom)
|
Skip connection for atom features |
p_skip |
Tensor, shape=(batch_size, n_atoms, n_atoms, c_atompair)
|
Skip connection for pair features |
Source code in src/beignet/nn/alphafold3/_atom_attention_encoder.py
130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
|
beignet.nn.alphafold3.AtomTransformer
Bases: Module
Atom Transformer for AlphaFold 3.
This module implements sequence-local atom attention using rectangular blocks along the diagonal. It applies the DiffusionTransformer with sequence-local attention masking based on query and key positions.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_block
|
int
|
Number of transformer blocks |
3
|
n_head
|
int
|
Number of attention heads |
4
|
n_queries
|
int
|
Number of queries per block |
32
|
n_keys
|
int
|
Number of keys per block |
128
|
subset_centres
|
list
|
Centers for subset selection |
[15.5, 47.5, 79.5, ...]
|
c_q
|
int
|
Query dimension (inferred from input if None) |
None
|
c_kv
|
int
|
Key-value dimension (inferred from input if None) |
None
|
c_pair
|
int
|
Pair dimension (inferred from input if None) |
None
|
Examples:
>>> import torch
>>> from beignet.nn import AtomTransformer
>>> batch_size, n_atoms = 2, 1000
>>> module = AtomTransformer()
>>> q = torch.randn(batch_size, n_atoms, 128)
>>> c = torch.randn(batch_size, n_atoms, 64)
>>> p = torch.randn(batch_size, n_atoms, n_atoms, 16)
>>> output = module(q, c, p)
>>> output.shape
torch.Size([2, 1000, 128])
References
.. [1] AlphaFold 3 Algorithm 7: Atom Transformer
Source code in src/beignet/nn/alphafold3/_atom_transformer.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
forward
forward(q, c, p)
Forward pass of Atom Transformer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q
|
Tensor, shape=(batch_size, n_atoms, c_q)
|
Query representations |
required |
c
|
Tensor, shape=(batch_size, n_atoms, c_kv)
|
Context (single) representations |
required |
p
|
Tensor, shape=(batch_size, n_atoms, n_atoms, c_pair)
|
Pair representations |
required |
Returns:
Name | Type | Description |
---|---|---|
q |
Tensor, shape=(batch_size, n_atoms, c_q)
|
Updated query representations |
Source code in src/beignet/nn/alphafold3/_atom_transformer.py
139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
|
beignet.nn.alphafold3.AttentionPairBias
Bases: Module
Attention with pair bias and mask from AlphaFold 3 Algorithm 24.
This implements the AttentionPairBias operation with conditioning signal support for diffusion models. It uses AdaLN when conditioning is provided, or standard LayerNorm when not. Includes proper gating and output projection.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c_a
|
int
|
Channel dimension for input representation 'a' |
required |
c_s
|
int
|
Channel dimension for conditioning signal 's' (can be None if no conditioning) |
required |
c_z
|
int
|
Channel dimension for pair representation 'z' |
required |
n_head
|
int
|
Number of attention heads |
required |
Examples:
>>> import torch
>>> from beignet.nn import AttentionPairBias
>>> batch_size, seq_len = 2, 10
>>> c_a, c_s, c_z, n_head = 256, 384, 128, 16
>>> module = AttentionPairBias(c_a=c_a, c_s=c_s, c_z=c_z, n_head=n_head)
>>> a = torch.randn(batch_size, seq_len, c_a)
>>> s = torch.randn(batch_size, seq_len, c_s)
>>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
>>> a_out = module(a, s, z, beta)
>>> a_out.shape
torch.Size([2, 10, 256])
References
.. [1] AlphaFold 3 Algorithm 24: AttentionPairBias with pair bias and mask
Source code in src/beignet/nn/alphafold3/_attention_pair_bias.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
|
forward
forward(a, s=None, z=None, beta=None)
Forward pass of attention with pair bias.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
a
|
Tensor, shape=(..., seq_len, c_a)
|
Input representation |
required |
s
|
Tensor, shape=(..., seq_len, c_s)
|
Conditioning signal (if None, uses LayerNorm instead of AdaLN) |
None
|
z
|
Tensor, shape=(..., seq_len, seq_len, c_z)
|
Pair representation for computing attention bias |
None
|
beta
|
Tensor, shape=(..., seq_len, seq_len, n_head)
|
Additional bias terms |
None
|
Returns:
Name | Type | Description |
---|---|---|
a_out |
Tensor, shape=(..., seq_len, c_a)
|
Updated representation after attention with pair bias |
Source code in src/beignet/nn/alphafold3/_attention_pair_bias.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 |
|
beignet.nn.alphafold3.DiffusionTransformer
Bases: Module
Diffusion Transformer from AlphaFold 3 Algorithm 23.
This implements a transformer block for diffusion models that alternates between AttentionPairBias and ConditionedTransitionBlock operations. The module processes single representations {ai} conditioned on {si}, pair representations {zij}, and bias terms {βij}.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c_a
|
int
|
Channel dimension for single representation 'a' |
required |
c_s
|
int
|
Channel dimension for conditioning signal 's' |
required |
c_z
|
int
|
Channel dimension for pair representation 'z' |
required |
n_head
|
int
|
Number of attention heads |
required |
n_block
|
int
|
Number of transformer blocks |
required |
n
|
int
|
Expansion factor for ConditionedTransitionBlock |
2
|
Examples:
>>> import torch
>>> from beignet.nn import DiffusionTransformer
>>> batch_size, seq_len, c_a, c_s, c_z = 2, 32, 256, 384, 128
>>> n_head, n_block = 16, 4
>>> module = DiffusionTransformer(
... c_a=c_a, c_s=c_s, c_z=c_z,
... n_head=n_head, n_block=n_block
... )
>>> a = torch.randn(batch_size, seq_len, c_a)
>>> s = torch.randn(batch_size, seq_len, c_s)
>>> z = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> beta = torch.randn(batch_size, seq_len, seq_len, n_head)
>>> a_out = module(a, s, z, beta)
>>> a_out.shape
torch.Size([2, 32, 256])
References
.. [1] AlphaFold 3 Algorithm 23: Diffusion Transformer
Source code in src/beignet/nn/alphafold3/_diffusion_transformer.py
9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|
forward
forward(a, s, z, beta)
Forward pass of Diffusion Transformer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
a
|
Tensor, shape=(batch_size, seq_len, c_a)
|
Single representations |
required |
s
|
Tensor, shape=(batch_size, seq_len, c_s)
|
Conditioning signal |
required |
z
|
Tensor, shape=(batch_size, seq_len, seq_len, c_z)
|
Pair representations |
required |
beta
|
Tensor, shape=(batch_size, seq_len, seq_len, n_head)
|
Bias terms for attention |
required |
Returns:
Name | Type | Description |
---|---|---|
a_out |
Tensor, shape=(batch_size, seq_len, c_a)
|
Updated single representations after diffusion transformer |
Source code in src/beignet/nn/alphafold3/_diffusion_transformer.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
|
beignet.nn.alphafold3.MSA
Bases: Module
Multiple Sequence Alignment Module from AlphaFold 3.
This implements Algorithm 8 from AlphaFold 3, which is a complete MSA processing module that combines multiple sub-modules in a structured way:
- MSA representation initialization and random sampling
- Communication block with OuterProductMean
- MSA stack with MSAPairWeightedAveraging and Transition
- Pair stack with triangle updates and attention
The module processes MSA features, single representations, and pair representations through multiple blocks to capture complex evolutionary and structural patterns.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_block
|
int
|
Number of processing blocks |
4
|
c_m
|
int
|
Channel dimension for MSA representation |
64
|
c_z
|
int
|
Channel dimension for pair representation |
128
|
c_s
|
int
|
Channel dimension for single representation |
256
|
n_head_msa
|
int
|
Number of attention heads for MSA operations |
8
|
n_head_pair
|
int
|
Number of attention heads for pair operations |
4
|
dropout_rate
|
float
|
Dropout rate for MSA operations |
0.15
|
Examples:
>>> import torch
>>> from beignet.nn import MSA
>>> batch_size, seq_len, n_seq = 2, 32, 16
>>> c_m, c_z, c_s = 64, 128, 256
>>>
>>> module = MSA(n_block=2, c_m=c_m, c_z=c_z, c_s=c_s)
>>>
>>> # Input features
>>> f_msa = torch.randn(batch_size, seq_len, n_seq, 23) # MSA features
>>> f_has_deletion = torch.randn(batch_size, seq_len, n_seq, 1)
>>> f_deletion_value = torch.randn(batch_size, seq_len, n_seq, 1)
>>> s_inputs = torch.randn(batch_size, seq_len, c_s) # Single inputs
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z) # Pair representation
>>>
>>> # Forward pass
>>> updated_z_ij = module(f_msa, f_has_deletion, f_deletion_value, s_inputs, z_ij)
>>> updated_z_ij.shape
torch.Size([2, 32, 32, 128])
References
.. [1] AlphaFold 3 paper, Algorithm 8: MSA Module
Source code in src/beignet/nn/alphafold3/_msa.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
|
forward
forward(f_msa, f_has_deletion, f_deletion_value, s_inputs, z_ij)
Forward pass of the MSA Module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_msa
|
Tensor, shape=(..., s, n_seq, 23)
|
MSA features (amino acid profiles) |
required |
f_has_deletion
|
Tensor, shape=(..., s, n_seq, 1)
|
Has deletion features |
required |
f_deletion_value
|
Tensor, shape=(..., s, n_seq, 1)
|
Deletion value features |
required |
s_inputs
|
Tensor, shape=(..., s, c_s)
|
Single representation inputs |
required |
z_ij
|
Tensor, shape=(..., s, s, c_z)
|
Pair representation |
required |
Returns:
Name | Type | Description |
---|---|---|
z_ij |
Tensor, shape=(..., s, s, c_z)
|
Updated pair representation |
Source code in src/beignet/nn/alphafold3/_msa.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 |
|
beignet.nn.alphafold3.PairformerStack
Bases: Module
Pairformer stack from AlphaFold 3 Algorithm 17.
This is the exact implementation of the Pairformer stack as specified in Algorithm 17, which processes single and pair representations through N_block iterations of triangle operations and attention mechanisms.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_block
|
int
|
Number of Pairformer blocks (N_block in Algorithm 17) |
48
|
c_s
|
int
|
Channel dimension for single representation |
384
|
c_z
|
int
|
Channel dimension for pair representation |
128
|
n_head_single
|
int
|
Number of attention heads for single representation |
16
|
n_head_pair
|
int
|
Number of attention heads for pair representation |
4
|
dropout_rate
|
float
|
Dropout rate as specified in Algorithm 17 |
0.25
|
transition_n
|
int
|
Multiplier for transition layer hidden dimension |
4
|
Examples:
>>> import torch
>>> from beignet.nn import PairformerStack
>>> batch_size, seq_len = 2, 10
>>> n_block, c_s, c_z = 4, 384, 128
>>> module = PairformerStack(n_block=n_block, c_s=c_s, c_z=c_z)
>>> s_i = torch.randn(batch_size, seq_len, c_s)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c_z)
>>> s_out, z_out = module(s_i, z_ij)
>>> s_out.shape
torch.Size([2, 10, 384])
>>> z_out.shape
torch.Size([2, 10, 10, 128])
References
.. [1] Abramson, J., Adler, J., Dunger, J. et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3. Nature 630, 493–500 (2024). Algorithm 17: Pairformer stack
Source code in src/beignet/nn/alphafold3/_pairformer_stack.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
|
forward
forward(s_i, z_ij)
Forward pass of Pairformer stack.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
s_i
|
Tensor, shape=(..., s, c_s)
|
Single representation where s is sequence length |
required |
z_ij
|
Tensor, shape=(..., s, s, c_z)
|
Pair representation |
required |
Returns:
Name | Type | Description |
---|---|---|
s_out |
Tensor, shape=(..., s, c_s)
|
Updated single representation after all blocks |
z_out |
Tensor, shape=(..., s, s, c_z)
|
Updated pair representation after all blocks |
Source code in src/beignet/nn/alphafold3/_pairformer_stack.py
90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
|
beignet.nn.alphafold3.RelativePositionEncoding
Bases: Module
Relative Position Encoding for AlphaFold 3.
This module implements Algorithm 3 exactly, computing relative position encodings based on asymmetric ID, residue index, entity ID, token index, and chain ID information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
r_max
|
int
|
Maximum residue separation for clipping |
32
|
s_max
|
int
|
Maximum chain separation for clipping |
2
|
c_z
|
int
|
Output channel dimension |
128
|
Examples:
>>> import torch
>>> from beignet.nn import RelativePositionEncoding
>>> batch_size, n_tokens = 2, 100
>>> module = RelativePositionEncoding()
>>> f_star = {
... 'asym_id': torch.randint(0, 5, (batch_size, n_tokens)),
... 'residue_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
... 'entity_id': torch.randint(0, 3, (batch_size, n_tokens)),
... 'token_index': torch.arange(n_tokens).unsqueeze(0).expand(batch_size, -1),
... 'sym_id': torch.randint(0, 10, (batch_size, n_tokens)),
... }
>>> p_ij = module(f_star)
>>> p_ij.shape
torch.Size([2, 100, 100, 128])
References
.. [1] AlphaFold 3 Algorithm 3: Relative position encoding
Source code in src/beignet/nn/alphafold3/_relative_position_encoding.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
forward
forward(f_star)
Forward pass implementing Algorithm 3 exactly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_star
|
dict
|
Dictionary containing features with keys: - 'asym_id': asymmetric unit IDs (batch, n_tokens) - 'residue_index': residue indices (batch, n_tokens) - 'entity_id': entity IDs (batch, n_tokens) - 'token_index': token indices (batch, n_tokens) - 'sym_id': symmetry IDs (batch, n_tokens) |
required |
Returns:
Name | Type | Description |
---|---|---|
p_ij |
Tensor, shape=(batch, n_tokens, n_tokens, c_z)
|
Relative position encodings |
Source code in src/beignet/nn/alphafold3/_relative_position_encoding.py
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
|
beignet.nn.alphafold3.SampleDiffusion
Bases: Module
Sample Diffusion for AlphaFold 3.
This module implements Algorithm 18 exactly, performing iterative denoising sampling for structure generation using a diffusion model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
gamma_0
|
float
|
Initial gamma parameter for augmentation |
0.8
|
gamma_min
|
float
|
Minimum gamma threshold |
1.0
|
noise_scale
|
float
|
Noise scale lambda parameter |
1.003
|
step_scale
|
float
|
Step scale eta parameter |
1.5
|
s_trans
|
float
|
Translation scale for augmentation |
1.0
|
Examples:
>>> import torch
>>> from beignet.nn.alphafold3 import SampleDiffusion
>>> batch_size, n_atoms, n_tokens = 2, 1000, 32
>>> module = SampleDiffusion()
>>>
>>> # Input features
>>> f_star = {'ref_pos': torch.randn(batch_size, n_atoms, 3)}
>>> s_inputs = torch.randn(batch_size, n_atoms, 100)
>>> s_trunk = torch.randn(batch_size, n_tokens, 384)
>>> z_trunk = torch.randn(batch_size, n_tokens, n_tokens, 128)
>>> noise_schedule = [0.1, 0.2, 0.5, 1.0, 2.0, 5.0, 10.0]
>>>
>>> x_t = module(f_star, s_inputs, s_trunk, z_trunk, noise_schedule)
>>> x_t.shape
torch.Size([2, 1000, 3])
References
.. [1] AlphaFold 3 Algorithm 18: Sample Diffusion
Source code in src/beignet/nn/alphafold3/_sample_diffusion.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
|
forward
forward(f_star, s_inputs, s_trunk, z_trunk, noise_schedule)
Forward pass implementing Algorithm 18 exactly.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_star
|
dict
|
Reference structure features |
required |
s_inputs
|
Tensor, shape=(batch_size, n_atoms, c_s_inputs)
|
Input single representations |
required |
s_trunk
|
Tensor, shape=(batch_size, n_tokens, c_s)
|
Trunk single representations |
required |
z_trunk
|
Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
|
Trunk pair representations |
required |
noise_schedule
|
list
|
Noise schedule [c0, c1, ..., cT] |
required |
Returns:
Name | Type | Description |
---|---|---|
x_t |
Tensor, shape=(batch_size, n_atoms, 3)
|
Final denoised positions |
Source code in src/beignet/nn/alphafold3/_sample_diffusion.py
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 |
|
beignet.nn.alphafold3.TemplateEmbedder
Bases: Module
Template Embedder for AlphaFold 3.
This module processes template structural information and adds it to the pair representation. Templates provide structural constraints from homologous structures that guide the prediction.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c_z
|
int
|
Pair representation dimension |
128
|
c_template
|
int
|
Template feature dimension |
64
|
n_head
|
int
|
Number of attention heads |
4
|
Examples:
>>> import torch
>>> from beignet.nn import TemplateEmbedder
>>> batch_size, n_tokens = 2, 64
>>> module = TemplateEmbedder()
>>> f_star = {'template_features': torch.randn(batch_size, n_tokens, n_tokens, 64)}
>>> z_ij = torch.randn(batch_size, n_tokens, n_tokens, 128)
>>> output = module(f_star, z_ij)
>>> output.shape
torch.Size([2, 64, 64, 128])
Source code in src/beignet/nn/alphafold3/_template_embedder.py
5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|
forward
forward(f_star, z_ij)
Forward pass of Template Embedder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
f_star
|
dict
|
Dictionary containing template features with key 'template_features' |
required |
z_ij
|
Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
|
Current pair representations |
required |
Returns:
Name | Type | Description |
---|---|---|
z_ij |
Tensor, shape=(batch_size, n_tokens, n_tokens, c_z)
|
Updated pair representations with template information |
Source code in src/beignet/nn/alphafold3/_template_embedder.py
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
|
beignet.nn.alphafold3.Transition
Bases: Module
Transition layer from AlphaFold 3.
This implements Algorithm 11 from AlphaFold 3, which is a simple transition layer with layer normalization, two linear projections, and a SwiGLU activation function for enhanced non-linearity.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c
|
int
|
Input and output channel dimension |
128
|
n
|
int
|
Expansion factor for the hidden dimension |
4
|
Examples:
>>> import torch
>>> from beignet.nn import Transition
>>> batch_size, seq_len, c = 2, 10, 128
>>> n = 4
>>> module = Transition(c=c, n=n)
>>> x = torch.randn(batch_size, seq_len, c)
>>> x_out = module(x)
>>> x_out.shape
torch.Size([2, 10, 128])
References
.. [1] AlphaFold 3 paper, Algorithm 11: Transition layer
Source code in src/beignet/nn/alphafold3/_transition.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
forward
forward(x)
Forward pass of transition layer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
x
|
Tensor, shape=(..., c)
|
Input tensor where c is the channel dimension. |
required |
Returns:
Name | Type | Description |
---|---|---|
x |
Tensor, shape=(..., c)
|
Output tensor after transition layer processing. |
Source code in src/beignet/nn/alphafold3/_transition.py
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
|
beignet.nn.alphafold3.TriangleAttentionEndingNode
Bases: Module
Triangular gated self-attention around ending node from AlphaFold 3.
This implements Algorithm 15 from AlphaFold 3, which performs triangular gated self-attention where the attention is computed around the ending node of each edge in the triangle. The key differences from Algorithm 14 are in steps 5 and 6 where k^h_kj and v^h_kj are used instead of k^h_jk and v^h_jk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c
|
int
|
Channel dimension for the pair representation |
32
|
n_head
|
int
|
Number of attention heads |
4
|
Examples:
>>> import torch
>>> from beignet.nn import TriangleAttentionEndingNode
>>> batch_size, seq_len, c = 2, 10, 32
>>> n_head = 4
>>> module = TriangleAttentionEndingNode(c=c, n_head=n_head)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 32])
References
.. [1] AlphaFold 3 paper, Algorithm 15: Triangular gated self-attention around ending node
Source code in src/beignet/nn/alphafold3/_triangle_attention_ending_node.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
|
forward
forward(z_ij)
Forward pass of triangular gated self-attention around ending node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_ij
|
Tensor, shape=(..., s, s, c)
|
Input pair representation where s is sequence length and c is channel dimension. |
required |
Returns:
Name | Type | Description |
---|---|---|
z_tilde_ij |
Tensor, shape=(..., s, s, c)
|
Updated pair representation after triangular attention. |
Source code in src/beignet/nn/alphafold3/_triangle_attention_ending_node.py
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
|
beignet.nn.alphafold3.TriangleAttentionStartingNode
Bases: Module
Triangular gated self-attention around starting node from AlphaFold 3.
This implements Algorithm 14 from AlphaFold 3, which performs triangular gated self-attention where the attention is computed around the starting node of each edge in the triangle.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c
|
int
|
Channel dimension for the pair representation |
32
|
n_head
|
int
|
Number of attention heads |
4
|
Examples:
>>> import torch
>>> from beignet.nn import TriangleAttentionStartingNode
>>> batch_size, seq_len, c = 2, 10, 32
>>> n_head = 4
>>> module = TriangleAttentionStartingNode(c=c, n_head=n_head)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 32])
References
.. [1] AlphaFold 3 paper, Algorithm 14: Triangular gated self-attention around starting node
Source code in src/beignet/nn/alphafold3/_triangle_attention_starting_node.py
8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
forward
forward(z_ij)
Forward pass of triangular gated self-attention around starting node.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_ij
|
Tensor, shape=(..., s, s, c)
|
Input pair representation where s is sequence length and c is channel dimension. |
required |
Returns:
Name | Type | Description |
---|---|---|
z_tilde_ij |
Tensor, shape=(..., s, s, c)
|
Updated pair representation after triangular attention. |
Source code in src/beignet/nn/alphafold3/_triangle_attention_starting_node.py
72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
|
beignet.nn.alphafold3.TriangleMultiplicationIncoming
Bases: Module
Triangular multiplicative update using "incoming" edges from AlphaFold 3.
This implements Algorithm 13 from AlphaFold 3, which performs triangular multiplicative updates on pair representations using incoming edges. The key difference from the outgoing version is in step 4 where a_ki ⊙ b_kj is computed instead of a_ik ⊙ b_jk.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c
|
int
|
Channel dimension for the pair representation |
128
|
Examples:
>>> import torch
>>> from beignet.nn import TriangleMultiplicationIncoming
>>> batch_size, seq_len, c = 2, 10, 128
>>> module = TriangleMultiplicationIncoming(c=c)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 128])
References
.. [1] AlphaFold 3 paper, Algorithm 13: Triangular multiplicative update using "incoming" edges
Source code in src/beignet/nn/alphafold3/_triangle_multiplication_incoming.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
|
forward
forward(z_ij)
Forward pass of triangular multiplicative update with incoming edges.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_ij
|
Tensor, shape=(..., s, s, c)
|
Input pair representation where s is sequence length and c is channel dimension. |
required |
Returns:
Name | Type | Description |
---|---|---|
z_tilde_ij |
Tensor, shape=(..., s, s, c)
|
Updated pair representation after triangular multiplicative update. |
Source code in src/beignet/nn/alphafold3/_triangle_multiplication_incoming.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
|
beignet.nn.alphafold3.TriangleMultiplicationOutgoing
Bases: Module
Triangular multiplicative update using "outgoing" edges from AlphaFold 3.
This implements Algorithm 12 from AlphaFold 3, which performs triangular multiplicative updates on pair representations. The algorithm establishes direct communication between edges that connect 3 nodes in a triangle, allowing the network to detect inconsistencies in spatial relationships.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
c
|
int
|
Channel dimension for the pair representation |
128
|
Examples:
>>> import torch
>>> from beignet.nn import TriangleMultiplicationOutgoing
>>> batch_size, seq_len, c = 2, 10, 128
>>> module = TriangleMultiplicationOutgoing(c=c)
>>> z_ij = torch.randn(batch_size, seq_len, seq_len, c)
>>> z_tilde_ij = module(z_ij)
>>> z_tilde_ij.shape
torch.Size([2, 10, 10, 128])
References
.. [1] AlphaFold 3 paper, Algorithm 12: Triangular multiplicative update using "outgoing" edges
Source code in src/beignet/nn/alphafold3/_triangle_multiplication_outgoing.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|
forward
forward(z_ij)
Forward pass of triangular multiplicative update with outgoing edges.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
z_ij
|
Tensor, shape=(..., s, s, c)
|
Input pair representation where s is sequence length and c is channel dimension. |
required |
Returns:
Name | Type | Description |
---|---|---|
z_tilde_ij |
Tensor, shape=(..., s, s, c)
|
Updated pair representation after triangular multiplicative update. |
Source code in src/beignet/nn/alphafold3/_triangle_multiplication_outgoing.py
55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
|