示例#1
0
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.writer.add_scalar.assert_called_once_with("tag/output", 12345, 123)

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.writer.add_scalar.assert_called_once_with("another_tag/loss", 12345, 123)
def test_grads_hist_handler(dummy_model_factory):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=False)

    wrapper = GradsHistHandler(model)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.writer.add_histogram.call_count == 4
    mock_logger.writer.add_histogram.assert_has_calls([
        call(tag="grads/fc1/weight", values=ANY, global_step=5),
        call(tag="grads/fc1/bias", values=ANY, global_step=5),
        call(tag="grads/fc2/weight", values=ANY, global_step=5),
        call(tag="grads/fc2/bias", values=ANY, global_step=5),
    ],
                                                      any_order=True)
示例#3
0
    def _test(tag=None):
        wrapper = WeightsScalarHandler(model, tag=tag)
        mock_logger = MagicMock(spec=NeptuneLogger)
        mock_logger.experiment = MagicMock()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = "{}/".format(tag) if tag else ""

        assert mock_logger.experiment.log_metric.call_count == 4
        mock_logger.experiment.log_metric.assert_has_calls([
            call(tag_prefix + "weights_norm/fc1/weight", y=0.0, x=5),
            call(tag_prefix + "weights_norm/fc1/bias", y=0.0, x=5),
            call(tag_prefix + "weights_norm/fc2/weight", y=12.0, x=5),
            call(tag_prefix + "weights_norm/fc2/bias", y=math.sqrt(12.0), x=5),
        ],
                                                           any_order=True)
示例#4
0
def test_no_grad():
    y_pred = torch.zeros(4, requires_grad=True)
    y = torch.zeros(4, requires_grad=False)

    class DummyMetric(Metric):
        def reset(self):
            pass

        def compute(self):
            pass

        def update(self, output):
            y_pred, y = output
            mse = torch.pow(y_pred - y.view_as(y_pred), 2)
            assert y_pred.requires_grad
            assert not mse.requires_grad

    metric = DummyMetric()
    state = State(output=(y_pred, y))
    engine = MagicMock(state=state)
    metric.iteration_completed(engine)
示例#5
0
def test_neptune_saver_integration():

    model = torch.nn.Module()
    to_save_serializable = {"model": model}

    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_artifact = MagicMock()
    mock_logger.delete_artifacts = MagicMock()
    saver = NeptuneSaver(mock_logger)

    checkpoint = Checkpoint(to_save=to_save_serializable,
                            save_handler=saver,
                            n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)
    assert mock_logger.log_artifact.call_count == 2
    assert mock_logger.delete_artifacts.call_count == 1
示例#6
0
def test_output_handler_output_transform():

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.log_metrics.assert_called_once_with(step=123, **{"tag/output": 12345})

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metrics.assert_called_once_with(step=123, **{"another_tag/loss": 12345})
示例#7
0
def _test_checkpoint_with_ddp(device):
    torch.manual_seed(0)

    model = DummyModel().to(device)
    device_ids = (None if "cpu" in device.type else [
        device,
    ])
    ddp_model = nn.parallel.DistributedDataParallel(model,
                                                    device_ids=device_ids)
    to_save = {"model": ddp_model}

    save_handler = MagicMock(spec=BaseSaveHandler)
    checkpointer = Checkpoint(to_save, save_handler=save_handler)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)

    checkpointer(trainer)
    assert save_handler.call_count == 1
    metadata = {"basename": "model", "score_name": None, "priority": 0}
    save_handler.assert_called_with(model.state_dict(), "model_0.pt", metadata)
