⚡ Performance Optimization Examples¶
Overview¶
This guide provides comprehensive examples for optimizing performance in spatial transcriptomics analysis using SPEX. Learn how to handle large datasets efficiently, parallelize computations, and optimize memory usage.
Prerequisites¶
import spex as sp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import scanpy as sc
import time
import psutil
import multiprocessing as mp
from functools import partial
import dask.array as da
import dask.dataframe as dd
from memory_profiler import profile
# Set up plotting
plt.style.use('default')
sns.set_palette("husl")
sc.settings.verbosity = 3
sc.settings.set_figure_params(dpi=80, facecolor='white')
1. Memory Management¶
Memory Usage Monitoring¶
def get_memory_usage():
"""
Get current memory usage in MB
Returns:
--------
float : Memory usage in MB
"""
process = psutil.Process()
memory_info = process.memory_info()
return memory_info.rss / 1024 / 1024 # Convert to MB
def monitor_memory(func):
"""
Decorator to monitor memory usage of functions
Parameters:
-----------
func : function
Function to monitor
Returns:
--------
function : Wrapped function with memory monitoring
"""
def wrapper(*args, **kwargs):
memory_before = get_memory_usage()
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
memory_after = get_memory_usage()
print(f"Function: {func.__name__}")
print(f" Execution time: {end_time - start_time:.2f} seconds")
print(f" Memory usage: {memory_after - memory_before:.2f} MB")
return result
return wrapper
# Example usage
@monitor_memory
def load_large_image(file_path):
"""Load large image with memory monitoring"""
return sp.load_image(file_path)
# Test memory monitoring
Image, channel = load_large_image('large_sample_data.tiff')
print(f"Image shape: {Image.shape}")
print(f"Image memory usage: {Image.nbytes / 1024 / 1024:.2f} MB")
Efficient Image Loading¶
def load_image_chunked(file_path, chunk_size=1000):
"""
Load large image in chunks to reduce memory usage
Parameters:
-----------
file_path : str
Path to image file
chunk_size : int
Size of chunks to load
Returns:
--------
tuple : (image, channel_names)
"""
import tifffile
with tifffile.TiffFile(file_path) as tif:
# Get image info
shape = tif.series[0].shape
dtype = tif.series[0].dtype
# Calculate chunk dimensions
if len(shape) == 3:
height, width, channels = shape
else:
height, width = shape
channels = 1
# Load image in chunks
image = np.zeros((height, width, channels), dtype=dtype)
for y in range(0, height, chunk_size):
y_end = min(y + chunk_size, height)
for x in range(0, width, chunk_size):
x_end = min(x + chunk_size, width)
# Load chunk
chunk = tif.series[0].asarray(key=slice(y, y_end))
if len(chunk.shape) == 2:
chunk = chunk[:, :, np.newaxis]
image[y:y_end, x:x_end, :] = chunk
# Get channel names
channel_names = [f"Channel_{i}" for i in range(channels)]
return image, channel_names
# Example usage
print("Loading large image with chunked approach...")
Image_chunked, channel_names = load_image_chunked('large_sample_data.tiff', chunk_size=500)
print(f"Loaded image shape: {Image_chunked.shape}")
Memory-Efficient Segmentation¶
def segment_large_image_chunked(image, chunk_size=512, overlap=50):
"""
Segment large image in chunks with overlap
Parameters:
-----------
image : np.ndarray
Input image
chunk_size : int
Size of processing chunks
overlap : int
Overlap between chunks
Returns:
--------
np.ndarray : Segmentation labels
"""
height, width = image.shape[:2]
labels = np.zeros((height, width), dtype=np.int32)
# Process image in chunks
for y in range(0, height, chunk_size - overlap):
y_end = min(y + chunk_size, height)
for x in range(0, width, chunk_size - overlap):
x_end = min(x + chunk_size, width)
# Extract chunk with overlap
y_start = max(0, y - overlap)
x_start = max(0, x - overlap)
y_end_ext = min(height, y_end + overlap)
x_end_ext = min(width, x_end + overlap)
chunk = image[y_start:y_end_ext, x_start:x_end_ext]
# Segment chunk
chunk_labels = sp.cellpose_cellseg(chunk, [0], diameter=30)
# Remove overlap and assign labels
if y > 0:
chunk_labels = chunk_labels[overlap:, :]
if x > 0:
chunk_labels = chunk_labels[:, overlap:]
if y_end < height:
chunk_labels = chunk_labels[:-(y_end_ext - y_end), :]
if x_end < width:
chunk_labels = chunk_labels[:, :-(x_end_ext - x_end)]
# Adjust labels to avoid conflicts
max_label = labels.max()
chunk_labels[chunk_labels > 0] += max_label
# Assign to main labels array
labels[y:y_end, x:x_end] = chunk_labels
return labels
# Example usage
print("Segmenting large image with chunked approach...")
labels_chunked = segment_large_image_chunked(Image_chunked, chunk_size=512, overlap=50)
print(f"Segmentation complete: {labels_chunked.max()} cells detected")
2. Parallel Processing¶
Parallel Feature Extraction¶
def extract_features_parallel(image, labels, channel_indices, n_jobs=-1):
"""
Extract features in parallel for multiple channels
Parameters:
-----------
image : np.ndarray
Input image
labels : np.ndarray
Segmentation labels
channel_indices : list
List of channel indices to process
n_jobs : int
Number of parallel jobs (-1 for all cores)
Returns:
--------
pd.DataFrame : Combined features
"""
if n_jobs == -1:
n_jobs = mp.cpu_count()
def extract_single_channel(ch_idx):
"""Extract features for single channel"""
return sp.feature_extraction(image, labels, [ch_idx])
# Process channels in parallel
with mp.Pool(n_jobs) as pool:
results = pool.map(extract_single_channel, channel_indices)
# Combine results
combined_features = pd.concat(results, axis=1)
return combined_features
# Example usage
print("Extracting features in parallel...")
features_parallel = extract_features_parallel(Image_chunked, labels_chunked,
list(range(len(channel_names))), n_jobs=4)
print(f"Extracted {features_parallel.shape[1]} features for {features_parallel.shape[0]} cells")
Parallel Clustering¶
def cluster_multiple_resolutions_parallel(features, resolutions, method='leiden', n_jobs=-1):
"""
Perform clustering with multiple resolutions in parallel
Parameters:
-----------
features : pd.DataFrame
Feature matrix
resolutions : list
List of resolution parameters
method : str
Clustering method
n_jobs : int
Number of parallel jobs
Returns:
--------
dict : Clustering results for each resolution
"""
if n_jobs == -1:
n_jobs = mp.cpu_count()
def cluster_single_resolution(res):
"""Perform clustering with single resolution"""
adata = sp.feature_extraction_adata(Image_chunked, labels_chunked,
list(range(len(channel_names))))
clusters = sp.cluster(adata, method=method, resolution=res)
return res, clusters
# Process resolutions in parallel
with mp.Pool(n_jobs) as pool:
results = pool.map(cluster_single_resolution, resolutions)
# Organize results
clustering_results = {res: clusters for res, clusters in results}
return clustering_results
# Example usage
resolutions = [0.1, 0.3, 0.5, 0.7, 1.0]
print("Performing clustering with multiple resolutions in parallel...")
clustering_results = cluster_multiple_resolutions_parallel(features_parallel, resolutions, n_jobs=4)
for res, clusters in clustering_results.items():
n_clusters = len(np.unique(clusters))
print(f"Resolution {res}: {n_clusters} clusters")
3. Dask Integration for Large Datasets¶
Dask-based Feature Extraction¶
def extract_features_dask(image, labels, channel_indices, chunk_size=1000):
"""
Extract features using Dask for large datasets
Parameters:
-----------
image : np.ndarray
Input image
labels : np.ndarray
Segmentation labels
channel_indices : list
List of channel indices
chunk_size : int
Size of Dask chunks
Returns:
--------
dd.DataFrame : Dask DataFrame with features
"""
# Convert to Dask arrays
image_dask = da.from_array(image, chunks=(chunk_size, chunk_size, -1))
labels_dask = da.from_array(labels, chunks=(chunk_size, chunk_size))
# Extract features for each channel
feature_dfs = []
for ch_idx in channel_indices:
# Extract channel
channel = image_dask[:, :, ch_idx]
# Calculate features using Dask
features_ch = da.map_blocks(
lambda x, y: sp.feature_extraction(x, y, [0]),
channel, labels_dask,
dtype=object
)
# Convert to Dask DataFrame
features_df = dd.from_dask_array(features_ch, columns=[f'ch_{ch_idx}_features'])
feature_dfs.append(features_df)
# Combine all features
combined_features = dd.concat(feature_dfs, axis=1)
return combined_features
# Example usage
print("Extracting features using Dask...")
features_dask = extract_features_dask(Image_chunked, labels_chunked,
list(range(len(channel_names))), chunk_size=500)
# Compute results
features_computed = features_dask.compute()
print(f"Dask features shape: {features_computed.shape}")
Dask-based Spatial Analysis¶
def spatial_analysis_dask(labels, cluster_labels, chunk_size=1000):
"""
Perform spatial analysis using Dask
Parameters:
-----------
labels : np.ndarray
Segmentation labels
cluster_labels : np.ndarray
Cluster assignments
chunk_size : int
Size of Dask chunks
Returns:
--------
dict : Spatial analysis results
"""
# Convert to Dask arrays
labels_dask = da.from_array(labels, chunks=(chunk_size, chunk_size))
cluster_labels_dask = da.from_array(cluster_labels, chunks=chunk_size)
# Calculate spatial metrics using Dask
def calculate_spatial_metrics(labels_chunk, clusters_chunk):
"""Calculate spatial metrics for chunk"""
# This is a simplified version - in practice, you'd implement
# the full spatial analysis logic here
return np.mean(labels_chunk), np.std(clusters_chunk)
# Apply spatial analysis
spatial_results = da.map_blocks(
calculate_spatial_metrics,
labels_dask, cluster_labels_dask,
dtype=float
)
# Compute results
results = spatial_results.compute()
return {
'mean_labels': results[0],
'std_clusters': results[1]
}
# Example usage
print("Performing spatial analysis with Dask...")
spatial_results_dask = spatial_analysis_dask(labels_chunked, clustering_results[0.5])
print("Spatial analysis complete")
4. Performance Profiling¶
Function Performance Profiling¶
@profile
def profile_feature_extraction(image, labels, channel_indices):
"""
Profile feature extraction performance
Parameters:
-----------
image : np.ndarray
Input image
labels : np.ndarray
Segmentation labels
channel_indices : list
Channel indices
"""
features = sp.feature_extraction(image, labels, channel_indices)
return features
# Example usage
print("Profiling feature extraction...")
features_profiled = profile_feature_extraction(Image_chunked, labels_chunked,
list(range(len(channel_names))))
Performance Comparison¶
def compare_performance_methods(image, labels, channel_indices):
"""
Compare performance of different methods
Parameters:
-----------
image : np.ndarray
Input image
labels : np.ndarray
Segmentation labels
channel_indices : list
Channel indices
"""
methods = {
'Sequential': lambda: sp.feature_extraction(image, labels, channel_indices),
'Parallel': lambda: extract_features_parallel(image, labels, channel_indices, n_jobs=4),
'Dask': lambda: extract_features_dask(image, labels, channel_indices).compute()
}
results = {}
for method_name, method_func in methods.items():
print(f"Testing {method_name} method...")
# Measure time
start_time = time.time()
start_memory = get_memory_usage()
result = method_func()
end_time = time.time()
end_memory = get_memory_usage()
results[method_name] = {
'time': end_time - start_time,
'memory': end_memory - start_memory,
'result_shape': result.shape
}
return results
# Example usage
print("Comparing performance methods...")
performance_comparison = compare_performance_methods(Image_chunked, labels_chunked,
list(range(len(channel_names))))
# Display results
print("\nPerformance Comparison:")
for method, metrics in performance_comparison.items():
print(f"\n{method}:")
print(f" Time: {metrics['time']:.2f} seconds")
print(f" Memory: {metrics['memory']:.2f} MB")
print(f" Result shape: {metrics['result_shape']}")
# Create performance visualization
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
# Time comparison
methods = list(performance_comparison.keys())
times = [performance_comparison[m]['time'] for m in methods]
memory = [performance_comparison[m]['memory'] for m in methods]
axes[0].bar(methods, times, color=['blue', 'green', 'red'])
axes[0].set_ylabel('Time (seconds)')
axes[0].set_title('Execution Time Comparison')
axes[0].tick_params(axis='x', rotation=45)
axes[1].bar(methods, memory, color=['blue', 'green', 'red'])
axes[1].set_ylabel('Memory Usage (MB)')
axes[1].set_title('Memory Usage Comparison')
axes[1].tick_params(axis='x', rotation=45)
plt.tight_layout()
plt.show()
5. Optimization Strategies¶
Batch Processing for Large Datasets¶
def process_large_dataset_batch(file_list, batch_size=5):
"""
Process large dataset in batches
Parameters:
-----------
file_list : list
List of file paths
batch_size : int
Number of files to process in each batch
Returns:
--------
list : Results for all files
"""
all_results = []
for i in range(0, len(file_list), batch_size):
batch_files = file_list[i:i + batch_size]
print(f"Processing batch {i//batch_size + 1}/{(len(file_list) + batch_size - 1)//batch_size}")
batch_results = []
for file_path in batch_files:
try:
# Load and process file
image, channels = sp.load_image(file_path)
labels = sp.cellpose_cellseg(image, [0], diameter=30)
features = sp.feature_extraction(image, labels, list(range(len(channels))))
batch_results.append({
'file': file_path,
'features': features,
'n_cells': labels.max()
})
except Exception as e:
print(f"Error processing {file_path}: {e}")
batch_results.append({
'file': file_path,
'error': str(e)
})
all_results.extend(batch_results)
# Clear memory after each batch
import gc
gc.collect()
return all_results
# Example usage
file_list = ['sample_1.tiff', 'sample_2.tiff', 'sample_3.tiff', 'sample_4.tiff', 'sample_5.tiff']
print("Processing large dataset in batches...")
batch_results = process_large_dataset_batch(file_list, batch_size=2)
for result in batch_results:
if 'error' not in result:
print(f"{result['file']}: {result['n_cells']} cells")
else:
print(f"{result['file']}: ERROR - {result['error']}")
Memory-Efficient Data Structures¶
def optimize_data_types(features_df):
"""
Optimize data types to reduce memory usage
Parameters:
-----------
features_df : pd.DataFrame
Input features DataFrame
Returns:
--------
pd.DataFrame : Optimized DataFrame
"""
# Get memory usage before optimization
memory_before = features_df.memory_usage(deep=True).sum() / 1024 / 1024
# Optimize numeric columns
for col in features_df.select_dtypes(include=[np.number]).columns:
col_min = features_df[col].min()
col_max = features_df[col].max()
# Optimize integer columns
if features_df[col].dtype == 'int64':
if col_min >= 0:
if col_max < 255:
features_df[col] = features_df[col].astype(np.uint8)
elif col_max < 65535:
features_df[col] = features_df[col].astype(np.uint16)
else:
features_df[col] = features_df[col].astype(np.uint32)
else:
if col_min > -128 and col_max < 127:
features_df[col] = features_df[col].astype(np.int8)
elif col_min > -32768 and col_max < 32767:
features_df[col] = features_df[col].astype(np.int16)
else:
features_df[col] = features_df[col].astype(np.int32)
# Optimize float columns
elif features_df[col].dtype == 'float64':
features_df[col] = features_df[col].astype(np.float32)
# Get memory usage after optimization
memory_after = features_df.memory_usage(deep=True).sum() / 1024 / 1024
print(f"Memory optimization:")
print(f" Before: {memory_before:.2f} MB")
print(f" After: {memory_after:.2f} MB")
print(f" Reduction: {((memory_before - memory_after) / memory_before * 100):.1f}%")
return features_df
# Example usage
print("Optimizing data types...")
features_optimized = optimize_data_types(features_parallel)
6. Performance Monitoring Dashboard¶
def create_performance_dashboard(performance_data):
"""
Create performance monitoring dashboard
Parameters:
-----------
performance_data : dict
Performance metrics data
"""
fig, axes = plt.subplots(2, 2, figsize=(15, 12))
# Method comparison
methods = list(performance_data.keys())
times = [performance_data[m]['time'] for m in methods]
memory = [performance_data[m]['memory'] for m in methods]
# Time comparison
axes[0, 0].bar(methods, times, color=['blue', 'green', 'red', 'orange'])
axes[0, 0].set_ylabel('Time (seconds)')
axes[0, 0].set_title('Execution Time by Method')
axes[0, 0].tick_params(axis='x', rotation=45)
# Memory comparison
axes[0, 1].bar(methods, memory, color=['blue', 'green', 'red', 'orange'])
axes[0, 1].set_ylabel('Memory Usage (MB)')
axes[0, 1].set_title('Memory Usage by Method')
axes[0, 1].tick_params(axis='x', rotation=45)
# Efficiency plot (time vs memory)
axes[1, 0].scatter(times, memory, c=range(len(methods)), cmap='viridis', s=100)
for i, method in enumerate(methods):
axes[1, 0].annotate(method, (times[i], memory[i]),
xytext=(5, 5), textcoords='offset points')
axes[1, 0].set_xlabel('Time (seconds)')
axes[1, 0].set_ylabel('Memory Usage (MB)')
axes[1, 0].set_title('Efficiency: Time vs Memory')
# Performance summary table
axes[1, 1].axis('off')
table_data = []
for method in methods:
table_data.append([
method,
f"{performance_data[method]['time']:.2f}s",
f"{performance_data[method]['memory']:.2f}MB",
f"{performance_data[method]['result_shape']}"
])
table = axes[1, 1].table(cellText=table_data,
colLabels=['Method', 'Time', 'Memory', 'Shape'],
cellLoc='center',
loc='center')
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
plt.tight_layout()
plt.show()
# Example usage
print("Creating performance dashboard...")
create_performance_dashboard(performance_comparison)
Summary¶
This performance optimization guide provides comprehensive tools for handling large spatial transcriptomics datasets efficiently:
- Memory Management: Monitoring and optimizing memory usage
- Parallel Processing: Utilizing multiple CPU cores for faster computation
- Dask Integration: Handling datasets that don't fit in memory
- Performance Profiling: Identifying bottlenecks and optimization opportunities
- Batch Processing: Processing large datasets in manageable chunks
- Data Type Optimization: Reducing memory footprint through efficient data types
Use these techniques to scale your spatial transcriptomics analysis to handle large datasets while maintaining reasonable performance and memory usage.