Skip to content

API reference: PyTorch integration#

You can use the NeptuneLogger to capture model training metadata when working with PyTorch.


NeptuneLogger#

Captures model training metadata and logs them to Neptune.

Parameters

Name         Type Default Description
run Run or Handler - (required) An existing run reference, as returned by neptune.init_run(), or a namespace handler.
base_namespace str, optional "training" Namespace under which all metadata logged by the Neptune logger will be stored.
model torch.nn.Module - (required) PyTorch model object to be tracked.
log_model_diagram bool, optional False Whether to save the model visualization. Requires torchviz to be installed.
log_gradients bool, optional False Whether to track the frobenius-order norm of the gradients.
log_parameters bool, optional False Whether to track the frobenius-order norm of the parameters.
log_freq int, optional 100 How often to log the parameters/gradients norm. Applicable only if log_parameters or log_gradients is set to True.

Examples#

Creating a Neptune run and callback#

Create a run:

import neptune

run = neptune.init_run()
If Neptune can't find your project name or API token

As a best practice, you should save your Neptune API token and project name as environment variables:

export NEPTUNE_API_TOKEN="h0dHBzOi8aHR0cHM6Lkc78ghs74kl0jv...Yh3Kb8"
export NEPTUNE_PROJECT="ml-team/classification"

Alternatively, you can pass the information when using a function that takes api_token and project as arguments:

run = neptune.init_run( # (1)!
    api_token="h0dHBzOi8aHR0cHM6Lkc78ghs74kl0jv...Yh3Kb8",  # your token here
    project="ml-team/classification",  # your full project name here
)
  1. Also works for init_model(), init_model_version(), init_project(), and integrations that create Neptune runs underneath the hood, such as NeptuneLogger or NeptuneCallback.

  2. API token: In the bottom-left corner, expand the user menu and select Get my API token.

  3. Project name: You can copy the path from the project details ( Edit project details).

If you haven't registered, you can log anonymously to a public project:

api_token=neptune.ANONYMOUS_API_TOKEN
project="common/quickstarts"

Make sure not to publish sensitive data through your code!

Instantiate the Neptune callback:

from neptune_pytorch import NeptuneLogger

neptune_logger = NeptuneLogger(run=run, model=model)

Train your model:

for epoch in range(1, 4):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()

Additional options#

import neptune
from neptune_pytorch import NeptuneLogger

run = neptune.init_run(
    name="My PyTorch run",
    tags=["test", "pytorch"],
    dependencies="infer",
)

neptune_logger = NeptuneLogger(
    run=run,
    model=model,
    base_namespace="test",
    log_model_diagram=True,
    log_gradients=True,
    log_parameters=True,
    log_freq=50,
)

log_checkpoint()#

Uploads a model checkpoint to Neptune, into a namespace called model/checkpoints nested under the base namespace of the run.

The filename is set to checkpoint_<checkpoint number>.pt by default, but can be customized.

Parameters

Name Type Default Description
checkpoint_name str, optional checkpoint_<checkpoint number>.pt Name for the logged checkpoint file. If left empty and the default name is used, the checkpoint number starts from 1 and is incremented automatically on each call. The extension .pt is added automatically.

Example

from neptune_pytorch import NeptuneLogger

neptune_logger = NeptuneLogger(...)
...
for epoch in range(parameters["epochs"]):
    ...
    neptune_logger.log_checkpoint()

log_model()#

Uploads the model to Neptune, into a namespace called model nested under the base namespace of the run.

The filename is set to model.pt by default, but can be customized.

Parameters

Name Type Default Description
model_name str, optional model.pt Name for the logged model file. The extension .pt is added automatically.

Example

from neptune_pytorch import NeptuneLogger

neptune_logger = NeptuneLogger(...)
...
neptune_logger.log_model()

save_checkpoint#

See log_checkpoint.

save_model#

See log_model.


See also

neptune-pytorch repo on GitHub