示例#1
0
def test_logger_init_hostname_port(visdom_server):
    # Explicit hostname, port
    vd_logger = VisdomLogger(server=visdom_server[0],
                             port=visdom_server[1],
                             num_workers=0)
    assert "main" in vd_logger.vis.get_env_list()
    vd_logger.close()
示例#2
0
def test_logger_init_env_vars(visdom_server):
    # As env vars
    import os

    os.environ["VISDOM_SERVER_URL"] = visdom_server[0]
    os.environ["VISDOM_PORT"] = str(visdom_server[1])
    vd_logger = VisdomLogger(server=visdom_server[0],
                             port=visdom_server[1],
                             num_workers=0)
    assert "main" in vd_logger.vis.get_env_list()
    vd_logger.close()
示例#3
0
def test_integration_with_executor_as_context_manager(visdom_server,
                                                      visdom_server_stop):

    n_epochs = 5
    data = list(range(50))

    losses = torch.rand(n_epochs * len(data))
    losses_iter = iter(losses)

    def update_fn(engine, batch):
        return next(losses_iter)

    with VisdomLogger(server=visdom_server[0],
                      port=visdom_server[1],
                      num_workers=1) as vd_logger:

        # close all windows in 'main' environment
        vd_logger.vis.close()

        trainer = Engine(update_fn)
        output_handler = OutputHandler(tag="training",
                                       output_transform=lambda x: {"loss": x})
        vd_logger.attach(trainer,
                         log_handler=output_handler,
                         event_name=Events.ITERATION_COMPLETED)

        trainer.run(data, max_epochs=n_epochs)

        assert len(output_handler.windows) == 1
        assert "training/loss" in output_handler.windows
        win_name = output_handler.windows["training/loss"]["win"]
        data = vd_logger.vis.get_window_data(win=win_name)
        data = _parse_content(data)
        assert "content" in data and "data" in data["content"]
        data = data["content"]["data"][0]
        assert "x" in data and "y" in data
        x_vals, y_vals = data["x"], data["y"]
        assert all([
            int(x) == x_true for x, x_true in zip(
                x_vals, list(range(1,
                                   n_epochs * len(data) + 1)))
        ])
        assert all([y == y_true for y, y_true in zip(y_vals, losses)])
示例#4
0
def test_integration_no_server():

    with pytest.raises(ConnectionError,
                       match="Error connecting to Visdom server"):
        VisdomLogger()
示例#5
0
def test_no_concurrent():
    with pytest.raises(
            RuntimeError,
            match=r"This contrib module requires concurrent.futures"):
        with patch.dict("sys.modules", {"concurrent.futures": None}):
            VisdomLogger(num_workers=1)
示例#6
0
def test_no_visdom(no_site_packages):

    with pytest.raises(RuntimeError,
                       match=r"This contrib module requires visdom package"):
        VisdomLogger()
示例#7
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    vd_logger = VisdomLogger(env="mnist_training")

    vd_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        vd_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    vd_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    vd_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    vd_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    vd_logger.close()