Example #1
0
def _test_ema_final_weight(model, device, ddp=False, interval=1):
    """Test if final smoothed weights are correct"""
    if isinstance(device, str):
        device = torch.device(device)
    model = model.to(device)
    if ddp:
        model = idist.auto_model(model)
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    # momentum will be constantly 0.5
    ema_handler = EMAHandler(model, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)
    ema_handler.attach(engine, "model", event=Events.ITERATION_COMPLETED(every=interval))

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)

    ema_weight = ema_handler.ema_model.weight.data
    model_weight = model.weight.data
    assert ema_weight.device == device
    assert model_weight.device == device
    if interval == 1:
        torch.testing.assert_allclose(ema_weight, torch.full((1, 2), 4.0625, device=device))
    elif interval == 2:
        torch.testing.assert_allclose(ema_weight, torch.full((1, 2), 3.5, device=device))
    else:
        pass
    torch.testing.assert_allclose(model_weight, torch.full((1, 2), 5.0, device=device))
Example #2
0
def test_ema_buffer():
    """Test if the tensors in buffer are also synchronized"""
    model = nn.BatchNorm2d(2)
    model.running_mean.data.fill_(1.5)
    model.running_var.data.fill_(1.5)
    ema_handler = EMAHandler(model)

    def _bn_step_fn(engine, batch):
        x = torch.rand(4, 2, 32, 32)
        _ = model(x)
        return 1

    engine = Engine(_bn_step_fn)
    ema_handler.attach(engine)

    ema_model = ema_handler.ema_model

    @engine.on(Events.ITERATION_COMPLETED)
    def check_buffers():
        assert ema_model.running_mean.allclose(model.running_mean)
        assert ema_model.running_var.allclose(model.running_var)

    # engine will run 4 iterations
    engine.run([0, 1], max_epochs=2)

    assert ema_model.running_mean.allclose(model.running_mean)
    assert ema_model.running_var.allclose(model.running_var)
Example #3
0
def _test_ema_final_weight(model, device=None, ddp=False, interval=1):
    """Test if final smoothed weights are correct"""
    if device is None:
        # let horovod decide the device
        device = idist.device()
    if isinstance(device, str):
        device = torch.device(device)
    model = model.to(device)
    if ddp:
        model = idist.auto_model(model)
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    ema_handler = EMAHandler(model, momentum=0.5)
    ema_handler.attach(engine,
                       "model",
                       event=Events.ITERATION_COMPLETED(every=interval))

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)

    # ema_model and model can be DP or DDP
    # explicitly cast to float32 to avoid test failure on XLA devices
    ema_weight = _unwrap_model(ema_handler.ema_model).weight.data.to(
        torch.float32)
    model_weight = _unwrap_model(model).weight.data.to(torch.float32)
    assert ema_weight.device == device
    assert model_weight.device == device
    if interval == 1:
        assert ema_weight.allclose(ema_weight.new_full((1, 2), 4.0625))
    elif interval == 2:
        assert ema_weight.allclose(ema_weight.new_full((1, 2), 3.5))
    else:
        pass
    assert model_weight.allclose(model_weight.new_full((1, 2), 5.0))
Example #4
0
def test_ema_update_ema_momentum(get_dummy_model):
    model = get_dummy_model()
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    warmup_iters = 4
    momentum_warmup = 0.1
    momentum = 0.2
    ema_handler = EMAHandler(model,
                             momentum_warmup=momentum_warmup,
                             momentum=momentum,
                             warmup_iters=warmup_iters)
    ema_handler.attach(engine)

    # add handlers to check momentum at each iteration
    @engine.on(Events.ITERATION_COMPLETED)
    def assert_momentum(engine: Engine):
        curr_iter = engine.state.iteration
        curr_momentum = engine.state.ema_momentum
        if curr_iter == 1:
            assert curr_momentum == momentum_warmup
        elif 1 < curr_iter < warmup_iters:
            assert momentum_warmup < curr_momentum < momentum
        else:
            assert curr_momentum == momentum

    engine.run(range(2), max_epochs=5)
Example #5
0
def test_ema_get_const_momentum(get_dummy_model):
    """Test if momentum retrieved from the engine is constant and equal to the handler's momentum"""
    model = get_dummy_model()
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    def assert_const_momentum(engine: Engine, const_momentum):
        assert engine.state.ema_momentum == const_momentum

    ema_handler = EMAHandler(model, momentum=0.002)
    ema_handler.attach(engine)
    engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum,
                             ema_handler.momentum)
    engine.run(range(10))
