Fine-tune Enformer to perform binary classification on human snATAC-seq data#

In this tutorial, we learn how to fine-tune the Enformer model (https://www.nature.com/articles/s41592-021-01252-x) on the CATLAS single-nucleus ATAC-seq dataset from human tissues (http://catlas.org/humanenhancer/).

We will perform binary classification, in which the model learns to predict the probability that a given sequence is accessible in the different cell types of the dataset. For an example of regression modeling, see the next tutorial.

import anndata
import os
import importlib
import pandas as pd
import numpy as np

%matplotlib inline

Set experiment parameters#

experiment='tutorial_2'
if not os.path.exists(experiment):
    os.makedirs(experiment)

Load data#

We download the CATlas ATAC-seq binary cell type x peak matrix from http://catlas.org/catlas_downloads/humantissues/cCRE_by_cell_type/ and load it as an AnnData object.

!wget http://catlas.org/catlas_downloads/humantissues/cCRE_by_cell_type/matrix.tsv.gz
ad = anndata.read_mtx('matrix.tsv.gz').T

# Prepare ad.obs
ad.obs = pd.read_table('http://catlas.org/catlas_downloads/humantissues/cCRE_by_cell_type/celltypes.txt.gz', header=None, names=['cell type'])
ad.obs_names = ad.obs['cell type']

# Prepare ad.var
var = pd.read_table('http://catlas.org/catlas_downloads/humantissues/cCRE_hg38.tsv.gz')
var.columns = ['chrom', 'start', 'end', 'cre_class', 'in_fetal', 'in_adult', 'cre_module']
var.index = var.index.astype(str)
ad.var = var

print(ad.shape)
--2024-05-28 05:45:03--  http://catlas.org/catlas_downloads/humantissues/cCRE_by_cell_type/matrix.tsv.gz
Resolving catlas.org (catlas.org)... 132.239.162.129
Connecting to catlas.org (catlas.org)|132.239.162.129|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 38772708 (37M) [application/x-gzip]
Saving to: ‘matrix.tsv.gz.2’

matrix.tsv.gz.2     100%[===================>]  36.98M  33.9MB/s    in 1.1s    

2024-05-28 05:45:05 (33.9 MB/s) - ‘matrix.tsv.gz.2’ saved [38772708/38772708]

(222, 1154611)

This contains a binary matrix representing the accessibility of 1154611 CREs measured in 222 cell types. Let us look at the components of this object:

ad.var.head()
chrom start end cre_class in_fetal in_adult cre_module
0 chr1 9955 10355 Promoter Proximal yes yes 146
1 chr1 29163 29563 Promoter yes yes 37
2 chr1 79215 79615 Distal no yes 75
3 chr1 102755 103155 Distal no yes 51
4 chr1 115530 115930 Distal yes no 36
ad.obs.head()
cell type
cell type
Follicular Follicular
Fibro General Fibro General
Acinar Acinar
T Lymphocyte 1 (CD8+) T Lymphocyte 1 (CD8+)
T lymphocyte 2 (CD4+) T lymphocyte 2 (CD4+)

The contents of this anndata object are binary values (0 or 1). 1 indicates accessibility of the peak in the cell type.

ad.X[:5, :5].todense()
matrix([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]], dtype=float32)

Filter peaks#

We will perform filtering of this dataset using the grelu.data.preprocess module.

First, we filter peaks within autosomes (chromosomes 1-22) or chromsomes X/Y. You can also supply autosomes or autosomesX to further restrict the chromosomes.

import grelu.data.preprocess

ad = grelu.data.preprocess.filter_chromosomes(ad, 'autosomesXY')
Keeping 1154464 intervals

Next, we drop peaks overlapping with the ENCODE blacklist regions for the hg38 genome.

ad = grelu.data.preprocess.filter_blacklist(ad, genome='hg38')
Keeping 1154464 intervals

Visualize data#

Next, we can plot the distribution of the data in various ways.

import grelu.visualize
%matplotlib inline

In how many accessible cell types is each peak accessible?

cell_types_per_peak = np.array(np.sum(ad.X > 0, axis=0))

