grelu.lightning.losses#
grelu.lightning.losses contains custom loss functions to train sequence-to-function models. These metrics are used in grelu.lightning.
All loss functions inherit from torch.nn.Module and define a forward function that takes input and target tensors. All loss functions produce a single value per task, which can be averaged across tasks by setting reduction=”mean”.
Classes#
Possion decomposition with multinomial specificity term. |
Module Contents#
- class grelu.lightning.losses.PoissonMultinomialLoss(total_weight: float = 1, eps: float = 1e-07, log_input: bool = True, reduction: str = 'mean', multinomial_axis: str = 'length')[source]#
Bases:
torch.nn.Module
Possion decomposition with multinomial specificity term.
- Parameters:
total_weight – Weight of the Poisson total term.
eps – Added small value to avoid log(0). Only needed if log_input = False.
log_input – If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts.
multinomial_axis – Either “length” or “task”, representing the axis along which the multinomial distribution should be calculated.
reduction – “mean” or “none”.