grelu.lightning.losses#

Custom loss functions

Classes#

PoissonMultinomialLoss

Possion decomposition with multinomial specificity term.

Module Contents#

class grelu.lightning.losses.PoissonMultinomialLoss(total_weight: float = 1, eps: float = 1e-07, log_input: bool = True, reduction: str = 'mean')[source]#

Bases: torch.nn.Module

Possion decomposition with multinomial specificity term.

Parameters:
  • total_weight – Weight of the Poisson total term.

  • eps – Added small value to avoid log(0). Only needed if log_input = False.

  • log_input – If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts.

  • reduction – “mean” or “none”.

forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]#

Loss computation

Parameters:
  • input – Tensor of shape (B, T, L)

  • target – Tensor of shape (B, T, L)

Returns:

Loss value