Example #1
0
def _test_neptune_saver_integration(device):

    model = torch.nn.Module().to(device)
    to_save_serializable = {"model": model}

    mock_logger = None
    if idist.get_rank() == 0:
        mock_logger = MagicMock(spec=NeptuneLogger)
        mock_logger.log_artifact = MagicMock()
        mock_logger.delete_artifacts = MagicMock()

    saver = NeptuneSaver(mock_logger)

    checkpoint = Checkpoint(to_save=to_save_serializable,
                            save_handler=saver,
                            n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)
    if idist.get_rank() == 0:
        assert mock_logger.log_artifact.call_count == 2
        assert mock_logger.delete_artifacts.call_count == 1
Example #2
0
def test_neptune_saver_serializable(dirname):

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_artifact = MagicMock()
    model = torch.nn.Module()
    to_save_serializable = {"model": model}

    saver = NeptuneSaver(mock_logger)
    fname = dirname / "test.pt"
    saver(to_save_serializable, fname)

    assert mock_logger.log_artifact.call_count == 1
Example #3
0
def test_neptune_saver_non_serializable():

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_artifact = MagicMock()

    to_save_non_serializable = {"model": lambda x: x}

    saver = NeptuneSaver(mock_logger)
    fname = "test.pt"
    try:
        with warnings.catch_warnings():
            # Ignore torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type
            # DummyModel. It won't be checked for correctness upon loading.
            warnings.simplefilter("ignore", category=UserWarning)
            saver(to_save_non_serializable, fname)
    except Exception:
        pass

    assert mock_logger.log_artifact.call_count == 0
def run(train_batch_size, val_batch_size, epochs, lr, momentum):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    npt_logger = NeptuneLogger(
        api_token="ANONYMOUS",
        project_name="shared/pytorch-ignite-integration",
        name="ignite-mnist-example",
        params={
            "train_batch_size": train_batch_size,
            "val_batch_size": val_batch_size,
            "epochs": epochs,
            "lr": lr,
            "momentum": momentum,
        },
    )

    npt_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        npt_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    npt_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    npt_logger.attach(trainer,
                      log_handler=WeightsScalarHandler(model),
                      event_name=Events.ITERATION_COMPLETED(every=100))

    npt_logger.attach(trainer,
                      log_handler=GradsScalarHandler(model),
                      event_name=Events.ITERATION_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    handler = Checkpoint(
        {"model": model},
        NeptuneSaver(npt_logger),
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, handler)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    npt_logger.close()