Shortcuts

Lightning Module API

PyTorch Lightning integration for training orchestration and distributed computing.

Overview

The Lightning module provides three main components:

  1. ConnectomicsModule: Lightning wrapper for models

  2. ConnectomicsDataModule: Lightning data handling

  3. create_trainer: Convenience function for trainer creation

Quick Example

from connectomics.config import load_config
from connectomics.training.lightning import (
    ConnectomicsModule,
    create_datamodule,
    create_trainer
)
from pytorch_lightning import seed_everything

# Load config
cfg = load_config("tutorials/minimal.yaml")

# Set seed
seed_everything(cfg.system.seed)

# Create components
datamodule = create_datamodule(cfg)
model = ConnectomicsModule(cfg)
trainer = create_trainer(cfg)

# Train
trainer.fit(model, datamodule=datamodule)

# Test
trainer.test(model, datamodule=datamodule)

Module Reference

ConnectomicsModule

ConnectomicsDataModule

create_trainer

Training Features

Distributed Training

Automatically uses DistributedDataParallel (DDP) with multiple GPUs:

system:
  num_gpus: 4  # Uses DDP automatically
trainer = create_trainer(cfg)  # DDP enabled automatically

Mixed Precision

Enable mixed precision for faster training:

optimization:
  precision: "16-mixed"  # FP16
  # or
  precision: "bf16-mixed"  # BFloat16 (Ampere+ GPUs)

Gradient Accumulation

Simulate larger batch sizes:

optimization:
  accumulate_grad_batches: 4

Gradient Clipping

Prevent exploding gradients:

optimization:
  gradient_clip_val: 1.0

Learning Rate Scheduling

Automatic LR scheduling with warmup:

optimization:
  scheduler:
    name: CosineAnnealingLR
    warmup_epochs: 5
    min_lr: 1e-6

Deep Supervision

Multi-scale loss computation:

model:
  loss:
    deep_supervision: true
    losses:
      - function: DiceLoss
        weight: 1.0

The module automatically:

  • Computes losses at multiple scales

  • Resizes ground truth to match each scale

  • Averages losses across scales

Callbacks

The trainer includes several useful callbacks:

Model Checkpointing

monitor:
  checkpoint:
    monitor: "val/loss"
    mode: "min"
    save_top_k: 3
    save_last: true
    filename: "epoch{epoch:02d}-loss{val/loss:.2f}"

Early Stopping

monitor:
  early_stopping:
    enabled: true
    monitor: "val/loss"
    patience: 10
    mode: "min"
    min_delta: 0.0

Learning Rate Monitoring

Automatically logs learning rate to TensorBoard/Wandb.

Logging

TensorBoard (Default)

monitor:
  logging:
    scalar:
      loss_every_n_steps: 10

Logs are saved to outputs/lightning_logs/.

View with:

tensorboard --logdir outputs/lightning_logs

Weights & Biases (Optional)

monitor:
  wandb:
    use_wandb: true
    project: "connectomics"
    entity: "your_team"
    name: "lucchi_exp"

Advanced Usage

Custom Callbacks

from pytorch_lightning.callbacks import Callback

class MyCallback(Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch} finished!")

# Add to trainer
from pytorch_lightning import Trainer

trainer = Trainer(
    max_epochs=100,
    callbacks=[MyCallback()]
)

Custom Training Step

from connectomics.training.lightning import ConnectomicsModule

class CustomModule(ConnectomicsModule):
    def training_step(self, batch, batch_idx):
        # Custom training logic
        images, labels = batch
        outputs = self.model(images)

        # Custom loss computation
        loss = self.compute_loss(outputs, labels)

        # Log metrics
        self.log('train/loss', loss)

        return loss

Inference

Single Batch Prediction

# Load trained model
model = ConnectomicsModule.load_from_checkpoint(
    "outputs/epoch=99.ckpt",
    cfg=cfg
)

model.eval()
model.cuda()

# Predict
with torch.no_grad():
    output = model(input_batch)

Full Dataset Inference

# Load model
model = ConnectomicsModule.load_from_checkpoint(
    "outputs/epoch=99.ckpt",
    cfg=cfg
)

# Create datamodule
datamodule = create_datamodule(cfg)

# Create trainer
trainer = create_trainer(cfg)

# Run inference
predictions = trainer.predict(model, datamodule=datamodule)

Resuming Training

# Resume from checkpoint
trainer = create_trainer(cfg)
trainer.fit(
    model,
    datamodule=datamodule,
    ckpt_path="outputs/last.ckpt"
)

Or from command line:

python scripts/main.py \
    --config tutorials/minimal.yaml \
    --resume outputs/last.ckpt

See Also