def test_logger_init_hostname_port(visdom_server): # Explicit hostname, port vd_logger = VisdomLogger(server=visdom_server[0], port=visdom_server[1], num_workers=0) assert "main" in vd_logger.vis.get_env_list() vd_logger.close()
def test_logger_init_env_vars(visdom_server): # As env vars import os os.environ["VISDOM_SERVER_URL"] = visdom_server[0] os.environ["VISDOM_PORT"] = str(visdom_server[1]) vd_logger = VisdomLogger(server=visdom_server[0], port=visdom_server[1], num_workers=0) assert "main" in vd_logger.vis.get_env_list() vd_logger.close()
def test_integration_with_executor_as_context_manager(visdom_server, visdom_server_stop): n_epochs = 5 data = list(range(50)) losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) with VisdomLogger(server=visdom_server[0], port=visdom_server[1], num_workers=1) as vd_logger: # close all windows in 'main' environment vd_logger.vis.close() trainer = Engine(update_fn) output_handler = OutputHandler(tag="training", output_transform=lambda x: {"loss": x}) vd_logger.attach(trainer, log_handler=output_handler, event_name=Events.ITERATION_COMPLETED) trainer.run(data, max_epochs=n_epochs) assert len(output_handler.windows) == 1 assert "training/loss" in output_handler.windows win_name = output_handler.windows["training/loss"]["win"] data = vd_logger.vis.get_window_data(win=win_name) data = _parse_content(data) assert "content" in data and "data" in data["content"] data = data["content"]["data"][0] assert "x" in data and "y" in data x_vals, y_vals = data["x"], data["y"] assert all([ int(x) == x_true for x, x_true in zip( x_vals, list(range(1, n_epochs * len(data) + 1))) ]) assert all([y == y_true for y, y_true in zip(y_vals, losses)])
def test_integration_no_server(): with pytest.raises(ConnectionError, match="Error connecting to Visdom server"): VisdomLogger()
def test_no_concurrent(): with pytest.raises( RuntimeError, match=r"This contrib module requires concurrent.futures"): with patch.dict("sys.modules", {"concurrent.futures": None}): VisdomLogger(num_workers=1)
def test_no_visdom(no_site_packages): with pytest.raises(RuntimeError, match=r"This contrib module requires visdom package"): VisdomLogger()
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) vd_logger = VisdomLogger(env="mnist_training") vd_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: vd_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) vd_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) vd_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) vd_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) vd_logger.close()