grelu.lightning.losses#

Custom loss functions

Classes#

PoissonMultinomialLoss

Possion decomposition with multinomial specificity term.

Module Contents#

class grelu.lightning.losses.PoissonMultinomialLoss(total_weight: float = 1, eps: float = 1e-07, log_input: bool = True, reduction: str = 'mean', multinomial_axis: str = 'length')[source]#

Bases: torch.nn.Module

Possion decomposition with multinomial specificity term.

Parameters:
  • total_weight – Weight of the Poisson total term.

  • eps – Added small value to avoid log(0). Only needed if log_input = False.

  • log_input – If True, the input is transformed with torch.exp to produce predicted counts. Otherwise, the input is assumed to already represent predicted counts.

  • multinomial_axis – Either “length” or “task”, representing the axis along which the multinomial distribution should be calculated.

  • reduction – “mean” or “none”.

eps[source]#
total_weight[source]#
log_input[source]#
reduction[source]#
forward(input: torch.Tensor, target: torch.Tensor) torch.Tensor[source]#

Loss computation

Parameters:
  • input – Tensor of shape (B, T, L)

  • target – Tensor of shape (B, T, L)

Returns:

Loss value