def test_state_dict_integration(): engine = Engine(lambda e, b: 1) data = range(100) engine.run(data, max_epochs=10) sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 assert sd["iteration"] == engine.state.iteration == 10 * 100 assert sd["epoch_length"] == engine.state.epoch_length == 100 assert sd["max_epochs"] == engine.state.max_epochs == 10
def test_state_dict_integration(): engine = Engine(lambda e, b: 1) data = list(range(100)) engine.run(data, max_epochs=10, seed=17) sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == 4 assert sd['seed'] == engine.state.seed assert sd['iteration'] == engine.state.iteration == 10 * 100 assert sd['epoch_length'] == engine.state.epoch_length == 100 assert sd['max_epochs'] == engine.state.max_epochs == 10
def test_state_dict(): engine = Engine(lambda e, b: 1) sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == 0 def _test(state): engine.state = state sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == len(engine._state_dict_all_req_keys) + 1 assert sd["iteration"] == engine.state.iteration assert sd["epoch_length"] == engine.state.epoch_length assert sd["max_epochs"] == engine.state.max_epochs _test(State(iteration=500, epoch_length=1000, max_epochs=100)) _test(State(epoch=5, epoch_length=1000, max_epochs=100))
def test_state_dict(): engine = Engine(lambda e, b: 1) sd = engine.state_dict() assert isinstance(sd, Mapping) and len(sd) == 0 def _test(state): engine.state = state sd = engine.state_dict() assert isinstance(sd, Mapping) and \ len(sd) == len(engine._state_dict_all_req_keys) + 1 assert sd['seed'] == engine.state.seed assert sd['iteration'] == engine.state.iteration assert sd['epoch_length'] == engine.state.epoch_length assert sd['max_epochs'] == engine.state.max_epochs _test(State(seed=0, iteration=500, epoch_length=1000, max_epochs=100)) _test(State(seed=0, epoch=5, epoch_length=1000, max_epochs=100))
def post_epoch_actions(trainer_instance: Engine): # evaluate model on validation set evaluator.run(val_loader) state_val_metrics = evaluator.state.metrics current_epoch: int = trainer_instance.state.epoch with tune.checkpoint_dir(current_epoch) as local_checkpoint_dir: # save model, optimizer and trainer checkpoints path = os.path.join(local_checkpoint_dir, "checkpoint") torch.save( (model.state_dict(), optimizer.state_dict(), trainer_instance.state_dict(), evaluator.state_dict()), path) # report validation scores to ray-tune report_dict: dict = { **state_val_metrics, "done": current_epoch == epochs } tune.report(**report_dict)
def attach( self, trainer: Engine, to_save: Mapping, output_transform: Callable = lambda output: output, num_iter: Optional[int] = None, end_lr: float = 10.0, step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, ): """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: .. code-block:: python to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder: trainer_with_lr_finder.run(dataloader)` Args: trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers will be executed. to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should implement `state_dict` and `load_state_dict` methods. output_transform (callable, optional): function that transforms the trainer's `state.output` after each iteration. It must return the loss of that iteration. num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will run for `trainer.state.epoch_length * trainer.state.max_epochs`. end_lr (float, optional): upper bound for lr search. Default, 10.0. step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial lr to `end_lr`. Default, "exp". smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05 diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`. Default, 5.0. Note: lr_finder cannot be attached to more than one trainer at a time. Returns: trainer_with_lr_finder: trainer used for finding the lr """ if not isinstance(to_save, Mapping): raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save))) Checkpoint._check_objects(to_save, "state_dict") Checkpoint._check_objects(to_save, "load_state_dict") if "optimizer" not in to_save: raise ValueError("Mapping to_save should contain 'optimizer' key") if not isinstance(to_save["optimizer"], torch.optim.Optimizer): raise TypeError( "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"])) ) if smooth_f < 0 or smooth_f >= 1: raise ValueError("smooth_f is outside the range [0, 1]") if diverge_th < 1: raise ValueError("diverge_th should be larger than 1") if step_mode not in ["exp", "linear"]: raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode)) if num_iter is not None: if not isinstance(num_iter, int): raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter)) if num_iter <= 0: raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter)) # store to_save with tempfile.TemporaryDirectory() as tmpdirname: obj = {k: o.state_dict() for k, o in to_save.items()} # add trainer obj["trainer"] = trainer.state_dict() cache_filepath = Path(tmpdirname) / "ignite_lr_finder_cache.pt" torch.save(obj, cache_filepath.as_posix()) optimizer = to_save["optimizer"] # Attach handlers if not trainer.has_event_handler(self._run): trainer.add_event_handler( Events.STARTED, self._run, optimizer, output_transform, num_iter, end_lr, step_mode, smooth_f, diverge_th, ) if not trainer.has_event_handler(self._warning): trainer.add_event_handler(Events.COMPLETED, self._warning) if not trainer.has_event_handler(self._reset): trainer.add_event_handler(Events.COMPLETED, self._reset) yield trainer self._detach(trainer) # restore to_save and reset trainer's state obj = torch.load(cache_filepath.as_posix()) trainer.load_state_dict(obj["trainer"]) for k, o in obj.items(): if k in to_save: to_save[k].load_state_dict(o)
def test_empty_state_dict_load_state_dict(): engine = Engine(lambda e, b: 1) sd = engine.state_dict() engine.load_state_dict(sd)