Skip to content

Schedulers

beignet.diffusers.schedulers.AlphaFold3Scheduler

Bases: SchedulerMixin

AF3-style coordinate diffusion scheduler (paired with your DiffusionModule).

Uses the update scheme consistent with your SampleDiffusion:

t_hat = c_{τ-1} * (γ + 1)
ζ ∼ λ * sqrt(max(t_hat^2 - c_{τ-1}^2, 0)) * N(0, I)
x_{τ} = x_noisy + η * (c_τ - t_hat) * ((x_{τ-1} - x_denoised) / t_hat)

Args: schedule: 1-D tensor of shape (T+1,) with noise levels [c0,...,cT] gamma0: γ_0 (default 0.8) gamma_min: threshold: if c_τ > gamma_min → γ=γ_0 else 0 noise_scale: λ (default 1.003) step_scale: η (default 1.5)

Source code in src/beignet/diffusers/schedulers/_alphafold3_scheduler.py
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
class AlphaFold3Scheduler(SchedulerMixin):
    r"""
    AF3-style coordinate diffusion scheduler (paired with your DiffusionModule).

    Uses the update scheme consistent with your SampleDiffusion:

        t_hat = c_{τ-1} * (γ + 1)
        ζ ∼ λ * sqrt(max(t_hat^2 - c_{τ-1}^2, 0)) * N(0, I)
        x_{τ} = x_noisy + η * (c_τ - t_hat) * ((x_{τ-1} - x_denoised) / t_hat)

    Args:
        schedule: 1-D tensor of shape (T+1,) with noise levels [c0,...,cT]
        gamma0: γ_0 (default 0.8)
        gamma_min: threshold: if c_τ > gamma_min → γ=γ_0 else 0
        noise_scale: λ (default 1.003)
        step_scale:  η (default 1.5)
    """

    def __init__(
        self,
        schedule: Tensor,
        gamma0: float = 0.8,
        gamma_min: float = 1.0,
        noise_scale: float = 1.003,
        step_scale: float = 1.5,
    ):
        super().__init__()
        schedule = torch.as_tensor(schedule, dtype=torch.float32)
        assert schedule.ndim == 1 and schedule.numel() >= 2, (
            "Schedule must be (T+1,) with T>=1"
        )
        self.register_to_config(
            gamma0=gamma0,
            gamma_min=gamma_min,
            noise_scale=noise_scale,
            step_scale=step_scale,
        )
        self.register_buffer("c", schedule.clone(), persistent=True)

    @property
    def num_inference_steps(self) -> int:
        return int(self.c.numel() - 1)

    # ---------- Teacher-forced training helper ----------
    def add_noise(
        self, x_clean: Tensor, t_index: Tensor | int
    ) -> tuple[Tensor, Tensor]:
        """
        Build x_noisy and t_hat from clean x at schedule index `t_index` (>=1).

        Args:
            x_clean: (B, N, 3)
            t_index: int or (B,) int tensor with values in [1, T]
        Returns:
            x_noisy: (B, N, 3)
            t_hat  : (B,) time scalars matching AF3 equations
        """
        device, dtype = x_clean.device, x_clean.dtype
        c = self.c.to(device=device, dtype=dtype)

        if isinstance(t_index, int):
            t_index = torch.full(
                (x_clean.shape[0],), t_index, device=device, dtype=torch.long
            )

        c_tau = c[t_index]  # (B,)
        c_prev = c[t_index - 1]  # (B,)
        gamma = torch.where(
            c_tau > self.config.gamma_min,
            torch.as_tensor(self.config.gamma0, device=device, dtype=dtype),
            torch.zeros((), device=device, dtype=dtype),
        )
        t_hat = c_prev * (gamma + 1.0)  # (B,)

        variance = (t_hat**2 - c_prev**2).clamp_min(0.0).view(-1, 1, 1)  # (B,1,1)
        eps = torch.randn_like(x_clean)
        zeta = self.config.noise_scale * variance.sqrt() * eps
        x_noisy = x_clean + zeta
        return x_noisy, t_hat  # (B,N,3), (B,)

    # ---------- Sampling step ----------
    @torch.no_grad()
    def step(
        self,
        x_denoised: Tensor,  # model's prediction x̂ (B,N,3)
        x_noisy: Tensor,  # current noisy sample (B,N,3)
        x_prev_ref: Tensor,  # previous sample x_{τ-1} (B,N,3)
        t_index: int,  # integer τ in [1, T]
    ) -> SchedulerOutput:
        """
        One sampler update:
            delta = (x_{τ-1} - x̂) / t_hat
            x_τ   = x_noisy + η * (c_τ - t_hat) * delta
        """
        device, dtype = x_noisy.device, x_noisy.dtype
        c = self.c.to(device=device, dtype=dtype)
        c_tau = c[t_index]
        c_prev = c[t_index - 1]

        gamma = self.config.gamma0 if float(c_tau) > self.config.gamma_min else 0.0
        t_hat = c_prev * (gamma + 1.0)

        delta = (x_prev_ref - x_denoised) / t_hat
        dt = c_tau - t_hat
        x_next = x_noisy + self.config.step_scale * dt * delta
        return SchedulerOutput(prev_sample=x_next)
