PyTorch Lightning

What will you get with this integration?

PyTorch Lightning is a lightweight PyTorch wrapper for high-performance AI research. With Neptune integration you can:

  • monitor model training live,

  • log training, validation, and testing metrics, and visualize them in the Neptune UI,

  • log hyperparameters,

  • monitor hardware usage,

  • log any additional metrics,

  • log performance charts and images,

  • save model checkpoints.


To install Neptune + PyTorch-Lightning integration go to your console and run:

pip install 'neptune-client[pytorch-lightning]'


Create NeptuneLogger

from import NeptuneLogger
neptune_logger = NeptuneLogger(
name='lightning-run', # Optional

Pass your Neptune Project name and API token to NeptuneLogger.

Pass NeptuneLogger to Trainer

Pass NeptuneLogger instance to lightning Trainer to log model training metadata to Neptune:

from pytorch_lightning import Trainer
trainer = Trainer(max_epochs=10, logger=neptune_logger)

Run model training

Pass your lightning Module and training Loader to Trainer and run .fit():, train_loader)

Explore Results

You just learned how to start logging PyTorch Lightning model training runs to Neptune, by using Neptune logger.

Use logger inside your lightning Module class

You can use log Images, model checkpoints, and other ML metadata from inside your training and evaluation steps.

To do that you need to:

from import File
class LitModel(LightningModule):
def training_step(self, batch, batch_idx):
# log metrics
acc = ...
# log images
img = ...
def any_lightning_module_function_or_hook(self):
# log model checkpoint
# generic recipe
metadata = ...

You can log other model-building metadata like metrics, images, video, audio, interactive visualizations, and more. See What can you log and display?

Log after training is finished

If you want to log objects after the training is finished use close_after_fit=False:

neptune_logger = NeptuneLogger(...,
trainer = Trainer(logger=neptune_logger)
# Log confusion matrix after training
from import File
from scikitplot.metrics import plot_confusion_matrix
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(16, 12))
plot_confusion_matrix(y_true, y_pred, ax=ax)
# Stop logging

Pass additional parameters to NeptuneLogger

You can also pass kwargs to specify the Run in greater detail, like tags and description:

neptune_logger = NeptuneLogger(
description='mlp quick run with pytorch-lightning',
tags=['mlp', 'quick-run'],
trainer = Trainer(max_epochs=3, logger=neptune_logger)

For more information about the Neptune Run, see Core Concepts.

What's next?