示例#8
0
def test_output_handler_both():
    wrapper = OutputHandler("tag",
                            metric_names=["a", "b"],
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.experiment = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.experiment.log_metric.call_count == 3
    mock_logger.experiment.log_metric.assert_has_calls([
        call("tag/a", y=12.23, x=5),
        call("tag/b", y=23.45, x=5),
        call("tag/loss", y=12345, x=5)
    ],
                                                       any_order=True)
示例#9
0
def test_weights_scalar_handler(dummy_model_factory):

    model = dummy_model_factory(with_grads=True, with_frozen_layer=False)

    wrapper = WeightsScalarHandler(model)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.writer.add_scalar.call_count == 4
    mock_logger.writer.add_scalar.assert_has_calls([
        call("weights_norm/fc1/weight", 0.0, 5),
        call("weights_norm/fc1/bias", 0.0, 5),
        call("weights_norm/fc2/weight", 12.0, 5),
        call("weights_norm/fc2/bias", math.sqrt(12.0), 5),
    ], any_order=True)
示例#10
0
    def _test(tag=None):
        wrapper = WeightsScalarHandler(model, tag=tag)
        mock_logger = MagicMock(spec=TensorboardLogger)
        mock_logger.writer = MagicMock()

        mock_engine = MagicMock()
        mock_engine.state = State()
        mock_engine.state.epoch = 5

        wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

        tag_prefix = "{}/".format(tag) if tag else ""

        assert mock_logger.writer.add_scalar.call_count == 4
        mock_logger.writer.add_scalar.assert_has_calls([
            call(tag_prefix + "weights_norm/fc1/weight", 0.0, 5),
            call(tag_prefix + "weights_norm/fc1/bias", 0.0, 5),
            call(tag_prefix + "weights_norm/fc2/weight", 12.0, 5),
            call(tag_prefix + "weights_norm/fc2/bias", math.sqrt(12.0), 5),
        ],
                                                       any_order=True)
示例#11
0
def test_output_handler_output_transform(dirname):

    wrapper = OutputHandler("tag", output_transform=lambda x: x)
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.output = 12345
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "tag/output" in wrapper.windows
    assert wrapper.windows["tag/output"]['win'] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[123, ], Y=[12345, ], env=mock_logger.vis.env,
        win=None, update=None,
        opts=wrapper.windows['tag/output']['opts'],
        name="tag/output"
    )

    wrapper = OutputHandler("another_tag", output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "another_tag/loss" in wrapper.windows
    assert wrapper.windows["another_tag/loss"]['win'] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[123, ], Y=[12345, ], env=mock_logger.vis.env,
        win=None, update=None,
        opts=wrapper.windows['another_tag/loss']['opts'],
        name="another_tag/loss"
    )
示例#12
0
def _test_gpu_info(device="cpu"):
    gpu_info = GpuInfo()

    # increase code cov
    gpu_info.reset()
    gpu_info.update(None)

    t = torch.rand(4, 10, 100, 100).to(device)
    data = gpu_info.compute()
    assert len(data) > 0
    assert "fb_memory_usage" in data[0]
    mem_report = data[0]["fb_memory_usage"]
    assert "used" in mem_report and "total" in mem_report
    assert mem_report["total"] > 0.0
    assert mem_report["used"] > t.shape[0] * t.shape[1] * t.shape[2] * t.shape[
        3] / 1024.0 / 1024.0

    assert "utilization" in data[0]
    util_report = data[0]["utilization"]
    assert "gpu_util" in util_report

    # with Engine
    engine = Engine(lambda engine, batch: 0.0)
    engine.state = State(metrics={})

    gpu_info.completed(engine, name="gpu")

    assert "gpu:0 mem(%)" in engine.state.metrics

    assert isinstance(engine.state.metrics["gpu:0 mem(%)"], int)
    assert int(mem_report["used"] * 100.0 /
               mem_report["total"]) == engine.state.metrics["gpu:0 mem(%)"]

    if util_report["gpu_util"] != "N/A":
        assert "gpu:0 util(%)" in engine.state.metrics
        assert isinstance(engine.state.metrics["gpu:0 util(%)"], int)
        assert int(
            util_report["gpu_util"]) == engine.state.metrics["gpu:0 util(%)"]
    else:
        assert "gpu:0 util(%)" not in engine.state.metrics
示例#13
0
def test_optimizer_params():

    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    # mock_logger.vis.line.assert_called_once_with("lr/group_0", 0.01, 123)
    assert len(wrapper.windows) == 1 and "lr/group_0" in wrapper.windows
    assert wrapper.windows["lr/group_0"]['win'] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[123, ], Y=[0.01, ], env=mock_logger.vis.env,
        win=None, update=None,
        opts=wrapper.windows['lr/group_0']['opts'],
        name="lr/group_0"
    )

    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr", tag="generator")
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    assert len(wrapper.windows) == 1 and "generator/lr/group_0" in wrapper.windows
    assert wrapper.windows["generator/lr/group_0"]['win'] is not None

    mock_logger.vis.line.assert_called_once_with(
        X=[123, ], Y=[0.01, ], env=mock_logger.vis.env,
        win=None, update=None,
        opts=wrapper.windows['generator/lr/group_0']['opts'],
        name="generator/lr/group_0"
    )
示例#14
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save, save_handler=save_handler)
        assert checkpointer.last_checkpoint is None

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=0, iteration=0)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        save_handler.assert_called_with(obj, "{}_0.pt".format(name))

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        save_handler.assert_called_with(obj, "{}_1234.pt".format(name))
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}_0.pt".format(name))
        assert checkpointer.last_checkpoint == "{}_1234.pt".format(name)
