Source code for reglm.metrics
import numpy as np
from reglm.dataset import CharDataset
[docs]def compute_accuracy(model, seqs, shuffle_labels=False, batch_size=64, num_workers=8):
"""
Compute per-base accuracy of a trained regLM model on labeled sequences
Args:
model (pl.LightningModule): Trained regLM model
seqs (pd.DataFrame): Dataframe containing sequences under 'Sequence'
and labels under 'label'.
shuffle_labels (bool): Whether to shuffle the labels among sequences
before computing accuracy.
batch_size (int): Batch size for inference
num_workers (int): Number of workers for inference
Returns:
seqs (pd.DataFrame): original dataframe with added columns for per-
base and average accuracy.
"""
# Extract labels
labels = seqs.label
# Shuffle labels if needed
if shuffle_labels:
labels = seqs.label.sample(len(seqs))
labels = labels.tolist()
# Create dataset
ds = CharDataset(seqs=seqs.Sequence.tolist(), labels=labels)
# Compute per-base accuracy
acc = model.compute_accuracy_on_dataset(
ds, batch_size=batch_size, num_workers=num_workers
)
# Add results to dataframe
if shuffle_labels:
seqs["acc_shuf"] = acc
seqs["acc_shuf_mean"] = seqs["acc_shuf"].apply(np.mean)
avg_acc = seqs["acc_shuf_mean"].mean()
else:
seqs["acc"] = acc
seqs["acc_mean"] = seqs["acc"].apply(np.mean)
avg_acc = seqs["acc_mean"].mean()
# Print overall mean
print(f"Mean accuracy: {avg_acc:.3f}")
return seqs