Skip to content

Modules

beignet.lightning.AlphaFold3LightningModule

Bases: LightningModule

Source code in src/beignet/lightning/_alphafold3_lightning_module.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class AlphaFold3LightningModule(LightningModule):
    def __init__(
        self,
        optimizer: Optimizer,
        scheduler: LRScheduler,
        *,
        module: Module = AlphaFold3,
    ):
        super().__init__()

        self.loss_weights = {
            "aligned_error": 0.5,
            "diffusion": 1.0,
            "distance_error": 0.25,
            "distogram": 1.0,
            "experimentally_resolved": 0.25,
            "local_distance_difference_test": 0.0,
        }

        self.module = module

        self.optimizer, self.scheduler = optimizer, scheduler

        self.save_hyperparameters(logger=False, ignore=["module"])

    def forward(
        self,
        x: Tensor,
    ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
        return self.module(x)

    def loss(
        self,
        input: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor),
        target: (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor),
    ) -> Tensor:
        (
            input,
            local_distance_difference_test,
            aligned_error,
            distance_error,
            experimentally_resolved,
            distogram,
        ) = input

        (
            target,
            target_local_distance_difference_test,
            target_aligned_error,
            target_distance_error,
            target_experimentally_resolved,
            target_distogram,
        ) = target

        smooth_local_distance_difference_test(
            local_distance_difference_test,
            target_local_distance_difference_test,
        )

        return local_distance_difference_test

    def training_step(self, batch, batch_idx):
        inputs, targets = batch

        output = self(inputs)

        loss = self.loss(output, targets)

        self.log("Loss (Train)", loss)

        return loss

    def validation_step(self, batch, batch_idx):
        inputs, targets = batch

        outputs = self(inputs)

        loss = self.loss(outputs, targets)

        self.log("Loss (Validation)", loss)

    def configure_optimizers(self):
        optimizer = self.hparams.optimizer(self.parameters())

        if self.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer)

            if hasattr(scheduler, "T_max"):
                scheduler.T_max = self.trainer.max_epochs

            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "Loss (Validation)",
                },
            }

        return optimizer