Example #6
0
def test_ema_no_warmup_momentum(get_dummy_model):
    model = get_dummy_model()
    step_fn = _get_dummy_step_fn(model)
    engine = Engine(step_fn)

    def assert_const_momentum(engine: Engine, const_momentum):
        assert engine.state.ema_momentum == const_momentum

    # no momentum_warmup
    ema_handler = EMAHandler(model,
                             momentum=0.002,
                             momentum_warmup=None,
                             warmup_iters=1)
    ema_handler.attach(engine)
    # attach the assertion handler after ema_handler, so the momentum is first updated and then tested
    engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum,
                             ema_handler.momentum)
    engine.run(range(2))

    # no warmup_iters
    engine = Engine(step_fn)
    ema_handler = EMAHandler(model,
                             momentum=0.002,
                             momentum_warmup=0.001,
                             warmup_iters=None)
    ema_handler.attach(engine)
    # attach the assertion handler after ema_handler, so the momentum is first updated and then tested
    engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum,
                             ema_handler.momentum)
    engine.run(range(2))
Example #7
0
def test_ema_invalid_model():
    with pytest.raises(
            ValueError,
            match="model should be an instance of nn.Module or its subclasses"
    ):
        model = "Invalid Model"
        EMAHandler(model)  # type: ignore
def test_param_scheduler_with_ema_handler():

    from ignite.handlers import EMAHandler

    model = nn.Linear(2, 1)
    trainer = Engine(lambda e, b: model(b))
    data = torch.rand(100, 2)

    param_name = "ema_decay"

    ema_handler = EMAHandler(model)
    ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED)

    ema_decay_scheduler = PiecewiseLinearStateScheduler(
        param_name=param_name, milestones_values=[(0, 0.0), (10, 0.999)], save_history=True
    )
    ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED)
    trainer.run(data, max_epochs=20)
Example #9
0
def test_has_momentum_scheduler(get_dummy_model):
    """Test the handler has attribute `momentum_scheduler` and `_momentum_lambda_obj`"""
    momentum_warmup = 0.0
    warmup_iters = 10
    ema_handler = EMAHandler(get_dummy_model(),
                             momentum_warmup=momentum_warmup,
                             warmup_iters=warmup_iters)
    assert hasattr(ema_handler, "momentum_scheduler")
    assert hasattr(ema_handler, "_momentum_lambda_obj")
Example #10
0
def test_ema_invalid_momentum_start_end(get_dummy_model):
    """Test momentum_end > momentum_start"""
    momentum = 0.001
    momentum_warmup = 0.1
    with pytest.raises(
            ValueError,
            match="momentum_warmup should be less than or equal to momentum"):
        EMAHandler(get_dummy_model(),
                   momentum_warmup=momentum_warmup,
                   momentum=momentum)
Example #11
0
def test_ema_load_state_dict(get_dummy_model):
    model_1 = get_dummy_model()
    model_1.weight.data.fill_(2)
    state_dict_1 = model_1.state_dict()

    model_2 = get_dummy_model()
    ema_handler = EMAHandler(model_2)
    ema_model = ema_handler.ema_model
    ema_model.load_state_dict(state_dict_1)
    assert ema_model.weight.data.allclose(model_1.weight.data)
Example #12
0
def test_ema_ema_model_on_cuda(get_dummy_model):
    """Test if ema_handler.ema_model is nn.Module and under eval mode"""
    model = get_dummy_model().to(idist.device())
    model = idist.auto_model(model)
    ema_handler = EMAHandler(model)
    ema_model = ema_handler.ema_model
    assert (
        isinstance(ema_model, nn.Module)
        and not isinstance(ema_model, nn.parallel.DistributedDataParallel)
        and not isinstance(ema_model, nn.parallel.DataParallel)
    )
    assert not ema_model.training
Example #13
0
def test_ema_warmup_func(get_dummy_model):
    """Test the built-in linear warmup function for the EMA momentum"""
    momentum = 0.5
    momentum_warmup_1 = 0.0
    momentum_warmup_2 = 1.0
    warmup_iters = 5

    def check_ema_momentum(engine: Engine, momentum_warmup, final_momentum,
                           warmup_iters):
        if engine.state.iteration == 1:
            assert engine.state.ema_momentum == momentum_warmup
        elif engine.state.iteration >= 1 + warmup_iters:
            assert engine.state.ema_momentum == final_momentum
        else:
            min_momentum = min(momentum, momentum_warmup)
            max_momentum = max(momentum, momentum_warmup)
            assert min_momentum <= engine.state.ema_momentum <= max_momentum

    # momentum_warmup < momentum
    model_1 = get_dummy_model()
    engine_1 = Engine(_get_dummy_step_fn(model_1))
    ema_handler_1 = EMAHandler(model_1, momentum, momentum_warmup_1,
                               warmup_iters)
    ema_handler_1.attach(engine_1)
    engine_1.add_event_handler(Events.ITERATION_COMPLETED, check_ema_momentum,
                               momentum_warmup_1, momentum, warmup_iters)
    engine_1.run(range(10))

    # momentum_warmup > momentum
    model_2 = get_dummy_model()
    engine_2 = Engine(_get_dummy_step_fn(model_2))
    ema_handler_2 = EMAHandler(model_2, momentum, momentum_warmup_2,
                               warmup_iters)
    ema_handler_2.attach(engine_2)
    engine_2.add_event_handler(Events.ITERATION_COMPLETED, check_ema_momentum,
                               momentum_warmup_2, momentum, warmup_iters)
    engine_2.run(range(10))
