Lightning Module API¶
PyTorch Lightning integration for training orchestration and distributed computing.
Overview¶
The Lightning module provides three main components:
ConnectomicsModule: Lightning wrapper for models
ConnectomicsDataModule: Lightning data handling
create_trainer: Convenience function for trainer creation
Quick Example¶
from connectomics.config import load_config
from connectomics.training.lightning import (
ConnectomicsModule,
create_datamodule,
create_trainer
)
from pytorch_lightning import seed_everything
# Load config
cfg = load_config("tutorials/minimal.yaml")
# Set seed
seed_everything(cfg.system.seed)
# Create components
datamodule = create_datamodule(cfg)
model = ConnectomicsModule(cfg)
trainer = create_trainer(cfg)
# Train
trainer.fit(model, datamodule=datamodule)
# Test
trainer.test(model, datamodule=datamodule)
Module Reference¶
ConnectomicsModule¶
ConnectomicsDataModule¶
create_trainer¶
Training Features¶
Distributed Training¶
Automatically uses DistributedDataParallel (DDP) with multiple GPUs:
system:
num_gpus: 4 # Uses DDP automatically
trainer = create_trainer(cfg) # DDP enabled automatically
Mixed Precision¶
Enable mixed precision for faster training:
optimization:
precision: "16-mixed" # FP16
# or
precision: "bf16-mixed" # BFloat16 (Ampere+ GPUs)
Gradient Accumulation¶
Simulate larger batch sizes:
optimization:
accumulate_grad_batches: 4
Gradient Clipping¶
Prevent exploding gradients:
optimization:
gradient_clip_val: 1.0
Learning Rate Scheduling¶
Automatic LR scheduling with warmup:
optimization:
scheduler:
name: CosineAnnealingLR
warmup_epochs: 5
min_lr: 1e-6
Deep Supervision¶
Multi-scale loss computation:
model:
loss:
deep_supervision: true
losses:
- function: DiceLoss
weight: 1.0
The module automatically:
Computes losses at multiple scales
Resizes ground truth to match each scale
Averages losses across scales
Callbacks¶
The trainer includes several useful callbacks:
Model Checkpointing¶
monitor:
checkpoint:
monitor: "val/loss"
mode: "min"
save_top_k: 3
save_last: true
filename: "epoch{epoch:02d}-loss{val/loss:.2f}"
Early Stopping¶
monitor:
early_stopping:
enabled: true
monitor: "val/loss"
patience: 10
mode: "min"
min_delta: 0.0
Learning Rate Monitoring¶
Automatically logs learning rate to TensorBoard/Wandb.
Logging¶
TensorBoard (Default)¶
monitor:
logging:
scalar:
loss_every_n_steps: 10
Logs are saved to outputs/lightning_logs/.
View with:
tensorboard --logdir outputs/lightning_logs
Weights & Biases (Optional)¶
monitor:
wandb:
use_wandb: true
project: "connectomics"
entity: "your_team"
name: "lucchi_exp"
Advanced Usage¶
Custom Callbacks¶
from pytorch_lightning.callbacks import Callback
class MyCallback(Callback):
def on_train_epoch_end(self, trainer, pl_module):
print(f"Epoch {trainer.current_epoch} finished!")
# Add to trainer
from pytorch_lightning import Trainer
trainer = Trainer(
max_epochs=100,
callbacks=[MyCallback()]
)
Custom Training Step¶
from connectomics.training.lightning import ConnectomicsModule
class CustomModule(ConnectomicsModule):
def training_step(self, batch, batch_idx):
# Custom training logic
images, labels = batch
outputs = self.model(images)
# Custom loss computation
loss = self.compute_loss(outputs, labels)
# Log metrics
self.log('train/loss', loss)
return loss
Inference¶
Single Batch Prediction¶
# Load trained model
model = ConnectomicsModule.load_from_checkpoint(
"outputs/epoch=99.ckpt",
cfg=cfg
)
model.eval()
model.cuda()
# Predict
with torch.no_grad():
output = model(input_batch)
Full Dataset Inference¶
# Load model
model = ConnectomicsModule.load_from_checkpoint(
"outputs/epoch=99.ckpt",
cfg=cfg
)
# Create datamodule
datamodule = create_datamodule(cfg)
# Create trainer
trainer = create_trainer(cfg)
# Run inference
predictions = trainer.predict(model, datamodule=datamodule)
Resuming Training¶
# Resume from checkpoint
trainer = create_trainer(cfg)
trainer.fit(
model,
datamodule=datamodule,
ckpt_path="outputs/last.ckpt"
)
Or from command line:
python scripts/main.py \
--config tutorials/minimal.yaml \
--resume outputs/last.ckpt