def _test(num_workers): max_epochs = 3 batch_size = 4 num_iters = 21 data = torch.randint(0, 1000, size=(num_iters * batch_size,)) dataloader = torch.utils.data.DataLoader( data, batch_size=batch_size, num_workers=num_workers, pin_memory="cuda" in device, drop_last=True, shuffle=True, ) seen_batchs = [] def update_fn(engine, batch): batch_to_device = batch.to(device) seen_batchs.append(batch) engine = Engine(update_fn) def foo(engine): pass engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo) engine.run(dataloader, max_epochs=max_epochs, seed=12) engine = None import gc gc.collect() assert len(gc.garbage) == 0
def test_state_get_event_attrib_value(): state = State() state.iteration = 10 state.epoch = 9 e = Events.ITERATION_STARTED assert state.get_event_attrib_value(e) == state.iteration e = Events.ITERATION_COMPLETED assert state.get_event_attrib_value(e) == state.iteration e = Events.EPOCH_STARTED assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED assert state.get_event_attrib_value(e) == state.epoch e = Events.STARTED assert state.get_event_attrib_value(e) == state.epoch e = Events.COMPLETED assert state.get_event_attrib_value(e) == state.epoch e = Events.ITERATION_STARTED(every=10) assert state.get_event_attrib_value(e) == state.iteration e = Events.ITERATION_COMPLETED(every=10) assert state.get_event_attrib_value(e) == state.iteration e = Events.EPOCH_STARTED(once=5) assert state.get_event_attrib_value(e) == state.epoch e = Events.EPOCH_COMPLETED(once=5) assert state.get_event_attrib_value(e) == state.epoch
def test_neg_event_filter_threshold_handlers_profiler(): true_event_handler_time = 0.1 true_max_epochs = 1 true_num_iters = 1 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) profiler.attach(dummy_trainer) @dummy_trainer.on(Events.EPOCH_STARTED(once=2)) def do_something_once_on_2_epoch(): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = results[0] assert "do_something_once_on_2_epoch" in event_results[0] assert event_results[1] == "EPOCH_STARTED" assert event_results[2] == "not triggered"
def test_pos_event_filter_threshold_handlers_profiler(): true_event_handler_time = HandlersTimeProfiler.EVENT_FILTER_THESHOLD_TIME true_max_epochs = 2 true_num_iters = 1 profiler = HandlersTimeProfiler() dummy_trainer = Engine(_do_nothing_update_fn) profiler.attach(dummy_trainer) @dummy_trainer.on(Events.EPOCH_STARTED(once=2)) def do_something_once_on_2_epoch(): time.sleep(true_event_handler_time) dummy_trainer.run(range(true_num_iters), max_epochs=true_max_epochs) results = profiler.get_results() event_results = results[0] assert "do_something_once_on_2_epoch" in event_results[0] assert event_results[1] == "EPOCH_STARTED" assert event_results[2] == approx( (true_max_epochs * true_num_iters * true_event_handler_time) / 2, abs=1e-1 ) # total
def test_event_handlers_with_decoration(): engine = Engine(lambda e, b: b) def decorated(fun): @functools.wraps(fun) def wrapper(*args, **kwargs): return fun(*args, **kwargs) return wrapper values = [] def foo(): values.append("foo") @decorated def decorated_foo(): values.append("decorated_foo") engine.add_event_handler(Events.EPOCH_STARTED, foo) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo) engine.add_event_handler(Events.EPOCH_STARTED, decorated_foo) engine.add_event_handler(Events.EPOCH_STARTED(every=2), decorated_foo) def foo_args(e): values.append("foo_args") values.append(e.state.iteration) @decorated def decorated_foo_args(e): values.append("decorated_foo_args") values.append(e.state.iteration) engine.add_event_handler(Events.EPOCH_STARTED, foo_args) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo_args) engine.add_event_handler(Events.EPOCH_STARTED, decorated_foo_args) engine.add_event_handler(Events.EPOCH_STARTED(every=2), decorated_foo_args) class Foo: def __init__(self): self.values = [] def foo(self): self.values.append("foo") @decorated def decorated_foo(self): self.values.append("decorated_foo") def foo_args(self, e): self.values.append("foo_args") self.values.append(e.state.iteration) @decorated def decorated_foo_args(self, e): self.values.append("decorated_foo_args") self.values.append(e.state.iteration) foo = Foo() engine.add_event_handler(Events.EPOCH_STARTED, foo.foo) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo.foo) engine.add_event_handler(Events.EPOCH_STARTED, foo.decorated_foo) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo.decorated_foo) engine.add_event_handler(Events.EPOCH_STARTED, foo.foo_args) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo.foo_args) engine.add_event_handler(Events.EPOCH_STARTED, foo.decorated_foo_args) engine.add_event_handler(Events.EPOCH_STARTED(every=2), foo.decorated_foo_args) engine.run([0], max_epochs=2) assert values == foo.values
def get_prepared_engine_for_handlers_profiler(true_event_handler_time): HANDLERS_SLEEP_COUNT = 11 PROCESSING_SLEEP_COUNT = 3 class CustomEvents(EventEnum): CUSTOM_STARTED = "custom_started" CUSTOM_COMPLETED = "custom_completed" def dummy_train_step(engine, batch): engine.fire_event(CustomEvents.CUSTOM_STARTED) time.sleep(true_event_handler_time) engine.fire_event(CustomEvents.CUSTOM_COMPLETED) dummy_trainer = Engine(dummy_train_step) dummy_trainer.register_events(*CustomEvents) @dummy_trainer.on(Events.STARTED) def delay_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.COMPLETED) def delay_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_STARTED) def delay_epoch_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_COMPLETED) def delay_epoch_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.ITERATION_STARTED) def delay_iter_start(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.ITERATION_COMPLETED) def delay_iter_complete(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.GET_BATCH_STARTED) def delay_get_batch_started(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.GET_BATCH_COMPLETED) def delay_get_batch_completed(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(CustomEvents.CUSTOM_STARTED) def delay_custom_started(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(CustomEvents.CUSTOM_COMPLETED) def delay_custom_completed(engine): time.sleep(true_event_handler_time) @dummy_trainer.on(Events.EPOCH_STARTED(once=1)) def do_something_once_on_1_epoch(): time.sleep(true_event_handler_time) return dummy_trainer, HANDLERS_SLEEP_COUNT, PROCESSING_SLEEP_COUNT
def run(output_path, config): device = "cuda" local_rank = config['local_rank'] distributed = backend is not None if distributed: torch.cuda.set_device(local_rank) device = "cuda" rank = dist.get_rank() if distributed else 0 # Rescale batch_size and num_workers ngpus_per_node = torch.cuda.device_count() ngpus = dist.get_world_size() if distributed else 1 batch_size = config['batch_size'] // ngpus num_workers = int( (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node) train_labelled_loader, test_loader = \ get_train_test_loaders(path=config['data_path'], batch_size=batch_size, distributed=distributed, num_workers=num_workers) model = get_model(config['model']) model = model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank) optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'], nesterov=True) criterion = nn.CrossEntropyLoss().to(device) le = len(train_labelled_loader) milestones_values = [(0, 0.0), (le * config['num_warmup_epochs'], config['learning_rate']), (le * config['num_epochs'], 0.0)] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) def _prepare_batch(batch, device, non_blocking): x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) def process_function(engine, labelled_batch): x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { 'batch loss': loss.item(), } trainer = Engine(process_function) if not hasattr(lr_scheduler, "step"): trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) metric_names = [ 'batch loss', ] def output_transform(x, name): return x[name] for n in metric_names: # We compute running average values on the output (batch loss) across all devices RunningAverage(output_transform=partial(output_transform, name=n), epoch_bound=False, device=device).attach(trainer, n) if rank == 0: checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="checkpoint") trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), checkpoint_handler, { 'model': model, 'optimizer': optimizer }) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED) if config['display_iters']: ProgressBar(persist=False, bar_format="").attach(trainer, metric_names=metric_names) tb_logger = TensorboardLogger(log_dir=output_path) tb_logger.attach(trainer, log_handler=tbOutputHandler( tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=tbOptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED) metrics = { "accuracy": Accuracy(device=device if distributed else None), "loss": Loss(criterion, device=device if distributed else None) } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): torch.cuda.synchronize() train_evaluator.run(train_labelled_loader) evaluator.run(test_loader) trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: if config['display_iters']: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) tb_logger.attach(train_evaluator, log_handler=tbOutputHandler(tag="train", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) tb_logger.attach(evaluator, log_handler=tbOutputHandler(tag="test", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) # Store the best model def default_score_fn(engine): score = engine.state.metrics['accuracy'] return score score_function = default_score_fn if not hasattr( config, "score_function") else config.score_function best_model_handler = ModelCheckpoint( dirname=output_path, filename_prefix="best", n_saved=3, global_step_transform=global_step_from_engine(trainer), score_name="val_accuracy", score_function=score_function) evaluator.add_event_handler(Events.COMPLETED, best_model_handler, { 'model': model, }) trainer.run(train_labelled_loader, max_epochs=config['num_epochs']) if rank == 0: tb_logger.close()
def run(output_dir, config): device = torch.device("cuda" if args.use_cuda else "cpu") torch.manual_seed(config['seed']) np.random.seed(config['seed']) # Rescale batch_size and num_workers ngpus_per_node = 1 batch_size = config['batch_size'] num_workers = int( (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node) (train_loader, test_loader, mislabeled_train_loader) = get_train_test_loaders( path=config['data_path'], batch_size=batch_size, num_workers=num_workers, random_seed=config['seed'], random_labels_fraction=config['random_labels_fraction'], ) model = get_mnist_model(args, device) optimizer = AdamFlexibleWeightDecay( model.parameters(), lr=config['init_lr'], weight_decay_order=config['weight_decay_order'], weight_decay=config['weight_decay']) criterion = nn.CrossEntropyLoss().to(device) le = len(train_loader) lr_scheduler = MultiStepLR(optimizer, milestones=[le * config['epochs'] * 3 // 4], gamma=0.1) def _prepare_batch(batch, device, non_blocking): x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) def process_function(unused_engine, batch): x, y = _prepare_batch(batch, device=device, non_blocking=True) model.train() optimizer.zero_grad() y_pred = model(x) if config['agreement_threshold'] > 0.0: # The "batch_size" in this function refers to the batch size per env # Since we treat every example as one env, we should set the parameter # n_agreement_envs equal to batch size mean_loss, masks = and_mask_utils.get_grads( agreement_threshold=config['agreement_threshold'], batch_size=1, loss_fn=criterion, n_agreement_envs=config['batch_size'], params=optimizer.param_groups[0]['params'], output=y_pred, target=y, method=args.method, scale_grad_inverse_sparsity=config[ 'scale_grad_inverse_sparsity'], ) else: mean_loss = criterion(y_pred, y) mean_loss.backward() optimizer.step() return {} trainer = Engine(process_function) metric_names = [] common.setup_common_training_handlers(trainer, output_path=output_dir, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbar_on_iters=True, log_every_iters=10) tb_logger = TensorboardLogger(log_dir=output_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED) metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)} test_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) mislabeled_train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): if args.use_cuda: torch.cuda.synchronize() train_evaluator.run(train_loader) if config['random_labels_fraction'] > 0.0: mislabeled_train_evaluator.run(mislabeled_train_loader) test_evaluator.run(test_loader) def flush_metrics(engine): tb_logger.writer.flush() trainer.add_event_handler(Events.EPOCH_STARTED(every=1), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) trainer.add_event_handler(Events.EPOCH_COMPLETED, flush_metrics) ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(test_evaluator) ProgressBar(persist=False, desc="Train (mislabeled portion) evaluation").attach( mislabeled_train_evaluator) tb_logger.attach( train_evaluator, log_handler=OutputHandler( tag="train", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.COMPLETED) tb_logger.attach( test_evaluator, log_handler=OutputHandler( tag="test", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.COMPLETED) tb_logger.attach( mislabeled_train_evaluator, log_handler=OutputHandler( tag="train_wrong", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.COMPLETED) trainer_rng = np.random.RandomState() trainer.run(train_loader, max_epochs=config['epochs'], seed=trainer_rng.randint(2**32)) tb_logger.close()
device=backend_conf.device) if backend_conf.rank == 0: tb_logger = TensorboardLogger(log_dir=str(output_path)) tb_logger.attach(trainer, log_handler=OutputHandler(tag='train', metric_names=metric_names), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, param_name='lr'), event_name=Events.ITERATION_STARTED) # TODO: make sure hp params logging works here + use test eval metrics instead of training's tb_logger.attach(trainer, log_handler=HyperparamsOutoutHandler(hp, metric_names=metric_names), event_name=Events.COMPLETED) def _metrics(prefix): return {**{f'{prefix}_{n}': m for n, m in metrics.items()}, **{f'{prefix}_{n}': loss for n, loss in losses.items()}} valid_evaluator = create_supervised_evaluator(model, metrics=_metrics('valid'), device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=_metrics('train'), device=device, non_blocking=True) @trainer.on(Events.EPOCH_STARTED(every=hp['validate_every_epochs'])) @trainer.on(Events.COMPLETED) def _run_validation(engine: Engine): if torch.cuda.is_available() and not backend_conf.is_cpu: torch.cuda.synchronize() # Trainset evaluation train_state = train_evaluator.run(trainset) train_metrics = {f'train_{n}': float(v) for n, v in train_state.metrics.items()} for n, v in train_metrics.items(): mlflow.log_metric(n, v, step=engine.state.epoch) # Validset evaluation valid_state = valid_evaluator.run(validset_testset[0]) valid_metrics = {f'valid_{n}': float(v) for n, v in valid_state.metrics.items()} for n, v in valid_metrics.items():
def run(output_path, config): distributed = dist.is_available() and dist.is_initialized() rank = dist.get_rank() if distributed else 0 manual_seed(config["seed"] + rank) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = utils.get_dataflow(config, distributed) model, optimizer = utils.get_model_optimizer(config, distributed) criterion = nn.CrossEntropyLoss().to(utils.device) le = len(train_loader) milestones_values = [ (0, 0.0), (le * config["num_warmup_epochs"], config["learning_rate"]), (le * config["num_epochs"], 0.0), ] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations def train_step(engine, batch): x = convert_tensor(batch[0], device=utils.device, non_blocking=True) y = convert_tensor(batch[1], device=utils.device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { "batch loss": loss.item(), } if config["deterministic"] and rank == 0: print("Setup deterministic trainer") trainer = Engine(train_step) if not config["deterministic"] else DeterministicEngine(train_step) train_sampler = train_loader.sampler if distributed else None to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler} metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], output_path=output_path, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbar_on_iters=config["display_iters"], log_every_iters=10, ) if rank == 0: # Setup Tensorboard logger - wrapper on SummaryWriter tb_logger = TensorboardLogger(log_dir=output_path) # Attach logger to the trainer and log trainer's metrics (stored in trainer.state.metrics) every iteration tb_logger.attach( trainer, log_handler=OutputHandler(tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED, ) # log optimizer's parameters: "lr" every iteration tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED ) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "accuracy": Accuracy(device=utils.device if distributed else None), "loss": Loss(criterion, device=utils.device if distributed else None), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True) def run_validation(engine): train_evaluator.run(train_loader) evaluator.run(test_loader) trainer.add_event_handler(Events.EPOCH_STARTED(every=config["validate_every"]), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: # Setup progress bar on evaluation engines if config["display_iters"]: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) # Let's log metrics of `train_evaluator` stored in `train_evaluator.state.metrics` when validation run is done tb_logger.attach( train_evaluator, log_handler=OutputHandler( tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer) ), event_name=Events.COMPLETED, ) # Let's log metrics of `evaluator` stored in `evaluator.state.metrics` when validation run is done tb_logger.attach( evaluator, log_handler=OutputHandler( tag="test", metric_names="all", global_step_transform=global_step_from_engine(trainer) ), event_name=Events.COMPLETED, ) # Store 3 best models by validation accuracy: common.save_best_model_by_val_score( output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test" ) # Optionally log model gradients if config["log_model_grads_every"] is not None: tb_logger.attach( trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED(every=config["log_model_grads_every"]), ) # In order to check training resuming we can emulate a crash if config["crash_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"])) def _(engine): raise Exception("STOP at iteration: {}".format(engine.state.iteration)) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix()) print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: tb_logger.close()
def run(output_path, config): device = "cuda" local_rank = config["local_rank"] distributed = backend is not None if distributed: torch.cuda.set_device(local_rank) device = "cuda" rank = dist.get_rank() if distributed else 0 torch.manual_seed(config["seed"] + rank) # Rescale batch_size and num_workers ngpus_per_node = torch.cuda.device_count() ngpus = dist.get_world_size() if distributed else 1 batch_size = config["batch_size"] // ngpus num_workers = int( (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) train_loader, test_loader = get_train_test_loaders( path=config["data_path"], batch_size=batch_size, distributed=distributed, num_workers=num_workers, ) model = get_model(config["model"]) model = model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank, ) optimizer = optim.SGD( model.parameters(), lr=config["learning_rate"], momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=True, ) criterion = nn.CrossEntropyLoss().to(device) le = len(train_loader) milestones_values = [ (0, 0.0), (le * config["num_warmup_epochs"], config["learning_rate"]), (le * config["num_epochs"], 0.0), ] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) def _prepare_batch(batch, device, non_blocking): x, y = batch return ( convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking), ) def process_function(engine, batch): x, y = _prepare_batch(batch, device=device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { "batch loss": loss.item(), } trainer = Engine(process_function) train_sampler = train_loader.sampler if distributed else None to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], output_path=output_path, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbar_on_iters=config["display_iters"], log_every_iters=10, ) if rank == 0: tb_logger = TensorboardLogger(log_dir=output_path) tb_logger.attach( trainer, log_handler=OutputHandler(tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED, ) tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED, ) metrics = { "accuracy": Accuracy(device=device if distributed else None), "loss": Loss(criterion, device=device if distributed else None), } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): torch.cuda.synchronize() train_evaluator.run(train_loader) evaluator.run(test_loader) trainer.add_event_handler( Events.EPOCH_STARTED(every=config["validate_every"]), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: if config["display_iters"]: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) tb_logger.attach( train_evaluator, log_handler=OutputHandler( tag="train", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) tb_logger.attach( evaluator, log_handler=OutputHandler( tag="test", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) # Store the best model by validation accuracy: common.save_best_model_by_val_score( output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test", ) if config["log_model_grads_every"] is not None: tb_logger.attach( trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED( every=config["log_model_grads_every"]), ) if config["crash_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"])) def _(engine): raise Exception("STOP at iteration: {}".format( engine.state.iteration)) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix()) print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: tb_logger.close()
trainer.add_event_handler(Events.COMPLETED, run_evaluation) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) training_saver = ModelCheckpoint( "checkpoint_190520", filename_prefix='checkpoint', save_interval=None, # Save every 1000 iterations n_saved=None, atomic=True, save_as_state_dict=True, require_empty=False, create_dir=True) #Changed from Events.ITERATION_COMPLETED to Events.EPOCH_COMPLETED. EDIT: changed to EPOCH_STARTED every 10 trainer.add_event_handler(Events.EPOCH_STARTED(every=3), training_saver, { "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler }) # Store the best model def default_score_fn(engine): score = engine.state.metrics['Accuracy'] return score # Add early stopping es_patience = 10 es_handler = EarlyStopping(patience=es_patience,