def _test(epoch_length=None): max_epochs = 10 num_iters = 21 data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_epoch in range(1, max_epochs): batch_checker = BatchChecker(data, init_counter=resume_epoch * epoch_length) def update_fn(engine, batch): assert batch_checker.check(batch), \ "{} | {}: {} vs {}".format( resume_epoch, batch_checker.counter, batch_checker.true_batch, batch) engine = Engine(update_fn) resume_state_dict = dict(epoch=resume_epoch, max_epochs=max_epochs, epoch_length=epoch_length, seed=0) engine.load_state_dict(resume_state_dict) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def _test(epoch_length=None): max_epochs = 5 num_iters = 21 data = torch.randint(0, 1000, size=(num_iters, )) if epoch_length is None: epoch_length = num_iters for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 4): batch_checker = BatchChecker(data, init_counter=resume_iteration) def update_fn(engine, batch): assert batch_checker.check(batch), \ "{} | {}: {} vs {}".format( resume_iteration, batch_checker.counter, batch_checker.true_batch, batch) engine = Engine(update_fn) @engine.on(Events.EPOCH_COMPLETED) def check_iteration(engine): assert engine.state.iteration == batch_checker.counter resume_state_dict = { "iteration": resume_iteration, "max_epochs": max_epochs, "epoch_length": epoch_length, "seed": 0 } engine.load_state_dict(resume_state_dict) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs
def test_load_state_dict_asserts(): engine = Engine(lambda e, b: 1) with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): engine.load_state_dict("123") with pytest.raises(ValueError, match=r"is absent in provided state_dict"): engine.load_state_dict({}) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): engine.load_state_dict({"max_epochs": 100, "epoch_length": 120}) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): engine.load_state_dict({ "max_epochs": 100, "epoch_length": 120, "iteration": 12, "epoch": 123 }) engine = Engine(lambda e, b: 1) engine.state_dict_user_keys.append("alpha") with pytest.raises(ValueError, match=r"Required user state attribute"): engine.load_state_dict({ "max_epochs": 100, "epoch_length": 120, "iteration": 12 })
def test_load_state_dict_integration(): engine = Engine(lambda e, b: 1) state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5} engine.load_state_dict(state_dict) engine.add_event_handler(Events.ITERATION_COMPLETED, IterationCounter(5 * 120 + 1)) engine.add_event_handler(Events.EPOCH_COMPLETED, EpochCounter(6)) data = range(120)
def test_load_state_dict_asserts(): engine = Engine(lambda e, b: 1) with pytest.raises(TypeError, match=r"Argument state_dict should be a dictionary"): engine.load_state_dict("123") with pytest.raises(ValueError, match=r"is absent in provided state_dict"): engine.load_state_dict({}) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): engine.load_state_dict({ "seed": 0, "max_epochs": 100, "epoch_length": 120 }) with pytest.raises(ValueError, match=r"state_dict should contain only one of"): engine.load_state_dict({ "seed": 0, "max_epochs": 100, "epoch_length": 120, "iteration": 12, "epoch": 123 })
def _test(epoch_length=None): max_epochs = 3 batch_size = 4 num_iters = 17 def infinite_data_iterator(): torch.manual_seed(0) while True: for _ in range(num_iters): data = torch.randint(0, 1000, size=(batch_size, ), device=device) yield data if epoch_length is None: epoch_length = num_iters seen_batchs = [] def update_fn(engine, batch): seen_batchs.append(batch) engine = Engine(update_fn), max_epochs=max_epochs, seed=12, epoch_length=epoch_length) for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 7): batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) def update_fn(engine, batch): assert batch_checker.check(batch), "{} | {}: {} vs {}".format( resume_iteration, batch_checker.counter, batch_checker.true_batch, batch) engine = Engine(update_fn) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, seed=12) engine.load_state_dict(resume_state_dict) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs, "{} | {} vs {}".format( resume_iteration, engine.state.iteration, epoch_length * max_epochs)
def _test(with_load_state_dict=False): engine = Engine(lambda e, b: None) engine.state.alpha = 0.0 engine.state.beta = 1.0 if with_load_state_dict: engine.load_state_dict({"iteration": 3, "max_epochs": 5, "epoch_length": 5}) @engine.on(Events.STARTED | Events.EPOCH_STARTED | Events.EPOCH_COMPLETED | Events.COMPLETED) def check_custom_attr(): assert hasattr(engine.state, "alpha") and engine.state.alpha == 0.0 assert hasattr(engine.state, "beta") and engine.state.beta == 1.0[0, 1, 2, 3, 4], max_epochs=5)
def test_load_state_dict_integration(counter_factory): engine = Engine(lambda e, b: 1) state_dict = { "seed": 0, "max_epochs": 100, "epoch_length": 120, "epoch": 5 } engine.load_state_dict(state_dict) engine.add_event_handler(Events.ITERATION_COMPLETED, counter_factory('iter', 5 * 120 + 1)) engine.add_event_handler(Events.EPOCH_COMPLETED, counter_factory('epoch', 6)) data = list(range(120))
def test_load_state_dict_with_params_overriding_integration(): state_dict = {"max_epochs": 100, "epoch_length": 120, "epoch": 5} data = range(120) # Override max_epochs new_max_epochs = 10 engine = Engine(lambda e, b: 1) engine.load_state_dict(state_dict) state =, max_epochs=new_max_epochs) assert state.max_epochs == new_max_epochs assert state.iteration == state_dict["epoch_length"] * new_max_epochs assert state.epoch == new_max_epochs with pytest.raises( ValueError, match=r"Argument max_epochs should be larger than the start epoch" ): engine.load_state_dict(state_dict), max_epochs=3) # Override epoch_length with pytest.raises( ValueError, match=r"Argument epoch_length should be same as in the state"): engine.load_state_dict(state_dict), epoch_length=90)
def attach( self, trainer: Engine, to_save: Mapping, output_transform: Callable = lambda output: output, num_iter: Optional[int] = None, end_lr: float = 10.0, step_mode: str = "exp", smooth_f: float = 0.05, diverge_th: float = 5.0, ): """Attaches lr_finder to a given trainer. It also resets model and optimizer at the end of the run. Usage: .. code-block:: python to_save = {"model": model, "optimizer": optimizer} with lr_finder.attach(trainer, to_save=to_save) as trainer_with_lr_finder:` Args: trainer (Engine): lr_finder is attached to this trainer. Please, keep in mind that all attached handlers will be executed. to_save (Mapping): dictionary with optimizer and other objects that needs to be restored after running the LR finder. For example, `to_save={'optimizer': optimizer, 'model': model}`. All objects should implement `state_dict` and `load_state_dict` methods. output_transform (callable, optional): function that transforms the trainer's `state.output` after each iteration. It must return the loss of that iteration. num_iter (int, optional): number of iterations for lr schedule between base lr and end_lr. Default, it will run for `trainer.state.epoch_length * trainer.state.max_epochs`. end_lr (float, optional): upper bound for lr search. Default, 10.0. step_mode (str, optional): "exp" or "linear", which way should the lr be increased from optimizer's initial lr to `end_lr`. Default, "exp". smooth_f (float, optional): loss smoothing factor in range `[0, 1)`. Default, 0.05 diverge_th (float, optional): Used for stopping the search when `current loss > diverge_th * best_loss`. Default, 5.0. Note: lr_finder cannot be attached to more than one trainer at a time. Returns: trainer_with_lr_finder: trainer used for finding the lr """ if not isinstance(to_save, Mapping): raise TypeError("Argument to_save should be a mapping, but given {}".format(type(to_save))) Checkpoint._check_objects(to_save, "state_dict") Checkpoint._check_objects(to_save, "load_state_dict") if "optimizer" not in to_save: raise ValueError("Mapping to_save should contain 'optimizer' key") if not isinstance(to_save["optimizer"], torch.optim.Optimizer): raise TypeError( "Object to_save['optimizer'] should be torch optimizer, but given {}".format(type(to_save["optimizer"])) ) if smooth_f < 0 or smooth_f >= 1: raise ValueError("smooth_f is outside the range [0, 1]") if diverge_th < 1: raise ValueError("diverge_th should be larger than 1") if step_mode not in ["exp", "linear"]: raise ValueError("step_mode should be 'exp' or 'linear', but given {}".format(step_mode)) if num_iter is not None: if not isinstance(num_iter, int): raise TypeError("if provided, num_iter should be an integer, but give {}".format(num_iter)) if num_iter <= 0: raise ValueError("if provided, num_iter should be positive, but give {}".format(num_iter)) # store to_save with tempfile.TemporaryDirectory() as tmpdirname: obj = {k: o.state_dict() for k, o in to_save.items()} # add trainer obj["trainer"] = trainer.state_dict() cache_filepath = Path(tmpdirname) / "", cache_filepath.as_posix()) optimizer = to_save["optimizer"] # Attach handlers if not trainer.has_event_handler(self._run): trainer.add_event_handler( Events.STARTED, self._run, optimizer, output_transform, num_iter, end_lr, step_mode, smooth_f, diverge_th, ) if not trainer.has_event_handler(self._warning): trainer.add_event_handler(Events.COMPLETED, self._warning) if not trainer.has_event_handler(self._reset): trainer.add_event_handler(Events.COMPLETED, self._reset) yield trainer self._detach(trainer) # restore to_save and reset trainer's state obj = torch.load(cache_filepath.as_posix()) trainer.load_state_dict(obj["trainer"]) for k, o in obj.items(): if k in to_save: to_save[k].load_state_dict(o)
def test_empty_state_dict_load_state_dict(): engine = Engine(lambda e, b: 1) sd = engine.state_dict() engine.load_state_dict(sd)
def _test(epoch_length=None): max_epochs = 3 batch_size = 4 num_iters = 17 data = torch.randint(0, 1000, size=(num_iters * batch_size, )) if epoch_length is None: epoch_length = num_iters for num_workers in [0, 4]: sampler = _setup_sampler(sampler_type, num_iters, batch_size) orig_dataloader = data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None) seen_batchs = [] def update_fn(engine, batch): batch_to_device = seen_batchs.append(batch) engine = Engine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch), max_epochs=max_epochs, seed=12, epoch_length=epoch_length) for resume_iteration in range( 1, min(num_iters * max_epochs, epoch_length * max_epochs), 7): batch_checker = BatchChecker(seen_batchs, init_counter=resume_iteration) sampler = _setup_sampler(sampler_type, num_iters, batch_size) resume_dataloader = data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, sampler=sampler, drop_last=True, shuffle=sampler is None) def update_fn(engine, batch): batch_to_device = assert batch_checker.check(batch), \ "{} {} | {}: {} vs {}".format( num_workers, resume_iteration, batch_checker.counter, batch_checker.true_batch, batch) engine = Engine(update_fn) if sampler_type == "distributed": @engine.on(Events.EPOCH_STARTED) def _(engine): sampler.set_epoch(engine.state.epoch) resume_state_dict = dict(iteration=resume_iteration, max_epochs=max_epochs, epoch_length=epoch_length, seed=12) engine.load_state_dict(resume_state_dict) assert engine.state.epoch == max_epochs assert engine.state.iteration == epoch_length * max_epochs, \ "{}, {} | {} vs {}".format(num_workers, resume_iteration, engine.state.iteration, epoch_length * max_epochs)