def _test_neptune_saver_integration(device): model = torch.nn.Module().to(device) to_save_serializable = {"model": model} mock_logger = None if idist.get_rank() == 0: mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() mock_logger.delete_artifacts = MagicMock() saver = NeptuneSaver(mock_logger) checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=saver, n_saved=1) trainer = Engine(lambda e, b: None) trainer.state = State(epoch=0, iteration=0) checkpoint(trainer) trainer.state.iteration = 1 checkpoint(trainer) if idist.get_rank() == 0: assert mock_logger.log_artifact.call_count == 2 assert mock_logger.delete_artifacts.call_count == 1
def test_neptune_saver_serializable(dirname): mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() model = torch.nn.Module() to_save_serializable = {"model": model} saver = NeptuneSaver(mock_logger) fname = dirname / "test.pt" saver(to_save_serializable, fname) assert mock_logger.log_artifact.call_count == 1
def test_neptune_saver_non_serializable(): mock_logger = MagicMock(spec=NeptuneLogger) mock_logger.log_artifact = MagicMock() to_save_non_serializable = {"model": lambda x: x} saver = NeptuneSaver(mock_logger) fname = "test.pt" try: with warnings.catch_warnings(): # Ignore torch/serialization.py:292: UserWarning: Couldn't retrieve source code for container of type # DummyModel. It won't be checked for correctness upon loading. warnings.simplefilter("ignore", category=UserWarning) saver(to_save_non_serializable, fname) except Exception: pass assert mock_logger.log_artifact.call_count == 0
def run(train_batch_size, val_batch_size, epochs, lr, momentum): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) npt_logger = NeptuneLogger( api_token="ANONYMOUS", project_name="shared/pytorch-ignite-integration", name="ignite-mnist-example", params={ "train_batch_size": train_batch_size, "val_batch_size": val_batch_size, "epochs": epochs, "lr": lr, "momentum": momentum, }, ) npt_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: npt_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) npt_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) npt_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) npt_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] handler = Checkpoint( {"model": model}, NeptuneSaver(npt_logger), n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, handler) # kick everything off trainer.run(train_loader, max_epochs=epochs) npt_logger.close()