grelu.visualize.plot_distribution(
    cell_types_per_peak,
    title='number of cell types per peak',
    method='density', # Alternative: histogram
    figsize=(3.6, 2), # width, height
)
../_images/243c0a551b4b0165d174b31fee579443be1be35240869a11e7865b49bba99ec3.png

How many peaks are accessible in each cell type?

fraction_accessible_per_cell_type = np.array(np.mean(ad.X > 0, axis=1))

grelu.visualize.plot_distribution(
    fraction_accessible_per_cell_type,
    title='Fraction of peaks accessible per cell type',
    method='histogram', # alternative: density
    binwidth=0.005,
    figsize=(3.6, 2), # width, height
)
../_images/8639fe25a4626315fe54993de8302bada1baa249149732aa38ffeac75f0ffad7.png

It seems that some cell types have very few accessible peaks. We will drop these cell types from the dataset.

print(ad.shape)
ad = ad[ad.X.mean(axis=1) > .03, :]
print(ad.shape)
(222, 1154464)
(203, 1154464)

We can plot the distribution once again to see the effect of this filtering:

fraction_accessible_per_cell_type = np.array(np.mean(ad.X > 0, axis=1))

grelu.visualize.plot_distribution(
    fraction_accessible_per_cell_type,
    title='Fraction of peaks accessible per cell type',
    method='histogram', # alternative: density
    binwidth=0.005,
    figsize=(3.6, 2), # width, height
)
../_images/df8df9bd13d7fc112998f661dabb765bd64a00de6e8572ba08fd73fc7d87840f.png

Resize peaks#

Finally, since the ATAC-seq peaks can have different lengths, we resize all of them to a constant length in order to train the model. Here, we take 200 bp. We can use the resize function in grelu.sequence.utils, which contains functions to manipulate DNA sequences.

import grelu.sequence.utils
seq_len = 200

ad.var = grelu.sequence.utils.resize(ad.var, seq_len)
ad.var.head(3)
chrom start end cre_class in_fetal in_adult cre_module
0 chr1 10055 10255 Promoter Proximal yes yes 146
1 chr1 29263 29463 Promoter yes yes 37
2 chr1 79315 79515 Distal no yes 75

Split data#

We will split the peaks by chromosome to create separate sets for training, validation and testing.

train_chroms='autosomes'
val_chroms=['chr10']
test_chroms=['chr11']

ad_train, ad_val, ad_test = grelu.data.preprocess.split(
    ad,
    train_chroms=train_chroms, val_chroms=val_chroms, test_chroms=test_chroms,
)
Selecting training samples
Keeping 1007912 intervals


Selecting validation samples
Keeping 56974 intervals


Selecting test samples
Keeping 56433 intervals
Final sizes: train: (203, 1007912), val: (203, 56974), test: (203, 56433)

Make labeled sequence datasets#

grelu.data.dataset contains PyTorch Dataset classes that can load and process genomic data for training a deep learning model. Here, we use the AnnDataSeqDataset class which loads data from an AnnData object. Other available dataset classes include DFSeqDataset and BigWigSeqDataset.

We first make the training dataset. To increase model robustness we use several forms of data augmentation here: rc=True (reverse complementing the input sequence), and max_seq_shift=1 (shifting the coordinates of the input sequence by upto 1 bp in either direction; also known as jitter). We use augmentation_mode="random" which means that at each iteration, the model sees a randomly selected augmented version for each sequence.

import grelu.data.dataset

train_dataset = grelu.data.dataset.AnnDataSeqDataset(
    ad_train.copy(),
    genome='hg38',
    rc=True, # reverse complement
    max_seq_shift=1, # Shift the sequence
    augment_mode="random", # Randomly select which augmentations to apply
)

We do not apply any augmentations to the validation and test datasets (although it is possible to do so).

val_dataset = grelu.data.dataset.AnnDataSeqDataset(ad_val.copy(), genome='hg38')
test_dataset = grelu.data.dataset.AnnDataSeqDataset(ad_test.copy(), genome='hg38')

Build the enformer model#

gReLU contains many model architectures. One of these is a class called EnformerPretrainedModel. This class creates a model identical to the published Enformer model and initialized with the trained weights, but where you can change the number of transformer layers and the output head.

