def _test_ema_final_weight(model, device, ddp=False, interval=1): """Test if final smoothed weights are correct""" if isinstance(device, str): device = torch.device(device) model = model.to(device) if ddp: model = idist.auto_model(model) step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) # momentum will be constantly 0.5 ema_handler = EMAHandler(model, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) ema_handler.attach(engine, "model", event=Events.ITERATION_COMPLETED(every=interval)) # engine will run 4 iterations engine.run(range(2), max_epochs=2) ema_weight = ema_handler.ema_model.weight.data model_weight = model.weight.data assert ema_weight.device == device assert model_weight.device == device if interval == 1: torch.testing.assert_allclose(ema_weight, torch.full((1, 2), 4.0625, device=device)) elif interval == 2: torch.testing.assert_allclose(ema_weight, torch.full((1, 2), 3.5, device=device)) else: pass torch.testing.assert_allclose(model_weight, torch.full((1, 2), 5.0, device=device))
def test_ema_buffer(): """Test if the tensors in buffer are also synchronized""" model = nn.BatchNorm2d(2) model.running_mean.data.fill_(1.5) model.running_var.data.fill_(1.5) ema_handler = EMAHandler(model) def _bn_step_fn(engine, batch): x = torch.rand(4, 2, 32, 32) _ = model(x) return 1 engine = Engine(_bn_step_fn) ema_handler.attach(engine) ema_model = ema_handler.ema_model @engine.on(Events.ITERATION_COMPLETED) def check_buffers(): assert ema_model.running_mean.allclose(model.running_mean) assert ema_model.running_var.allclose(model.running_var) # engine will run 4 iterations engine.run([0, 1], max_epochs=2) assert ema_model.running_mean.allclose(model.running_mean) assert ema_model.running_var.allclose(model.running_var)
def _test_ema_final_weight(model, device=None, ddp=False, interval=1): """Test if final smoothed weights are correct""" if device is None: # let horovod decide the device device = idist.device() if isinstance(device, str): device = torch.device(device) model = model.to(device) if ddp: model = idist.auto_model(model) step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) ema_handler = EMAHandler(model, momentum=0.5) ema_handler.attach(engine, "model", event=Events.ITERATION_COMPLETED(every=interval)) # engine will run 4 iterations engine.run(range(2), max_epochs=2) # ema_model and model can be DP or DDP # explicitly cast to float32 to avoid test failure on XLA devices ema_weight = _unwrap_model(ema_handler.ema_model).weight.data.to( torch.float32) model_weight = _unwrap_model(model).weight.data.to(torch.float32) assert ema_weight.device == device assert model_weight.device == device if interval == 1: assert ema_weight.allclose(ema_weight.new_full((1, 2), 4.0625)) elif interval == 2: assert ema_weight.allclose(ema_weight.new_full((1, 2), 3.5)) else: pass assert model_weight.allclose(model_weight.new_full((1, 2), 5.0))
def test_ema_update_ema_momentum(get_dummy_model): model = get_dummy_model() step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) warmup_iters = 4 momentum_warmup = 0.1 momentum = 0.2 ema_handler = EMAHandler(model, momentum_warmup=momentum_warmup, momentum=momentum, warmup_iters=warmup_iters) ema_handler.attach(engine) # add handlers to check momentum at each iteration @engine.on(Events.ITERATION_COMPLETED) def assert_momentum(engine: Engine): curr_iter = engine.state.iteration curr_momentum = engine.state.ema_momentum if curr_iter == 1: assert curr_momentum == momentum_warmup elif 1 < curr_iter < warmup_iters: assert momentum_warmup < curr_momentum < momentum else: assert curr_momentum == momentum engine.run(range(2), max_epochs=5)
def test_ema_get_const_momentum(get_dummy_model): """Test if momentum retrieved from the engine is constant and equal to the handler's momentum""" model = get_dummy_model() step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) def assert_const_momentum(engine: Engine, const_momentum): assert engine.state.ema_momentum == const_momentum ema_handler = EMAHandler(model, momentum=0.002) ema_handler.attach(engine) engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum, ema_handler.momentum) engine.run(range(10))
def test_ema_no_warmup_momentum(get_dummy_model): model = get_dummy_model() step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) def assert_const_momentum(engine: Engine, const_momentum): assert engine.state.ema_momentum == const_momentum # no momentum_warmup ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=None, warmup_iters=1) ema_handler.attach(engine) # attach the assertion handler after ema_handler, so the momentum is first updated and then tested engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum, ema_handler.momentum) engine.run(range(2)) # no warmup_iters engine = Engine(step_fn) ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.001, warmup_iters=None) ema_handler.attach(engine) # attach the assertion handler after ema_handler, so the momentum is first updated and then tested engine.add_event_handler(Events.ITERATION_COMPLETED, assert_const_momentum, ema_handler.momentum) engine.run(range(2))
def test_ema_invalid_model(): with pytest.raises( ValueError, match="model should be an instance of nn.Module or its subclasses" ): model = "Invalid Model" EMAHandler(model) # type: ignore
def test_param_scheduler_with_ema_handler(): from ignite.handlers import EMAHandler model = nn.Linear(2, 1) trainer = Engine(lambda e, b: model(b)) data = torch.rand(100, 2) param_name = "ema_decay" ema_handler = EMAHandler(model) ema_handler.attach(trainer, name=param_name, event=Events.ITERATION_COMPLETED) ema_decay_scheduler = PiecewiseLinearStateScheduler( param_name=param_name, milestones_values=[(0, 0.0), (10, 0.999)], save_history=True ) ema_decay_scheduler.attach(trainer, Events.ITERATION_COMPLETED) trainer.run(data, max_epochs=20)
def test_has_momentum_scheduler(get_dummy_model): """Test the handler has attribute `momentum_scheduler` and `_momentum_lambda_obj`""" momentum_warmup = 0.0 warmup_iters = 10 ema_handler = EMAHandler(get_dummy_model(), momentum_warmup=momentum_warmup, warmup_iters=warmup_iters) assert hasattr(ema_handler, "momentum_scheduler") assert hasattr(ema_handler, "_momentum_lambda_obj")
def test_ema_invalid_momentum_start_end(get_dummy_model): """Test momentum_end > momentum_start""" momentum = 0.001 momentum_warmup = 0.1 with pytest.raises( ValueError, match="momentum_warmup should be less than or equal to momentum"): EMAHandler(get_dummy_model(), momentum_warmup=momentum_warmup, momentum=momentum)
def test_ema_load_state_dict(get_dummy_model): model_1 = get_dummy_model() model_1.weight.data.fill_(2) state_dict_1 = model_1.state_dict() model_2 = get_dummy_model() ema_handler = EMAHandler(model_2) ema_model = ema_handler.ema_model ema_model.load_state_dict(state_dict_1) assert ema_model.weight.data.allclose(model_1.weight.data)
def test_ema_ema_model_on_cuda(get_dummy_model): """Test if ema_handler.ema_model is nn.Module and under eval mode""" model = get_dummy_model().to(idist.device()) model = idist.auto_model(model) ema_handler = EMAHandler(model) ema_model = ema_handler.ema_model assert ( isinstance(ema_model, nn.Module) and not isinstance(ema_model, nn.parallel.DistributedDataParallel) and not isinstance(ema_model, nn.parallel.DataParallel) ) assert not ema_model.training
def test_ema_warmup_func(get_dummy_model): """Test the built-in linear warmup function for the EMA momentum""" momentum = 0.5 momentum_warmup_1 = 0.0 momentum_warmup_2 = 1.0 warmup_iters = 5 def check_ema_momentum(engine: Engine, momentum_warmup, final_momentum, warmup_iters): if engine.state.iteration == 1: assert engine.state.ema_momentum == momentum_warmup elif engine.state.iteration >= 1 + warmup_iters: assert engine.state.ema_momentum == final_momentum else: min_momentum = min(momentum, momentum_warmup) max_momentum = max(momentum, momentum_warmup) assert min_momentum <= engine.state.ema_momentum <= max_momentum # momentum_warmup < momentum model_1 = get_dummy_model() engine_1 = Engine(_get_dummy_step_fn(model_1)) ema_handler_1 = EMAHandler(model_1, momentum, momentum_warmup_1, warmup_iters) ema_handler_1.attach(engine_1) engine_1.add_event_handler(Events.ITERATION_COMPLETED, check_ema_momentum, momentum_warmup_1, momentum, warmup_iters) engine_1.run(range(10)) # momentum_warmup > momentum model_2 = get_dummy_model() engine_2 = Engine(_get_dummy_step_fn(model_2)) ema_handler_2 = EMAHandler(model_2, momentum, momentum_warmup_2, warmup_iters) ema_handler_2.attach(engine_2) engine_2.add_event_handler(Events.ITERATION_COMPLETED, check_ema_momentum, momentum_warmup_2, momentum, warmup_iters) engine_2.run(range(10))
def test_ema_invalid_momentum(get_dummy_model, momentum): with pytest.raises(ValueError, match="Invalid momentum"): EMAHandler(get_dummy_model(), momentum=momentum)
def test_ema_two_handlers(get_dummy_model): """Test when two EMA handlers are attached to a trainer""" model_1 = get_dummy_model() ema_handler_1 = EMAHandler(model_1, momentum=0.5) model_2 = get_dummy_model() ema_handler_2 = EMAHandler(model_2, momentum=0.5) def _step_fn(engine: Engine, batch: Any): model_1.weight.data.add_(1) model_2.weight.data.add_(1) return 0 engine = Engine(_step_fn) assert not hasattr(engine.state, "ema_momentum_1") # handler_1 update EMA model of model_1 every 1 iteration ema_handler_1.attach(engine, "ema_momentum_1", event=Events.ITERATION_COMPLETED) assert hasattr(engine.state, "ema_momentum_1") # handler_2 update EMA model for model_2 every 2 iterations ema_handler_2.attach(engine, "ema_momentum_2", event=Events.ITERATION_COMPLETED(every=2)) assert hasattr(engine.state, "ema_momentum_2") # engine will run 4 iterations engine.run(range(2), max_epochs=2) # explicitly cast to float32 to avoid test failure on XLA devices ema_weight_1 = ema_handler_1.ema_model.weight.data.to(torch.float32) ema_weight_2 = ema_handler_2.ema_model.weight.data.to(torch.float32) assert ema_weight_1.allclose(ema_weight_1.new_full((1, 2), 4.0625)) assert ema_weight_2.allclose(ema_weight_2.new_full((1, 2), 3.5)) assert engine.state.ema_momentum_1 == 0.5 assert engine.state.ema_momentum_2 == 0.5 model_3 = get_dummy_model() ema_handler_3 = EMAHandler(model_3) with pytest.warns(UserWarning, match="Attribute 'ema_momentum_1' already exists"): ema_handler_3.attach(engine, name="ema_momentum_1")
def test_ema_two_handlers(get_dummy_model): """Test when two EMA handlers are attached to a trainer""" model_1 = get_dummy_model() # momentum will be constantly 0.5 ema_handler_1 = EMAHandler(model_1, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) model_2 = get_dummy_model() ema_handler_2 = EMAHandler(model_2, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) def _step_fn(engine: Engine, batch: Any): model_1.weight.data.add_(1) model_2.weight.data.add_(1) return 0 engine = Engine(_step_fn) assert not hasattr(engine.state, "ema_momentum_1") # handler_1 update EMA model of model_1 every 1 iteration ema_handler_1.attach(engine, "ema_momentum_1", event=Events.ITERATION_COMPLETED) assert hasattr(engine.state, "ema_momentum_1") # handler_2 update EMA model for model_2 every 2 iterations ema_handler_2.attach(engine, "ema_momentum_2", event=Events.ITERATION_COMPLETED(every=2)) assert hasattr(engine.state, "ema_momentum_2") # engine will run 4 iterations engine.run(range(2), max_epochs=2) ema_weight_1 = ema_handler_1.ema_model.weight.data ema_weight_2 = ema_handler_2.ema_model.weight.data torch.testing.assert_allclose(ema_weight_1, torch.full((1, 2), 4.0625)) torch.testing.assert_allclose(ema_weight_2, torch.full((1, 2), 3.5)) assert engine.state.ema_momentum_1 == 0.5 assert engine.state.ema_momentum_2 == 0.5 model_3 = get_dummy_model() ema_handler_3 = EMAHandler(model_3) with pytest.raises(ValueError, match="Please select another name"): ema_handler_3.attach(engine, "ema_momentum_2")
def test_ema_invalid_momentum_warmup(get_dummy_model, momentum_warmup): with pytest.raises(ValueError, match="Invalid momentum_warmup"): EMAHandler(get_dummy_model, momentum_warmup=momentum_warmup)