def test_returns_state(): engine = Engine(MagicMock(return_value=1)) state = engine.run([]) assert isinstance(state, State)
def _test(save_history): tensor = torch.ones([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0.001) max_epochs = 25 lr_max_value = 0.4 num_iterations_per_epoch = 128 num_iterations = max_epochs * num_iterations_per_epoch warmup_duration = 5 * num_iterations_per_epoch cooldown_duration = 5 * num_iterations_per_epoch scheduler_1 = LinearCyclicalScheduler( optimizer, "lr", start_value=lr_max_value, end_value=lr_max_value * 0.9, cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2, ) scheduler_2 = LinearCyclicalScheduler( optimizer, "lr", start_value=lr_max_value, end_value=0.0, cycle_size=cooldown_duration * 2 ) lr_scheduler = ConcatScheduler( schedulers=[scheduler_1, scheduler_2], durations=[num_iterations - warmup_duration - cooldown_duration], save_history=False, ) lr_values = [None] * num_iterations scheduler = create_lr_scheduler_with_warmup( lr_scheduler, warmup_start_value=0.0, warmup_end_value=lr_max_value, warmup_duration=warmup_duration, save_history=save_history, output_simulated_values=lr_values, ) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) @trainer.on(Events.ITERATION_COMPLETED) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) data = [0] * num_iterations_per_epoch for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == pytest.approx([v for i, v in lr_values]) if save_history: param_history = trainer.state.param_history["lr"] assert lrs == pytest.approx([v[0] for v in param_history]) trainer.state.param_history = None scheduler.load_state_dict(state_dict)
class Trainer: _STEPS_PER_LOSS_WRITE = 10 _STEPS_PER_GRAD_WRITE = 10 _STEPS_PER_LR_WRITE = 10 def __init__( self, module, device, train_metrics, train_loader, opt, lr_scheduler, max_epochs, max_grad_norm, test_metrics, test_loader, epochs_per_test, early_stopping, valid_loss, valid_loader, max_bad_valid_epochs, visualizer, writer, should_checkpoint_latest, should_checkpoint_best_valid ): self._module = module self._device = device self._train_metrics = train_metrics self._train_loader = train_loader self._opt = opt self._lr_scheduler = lr_scheduler self._max_epochs = max_epochs self._max_grad_norm = max_grad_norm self._test_metrics = test_metrics self._test_loader = test_loader self._epochs_per_test = epochs_per_test self._valid_loss = valid_loss self._valid_loader = valid_loader self._max_bad_valid_epochs = max_bad_valid_epochs self._best_valid_loss = float("inf") self._num_bad_valid_epochs = 0 self._visualizer = visualizer self._writer = writer self._should_checkpoint_best_valid = should_checkpoint_best_valid ### Training self._trainer = Engine(self._train_batch) AverageMetric().attach(self._trainer) ProgressBar(persist=True).attach(self._trainer, ["loss"]) self._trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) self._trainer.add_event_handler(Events.ITERATION_COMPLETED, self._log_training_info) ### Validation if early_stopping: self._validator = Engine(self._validate_batch) AverageMetric().attach(self._validator) ProgressBar(persist=False, desc="Validating").attach(self._validator) self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._validate) ### Testing self._tester = Engine(self._test_batch) AverageMetric().attach(self._tester) ProgressBar(persist=False, desc="Testing").attach(self._tester) self._trainer.add_event_handler(Events.EPOCH_COMPLETED, self._test_and_log) ### Checkpointing if should_checkpoint_latest: self._trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: self._save_checkpoint("latest")) try: self._load_checkpoint("latest") except FileNotFoundError: print("Did not find `latest' checkpoint.", file=sys.stderr) try: self._load_checkpoint("best_valid") except FileNotFoundError: print("Did not find `best_valid' checkpoint.", file=sys.stderr) def train(self): self._trainer.run(data=self._train_loader, max_epochs=self._max_epochs) def _train_batch(self, engine, batch): self._module.train() x, _ = batch # TODO: Potentially pass y also for genericity x = x.to(self._device) self._opt.zero_grad() train_metrics = self._train_metrics(self._module, x) loss = train_metrics["loss"] loss.backward() if self._max_grad_norm is not None: torch.nn.utils.clip_grad_norm_(self._module.parameters(), self._max_grad_norm) self._opt.step() self._lr_scheduler.step() return {"metrics": train_metrics} def test(self): self._module.eval() return self._tester.run(data=self._test_loader).metrics @torch.no_grad() def _test_and_log(self, engine): epoch = engine.state.epoch if (epoch - 1) % self._epochs_per_test == 0: # Test after first epoch for k, v in self.test().items(): self._writer.write_scalar(f"test/{k}", v, global_step=engine.state.epoch) if not torch.isfinite(v): self._save_checkpoint(tag="nan_during_test") self._visualizer.visualize(self._module, epoch) def _test_batch(self, engine, batch): x, _ = batch x = x.to(self._device) return {"metrics": self._test_metrics(self._module, x)} @torch.no_grad() def _validate(self, engine): self._module.eval() state = self._validator.run(data=self._valid_loader) valid_loss = state.metrics["loss"] if valid_loss < self._best_valid_loss: print(f"Best validation loss {valid_loss} after epoch {engine.state.epoch}") self._num_bad_valid_epochs = 0 self._best_valid_loss = valid_loss if self._should_checkpoint_best_valid: self._save_checkpoint(tag="best_valid") else: if not torch.isfinite(valid_loss): self._save_checkpoint(tag="nan_during_validation") self._num_bad_valid_epochs += 1 # We do this manually (i.e. don't use Ignite's early stopping) to permit # saving/resuming more easily if self._num_bad_valid_epochs > self._max_bad_valid_epochs: print( f"No validation improvement after {self._num_bad_valid_epochs} epochs. Terminating." ) self._trainer.terminate() def _validate_batch(self, engine, batch): x, _ = batch x = x.to(self._device) return {"metrics": {"loss": self._valid_loss(self._module, x)}} def _log_training_info(self, engine): i = engine.state.iteration if i % self._STEPS_PER_LOSS_WRITE == 0: for k, v in engine.state.output["metrics"].items(): self._writer.write_scalar("train/" + k, v, global_step=i) # TODO: Inefficient to recompute this if we are doing gradient clipping if i % self._STEPS_PER_GRAD_WRITE == 0: self._writer.write_scalar("train/grad-norm", self._get_grad_norm(), global_step=i) # TODO: We should do this _before_ calling self._lr_scheduler.step(), since # we will not correspond to the learning rate used at iteration i otherwise if i % self._STEPS_PER_LR_WRITE == 0: self._writer.write_scalar("train/lr", self._get_lr(), global_step=i) def _get_grad_norm(self): norm = 0 for param in self._module.parameters(): if param.grad is not None: norm += param.grad.norm().item()**2 return np.sqrt(norm) def _get_lr(self): param_group, = self._opt.param_groups return param_group["lr"] def _save_checkpoint(self, tag): # We do this manually (i.e. don't use Ignite's checkpointing) because # Ignite only allows saving objects, not scalars (e.g. the current epoch) checkpoint = { "epoch": self._trainer.state.epoch, "iteration": self._trainer.state.iteration, "module_state_dict": self._module.state_dict(), "opt_state_dict": self._opt.state_dict(), "best_valid_loss": self._best_valid_loss, "num_bad_valid_epochs": self._num_bad_valid_epochs, "lr_scheduler_state_dict": self._lr_scheduler.state_dict() } self._writer.write_checkpoint(tag, checkpoint) def _load_checkpoint(self, tag): checkpoint = self._writer.load_checkpoint(tag, device=self._device) @self._trainer.on(Events.STARTED) def resume_trainer_state(engine): engine.state.epoch = checkpoint["epoch"] engine.state.iteration = checkpoint["iteration"] self._module.load_state_dict(checkpoint["module_state_dict"]) self._opt.load_state_dict(checkpoint["opt_state_dict"]) self._best_valid_loss = checkpoint["best_valid_loss"] self._num_bad_valid_epochs = checkpoint["num_bad_valid_epochs"] try: self._lr_scheduler.load_state_dict(checkpoint["lr_scheduler_state_dict"]) except KeyError: print("No lr scheduler in saved checkpoint") print(f"Loaded checkpoint `{tag}' after epoch {checkpoint['epoch']}", file=sys.stderr)
def _test(duration_vals_as_np_int): scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10) scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10) durations = [10] if duration_vals_as_np_int: durations = [np.int64(t) for t in durations] concat_scheduler = ConcatScheduler( schedulers=[scheduler_1, scheduler_2], durations=durations, save_history=True ) state_dict = concat_scheduler.state_dict() data = [0] * 10 max_epochs = 2 simulated_values = ConcatScheduler.simulate_values( num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2], durations=durations ) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == list( map( pytest.approx, [ # Cycle 1 of the LinearCyclicalScheduler 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 1 of the CosineAnnealingScheduler 0.0, 0.02447174185242318, 0.09549150281252627, 0.20610737385376332, 0.3454915028125263, 0.5, 0.6545084971874737, 0.7938926261462365, 0.9045084971874737, 0.9755282581475768, ], ) ) state_lrs = trainer.state.param_history["lr"] assert len(state_lrs) == len(lrs) # Unpack singleton lists assert [group[0] for group in state_lrs] == lrs assert lrs == pytest.approx([v for i, v in simulated_values]) concat_scheduler.load_state_dict(state_dict) trainer.state.param_history = None
def test_concat_scheduler_3_schedulers(): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0) scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.5, cycle_size=20) scheduler_2 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.45, cycle_size=10) scheduler_3 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.0, cycle_size=20) durations = [10, 5] concat_scheduler = ConcatScheduler( schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations, save_history=True ) state_dict = concat_scheduler.state_dict() data = [0] * 10 max_epochs = 2 simulated_values = ConcatScheduler.simulate_values( num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations ) def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run(data, max_epochs=max_epochs) assert lrs == list( map( pytest.approx, [ # Cycle 1 of the first LinearCyclicalScheduler 1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, # Cycle 1 of the second LinearCyclicalScheduler 0.5, 0.49, 0.48, 0.47, 0.46, # Cycle 1 of the third LinearCyclicalScheduler 0.5, 0.45, 0.4, 0.35, 0.3, ], ) ) state_lrs = trainer.state.param_history["lr"] assert len(state_lrs) == len(lrs) # Unpack singleton lists assert [group[0] for group in state_lrs] == lrs assert lrs == pytest.approx([v for i, v in simulated_values]) concat_scheduler.load_state_dict(state_dict) trainer.state.param_history = None
gamma=params.gamma, steps_count=N_STEPS) buffer = dqn_extra.PrioReplayBuffer(exp_source, params.replay_size, PRIO_REPLAY_ALPHA) optimizer = optim.Adam(net.parameters(), lr=params.learning_rate) def process_batch(engine, batch_data): batch, batch_indices, batch_weights = batch_data optimizer.zero_grad() loss_v, sample_prios = calc_loss_prio(batch, batch_weights, net, tgt_net.target_model, gamma=params.gamma**N_STEPS, device=device) loss_v.backward() optimizer.step() buffer.update_priorities(batch_indices, sample_prios) if engine.state.iteration % params.target_net_sync == 0: tgt_net.sync() return { "loss": loss_v.item(), "beta": buffer.update_beta(engine.state.iteration), } engine = Engine(process_batch) common.setup_ignite(engine, params, exp_source, NAME) engine.run( common.batch_generator(buffer, params.replay_initial, params.batch_size))
@trainer.on(Events.COMPLETED) def plot_font_results(engine): evaluator.run(valid_loader) real_font, fake_font, latent_vectors = evaluator.state.output print(real_font.shape) print(fake_font) plt.figure(figsize=(6, 100)) for i, (real, fake) in enumerate(zip(real_font, fake_font)): plt.subplot(107, 2, 2 * i + 1) plt.imshow(real.cpu().detach().numpy()) plt.subplot(107, 2, 2 * i + 2) plt.imshow(fake.cpu().detach().numpy()) # plt.savefig('real_fake_fonts_{}_for_category_5layers.png'.format(engine.state.epoch)) plt.close() @trainer.on(Events.COMPLETED) def plot_latent_vectors(engine): evaluator.run(valid_loader) _, _, latent_vectors = evaluator.state.output print(latent_vectors.shape) plt.figure() latent_vectors = latent_vectors.cpu().detach().numpy() for i in range(len(latent_vectors)): plt.plot(latent_vectors[i, 0], latent_vectors[i, 1], marker='o') # plt.plot(latent_vectors[:, 0], latent_vectors[:, 1], marker='.') # plt.savefig('latent_vectors_for_category_layers.png') plt.close() trainer.run(train_loader, max_epochs=epochs)
trainer.tb.writer.add_image("fake", fake_img, trainer.state.iteration) real_img = vutils.make_grid(batch_v.data[:64], normalize=True) trainer.tb.writer.add_image("real", real_img, trainer.state.iteration) trainer.tb.writer.flush() return dis_loss.item(), gen_loss.item() engine = Engine(process_batch) tb = tb_logger.TensorboardLogger(log_dir=None) engine.tb = tb RunningAverage(output_transform=lambda out: out[0]).attach( engine, "avg_loss_gen") RunningAverage(output_transform=lambda out: out[1]).attach( engine, "avg_loss_dis") handler = tb_logger.OutputHandler( tag="train", metric_names=['avg_loss_gen', 'avg_loss_dis']) tb.attach(engine, log_handler=handler, event_name=Events.ITERATION_COMPLETED) @engine.on(Events.ITERATION_COMPLETED) def log_losses(trainer): if trainer.state.iteration % REPORT_EVERY_ITER == 0: log.info("%d: gen_loss=%f, dis_loss=%f", trainer.state.iteration, trainer.state.metrics['avg_loss_gen'], trainer.state.metrics['avg_loss_dis']) engine.run(data=iterate_batches(envs))
class DataflowBenchmark: def __init__(self, num_iters=100, prepare_batch=None): from ignite.handlers import Timer device = idist.device() def upload_to_gpu(engine, batch): if prepare_batch is not None: x, y = prepare_batch(batch, device=device, non_blocking=False) self.num_iters = num_iters self.benchmark_dataflow = Engine(upload_to_gpu) @self.benchmark_dataflow.on(Events.ITERATION_COMPLETED(once=num_iters)) def stop_benchmark_dataflow(engine): engine.terminate() if idist.get_rank() == 0: @self.benchmark_dataflow.on( Events.ITERATION_COMPLETED(every=num_iters // 100)) def show_progress_benchmark_dataflow(engine): print(".", end=" ") self.timer = Timer(average=False) self.timer.attach( self.benchmark_dataflow, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) def attach(self, trainer, train_loader): from torch.utils.data import DataLoader @trainer.on(Events.STARTED) def run_benchmark(_): if idist.get_rank() == 0: print("-" * 50) print(" - Dataflow benchmark") self.benchmark_dataflow.run(train_loader) t = self.timer.value() if idist.get_rank() == 0: print(" ") print( f" Total time ({self.num_iters} iterations) : {t:.5f} seconds" ) print( f" time per iteration : {t / self.num_iters} seconds" ) if isinstance(train_loader, DataLoader): num_images = train_loader.batch_size * self.num_iters print(f" number of images / s : {num_images / t}") print("-" * 50)
def run(output_path, config): device = "cuda" local_rank = config["local_rank"] distributed = backend is not None if distributed: torch.cuda.set_device(local_rank) device = "cuda" rank = dist.get_rank() if distributed else 0 torch.manual_seed(config["seed"] + rank) # Rescale batch_size and num_workers ngpus_per_node = torch.cuda.device_count() ngpus = dist.get_world_size() if distributed else 1 batch_size = config["batch_size"] // ngpus num_workers = int( (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node) train_loader, test_loader = get_train_test_loaders( path=config["data_path"], batch_size=batch_size, distributed=distributed, num_workers=num_workers, ) model = get_model(config["model"]) model = model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank, ) optimizer = optim.SGD( model.parameters(), lr=config["learning_rate"], momentum=config["momentum"], weight_decay=config["weight_decay"], nesterov=True, ) criterion = nn.CrossEntropyLoss().to(device) le = len(train_loader) milestones_values = [ (0, 0.0), (le * config["num_warmup_epochs"], config["learning_rate"]), (le * config["num_epochs"], 0.0), ] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) def _prepare_batch(batch, device, non_blocking): x, y = batch return ( convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking), ) def process_function(engine, batch): x, y = _prepare_batch(batch, device=device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { "batch loss": loss.item(), } trainer = Engine(process_function) train_sampler = train_loader.sampler if distributed else None to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], output_path=output_path, lr_scheduler=lr_scheduler, output_names=metric_names, with_pbar_on_iters=config["display_iters"], log_every_iters=10, ) if rank == 0: tb_logger = TensorboardLogger(log_dir=output_path) tb_logger.attach( trainer, log_handler=OutputHandler(tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED, ) tb_logger.attach( trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED, ) metrics = { "accuracy": Accuracy(device=device if distributed else None), "loss": Loss(criterion, device=device if distributed else None), } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): torch.cuda.synchronize() train_evaluator.run(train_loader) evaluator.run(test_loader) trainer.add_event_handler( Events.EPOCH_STARTED(every=config["validate_every"]), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: if config["display_iters"]: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) tb_logger.attach( train_evaluator, log_handler=OutputHandler( tag="train", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) tb_logger.attach( evaluator, log_handler=OutputHandler( tag="test", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) # Store the best model by validation accuracy: common.save_best_model_by_val_score( output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test", ) if config["log_model_grads_every"] is not None: tb_logger.attach( trainer, log_handler=GradsHistHandler(model, tag=model.__class__.__name__), event_name=Events.ITERATION_COMPLETED( every=config["log_model_grads_every"]), ) if config["crash_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"])) def _(engine): raise Exception("STOP at iteration: {}".format( engine.state.iteration)) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix()) print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: tb_logger.close()
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=1, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=2.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=1000, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") tokenizer_class = GPT2Tokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2DoubleHeadsModel model = model_class.from_pretrained(args.model_checkpoint) model.to(args.device) # Add special tokens if they are not already added add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch (lm_loss), (mc_loss), *_ = model(input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, mc_labels=mc_labels, lm_labels=lm_labels) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) # if we dont send labels to model, it doesnt return losses lm_logits, mc_logits, *_ = model( input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, ) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) log_dir = make_logdir(args.model_checkpoint) tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def main(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model", type=str, default="", help="Model type, one of: %s" % ', '.join(MODELS.keys())) parser.add_argument("--model_checkpoint", type=str, default="", help="Path, url or short name of a pretrained model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--adv_coef", type=float, default=1.0, help="Adversarial dataset prediction loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", default=0.0, type=float, help="Weight decay if we apply some.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") #parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") #parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--no_cuda", action='store_true', help="Avoid using CUDA when available") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument( "--max_sequence_length", type=int, default=-1, help="If set, use this to manually restrict the sequence length. " "This might be helpful to save resources (memory). " "If not set, this is looked up from the model config (n_ctx value).") parser.add_argument( "--adversarial_dataset_prediction", action='store_true', help="Set to train with adversarial dataset prediction") parser.add_argument("--seed", type=int, default=None, help='set random seed') args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) if args.seed is not None: torch.manual_seed(args.seed) args.distributed = (args.local_rank != -1) logger.info("Prepare tokenizer and data") if not args.model: logger.warning( '"model" parameter is not set! This is deprecated. Please use one of: %s. ' 'To mimic deprecated behaviour, "model_checkpoint" will be used as "model"' % ', '.join(MODELS.keys())) args.model = args.model_checkpoint if args.model not in MODELS: raise NotImplementedError( 'model "%s" not implemented. use one of: %s' % (args.model, ', '.join(MODELS.keys()))) config_class, tokenizer_class, model_class, _ = MODELS[args.model] if not args.model_checkpoint: args.model_checkpoint = args.model model_config = config_class.from_pretrained(args.model_checkpoint) tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) additional_special_tokens = [TYPE_BACKGROUND, TYPE_BOT, TYPE_USER] # for adversarial training (dataset prediction) dataset_labels = None if args.adversarial_dataset_prediction: dataset_labels = [ get_dataset_label(dataset_path) for dataset_path in args.dataset_path.split(',') ] #additional_special_tokens.extend(dataset_labels) #if model_class not in ADV_MODELS.values(): assert model_class in ADV_MODELS, f'no adversarial model implemented for model class: {model_class.__name__}' model_class = ADV_MODELS[model_class] if not hasattr(model_config, 'cls'): model_config.cls = {} if 'dataset_labels' in model_config.cls: assert all([dl in model_config.cls['dataset_labels']['labels'] for dl in dataset_labels]), \ f'loaded dataset_labels [{model_config.cls["dataset_labels"]["labels"]}] do not contain all ' \ f'current dataset_labels [{dataset_labels}]' dataset_labels = model_config.cls['dataset_labels']['labels'] else: model_config.cls['dataset_labels'] = { 'labels': dataset_labels, 'is_adversarial': True } model_input_names = [ "input_ids", "mc_token_ids", "lm_labels", "mc_labels", "dataset_labels", "token_type_ids" ] # not yet used model_output_names = [ "lm_loss", "mc_loss", "cl_loss_0", "lm_logits", "mc_logits", "cl_logits_0", "presents" ] else: model_input_names = [ "input_ids", "mc_token_ids", "lm_labels", "mc_labels", "token_type_ids" ] # not yet used model_output_names = [ "lm_loss", "mc_loss", "lm_logits", "mc_logits", "presents" ] tokenizer.add_special_tokens({ 'bos_token': TYPE_BOS, 'eos_token': TYPE_EOS, 'pad_token': TYPE_PAD, 'additional_special_tokens': additional_special_tokens }) logger.info("Prepare datasets") max_sequence_length = model_config.n_ctx if args.max_sequence_length <= 0 else args.max_sequence_length assert max_sequence_length <= model_config.n_ctx, 'max_sequence_length [%i] was set to a value higher than ' \ 'supported by the model (config.n_ctx [%i]). Please use a lower ' \ 'value or do not set it [-1] to use the highest supported one.' \ % (max_sequence_length, model_config.n_ctx) train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args=args, tokenizer=tokenizer, model_input_names=model_input_names, max_sequence_length=max_sequence_length, dataset_labels=dataset_labels) logger.info( "Prepare pretrained model and optimizer - add special tokens for fine-tuning" ) # Initialize distributed training if needed # Setup CUDA, GPU & distributed training if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") args.n_gpu = torch.cuda.device_count() else: # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl') args.n_gpu = 1 args.device = device # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Barrier to make sure only the first process in distributed training download model & vocab #model = model_class.from_pretrained(args.model_checkpoint, num_cl_labels=len(dataset_ids)) # for GPT2DoubleHeadsModelwithAdversarial model = model_class.from_pretrained(args.model_checkpoint, config=model_config) model.resize_token_embeddings(len(tokenizer)) model.to(args.device) if args.local_rank == 0: torch.distributed.barrier( ) # End of barrier to make sure only the first process in distributed training download model & vocab #################################################################################################################### # Prepare optimizer and schedule (linear warmup and decay) no_decay = ['bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) ], 'weight_decay': args.weight_decay }, { 'params': [ p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) ], 'weight_decay': 0.0 }] #optimizer = OpenAIAdam(model.parameters(), lr=args.lr) optimizer = AdamW(optimizer_grouped_parameters, lr=args.lr) # scheduler is set below (see ignite) #scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, # num_training_steps=len(train_loader) // args.train_batch_size + 1) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_checkpoint, 'optimizer.pt')) and os.path.isfile( os.path.join(args.model_checkpoint, 'scheduler.pt')): # Load in optimizer and scheduler states # TODO: this needs to be dumped somewhere optimizer.load_state_dict( torch.load(os.path.join(args.model_checkpoint, 'optimizer.pt'))) #scheduler.load_state_dict(torch.load(os.path.join(args.model_checkpoint, 'scheduler.pt'))) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True) # Training function and trainer def update(engine, batch): model.train() batch = { model_input_names[i]: input_tensor.to(args.device) for i, input_tensor in enumerate(batch) } model_output = model(**batch) losses = model_output[: 3] if args.adversarial_dataset_prediction else model_output[: 2] if args.n_gpu > 1: # mean() to average on multi-gpu. losses = list(losses) for i in range(len(losses)): losses[i] = losses[i].mean() lm_loss, mc_loss = losses[0], losses[1] loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps # handle adversarial loss loss_wo_adv = loss.clone() if args.adversarial_dataset_prediction: adv_loss = model_output[2] loss += (adv_loss * args.adv_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() #scheduler.step() # Update learning rate schedule # already DONE below! optimizer.zero_grad() return loss_wo_adv.item(), loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) if args.adversarial_dataset_prediction: input_ids, mc_token_ids, lm_labels, mc_labels, dataset_labels, token_type_ids = batch else: input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.debug( tokenizer.decode(input_ids[0, -1, :].tolist()).replace( TYPE_PAD, '')) model_outputs = model(input_ids=input_ids, mc_token_ids=mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero (scheduler) scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "loss") if args.adversarial_dataset_prediction: RunningAverage(output_transform=lambda x: x[1]).attach( trainer, "loss_w/_adv") RunningAverage(output_transform=lambda x: x[1] - x[0]).attach( trainer, "loss_only_adv") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=None) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) if args.adversarial_dataset_prediction: tb_logger.attach(trainer, log_handler=OutputHandler( tag="training", metric_names=["loss_w/_adv"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler( tag="training", metric_names=["loss_only_adv"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) logger.info('save checkpoints to: %s' % tb_logger.writer.log_dir) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_pretrained(tb_logger.writer.log_dir) #logger.debug("Saving optimizer and scheduler states to %s", tb_logger.writer.log_dir) #torch.save(optimizer.state_dict(), os.path.join(tb_logger.writer.log_dir, 'optimizer.pt')) #torch.save(scheduler.state_dict(), os.path.join(tb_logger.writer.log_dir, 'scheduler.pt')) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def test_stopping_criterion_is_max_epochs(): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 state = engine.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def test_default_exception_handler(): update_function = MagicMock(side_effect=ValueError()) engine = Engine(update_function) with raises(ValueError): engine.run([1])
def test_integration(): n_iters = 100 batch_size = 10 n_classes = 10 y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) loss_values = iter(range(n_iters)) def update_fn(engine, batch): loss_value = next(loss_values) y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy( y_true_batch) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage( Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, "running_avg_accuracy") avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, "running_avg_output") running_avg_acc = [ None, ] @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): _, y_pred, y = engine.state.output indices = torch.max(y_pred, 1)[1] correct = torch.eq(indices, y).view(-1) num_correct = torch.sum(correct).item() num_examples = correct.shape[0] batch_acc = num_correct * 1.0 / num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.EPOCH_STARTED) def running_avg_output_init(engine): engine.state.running_avg_output = None @trainer.on(Events.ITERATION_COMPLETED) def running_avg_output_update(engine): if engine.state.running_avg_output is None: engine.state.running_avg_output = engine.state.output[0] else: engine.state.running_avg_output = ( engine.state.running_avg_output * alpha + (1.0 - alpha) * engine.state.output[0]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert engine.state.running_avg_acc == engine.state.metrics[ "running_avg_accuracy"], "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_output_values(engine): assert engine.state.running_avg_output == engine.state.metrics[ "running_avg_output"], "{} vs {}".format( engine.state.running_avg_output, engine.state.metrics["running_avg_output"]) np.random.seed(10) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1)
def main(dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs, learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device, alpha, output_dir): # seed check_manual_seed(seed) # netowrks netG = Generator(z_dim, g_filters).to(device) netD = Discriminator(d_filters).to(device) # criterion bce = nn.BCELoss() # optimizers optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) # data dataset = check_dataset(dataset, dataroot) loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G)) if saved_D: netD.load_state_dict(torch.load(saved_D)) # misc real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device) def get_noise(): return torch.randn(batch_size, z_dim, 1, 1, device=device) # The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return { 'errD': errD.item(), 'errG': errG.item(), 'D_x': D_x, 'D_G_z1': D_G_z1, 'D_G_z2': D_G_z2 } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, save_interval=1, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2'] for metric in monitoring_metrics: RunningAverage(alpha=alpha, output_transform=lambda x: x[metric]).attach( trainer, metric) # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = os.path.join(output_dir, LOGS_FNAME) columns = engine.state.metrics.keys() values = [ str(round(value, 5)) for value in engine.state.metrics.values() ] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=epochs, i=(engine.state.iteration % len(loader)), max_i=len(loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'netG': netG, 'netD': netD }) # automatically adding handlers via a special `attach` method of `Timer` handler timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use('agg') import numpy as np import pandas as pd import matplotlib.pyplot as plt except ImportError: warnings.warn( 'Loss plots will not be generated -- pandas or matplotlib not found' ) else: df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter='\t') x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ) _ = df.plot(x=x, subplots=True, figsize=(20, 20)) _ = plt.xlabel('Iteration number') fig = plt.gcf() path = os.path.join(output_dir, PLOT_FNAME) fig.savefig(path) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') create_plots(engine) checkpoint_handler(engine, { 'netG_exception': netG, 'netD_exception': netD }) else: raise e # Setup is done. Now let's run the training trainer.run(loader, epochs)
class UnisstBaseExperiment(BaseExperiment): def __init__( self, gen, dis_img, dis_vid, corruption, train, test=None, val=None, optim_gen=None, optim_dis=None, sacred_run=None, writers=None, root=None, nepoch=10, niter=300, display_frequency=1, num_dis_step:int = 1, device='cuda:0', fid_fvd: bool = True, colorize=False, **kwargs): super().__init__(**kwargs) self.train = train self.test = test self.val = val self.sacred_run = sacred_run self.sacred_run.result = float('Inf') self.device = device self.nepoch = nepoch self.display_frequency = display_frequency self.colorize = colorize if isinstance(niter, str) and niter.find('epoch') > 0: nepoch = int(niter.split(' ')[0]) niter = nepoch * len(train) self.niter = niter self.fid_fvd = fid_fvd if root is not None: self.basedir = os.path.join(root, str(sacred_run._id)) else: writers = None checkpoint = None if writers is not None: self.writers = init_writers(*writers, sacred_run=sacred_run, dirname=self.basedir) else: self.writers = None if checkpoint is not None: self.checkpoint = init_checkpoint_handler(dirname=self.basedir, **checkpoint) self.trainer = Engine(self.train_step) self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.evaluate) self.trainer.add_event_handler(Events.ITERATION_COMPLETED, self.log) self.tester = Engine(self.test_step) self.evaluator = Engine(self.test_step) self.gen = gen.to(self.device) self.dis_img = dis_img.to(self.device) self.dis_vid = dis_vid.to(self.device) self.optim_gen = optim_gen self.optim_dis = optim_dis self.scheduler_gen = ExponentialLR(self.optim_gen, gamma=0.99) self.scheduler_dis = ExponentialLR(self.optim_dis, gamma=0.99) self.corruption = corruption self.num_dis_step = num_dis_step RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_gen'].item()).attach(self.trainer, 'loss_gen') RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_dis_img'].item()).attach(self.trainer, 'loss_dis_img') RunningAverage(alpha=0.9, output_transform=lambda x: x['loss_dis_vid'].item()).attach(self.trainer, 'loss_dis_vid') if self.fid_fvd: self.dims = 2048 block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[self.dims] self.model = InceptionV3([block_idx]).to('cuda') self.fid_score = float('inf') def train_step(self, engine, batch): self.training() batch = convert_tensor(batch, self.device) loss_gen, loss_dis_img, loss_dis_vid, output = self.forward_backward(**batch) metric = self.metric(**output, **batch) loss = { 'loss_gen': loss_gen, 'loss_dis_img': loss_dis_img, 'loss_dis_vid': loss_dis_vid, } return { **batch, **output, **loss, **metric, } def test_step(self, engine, batch): self.evaluating() with torch.no_grad(): batch = convert_tensor(batch, self.device) _, _, _, output = self.forward_backward(backward=False, **batch) metric = self.metric(**output, **batch) return { **batch, **output, **metric, } def evaluate(self, engine): iteration = engine.state.iteration if iteration % self.niter == 0: self.step(engine, iteration) if self.val is not None: self.evaluator.run(self.val, max_epochs=1) self.step(self.evaluator, iteration, dataset_name='val') columns = self.evaluator.state.metrics.keys() values = [value for value in self.evaluator.state.metrics.values()] message = 'Val: ' for name, value in zip(columns, values): message += ' | {name}: {value:.4f}, {std:.4f}'.format(name=name, value=statistics.mean(value), std=statistics.stdev(value)) print(message) if self.test is not None: self.tester.run(self.test, max_epochs=1) self.step(self.tester, iteration, dataset_name='test') columns = self.tester.state.metrics.keys() values = [value for value in self.tester.state.metrics.values()] message = 'Test: ' for name, value in zip(columns, values): message += ' | {name}: {value:.4f}, {std:.4f}'.format(name=name, value=statistics.mean(value), std=statistics.stdev(value)) print(message) def log(self, engine): iter = engine.state.iteration if iter % self.display_frequency == 0: columns = engine.state.metrics.keys() values = [value for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format(epoch=engine.state.epoch, max_epoch=self.nepoch, i=(engine.state.iteration % len(self.train)), max_i=len(self.train)) for name, value in zip(columns, values): message += ' | {name}: {value:.4f}'.format(name=name, value=value) print(message) def write(self, engine, dataset_name): iteration = self.trainer.state.iteration # Logging Images o = engine.state.output b = engine.state.batch img_tensor, nrow = self.get_tensor(o, b) img = make_grid( img_tensor, nrow=nrow, normalize=True, range=(-1, 1) ) try: self.writers.add_image(dataset_name, img, iteration) except: print('IMPOSSIBLE TO SAVE') def forward(self, **kwargs): raise NotImplementedError def backward_gen(self, **kwargs): raise NotImplementedError def backward_dis(self, **kwargs): raise NotImplementedError def get_tensor(self, o, b, limit=2): batch_size, nc, seq_len, height, width = b['x'].shape if batch_size < limit: limit = batch_size x = b['x'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width) y = o['y'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width) x_hat = o['x_hat'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width) if self.colorize: mask = b['mask'][:limit].cpu().permute(0, 2, 1, 3, 4).contiguous().view(-1, nc, height, width) x = colorize(x) y = colorize(y) x_hat = colorize(x_hat) y[mask.expand_as(y)] = 1 list_tensor = [x, y, x_hat] nrow = seq_len return torch.cat(list_tensor), nrow def step(self, engine, iteration, dataset_name='train'): values = {c: value for value, c in zip(engine.state.metrics.values(), engine.state.metrics.keys())} if dataset_name == 'train': self.scheduler_gen.step() self.scheduler_dis.step() if values['loss_dis_img'] + values['loss_dis_vid'] < 0.001: raise CustomInterrupt('DIS_TOO_SMALL') if values['loss_gen'] < 0.001: raise CustomInterrupt('GEN_TOO_SMALL') if values['recon_mae'] > 1.9: raise CustomInterrupt('RECON_TOO_HIGH') metrics = engine.state.metrics if self.writers is not None: for name, value in metrics.items(): metric_name = dataset_name + '/' + name if dataset_name in ['test', 'val']: m = statistics.mean(value) s = statistics.stdev(value) self.writers.add_scalar(metric_name, m, iteration) self.writers.add_scalar(metric_name + '_std', s, iteration) else: self.writers.add_scalar(metric_name, value, iteration) print(f"saving {iteration}") self.write(engine, dataset_name) if iteration % 2 * self.niter == 0 and self.fid_fvd: self.compute_fid(iteration, dataset_name) self.compute_fvd(iteration, dataset_name) def compute_fvd(self, iteration, dataset_name): self.evaluating() fake_list, real_list = [], [] if dataset_name == 'train': dataset = self.train elif dataset_name == 'val': dataset = self.val else: dataset = self.test with torch.no_grad(): for i, batch in enumerate(dataset): batch = convert_tensor(batch, self.device) output = self.forward_backward(**batch, backward=False)[-1] real_seq_len = batch['seq_len'] batch_size, nc, _, _, _ = batch['x'].shape x_hat = output['x_hat'] x = batch['x'] # B x C x T x H x W if nc != 3: fake = x_hat.repeat(1, 3, 1, 1, 1) true = x.repeat(1, 3, 1, 1, 1) fake_list.append(fake.cpu()) real_list.append(true.cpu()) if i == 15: break fake_vid = torch.cat(fake_list, dim=0) real_vid = torch.cat(real_list, dim=0) fvd_score = fvd(real_vid, fake_vid) print(f"FVD_{dataset_name} : {fvd_score}") if self.writers is not None: self.writers.add_scalar(f'FVD_{dataset_name}', fvd_score, iteration) def compute_fid(self, iteration, dataset_name): self.evaluating() fake_list, real_list = [], [] if dataset_name == 'train': dataset = self.train elif dataset_name == 'val': dataset = self.val else: dataset = self.test with torch.no_grad(): for i, batch in enumerate(dataset): batch = convert_tensor(batch, self.device) output = self.forward_backward(**batch, backward=False)[-1] real_seq_len = batch['seq_len'] batch_size, nc, _, _, _ = batch['x'].shape x_hat = output['x_hat'] x = batch['x'] fake = [] true = [] for bi in range(batch_size): fake.append(torch.narrow(x_hat[bi], 1, 0, real_seq_len[bi]).permute(1, 0, 2, 3)) true.append(torch.narrow( x[bi], 1, 0, real_seq_len[bi]).permute(1, 0, 2, 3)) fake = torch.cat(fake, dim=0) true = torch.cat(true, dim=0) if nc != 3: fake = fake.repeat(1, 3, 1, 1) true = true.repeat(1, 3, 1, 1) fake_list.append((fake.cpu().numpy() + 1.0) / 2.0) real_list.append((true.cpu().numpy() + 1.0) / 2.0) if i == 15: break fake_images = np.concatenate(fake_list) real_images = np.concatenate(real_list) mu_fake, sigma_fake = metrics.calculate_activation_statistics( fake_images, self.model, self.train.batch_size, device=self.device ) mu_real, sigma_real = metrics.calculate_activation_statistics( real_images, self.model, self.train.batch_size, device=self.device ) fid_score = metrics.calculate_frechet_distance( mu_fake, sigma_fake, mu_real, sigma_real ) print(f"FID_{dataset_name} : {fid_score}") if self.writers is not None: self.writers.add_scalar(f'FID_{dataset_name}', fid_score, iteration) def run(self): self.trainer.run(self.train, max_epochs=self.nepoch)
@engine.on(PeriodEvents.ITERS_10000_COMPLETED) def test_network(engine: Engine): dqn_model.train(False) test_reward: float test_steps: float test_reward, test_steps, test_deers = test_model( dqn_model, device, configuration) dqn_model.train(True) engine.state.metrics[TEST_REWARD_METRIC] = test_reward engine.state.metrics[TEST_STEPS_METRTIC] = test_steps engine.state.metrics[TEST_DEERS_METRIC] = test_deers print( "Test done: got %.3f reward after %.2f steps. Deer survival %.3f " % (test_reward, test_steps, test_deers)) global best_test_reward if best_test_reward is None: best_test_reward = test_reward elif best_test_reward < test_reward: print("Best test reward updated %.3f <- %.3f, save model" % (best_test_reward, test_reward)) best_test_reward = test_reward torch.save(dqn_model.state_dict(), os.path.join(saves_path, "best_%.3f.dat" % test_reward)) engine.run( batch_generator(replay_buffer, PARAMETERS.replay_initial, PARAMETERS.batch_size))
best_test_reward = None @engine.on(ptan_ignite.PeriodEvents.ITERS_1000_COMPLETED) def test_network(engine): net.train(False) a_reward, a_steps, b_reward, b_steps = test_model(net, device, config) net.train(True) engine.state.metrics['test_reward_a'] = a_reward engine.state.metrics['test_steps_a'] = a_steps engine.state.metrics['test_reward_b'] = b_reward engine.state.metrics['test_steps_b'] = b_steps print( "Test done: A got %.3f reward after %.2f steps, B %.3f reward after %.2f steps" % (a_reward, a_steps, b_reward, b_steps)) global best_test_reward reward = max(a_reward, b_reward) if best_test_reward is None: best_test_reward = reward elif best_test_reward < reward: print("Best test reward updated %.3f <- %.3f, save model" % (best_test_reward, reward)) best_test_reward = reward torch.save(net.state_dict(), os.path.join(saves_path, "best_%.3f.dat" % reward)) engine.run( batch_generator(a_exp_source, b_exp_source, buffer, PARAMS.replay_initial, PARAMS.batch_size))
class Evaluator: """ Class which setups the evaluation logic which mainly involves defining callback handlers and attaching them to the evaluation loop. """ def __init__(self, model, config, data_loaders, tb_writer, run_info, logger, checkpoint_dir): """ Creates a new evaluator object for evaluating a model. :param model: model to train. Needs to inherit from the BaseModel class. :param config: dictionary containing the whole configuration of the experiment :param data_loaders: (dictionary) the keys represent the name and each value contains a pytorch data loader providing the validation data :param tb_writer: tensorboardX summary writer :param run_info: sacred run info for loging training progress :param logger: python logger object :param checkpoint_dir: directory path for storing checkpoints """ self.run_info = run_info self.logger = logger self.data_loaders = data_loaders self.config = config self.engine = Engine(self._step) self.model = model self.tb_writer = tb_writer self.trainer = None # Using custom metric wrapper which retrieves metrics from dictionary instead of separately calculating them. self.metrics = {k: LossFromDict(k) for k in self.model.metric_names} self.non_scalar_metrics = { k: LossFromDict(k, reduce=False) for k in self.model.non_scalar_metrics_names } if 'external_metrics' in config['val_data']: for idx, name in enumerate(config['val_data']['external_metrics']): if 'external_metrics_kw_args' in config['val_data']: self.metrics[name] = get_subclass(name, Metric)( config['devices'][0], **config['val_data']['external_metrics_kw_args'][idx]) else: self.metrics[name] = get_subclass(name, Metric)() self._handle_save_best_checkpoint_handler = \ ModelCheckpoint(checkpoint_dir, 'best', score_function=lambda engine: -self.model.main_metric(engine.state.metrics), score_name=self.model.name_main_metric, n_saved=1, require_empty=False) self.add_handler() self.best_loss = None self.current_data_loader = None self.main_data_loader = config['val_data']['main_dataset'] def run(self): """ Start the evaluation run which will run through one epoch for each validation dataset :return: """ for name, data_loader in self.data_loaders.items(): self.current_data_loader = name self.engine.run(data_loader) def set_trainer(self, trainer): """ Setter method for setting the trainer object which is mainly needed for getting information on the current training iteration. :param trainer: Trainer object :return: """ self.trainer = trainer def add_handler(self): """ Adds all the callback handlers to the trainer engine. Should be called in the end of the init. :return: """ for name, metric in self.metrics.items(): metric.attach(self.engine, name) for name, non_scalar_metric in self.non_scalar_metrics.items(): non_scalar_metric.attach(self.engine, name) # on epoch complete self.engine.add_event_handler( Events.EPOCH_COMPLETED, self._handle_save_best_checkpoint_handler, self.model.networks) self.engine.add_event_handler(Events.EPOCH_COMPLETED, self._handle_log_validation_results) # on iteration complete self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_log_val_images) def _step(self, engine, batch): """ Definition of a single evaluation step. This function gets automatically called by the engine every iteration. :param engine: evaluator engine :param batch: one batch provided by the data loader :return: """ self.model.eval() self.model.set_input(batch) self.model.test() return self.model.state def _handle_log_validation_results(self, engine): """ Handler for writing the losses to tensorboard and sacred. :param engine: evaluation engine :return: """ metrics = self.engine.state.metrics loss = self.model.main_metric(metrics) metrics[self.model.name_main_metric] = loss for name, m in metrics.items(): if 'non_scalar_metric_' not in name: # Only add scalars # log to sacred self.run_info.log_scalar( f"val_{self.current_data_loader}.{name}.", m, self.trainer.engine.state.iteration) self.tb_writer.add_scalar( f"val_{self.current_data_loader}/{name}.", m, self.trainer.engine.state.iteration) self.logger.info( "Validation Results for {} - Epoch: {} Avg loss: {:.6f}".format( self.current_data_loader, self.trainer.engine.state.epoch, loss)) if self.current_data_loader == self.main_data_loader and \ (self.best_loss is None or loss < self.best_loss): self.best_loss = loss self.run_info.result = self.best_loss self._handle_complete_val_dataset_figure(engine) def _handle_log_val_images(self, engine): """ Handler for writing visual samples to tensorboard. :param engine: evaluation engine :return: """ if engine.state.iteration == 1: for name, visual in self.model.visuals.items(): self.tb_writer.add_image( f"val_{self.current_data_loader}/{name}.", visual.transpose(2, 0, 1), self.trainer.engine.state.iteration) def _score_function(self, engine): """ Helper method use in ModelCheckpoint to save the best model. Need to change the sign because it saves the ModelCheckpoint saves the best scores. :param engine: evaluation engine :return: """ val_loss = engine.state.metrics[self.model.name_main_metric] return -val_loss def _handle_complete_val_dataset_figure(self, engine): """ Adds complete validation dataset metric figure to tensorboard. :param engine: evaluation engine :return: """ figures = self.model.get_validation_figures(engine.state) for name, figure in figures.items(): self.tb_writer.add_figure( f"val_{self.current_data_loader}_metrics/{name}", figure, self.trainer.engine.state.iteration)
def _run(self, tempdir): my_rank = dist.get_rank() fnames = ["aaa" * 300, "bbb" * 301, "ccc" * 302] metrics_saver = MetricsSaver( save_dir=tempdir, metrics=["metric1", "metric2"], metric_details=["metric3", "metric4"], batch_transform=lambda x: x[PostFix.meta("image")], summary_ops="*", delimiter="\t", ) def _val_func(engine, batch): pass engine = Engine(_val_func) if my_rank == 0: data = [{PostFix.meta("image"): {"filename_or_obj": [fnames[0]]}}] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics0(engine): engine.state.metrics = {"metric1": 1, "metric2": 2} engine.state.metric_details = { "metric3": torch.tensor([[1, 2]]), "metric4": torch.tensor([[5, 6]]) } if my_rank == 1: # different ranks have different data length data = [ { PostFix.meta("image"): { "filename_or_obj": [fnames[1]] } }, { PostFix.meta("image"): { "filename_or_obj": [fnames[2]] } }, ] @engine.on(Events.EPOCH_COMPLETED) def _save_metrics1(engine): engine.state.metrics = {"metric1": 1, "metric2": 2} engine.state.metric_details = { "metric3": torch.tensor([[2, 3], [3, 4]]), "metric4": torch.tensor([[6, 7], [7, 8]]), } @engine.on(Events.EPOCH_COMPLETED) def _all_gather(engine): scores = engine.state.metric_details["metric3"] engine.state.metric_details[ "metric3"] = evenly_divisible_all_gather(data=scores, concat=True) scores = engine.state.metric_details["metric4"] engine.state.metric_details[ "metric4"] = evenly_divisible_all_gather(data=scores, concat=True) metrics_saver.attach(engine) engine.run(data, max_epochs=1) if my_rank == 0: # check the metrics.csv and content self.assertTrue( os.path.exists(os.path.join(tempdir, "metrics.csv"))) with open(os.path.join(tempdir, "metrics.csv")) as f: f_csv = csv.reader(f) for i, row in enumerate(f_csv): self.assertEqual(row, [f"metric{i + 1}\t{i + 1}"]) self.assertTrue( os.path.exists(os.path.join(tempdir, "metric3_raw.csv"))) # check the metric_raw.csv and content with open(os.path.join(tempdir, "metric3_raw.csv")) as f: f_csv = csv.reader(f) for i, row in enumerate(f_csv): if i > 0: expected = [ f"{fnames[i-1]}\t{float(i):.4f}\t{float(i + 1):.4f}\t{i + 0.5:.4f}" ] self.assertEqual(row, expected) self.assertTrue( os.path.exists(os.path.join(tempdir, "metric3_summary.csv"))) # check the metric_summary.csv and content with open(os.path.join(tempdir, "metric3_summary.csv")) as f: f_csv = csv.reader(f) for i, row in enumerate(f_csv): if i == 1: self.assertEqual(row, [ "class0\t2.0000\t2.0000\t3.0000\t1.0000\t2.8000\t0.8165\t3.0000" ]) elif i == 2: self.assertEqual(row, [ "class1\t3.0000\t3.0000\t4.0000\t2.0000\t3.8000\t0.8165\t3.0000" ]) elif i == 3: self.assertEqual(row, [ "mean\t2.5000\t2.5000\t3.5000\t1.5000\t3.3000\t0.8165\t3.0000" ]) self.assertTrue( os.path.exists(os.path.join(tempdir, "metric4_raw.csv"))) self.assertTrue( os.path.exists(os.path.join(tempdir, "metric4_summary.csv"))) dist.barrier()
class Trainer: """ Class which setups the training logic which mainly involves defining callback handlers and attaching them to the training loop. """ def __init__(self, model, config, evaluator, data_loader, tb_writer, run_info, logger, checkpoint_dir): """ Creates a new trainer object for training a model. :param model: model to train. Needs to inherit from the BaseModel class. :param config: dictionary containing the whole configuration of the experiment :param evaluator: Instance of the evaluator class, used to run evaluation on a specified schedule :param data_loader: pytorch data loader providing the training data :param tb_writer: tensorboardX summary writer :param run_info: sacred run info for loging training progress :param logger: python logger object :param checkpoint_dir: directory path for storing checkpoints """ self.run_info = run_info self.logger = logger self.data_loader = data_loader self.evaluator = evaluator self.engine = Engine(self._step) self.model = model self.config = config self.train_cfg = config['train'] self.tb_writer = tb_writer self.pbar = ProgressBar(ascii=True, desc='* Epoch') self.timer = Timer(average=True) self.save_last_checkpoint_handler = ModelCheckpoint( checkpoint_dir, 'last', save_interval=self.train_cfg['save_interval'], n_saved=self.train_cfg['save_n_last'], require_empty=False) self.add_handler() def run(self): """ Start the training loop which will run until all epochs are complete :return: """ self.engine.run(self.data_loader, max_epochs=self.train_cfg['n_epochs']) def add_handler(self): """ Adds all the callback handlers to the trainer engine. Should be called in the end of the init. :return: """ # Learning rate decay for lr_s in self.model.schedulers: self.engine.add_event_handler(Events.ITERATION_STARTED, lr_s) # Checkpoint saving self.engine.add_event_handler(Events.EPOCH_STARTED, self.save_last_checkpoint_handler, self.model.networks) # Progbar monitoring_metrics = self.model.metric_names for mm in monitoring_metrics: RunningAverage(output_transform=self._extract_loss(mm)).attach( self.engine, mm) self.pbar.attach(self.engine, metric_names=monitoring_metrics) # Timer self.timer.attach(self.engine, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Logging self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_log_train_results) self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_log_train_images) self.engine.add_event_handler(Events.ITERATION_COMPLETED, self._handle_run_evaluation) self.engine.add_event_handler(Events.EPOCH_COMPLETED, self._handle_print_times) # Exception handling self.engine.add_event_handler(Events.EXCEPTION_RAISED, self._handle_exception) def _step(self, engine, batch): """ Definition of a single training step. This function gets automatically called by the engine every iteration. :param engine: trainer engine :param batch: one batch provided by the dataloader :return: """ self.model.train() self.model.set_input(batch) self.model.optimize_parameters() return self.model.state def _handle_log_train_results(self, engine): """ Handler for writing the losses to tensorboard and sacred. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['log_interval'] == 0: metrics = engine.state.metrics # does not include non scalar metrics, since loggers can not handle this for m_name, m_val in metrics.items(): if m_val is None: raise ValueError(f'Value for {m_name} is None') self.run_info.log_scalar("train.%s" % m_name, m_val, engine.state.iteration) self.tb_writer.add_scalar("train/%s" % m_name, m_val, engine.state.iteration) for lr_name, lr_val in self.model.learning_rates.items(): if lr_val is None: raise ValueError(f'Value for {lr_name} is None') self.run_info.log_scalar("train.%s" % lr_name, lr_val, engine.state.iteration) self.tb_writer.add_scalar("train/%s" % lr_name, lr_val, engine.state.iteration) def _handle_log_train_images(self, engine): """ Handler for writing visual samples to tensorboard. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['img_log_interval'] == 0: for name, visual in self.model.visuals.items(): # TODO remove the visual.transpose here and put it in the visualization function of the models self.tb_writer.add_image('train/%s' % name, visual.transpose(2, 0, 1), engine.state.iteration) for name, figure in self.model.figures.items(): self.tb_writer.add_figure('train_metrics/%s' % name, figure, engine.state.iteration) def _handle_run_evaluation(self, engine): """ Handler which will execute evaluation by running the evaluator object. :param engine: train engine :return: """ if (engine.state.iteration - 1) % self.train_cfg['eval_interval'] == 0: self.evaluator.run() def _handle_exception(self, engine, e): """ Exception handler which ensures that the model gets saved when stopped through a keyboard interruption. :param engine: train engine :param e: the exception which caused the training to stop :return: """ if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() self.logger.warning( 'KeyboardInterrupt caught. Exiting gracefully.') self.save_last_checkpoint_handler(engine, self.model.networks) else: raise e def _handle_print_times(self, engine): """ Handler for logging timer information for different training and evaluation steps. :param engine: train engine :return: """ self.logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, self.timer.value())) self.timer.reset() @staticmethod def _extract_loss(key): """ Helper method to return losses for the RunningAverage :param key: (str) loss name :return: (fn) for the corresponding key """ def _func(losses): return losses[key] return _func
def test_linear_scheduler(): with pytest.raises(TypeError, match=r"Argument optimizer should be torch.optim.Optimizer"): LinearCyclicalScheduler({}, "lr", 1, 0, cycle_size=0) tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0.0) with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"): LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=0) with pytest.raises(ValueError, match=r"Argument cycle_size should be positive and larger than 1"): LinearCyclicalScheduler(optimizer, "lr", 1, 0, cycle_size=1) scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10) state_dict = scheduler.state_dict() def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run([0] * 9, max_epochs=2) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 2 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, # 0.6, 0.8, ], ) ) scheduler.load_state_dict(state_dict) optimizer = torch.optim.SGD([tensor], lr=0) scheduler = LinearCyclicalScheduler(optimizer, "lr", 1, 0, 10, cycle_mult=2) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run([0] * 10, max_epochs=3) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 2 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, ], ) ) scheduler.load_state_dict(state_dict) # With float cycle_size optimizer = torch.optim.SGD([tensor], lr=0) scheduler = LinearCyclicalScheduler( optimizer, "lr", start_value=1.2, end_value=0.2, cycle_size=10.00000012, cycle_mult=1.0 ) state_dict = scheduler.state_dict() trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run([0] * 9, max_epochs=2) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.2, 1.0, 0.8, 0.6, 0.4, 0.2, 0.4, 0.6, 0.8, 1.0, # Cycle 2 1.2, 1.0, 0.8, 0.6, 0.4, 0.2, 0.4, 0.6, # 0.8, 1.0, ], ) ) scheduler.load_state_dict(state_dict)
def adv_train_loop(model, params, ds, min_y, base_data, model_id, attack_type, device, batch_size, max_epochs=5): print('training adversarial:', attack_type) ds_train, ds_valid = ds min_y_train, min_y_val = min_y original_model = copy.deepcopy( model) # used to generate adv images for the trained model original_model.eval() model = copy.deepcopy( model) # making a copy so that original model is not changed model = model.to(device) model_id = f'{model_id}_{attack_type}' with create_summary_writer(model, ds_train, base_data, model_id, device=device) as writer: lr = params['lr'] mom = params['momentum'] wd = params['l2_wd'] optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=mom, weight_decay=wd) sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5) funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)} loss = funcs['loss']._loss_fn acc_metric = Accuracy(device=device) loss_metric = Loss(F.cross_entropy, device=device) acc_val_metric = Accuracy(device=device) loss_val_metric = Loss(F.cross_entropy, device=device) classifier = PyTorchClassifier( model=original_model, clip_values=(0, 1), loss=nn.CrossEntropyLoss(), optimizer=optimizer, input_shape=(3, 64, 64), nb_classes=200, ) attack = None # if attack_type == "fgsm": # attack = FastGradientMethod(estimator=classifier, eps=0.2) # elif attack_type == "bim": # attack = BasicIterativeMethod(estimator=classifier, eps=0.2) # elif attack_type == "carlini": # attack = CarliniLInfMethod(classifier=classifier) # elif attack_type == "deepfool": # attack = DeepFool(classifier=classifier) if attack_type == "fgsm": attack = GradientSignAttack(model, loss_fn=loss, eps=0.2) elif attack_type == "ffa": attack = FastFeatureAttack(model, loss_fn=loss, eps=0.3) elif attack_type == "carlini": attack = CarliniWagnerL2Attack(model, 200, max_iterations=1000) elif attack_type == "lbfgs": attack = DeepFool(classifier=classifier) def train_step(engine, batch): model.train() x, y = batch x = x.to(device) y = y.to(device) - min_y_train with ctx_noparamgrad_and_eval(model): x_adv = attack.perturb(x, y) optimizer.zero_grad() x = torch.cat((x, x_adv)) y = torch.cat((y, y)) ans = model.forward(x) l = loss(ans, y) optimizer.zero_grad() l.backward() optimizer.step() # return ans, y return l.item() trainer = Engine(train_step) # acc_metric.attach(trainer, "accuracy") # loss_metric.attach(trainer, 'loss') def train_eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_train x_adv = attack.perturb(x, y) x = torch.cat((x, x_adv)) y = torch.cat((y, y)) with torch.no_grad(): ans = model.forward(x) return ans, y train_evaluator = Engine(train_eval_step) acc_metric.attach(train_evaluator, "accuracy") loss_metric.attach(train_evaluator, 'loss') def validation_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_train x_adv = attack.perturb(x, y) x = torch.cat((x, x_adv)) y = torch.cat((y, y)) with torch.no_grad(): ans = model.forward(x) return ans, y valid_evaluator = Engine(validation_step) acc_val_metric.attach(valid_evaluator, "accuracy") loss_val_metric.attach(valid_evaluator, 'loss') @trainer.on( Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10)) def log_validation_results(engine): valid_evaluator.run(ds_valid) metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] avg_nll = metrics['loss'] print( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, valid_avg_accuracy, avg_nll)) writer.add_scalar("validation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy, engine.state.epoch) writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler(engine): metrics = valid_evaluator.state.metrics avg_nll = metrics['accuracy'] sched.step(avg_nll) @trainer.on(Events.ITERATION_COMPLETED(every=50)) def log_training_loss(engine): batch = engine.state.batch ds = DataLoader(TensorDataset(*batch), batch_size=batch_size) train_evaluator.run(ds) metrics = train_evaluator.state.metrics # metrics = engine.state.metrics accuracy = metrics['accuracy'] nll = metrics['loss'] iter = (engine.state.iteration - 1) % len(ds_train) + 1 if (iter % 50) == 0: print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}". format(engine.state.epoch, iter, len(ds_train), accuracy, nll)) writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch) writer.add_scalar("batchtraining/accuracy", accuracy, engine.state.iteration) writer.add_scalar("batchtraining/error", 1. - accuracy, engine.state.iteration) writer.add_scalar("batchtraining/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_lr(engine): writer.add_scalar("lr", optimizer.param_groups[0]['lr'], engine.state.epoch) # @trainer.on(Events.EPOCH_COMPLETED) # def log_training_results(engine): # train_evaluator.run(ds_train) # metrics = train_evaluator.state.metrics # # metrics = engine.state.metrics # avg_accuracy = metrics['accuracy'] # avg_nll = metrics['loss'] # print("Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" # .format(engine.state.epoch, avg_accuracy, avg_nll)) # writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) # writer.add_scalar("training/avg_accuracy", # avg_accuracy, engine.state.epoch) # writer.add_scalar("training/avg_error", 1. - # avg_accuracy, engine.state.epoch) @trainer.on( Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10)) def validation_value(engine): metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] return valid_avg_accuracy to_save = {'model': model} handler = Checkpoint( to_save, DiskSaver(os.path.join(base_data, model_id), create_dir=True), score_function=validation_value, score_name="val_acc", global_step_transform=global_step_from_engine(trainer), n_saved=None) # kick everything off trainer.add_event_handler( Events.ITERATION_COMPLETED(every=200 * 5000 // batch_size // 10), handler) trainer.run(ds_train, max_epochs=max_epochs)
def _test(milestones_as_np_int): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0) milestones_values = [(5, 0.5), (15, 1.0), (25, 0.0), (35, 1.0), (40, 0.5)] if milestones_as_np_int: milestones_values = [(np.int64(t), v) for t, v in milestones_values] scheduler = PiecewiseLinear(optimizer, "lr", milestones_values=milestones_values) state_dict = scheduler.state_dict() def save_lr(engine): lrs.append(optimizer.param_groups[0]["lr"]) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) for _ in range(2): lrs = [] trainer.run([0] * 25, max_epochs=2) assert lrs == list( map( pytest.approx, [ 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, ], ) ) scheduler.load_state_dict(state_dict)
def train(run_name, forward_func, model, train_set, val_set, n_epochs, batch_size, lr): # Make the run directory save_dir = os.path.join('training/simple/saved_runs', run_name) if run_name == 'debug': shutil.rmtree(save_dir, ignore_errors=True) os.mkdir(save_dir) model = model.to(device) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Training step def step(engine, batch): model.train() if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_gen, x_q, _ = forward_func(model, batch) loss = F.l1_loss(x_gen, x_q) loss.backward() optimizer.step() optimizer.zero_grad() return {'L1': loss} # Trainer and metrics trainer = Engine(step) metric_names = ['L1'] RunningAverage(output_transform=lambda x: x['L1']).attach(trainer, 'L1') ProgressBar().attach(trainer, metric_names=metric_names) Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Model checkpointing checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'), type(model).__name__, save_interval=1, n_saved=3, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optimizer': optimizer }) # Tensorbard writer writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs')) @trainer.on(Events.ITERATION_COMPLETED) def log_metrics(engine): if engine.state.iteration % 100 == 0: for metric, value in engine.state.metrics.items(): writer.add_scalar('training/{}'.format(metric), value, engine.state.iteration) def save_images(engine, batch): x_gen, x_q, r = forward_func(model, batch) r_dim = r.shape[1] if isinstance(model, SimpleVVGQN): r = (r + 1) / 2 r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim))) x_gen = x_gen.detach().cpu().float() r = r.detach().cpu().float() writer.add_image('representation', make_grid(r), engine.state.epoch) writer.add_image('generation', make_grid(x_gen), engine.state.epoch) writer.add_image('query', make_grid(x_q), engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): model.eval() with torch.no_grad(): batch = next(iter(val_loader)) if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_gen, x_q, r = forward_func(model, batch) loss = F.l1_loss(x_gen, x_q) writer.add_scalar('validation/L1', loss.item(), engine.state.epoch) save_images(engine, batch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): writer.close() engine.terminate() if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): import warnings warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'model_exception': model}) else: raise e start_time = time.time() trainer.run(train_loader, n_epochs) writer.close() end_time = time.time() print('Total training time: {}'.format( timedelta(seconds=end_time - start_time)))
def main( batch_size, epochs, length_scale, centroid_size, model_output_size, learning_rate, l_gradient_penalty, gamma, weight_decay, final_model, ): name = f"DUQ_{length_scale}__{l_gradient_penalty}_{gamma}_{centroid_size}" writer = SummaryWriter(comment=name) ds = all_datasets["CIFAR10"]() input_size, num_classes, dataset, test_dataset = ds # Split up training set idx = list(range(len(dataset))) random.shuffle(idx) if final_model: train_dataset = dataset val_dataset = test_dataset else: val_size = int(len(dataset) * 0.8) train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) val_dataset.transform = (test_dataset.transform ) # Test time preprocessing for validation model = ResNet_DUQ(input_size, num_classes, centroid_size, model_output_size, length_scale, gamma) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 75], gamma=0.2) def bce_loss_fn(y_pred, y): bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div( num_classes * y_pred.shape[0]) return bce def output_transform_bce(output): y_pred, y, x = output y = F.one_hot(y, num_classes).float() return y_pred, y def output_transform_acc(output): y_pred, y, x = output return y_pred, y def output_transform_gp(output): y_pred, y, x = output return x, y_pred def calc_gradients_input(x, y_pred): gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) return gradients def calc_gradient_penalty(x, y_pred): gradients = calc_gradients_input(x, y_pred) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1)**2).mean() return gradient_penalty def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() if l_gradient_penalty > 0: x.requires_grad_(True) z, y_pred = model(x) y = F.one_hot(y, num_classes).float() loss = bce_loss_fn(y_pred, y) if l_gradient_penalty > 0: loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred) loss.backward() optimizer.step() x.requires_grad_(False) with torch.no_grad(): model.eval() model.update_embeddings(x, y) return loss.item() def eval_step(engine, batch): model.eval() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) z, y_pred = model(x) return y_pred, y, x trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "loss") metric = Accuracy(output_transform=output_transform_acc) metric.attach(evaluator, "accuracy") metric = Loss(F.binary_cross_entropy, output_transform=output_transform_bce) metric.attach(evaluator, "bce") metric = Loss(calc_gradient_penalty, output_transform=output_transform_gp) metric.attach(evaluator, "gradient_penalty") kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=1000, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, **kwargs) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics loss = metrics["loss"] print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f} ") writer.add_scalar("Loss/train", loss, trainer.state.epoch) if trainer.state.epoch % 5 == 0 or trainer.state.epoch > 65: accuracy, auroc = get_cifar_svhn_ood(model) print(f"Test Accuracy: {accuracy}, AUROC: {auroc}") writer.add_scalar("OoD/test_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch) accuracy, auroc = get_auroc_classification(val_dataset, model) print(f"AUROC - uncertainty: {auroc}") writer.add_scalar("OoD/val_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc_classification", auroc, trainer.state.epoch) evaluator.run(val_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] bce = metrics["bce"] GP = metrics["gradient_penalty"] loss = bce + l_gradient_penalty * GP print((f"Valid - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"Loss: {loss:.2f} " f"BCE: {bce:.2f} " f"GP: {GP:.2f} ")) writer.add_scalar("Loss/valid", loss, trainer.state.epoch) writer.add_scalar("BCE/valid", bce, trainer.state.epoch) writer.add_scalar("GP/valid", GP, trainer.state.epoch) writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch) print(f"Centroid norm: {torch.norm(model.m / model.N, dim=0)}") scheduler.step() if trainer.state.epoch > 65: torch.save(model.state_dict(), f"saved_models/{name}_{trainer.state.epoch}.pt") pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) trainer.run(train_loader, max_epochs=epochs) evaluator.run(test_loader) acc = evaluator.state.metrics["accuracy"] print(f"Test - Accuracy {acc:.4f}") writer.close()
def _test(metric_device): data = list(range(n_iters)) np.random.seed(12) all_y_true_batch_values = np.random.randint( 0, n_classes, size=(idist.get_world_size(), n_epochs * n_iters, batch_size)) all_y_pred_batch_values = np.random.rand(idist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return torch.from_numpy(y_pred_batch), torch.from_numpy( y_true_batch) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage(Accuracy( output_transform=lambda x: [x[0], x[1]], device=metric_device), alpha=alpha, epoch_bound=False) acc_metric.attach(trainer, "running_avg_accuracy") running_avg_acc = [ None, ] true_acc_metric = Accuracy(device=metric_device) @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): i = engine.state.iteration - 1 true_acc_metric.reset() for j in range(idist.get_world_size()): output = ( torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), torch.from_numpy(all_y_true_batch_values[j, i, :]), ) true_acc_metric.update(output) batch_acc = true_acc_metric._num_correct.item( ) * 1.0 / true_acc_metric._num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert engine.state.running_avg_acc == engine.state.metrics[ "running_avg_accuracy"], "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) trainer.run(data, max_epochs=3)
def train_model( name="", resume="", base_dir=utils.BASE_DIR, model_name="v0", chosen_diseases=None, n_epochs=10, batch_size=4, oversample=False, max_os=None, shuffle=False, opt="sgd", opt_params={}, loss_name="wbce", loss_params={}, train_resnet=False, log_metrics=None, flush_secs=120, train_max_images=None, val_max_images=None, test_max_images=None, experiment_mode="debug", save=True, save_cms=True, # Note that in this case, save_cms (to disk) includes write_cms (to TB) write_graph=False, write_emb=False, write_emb_img=False, write_img=False, image_format="RGB", multiple_gpu=False, ): # Choose GPU device = utilsT.get_torch_device() print("Using device: ", device) # Common folders dataset_dir = os.path.join(base_dir, "dataset") # Dataset handling print("Loading train dataset...") train_dataset, train_dataloader = utilsT.prepare_data( dataset_dir, "train", chosen_diseases, batch_size, oversample=oversample, max_os=max_os, shuffle=shuffle, max_images=train_max_images, image_format=image_format, ) train_samples, _ = train_dataset.size() print("Loading val dataset...") val_dataset, val_dataloader = utilsT.prepare_data( dataset_dir, "val", chosen_diseases, batch_size, max_images=val_max_images, image_format=image_format, ) val_samples, _ = val_dataset.size() # Should be the same than chosen_diseases chosen_diseases = list(train_dataset.classes) print("Chosen diseases: ", chosen_diseases) if resume: # Load model and optimizer model, model_name, optimizer, opt, loss_name, loss_params, chosen_diseases = models.load_model( base_dir, resume, experiment_mode="", device=device) model.train(True) else: # Create model model = models.init_empty_model(model_name, chosen_diseases, train_resnet=train_resnet).to(device) # Create optimizer OptClass = optimizers.get_optimizer_class(opt) optimizer = OptClass(model.parameters(), **opt_params) # print("OPT: ", opt_params) # Allow multiple GPUs if multiple_gpu: model = DataParallel(model) # Tensorboard log options run_name = utils.get_timestamp() if name: run_name += "_{}".format(name) if len(chosen_diseases) == 1: run_name += "_{}".format(chosen_diseases[0]) elif len(chosen_diseases) == 14: run_name += "_all" log_dir = get_log_dir(base_dir, run_name, experiment_mode=experiment_mode) print("Run name: ", run_name) print("Saved TB in: ", log_dir) writer = SummaryWriter(log_dir=log_dir, flush_secs=flush_secs) # Create validator engine validator = Engine( utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params, False)) val_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1) val_loss.attach(validator, loss_name) utilsT.attach_metrics(validator, chosen_diseases, "prec", Precision, True) utilsT.attach_metrics(validator, chosen_diseases, "recall", Recall, True) utilsT.attach_metrics(validator, chosen_diseases, "acc", Accuracy, True) utilsT.attach_metrics(validator, chosen_diseases, "roc_auc", utilsT.RocAucMetric, False) utilsT.attach_metrics(validator, chosen_diseases, "cm", ConfusionMatrix, get_transform_fn=utilsT.get_transform_cm, metric_args=(2, )) utilsT.attach_metrics(validator, chosen_diseases, "positives", RunningAverage, get_transform_fn=utilsT.get_count_positives) # Create trainer engine trainer = Engine( utilsT.get_step_fn(model, optimizer, device, loss_name, loss_params, True)) train_loss = RunningAverage(output_transform=lambda x: x[0], alpha=1) train_loss.attach(trainer, loss_name) utilsT.attach_metrics(trainer, chosen_diseases, "acc", Accuracy, True) utilsT.attach_metrics(trainer, chosen_diseases, "prec", Precision, True) utilsT.attach_metrics(trainer, chosen_diseases, "recall", Recall, True) utilsT.attach_metrics(trainer, chosen_diseases, "roc_auc", utilsT.RocAucMetric, False) utilsT.attach_metrics(trainer, chosen_diseases, "cm", ConfusionMatrix, get_transform_fn=utilsT.get_transform_cm, metric_args=(2, )) utilsT.attach_metrics(trainer, chosen_diseases, "positives", RunningAverage, get_transform_fn=utilsT.get_count_positives) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) # TODO: Early stopping # def score_function(engine): # val_loss = engine.state.metrics[loss_name] # return -val_loss # handler = EarlyStopping(patience=10, score_function=score_function, trainer=trainer) # validator.add_event_handler(Events.COMPLETED, handler) # Metrics callbacks if log_metrics is None: log_metrics = list(ALL_METRICS) def _write_metrics(run_type, metrics, epoch, wall_time): loss = metrics.get(loss_name, 0) writer.add_scalar("Loss/" + run_type, loss, epoch, wall_time) for metric_base_name in log_metrics: for disease in chosen_diseases: metric_value = metrics.get( "{}_{}".format(metric_base_name, disease), -1) writer.add_scalar( "{}_{}/{}".format(metric_base_name, disease, run_type), metric_value, epoch, wall_time) @trainer.on(Events.EPOCH_COMPLETED) def tb_write_metrics(trainer): epoch = trainer.state.epoch max_epochs = trainer.state.max_epochs # Run on evaluation validator.run(val_dataloader, 1) # Common time wall_time = time.time() # Log all metrics to TB _write_metrics("train", trainer.state.metrics, epoch, wall_time) _write_metrics("val", validator.state.metrics, epoch, wall_time) train_loss = trainer.state.metrics.get(loss_name, 0) val_loss = validator.state.metrics.get(loss_name, 0) tb_write_histogram(writer, model, epoch, wall_time) print("Finished epoch {}/{}, loss {:.3f}, val loss {:.3f} (took {})". format(epoch, max_epochs, train_loss, val_loss, utils.duration_to_str(int(timer._elapsed())))) # Hparam dict hparam_dict = { "resume": resume, "n_diseases": len(chosen_diseases), "diseases": ",".join(chosen_diseases), "n_epochs": n_epochs, "batch_size": batch_size, "shuffle": shuffle, "model_name": model_name, "opt": opt, "loss": loss_name, "samples (train, val)": "{},{}".format(train_samples, val_samples), "train_resnet": train_resnet, "multiple_gpu": multiple_gpu, } def copy_params(params_dict, base_name): for name, value in params_dict.items(): hparam_dict["{}_{}".format(base_name, name)] = value copy_params(loss_params, "loss") copy_params(opt_params, "opt") print("HPARAM: ", hparam_dict) # Train print("-" * 50) print("Training...") trainer.run(train_dataloader, n_epochs) # Capture time secs_per_epoch = timer.value() duration_per_epoch = utils.duration_to_str(int(secs_per_epoch)) print("Average time per epoch: ", duration_per_epoch) print("-" * 50) ## Write all hparams hparam_dict["duration_per_epoch"] = duration_per_epoch # FIXME: this is commented to avoid having too many hparams in TB frontend # metrics # def copy_metrics(engine, engine_name): # for metric_name, metric_value in engine.state.metrics.items(): # hparam_dict["{}_{}".format(engine_name, metric_name)] = metric_value # copy_metrics(trainer, "train") # copy_metrics(validator, "val") print("Writing TB hparams") writer.add_hparams(hparam_dict, {}) # Save model to disk if save: print("Saving model...") models.save_model(base_dir, run_name, model_name, experiment_mode, hparam_dict, trainer, model, optimizer) # Write graph to TB if write_graph: print("Writing TB graph...") tb_write_graph(writer, model, train_dataloader, device) # Write embeddings to TB if write_emb: print("Writing TB embeddings...") image_size = 256 if write_emb_img else 0 # FIXME: be able to select images (balanced, train vs val, etc) image_list = list(train_dataset.label_index["FileName"])[:1000] # disease = chosen_diseases[0] # positive = train_dataset.label_index[train_dataset.label_index[disease] == 1] # negative = train_dataset.label_index[train_dataset.label_index[disease] == 0] # positive_images = list(positive["FileName"])[:25] # negative_images = list(negative["FileName"])[:25] # image_list = positive_images + negative_images all_images, all_embeddings, all_predictions, all_ground_truths = gen_embeddings( model, train_dataset, device, image_list=image_list, image_size=image_size) tb_write_embeddings( writer, chosen_diseases, all_images, all_embeddings, all_predictions, all_ground_truths, global_step=n_epochs, use_images=write_emb_img, tag="1000_{}".format("img" if write_emb_img else "no_img"), ) # Save confusion matrices (is expensive to calculate them afterwards) if save_cms: print("Saving confusion matrices...") # Assure folder cms_dir = os.path.join(base_dir, "cms", experiment_mode) os.makedirs(cms_dir, exist_ok=True) base_fname = os.path.join(cms_dir, run_name) n_diseases = len(chosen_diseases) def extract_cms(metrics): """Extract confusion matrices from a metrics dict.""" cms = [] for disease in chosen_diseases: key = "cm_" + disease if key not in metrics: cm = np.array([[-1, -1], [-1, -1]]) else: cm = metrics[key].numpy() cms.append(cm) return np.array(cms) # Train confusion matrix train_cms = extract_cms(trainer.state.metrics) np.save(base_fname + "_train", train_cms) tb_write_cms(writer, "train", chosen_diseases, train_cms) # Validation confusion matrix val_cms = extract_cms(validator.state.metrics) np.save(base_fname + "_val", val_cms) tb_write_cms(writer, "val", chosen_diseases, val_cms) # All confusion matrix (train + val) all_cms = train_cms + val_cms np.save(base_fname + "_all", all_cms) # Print to console if len(chosen_diseases) == 1: print("Train CM: ") print(train_cms[0]) print("Val CM: ") print(val_cms[0]) # print("Train CM 2: ") # print(trainer.state.metrics["cm_" + chosen_diseases[0]]) # print("Val CM 2: ") # print(validator.state.metrics["cm_" + chosen_diseases[0]]) if write_img: # NOTE: this option is not recommended, use Testing notebook to plot and analyze images print("Writing images to TB...") test_dataset, test_dataloader = utilsT.prepare_data( dataset_dir, "test", chosen_diseases, batch_size, max_images=test_max_images, ) # TODO: add a way to select images? # image_list = list(test_dataset.label_index["FileName"])[:3] # Examples in test_dataset (with bboxes available): image_list = [ # "00010277_000.png", # (Effusion, Infiltrate, Mass, Pneumonia) # "00018427_004.png", # (Atelectasis, Effusion, Mass) # "00021703_001.png", # (Atelectasis, Effusion, Infiltrate) # "00028640_008.png", # (Effusion, Infiltrate) # "00019124_104.png", # (Pneumothorax) # "00019124_090.png", # (Nodule) # "00020318_007.png", # (Pneumothorax) "00000003_000.png", # (0) # "00000003_001.png", # (0) # "00000003_002.png", # (0) "00000732_005.png", # (Cardiomegaly, Pneumothorax) # "00012261_001.png", # (Cardiomegaly, Pneumonia) # "00013249_033.png", # (Cardiomegaly, Pneumonia) # "00029808_003.png", # (Cardiomegaly, Pneumonia) # "00022215_012.png", # (Cardiomegaly, Pneumonia) # "00011402_007.png", # (Cardiomegaly, Pneumonia) # "00019018_007.png", # (Cardiomegaly, Infiltrate) # "00021009_001.png", # (Cardiomegaly, Infiltrate) # "00013670_151.png", # (Cardiomegaly, Infiltrate) # "00005066_030.png", # (Cardiomegaly, Infiltrate, Effusion) "00012288_000.png", # (Cardiomegaly) "00008399_007.png", # (Cardiomegaly) "00005532_000.png", # (Cardiomegaly) "00005532_014.png", # (Cardiomegaly) "00005532_016.png", # (Cardiomegaly) "00005827_000.png", # (Cardiomegaly) # "00006912_007.png", # (Cardiomegaly) # "00007037_000.png", # (Cardiomegaly) # "00007043_000.png", # (Cardiomegaly) # "00012741_004.png", # (Cardiomegaly) # "00007551_020.png", # (Cardiomegaly) # "00007735_040.png", # (Cardiomegaly) # "00008339_010.png", # (Cardiomegaly) # "00008365_000.png", # (Cardiomegaly) # "00012686_003.png", # (Cardiomegaly) ] tb_write_images(writer, model, test_dataset, chosen_diseases, n_epochs, device, image_list) # Close TB writer if experiment_mode != "debug": writer.close() # Run post_train print("-" * 50) print("Running post_train...") print("Loading test dataset...") test_dataset, test_dataloader = utilsT.prepare_data( dataset_dir, "test", chosen_diseases, batch_size, max_images=test_max_images) save_cms_with_names(run_name, experiment_mode, model, test_dataset, test_dataloader, chosen_diseases) evaluate_model(run_name, model, optimizer, device, loss_name, loss_params, chosen_diseases, test_dataloader, experiment_mode=experiment_mode, base_dir=base_dir) # Return values for debugging model_run = ModelRun(model, run_name, model_name, chosen_diseases) if experiment_mode == "debug": model_run.save_debug_data(writer, trainer, validator, train_dataset, train_dataloader, val_dataset, val_dataloader) return model_run
def main(config, needs_save, study_name, k, n_splits): if config.run.visible_devices: os.environ['CUDA_VISIBLE_DEVICES'] = config.run.visible_devices seed = check_manual_seed(config.run.seed) print('Using seed: {}'.format(seed)) train_data_loader, test_data_loader, data_train = get_k_hold_data_loader( config.dataset, k=k, n_splits=n_splits, ) data_train = torch.from_numpy(data_train).float().cuda(non_blocking=True) data_train = torch.t(data_train) model = get_model(config.model) model.cuda() model = nn.DataParallel(model) print('count params: ', count_parameters(model.module)) saved_model_path, _, _ = get_saved_model_path( config, study_name, config.model.checkpoint_epoch, k, n_splits, ) model.load_state_dict(torch.load(saved_model_path)['model']) model.eval() if config.model.model_name == 'MLP': embedding = model.module.get_embedding() elif config.model.model_name == 'ModifiedMLP': embedding = model.module.get_embedding() elif config.model.model_name == 'DietNetworks': embedding = model.module.get_embedding(data_train) elif config.model.model_name == 'ModifiedDietNetworks': embedding = model.module.get_embedding(data_train) embedding = embedding.detach().cpu().numpy() emb_pca = PCA(n_components=2) emb_pca.fit_transform(embedding) if config.run.decomp == '1D': print('Approximate by 1D PCA') axis_1= torch.from_numpy(emb_pca.components_[0]) score_1 = np.dot(embedding, axis_1) approx = np.outer(score_1, axis_1) elif config.run.decomp == '2D': print('Approximate by 2D PCA') axis_1= torch.from_numpy(emb_pca.components_[0]) score_1 = np.dot(embedding, axis_1) axis_2= torch.from_numpy(emb_pca.components_[1]) score_2 = np.dot(embedding, axis_2) approx = np.outer(score_1, axis_1) + np.outer(score_2, axis_2) # approx = np.outer(score_2, axis_2) approx = torch.from_numpy(approx).float().cuda(non_blocking=True) criterion = nn.CrossEntropyLoss() def inference(engine, batch): x = batch['data'].float().cuda(non_blocking=True) y = batch['label'].long().cuda(non_blocking=True) assert config.run.transposed_matrix == 'overall' x_t = data_train with torch.no_grad(): out, _ = model.module.approx(x, approx) l_discriminative = criterion(out, y) l_total = l_discriminative metrics = calc_metrics(out, y) metrics.update({ 'l_total': l_total.item(), 'l_discriminative': l_discriminative.item(), }) torch.cuda.synchronize() return metrics evaluator = Engine(inference) monitoring_metrics = ['l_total', 'l_discriminative', 'accuracy'] for metric in monitoring_metrics: RunningAverage( alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric) ).attach(evaluator, metric) pbar = ProgressBar() pbar.attach(evaluator, metric_names=monitoring_metrics) evaluator.run(test_data_loader, 1) columns = ['k', 'n_splits', 'epoch', 'iteration'] + list(evaluator.state.metrics.keys()) values = [str(k), str(n_splits), str(evaluator.state.epoch), str(evaluator.state.iteration)] \ + [str(value) for value in evaluator.state.metrics.values()] values = {c: v for (c, v) in zip(columns, values)} values.update({ 'variance_ratio_1': emb_pca.explained_variance_ratio_[0], 'variance_ratio_2': emb_pca.explained_variance_ratio_[1], }) return values