示例#15
0
def test_completed_on_cuda():

    # Checks https://github.com/pytorch/ignite/issues/1635#issuecomment-863026919

    class DummyMetric(Metric):
        def reset(self):
            pass

        def compute(self):
            return torch.tensor([1.0, 2.0, 3.0], device="cuda")

        def update(self, output):
            pass

    m = DummyMetric()

    # tensor
    engine = MagicMock(state=State(metrics={}))
    m.completed(engine, "metric")
    assert "metric" in engine.state.metrics
    assert isinstance(engine.state.metrics["metric"], torch.Tensor)
    assert engine.state.metrics["metric"].device.type == "cpu"
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag",
                            metric_names=["a", "b"],
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.writer.add_scalar.call_count == 3
    mock_logger.writer.add_scalar.assert_has_calls([
        call("tag/a", 12.23, 5),
        call("tag/b", 23.45, 5),
        call("tag/loss", 12345, 5)
    ],
                                                   any_order=True)
示例#17
0
    def _test(filename_prefix, to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(
            to_save,
            save_handler=save_handler,
            filename_prefix=filename_prefix,
            global_step_transform=lambda e, _: e.state.epoch,
        )

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=1, iteration=1)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        if len(filename_prefix) > 0:
            filename_prefix += "_"

        metadata = {
            "basename": "{}{}".format(filename_prefix, name),
            "score_name": None,
            "priority": 1
        }
        save_handler.assert_called_with(
            obj, "{}{}_1.pt".format(filename_prefix, name), metadata)

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 1234
        save_handler.assert_called_with(
            obj, "{}{}_12.pt".format(filename_prefix, name), metadata)
        assert save_handler.remove.call_count == 1
        save_handler.remove.assert_called_with("{}{}_1.pt".format(
            filename_prefix, name))
        assert checkpointer.last_checkpoint == "{}{}_12.pt".format(
            filename_prefix, name)
示例#18
0
def test_trains_disk_saver_integration_no_logger():
    model = torch.nn.Module()
    to_save_serializable = {"model": model}

    with pytest.warns(UserWarning, match="TrainsSaver created a temporary checkpoints directory"):
        trains.Task.current_task = Mock(return_value=object())
        trains.binding.frameworks.WeightsFileHandler.create_output_model = MagicMock()
        trains_saver = TrainsSaver()
        checkpoint = Checkpoint(to_save=to_save_serializable, save_handler=trains_saver, n_saved=1)

    trainer = Engine(lambda e, b: None)
    trainer.state = State(epoch=0, iteration=0)
    checkpoint(trainer)
    trainer.state.iteration = 1
    checkpoint(trainer)

    if trains_saver._atomic:
        assert trains.binding.frameworks.WeightsFileHandler.create_output_model.call_count == 2
    else:
        saved_files = list(os.listdir(trains_saver.dirname))
        assert len(saved_files) == 1
        assert saved_files[0] == "model_1.pt"
示例#19
0
def test_output_handler_both(dirname):

    wrapper = OutputHandler("tag", metric_names=["a", "b"], output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=TrainsLogger)
    mock_logger.trains_logger = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.trains_logger.report_scalar.call_count == 3
    mock_logger.trains_logger.report_scalar.assert_has_calls(
        [
            call(title="tag", series="a", iteration=5, value=12.23),
            call(title="tag", series="b", iteration=5, value=23.45),
            call(title="tag", series="loss", iteration=5, value=12345),
        ],
        any_order=True,
    )
示例#20
0
def test_mlflow_bad_metric_name_handling(dirname):
    import mlflow

    true_values = [123.0, 23.4, 333.4]
    with MLflowLogger(os.path.join(dirname, "mlruns")) as mlflow_logger:

        active_run = mlflow.active_run()

        handler = OutputHandler(tag="training", metric_names="all")
        engine = Engine(lambda e, b: None)
        engine.state = State(metrics={
            "metric:0 in %": 123.0,
            "metric 0": 1000.0
        })

        with pytest.warns(
                UserWarning,
                match=
                r"MLflowLogger output_handler encountered an invalid metric name"
        ):

            engine.state.epoch = 1
            handler(engine, mlflow_logger, event_name=Events.EPOCH_COMPLETED)

            for _, v in enumerate(true_values):
                engine.state.epoch += 1
                engine.state.metrics["metric 0"] = v
                handler(engine,
                        mlflow_logger,
                        event_name=Events.EPOCH_COMPLETED)

    from mlflow.tracking import MlflowClient

    client = MlflowClient(tracking_uri=os.path.join(dirname, "mlruns"))
    stored_values = client.get_metric_history(active_run.info.run_id,
                                              "training metric 0")

    for t, s in zip([1000.0] + true_values, stored_values):
        assert t == s.value