Models are created using the grelu.lightning module. In order to instantiate a model, we need to specify model_params (parameters for the model architecture) and train_params (parameters for training).

model_params = {
    'model_type':'EnformerPretrainedModel', # Type of model
    'n_tasks': ad.shape[0], # Number of cell types to predict
    'crop_len':0, # No cropping of the model output
    'n_transformers': 1, # Number of transformer layers; the published Enformer model has 11
}

train_params = {
    'task':'binary', # binary classification
    'lr':1e-4, # learning rate
    'logger': 'csv', # Logs will be written to a CSV file
    'batch_size': 1024,
    'num_workers': 8,
    'devices': 0, # GPU index
    'save_dir': experiment,
    'optimizer': 'adam',
    'max_epochs': 10,
    'checkpoint': True, # Save checkpoints
}

import grelu.lightning
model = grelu.lightning.LightningModel(model_params=model_params, train_params=train_params)

Train the model#

We train the model on the new training dataset using the train_on_dataset method. Note that here, we update the weights of the entire model during training. If you want to hold the enformer weights fixed and only learn a linear layer from the enformer embedding to the outputs, see the tune_on_dataset method.

# See the 'tutorial_2' folder for logs
trainer = model.train_on_dataset(
    train_dataset=train_dataset,
    val_dataset=val_dataset
)
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:06<00:00,  8.18it/s]
Epoch 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.73it/s, v_num=0, train_loss_step=0.163]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.57it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.61it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.61it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.52it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.52it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.51it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.51it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.50it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.50it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.49it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.50it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.51it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.59it/s]
Epoch 1: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.72it/s, v_num=0, train_loss_step=0.170, train_loss_epoch=0.170]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.55it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.53it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.54it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.54it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.54it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.51it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.50it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.49it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.49it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:04<00:00, 13.50it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.50it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.58it/s]
Epoch 2: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:29<00:00,  4.71it/s, v_num=0, train_loss_step=0.162, train_loss_epoch=0.151]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.63it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.60it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.61it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.65it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.66it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.65it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.61it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.61it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.57it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.58it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.58it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.59it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.57it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.58it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.57it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.57it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.57it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.57it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.57it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.56it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.56it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.56it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.55it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.55it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:01<00:02, 13.55it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.55it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:01, 13.55it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.55it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.54it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.53it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.53it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.53it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.54it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.54it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.53it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.61it/s]
Epoch 3: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.71it/s, v_num=0, train_loss_step=0.142, train_loss_epoch=0.148]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.49it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.57it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.49it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.50it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:01<00:02, 13.51it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.50it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.52it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.52it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.52it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.60it/s]
Epoch 4: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.72it/s, v_num=0, train_loss_step=0.139, train_loss_epoch=0.146]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.37it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:04, 13.47it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.39it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.43it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.45it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.45it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:02<00:02, 13.44it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.44it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.45it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:04<00:00, 13.48it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.48it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.56it/s]
Epoch 5: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.72it/s, v_num=0, train_loss_step=0.147, train_loss_epoch=0.145]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.32it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:04, 13.08it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.25it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.34it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.34it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.36it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.39it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.36it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.33it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.36it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.37it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.38it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.38it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.40it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.40it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.40it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.40it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.41it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.40it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.40it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.41it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.41it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.44it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.45it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.47it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.48it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:04<00:00, 13.48it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.48it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.55it/s]
Epoch 6: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.72it/s, v_num=0, train_loss_step=0.142, train_loss_epoch=0.144]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:05, 10.87it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:04, 11.18it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:04, 11.33it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:04, 11.55it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:04, 11.69it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:04, 11.93it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:04, 12.11it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 12.29it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 12.36it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 12.48it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 12.58it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 12.66it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:01<00:03, 12.71it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 12.77it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 12.80it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:03, 12.85it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:03, 12.88it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 12.92it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 12.95it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 12.97it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 12.99it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.01it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.03it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.04it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.06it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.08it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:02<00:02, 13.09it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.10it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.11it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:02, 12.99it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.00it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.01it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.03it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.04it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.06it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.07it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.08it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.09it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.10it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:03<00:01, 13.11it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.12it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.13it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.14it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.15it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.16it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.16it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.17it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.18it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.19it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.19it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.20it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.21it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:04<00:00, 13.21it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:04<00:00, 13.22it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.22it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.30it/s]
Epoch 7: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.71it/s, v_num=0, train_loss_step=0.138, train_loss_epoch=0.143]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.62it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.63it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.64it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.52it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.58it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.54it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.54it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.55it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.54it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.54it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.54it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.53it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:01<00:02, 13.52it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.52it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.52it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.51it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.51it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.51it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:03<00:00, 13.53it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.53it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.61it/s]
Epoch 8: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.71it/s, v_num=0, train_loss_step=0.158, train_loss_epoch=0.142]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 12.59it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:04, 13.09it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:04, 13.07it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:04, 12.94it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.06it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.17it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.23it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.26it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.26it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.30it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.28it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.29it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.29it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.30it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.32it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:03, 13.31it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.33it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.34it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.35it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.35it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.35it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.36it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.36it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.36it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.36it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:02<00:02, 13.36it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.36it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.36it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.36it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.35it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.35it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.36it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.38it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.39it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.48it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:28<00:00,  4.72it/s, v_num=0, train_loss_step=0.151, train_loss_epoch=0.141]
Validation: |                                                                                                                                       | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                                                                                                  | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                                                                                                     | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|█▉                                                                                                           | 1/56 [00:00<00:04, 13.52it/s]
Validation DataLoader 0:   4%|███▉                                                                                                         | 2/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:   5%|█████▊                                                                                                       | 3/56 [00:00<00:03, 13.59it/s]
Validation DataLoader 0:   7%|███████▊                                                                                                     | 4/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:   9%|█████████▋                                                                                                   | 5/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:  11%|███████████▋                                                                                                 | 6/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:  12%|█████████████▋                                                                                               | 7/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:  14%|███████████████▌                                                                                             | 8/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  16%|█████████████████▌                                                                                           | 9/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  18%|███████████████████▎                                                                                        | 10/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  20%|█████████████████████▏                                                                                      | 11/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  21%|███████████████████████▏                                                                                    | 12/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:  23%|█████████████████████████                                                                                   | 13/56 [00:00<00:03, 13.49it/s]
Validation DataLoader 0:  25%|███████████████████████████                                                                                 | 14/56 [00:01<00:03, 13.48it/s]
Validation DataLoader 0:  27%|████████████████████████████▉                                                                               | 15/56 [00:01<00:03, 13.48it/s]
Validation DataLoader 0:  29%|██████████████████████████████▊                                                                             | 16/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  30%|████████████████████████████████▊                                                                           | 17/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  32%|██████████████████████████████████▋                                                                         | 18/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  34%|████████████████████████████████████▋                                                                       | 19/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  36%|██████████████████████████████████████▌                                                                     | 20/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  38%|████████████████████████████████████████▌                                                                   | 21/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  39%|██████████████████████████████████████████▍                                                                 | 22/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  41%|████████████████████████████████████████████▎                                                               | 23/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  43%|██████████████████████████████████████████████▎                                                             | 24/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  45%|████████████████████████████████████████████████▏                                                           | 25/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  46%|██████████████████████████████████████████████████▏                                                         | 26/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  48%|████████████████████████████████████████████████████                                                        | 27/56 [00:02<00:02, 13.48it/s]
Validation DataLoader 0:  50%|██████████████████████████████████████████████████████                                                      | 28/56 [00:02<00:02, 13.48it/s]
Validation DataLoader 0:  52%|███████████████████████████████████████████████████████▉                                                    | 29/56 [00:02<00:02, 13.48it/s]
Validation DataLoader 0:  54%|█████████████████████████████████████████████████████████▊                                                  | 30/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  55%|███████████████████████████████████████████████████████████▊                                                | 31/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  57%|█████████████████████████████████████████████████████████████▋                                              | 32/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  59%|███████████████████████████████████████████████████████████████▋                                            | 33/56 [00:02<00:01, 13.47it/s]
Validation DataLoader 0:  61%|█████████████████████████████████████████████████████████████████▌                                          | 34/56 [00:02<00:01, 13.47it/s]
Validation DataLoader 0:  62%|███████████████████████████████████████████████████████████████████▌                                        | 35/56 [00:02<00:01, 13.47it/s]
Validation DataLoader 0:  64%|█████████████████████████████████████████████████████████████████████▍                                      | 36/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  66%|███████████████████████████████████████████████████████████████████████▎                                    | 37/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  68%|█████████████████████████████████████████████████████████████████████████▎                                  | 38/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  70%|███████████████████████████████████████████████████████████████████████████▏                                | 39/56 [00:02<00:01, 13.48it/s]
Validation DataLoader 0:  71%|█████████████████████████████████████████████████████████████████████████████▏                              | 40/56 [00:02<00:01, 13.49it/s]
Validation DataLoader 0:  73%|███████████████████████████████████████████████████████████████████████████████                             | 41/56 [00:03<00:01, 13.49it/s]
Validation DataLoader 0:  75%|█████████████████████████████████████████████████████████████████████████████████                           | 42/56 [00:03<00:01, 13.49it/s]
Validation DataLoader 0:  77%|██████████████████████████████████████████████████████████████████████████████████▉                         | 43/56 [00:03<00:00, 13.49it/s]
Validation DataLoader 0:  79%|████████████████████████████████████████████████████████████████████████████████████▊                       | 44/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  80%|██████████████████████████████████████████████████████████████████████████████████████▊                     | 45/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  82%|████████████████████████████████████████████████████████████████████████████████████████▋                   | 46/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  84%|██████████████████████████████████████████████████████████████████████████████████████████▋                 | 47/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  86%|████████████████████████████████████████████████████████████████████████████████████████████▌               | 48/56 [00:03<00:00, 13.50it/s]
Validation DataLoader 0:  88%|██████████████████████████████████████████████████████████████████████████████████████████████▌             | 49/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  89%|████████████████████████████████████████████████████████████████████████████████████████████████▍           | 50/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  91%|██████████████████████████████████████████████████████████████████████████████████████████████████▎         | 51/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  93%|████████████████████████████████████████████████████████████████████████████████████████████████████▎       | 52/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  95%|██████████████████████████████████████████████████████████████████████████████████████████████████████▏     | 53/56 [00:03<00:00, 13.51it/s]
Validation DataLoader 0:  96%|████████████████████████████████████████████████████████████████████████████████████████████████████████▏   | 54/56 [00:03<00:00, 13.52it/s]
Validation DataLoader 0:  98%|██████████████████████████████████████████████████████████████████████████████████████████████████████████  | 55/56 [00:04<00:00, 13.52it/s]
Validation DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:04<00:00, 13.60it/s]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 985/985 [03:56<00:00,  4.16it/s, v_num=0, train_loss_step=0.151, train_loss_epoch=0.140]
Epoch 9: 100%|██████████████████████████████████████████████████████████████████| 985/985 [04:02<00:00,  4.07it/s, v_num=0, train_loss_step=0.151, train_loss_epoch=0.140]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃      Validate metric             DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       val_accuracy            0.5015555024147034     │
│         val_auroc             0.5114357471466064     │
│        val_avgprec             0.110662542283535     │
│        val_best_f1            0.17474424839019775    │
│         val_loss              0.6944765448570251     │
└───────────────────────────┴───────────────────────────┘