add_noise
add_noise(x_clean, t_index)

Build x_noisy and t_hat from clean x at schedule index t_index (>=1).

Args: x_clean: (B, N, 3) t_index: int or (B,) int tensor with values in [1, T] Returns: x_noisy: (B, N, 3) t_hat : (B,) time scalars matching AF3 equations

Source code in src/beignet/diffusers/schedulers/_alphafold3_scheduler.py
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
def add_noise(
    self, x_clean: Tensor, t_index: Tensor | int
) -> tuple[Tensor, Tensor]:
    """
    Build x_noisy and t_hat from clean x at schedule index `t_index` (>=1).

    Args:
        x_clean: (B, N, 3)
        t_index: int or (B,) int tensor with values in [1, T]
    Returns:
        x_noisy: (B, N, 3)
        t_hat  : (B,) time scalars matching AF3 equations
    """
    device, dtype = x_clean.device, x_clean.dtype
    c = self.c.to(device=device, dtype=dtype)

    if isinstance(t_index, int):
        t_index = torch.full(
            (x_clean.shape[0],), t_index, device=device, dtype=torch.long
        )

    c_tau = c[t_index]  # (B,)
    c_prev = c[t_index - 1]  # (B,)
    gamma = torch.where(
        c_tau > self.config.gamma_min,
        torch.as_tensor(self.config.gamma0, device=device, dtype=dtype),
        torch.zeros((), device=device, dtype=dtype),
    )
    t_hat = c_prev * (gamma + 1.0)  # (B,)

    variance = (t_hat**2 - c_prev**2).clamp_min(0.0).view(-1, 1, 1)  # (B,1,1)
    eps = torch.randn_like(x_clean)
    zeta = self.config.noise_scale * variance.sqrt() * eps
    x_noisy = x_clean + zeta
    return x_noisy, t_hat  # (B,N,3), (B,)
step
step(x_denoised, x_noisy, x_prev_ref, t_index)

One sampler update: delta = (x_{τ-1} - x̂) / t_hat x_τ = x_noisy + η * (c_τ - t_hat) * delta

Source code in src/beignet/diffusers/schedulers/_alphafold3_scheduler.py
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
@torch.no_grad()
def step(
    self,
    x_denoised: Tensor,  # model's prediction x̂ (B,N,3)
    x_noisy: Tensor,  # current noisy sample (B,N,3)
    x_prev_ref: Tensor,  # previous sample x_{τ-1} (B,N,3)
    t_index: int,  # integer τ in [1, T]
) -> SchedulerOutput:
    """
    One sampler update:
        delta = (x_{τ-1} - x̂) / t_hat
        x_τ   = x_noisy + η * (c_τ - t_hat) * delta
    """
    device, dtype = x_noisy.device, x_noisy.dtype
    c = self.c.to(device=device, dtype=dtype)
    c_tau = c[t_index]
    c_prev = c[t_index - 1]

    gamma = self.config.gamma0 if float(c_tau) > self.config.gamma_min else 0.0
    t_hat = c_prev * (gamma + 1.0)

    delta = (x_prev_ref - x_denoised) / t_hat
    dt = c_tau - t_hat
    x_next = x_noisy + self.config.step_scale * dt * delta
    return SchedulerOutput(prev_sample=x_next)