Polyaxon allows to schedule Pytorch-Lightning experiments and supports tracking metrics, outputs, and models.

With Polyaxon you can:

  • log hyperparameters for every run
  • see learning curves for losses and metrics during training
  • see hardware consumption and stdout/stderr output during training
  • log images, charts, and other assets
  • log git commit information
  • log env information
  • log model

Tracking API

Polyaxon provides a tracking API to track experiment and report metrics, artifacts, logs, and results to the Polyaxon dashboard.

You can use the tracking API to create a custom tracking experience with Pytorch-Lightning.


In order to use Polyaxon tracking with Pytorch-Lightning, you need to install Polyaxon library

pip install polyaxon

Initialize your script with Polyaxon

This is an optional step if you need to perform some manual tracking or to track some information before passing the callback.

from polyaxon import tracking


Polyaxon callback

Pass Polyaxon’s callback to your logger

from polyaxon.tracking.contrib.pytorch_lightning import PolyaxonCallback

trainer = pl.Trainer(

Customizing the callback

Creating the callback will use the current initialized run, but you can use a different run if you need to have more control, the PL callback accepts all tracking.init arguments:

from polyaxon.tracking.contrib.pytorch_lightning import PolyaxonCallback

trainer = pl.Trainer(

Manual logging

If you want to have more control and use Polyaxon to log metrics in your scripts, you just need to add tracking.log_... anywhere needed:

  • log metrics
  • log artifacts

You can also use the callback instance to log extra information:

from polyaxon.tracking.contrib.pytorch_lightning import PolyaxonCallback

callback = PolyaxonCallback()


In this example we will go through the process of logging a Pytorch-Lightning model using Polyaxon’s callback.

This example can be used with the offline mode POLYAXON_OFFLINE=true and it does not require a Polyaxon API to run locally.

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
import pytorch_lightning as pl

from polyaxon.tracking.contrib.pytorch_lightning import PolyaxonCallback

class MNISTModel(pl.LightningModule):

    def __init__(self):
        super(MNISTModel, self).__init__()
        self.l1 = torch.nn.Linear(28 * 28, 10)

    def forward(self, x):
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        tensorboard_logs = {'train_loss': loss}
        return {'loss': loss, 'log': tensorboard_logs}

    def validation_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        return {'val_loss': F.cross_entropy(y_hat, y)}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_nb):
        x, y = batch
        y_hat = self(x)
        return {'test_loss': F.cross_entropy(y_hat, y)}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean()
        logs = {'test_loss': avg_loss}
        return {'test_loss': avg_loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.02)

    def train_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def val_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

    def test_dataloader(self):
        return DataLoader(MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor()), batch_size=32)

train_loader = DataLoader(MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor()), batch_size=32)

mnist_model = MNISTModel()
trainer = pl.Trainer(
trainer.fit(mnist_model, train_loader)