API reference: skorch integration#
You can use a Neptune callback to capture model training metadata when using skorch.
NeptuneLogger
#
Captures NeuralNetClassifier
history and logs the metadata to Neptune.
Parameters
Name | Type | Default | Description |
---|---|---|---|
run |
Run |
- | An existing run reference, as returned by neptune.init_run() , or a namespace handler. |
log_on_batch_end |
bool , optional |
False |
Whether to log loss and other metrics on batch level. |
close_after_train |
bool , optional |
True |
Whether to close the run object once training finishes. Set to False if you want to continue logging to the same run or if you use it as a context manager. |
keys_ignored |
str or list of str |
None |
Key or list of keys that should not be logged to Neptune. Note that in addition to the keys provided by the user, keys such as those starting with "event_" or ending with "_best" are ignored by default. |
base_namespace |
str , optional |
"training" |
Namespace under which all metadata logged by the Neptune callback will be stored. |
Examples
Create a NeptuneLogger callback:
(Optional) Set the path to the checkpoints directory:
Pass the callback to the net callbacks
argument:
net = NeuralNetClassifier(
ClassifierModule,
max_epochs=20,
lr=0.01,
callbacks=[neptune_logger, checkpoint],
)
# Run training
net.fit(X, y)
Log additional metrics after training has finished:
from sklearn.metrics import roc_auc_score
y_pred = net.predict_proba(X)
auc = roc_auc_score(y, y_pred[:, 1])
neptune_logger.run["roc_auc_score"].append(auc)
Log charts, such as an ROC curve:
from scikitplot.metrics import plot_roc
import matplotlib.pyplot as plt
from neptune.types import File
fig, ax = plt.subplots(figsize=(16, 12))
plot_roc(y, y_pred, ax=ax)
neptune_logger.run["roc_curve"].upload(File.as_html(fig))
Log the net object after training:
net.save_params(f_params="basic_model.pkl")
neptune_logger.run["basic_model"].upload("basic_model.pkl")
Close the run if needed
If you set close_after_train=False
, close the run when done:
See also
NeptuneLogger
in the skorch API reference