"""
This file contains the neural network architectures.
These are all you need for inference.
"""
import torch
from torch import nn
import torch.nn.functional as F
from typing import List
[docs]class Encoder(nn.Module):
"""A class that encapsulates the encoder."""
def __init__(
self,
n_genes: int,
latent_dim: int = 128,
hidden_dim: List[int] = [1024, 1024],
dropout: float = 0.5,
input_dropout: float = 0.4,
residual: bool = False,
):
"""Constructor.
Parameters
----------
n_genes: int
The number of genes in the gene space, representing the input dimensions.
latent_dim: int, default: 128
The latent space dimensions
hidden_dim: List[int], default: [1024, 1024]
A list of hidden layer dimensions, describing the number of layers and their dimensions.
Hidden layers are constructed in the order of the list for the encoder and in reverse
for the decoder.
dropout: float, default: 0.5
The dropout rate for hidden layers
input_dropout: float, default: 0.4
The dropout rate for the input layer
residual: bool, default: False
Use residual connections.
"""
super().__init__()
self.latent_dim = latent_dim
self.network = nn.ModuleList()
self.residual = residual
if self.residual:
assert len(set(hidden_dim)) == 1
for i in range(len(hidden_dim)):
if i == 0: # input layer
self.network.append(
nn.Sequential(
nn.Dropout(p=input_dropout),
nn.Linear(n_genes, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
else: # hidden layers
self.network.append(
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
# output layer
self.network.append(nn.Linear(hidden_dim[-1], latent_dim))
[docs] def forward(self, x) -> torch.Tensor:
"""Forward.
Parameters
----------
x: torch.Tensor
Input tensor corresponding to input layer.
Returns
-------
torch.Tensor
Output tensor corresponding to output layer.
"""
for i, layer in enumerate(self.network):
if self.residual and (0 < i < len(self.network) - 1):
x = layer(x) + x
else:
x = layer(x)
return F.normalize(x, p=2, dim=1)
[docs] def save_state(self, filename: str):
"""Save model state.
Parameters
----------
filename: str
Filename to save the model state.
"""
torch.save({"state_dict": self.state_dict()}, filename)
[docs] def load_state(self, filename: str, use_gpu: bool = False):
"""Load model state.
Parameters
----------
filename: str
Filename containing the model state.
use_gpu: bool, default: False
Boolean indicating whether or not to use GPUs.
"""
if not use_gpu:
ckpt = torch.load(
filename, map_location=torch.device("cpu"), weights_only=False
)
else:
ckpt = torch.load(filename, weights_only=False)
self.load_state_dict(ckpt["state_dict"])
[docs]class Decoder(nn.Module):
"""A class that encapsulates the decoder."""
def __init__(
self,
n_genes: int,
latent_dim: int = 128,
hidden_dim: List[int] = [1024, 1024],
dropout: float = 0.5,
residual: bool = False,
):
"""Constructor.
Parameters
----------
n_genes: int
The number of genes in the gene space, representing the input dimensions.
latent_dim: int, default: 128
The latent space dimensions
hidden_dim: List[int], default: [1024, 1024]
A list of hidden layer dimensions, describing the number of layers and their dimensions.
Hidden layers are constructed in the order of the list for the encoder and in reverse
for the decoder.
dropout: float, default: 0.5
The dropout rate for hidden layers
residual: bool, default: False
Use residual connections.
"""
super().__init__()
self.latent_dim = latent_dim
self.network = nn.ModuleList()
self.residual = residual
if self.residual:
assert len(set(hidden_dim)) == 1
for i in range(len(hidden_dim)):
if i == 0: # first hidden layer
self.network.append(
nn.Sequential(
nn.Linear(latent_dim, hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
else: # other hidden layers
self.network.append(
nn.Sequential(
nn.Dropout(p=dropout),
nn.Linear(hidden_dim[i - 1], hidden_dim[i]),
nn.BatchNorm1d(hidden_dim[i]),
nn.PReLU(),
)
)
# reconstruction layer
self.network.append(nn.Linear(hidden_dim[-1], n_genes))
[docs] def forward(self, x) -> torch.Tensor:
"""Forward.
Parameters
----------
x: torch.Tensor
Input tensor corresponding to input layer.
Returns
-------
torch.Tensor
Output tensor corresponding to output layer.
"""
for i, layer in enumerate(self.network):
if self.residual and (0 < i < len(self.network) - 1):
x = layer(x) + x
else:
x = layer(x)
return x
[docs] def save_state(self, filename: str):
"""Save model state.
Parameters
----------
filename: str
Filename to save the model state.
"""
torch.save({"state_dict": self.state_dict()}, filename)
[docs] def load_state(self, filename: str, use_gpu: bool = False):
"""Load model state.
Parameters
----------
filename: str
Filename containing the model state.
use_gpu: bool, default: False
Boolean indicating whether or not to use GPUs.
"""
if not use_gpu:
ckpt = torch.load(
filename, map_location=torch.device("cpu"), weights_only=False
)
else:
ckpt = torch.load(filename, weights_only=False)
self.load_state_dict(ckpt["state_dict"])