def test_torch_save_load(): lambda_state_parameter_scheduler = LambdaStateScheduler( param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), ) torch.save(lambda_state_parameter_scheduler, "dummy_lambda_state_parameter_scheduler.pt") loaded_lambda_state_parameter_scheduler = torch.load( "dummy_lambda_state_parameter_scheduler.pt") engine1 = Engine(lambda e, b: None) lambda_state_parameter_scheduler.attach(engine1, Events.EPOCH_COMPLETED) engine1.run([0] * 8, max_epochs=2) torch.testing.assert_allclose( getattr(engine1.state, "custom_scheduled_param"), LambdaState(initial_value=10, gamma=0.99)(2)) engine2 = Engine(lambda e, b: None) loaded_lambda_state_parameter_scheduler.attach(engine2, Events.EPOCH_COMPLETED) engine2.run([0] * 8, max_epochs=2) torch.testing.assert_allclose( getattr(engine2.state, "custom_scheduled_param"), LambdaState(initial_value=10, gamma=0.99)(2)) torch.testing.assert_allclose( getattr(engine1.state, "custom_scheduled_param"), getattr(engine2.state, "custom_scheduled_param"))
def attach( self, engine: Engine, name: str = "ema_momentum", warn_if_exists: bool = True, event: Union[str, Events, CallableEventWithFilter, EventsList] = Events.ITERATION_COMPLETED, ) -> None: """Attach the handler to engine. After the handler is attached, the ``Engine.state`` will add an new attribute with name ``name`` if the attribute does not exist. Then, the current momentum can be retrieved from ``Engine.state`` when the engine runs. Note: There are two cases where a momentum with name ``name`` already exists: 1. the engine has loaded its state dict after resuming. In this case, there is no need to initialize the momentum again, and users can set ``warn_if_exists`` to False to suppress the warning message; 2. another handler has created a state attribute with the same name. In this case, users should choose another name for the ema momentum. Args: engine: trainer to which the handler will be attached. name: attribute name for retrieving EMA momentum from ``Engine.state``. It should be a unique name since a trainer can have multiple EMA handlers. warn_if_exists: if True, a warning will be thrown if the momentum with name ``name`` already exists. event: event when the EMA momentum and EMA model are updated. """ if hasattr(engine.state, name): if warn_if_exists: warnings.warn( f"Attribute '{name}' already exists in Engine.state. It might because 1. the engine has loaded its " f"state dict or 2. {name} is already created by other handlers. Turn off this warning by setting" f"warn_if_exists to False.", category=UserWarning, ) else: setattr(engine.state, name, self.momentum) if self._momentum_lambda_obj is not None: self.momentum_scheduler = LambdaStateScheduler( self._momentum_lambda_obj, param_name="ema_momentum") # first update the momentum and then update the EMA model self.momentum_scheduler.attach(engine, event) engine.add_event_handler(event, self._update_ema_model, name)
def test_custom_scheduler(): engine = Engine(lambda e, b: None) class LambdaState: def __init__(self, initial_value, gamma): self.initial_value = initial_value self.gamma = gamma def __call__(self, event_index): return self.initial_value * self.gamma**(event_index % 9) lambda_state_parameter_scheduler = LambdaStateScheduler( param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), ) lambda_state_parameter_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=2) torch.testing.assert_allclose( getattr(engine.state, "custom_scheduled_param"), LambdaState(initial_value=10, gamma=0.99)(2)) engine.run([0] * 8, max_epochs=20) torch.testing.assert_allclose( getattr(engine.state, "custom_scheduled_param"), LambdaState(initial_value=10, gamma=0.99)(20)) state_dict = lambda_state_parameter_scheduler.state_dict() lambda_state_parameter_scheduler.load_state_dict(state_dict)
def test_custom_scheduler_asserts(): class LambdaState: def __init__(self, initial_value, gamma): self.initial_value = initial_value self.gamma = gamma with pytest.raises(ValueError, match=r"Expected lambda_obj to be callable."): lambda_state_parameter_scheduler = LambdaStateScheduler( param_name="custom_scheduled_param", lambda_obj=LambdaState(initial_value=10, gamma=0.99), create_new=True )
def test_docstring_examples(): # LambdaStateScheduler engine = Engine(lambda e, b: None) class LambdaState: def __init__(self, initial_value, gamma): self.initial_value = initial_value self.gamma = gamma def __call__(self, event_index): return self.initial_value * self.gamma**(event_index % 9) param_scheduler = LambdaStateScheduler( param_name="param", lambda_obj=LambdaState(10, 0.99), ) param_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=2) # PiecewiseLinearStateScheduler engine = Engine(lambda e, b: None) param_scheduler = PiecewiseLinearStateScheduler(param_name="param", milestones_values=[ (10, 0.5), (20, 0.45), (21, 0.3), (30, 0.1), (40, 0.1) ]) param_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=40) # ExpStateScheduler engine = Engine(lambda e, b: None) param_scheduler = ExpStateScheduler(param_name="param", initial_value=10, gamma=0.99) param_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=2) # StepStateScheduler engine = Engine(lambda e, b: None) param_scheduler = StepStateScheduler(param_name="param", initial_value=10, gamma=0.99, step_size=5) param_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=10) # MultiStepStateScheduler engine = Engine(lambda e, b: None) param_scheduler = MultiStepStateScheduler( param_name="param", initial_value=10, gamma=0.99, milestones=[3, 6], ) param_scheduler.attach(engine, Events.EPOCH_COMPLETED) engine.run([0] * 8, max_epochs=10)
class EMAHandler: r"""Exponential moving average (EMA) handler can be used to compute a smoothed version of model. The EMA model is updated as follows: .. math:: \theta_{\text{EMA}, t+1} = (1 - \lambda) \cdot \theta_{\text{EMA}, t} + \lambda \cdot \theta_{t} where :math:`\theta_{\text{EMA}, t}` and :math:`\theta_{t}` are the EMA weights and online model weights at :math:`t`-th iteration, respectively; :math:`\lambda` is the update momentum. Current momentum can be retrieved from ``Engine.state.ema_momentum``. Args: model: the online model for which an EMA model will be computed. If ``model`` is ``DataParallel`` or ``DistributedDataParallel``, the EMA smoothing will be applied to ``model.module`` . momentum: the update momentum after warmup phase, should be float in range :math:`\left(0, 1 \right)`. momentum_warmup: the initial update momentum during warmup phase. warmup_iters: iterations of warmup. Attributes: ema_model: the exponential moving averaged model. model: the online model that is tracked by EMAHandler. It is ``model.module`` if ``model`` in the initialization method is an instance of ``DistributedDataParallel``. momentum: the update momentum. Note: The EMA model is already in ``eval`` mode. If model in the arguments is an ``nn.Module`` or ``DistributedDataParallel``, the EMA model is an ``nn.Module`` and it is on the same device as the online model. If the model is an ``nn.DataParallel``, then the EMA model is an ``nn.DataParallel``. Note: It is recommended to initialize and use an EMA handler in following steps: 1. Initialize ``model`` (``nn.Module`` or ``DistributedDataParallel``) and ``ema_handler`` (``EMAHandler``). 2. Build ``trainer`` (``ignite.engine.Engine``). 3. Resume from checkpoint for ``model`` and ``ema_handler.ema_model``. 4. Attach ``ema_handler`` to ``trainer``. Examples: .. code-block:: python device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # update the ema every 5 iterations ema_handler = EMAHandler(model, momentum=0.0002) # get the ema model, which is an instance of nn.Module ema_model = ema_handler.ema_model trainer = Engine(train_step_fn) to_load = {"model": model, "ema_model", ema_model, "trainer", trainer} if resume_from is not None: Checkpoint.load_objects(to_load, checkpoint=resume_from) # update the EMA model every 5 iterations ema_handler.attach(trainer, name="ema_momentum", event=Events.ITERATION_COMPLETED(every=5)) # add other handlers to_save = to_load ckpt_handler = Checkpoint(to_save, DiskSaver(...), ...) trainer.add_event_handler(Events.EPOCH_COMPLETED, ckpt_handler) # current momentum can be retrieved from engine.state, # the attribute name is the `name` parameter used in the attach function @trainer.on(Events.ITERATION_COMPLETED): def print_ema_momentum(engine): print(f"current momentum: {engine.state.ema_momentum}" # use ema model for validation val_step_fn = get_val_step_fn(ema_model) evaluator = Engine(val_step_fn) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): engine.run(val_data_loader) trainer.run(...) The following example shows how to perform warm-up to the EMA momentum: .. code-block:: python device = torch.device("cuda:0") model = nn.Linear(2, 1).to(device) # linearly change the EMA momentum from 0.2 to 0.002 in the first 100 iterations, # then keep a constant EMA momentum of 0.002 afterwards ema_handler = EMAHandler(model, momentum=0.002, momentum_warmup=0.2, warmup_iters=100) engine = Engine(step_fn) ema_handler.attach(engine, name="ema_momentum") engine.run(...) The following example shows how to attach two handlers to the same trainer: .. code-block:: python generator = build_generator(...) discriminator = build_discriminator(...) gen_handler = EMAHandler(generator) disc_handler = EMAHandler(discriminator) step_fn = get_step_fn(...) engine = Engine(step_fn) # update EMA model of generator every 1 iteration gen_handler.attach(engine, "gen_ema_momentum", event=Events.ITERATION_COMPLETED) # update EMA model of discriminator every 2 iteration disc_handler.attach(engine, "dis_ema_momentum", event=Events.ITERATION_COMPLETED(every=2)) @engine.on(Events.ITERATION_COMPLETED) def print_ema_momentum(engine): print(f"current momentum for generator: {engine.state.gen_ema_momentum}") print(f"current momentum for discriminator: {engine.state.disc_ema_momentum}") engine.run(...) .. versionadded:: 0.4.6 """ def __init__( self, model: nn.Module, momentum: float = 0.0002, momentum_warmup: Optional[float] = None, warmup_iters: Optional[int] = None, ) -> None: if not 0 < momentum < 1: raise ValueError(f"Invalid momentum: {momentum}") self.momentum = momentum self._momentum_lambda_obj: Optional[EMAWarmUp] = None if momentum_warmup is not None and warmup_iters is not None: self.momentum_scheduler: Optional[BaseParamScheduler] = None self._momentum_lambda_obj = EMAWarmUp(momentum_warmup, warmup_iters, momentum) if not isinstance(model, nn.Module): raise ValueError( f"model should be an instance of nn.Module or its subclasses, but got" f"model: {model.__class__.__name__}") if isinstance(model, nn.parallel.DistributedDataParallel): model = model.module self.model = model self.ema_model = deepcopy(self.model) for param in self.ema_model.parameters(): param.detach_() self.ema_model.eval() def _update_ema_model(self, engine: Engine, name: str) -> None: """Update weights of ema model""" momentum = getattr(engine.state, name) for ema_p, model_p in zip(self.ema_model.parameters(), self.model.parameters()): ema_p.mul_(1.0 - momentum).add_(model_p.data, alpha=momentum) # assign the buffers for ema_b, model_b in zip(self.ema_model.buffers(), self.model.buffers()): ema_b.data = model_b.data def attach( self, engine: Engine, name: str = "ema_momentum", warn_if_exists: bool = True, event: Union[str, Events, CallableEventWithFilter, EventsList] = Events.ITERATION_COMPLETED, ) -> None: """Attach the handler to engine. After the handler is attached, the ``Engine.state`` will add an new attribute with name ``name`` if the attribute does not exist. Then, the current momentum can be retrieved from ``Engine.state`` when the engine runs. Note: There are two cases where a momentum with name ``name`` already exists: 1. the engine has loaded its state dict after resuming. In this case, there is no need to initialize the momentum again, and users can set ``warn_if_exists`` to False to suppress the warning message; 2. another handler has created a state attribute with the same name. In this case, users should choose another name for the ema momentum. Args: engine: trainer to which the handler will be attached. name: attribute name for retrieving EMA momentum from ``Engine.state``. It should be a unique name since a trainer can have multiple EMA handlers. warn_if_exists: if True, a warning will be thrown if the momentum with name ``name`` already exists. event: event when the EMA momentum and EMA model are updated. """ if hasattr(engine.state, name): if warn_if_exists: warnings.warn( f"Attribute '{name}' already exists in Engine.state. It might because 1. the engine has loaded its " f"state dict or 2. {name} is already created by other handlers. Turn off this warning by setting" f"warn_if_exists to False.", category=UserWarning, ) else: setattr(engine.state, name, self.momentum) if self._momentum_lambda_obj is not None: self.momentum_scheduler = LambdaStateScheduler( self._momentum_lambda_obj, param_name="ema_momentum") # first update the momentum and then update the EMA model self.momentum_scheduler.attach(engine, event) engine.add_event_handler(event, self._update_ema_model, name)