Example #14
0
def test_ema_invalid_momentum(get_dummy_model, momentum):
    with pytest.raises(ValueError, match="Invalid momentum"):
        EMAHandler(get_dummy_model(), momentum=momentum)
Example #15
0
def test_ema_two_handlers(get_dummy_model):
    """Test when two EMA handlers are attached to a trainer"""
    model_1 = get_dummy_model()
    ema_handler_1 = EMAHandler(model_1, momentum=0.5)

    model_2 = get_dummy_model()
    ema_handler_2 = EMAHandler(model_2, momentum=0.5)

    def _step_fn(engine: Engine, batch: Any):
        model_1.weight.data.add_(1)
        model_2.weight.data.add_(1)
        return 0

    engine = Engine(_step_fn)
    assert not hasattr(engine.state, "ema_momentum_1")
    # handler_1 update EMA model of model_1 every 1 iteration
    ema_handler_1.attach(engine,
                         "ema_momentum_1",
                         event=Events.ITERATION_COMPLETED)
    assert hasattr(engine.state, "ema_momentum_1")

    # handler_2 update EMA model for model_2 every 2 iterations
    ema_handler_2.attach(engine,
                         "ema_momentum_2",
                         event=Events.ITERATION_COMPLETED(every=2))
    assert hasattr(engine.state, "ema_momentum_2")

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)
    # explicitly cast to float32 to avoid test failure on XLA devices
    ema_weight_1 = ema_handler_1.ema_model.weight.data.to(torch.float32)
    ema_weight_2 = ema_handler_2.ema_model.weight.data.to(torch.float32)
    assert ema_weight_1.allclose(ema_weight_1.new_full((1, 2), 4.0625))
    assert ema_weight_2.allclose(ema_weight_2.new_full((1, 2), 3.5))

    assert engine.state.ema_momentum_1 == 0.5
    assert engine.state.ema_momentum_2 == 0.5

    model_3 = get_dummy_model()
    ema_handler_3 = EMAHandler(model_3)
    with pytest.warns(UserWarning,
                      match="Attribute 'ema_momentum_1' already exists"):
        ema_handler_3.attach(engine, name="ema_momentum_1")
Example #16
0
def test_ema_two_handlers(get_dummy_model):
    """Test when two EMA handlers are attached to a trainer"""
    model_1 = get_dummy_model()
    # momentum will be constantly 0.5
    ema_handler_1 = EMAHandler(model_1, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)

    model_2 = get_dummy_model()
    ema_handler_2 = EMAHandler(model_2, momentum_warmup=0.5, momentum=0.5, warmup_iters=1)

    def _step_fn(engine: Engine, batch: Any):
        model_1.weight.data.add_(1)
        model_2.weight.data.add_(1)
        return 0

    engine = Engine(_step_fn)
    assert not hasattr(engine.state, "ema_momentum_1")
    # handler_1 update EMA model of model_1 every 1 iteration
    ema_handler_1.attach(engine, "ema_momentum_1", event=Events.ITERATION_COMPLETED)
    assert hasattr(engine.state, "ema_momentum_1")

    # handler_2 update EMA model for model_2 every 2 iterations
    ema_handler_2.attach(engine, "ema_momentum_2", event=Events.ITERATION_COMPLETED(every=2))
    assert hasattr(engine.state, "ema_momentum_2")

    # engine will run 4 iterations
    engine.run(range(2), max_epochs=2)
    ema_weight_1 = ema_handler_1.ema_model.weight.data
    ema_weight_2 = ema_handler_2.ema_model.weight.data
    torch.testing.assert_allclose(ema_weight_1, torch.full((1, 2), 4.0625))
    torch.testing.assert_allclose(ema_weight_2, torch.full((1, 2), 3.5))

    assert engine.state.ema_momentum_1 == 0.5
    assert engine.state.ema_momentum_2 == 0.5

    model_3 = get_dummy_model()
    ema_handler_3 = EMAHandler(model_3)
    with pytest.raises(ValueError, match="Please select another name"):
        ema_handler_3.attach(engine, "ema_momentum_2")
Example #17
0
def test_ema_invalid_momentum_warmup(get_dummy_model, momentum_warmup):
    with pytest.raises(ValueError, match="Invalid momentum_warmup"):
        EMAHandler(get_dummy_model, momentum_warmup=momentum_warmup)