示例#1
0
def test_output_handler_with_global_step_from_engine():

    mock_another_engine = MagicMock()
    mock_another_engine.state = State()
    mock_another_engine.state.epoch = 10
    mock_another_engine.state.output = 12.345

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_from_engine(mock_another_engine),
    )

    mock_logger = MagicMock(spec=ClearMLLogger)
    mock_logger.clearml_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 1
    mock_engine.state.output = 0.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.clearml_logger.report_scalar.call_count == 1
    mock_logger.clearml_logger.report_scalar.assert_has_calls([
        call(title="tag",
             series="loss",
             iteration=mock_another_engine.state.epoch,
             value=mock_engine.state.output)
    ])

    mock_another_engine.state.epoch = 11
    mock_engine.state.output = 1.123

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.clearml_logger.report_scalar.call_count == 2
    mock_logger.clearml_logger.report_scalar.assert_has_calls([
        call(title="tag",
             series="loss",
             iteration=mock_another_engine.state.epoch,
             value=mock_engine.state.output)
    ])
def run(train_batch_size, val_batch_size, epochs, lr, momentum):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    clearml_logger = ClearMLLogger(project_name="examples", task_name="ignite")

    clearml_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
    )

    for tag, evaluator in [("training metrics", train_evaluator),
                           ("validation metrics", validation_evaluator)]:
        clearml_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    clearml_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    clearml_logger.attach(trainer,
                          log_handler=WeightsScalarHandler(model),
                          event_name=Events.ITERATION_COMPLETED(every=100))

    clearml_logger.attach(trainer,
                          log_handler=WeightsHistHandler(model),
                          event_name=Events.EPOCH_COMPLETED(every=100))

    clearml_logger.attach(trainer,
                          log_handler=GradsScalarHandler(model),
                          event_name=Events.ITERATION_COMPLETED(every=100))

    clearml_logger.attach(trainer,
                          log_handler=GradsHistHandler(model),
                          event_name=Events.EPOCH_COMPLETED(every=100))

    handler = Checkpoint(
        {"model": model},
        ClearMLSaver(),
        n_saved=1,
        score_function=lambda e: e.state.metrics["accuracy"],
        score_name="val_acc",
        filename_prefix="best",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.EPOCH_COMPLETED, handler)

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    clearml_logger.close()