Load best model from checkpoint#

During training, the performance of the model on the validation set is checked after each epoch. We will load the version of the model that had the best validation set performance.

best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
model = grelu.lightning.LightningModel.load_from_checkpoint(best_checkpoint)
tutorial_2/2024_28_05_05_46/version_0/checkpoints/epoch=9-step=9850.ckpt

Calculate performance metrics on the test set#

We calculate global performance metrics using the test_on_dataset method.

test_metrics = model.test_on_dataset(
    test_dataset,
    devices=0,
    num_workers=8,
    batch_size=1024,
)
Testing DataLoader 0: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:06<00:00,  9.26it/s]
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃        Test metric               DataLoader 0        ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩
│       test_accuracy           0.9454869627952576     │
│        test_auroc             0.9040013551712036     │
│       test_avgprec            0.6041162610054016     │
│       test_best_f1            0.5663296580314636     │
│         test_loss             0.15561392903327942    │
└───────────────────────────┴───────────────────────────┘

test_metrics is a dataframe containing metrics for each model task.

test_metrics.head()
test_accuracy test_auroc test_avgprec test_best_f1
Follicular 0.908617 0.879266 0.626482 0.578347
Fibro General 0.915900 0.881493 0.621228 0.574634
Acinar 0.943260 0.913378 0.644469 0.599259
T Lymphocyte 1 (CD8+) 0.961228 0.930512 0.632422 0.591905
T lymphocyte 2 (CD4+) 0.966917 0.939098 0.618306 0.571317

