Skip to content

API reference: skorch integration#

You can use a Neptune callback to capture model training metadata when using skorch.

Related


NeptuneLogger#

Captures NeuralNetClassifier history and logs the metadata to Neptune.

Parameters

Name         Type Default    Description
run Run - An existing run reference, as returned by neptune.init_run().
log_on_batch_end bool, optional False Whether to log loss and other metrics on batch level.
close_after_train bool, optional True Whether to close the run object once training finishes. Set to False if you want to continue logging to the same run or if you use it as a context manager.
keys_ignored str or list of str None Key or list of keys that should not be logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with "event_" or ending with "_best" are ignored by default.
base_namespace str, optional "training" Namespace under which all metadata logged by the Neptune callback will be stored.

Examples

Create a NeptuneLogger callback:

import neptune.new as neptune
neptune_logger = NeptuneLogger(neptune.init_run(), close_after_train=False)

(Optional) Set the path to the checkpoints directory:

checkpoint_dirname = "./checkpoints"
checkpoint = Checkpoint(dirname=checkpoint_dirname)

Pass the callback to the net callbacks argument:

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.01,
    callbacks=[neptune_logger, checkpoint],
)

# Run training
net.fit(X, y)

Log additional metrics after training has finished:

from sklearn.metrics import roc_auc_score
y_pred = net.predict_proba(X)
auc = roc_auc_score(y, y_pred[:, 1])
neptune_logger.run["roc_auc_score"].append(auc)

Log charts, such as an ROC curve:

from scikitplot.metrics import plot_roc
import matplotlib.pyplot as plt

from neptune.new.types import File

fig, ax = plt.subplots(figsize=(16, 12))
plot_roc(y, y_pred, ax=ax)
neptune_logger.run["roc_curve"].upload(File.as_html(fig))

Log the net object after training:

net.save_params(f_params="basic_model.pkl")
neptune_logger.run["basic_model"].upload("basic_model.pkl")

Close the run if needed

If you set close_after_train=False, close the run when done:

neptune_logger.run.stop()