def test_weights_scalar_handler_wrong_setup(): with pytest.raises( TypeError, match="Argument model should be of type torch.nn.Module"): WeightsScalarHandler(None) model = MagicMock(spec=torch.nn.Module) with pytest.raises(TypeError, match="Argument reduction should be callable"): WeightsScalarHandler(model, reduction=123) with pytest.raises( TypeError, match="Output of the reduction function should be a scalar"): WeightsScalarHandler(model, reduction=lambda x: x) wrapper = WeightsScalarHandler(model) mock_logger = MagicMock() mock_engine = MagicMock() with pytest.raises( RuntimeError, match="Handler 'WeightsScalarHandler' works only with VisdomLogger" ): wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
def _test(tag=None): wrapper = WeightsScalarHandler(model, tag=tag) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) tag_prefix = f"{tag}/" if tag else "" assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls( [ call( X=[5], Y=[0.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc1/weight"]["opts"], name=tag_prefix + "weights_norm/fc1/weight", ), call( X=[5], Y=[0.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc1/bias"]["opts"], name=tag_prefix + "weights_norm/fc1/bias", ), call( X=[5], Y=[12.0], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc2/weight"]["opts"], name=tag_prefix + "weights_norm/fc2/weight", ), call( X=[5], Y=ANY, env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows[tag_prefix + "weights_norm/fc2/bias"]["opts"], name=tag_prefix + "weights_norm/fc2/bias", ), ], any_order=True, )
def test_weights_scalar_handler_custom_reduction(): class DummyModel(torch.nn.Module): def __init__(self): super(DummyModel, self).__init__() self.fc1 = torch.nn.Linear(10, 10) self.fc2 = torch.nn.Linear(12, 12) self.fc1.weight.data.zero_() self.fc1.bias.data.zero_() self.fc2.weight.data.fill_(1.0) self.fc2.bias.data.fill_(1.0) model = DummyModel() def norm(x): return 12.34 wrapper = WeightsScalarHandler(model, reduction=norm, show_legend=True) mock_logger = MagicMock(spec=VisdomLogger) mock_logger.vis = MagicMock() mock_logger.executor = _DummyExecutor() mock_engine = MagicMock() mock_engine.state = State() mock_engine.state.epoch = 5 wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED) assert mock_logger.vis.line.call_count == 4 mock_logger.vis.line.assert_has_calls( [ call( X=[ 5, ], Y=[ 12.34, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["weights_norm/fc1/weight"]["opts"], name="weights_norm/fc1/weight", ), call( X=[ 5, ], Y=[ 12.34, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["weights_norm/fc1/bias"]["opts"], name="weights_norm/fc1/bias", ), call( X=[ 5, ], Y=[ 12.34, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["weights_norm/fc2/weight"]["opts"], name="weights_norm/fc2/weight", ), call( X=[ 5, ], Y=[ 12.34, ], env=mock_logger.vis.env, win=None, update=None, opts=wrapper.windows["weights_norm/fc2/bias"]["opts"], name="weights_norm/fc2/bias", ), ], any_order=True, )
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() device = "cpu" if torch.cuda.is_available(): device = "cuda" model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(model, optimizer, criterion, device=device) trainer.logger = setup_logger("Trainer") metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) train_evaluator.logger = setup_logger("Train Evaluator") validation_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device) validation_evaluator.logger = setup_logger("Val Evaluator") @trainer.on(Events.EPOCH_COMPLETED) def compute_metrics(engine): train_evaluator.run(train_loader) validation_evaluator.run(val_loader) vd_logger = VisdomLogger(env="mnist_training") vd_logger.attach_output_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), tag="training", output_transform=lambda loss: {"batchloss": loss}, ) for tag, evaluator in [("training", train_evaluator), ("validation", validation_evaluator)]: vd_logger.attach_output_handler( evaluator, event_name=Events.EPOCH_COMPLETED, tag=tag, metric_names=["loss", "accuracy"], global_step_transform=global_step_from_engine(trainer), ) vd_logger.attach_opt_params_handler( trainer, event_name=Events.ITERATION_COMPLETED(every=100), optimizer=optimizer) vd_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) vd_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED(every=100)) def score_function(engine): return engine.state.metrics["accuracy"] model_checkpoint = ModelCheckpoint( log_dir, n_saved=2, filename_prefix="best", score_function=score_function, score_name="validation_accuracy", global_step_transform=global_step_from_engine(trainer), ) validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint, {"model": model}) # kick everything off trainer.run(train_loader, max_epochs=epochs) vd_logger.close()