Beispiel #1
0
def test_weights_scalar_handler_wrong_setup():

    with pytest.raises(
            TypeError,
            match="Argument model should be of type torch.nn.Module"):
        WeightsScalarHandler(None)

    model = MagicMock(spec=torch.nn.Module)
    with pytest.raises(TypeError,
                       match="Argument reduction should be callable"):
        WeightsScalarHandler(model, reduction=123)

    with pytest.raises(
            TypeError,
            match="Output of the reduction function should be a scalar"):
        WeightsScalarHandler(model, reduction=lambda x: x)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock()
    mock_engine = MagicMock()
    with pytest.raises(
            RuntimeError,
            match="Handler 'WeightsScalarHandler' works only with VisdomLogger"
    ):
        wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
Beispiel #2
0
    def _test(tag=None):
        wrapper = WeightsScalarHandler(model, tag=tag)
        mock_logger = MagicMock(spec=VisdomLogger)
        mock_logger.vis = MagicMock()
        mock_logger.executor = _DummyExecutor()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = f"{tag}/" if tag else ""

        assert mock_logger.vis.line.call_count == 4
        mock_logger.vis.line.assert_has_calls(
            [
                call(
                    X=[5],
                    Y=[0.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc1/weight"]["opts"],
                    name=tag_prefix + "weights_norm/fc1/weight",
                ),
                call(
                    X=[5],
                    Y=[0.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc1/bias"]["opts"],
                    name=tag_prefix + "weights_norm/fc1/bias",
                ),
                call(
                    X=[5],
                    Y=[12.0],
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc2/weight"]["opts"],
                    name=tag_prefix + "weights_norm/fc2/weight",
                ),
                call(
                    X=[5],
                    Y=ANY,
                    env=mock_logger.vis.env,
                    win=None,
                    update=None,
                    opts=wrapper.windows[tag_prefix +
                                         "weights_norm/fc2/bias"]["opts"],
                    name=tag_prefix + "weights_norm/fc2/bias",
                ),
            ],
            any_order=True,
        )
Beispiel #3
0
def test_weights_scalar_handler_custom_reduction():
    class DummyModel(torch.nn.Module):
        def __init__(self):
            super(DummyModel, self).__init__()
            self.fc1 = torch.nn.Linear(10, 10)
            self.fc2 = torch.nn.Linear(12, 12)
            self.fc1.weight.data.zero_()
            self.fc1.bias.data.zero_()
            self.fc2.weight.data.fill_(1.0)
            self.fc2.bias.data.fill_(1.0)

    model = DummyModel()

    def norm(x):
        return 12.34

    wrapper = WeightsScalarHandler(model, reduction=norm, show_legend=True)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.vis.line.call_count == 4
    mock_logger.vis.line.assert_has_calls(
        [
            call(
                X=[
                    5,
                ],
                Y=[
                    12.34,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["weights_norm/fc1/weight"]["opts"],
                name="weights_norm/fc1/weight",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    12.34,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["weights_norm/fc1/bias"]["opts"],
                name="weights_norm/fc1/bias",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    12.34,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["weights_norm/fc2/weight"]["opts"],
                name="weights_norm/fc2/weight",
            ),
            call(
                X=[
                    5,
                ],
                Y=[
                    12.34,
                ],
                env=mock_logger.vis.env,
                win=None,
                update=None,
                opts=wrapper.windows["weights_norm/fc2/bias"]["opts"],
                name="weights_norm/fc2/bias",
            ),
        ],
        any_order=True,
    )
Beispiel #4
0
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_dir):
    train_loader, val_loader = get_data_loaders(train_batch_size,
                                                val_batch_size)
    model = Net()
    device = "cpu"

    if torch.cuda.is_available():
        device = "cuda"

    model.to(device)  # Move model before creating optimizer
    optimizer = SGD(model.parameters(), lr=lr, momentum=momentum)
    criterion = nn.CrossEntropyLoss()
    trainer = create_supervised_trainer(model,
                                        optimizer,
                                        criterion,
                                        device=device)
    trainer.logger = setup_logger("Trainer")

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device)
    train_evaluator.logger = setup_logger("Train Evaluator")
    validation_evaluator = create_supervised_evaluator(model,
                                                       metrics=metrics,
                                                       device=device)
    validation_evaluator.logger = setup_logger("Val Evaluator")

    @trainer.on(Events.EPOCH_COMPLETED)
    def compute_metrics(engine):
        train_evaluator.run(train_loader)
        validation_evaluator.run(val_loader)

    vd_logger = VisdomLogger(env="mnist_training")

    vd_logger.attach_output_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        tag="training",
        output_transform=lambda loss: {"batchloss": loss},
    )

    for tag, evaluator in [("training", train_evaluator),
                           ("validation", validation_evaluator)]:
        vd_logger.attach_output_handler(
            evaluator,
            event_name=Events.EPOCH_COMPLETED,
            tag=tag,
            metric_names=["loss", "accuracy"],
            global_step_transform=global_step_from_engine(trainer),
        )

    vd_logger.attach_opt_params_handler(
        trainer,
        event_name=Events.ITERATION_COMPLETED(every=100),
        optimizer=optimizer)

    vd_logger.attach(trainer,
                     log_handler=WeightsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    vd_logger.attach(trainer,
                     log_handler=GradsScalarHandler(model),
                     event_name=Events.ITERATION_COMPLETED(every=100))

    def score_function(engine):
        return engine.state.metrics["accuracy"]

    model_checkpoint = ModelCheckpoint(
        log_dir,
        n_saved=2,
        filename_prefix="best",
        score_function=score_function,
        score_name="validation_accuracy",
        global_step_transform=global_step_from_engine(trainer),
    )
    validation_evaluator.add_event_handler(Events.COMPLETED, model_checkpoint,
                                           {"model": model})

    # kick everything off
    trainer.run(train_loader, max_epochs=epochs)

    vd_logger.close()