Fine-tune Enformer to perform binary classification on human snATAC-seq data#

In this tutorial, we learn how to fine-tune the Enformer model (https://www.nature.com/articles/s41592-021-01252-x) on the CATLAS single-nucleus ATAC-seq dataset from human tissues (http://catlas.org/humanenhancer/).

We will perform binary classification, in which the model learns to predict the probability that a given sequence is accessible in the different cell types of the dataset. For an example of regression modeling, see the next tutorial.

import anndata
import os
import importlib
import pandas as pd
import numpy as np

%matplotlib inline

Set experiment parameters#

experiment='tutorial_2'
if not os.path.exists(experiment):
    os.makedirs(experiment)

Download data#

We download the CATlas ATAC-seq binary cell type x peak matrix from the gReLU model zoo. For more details on downloading models and datasets from the zoo, see Tutorial 6. The original source of this data is https://decoder-genetics.wustl.edu/catlasv1/humanenhancer/data/cCRE_by_cell_type/. For more details, see https://decoder-genetics.wustl.edu/catlasv1/catlas_humanenhancer/#!/.

import grelu.resources
artifact = grelu.resources.get_artifact(
    name="dataset",
    project = 'tutorial-2',
)
artifact_dir = artifact.download()
os.listdir(artifact_dir)
['data.h5ad']

Load data#

As you see above, we downloaded an h5ad file data.h5ad containing the binaized chromatin accessibility data for multiple human cell types. We now load this as an anndata object.

ad = anndata.read_h5ad(os.path.join(artifact_dir, 'data.h5ad'))
ad
AnnData object with n_obs × n_vars = 222 × 1154611
    obs: 'cell type'
    var: 'chrom', 'start', 'end', 'Class', 'Present in fetal tissues', 'Present in adult tissues', 'CRE module', 'width'

This contains a binary matrix representing the accessibility of 1154611 CREs measured in 222 cell types. Let us look at the components of this object:

ad.var.head()
chrom start end Class Present in fetal tissues Present in adult tissues CRE module width
0 chr1 9955 10355 Promoter Proximal yes yes 146 400
1 chr1 29163 29563 Promoter yes yes 37 400
2 chr1 79215 79615 Distal no yes 75 400
3 chr1 102755 103155 Distal no yes 51 400
4 chr1 115530 115930 Distal yes no 36 400
ad.obs.head()
cell type
cell type
Follicular Follicular
Fibro General Fibro General
Acinar Acinar
T Lymphocyte 1 (CD8+) T Lymphocyte 1 (CD8+)
T lymphocyte 2 (CD4+) T lymphocyte 2 (CD4+)

The contents of this anndata object are binary values (0 or 1). 1 indicates accessibility of the peak in the cell type.

ad.X[:5, :5].todense()
matrix([[1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0.]], dtype=float32)

Filter peaks#

We will perform filtering of this dataset using the grelu.data.preprocess module.

First, we filter peaks within autosomes (chromosomes 1-22) or chromsomes X/Y. You can also supply autosomes or autosomesX to further restrict the chromosomes.

import grelu.data.preprocess

ad = grelu.data.preprocess.filter_chromosomes(ad, 'autosomesXY')
Keeping 1154464 intervals

Next, we drop peaks overlapping with the ENCODE blacklist regions for the hg38 genome.

ad = grelu.data.preprocess.filter_blacklist(ad, genome='hg38')
Keeping 1154464 intervals

Visualize data#

Next, we can plot the distribution of the data in various ways.

import grelu.visualize
%matplotlib inline

In how many accessible cell types is each peak accessible?

cell_types_per_peak = np.array(np.sum(ad.X > 0, axis=0))

grelu.visualize.plot_distribution(
    cell_types_per_peak,
    title='number of cell types per peak',
    method='density', # Alternative: histogram
    figsize=(3.6, 2), # width, height
)
../_images/a3de921d1fa9c007ea2387a655ffe5aaaaaae830438fc1846918d9cfa746cfea.png

How many peaks are accessible in each cell type?

fraction_accessible_per_cell_type = np.array(np.mean(ad.X > 0, axis=1))

grelu.visualize.plot_distribution(
    fraction_accessible_per_cell_type,
    title='Fraction of peaks accessible per cell type',
    method='histogram', # alternative: density
    binwidth=0.005,
    figsize=(3.6, 2), # width, height
)
../_images/2c2023f9f8b1d4856d13c61d21cf8fd603eb6c30ae96ca2a5783066d08d40633.png

It seems that some cell types have very few accessible peaks. We will drop these cell types from the dataset.

print(ad.shape)
ad = ad[ad.X.mean(axis=1) > .03, :]
print(ad.shape)
(222, 1154464)
(203, 1154464)

We can plot the distribution once again to see the effect of this filtering:

fraction_accessible_per_cell_type = np.array(np.mean(ad.X > 0, axis=1))

grelu.visualize.plot_distribution(
    fraction_accessible_per_cell_type,
    title='Fraction of peaks accessible per cell type',
    method='histogram', # alternative: density
    binwidth=0.005,
    figsize=(3.6, 2), # width, height
)
../_images/c3dec219deb31ab6c83f9cc9e1d3b1311c9372925aefda0269e9e9c7aaeeb7a6.png

Resize peaks#

Finally, since the ATAC-seq peaks can have different lengths, we resize all of them to a constant length in order to train the model. Here, we take 200 bp. We can use the resize function in grelu.sequence.utils, which contains functions to manipulate DNA sequences.

import grelu.sequence.utils
seq_len = 200

ad.var = grelu.sequence.utils.resize(ad.var, seq_len)
ad.var.head(3)
chrom start end Class Present in fetal tissues Present in adult tissues CRE module width
0 chr1 10055 10255 Promoter Proximal yes yes 146 400
1 chr1 29263 29463 Promoter yes yes 37 400
2 chr1 79315 79515 Distal no yes 75 400

Split data#

We will split the peaks by chromosome to create separate sets for training, validation and testing.

train_chroms='autosomes'
val_chroms=['chr10']
test_chroms=['chr11']

ad_train, ad_val, ad_test = grelu.data.preprocess.split(
    ad,
    train_chroms=train_chroms, val_chroms=val_chroms, test_chroms=test_chroms,
)
Selecting training samples
Keeping 1007912 intervals


Selecting validation samples
Keeping 56974 intervals


Selecting test samples
Keeping 56433 intervals
Final sizes: train: (203, 1007912), val: (203, 56974), test: (203, 56433)

Make labeled sequence datasets#

grelu.data.dataset contains PyTorch Dataset classes that can load and process genomic data for training a deep learning model. Here, we use the AnnDataSeqDataset class which loads data from an AnnData object. Other available dataset classes include DFSeqDataset and BigWigSeqDataset.

We first make the training dataset. To increase model robustness we use several forms of data augmentation here: rc=True (reverse complementing the input sequence), and max_seq_shift=1 (shifting the coordinates of the input sequence by upto 1 bp in either direction; also known as jitter). We use augmentation_mode="random" which means that at each iteration, the model sees a randomly selected augmented version for each sequence.

import grelu.data.dataset
train_dataset = grelu.data.dataset.AnnDataSeqDataset(
    ad_train.copy(),
    genome='hg38',
    rc=True, # reverse complement
    max_seq_shift=1, # Shift the sequence
    augment_mode="random", # Randomly select which augmentations to apply
)

We do not apply any augmentations to the validation and test datasets (although it is possible to do so).

val_dataset = grelu.data.dataset.AnnDataSeqDataset(ad_val.copy(), genome='hg38')
test_dataset = grelu.data.dataset.AnnDataSeqDataset(ad_test.copy(), genome='hg38')

Build the enformer model#

gReLU contains many model architectures. One of these is a class called EnformerPretrainedModel. This class creates a model identical to the published Enformer model and initialized with the trained weights, but where you can change the number of transformer layers and the output head.

Models are created using the grelu.lightning module. In order to instantiate a model, we need to specify model_params (parameters for the model architecture) and train_params (parameters for training).

model_params = {
    'model_type':'EnformerPretrainedModel', # Type of model
    'n_tasks': ad.shape[0], # Number of cell types to predict
    'crop_len':0, # No cropping of the model output
    'n_transformers': 1, # Number of transformer layers; the published Enformer model has 11
}

train_params = {
    'task':'binary', # binary classification
    'lr':1e-4, # learning rate
    'logger': 'csv', # Logs will be written to a CSV file
    'batch_size': 1024,
    'num_workers': 8,
    'devices': 0, # GPU index
    'save_dir': experiment,
    'optimizer': 'adam',
    'max_epochs': 10,
    'checkpoint': True, # Save checkpoints
}

import grelu.lightning
model = grelu.lightning.LightningModel(model_params=model_params, train_params=train_params)

Train the model#

We train the model on the new training dataset using the train_on_dataset method. Note that here, we update the weights of the entire model during training. If you want to hold the enformer weights fixed and only learn a linear layer from the enformer embedding to the outputs, see the tune_on_dataset method.

# See the 'tutorial_2' folder for logs
trainer = model.train_on_dataset(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
)
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:07<00:00,  7.67it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      val_accuracy           0.499550461769104
        val_auroc           0.5052121877670288
       val_avgprec          0.11139532178640366
       val_best_f1          0.17386437952518463
        val_loss            0.6942023038864136
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
Epoch 0: 100%|███████████| 985/985 [03:12<00:00,  5.11it/s, v_num=0, train_loss_step=0.151]
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.49it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 13.31it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.33it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.31it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.35it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.35it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.30it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.19it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.22it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.23it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.25it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.25it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.27it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:03, 13.27it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.27it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.28it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.30it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.30it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.31it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.31it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.32it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.33it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.31it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.32it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.31it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.26it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:03<00:01, 13.26it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.27it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.27it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.28it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.29it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.29it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.29it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.29it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.29it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.29it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.30it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.38it/s]
Epoch 1: 100%|█| 985/985 [03:14<00:00,  5.07it/s, v_num=0, train_loss_step=0.142, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.40it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 13.47it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.44it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.40it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.40it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.42it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.42it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.50it/s]
Epoch 2: 100%|█| 985/985 [03:13<00:00,  5.08it/s, v_num=0, train_loss_step=0.144, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.09it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 12.62it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:04, 12.82it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.02it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.08it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.16it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.22it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.24it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.23it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.23it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.25it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.26it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.27it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.29it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.29it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:03, 13.31it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.31it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.31it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.33it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.33it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.34it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.35it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.35it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.36it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.36it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.36it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.38it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.36it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.37it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.36it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.36it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.36it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.36it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.36it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.36it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.36it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.36it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.37it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.38it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.38it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.46it/s]
Epoch 3: 100%|█| 985/985 [03:14<00:00,  5.07it/s, v_num=0, train_loss_step=0.146, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.61it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:03, 13.52it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.58it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.54it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.52it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.48it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.49it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.46it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.46it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.46it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.46it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.45it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.45it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.42it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.41it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.41it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.41it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.48it/s]
Epoch 4: 100%|█| 985/985 [03:14<00:00,  5.07it/s, v_num=0, train_loss_step=0.148, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:03, 13.83it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 13.18it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.35it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.45it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.46it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.41it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.40it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.39it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.39it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.39it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.40it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.48it/s]
Epoch 5: 100%|█| 985/985 [03:13<00:00,  5.08it/s, v_num=0, train_loss_step=0.137, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:03, 13.76it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.39it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.37it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.34it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.39it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.43it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.42it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.42it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.44it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.44it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.52it/s]
Epoch 6: 100%|█| 985/985 [03:13<00:00,  5.08it/s, v_num=0, train_loss_step=0.146, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.73it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 13.43it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.42it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.44it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.43it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.45it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.44it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.44it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.43it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.43it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.42it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.42it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.41it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.42it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.42it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.42it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.43it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.44it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.44it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.44it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.52it/s]
Epoch 7: 100%|█| 985/985 [03:13<00:00,  5.08it/s, v_num=0, train_loss_step=0.134, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.65it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:03, 13.56it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.55it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.45it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.46it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.47it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.48it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.49it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.50it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.50it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.49it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.49it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.48it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.47it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.47it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.46it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.46it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.46it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.46it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.46it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.46it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.46it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.54it/s]
Epoch 8: 100%|█| 985/985 [03:13<00:00,  5.08it/s, v_num=0, train_loss_step=0.138, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.52it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:04, 13.43it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.50it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.53it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.41it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.32it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.32it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.34it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.35it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.36it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.37it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.37it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.38it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.37it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.37it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.32it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.31it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.30it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.28it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.29it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.29it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.30it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.30it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.30it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.29it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.28it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.28it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.29it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.29it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.29it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.30it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.30it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.30it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:03<00:01, 13.29it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.29it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.30it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.30it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.31it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.31it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.32it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.31it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.31it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.32it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.32it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.32it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.33it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.33it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.33it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.33it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.41it/s]
Epoch 9: 100%|█| 985/985 [03:14<00:00,  5.08it/s, v_num=0, train_loss_step=0.129, train_los
Validation: |                                                        | 0/? [00:00<?, ?it/s]
Validation:   0%|                                                   | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   0%|                                      | 0/56 [00:00<?, ?it/s]
Validation DataLoader 0:   2%|▌                             | 1/56 [00:00<00:04, 13.58it/s]
Validation DataLoader 0:   4%|█                             | 2/56 [00:00<00:03, 13.59it/s]
Validation DataLoader 0:   5%|█▌                            | 3/56 [00:00<00:03, 13.51it/s]
Validation DataLoader 0:   7%|██▏                           | 4/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:   9%|██▋                           | 5/56 [00:00<00:03, 13.38it/s]
Validation DataLoader 0:  11%|███▏                          | 6/56 [00:00<00:03, 13.37it/s]
Validation DataLoader 0:  12%|███▊                          | 7/56 [00:00<00:03, 13.38it/s]
Validation DataLoader 0:  14%|████▎                         | 8/56 [00:00<00:03, 13.40it/s]
Validation DataLoader 0:  16%|████▊                         | 9/56 [00:00<00:03, 13.36it/s]
Validation DataLoader 0:  18%|█████▏                       | 10/56 [00:00<00:03, 13.37it/s]
Validation DataLoader 0:  20%|█████▋                       | 11/56 [00:00<00:03, 13.39it/s]
Validation DataLoader 0:  21%|██████▏                      | 12/56 [00:00<00:03, 13.38it/s]
Validation DataLoader 0:  23%|██████▋                      | 13/56 [00:00<00:03, 13.37it/s]
Validation DataLoader 0:  25%|███████▎                     | 14/56 [00:01<00:03, 13.38it/s]
Validation DataLoader 0:  27%|███████▊                     | 15/56 [00:01<00:03, 13.38it/s]
Validation DataLoader 0:  29%|████████▎                    | 16/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  30%|████████▊                    | 17/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  32%|█████████▎                   | 18/56 [00:01<00:02, 13.38it/s]
Validation DataLoader 0:  34%|█████████▊                   | 19/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  36%|██████████▎                  | 20/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  38%|██████████▉                  | 21/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  39%|███████████▍                 | 22/56 [00:01<00:02, 13.38it/s]
Validation DataLoader 0:  41%|███████████▉                 | 23/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  43%|████████████▍                | 24/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  45%|████████████▉                | 25/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  46%|█████████████▍               | 26/56 [00:01<00:02, 13.39it/s]
Validation DataLoader 0:  48%|█████████████▉               | 27/56 [00:02<00:02, 13.39it/s]
Validation DataLoader 0:  50%|██████████████▌              | 28/56 [00:02<00:02, 13.39it/s]
Validation DataLoader 0:  52%|███████████████              | 29/56 [00:02<00:02, 13.39it/s]
Validation DataLoader 0:  54%|███████████████▌             | 30/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  55%|████████████████             | 31/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  57%|████████████████▌            | 32/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  59%|█████████████████            | 33/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  61%|█████████████████▌           | 34/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  62%|██████████████████▏          | 35/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  64%|██████████████████▋          | 36/56 [00:02<00:01, 13.39it/s]
Validation DataLoader 0:  66%|███████████████████▏         | 37/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  68%|███████████████████▋         | 38/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  70%|████████████████████▏        | 39/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  71%|████████████████████▋        | 40/56 [00:02<00:01, 13.40it/s]
Validation DataLoader 0:  73%|█████████████████████▏       | 41/56 [00:03<00:01, 13.40it/s]
Validation DataLoader 0:  75%|█████████████████████▊       | 42/56 [00:03<00:01, 13.40it/s]
Validation DataLoader 0:  77%|██████████████████████▎      | 43/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  79%|██████████████████████▊      | 44/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  80%|███████████████████████▎     | 45/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  82%|███████████████████████▊     | 46/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  84%|████████████████████████▎    | 47/56 [00:03<00:00, 13.40it/s]
Validation DataLoader 0:  86%|████████████████████████▊    | 48/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  88%|█████████████████████████▍   | 49/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  89%|█████████████████████████▉   | 50/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  91%|██████████████████████████▍  | 51/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  93%|██████████████████████████▉  | 52/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  95%|███████████████████████████▍ | 53/56 [00:03<00:00, 13.41it/s]
Validation DataLoader 0:  96%|███████████████████████████▉ | 54/56 [00:04<00:00, 13.41it/s]
Validation DataLoader 0:  98%|████████████████████████████▍| 55/56 [00:04<00:00, 13.41it/s]
Validation DataLoader 0: 100%|█████████████████████████████| 56/56 [00:04<00:00, 13.49it/s]
Epoch 9: 100%|█| 985/985 [03:20<00:00,  4.91it/s, v_num=0, train_loss_step=0.129, train_los
Epoch 9: 100%|█| 985/985 [03:26<00:00,  4.77it/s, v_num=0, train_loss_step=0.129, train_los

Load best model from checkpoint#

During training, the performance of the model on the validation set is checked after each epoch. We will load the version of the model that had the best validation set performance.

best_checkpoint = trainer.checkpoint_callback.best_model_path
print(best_checkpoint)
model = grelu.lightning.LightningModel.load_from_checkpoint(best_checkpoint)
tutorial_2/2025_26_06_23_37/version_0/checkpoints/epoch=7-step=7880.ckpt

Calculate performance metrics on the test set#

We calculate global performance metrics using the test_on_dataset method.

test_metrics = model.test_on_dataset(
    test_dataset,
    devices=0,
    num_workers=8,
    batch_size=1024,
)
Testing DataLoader 0: 100%|████████████████████████████████| 56/56 [00:07<00:00,  7.82it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
      test_accuracy         0.9458227157592773
       test_auroc           0.9040488600730896
      test_avgprec           0.605280876159668
      test_best_f1          0.5670120120048523
        test_loss           0.15482398867607117
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

test_metrics is a dataframe containing metrics for each model task.

test_metrics.head()
test_accuracy test_auroc test_avgprec test_best_f1
Follicular 0.908741 0.879009 0.627896 0.580418
Fibro General 0.916751 0.882644 0.624048 0.575924
Acinar 0.944270 0.913380 0.649751 0.601112
T Lymphocyte 1 (CD8+) 0.961264 0.929714 0.631676 0.593047
T lymphocyte 2 (CD4+) 0.966828 0.938453 0.616687 0.570218

Visualize performance metrics#

We can plot the distribution of each metric across all cell types:

grelu.visualize.plot_distribution(
    test_metrics.test_best_f1,
    method='histogram',
    title='Best F1 Score per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/59fc0f08c28e7d596a467550910e3ae999122f2b9ac22af9e06a29f3be5456d8.png
grelu.visualize.plot_distribution(
    test_metrics.test_avgprec,
    method='histogram',
    title='Average Precision per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/a7af3bd1e737fd0685636c76085c7672a618f10aa995b4f5d02055b2e3ee683c.png
grelu.visualize.plot_distribution(
    test_metrics.test_auroc,
    method='histogram',
    title='AUROC per task',
    binwidth=0.005,
    figsize=(3,2),
)
../_images/499f7fd514fe056ab4ba1dbecc51db35902cf428b24f08a07ebbd8135685934b.png

Run inference on the test set#

Instead of overall metrics, we can also get the individual predictions for each test set example.

probs = model.predict_on_dataset(
    test_dataset,
    devices=0,
    num_workers=8,
    batch_size=1024,
    return_df=True # Return the output as a pandas dataframe
)

probs.head()
Predicting DataLoader 0: 100%|█████████████████████████████| 56/56 [00:03<00:00, 14.31it/s]
Follicular Fibro General Acinar T Lymphocyte 1 (CD8+) T lymphocyte 2 (CD4+) Natural Killer T Naive T Fibro Epithelial Cardiac Pericyte 1 Pericyte General 1 ... Fetal Cardiac Fibroblast Fetal Fibro General 2 Fetal Fibro Muscle 1 Fetal Fibro General 3 Fetal Mesangial 2 Fetal Stellate Fetal Alveolar Epithelial 1 Fetal Cilliated Fetal Excitatory Neuron 1 Fetal Excitatory Neuron 2
0 0.339327 0.677568 0.594792 0.690015 0.485817 0.296680 0.412475 0.608682 0.313247 0.704991 ... 0.050871 0.067126 0.103167 0.081013 0.114244 0.023298 0.083920 0.041261 0.194604 0.082871
1 0.010077 0.026092 0.003200 0.134437 0.152224 0.066077 0.038752 0.009628 0.054473 0.007203 ... 0.069931 0.027279 0.020134 0.025630 0.011467 0.014433 0.022086 0.014063 0.032595 0.012260
2 0.018396 0.030480 0.009119 0.026245 0.035581 0.041032 0.014913 0.019749 0.132521 0.010570 ... 0.047521 0.014308 0.018792 0.035934 0.019014 0.007841 0.017029 0.007948 0.333166 0.026401
3 0.598535 0.685552 0.684242 0.665471 0.679174 0.539423 0.570127 0.532968 0.625712 0.516697 ... 0.630569 0.611157 0.587350 0.698837 0.594404 0.536245 0.757952 0.664568 0.593913 0.404532
4 0.229919 0.398327 0.320435 0.488460 0.401142 0.289747 0.350900 0.268153 0.273715 0.192955 ... 0.148155 0.193750 0.161295 0.207824 0.115282 0.110903 0.351928 0.277693 0.132093 0.061690

5 rows × 203 columns

Since this is a binary classification model, the output takes the form of probabilities ranging from 0 to 1. We interpret these as the predicted probabilities of the element being accessible in each cell type.

Plot additional visualizations of the test set predictions#

We can plot a calibration curve for all the tasks. This shows us the fraction of true positive examples, for different levels of model-predicted probability.

grelu.visualize.plot_calibration_curve(
    probs, labels=test_dataset.labels, aggregate=False, show_legend=False
)
../_images/47fa9e637485d440ed5b17ded86787fe37ec4ffe8755c77e9a9cb137fc7c17d6.png

We can also pick any cell type and compare the predictions on accessible and non-accessible elements.

grelu.visualize.plot_binary_preds(
    probs,
    labels=test_dataset.labels,
    tasks='Microglia',
    figure_size=(3, 2) # Width, height
)
../_images/cc05e6b5113c2c7e7989bc03592d99f25be1a275c883023859e26b09c65f3401.png

Interpret model predictions (for microglia) using TF-modisco#

Suppose we want to focus specifically on Microglia. We can create a transform - a class that takes in the model’s prediction and returns a function of the prediction, e.g. the prediction for only a subset of cell types that we are interested in. Then, all subsquent analyses that we do will be based only on this subset of the full prediction matrix.

Transforms can also compute more complicated functions of the model’s output - for example, the Specificity class of transforms take in the full set of predictions produced by the model and compute the predicted specificity of each sequence to a particular cell type or otuput track.

In this example, we use the Specificity transform to analyze the predicted cell type specificity of sequences in microglia - i.e. the predicted accessibility in microglia minus the mean predicted accessibility in other cell types. See the Aggregate class if you want to just subset the predictions to microglia.

from grelu.transforms.prediction_transforms import Specificity

microglia_scorer = Specificity(
    on_tasks = ["Microglia"],
    model = model,
    compare_func='subtract' # Subtract the prediction in other cell types from the prediction in microglia
)

The simplest way to use a transform is to apply it directly to the model’s predictions using the .compute() method. Here, we apply this microglia specificity transform to the test set predictions:

# Transforms expect a 3-D input - add an axis 2 to the predictions.
mcg_specificity = microglia_scorer.compute(np.expand_dims(probs, 2)).squeeze()

# This is the predicted specificity for each test sequence in microglia
mcg_specificity[:5]
array([-0.07410772, -0.0159848 ,  0.02259409, -0.23796564, -0.18606862],
      dtype=float32)

We can now select the 250 test set peaks with the highest predicted specificity in microglia:

mcg_peaks = ad.var.iloc[np.argsort(mcg_specificity)].tail(250)
mcg_peaks
chrom start end Class Present in fetal tissues Present in adult tissues CRE module width
14465 chr1 27138447 27138647 Distal yes yes 114 400
44865 chr1 98763570 98763770 Distal yes yes 39 400
6833 chr1 12807886 12808086 Distal no yes 54 400
36037 chr1 76252103 76252303 Distal yes no 118 400
53392 chr1 121167309 121167509 Distal no yes 61 400
... ... ... ... ... ... ... ... ...
46517 chr1 103765391 103765591 Distal yes no 72 400
18728 chr1 36232525 36232725 Distal yes yes 52 400
53991 chr1 146640865 146641065 Distal no yes 40 400
52625 chr1 118788443 118788643 Distal yes yes 114 400
22410 chr1 43592990 43593190 Distal yes no 7 400

250 rows × 8 columns

But transforms are more powerful than this - they allow you to do any downstream task, e.g. interpretation, variant effect prediction and design, with respect to the computed quantity.

Here, we will run TF-Modisco on these selected peaks. TF-Modisco identifies motifs that consistently contribute to the model’s output. Since we are using the microglia_scorer filter to compute predicted specificity to microglia, we will only get motifs that contribute to increased or decreased specificity. We also use TOMTOM to match the TF-Modisco motifs to a set of reference motifs. Here, we use the HOCOMOCO v12 motif set (https://hocomoco12.autosome.org/).

See the subsequent tutorials for more examples of how to use transforms.

%%time
import grelu.interpret.modisco
grelu.interpret.modisco.run_modisco(
    model,
    seqs=mcg_peaks, 
    genome="hg38",
    prediction_transform=microglia_scorer, # Attributions will be calculated with respect to this output
    meme_file="hocomoco_v12", # We will compare the Modisco CWMs to HOCOMOCO motifs
    method="saliency", # Base-level attribution scores will be calculated using saliency. You can also use ISM here.
    correct_grad=True, # Gradient correction; only applied with saliency
    out_dir=experiment,
    batch_size=256,
    devices=0,
    num_workers=32,
    seed=0,

    # Here, we are demonstrating this function on a small number of peaks,
    # So we set these modisco parameters to a low value. In real runs, we suggest
    # using larger or default values.
    min_metacluster_size=6,
    final_min_cluster_size=6,    
)
Getting attributions
Running modisco
4 positive and 3 negative patterns were found.
Writing modisco output
Creating sequence logos
Creating html report
Running TOMTOM
CPU times: user 53min 9s, sys: 23.6 s, total: 53min 33s
Wall time: 1min 37s

Load TOMTOM output for modisco motifs#

The full output of TF-Modisco and TOMTOM can be found in the experiment folder. Here, we read the output of TOMTOM and list the significant TOMTOM matches, i.e. known TF motifs that are similar to those found by TF-MoDISco.

tomtom_file = os.path.join(experiment, 'tomtom.csv')
tomtom = pd.read_csv(tomtom_file, index_col=0)

tomtom[tomtom['q-value'] < 1e-3] # Display most significant matches
Query_ID Target_ID Optimal_offset p-value E-value q-value Overlap Query_consensus Target_consensus Orientation
1603 pos_pattern_1 EHF.H12CORE.0.P.B -26.0 2.949101e-08 0.000043 0.000099 15.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC GAACCAGGAAGTGGG -
1610 pos_pattern_1 ELF4.H12CORE.1.M.B -27.0 2.633492e-07 0.000380 0.000443 15.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC GGAAACAGGAAGTAA -
1611 pos_pattern_1 ELF5.H12CORE.0.PSM.A -27.0 2.053718e-09 0.000003 0.000021 16.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AGGAAGGAGGAAGTAA -
1612 pos_pattern_1 ELF5.H12CORE.1.S.B -27.0 8.806208e-07 0.001271 0.000890 14.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC GAACCCGGAAGTAC -
1646 pos_pattern_1 ETV6.H12CORE.1.P.B -27.0 8.682664e-07 0.001253 0.000890 12.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AAGAGGAAGTGG -
2393 pos_pattern_1 SPI1.H12CORE.0.P.B -27.0 3.370440e-07 0.000486 0.000486 14.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AAAAGAGGAAGTGA -
2394 pos_pattern_1 SPI1.H12CORE.1.S.B -27.0 4.934789e-07 0.000712 0.000623 12.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AAGGGGAAGTAG -
2395 pos_pattern_1 SPIB.H12CORE.0.P.B -24.0 5.263555e-08 0.000076 0.000133 16.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AAAGAGGAAGTGAAAG -
2396 pos_pattern_1 SPIB.H12CORE.1.S.C -28.0 1.888892e-07 0.000273 0.000382 13.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC AAAAGCGGAAGTA -
2397 pos_pattern_1 SPIB.H12CORE.2.SM.B -27.0 1.823026e-08 0.000026 0.000092 16.0 CCCCACTCTTACAGATTCCTTTAGTTCCCACTTCCGCTTTCCAC GGGAATGAGGAAGTAG -