示例#21
0
def test_output_handler_both():

    wrapper = OutputHandler("tag",
                            metric_names=["a", "b"],
                            output_transform=lambda x: {"loss": x})
    mock_logger = MagicMock(spec=PolyaxonLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State(metrics={"a": 12.23, "b": 23.45})
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    assert mock_logger.log_metrics.call_count == 1
    mock_logger.log_metrics.assert_called_once_with(step=5,
                                                    **{
                                                        "tag/a": 12.23,
                                                        "tag/b": 23.45,
                                                        "tag/loss": 12345
                                                    })
示例#22
0
def test_transform():
    y_pred = torch.Tensor([[2.0], [-2.0]])
    y = torch.zeros(2)

    class DummyMetric(Metric):
        def reset(self):
            pass

        def compute(self):
            pass

        def update(self, output):
            assert output == (y_pred, y)

    def transform(output):
        pred_dict, target_dict = output
        return pred_dict['y'], target_dict['y']

    metric = DummyMetric(output_transform=transform)
    state = State(output=({'y': y_pred}, {'y': y}))
    engine = MagicMock(state=state)
    metric.iteration_completed(engine)
示例#23
0
def test_checkpoint_last_checkpoint_on_score():
    save_handler = MagicMock(spec=BaseSaveHandler)
    to_save = {"model": DummyModel()}

    checkpointer = Checkpoint(
        to_save,
        save_handler=save_handler,
        n_saved=None,
        score_name="val_acc",
        score_function=lambda e: e.state.metrics["val_acc"],
    )

    trainer = Engine(lambda e, b: None)

    val_acc = 0.0
    for i in range(10):
        val_acc = i * 0.1
        trainer.state = State(epoch=1, iteration=i, metrics={"val_acc": val_acc})
        checkpointer(trainer)

    assert save_handler.call_count == 10
    assert checkpointer.last_checkpoint == "{}_val_acc=0.9000.pt".format("model")
示例#24
0
def test_output_handler_state_attrs():
    wrapper = OutputHandler("tag", state_attributes=["alpha", "beta", "gamma"])
    mock_logger = MagicMock(spec=MLflowLogger)
    mock_logger.log_metrics = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 5
    mock_engine.state.alpha = 3.899
    mock_engine.state.beta = torch.tensor(12.21)
    mock_engine.state.gamma = torch.tensor([21.0, 6.0])

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)

    mock_logger.log_metrics.assert_called_once_with(
        {
            "tag alpha": 3.899,
            "tag beta": torch.tensor(12.21).item(),
            "tag gamma 0": 21.0,
            "tag gamma 1": 6.0
        },
        step=5)
示例#25
0
def test_best_k_with_suffix(dirname):
    scores = [0.3456789, 0.1234, 0.4567, 0.134567]
    scores_iter = iter(scores)

    def score_function(engine):
        return next(scores_iter)

    h = ModelCheckpoint(dirname, _PREFIX, create_dir=False, n_saved=2,
                        score_function=score_function, score_name="val_loss")

    engine = Engine(lambda e, b: None)
    engine.state = State(epoch=0, iteration=0)

    model = DummyModel()
    to_save = {'model': model}
    for _ in range(4):
        engine.state.epoch += 1
        h(engine, to_save)

    expected = ['{}_{}_val_loss={:.4}.pth'.format(_PREFIX, 'model', scores[e - 1]) for e in [1, 3]]

    assert sorted(os.listdir(dirname)) == expected
示例#26
0
    def _test(ext, require_empty, archived):
        previous_fname = os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'obj', 1, ext))
        with open(previous_fname, 'w') as f:
            f.write("test")

        h = ModelCheckpoint(dirname, _PREFIX, create_dir=True, require_empty=require_empty, archived=archived)
        engine = Engine(lambda e, b: None)
        engine.state = State(epoch=0, iteration=1)

        model = DummyModel()
        to_save = {'model': model}
        h(engine, to_save)

        fname = h.last_checkpoint
        ext = ".pth.tar" if archived else ".pth"
        assert isinstance(fname, str)
        assert os.path.join(dirname, '{}_{}_{}{}'.format(_PREFIX, 'model', 1, ext)) == fname
        assert os.path.exists(fname)
        assert os.path.exists(previous_fname)
        loaded_objects = torch.load(fname)
        assert loaded_objects == model.state_dict()
        os.remove(fname)
