def _setup_common_distrib_training_handlers( trainer, train_sampler=None, to_save=None, save_every_iters=1000, output_path=None, lr_scheduler=None, with_gpu_stats=True, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, device="cuda", ): if not (dist.is_available() and dist.is_initialized()): raise RuntimeError( "Distributed setting is not initialized, please call `dist.init_process_group` before." ) _setup_common_training_handlers( trainer, to_save=None, lr_scheduler=lr_scheduler, with_gpu_stats=with_gpu_stats, output_names=output_names, with_pbars=(dist.get_rank() == 0) and with_pbars, with_pbar_on_iters=with_pbar_on_iters, log_every_iters=log_every_iters, device=device, ) if train_sampler is not None: if not callable(getattr(train_sampler, "set_epoch", None)): raise TypeError("Train sampler should have `set_epoch` method") @trainer.on(Events.EPOCH_STARTED) def distrib_set_epoch(engine): train_sampler.set_epoch(engine.state.epoch - 1) if dist.get_rank() == 0: if to_save is not None: if output_path is None: raise ValueError( "If to_save argument is provided then output_path argument should be also defined" ) checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="training", require_empty=False) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler, to_save)
def setup_any_logging(logger, logger_module, trainer, optimizers, evaluators, log_every_iters): if optimizers is not None: from torch.optim.optimizer import Optimizer if not isinstance(optimizers, (Optimizer, Mapping)): raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers") if evaluators is not None: if not isinstance(evaluators, (Engine, Mapping)): raise TypeError("Argument optimizers should be either a single optimizer or a dictionary or optimizers") if log_every_iters is None: log_every_iters = 1 logger.attach(trainer, log_handler=logger_module.OutputHandler(tag="training", metric_names='all'), event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) if optimizers is not None: # Log optimizer parameters if isinstance(optimizers, Optimizer): optimizers = {None: optimizers} for k, optimizer in optimizers.items(): logger.attach(trainer, log_handler=logger_module.OptimizerParamsHandler(optimizer, param_name="lr", tag=k), event_name=Events.ITERATION_STARTED(every=log_every_iters)) if evaluators is not None: # Log evaluation metrics if isinstance(evaluators, Engine): evaluators = {"validation": evaluators} for k, evaluator in evaluators.items(): gst = global_step_from_engine(trainer) logger.attach(evaluator, log_handler=logger_module.OutputHandler(tag=k, metric_names='all', global_step_transform=gst), event_name=Events.COMPLETED)
def test_pbar_wrong_events_order(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED, ) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach( engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED, ) with pytest.raises(ValueError, match="Closing event should not use any event filter"): pbar.attach( engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10), )
def test_lr_suggestion_mnist(lr_finder, mnist_to_save, dummy_engine_mnist, mnist_dataloader): max_iters = 50 with lr_finder.attach(dummy_engine_mnist, mnist_to_save) as trainer_with_finder: with trainer_with_finder.add_event_handler( Events.ITERATION_COMPLETED(once=max_iters), lambda _: trainer_with_finder.terminate()): trainer_with_finder.run(mnist_dataloader) assert 1e-4 <= lr_finder.lr_suggestion() <= 2
def create_vae_engines( model, optimizer, criterion=None, metrics=None, device=None, non_blocking=False, fig_dir=None, unflatten=None, ): device = model.device if criterion is None: criterion = get_default_autoencoder_loss() train_step = create_vae_train_step(model, optimizer, criterion, device=device, non_blocking=non_blocking) eval_step = create_vae_eval_step(model, device=device, non_blocking=non_blocking) if metrics is None: metrics = {} metrics.setdefault( "loss", Loss(criterion, output_transform=loss_eval_output_transform), ) metrics.setdefault("mse", MeanSquaredError(output_transform=lambda x: x[:2])) trainer = Engine(train_step) evaluator = create_autoencoder_evaluator(eval_step, metrics=metrics) save_image_callback = create_save_image_callback(fig_dir, unflatten=unflatten) def _epoch_getter(): return trainer.state.__dict__.get("epoch", None) evaluator.add_event_handler( Events.ITERATION_COMPLETED(once=1), save_image_callback, epoch=_epoch_getter, ) val_log_handler, val_logger = create_log_handler(trainer) return trainer, evaluator, val_log_handler, val_logger
def test_as_context_manager(): n_epochs = 5 data = list(range(50)) class _DummyLogger(DummyLogger): def __init__(self, writer): self.writer = writer def close(self): self.writer.close() def _test(event, n_calls): global close_counter close_counter = 0 losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) writer = MagicMock() writer.close = MagicMock() with _DummyLogger(writer) as logger: assert isinstance(logger, _DummyLogger) trainer = Engine(update_fn) mock_log_handler = MagicMock() logger.attach(trainer, log_handler=mock_log_handler, event_name=event) trainer.run(data, max_epochs=n_epochs) mock_log_handler.assert_called_with(trainer, logger, event) assert mock_log_handler.call_count == n_calls writer.close.assert_called_once_with() _test(Events.ITERATION_STARTED, len(data) * n_epochs) _test(Events.ITERATION_COMPLETED, len(data) * n_epochs) _test(Events.EPOCH_STARTED, n_epochs) _test(Events.EPOCH_COMPLETED, n_epochs) _test(Events.STARTED, 1) _test(Events.COMPLETED, 1) _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)
def test_callable_events_with_wrong_inputs(): with pytest.raises( ValueError, match=r"Only one of the input arguments should be specified"): Events.ITERATION_STARTED() with pytest.raises( ValueError, match=r"Only one of the input arguments should be specified"): Events.ITERATION_STARTED(event_filter="123", every=12) with pytest.raises(TypeError, match=r"Argument event_filter should be a callable"): Events.ITERATION_STARTED(event_filter="123") with pytest.raises( ValueError, match=r"Argument every should be integer and greater than one"): Events.ITERATION_STARTED(every=-1) with pytest.raises(ValueError, match=r"but will be called with"): Events.ITERATION_STARTED(event_filter=lambda x: x)
def create_trainer_and_evaluators( model: nn.Module, optimizer: Optimizer, criterion: nn.Module, data_loaders: Dict[str, DataLoader], metrics: Dict[str, Metric], config: ConfigSchema, logger: Logger, ) -> Tuple[Engine, Dict[str, Engine]]: trainer = get_trainer(model, criterion, optimizer) trainer.logger = logger evaluators = get_evaluators(model, metrics) setup_evaluation(trainer, evaluators, data_loaders, logger) lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"]) to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } common.setup_common_training_handlers( trainer=trainer, to_save=to_save, save_every_iters=config.checkpoint_every, save_handler=get_save_handler(config), with_pbars=False, train_sampler=data_loaders["train"].sampler, ) trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters), ) resume_from = config.resume_from if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer, evaluators
def test_attach(): n_epochs = 5 data = list(range(50)) def _test(event, n_calls, kwargs={}): losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) trainer = Engine(update_fn) logger = DummyLogger() mock_log_handler = MagicMock() logger.attach(trainer, log_handler=mock_log_handler, event_name=event, **kwargs) trainer.run(data, max_epochs=n_epochs) if isinstance(event, EventsList): events = [e for e in event] else: events = [event] if len(kwargs) > 0: calls = [call(trainer, logger, e, **kwargs) for e in events] else: calls = [call(trainer, logger, e) for e in events] mock_log_handler.assert_has_calls(calls) assert mock_log_handler.call_count == n_calls _test(Events.ITERATION_STARTED, len(data) * n_epochs, kwargs={"a": 0}) _test(Events.ITERATION_COMPLETED, len(data) * n_epochs) _test(Events.EPOCH_STARTED, n_epochs) _test(Events.EPOCH_COMPLETED, n_epochs) _test(Events.STARTED, 1) _test(Events.COMPLETED, 1) _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs) _test(Events.STARTED | Events.COMPLETED, 2)
def run(subj_ind: int, result_name: str, dataset_path: str, deep4_path: str, result_path: str, config: dict = default_config, model_builder: ProgressiveModelBuilder = default_model_builder): result_path_subj = os.path.join(result_path, result_name, str(subj_ind)) os.makedirs(result_path_subj, exist_ok=True) joblib.dump(config, os.path.join(result_path_subj, 'config.dict'), compress=False) joblib.dump(model_builder, os.path.join(result_path_subj, 'model_builder.jblb'), compress=True) # create discriminator and generator modules discriminator = model_builder.build_discriminator() generator = model_builder.build_generator() # initiate weights generator.apply(weight_filler) discriminator.apply(weight_filler) # trainer engine trainer = GanSoftplusTrainer(10, discriminator, generator, config['r1_gamma'], config['r2_gamma']) # handles potential progression after each epoch progression_handler = ProgressionHandler( discriminator, generator, config['n_stages'], config['use_fade'], config['n_epochs_fade'], freeze_stages=config['freeze_stages']) progression_handler.set_progression(0, 1.) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), progression_handler.advance_alpha) generator.train() discriminator.train() train(subj_ind, dataset_path, deep4_path, result_path_subj, progression_handler, trainer, config['n_batch'], config['lr_d'], config['lr_g'], config['betas'], config['n_epochs_per_stage'], config['n_epochs_metrics'], config['plot_every_epoch'], config['orig_fs'])
def train_epochs(self, max_epochs): self.trainer = Engine(self.train_one_step) self.evaluator = Engine(self.evaluate_one_step) self.metrics = {'Loss': Loss(self.criterion), 'Acc': Accuracy()} for name, metric in self.metrics.items(): metric.attach(self.evaluator, name) with SummaryWriter( log_dir="/tmp/tensorboard/Transform" + str(type(self))[17:len(str(type(self))) - 2]) as writer: @self.trainer.on(Events.EPOCH_COMPLETED(every=1)) # Cada 1 epocas def log_results(engine): # Evaluo el conjunto de entrenamiento self.eval() self.evaluator.run(self.train_loader) writer.add_scalar("train/loss", self.evaluator.state.metrics['Loss'], engine.state.epoch) writer.add_scalar("train/accy", self.evaluator.state.metrics['Acc'], engine.state.epoch) # Evaluo el conjunto de validación self.evaluator.run(self.valid_loader) writer.add_scalar("valid/loss", self.evaluator.state.metrics['Loss'], engine.state.epoch) writer.add_scalar("valid/accy", self.evaluator.state.metrics['Acc'], engine.state.epoch) self.train() # Guardo el mejor modelo en validación best_model_handler = ModelCheckpoint( dirname='.', require_empty=False, filename_prefix="best", n_saved=1, score_function=lambda engine: -engine.state.metrics['Loss'], score_name="val_loss") # Lo siguiente se ejecuta cada ves que termine el loop de validación self.evaluator.add_event_handler( Events.COMPLETED, best_model_handler, { f'Transform{str(type(self))[17:len(str(type(self)))-2]}': model }) self.trainer.run(self.train_loader, max_epochs=max_epochs)
def test_remove_event_handler_on_callable_events(): engine = Engine(lambda e, b: 1) def foo(e): pass assert not engine.has_event_handler(foo) engine.add_event_handler(Events.EPOCH_STARTED, foo) assert engine.has_event_handler(foo) engine.remove_event_handler(foo, Events.EPOCH_STARTED) assert not engine.has_event_handler(foo) def bar(e): pass engine.add_event_handler(Events.EPOCH_COMPLETED(every=3), bar) assert engine.has_event_handler(bar) engine.remove_event_handler(bar, Events.EPOCH_COMPLETED) assert not engine.has_event_handler(foo) with pytest.raises(TypeError, match=r"Argument event_name should not be a filtered event"): engine.remove_event_handler(bar, Events.EPOCH_COMPLETED(every=3))
def set_defaults(self, is_training=True): """ Fill in the default events for training or evaluation specs """ if self.metrics is None: self.metrics = {} if self.plot_event == 'default': self.plot_event = None #Events.EPOCH_COMPLETED if is_training: # Log and print every 100 training iterations if self.log_event == 'default': self.log_event = Events.ITERATION_COMPLETED(every=100) if self.print_event == 'default': self.print_event = Events.ITERATION_COMPLETED(every=100) if self.print_fmt == 'default': self.print_fmt = TRAIN_MESSAGE else: # Log and print at the end of each evaluation if self.log_event == 'default': self.log_event = Events.EPOCH_COMPLETED if self.print_event == 'default': self.print_event = Events.EPOCH_COMPLETED if self.print_fmt == 'default': self.print_fmt = EVAL_MESSAGE
def test_ema_two_handlers(get_dummy_model): """Test when two EMA handlers are attached to a trainer""" model_1 = get_dummy_model() # momentum will be constantly 0.5 ema_handler_1 = EMAHandler(model_1, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) model_2 = get_dummy_model() ema_handler_2 = EMAHandler(model_2, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) def _step_fn(engine: Engine, batch: Any): model_1.weight.data.add_(1) model_2.weight.data.add_(1) return 0 engine = Engine(_step_fn) assert not hasattr(engine.state, "ema_momentum_1") # handler_1 update EMA model of model_1 every 1 iteration ema_handler_1.attach(engine, "ema_momentum_1", event=Events.ITERATION_COMPLETED) assert hasattr(engine.state, "ema_momentum_1") # handler_2 update EMA model for model_2 every 2 iterations ema_handler_2.attach(engine, "ema_momentum_2", event=Events.ITERATION_COMPLETED(every=2)) assert hasattr(engine.state, "ema_momentum_2") # engine will run 4 iterations engine.run(range(2), max_epochs=2) ema_weight_1 = ema_handler_1.ema_model.weight.data ema_weight_2 = ema_handler_2.ema_model.weight.data torch.testing.assert_allclose(ema_weight_1, torch.full((1, 2), 4.0625)) torch.testing.assert_allclose(ema_weight_2, torch.full((1, 2), 3.5)) assert engine.state.ema_momentum_1 == 0.5 assert engine.state.ema_momentum_2 == 0.5 model_3 = get_dummy_model() ema_handler_3 = EMAHandler(model_3) with pytest.raises(ValueError, match="Please select another name"): ema_handler_3.attach(engine, "ema_momentum_2")
def test_mnist_lr_suggestion(lr_finder, mnist_model, mnist_optimizer, mnist_dataloader): criterion = nn.CrossEntropyLoss() trainer = create_supervised_trainer(mnist_model, mnist_optimizer, criterion) to_save = {"model": mnist_model, "optimizer": mnist_optimizer} max_iters = 50 with lr_finder.attach(trainer, to_save) as trainer_with_finder: with trainer_with_finder.add_event_handler( Events.ITERATION_COMPLETED(once=max_iters), lambda _: trainer_with_finder.terminate()): trainer_with_finder.run(mnist_dataloader) assert 1e-4 <= lr_finder.lr_suggestion() <= 10
def test_pbar_on_callable_events(capsys): n_epochs = 1 loader = list(range(100)) engine = Engine(update_fn) pbar = ProgressBar() pbar.attach(engine, event_name=Events.ITERATION_STARTED(every=10), closing_event_name=Events.EPOCH_COMPLETED) engine.run(loader, max_epochs=n_epochs) captured = capsys.readouterr() err = captured.err.split("\r") err = list(map(lambda x: x.strip(), err)) err = list(filter(None, err)) actual = err[-1] expected = "Iteration: [90/100] 90%|█████████ [00:00<00:00]" assert actual == expected
def set_handlers(trainer: Engine, evaluator: Engine, valloader: DataLoader, model: nn.Module, optimizer: optim.Optimizer, args: Namespace) -> None: ROC_AUC( output_transform=lambda output: (output.logit, output.label)).attach( engine=evaluator, name='roc_auc') Accuracy(output_transform=lambda output: ( (output.logit > 0).long(), output.label)).attach(engine=evaluator, name='accuracy') Loss(loss_fn=nn.BCEWithLogitsLoss(), output_transform=lambda output: (output.logit, output.label.float())).attach(engine=evaluator, name='loss') ProgressBar(persist=True, desc='Epoch').attach( engine=trainer, output_transform=lambda output: {'loss': output.loss}) ProgressBar(persist=False, desc='Eval').attach(engine=evaluator) ProgressBar(persist=True, desc='Eval').attach( engine=evaluator, metric_names=['roc_auc', 'accuracy', 'loss'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) @trainer.on(Events.ITERATION_COMPLETED(every=args.evaluation_interval)) def _evaluate(trainer: Engine): evaluator.run(valloader, max_epochs=1) evaluator.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=Checkpoint( to_save={ 'model': model, 'optimizer': optimizer, 'trainer': trainer }, save_handler=DiskSaver(dirname=args.checkpoint_dir, atomic=True, create_dir=True, require_empty=False), filename_prefix='best', score_function=lambda engine: engine.state.metrics['roc_auc'], score_name='val_roc_auc', n_saved=1, global_step_transform=global_step_from_engine(trainer)))
def test_neg_event_filter_threshold_handlers_profiler(): true_event_handler_time = 0.1 true_max_epochs = 1 true_num_iters = 1 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) profiler.attach(dummy_trainer) @dummy_trainer.on(Events.EPOCH_STARTED(once=2)) def do_something_once_on_2_epoch(): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = results[0] assert "do_something_once_on_2_epoch" in event_results[0] assert event_results[1] == "EPOCH_STARTED" assert event_results[2] == "not triggered"
def test_concepts_snippet_warning(): def random_train_data_generator(): while True: yield torch.randint(0, 100, size=(1, )) def print_train_data(engine, batch): i = engine.state.iteration e = engine.state.epoch print("train", e, i, batch.tolist()) trainer = DeterministicEngine(print_train_data) @trainer.on(Events.ITERATION_COMPLETED(every=3)) def user_handler(_): # handler synchronizes the random state torch.manual_seed(12) a = torch.rand(1) trainer.run(random_train_data_generator(), max_epochs=3, epoch_length=5)
def test_ema_two_handlers(get_dummy_model): """Test when two EMA handlers are attached to a trainer""" model_1 = get_dummy_model() ema_handler_1 = EMAHandler(model_1, momentum=0.5) model_2 = get_dummy_model() ema_handler_2 = EMAHandler(model_2, momentum=0.5) def _step_fn(engine: Engine, batch: Any): model_1.weight.data.add_(1) model_2.weight.data.add_(1) return 0 engine = Engine(_step_fn) assert not hasattr(engine.state, "ema_momentum_1") # handler_1 update EMA model of model_1 every 1 iteration ema_handler_1.attach(engine, "ema_momentum_1", event=Events.ITERATION_COMPLETED) assert hasattr(engine.state, "ema_momentum_1") # handler_2 update EMA model for model_2 every 2 iterations ema_handler_2.attach(engine, "ema_momentum_2", event=Events.ITERATION_COMPLETED(every=2)) assert hasattr(engine.state, "ema_momentum_2") # engine will run 4 iterations engine.run(range(2), max_epochs=2) # explicitly cast to float32 to avoid test failure on XLA devices ema_weight_1 = ema_handler_1.ema_model.weight.data.to(torch.float32) ema_weight_2 = ema_handler_2.ema_model.weight.data.to(torch.float32) assert ema_weight_1.allclose(ema_weight_1.new_full((1, 2), 4.0625)) assert ema_weight_2.allclose(ema_weight_2.new_full((1, 2), 3.5)) assert engine.state.ema_momentum_1 == 0.5 assert engine.state.ema_momentum_2 == 0.5 model_3 = get_dummy_model() ema_handler_3 = EMAHandler(model_3) with pytest.warns(UserWarning, match="Attribute 'ema_momentum_1' already exists"): ema_handler_3.attach(engine, name="ema_momentum_1")
def setup_evaluation( trainer: Engine, evaluators: Dict[str, Engine], data_loaders: Dict[str, DataLoader], logger: Logger, ) -> None: # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score def _evaluation(engine: Engine) -> None: epoch = trainer.state.epoch for split in ["train", "val", "test"]: state = evaluators[split].run(data_loaders[split]) log_metrics(logger, epoch, state.times["COMPLETED"], split, state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config.validate_every) | Events.COMPLETED, _evaluation, ) return
def test_get_intermediate_results_during_run_basic_profiler(capsys): true_event_handler_time = 0.0645 true_max_epochs = 2 true_num_iters = 5 profiler = BasicTimeProfiler() dummy_trainer = get_prepared_engine_for_basic_profiler(true_event_handler_time) profiler.attach(dummy_trainer) @dummy_trainer.on(Events.ITERATION_COMPLETED(every=3)) def log_results(_): results = profiler.get_results() profiler.print_results(results) captured = capsys.readouterr() out = captured.out assert "BasicTimeProfiler._" not in out assert "nan" not in out assert " min/index: (0.0, " not in out, out dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs)
def test_pos_event_filter_threshold_handlers_profiler(): true_event_handler_time = HandlersTimeProfiler.EVENT_FILTER_THESHOLD_TIME true_max_epochs = 2 true_num_iters = 1 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) profiler.attach(dummy_trainer) @dummy_trainer.on(Events.EPOCH_STARTED(once=2)) def do_something_once_on_2_epoch(): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = results[0] assert "do_something_once_on_2_epoch" in event_results[0] assert event_results[1] == "EPOCH_STARTED" assert event_results[2] == approx( (true_max_epochs * true_num_iters * true_event_handler_time) / 2, abs=1e-1 ) # total
def test_run_finite_iterator_no_epoch_length_2(): # FR: https://github.com/pytorch/ignite/issues/871 known_size = 11 def finite_size_data_iter(size): for i in range(size): yield i bc = BatchChecker(data=list(range(known_size))) engine = Engine(lambda e, b: bc.check(b)) @engine.on(Events.ITERATION_COMPLETED(every=known_size)) def restart_iter(): engine.state.dataloader = finite_size_data_iter(known_size) data_iter = finite_size_data_iter(known_size) engine.run(data_iter, max_epochs=5) assert engine.state.epoch == 5 assert engine.state.iteration == known_size * 5
def _test_ema_final_weight(model, device=None, ddp=False, interval=1): """Test if final smoothed weights are correct""" if device is None: # let horovod decide the device device = idist.device() if isinstance(device, str): device = torch.device(device) model = model.to(device) if ddp: model = idist.auto_model(model) step_fn = _get_dummy_step_fn(model) engine = Engine(step_fn) # momentum will be constantly 0.5 ema_handler = EMAHandler(model, momentum_warmup=0.5, momentum=0.5, warmup_iters=1) ema_handler.attach(engine, "model", event=Events.ITERATION_COMPLETED(every=interval)) # engine will run 4 iterations engine.run(range(2), max_epochs=2) # ema_model and model can be DP or DDP ema_weight = _unwrap_model(ema_handler.ema_model).weight.data model_weight = _unwrap_model(model).weight.data assert ema_weight.device == device assert model_weight.device == device if interval == 1: torch.testing.assert_allclose( ema_weight, torch.full((1, 2), 4.0625, device=device)) elif interval == 2: torch.testing.assert_allclose(ema_weight, torch.full((1, 2), 3.5, device=device)) else: pass torch.testing.assert_allclose(model_weight, torch.full((1, 2), 5.0, device=device))
def test_custom_event_with_arg_handlers_profiler(): true_event_handler_time = 0.1 true_max_epochs = 1 true_num_iters = 2 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) dummy_trainer.register_events("custom_event") profiler.attach(dummy_trainer) @dummy_trainer.on(Events.ITERATION_COMPLETED(every=1)) def trigger_custom_event(): dummy_trainer.fire_event("custom_event") args = [122, 324] @dummy_trainer.on("custom_event", args) def on_custom_event(args): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = None for row in results: if row[1] == "custom_event": event_results = row break assert event_results is not None assert "on_custom_event" in event_results[0] assert event_results[2] == approx(true_max_epochs * true_num_iters * true_event_handler_time, abs=1e-1) # total assert event_results[3][0] == approx(true_event_handler_time, abs=1e-1) # min assert event_results[4][0] == approx(true_event_handler_time, abs=1e-1) # max assert event_results[5] == approx(true_event_handler_time, abs=1e-1) # mean assert event_results[6] == approx(0.0, abs=1e-1) # stddev
def __init__(self, model, loss, optimizer, lr_scheduler, device, logger, log_interval, output_dir=None): self.logger = logger self.lr_scheduler = lr_scheduler self.log_interval = log_interval self.progress_bar_desc = "ITERATION - loss: {:.2f}" self.trainer_engine = create_supervised_trainer(model, optimizer, loss, device=device) self.trainer_engine.add_event_handler( Events.ITERATION_COMPLETED(every=log_interval), self.log_training_loss) self.trainer_engine.add_event_handler(Events.EPOCH_COMPLETED, self.lr_step)
def setup_snapshots(trainer, sample_images, conf): # type: (Engine, SampleImages, DictConfig) -> None snapshots = conf.snapshots use_ema = conf.G_smoothing.enabled if snapshots.enabled: if use_ema: snap_event = Events.ITERATION_COMPLETED( every=snapshots.interval_iteration) snap_path = snapshots.get('save_dir', os.path.join(os.getcwd(), 'images')) if not os.path.exists(snap_path): os.makedirs(snap_path) logging.info("Saving snapshot images to {}".format(snap_path)) trainer.add_event_handler(snap_event, handle_snapshot_images, sample_images, snap_path, dynamic_range=tuple( snapshots.dynamic_range)) else: logging.warning( "Snapshot generation requires G_smoothing.enabled=true. " "Snapshots will be turned off for this run.")
def train(): set_seed(train_param.seed) model = Model(model_param) optimizer = AdamW(model.parameters(), lr=train_param.lr, eps=1e-8) update_steps = train_param.epoch * len(train_loader) scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=0, num_training_steps=update_steps) loss_fn = [translate, MSELoss()] device = torch.device(f'cuda:{train_param.device}') trainer = create_trainer(model, optimizer, scheduler, loss_fn, train_param.grad_norm, device) train_evaluator = create_evaluator(model, metric, device) dev_evaluator = create_evaluator(model, metric, device) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=train_param.interval), log_training_loss) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *(train_evaluator, train_loader, 'Train')) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_results, *(dev_evaluator, dev_loader, 'Dev')) es_handler = EarlyStopping(patience=train_param.patience, score_function=score_fn, trainer=trainer) dev_evaluator.add_event_handler(Events.COMPLETED, es_handler) ckpt_handler = ModelCheckpoint(train_param.save_path, '', score_function=score_fn, score_name='score', require_empty=False) dev_evaluator.add_event_handler(Events.COMPLETED, ckpt_handler, { 'model': model, 'param': model_param }) print( f'Start running {train_param.save_path.split("/")[-1]} at device: {train_param.device}\t' f'lr: {train_param.lr}') trainer.run(train_loader, max_epochs=train_param.epoch)
def test_attach(): n_epochs = 5 data = list(range(50)) def _test(event, n_calls): losses = torch.rand(n_epochs * len(data)) losses_iter = iter(losses) def update_fn(engine, batch): return next(losses_iter) trainer = Engine(update_fn) logger = DummyLogger() mock_log_handler = MagicMock() logger.attach(trainer, log_handler=mock_log_handler, event_name=event) trainer.run(data, max_epochs=n_epochs) if isinstance(event, EventWithFilter): event = event.event mock_log_handler.assert_called_with(trainer, logger, event) assert mock_log_handler.call_count == n_calls _test(Events.ITERATION_STARTED, len(data) * n_epochs) _test(Events.ITERATION_COMPLETED, len(data) * n_epochs) _test(Events.EPOCH_STARTED, n_epochs) _test(Events.EPOCH_COMPLETED, n_epochs) _test(Events.STARTED, 1) _test(Events.COMPLETED, 1) _test(Events.ITERATION_STARTED(every=10), len(data) // 10 * n_epochs)