Visualize performance metrics#

We can plot the distribution of each metric across all cell types:

grelu.visualize.plot_distribution(
    test_metrics.test_best_f1,
    method='histogram',
    title='Best F1 Score per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/1586b59295f0fa9085c1191aa61d06d881389c19bd25b3916dacaa7f53ac37f9.png
grelu.visualize.plot_distribution(
    test_metrics.test_avgprec,
    method='histogram',
    title='Average Precision per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/a6e18102767da3d1a3a3e494083826b37d95959075649b046d672d1983e090ed.png
grelu.visualize.plot_distribution(
    test_metrics.test_auroc,
    method='histogram',
    title='AUROC per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/c443c7e0c686335938a77190631d52d4d14bbf35a2780614bc5e47c46fbd99d1.png

Run inference on the test set#

Instead of overall metrics, we can also get the individual predictions for each test set example.

probs = model.predict_on_dataset(
    test_dataset,
    devices=0,
    num_workers=8,
    batch_size=1024,
    return_df=True # Return the output as a pandas dataframe
)

probs.head()
Predicting DataLoader 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████| 56/56 [00:03<00:00, 14.23it/s]
Follicular Fibro General Acinar T Lymphocyte 1 (CD8+) T lymphocyte 2 (CD4+) Natural Killer T Naive T Fibro Epithelial Cardiac Pericyte 1 Pericyte General 1 ... Fetal Cardiac Fibroblast Fetal Fibro General 2 Fetal Fibro Muscle 1 Fetal Fibro General 3 Fetal Mesangial 2 Fetal Stellate Fetal Alveolar Epithelial 1 Fetal Cilliated Fetal Excitatory Neuron 1 Fetal Excitatory Neuron 2
0 0.391466 0.736395 0.667206 0.835141 0.652502 0.554888 0.590892 0.710168 0.323553 0.747944 ... 0.080258 0.117112 0.128430 0.094408 0.126652 0.017271 0.113082 0.067970 0.145862 0.055562
1 0.012470 0.029601 0.002608 0.090092 0.132307 0.053100 0.032421 0.008285 0.046283 0.003837 ... 0.070223 0.019960 0.011101 0.017780 0.007612 0.019467 0.024475 0.010269 0.029854 0.013195
2 0.013281 0.017351 0.006397 0.045844 0.026573 0.054169 0.011686 0.026472 0.093590 0.009771 ... 0.031002 0.012968 0.014757 0.025925 0.028630 0.009786 0.020447 0.011602 0.294897 0.037187
3 0.821954 0.810962 0.875161 0.829395 0.839307 0.715744 0.766881 0.691507 0.757547 0.603458 ... 0.784233 0.813539 0.717051 0.807969 0.775087 0.652014 0.875736 0.851576 0.727039 0.445682
4 0.237077 0.326297 0.346979 0.439788 0.319168 0.176935 0.330699 0.198229 0.273167 0.124448 ... 0.149903 0.158350 0.106138 0.195934 0.105235 0.103239 0.310690 0.277557 0.134348 0.038449

5 rows × 203 columns

Since this is a binary classification model, the output takes the form of probabilities ranging from 0 to 1. We interpret these as the predicted probabilities of the element being accessible in each cell type.

Plot additional visualizations of the test set predictions#

We can plot a calibration curve for all the tasks. This shows us the fraction of true positive examples, for different levels of model-predicted probability.

grelu.visualize.plot_calibration_curve(
    probs, labels=test_dataset.labels, aggregate=False, show_legend=False
)
../_images/a394de89206b1d8d796e0c931bd4621e40efcc93664a2ac1bd9a5c6aeb4d0717.png

We can also pick any cell type and compare the predictions on accessible and non-accessible elements.

grelu.visualize.plot_binary_preds(
    probs,
    labels=test_dataset.labels,
    tasks='Microglia',
    figure_size=(3, 2) # Width, height
)
../_images/68f227cc3b4f79102b8902ada1717bdbc1390438db5cc9535d6bf0aff2e22aae.png

Interpret model predictions (for microglia) using TF-modisco#

Suppose we want to focus specifically on Microglia. We can create a transform - a class that takes in the model’s prediction and returns a function of the prediction, e.g. the prediction for only a subset of cell types that we are interested in. Then, all subsquent analyses that we do will be based only on this subset of the full prediction matrix.

Here, we use the Aggregate class to create a transform that will take in the full set of predictions produced by the model and subset only the predictions in Microglia.

from grelu.transforms.prediction_transforms import Aggregate

microglia_score = Aggregate(
    tasks = ["Microglia"],
    model = model,
)

Let us now identify all peaks in the test set that are accessible in microglia.

is_accessible_in_mcg = (ad_test["Microglia", :].X.A==1).squeeze()
mcg_peaks = ad_test.var[is_accessible_in_mcg]
len(mcg_peaks)
2273

We now run TF-Modisco on these peaks. TF-modisco identifies motifs that consistently contribute to the model’s output. Since we are using the microglia_score filter to limit the model’s prediction to microglia, we will only get motifs relevant to this cell type. We also use TOMTOM to match the TF-Modisco motifs to a set of reference motifs. Here, we use a reference set of non-redundant motifs (https://www.vierstra.org/resources/motif_clustering) that are provided with gReLU as consensus.

%%time
import grelu.interpret.score

grelu.interpret.score.run_modisco(
    model,
    seqs=mcg_peaks, 
    genome="hg38",
    prediction_transform=microglia_score, # Base importance scores will be calculated with respect to this output
    meme_file="consensus", # We will compare the Modisco CWMs to consensus motifs
    method="ism", # Base-level attribution scores will be calculated using ISM. You can also use "deepshap".
    out_dir=experiment,
    batch_size=1024,
    devices=0,
    window=100, # ISM scores will be calculated over the central 100 bp of each peak
    seed=0,
)
Performing ISM
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:00<00:00, 15.63it/s]
Predicting DataLoader 0: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 666/666 [00:47<00:00, 13.99it/s]
Running modisco
Writing modisco output
Making report
CPU times: user 1h 56min 38s, sys: 57.6 s, total: 1h 57min 35s
Wall time: 2min 19s

Load TOMTOM output for modisco motifs#

The full output of TF-Modisco and TOMTOM can be found in the experiment folder. Here, we read the output of TOMTOM and list the significant TOMTOM matches, i.e. known TF motifs that are similar to those found by TF-MoDISco.

import grelu.io
tomtom_dir = os.path.join(experiment, 'tomtom')

tomtom = grelu.io.read_tomtom(tomtom_dir, qthresh=.01) # Read only significant matches
tomtom
Query_ID Target_ID Optimal_offset p-value E-value q-value Overlap Query_consensus Target_consensus Orientation
0 pos_patterns.pattern_0 AC0227:SPI_BCL11A:Ets -1.0 1.487280e-06 9.473990e-04 1.642520e-03 10.0 AAAGAGGAAGT AAGAGGAAGTGA +
1 pos_patterns.pattern_0 AC0622:ELF_SPIB:Ets -1.0 2.578520e-06 1.642520e-03 1.642520e-03 10.0 AAAGAGGAAGT AAGAGGAAGT +
31 pos_patterns.pattern_1 AC0078:CTCF_CTCFL:C2H2_ZF 0.0 7.231420e-10 4.606410e-07 9.212820e-07 13.0 TGCCCTCTAGTGG CGCCCCCTGGTGG +
44 pos_patterns.pattern_3 AC0078:CTCF_CTCFL:C2H2_ZF -1.0 4.728010e-06 3.011740e-03 5.996730e-03 13.0 GCTCCCCCTCGTGG CGCCCCCTGGTGG +