示例#27
0
def test_optimizer_params():
    optimizer = torch.optim.SGD([torch.Tensor(0)], lr=0.01)
    wrapper = OptimizerParamsHandler(optimizer=optimizer, param_name="lr")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()
    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.iteration = 123

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("lr/group_0", y=0.01, x=123)

    wrapper = OptimizerParamsHandler(optimizer,
                                     param_name="lr",
                                     tag="generator")
    mock_logger = MagicMock(spec=NeptuneLogger)
    mock_logger.log_metric = MagicMock()

    wrapper(mock_engine, mock_logger, Events.ITERATION_STARTED)
    mock_logger.log_metric.assert_called_once_with("generator/lr/group_0",
                                                   y=0.01,
                                                   x=123)
示例#28
0
def test_grads_scalar_handler(dummy_model_factory, norm_mock):
    model = dummy_model_factory(with_grads=True, with_frozen_layer=False)

    wrapper = GradsScalarHandler(model, reduction=norm_mock)
    mock_logger = MagicMock(spec=TensorboardLogger)
    mock_logger.writer = MagicMock()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    norm_mock.reset_mock()

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)

    mock_logger.writer.add_scalar.assert_has_calls([
        call("grads_norm/fc1/weight", ANY, 5),
        call("grads_norm/fc1/bias", ANY, 5),
        call("grads_norm/fc2/weight", ANY, 5),
        call("grads_norm/fc2/bias", ANY, 5),
    ], any_order=True)
    assert mock_logger.writer.add_scalar.call_count == 4
    assert norm_mock.call_count == 4
示例#29
0
    def _test(to_save, obj, name):
        save_handler = MagicMock(spec=BaseSaveHandler)

        checkpointer = Checkpoint(to_save, save_handler=save_handler, include_self=True)
        assert checkpointer.last_checkpoint is None

        trainer = Engine(lambda e, b: None)
        trainer.state = State(epoch=0, iteration=0)

        checkpointer(trainer)
        assert save_handler.call_count == 1

        fname = "{}_0.pt".format(name)
        obj["checkpointer"] = OrderedDict([("saved", [(0, fname)])])

        metadata = {"basename": name, "score_name": None, "priority": 0}
        save_handler.assert_called_with(obj, fname, metadata)

        # Swap object, state should be maintained
        checkpointer2 = Checkpoint(to_save, save_handler=save_handler, include_self=True)
        checkpointer2.load_state_dict(checkpointer.state_dict())
        assert checkpointer2.last_checkpoint == fname

        trainer.state.epoch = 12
        trainer.state.iteration = 1234
        checkpointer2(trainer)
        assert save_handler.call_count == 2
        metadata["priority"] = 1234

        # This delete only happens if state was restored correctly.
        save_handler.remove.assert_called_with("{}_0.pt".format(name))

        fname = "{}_1234.pt".format(name)
        obj["checkpointer"] = OrderedDict([("saved", [(1234, fname)])])

        save_handler.assert_called_with(obj, fname, metadata)
        assert save_handler.remove.call_count == 1
        assert checkpointer2.last_checkpoint == fname
示例#30
0
def test_output_handler_with_global_step_transform():
    def global_step_transform(*args, **kwargs):
        return 10

    wrapper = OutputHandler(
        "tag",
        output_transform=lambda x: {"loss": x},
        global_step_transform=global_step_transform,
    )
    mock_logger = MagicMock(spec=VisdomLogger)
    mock_logger.vis = MagicMock()
    mock_logger.executor = _DummyExecutor()

    mock_engine = MagicMock()
    mock_engine.state = State()
    mock_engine.state.epoch = 5
    mock_engine.state.output = 12345

    wrapper(mock_engine, mock_logger, Events.EPOCH_STARTED)
    assert mock_logger.vis.line.call_count == 1
    assert len(wrapper.windows) == 1 and "tag/loss" in wrapper.windows
    assert wrapper.windows["tag/loss"]["win"] is not None

    mock_logger.vis.line.assert_has_calls([
        call(
            X=[
                10,
            ],
            Y=[
                12345,
            ],
            env=mock_logger.vis.env,
            win=None,
            update=None,
            opts=wrapper.windows["tag/loss"]["opts"],
            name="tag/loss",
        )
    ])