def test_pbar_file(tmp_path): n_epochs = 2 loader = [1, 2] engine = Engine(update_fn) file_path = tmp_path / "temp.txt" file = open(str(file_path), "w+") pbar = ProgressBar(file=file) pbar.attach(engine, ["a"]) engine.run(loader, max_epochs=n_epochs) file.close() # Force a flush of the buffer. file.flush() does not work. file = open(str(file_path), "r") lines = file.readlines() if get_tqdm_version() < LooseVersion("4.49.0"): expected = "Epoch [2/2]: [1/2] 50%|█████ , a=1 [00:00<00:00]\n" else: expected = "Epoch [2/2]: [1/2] 50%|█████ , a=1 [00:00<?]\n" assert lines[-2] == expected
def test_pbar_wrong_events_order(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.EPOCH_COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.COMPLETED, closing_event_name=Events.ITERATION_COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.EPOCH_COMPLETED) with pytest.raises(ValueError, match="should be called before closing event"): pbar.attach(engine, event_name=Events.ITERATION_COMPLETED, closing_event_name=Events.ITERATION_STARTED) with pytest.raises(ValueError, match="should not be a filtered event"): pbar.attach(engine, event_name=Events.ITERATION_STARTED, closing_event_name=Events.EPOCH_COMPLETED(every=10))
def fit(self, train_loader, valid_loader, n_epochs): trainer = create_supervised_trainer(self.model, self.optim, self.loss_fn, device=self.device) evaluator = create_supervised_evaluator( self.model, metrics={'loss': Loss(self.loss_fn)}, device=self.device) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') pbar = ProgressBar(persist=False) pbar.attach(trainer, metric_names="all") @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(): train_loss = trainer.state.metrics['loss'] evaluator.run(valid_loader) valid_loss = evaluator.state.metrics['loss'] self.history['train_loss'].append(train_loss) self.history['valid_loss'].append(valid_loss) if valid_loss < self.best_loss: self.best_loss = valid_loss self.best_epoch = trainer.state.epoch self.best_model = deepcopy(self.model.state_dict()) template = "Epoch [%3d/%3d] >> train_loss = %.4f, valid_loss = %.4f, " template += "lowest_loss = %.4f @epoch = %d" pbar.log_message(template % (trainer.state.epoch, trainer.state.max_epochs, trainer.state.output, valid_loss, self.best_loss, self.best_epoch)) trainer.run(train_loader, max_epochs=n_epochs) self.model.load_state_dict(self.best_model)
def start_to_learn(trainer, train_loader, tester, test_loader, epochs, model_metrics): # ---Log Message Initializing--- log_msg = ProgressBar(persist=True, bar_format=" ") # ---After Training starts Testing--- @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): tester.run(test_loader) epoch = engine.state.epoch metrics = tester.state.metrics model_metrics.update({epoch: metrics}) log_msg.log_message("Epoch: {} \nAccuracy: {:.3f} \tLoss: {:.3f} \tRecall: {:.3f} \tPrecision: {:.3f}\n\n" .format(epoch, metrics['accuracy'], metrics['loss'], metrics['precision'], metrics['recall'])) log_msg.n = log_msg.last_print_n = 0 start = time.time() trainer.run(train_loader, epochs) # Training Starting duration = time.time() - start print('Duration of execution: ', duration) return duration
def test_pbar_with_state_attrs(capsys): n_iters = 2 data = list(range(n_iters)) loss_values = iter(range(n_iters)) def step(engine, batch): loss_value = next(loss_values) return loss_value trainer = Engine(step) trainer.state.alpha = 3.899 trainer.state.beta = torch.tensor(12.21) trainer.state.gamma = torch.tensor([21.0, 6.0]) RunningAverage(alpha=0.5, output_transform=lambda x: x).attach(trainer, "batchloss") pbar = ProgressBar() pbar.attach(trainer, metric_names=["batchloss"], state_attributes=["alpha", "beta", "gamma"]) trainer.run(data=data, max_epochs=1) captured = capsys.readouterr() err = captured.err.split("\r") err = list(map(lambda x: x.strip(), err)) err = list(filter(None, err)) actual = err[-1] if get_tqdm_version() < Version("4.49.0"): expected = ( "Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<00:00]" ) else: expected = ( "Iteration: [1/2] 50%|█████ , batchloss=0.5, alpha=3.9, beta=12.2, gamma_0=21, gamma_1=6 [00:00<?]" ) assert actual == expected
def build_trainer(self) -> Engine: loss_fn: callable = F.nll_loss optimizer: torch.optim.Adam = torch.optim.Adam(self.parameters(), 1e-3) model = self def process_function(engine: Engine, batch: Tuple[torch.Tensor, torch.Tensor, List[int]]) -> \ Tuple[float, torch.Tensor, torch.Tensor]: """Single training loop to be attached to trainer Engine""" model.train() optimizer.zero_grad() x, y, lengths = batch x, y = x.to(model.device), y.to(model.device) y_pred: torch.Tensor = model(x, lengths) loss: torch.Tensor = loss_fn(y_pred, y) loss.backward() optimizer.step() return loss.item(), torch.max(y_pred, dim=1)[1], y def eval_function(engine: Engine, batch: Tuple[torch.Tensor, torch.Tensor, List[int]]) -> \ Tuple[torch.Tensor, torch.Tensor]: """Single evaluator loop to be attached to trainer and evaluator Engine""" model.eval() with torch.no_grad(): x, y, lengths = batch x, y = x.to(model.device), y.to(model.device) y_pred: torch.Tensor = model(x, lengths) return y_pred, y trainer: Engine = Engine(process_function) train_evaluator: Engine = Engine(eval_function) validation_evaluator: Engine = Engine(eval_function) ConcatPoolingGRUAdaptive.track_progress(train_evaluator, validation_evaluator, loss_fn, trainer) pbar = ProgressBar(persist=True, bar_format="") pbar.attach(trainer, ['loss', 'acc']) self.log_results(train_evaluator, validation_evaluator, pbar, trainer) return trainer
def create_segmentation_evaluator( model, device, num_classes=19, loss_fn=None, non_blocking=True): cm = partial(ConfusionMatrix, num_classes) metrics = { 'iou': IoU(cm()), 'miou': mIoU(cm()), 'accuracy': cmAccuracy(cm()), 'dice': DiceCoefficient(cm()), } if loss_fn is not None: metrics['loss'] = Loss(loss_fn) evaluator = create_supervised_evaluator( model, metrics, device, non_blocking) ProgressBar(persist=False) \ .attach(evaluator) return evaluator
y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # validation Accuracy(output_transform=thresholded_output_transform).attach( train_evaluator, 'accuracy') Loss(criterion).attach(train_evaluator, 'bce') # test Accuracy(output_transform=thresholded_output_transform).attach( validation_evaluator, 'accuracy') Loss(criterion).attach(validation_evaluator, 'bce') pbar = ProgressBar(persist=True, bar_format="") pbar.attach(trainer, ['loss']) trainer.run(train_iterator, max_epochs=N_EPOCH) def binary_accuracy(preds, y): """ Returns accuracy per batch, i.e. if you get 8/10 right, this returns 0.8, NOT 8 """ #round predictions to the closest integer rounded_preds = torch.round(torch.sigmoid(preds)) correct = (rounded_preds == y).float() #convert into float for division acc = correct.sum() / len(correct) return acc
def train(): parser = ArgumentParser() parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss, mc_loss = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=None) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def create_trainer( train_step, output_names, model, ema_model, optimizer, lr_scheduler, supervised_train_loader, test_loader, cfg, logger, cta=None, unsup_train_loader=None, cta_probe_loader=None, ): trainer = Engine(train_step) trainer.logger = logger output_path = os.getcwd() to_save = { "model": model, "ema_model": ema_model, "optimizer": optimizer, "trainer": trainer, "lr_scheduler": lr_scheduler, } if cta is not None: to_save["cta"] = cta common.setup_common_training_handlers( trainer, train_sampler=supervised_train_loader.sampler, to_save=to_save, save_every_iters=cfg.solver.checkpoint_every, output_path=output_path, output_names=output_names, lr_scheduler=lr_scheduler, with_pbars=False, clear_cuda_cache=False, ) ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED ) unsupervised_train_loader_iter = None if unsup_train_loader is not None: unsupervised_train_loader_iter = cycle(unsup_train_loader) cta_probe_loader_iter = None if cta_probe_loader is not None: cta_probe_loader_iter = cycle(cta_probe_loader) # Setup handler to prepare data batches @trainer.on(Events.ITERATION_STARTED) def prepare_batch(e): sup_batch = e.state.batch e.state.batch = { "sup_batch": sup_batch, } if unsupervised_train_loader_iter is not None: unsup_batch = next(unsupervised_train_loader_iter) e.state.batch["unsup_batch"] = unsup_batch if cta_probe_loader_iter is not None: cta_probe_batch = next(cta_probe_loader_iter) cta_probe_batch["policy"] = [ deserialize(p) for p in cta_probe_batch["policy"] ] e.state.batch["cta_probe_batch"] = cta_probe_batch # Setup handler to update EMA model @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay) def update_ema_model(ema_decay): # EMA on parametes for ema_param, param in zip(ema_model.parameters(), model.parameters()): ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay) # Setup handlers for debugging if cfg.debug: @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100)) @idist.one_rank_only() def log_weights_norms(): wn = [] ema_wn = [] for ema_param, param in zip(ema_model.parameters(), model.parameters()): wn.append(torch.mean(param.data)) ema_wn.append(torch.mean(ema_param.data)) msg = "\n\nWeights norms" msg += "\n- Raw model: {}".format( to_list_str(torch.tensor(wn[:10] + wn[-10:])) ) msg += "\n- EMA model: {}\n".format( to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:])) ) logger.info(msg) rmn = [] rvar = [] ema_rmn = [] ema_rvar = [] for m1, m2 in zip(model.modules(), ema_model.modules()): if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d): rmn.append(torch.mean(m1.running_mean)) rvar.append(torch.mean(m1.running_var)) ema_rmn.append(torch.mean(m2.running_mean)) ema_rvar.append(torch.mean(m2.running_var)) msg = "\n\nBN buffers" msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10]))) msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10]))) msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10]))) msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10]))) logger.info(msg) # TODO: Need to inspect a bug # if idist.get_rank() == 0: # from ignite.contrib.handlers import ProgressBar # # profiler = BasicTimeProfiler() # profiler.attach(trainer) # # @trainer.on(Events.ITERATION_COMPLETED(every=200)) # def log_profiling(_): # results = profiler.get_results() # profiler.print_results(results) # Setup validation engine metrics = { "accuracy": Accuracy(), } if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU): metrics.update({ "precision": Precision(average=False), "recall": Recall(average=False), }) eval_kwargs = dict( metrics=metrics, prepare_batch=sup_prepare_batch, device=idist.device(), non_blocking=True, ) evaluator = create_supervised_evaluator(model, **eval_kwargs) ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs) def log_results(epoch, max_epochs, metrics, ema_metrics): msg1 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()] ) msg2 = "\n".join( ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()] ) logger.info( "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2) ) if cta is not None: logger.info("\n" + stats(cta)) @trainer.on( Events.EPOCH_COMPLETED(every=cfg.solver.validate_every) | Events.STARTED | Events.COMPLETED ) def run_evaluation(): evaluator.run(test_loader) ema_evaluator.run(test_loader) log_results( trainer.state.epoch, trainer.state.max_epochs, evaluator.state.metrics, ema_evaluator.state.metrics, ) # setup TB logging if idist.get_rank() == 0: tb_logger = common.setup_tb_logging( output_path, trainer, optimizers=optimizer, evaluators={"validation": evaluator, "ema validation": ema_evaluator}, log_every_iters=15, ) if cfg.online_exp_tracking.wandb: from ignite.contrib.handlers import WandBLogger wb_dir = Path("/tmp/output-fixmatch-wandb") if not wb_dir.exists(): wb_dir.mkdir() _ = WandBLogger( project="fixmatch-pytorch", name=cfg.name, config=cfg, sync_tensorboard=True, dir=wb_dir.as_posix(), reinit=True, ) resume_from = cfg.solver.resume_from if resume_from is not None: resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*")) if len(resume_from) > 0: # get latest checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix() ) logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix()) Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) @trainer.on(Events.COMPLETED) def release_all_resources(): nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter if idist.get_rank() == 0: tb_logger.close() if unsupervised_train_loader_iter is not None: unsupervised_train_loader_iter = None if cta_probe_loader_iter is not None: cta_probe_loader_iter = None return trainer
def train(): config_file = "configs/train_daily_dialog_emotion_action_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", config.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) # Training function and trainer def update(engine, batch): model.train() input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple( input_tensor.to(config.device) for input_tensor in batch) lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids) loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps if config.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) if engine.state.iteration % config.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids, token_action_ids=token_action_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if config.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if config.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if config.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if config.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=config.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=config.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if config.local_rank in [-1, 0] and config.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr, sn, flowgan, eval_every, ld_on_samples, weight_gan, weight_prior, weight_logdet, jac_reg_lambda, affine_eps, no_warm_up, optim_name, clamp, svd_every, eval_only, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, init_sample, no_split, disc_arch, weight_entropy_reg, db): check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) model = model.to(device) if disc_arch == 'mine': discriminator = mine.Discriminator(image_shape[-1]) elif disc_arch == 'biggan': discriminator = cgan_models.Discriminator( image_channels=image_shape[-1], conditional_D=False) elif disc_arch == 'dcgan': discriminator = DCGANDiscriminator(image_shape[0], 64, image_shape[-1]) elif disc_arch == 'inv': discriminator = InvDiscriminator( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform, sn, affine_eps, no_actnorm, affine_scale_eps, actnorm_max_scale, no_conv_actnorm, affine_max_scale, actnorm_eps, no_split) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) if optim_name == 'adam': optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) elif optim_name == 'adamax': optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) if not no_warm_up: lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) iteration_fieldnames = [ 'global_iteration', 'fid', 'sample_pad', 'train_bpd', 'eval_bpd', 'pad', 'batch_real_acc', 'batch_fake_acc', 'batch_acc' ] iteration_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'iteration_log.csv')) iteration_fieldnames = [ 'global_iteration', 'condition_num', 'max_sv', 'min_sv', 'inverse_condition_num', 'inverse_max_sv', 'inverse_min_sv' ] svd_logger = CSVLogger(fieldnames=iteration_fieldnames, filename=os.path.join(output_dir, 'svd_log.csv')) # test_iter = test_loader.__iter__() N_inception = 1000 x_real_inception = torch.cat([ test_iter.__next__()[0].to(device) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] x_real_inception = x_real_inception + .5 x_for_recon = test_iter.__next__()[0].to(device) def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) real_acc = fake_acc = acc = 0 if weight_gan > 0: fake = generate_from_noise(model, x.size(0), clamp=clamp) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) D_real_accuracy = torch.sum( torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) D_fake_accuracy = torch.sum( torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty( x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(model, x.size(0), clamp=clamp, guard_nans=False) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) # Trace real_acc = D_real_accuracy.item() fake_acc = D_fake_accuracy.item() acc = .5 * (D_fake_accuracy.item() + D_real_accuracy.item()) z, nll, y_logits, (prior, logdet) = model.forward(x, None, return_details=True) train_bpd = nll.mean().item() loss = 0 if weight_gan > 0: loss = loss + weight_gan * G_loss if weight_prior > 0: loss = loss + weight_prior * -prior.mean() if weight_logdet > 0: loss = loss + weight_logdet * -logdet.mean() if weight_entropy_reg > 0: _, _, _, (sample_prior, sample_logdet) = model.forward(fake, None, return_details=True) # notice this is actually "decreasing" sample likelihood. loss = loss + weight_entropy_reg * (sample_prior.mean() + sample_logdet.mean()) # Jac Reg if jac_reg_lambda > 0: # Sample x_samples = generate_from_noise(model, args.batch_size, clamp=clamp).detach() x_samples.requires_grad_() z = model.forward(x_samples, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) sample_foward_jac = compute_jacobian_regularizer(x_samples, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) randz = torch.randn(zshape).to(device) randz = torch.autograd.Variable(randz, requires_grad=True) images = model(z=randz, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [randz] + other_zs sample_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # Data x.requires_grad_() z = model.forward(x, None, return_details=True)[0] other_zs = torch.cat([ split._last_z2.view(x.size(0), -1) for split in model.flow.splits ], -1) all_z = torch.cat([other_zs, z.view(x.size(0), -1)], -1) data_foward_jac = compute_jacobian_regularizer(x, all_z, n_proj=1) _, c2, h, w = model.prior_h.shape c = c2 // 2 zshape = (batch_size, c, h, w) z.requires_grad_() images = model(z=z, y_onehot=None, temperature=1, reverse=True, batch_size=0) other_zs = [split._last_z2 for split in model.flow.splits] all_z = [z] + other_zs data_inverse_jac = compute_jacobian_regularizer_manyinputs( all_z, images, n_proj=1) # loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac ) loss = loss + jac_reg_lambda * (sample_foward_jac + sample_inverse_jac + data_foward_jac + data_inverse_jac) if not eval_only: optimizer.zero_grad() loss.backward() if not db: assert max_grad_clip == max_grad_norm == 0 if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) # Replace NaN gradient with 0 for p in model.parameters(): if p.requires_grad and p.grad is not None: g = p.grad.data g[g != g] = 0 optimizer.step() if engine.iter_ind % 100 == 0: with torch.no_grad(): fake = generate_from_noise(model, x.size(0), clamp=clamp) z = model.forward(fake, None, return_details=True)[0] print("Z max min") print(z.max().item(), z.min().item()) if (fake != fake).float().sum() > 0: title = 'NaNs' else: title = "Good" grid = make_grid((postprocess(fake.detach().cpu(), dataset)[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.title(title) plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) if engine.iter_ind % eval_every == 0: def check_all_zero_except_leading(x): return x % 10**np.floor(np.log10(x)) == 0 if engine.iter_ind == 0 or check_all_zero_except_leading( engine.iter_ind): torch.save( model.state_dict(), os.path.join(output_dir, f'ckpt_sd_{engine.iter_ind}.pt')) model.eval() with torch.no_grad(): # Plot recon fpath = os.path.join(output_dir, '_recon', f'recon_{engine.iter_ind}.png') sample_pad = run_recon_evolution( model, generate_from_noise(model, args.batch_size, clamp=clamp).detach(), fpath) print( f"Iter: {engine.iter_ind}, Recon Sample PAD: {sample_pad}") pad = run_recon_evolution(model, x_for_recon, fpath) print(f"Iter: {engine.iter_ind}, Recon PAD: {pad}") pad = pad.item() sample_pad = sample_pad.item() # Inception score sample = torch.cat([ generate_from_noise(model, args.batch_size, clamp=clamp) for _ in range(N_inception // args.batch_size + 1) ], 0)[:N_inception] sample = sample + .5 if (sample != sample).float().sum() > 0: print("Sample NaNs") raise else: fid = run_fid(x_real_inception.clamp_(0, 1), sample.clamp_(0, 1)) print(f'fid: {fid}, global_iter: {engine.iter_ind}') # Eval BPD eval_bpd = np.mean([ model.forward(x.to(device), None, return_details=True)[1].mean().item() for x, _ in test_loader ]) stats_dict = { 'global_iteration': engine.iter_ind, 'fid': fid, 'train_bpd': train_bpd, 'pad': pad, 'eval_bpd': eval_bpd, 'sample_pad': sample_pad, 'batch_real_acc': real_acc, 'batch_fake_acc': fake_acc, 'batch_acc': acc } iteration_logger.writerow(stats_dict) plot_csv(iteration_logger.filename) model.train() if engine.iter_ind + 2 % svd_every == 0: model.eval() svd_dict = {} ret = utils.computeSVDjacobian(x_for_recon, model) D_for, D_inv = ret['D_for'], ret['D_inv'] cn = float(D_for.max() / D_for.min()) cn_inv = float(D_inv.max() / D_inv.min()) svd_dict['global_iteration'] = engine.iter_ind svd_dict['condition_num'] = cn svd_dict['max_sv'] = float(D_for.max()) svd_dict['min_sv'] = float(D_for.min()) svd_dict['inverse_condition_num'] = cn_inv svd_dict['inverse_max_sv'] = float(D_inv.max()) svd_dict['inverse_min_sv'] = float(D_inv.min()) svd_logger.writerow(svd_dict) # plot_utils.plot_stability_stats(output_dir) # plot_utils.plot_individual_figures(output_dir, 'svd_log.csv') model.train() if eval_only: sys.exit() # Dummy losses['total_loss'] = torch.mean(nll).item() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(gan_step) # else: # trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=5, n_saved=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: print("Loading...") print(saved_model) loaded = torch.load(saved_model) # if 'Glow' in str(type(loaded)): # model = loaded # else: # raise # # if 'Glow' in str(type(loaded)): # # loaded = loaded.state_dict() model.load_state_dict(loaded) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): if saved_model: return model.train() print("Initializing Actnorm...") init_batches = [] init_targets = [] if n_init_batches == 0: model.set_actnorm_init() return with torch.no_grad(): if init_sample: generate_from_noise(model, args.batch_size * args.n_init_batches) else: for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) if not no_warm_up: scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def main(config, needs_save): os.environ['CUDA_VISIBLE_DEVICES'] = config.training.visible_devices seed = check_manual_seed(config.training.seed) print('Using manual seed: {}'.format(seed)) if config.dataset.patient_ids == 'TRAIN_PATIENT_IDS': patient_ids = TRAIN_PATIENT_IDS elif config.dataset.patient_ids == 'TEST_PATIENT_IDS': patient_ids = TEST_PATIENT_IDS else: raise NotImplementedError data_loader = get_data_loader( mode=config.dataset.mode, dataset_name=config.dataset.name, patient_ids=patient_ids, root_dir_path=config.dataset.root_dir_path, use_augmentation=config.dataset.use_augmentation, batch_size=config.dataset.batch_size, num_workers=config.dataset.num_workers, image_size=config.dataset.image_size) E = Encoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.enc_filters, activation=config.model.enc_activation).float() D = Decoder(input_dim=config.model.input_dim, z_dim=config.model.z_dim, filters=config.model.dec_filters, activation=config.model.dec_activation, final_activation=config.model.dec_final_activation).float() if config.model.enc_spectral_norm: apply_spectral_norm(E) if config.model.dec_spectral_norm: apply_spectral_norm(D) if config.training.use_cuda: E.cuda() D.cuda() E = nn.DataParallel(E) D = nn.DataParallel(D) if config.model.saved_E: print(config.model.saved_E) E.load_state_dict(torch.load(config.model.saved_E)) if config.model.saved_D: print(config.model.saved_D) D.load_state_dict(torch.load(config.model.saved_D)) print(E) print(D) e_optim = optim.Adam(filter(lambda p: p.requires_grad, E.parameters()), config.optimizer.enc_lr, [0.9, 0.9999]) d_optim = optim.Adam(filter(lambda p: p.requires_grad, D.parameters()), config.optimizer.dec_lr, [0.9, 0.9999]) alpha = config.training.alpha beta = config.training.beta margin = config.training.margin batch_size = config.dataset.batch_size fixed_z = torch.randn(calc_latent_dim(config)) if 'ssim' in config.training.loss: ssim_loss = pytorch_ssim.SSIM(window_size=11) def l_recon(recon: torch.Tensor, target: torch.Tensor): if config.training.loss == 'l2': loss = F.mse_loss(recon, target, reduction='sum') elif config.training.loss == 'l1': loss = F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) elif config.training.loss == 'ssim+l1': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.l1_loss(recon, target, reduction='sum') elif config.training.loss == 'ssim+l2': loss = (1.0 - ssim_loss(recon, target)) * torch.numel(recon) \ + F.mse_loss(recon, target, reduction='sum') else: raise NotImplementedError return beta * loss / batch_size def l_reg(mu: torch.Tensor, log_var: torch.Tensor): loss = -0.5 * torch.sum(1 + log_var - mu**2 - torch.exp(log_var)) return loss / batch_size def update(engine, batch): E.train() D.train() image = norm(batch['image']) if config.training.use_cuda: image = image.cuda(non_blocking=True).float() else: image = image.float() e_optim.zero_grad() d_optim.zero_grad() z, z_mu, z_logvar = E(image) x_r = D(z) l_vae_reg = l_reg(z_mu, z_logvar) l_vae_recon = l_recon(x_r, image) l_vae_total = l_vae_reg + l_vae_recon l_vae_total.backward() e_optim.step() d_optim.step() if config.training.use_cuda: torch.cuda.synchronize() return { 'TotalLoss': l_vae_total.item(), 'EncodeLoss': l_vae_reg.item(), 'ReconLoss': l_vae_recon.item(), } output_dir = get_output_dir_path(config) trainer = Engine(update) timer = Timer(average=True) monitoring_metrics = ['TotalLoss', 'EncodeLoss', 'ReconLoss'] for metric in monitoring_metrics: RunningAverage(alpha=0.98, output_transform=partial(lambda x, metric: x[metric], metric=metric)).attach( trainer, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.STARTED) def save_config(engine): config_to_save = defaultdict(dict) for key, child in config._asdict().items(): for k, v in child._asdict().items(): config_to_save[key][k] = v config_to_save['seed'] = seed config_to_save['output_dir'] = output_dir print('Training starts by the following configuration: ', config_to_save) if needs_save: save_path = os.path.join(output_dir, 'config.json') with open(save_path, 'w') as f: json.dump(config_to_save, f) @trainer.on(Events.ITERATION_COMPLETED) def show_logs(engine): if (engine.state.iteration - 1) % config.save.log_iter_interval == 0: columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=config.training.n_epochs, i=engine.state.iteration, max_i=len(data_loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.EPOCH_COMPLETED) def save_logs(engine): if needs_save: fname = os.path.join(output_dir, 'logs.tsv') columns = ['epoch', 'iteration'] + list( engine.state.metrics.keys()) values = [str(engine.state.epoch), str(engine.state.iteration)] \ + [str(value) for value in engine.state.metrics.values()] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def save_images(engine): if needs_save: if engine.state.epoch % config.save.save_epoch_interval == 0: image = norm(engine.state.batch['image']) with torch.no_grad(): z, _, _ = E(image) x_r = D(z) x_p = D(fixed_z) image = denorm(image).detach().cpu() x_r = denorm(x_r).detach().cpu() x_p = denorm(x_p).detach().cpu() image = image[:config.save.n_save_images, ...] x_r = x_r[:config.save.n_save_images, ...] x_p = x_p[:config.save.n_save_images, ...] save_path = os.path.join( output_dir, 'result_{}.png'.format(engine.state.epoch)) save_image(torch.cat([image, x_r, x_p]).data, save_path) if needs_save: checkpoint_handler = ModelCheckpoint( output_dir, config.save.study_name, save_interval=config.save.save_epoch_interval, n_saved=config.save.n_saved, create_dir=True, ) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'E': E, 'D': D }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) print('Training starts: [max_epochs] {}, [max_iterations] {}'.format( config.training.n_epochs, config.training.n_epochs * len(data_loader))) trainer.run(data_loader, config.training.n_epochs)
def train(device, net, dataloader, val_loader, args, logger, experiment): def update(engine, data): input_left, input_right, label = data['left_image'], data['right_image'], data['winner'] input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device) rank_label = label.clone() inverse_label = label.clone() label[label==-1] = 0 # zero the parameter gradients optimizer.zero_grad() rank_label = rank_label.float() start = timer() output_clf,output_rank_left, output_rank_right = net(input_left,input_right) end = timer() logger.info(f'FORWARD,{end-start:.4f}') #compute clf loss start = timer() loss_clf = clf_crit(output_clf,label) #compute ranking loss loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit) loss = loss_clf + loss_rank end = timer() logger.info(f'LOSS,{end-start:.4f}') #compute ranking accuracy start = timer() rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label) end = timer() logger.info(f'RANK-ACC,{end-start:.4f}') # backward step start = timer() loss.backward() optimizer.step() end = timer() logger.info(f'BACKWARD,{end-start:.4f}') #swapped forward start = timer() inverse_label*=-1 #swap label inverse_rank_label = inverse_label.clone() inverse_rank_label = inverse_rank_label.float() inverse_label[inverse_label==-1] = 0 end = timer() logger.info(f'SWAPPED-SETUP,{end-start:.4f}') start = timer() outputs, output_rank_left, output_rank_right = net(input_right,input_left) #pass swapped input end = timer() logger.info(f'SWAPPED-FORWARD,{end-start:.4f}') start = timer() inverse_loss_clf = clf_crit(outputs, inverse_label) #compute ranking loss inverse_loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit) #swapped backward inverse_loss = inverse_loss_clf + inverse_loss_rank end = timer() logger.info(f'SWAPPED-LOSS,{end-start:.4f}') start = timer() inverse_loss.backward() optimizer.step() end = timer() logger.info(f'SWAPPED-BACKWARD,{end-start:.4f}') return { 'loss':loss.item(), 'loss_clf':loss_clf.item(), 'loss_rank':loss_rank.item(), 'y':label, 'y_pred': output_clf, 'rank_acc': rank_acc } def inference(engine,data): with torch.no_grad(): start = timer() input_left, input_right, label = data['left_image'], data['right_image'], data['winner'] input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device) rank_label = label.clone() label[label==-1] = 0 rank_label = rank_label.float() # forward output_clf,output_rank_left, output_rank_right = net(input_left,input_right) loss_clf = clf_crit(output_clf,label) loss_rank = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit) rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label) loss = loss_clf + loss_rank end = timer() logger.info(f'INFERENCE,{end-start:.4f}') return { 'loss':loss.item(), 'loss_clf':loss_clf.item(), 'loss_rank':loss_rank.item(), 'y':label, 'y_pred': output_clf, 'rank_acc': rank_acc } net = net.to(device) clf_crit = nn.NLLLoss() rank_crit = nn.MarginRankingLoss(reduction='mean', margin=1) optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9) lamb = Variable(torch.FloatTensor([1]),requires_grad = False).cuda()[0] trainer = Engine(update) evaluator = Engine(inference) writer = SummaryWriter() RunningAverage(output_transform=lambda x: x['loss']).attach(trainer, 'loss') RunningAverage(output_transform=lambda x: x['loss_clf']).attach(trainer, 'loss_clf') RunningAverage(output_transform=lambda x: x['loss_rank']).attach(trainer, 'loss_rank') RunningAverage(output_transform=lambda x: x['rank_acc']).attach(trainer, 'rank_acc') RunningAverage(Accuracy(output_transform=lambda x: (x['y_pred'],x['y']))).attach(trainer,'avg_acc') RunningAverage(output_transform=lambda x: x['loss']).attach(evaluator, 'loss') RunningAverage(output_transform=lambda x: x['loss_clf']).attach(evaluator, 'loss_clf') RunningAverage(output_transform=lambda x: x['loss_rank']).attach(evaluator, 'loss_rank') RunningAverage(output_transform=lambda x: x['rank_acc']).attach(evaluator, 'rank_acc') RunningAverage(Accuracy(output_transform=lambda x: (x['y_pred'],x['y']))).attach(evaluator,'avg_acc') if args.pbar: pbar = ProgressBar(persist=False) pbar.attach(trainer,['loss','avg_acc', 'rank_acc']) pbar = ProgressBar(persist=False) pbar.attach(evaluator,['loss','loss_clf', 'loss_rank','avg_acc']) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): net.eval() evaluator.run(val_loader) trainer.state.metrics['val_acc'] = evaluator.state.metrics['rank_acc'] net.train() tb_log( { "accuracy":{ 'accuracy':trainer.state.metrics['avg_acc'], 'rank_accuracy':trainer.state.metrics['rank_acc'] }, "loss": { 'total':trainer.state.metrics['loss'], 'clf':trainer.state.metrics['loss_clf'], 'rank':trainer.state.metrics['loss_rank'] } }, { "accuracy":{ 'accuracy':evaluator.state.metrics['avg_acc'], 'rank_accuracy':evaluator.state.metrics['rank_acc'] }, "loss": { 'total':evaluator.state.metrics['loss'], 'clf':evaluator.state.metrics['loss_clf'], 'rank':evaluator.state.metrics['loss_rank'] } }, writer, args.attribute, trainer.state.epoch ) handler = ModelCheckpoint(args.model_dir, '{}_{}_{}'.format(args.model, args.premodel, args.attribute), n_saved=1, create_dir=True, save_as_state_dict=True, require_empty=False, score_function=lambda engine: engine.state.metrics['val_acc']) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, { 'model': net }) if (args.resume): def start_epoch(engine): engine.state.epoch = args.epoch trainer.add_event_handler(Events.STARTED, start_epoch) evaluator.add_event_handler(Events.STARTED, start_epoch) trainer.run(dataloader,max_epochs=args.max_epochs)
def run_once(self): log_dir = self.log_dir misc.check_manual_seed(self.seed) train_pairs, valid_pairs = dataset.prepare_data_VIABLE_2048() print(len(train_pairs)) # --------------------------- Dataloader train_augmentors = self.train_augmentors() train_dataset = dataset.DatasetSerial(train_pairs[:], shape_augs=iaa.Sequential(train_augmentors[0]), input_augs=iaa.Sequential(train_augmentors[1])) infer_augmentors = self.infer_augmentors() infer_dataset = dataset.DatasetSerial(valid_pairs[:], shape_augs=iaa.Sequential(infer_augmentors)) train_loader = data.DataLoader(train_dataset, num_workers=self.nr_procs_train, batch_size=self.train_batch_size, shuffle=True, drop_last=True) valid_loader = data.DataLoader(infer_dataset, num_workers=self.nr_procs_valid, batch_size=self.infer_batch_size, shuffle=True, drop_last=False) # --------------------------- Training Sequence if self.logging: misc.check_log_dir(log_dir) device = 'cuda' # networks input_chs = 3 net = DenseNet(input_chs, self.nr_classes) net = torch.nn.DataParallel(net).to(device) # print(net) # optimizers optimizer = optim.Adam(net.parameters(), lr=self.init_lr) scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps) # load pre-trained models if self.load_network: saved_state = torch.load(self.save_net_path) net.load_state_dict(saved_state) # trainer = Engine(lambda engine, batch: self.train_step(net, batch, optimizer, 'cuda')) inferer = Engine(lambda engine, batch: self.infer_step(net, batch, 'cuda')) train_output = ['loss', 'acc'] infer_output = ['prob', 'true'] ## if self.logging: checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, save_interval=1, n_saved=120, require_empty=False) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) timer.attach(inferer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # attach running average metrics computation # decay of EMA to 0.95 to match tensorpack default RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach(trainer, 'loss') RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach(trainer, 'acc') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) pbar.attach(inferer) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'net_exception': net}) else: raise e # writer for tensorboard logging if self.logging: writer = SummaryWriter(log_dir=log_dir) json_log_file = log_dir + '/stats.json' with open(json_log_file, 'w') as json_file: json.dump({}, json_file) # create empty file @trainer.on(Events.EPOCH_STARTED) def log_lrs(engine): if self.logging: lr = float(optimizer.param_groups[0]['lr']) writer.add_scalar("lr", lr, engine.state.epoch) # advance scheduler clock scheduler.step() #### def update_logs(output, epoch, prefix, color): # print values and convert max_length = len(max(output.keys(), key=len)) for metric in output: key = colored(prefix + '-' + metric.ljust(max_length), color) print('------%s : ' % key, end='') print('%0.7f' % output[metric]) if 'train' in prefix: lr = float(optimizer.param_groups[0]['lr']) key = colored(prefix + '-' + 'lr'.ljust(max_length), color) print('------%s : %0.7f' % (key, lr)) if not self.logging: return # create stat dicts stat_dict = {} for metric in output: metric_value = output[metric] stat_dict['%s-%s' % (prefix, metric)] = metric_value # json stat log file, update and overwrite with open(json_log_file) as json_file: json_data = json.load(json_file) current_epoch = str(epoch) if current_epoch in json_data: old_stat_dict = json_data[current_epoch] stat_dict.update(old_stat_dict) current_epoch_dict = {current_epoch : stat_dict} json_data.update(current_epoch_dict) with open(json_log_file, 'w') as json_file: json.dump(json_data, json_file) # log values to tensorboard for metric in output: writer.add_scalar(prefix + '-' + metric, output[metric], current_epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_train_running_results(engine): """ running training measurement """ training_ema_output = engine.state.metrics # update_logs(training_ema_output, engine.state.epoch, prefix='train-ema', color='green') #### def get_init_accumulator(output_names): return {metric : [] for metric in output_names} import cv2 def process_accumulated_output(output): def uneven_seq_to_np(seq, batch_size=self.infer_batch_size): if self.infer_batch_size == 1: return np.squeeze(seq) item_count = batch_size * (len(seq) - 1) + len(seq[-1]) cat_array = np.zeros((item_count,) + seq[0][0].shape, seq[0].dtype) for idx in range(0, len(seq)-1): cat_array[idx * batch_size : (idx+1) * batch_size] = seq[idx] cat_array[(idx+1) * batch_size:] = seq[-1] return cat_array # prob = uneven_seq_to_np(output['prob']) true = uneven_seq_to_np(output['true']) # cmap = plt.get_cmap('jet') # epi = prob[...,1] # epi = (cmap(epi) * 255.0).astype('uint8') # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR)) pred = np.argmax(prob, axis=-1) true = np.squeeze(true) # deal with ignore index pred = pred.flatten() true = true.flatten() pred = pred[true != 0] - 1 true = true[true != 0] - 1 acc = np.mean(pred == true) inter = (pred * true).sum() total = (pred + true).sum() dice = 2 * inter / total # proc_output = dict(acc=acc, dice=dice) return proc_output @trainer.on(Events.EPOCH_COMPLETED) def infer_valid(engine): """ inference measurement """ inferer.accumulator = get_init_accumulator(infer_output) inferer.run(valid_loader) output_stat = process_accumulated_output(inferer.accumulator) update_logs(output_stat, engine.state.epoch, prefix='valid', color='red') @inferer.on(Events.ITERATION_COMPLETED) def accumulate_outputs(engine): batch_output = engine.state.output for key, item in batch_output.items(): engine.accumulator[key].extend([item]) ### #Setup is done. Now let's run the training trainer.run(train_loader, self.nr_epochs) return
def main(hparams): results_dir = get_results_directory(hparams.output_dir) writer = SummaryWriter(log_dir=str(results_dir)) ds = get_dataset(hparams.dataset, root=hparams.data_root) input_size, num_classes, train_dataset, test_dataset = ds hparams.seed = set_seed(hparams.seed) if hparams.n_inducing_points is None: hparams.n_inducing_points = num_classes print(f"Training with {hparams}") hparams.save(results_dir / "hparams.json") if hparams.ard: # Hardcoded to WRN output size ard = 640 else: ard = None feature_extractor = WideResNet( spectral_normalization=hparams.spectral_normalization, dropout_rate=hparams.dropout_rate, coeff=hparams.coeff, n_power_iterations=hparams.n_power_iterations, batchnorm_momentum=hparams.batchnorm_momentum, ) initial_inducing_points, initial_lengthscale = initial_values_for_GP( train_dataset, feature_extractor, hparams.n_inducing_points ) gp = GP( num_outputs=num_classes, initial_lengthscale=initial_lengthscale, initial_inducing_points=initial_inducing_points, separate_inducing_points=hparams.separate_inducing_points, kernel=hparams.kernel, ard=ard, lengthscale_prior=hparams.lengthscale_prior, ) model = DKL_GP(feature_extractor, gp) model = model.cuda() likelihood = SoftmaxLikelihood(num_classes=num_classes, mixing_weights=False) likelihood = likelihood.cuda() elbo_fn = VariationalELBO(likelihood, gp, num_data=len(train_dataset)) parameters = [ {"params": feature_extractor.parameters(), "lr": hparams.learning_rate}, {"params": gp.parameters(), "lr": hparams.learning_rate}, {"params": likelihood.parameters(), "lr": hparams.learning_rate}, ] optimizer = torch.optim.SGD( parameters, momentum=0.9, weight_decay=hparams.weight_decay ) milestones = [60, 120, 160] scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=milestones, gamma=0.2 ) def step(engine, batch): model.train() likelihood.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() y_pred = model(x) elbo = -elbo_fn(y_pred, y) elbo.backward() optimizer.step() return elbo.item() def eval_step(engine, batch): model.eval() likelihood.eval() x, y = batch x, y = x.cuda(), y.cuda() with torch.no_grad(): y_pred = model(x) return y_pred, y trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "elbo") def output_transform(output): y_pred, y = output # Sample softmax values independently for classification at test time y_pred = y_pred.to_data_independent_dist() # The mean here is over likelihood samples y_pred = likelihood(y_pred).probs.mean(0) return y_pred, y metric = Accuracy(output_transform=output_transform) metric.attach(evaluator, "accuracy") metric = Loss(lambda y_pred, y: -elbo_fn(y_pred, y)) metric.attach(evaluator, "elbo") kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=hparams.batch_size, shuffle=True, drop_last=True, **kwargs, ) test_loader = torch.utils.data.DataLoader( test_dataset, batch_size=512, shuffle=False, **kwargs ) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics elbo = metrics["elbo"] print(f"Train - Epoch: {trainer.state.epoch} ELBO: {elbo:.2f} ") writer.add_scalar("Likelihood/train", elbo, trainer.state.epoch) if hparams.spectral_normalization: for name, layer in model.feature_extractor.named_modules(): if isinstance(layer, torch.nn.Conv2d): writer.add_scalar( f"sigma/{name}", layer.weight_sigma, trainer.state.epoch ) if not hparams.ard: # Otherwise it's too much to submit to tensorboard length_scales = model.gp.covar_module.base_kernel.lengthscale.squeeze() for i in range(length_scales.shape[0]): writer.add_scalar( f"length_scale/{i}", length_scales[i], trainer.state.epoch ) if trainer.state.epoch > 150 and trainer.state.epoch % 5 == 0: _, auroc, aupr = get_ood_metrics( hparams.dataset, "SVHN", model, likelihood, hparams.data_root ) print(f"OoD Metrics - AUROC: {auroc}, AUPR: {aupr}") writer.add_scalar("OoD/auroc", auroc, trainer.state.epoch) writer.add_scalar("OoD/auprc", aupr, trainer.state.epoch) evaluator.run(test_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] elbo = metrics["elbo"] print( f"Test - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"ELBO: {elbo:.2f} " ) writer.add_scalar("Likelihood/test", elbo, trainer.state.epoch) writer.add_scalar("Accuracy/test", acc, trainer.state.epoch) scheduler.step() pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) trainer.run(train_loader, max_epochs=200) # Done training - time to evaluate results = {} evaluator.run(train_loader) train_acc = evaluator.state.metrics["accuracy"] train_elbo = evaluator.state.metrics["elbo"] results["train_accuracy"] = train_acc results["train_elbo"] = train_elbo evaluator.run(test_loader) test_acc = evaluator.state.metrics["accuracy"] test_elbo = evaluator.state.metrics["elbo"] results["test_accuracy"] = test_acc results["test_elbo"] = test_elbo _, auroc, aupr = get_ood_metrics( hparams.dataset, "SVHN", model, likelihood, hparams.data_root ) results["auroc_ood_svhn"] = auroc results["aupr_ood_svhn"] = aupr print(f"Test - Accuracy {results['test_accuracy']:.4f}") results_json = json.dumps(results, indent=4, sort_keys=True) (results_dir / "results.json").write_text(results_json) torch.save(model.state_dict(), results_dir / "model.pt") torch.save(likelihood.state_dict(), results_dir / "likelihood.pt") writer.close()
def test_attach_fail_with_string(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, "a")
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, fresh): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def test_pbar_fail_with_non_callable_transform(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, output_transform=1)
def main(): SEED = 1234 torch.manual_seed(SEED) torch.cuda.manual_seed(SEED) torch.backends.cudnn.deterministic = True TEXT = data.Field(lower=True, batch_first=True, tokenize='spacy') LABEL = data.LabelField(dtype=torch.float) train_data, test_data = datasets.IMDB.splits(TEXT, LABEL, root='/tmp/imdb/') train_data, valid_data = train_data.split(split_ratio=0.8, random_state=random.seed(SEED)) TEXT.build_vocab(train_data, vectors=GloVe(name='6B', dim=100, cache='/tmp/glove/'), unk_init=torch.Tensor.normal_) LABEL.build_vocab(train_data) BATCH_SIZE = 64 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') train_iterator, valid_iterator, test_iterator = data.BucketIterator.splits( (train_data, valid_data, test_data), batch_size=BATCH_SIZE, device=device) vocab_size, embedding_dim = TEXT.vocab.vectors.shape class SentimentAnalysisCNN(nn.Module): def __init__(self, vocab_size, embedding_dim, kernel_sizes, num_filters, num_classes, d_prob, mode, use_drop=False): """ Args: vocab_size : int - size of vocabulary in dictionary embedding_dim : int - the dimension of word embedding vector kernel_sizes : list of int - sequence of sizes of kernels in this architecture num_filters : how many filters used for each layers num_classes : int - number of classes to classify d_prob: probability for dropout layer mode: one of : static : pretrained weights, non-trainable nonstatic : pretrained weights, trainable rand : random init weights use_drop : use drop or not in this class """ super(SentimentAnalysisCNN, self).__init__() self.vocab_size = vocab_size self.embedding_dim = embedding_dim self.kernel_sizes = kernel_sizes self.num_filters = num_filters self.num_classes = num_classes self.d_prob = d_prob self.mode = mode self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=1) self.load_embeddings() self.conv = nn.ModuleList([ nn.Sequential( nn.Conv1d(in_channels=embedding_dim, out_channels=num_filters, kernel_size=k, stride=1), nn.Dropout(p=0.5, inplace=True)) for k in kernel_sizes ]) self.use_drop = use_drop if self.use_drop: self.dropout = nn.Dropout(d_prob) self.fc = nn.Linear(len(kernel_sizes) * num_filters, num_classes) def forward(self, x): batch_size, sequence_length = x.shape x = self.embedding(x).transpose(1, 2) x = [F.relu(conv(x)) for conv in self.conv] x = [F.max_pool1d(c, c.size(-1)).squeeze(dim=-1) for c in x] x = torch.cat(x, dim=1) if self.use_drop: x = self.fc(self.dropout(x)) x = self.fc(x) return torch.sigmoid(x).squeeze() def load_embeddings(self): if 'static' in self.mode: self.embedding.weight.data.copy_(TEXT.vocab.vectors) if 'non' not in self.mode: self.embedding.weight.data.requires_grad = False print( 'Loaded pretrained embeddings, weights are not trainable.' ) else: self.embedding.weight.data.requires_grad = True print( 'Loaded pretrained embeddings, weights are trainable.') elif self.mode == 'rand': print('Randomly initialized embeddings are used.') else: raise ValueError( 'Unexpected value of mode. Please choose from static, nonstatic, rand.' ) model = SentimentAnalysisCNN( vocab_size=vocab_size, #pkgmodel embedding_dim=embedding_dim, kernel_sizes=[3, 4, 5], num_filters=100, num_classes=1, d_prob=0.5, mode='static') model.to(device) ## switch back and forth the two optimizers # optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-3) ## optimizer provide better performance but get overfitting quickly optimizer = Ranger(model.parameters(), weight_decay=0.1) criterion = nn.BCELoss() def process_function(engine, batch): model.train() optimizer.zero_grad() x, y = batch.text, batch.label y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item() def eval_function(engine, batch): model.eval() with torch.no_grad(): x, y = batch.text, batch.label y_pred = model(x) return y_pred, y trainer = Engine(process_function) train_evaluator = Engine(eval_function) validation_evaluator = Engine(eval_function) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') def thresholded_output_transform(output): y_pred, y = output y_pred = torch.round(y_pred) return y_pred, y Accuracy(output_transform=thresholded_output_transform).attach( train_evaluator, 'accuracy') Loss(criterion).attach(train_evaluator, 'bce') Accuracy(output_transform=thresholded_output_transform).attach( validation_evaluator, 'accuracy') Loss(criterion).attach(validation_evaluator, 'bce') pbar = ProgressBar(persist=True, bar_format="") pbar.attach(trainer, ['loss']) def score_function(engine): val_loss = engine.state.metrics['bce'] return -val_loss handler = EarlyStopping(patience=5, score_function=score_function, trainer=trainer) validation_evaluator.add_event_handler(Events.COMPLETED, handler) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_iterator) metrics = train_evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_bce = metrics['bce'] pbar.log_message( "Training Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_bce)) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): validation_evaluator.run(valid_iterator) metrics = validation_evaluator.state.metrics avg_accuracy = metrics['accuracy'] avg_bce = metrics['bce'] pbar.log_message( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, avg_accuracy, avg_bce)) pbar.n = pbar.last_print_n = 0 checkpointer = ModelCheckpoint('/tmp/models', 'textcnn_ranger_wd_0_1', save_interval=1, n_saved=2, create_dir=True, save_as_state_dict=True, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'textcnn_ranger_wd_0_1': model}) # trainer.add_event_handler(Events.EPOCH_COMPLETED, log_validation_results) trainer.run(train_iterator, max_epochs=20)
def __init__(self: TrainerType, model: nn.Module, optimizer: Optimizer, checkpoint_dir: str = '../../checkpoints', experiment_name: str = 'experiment', model_checkpoint: Optional[str] = None, optimizer_checkpoint: Optional[str] = None, metrics: types.GenericDict = None, patience: int = 10, validate_every: int = 1, accumulation_steps: int = 1, loss_fn: Union[_Loss, DataParallelCriterion] = None, non_blocking: bool = True, retain_graph: bool = False, dtype: torch.dtype = torch.float, device: str = 'cpu', parallel: bool = False) -> None: self.dtype = dtype self.retain_graph = retain_graph self.non_blocking = non_blocking self.device = device self.loss_fn = loss_fn self.validate_every = validate_every self.patience = patience self.accumulation_steps = accumulation_steps self.checkpoint_dir = checkpoint_dir model_checkpoint = self._check_checkpoint(model_checkpoint) optimizer_checkpoint = self._check_checkpoint(optimizer_checkpoint) self.model = cast( nn.Module, from_checkpoint(model_checkpoint, model, map_location=torch.device('cpu'))) self.model = self.model.type(dtype).to(device) self.optimizer = from_checkpoint(optimizer_checkpoint, optimizer) self.parallel = parallel if parallel: if device == 'cpu': raise ValueError("parallel can be used only with cuda device") self.model = DataParallelModel(self.model).to(device) self.loss_fn = DataParallelCriterion(self.loss_fn) # type: ignore if metrics is None: metrics = {} if 'loss' not in metrics: if self.parallel: metrics['loss'] = Loss( lambda x, y: self.loss_fn(x, y).mean()) # type: ignore else: metrics['loss'] = Loss(self.loss_fn) self.trainer = Engine(self.train_step) self.train_evaluator = Engine(self.eval_step) self.valid_evaluator = Engine(self.eval_step) for name, metric in metrics.items(): metric.attach(self.train_evaluator, name) metric.attach(self.valid_evaluator, name) self.pbar = ProgressBar() self.val_pbar = ProgressBar(desc='Validation') if checkpoint_dir is not None: self.checkpoint = CheckpointHandler(checkpoint_dir, experiment_name, score_name='validation_loss', score_function=self._score_fn, n_saved=2, require_empty=False, save_as_state_dict=True) self.early_stop = EarlyStopping(patience, self._score_fn, self.trainer) self.val_handler = EvaluationHandler(pbar=self.pbar, validate_every=1, early_stopping=self.early_stop) self.attach() log.info( f'Trainer configured to run {experiment_name}\n' f'\tpretrained model: {model_checkpoint} {optimizer_checkpoint}\n' f'\tcheckpoint directory: {checkpoint_dir}\n' f'\tpatience: {patience}\n' f'\taccumulation steps: {accumulation_steps}\n' f'\tnon blocking: {non_blocking}\n' f'\tretain graph: {retain_graph}\n' f'\tdevice: {device}\n' f'\tmodel dtype: {dtype}\n' f'\tparallel: {parallel}')
def run(args): train_loader, val_loader = get_data_loaders(args.dir, args.batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = CityscapesDataset.num_instance_classes() + 1 model = models.box2pix(num_classes=num_classes) model.init_from_googlenet() writer = create_summary_writer(model, train_loader, args.log_dir) if torch.cuda.device_count() > 1: print("Using %d GPU(s)" % torch.cuda.device_count()) model = nn.DataParallel(model) model = model.to(device) semantics_criterion = nn.CrossEntropyLoss(ignore_index=255) offsets_criterion = nn.MSELoss() box_criterion = BoxLoss(num_classes, gamma=2) multitask_criterion = MultiTaskLoss().to(device) box_coder = BoxCoder() optimizer = optim.Adam([{ 'params': model.parameters(), 'weight_decay': 5e-4 }, { 'params': multitask_criterion.parameters() }], lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) multitask_criterion.load_state_dict(checkpoint['multitask']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) def _prepare_batch(batch, non_blocking=True): x, instance, boxes, labels = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(instance, device=device, non_blocking=non_blocking), convert_tensor(boxes, device=device, non_blocking=non_blocking), convert_tensor(labels, device=device, non_blocking=non_blocking)) def _update(engine, batch): model.train() optimizer.zero_grad() x, instance, boxes, labels = _prepare_batch(batch) boxes, labels = box_coder.encode(boxes, labels) loc_preds, conf_preds, semantics_pred, offsets_pred = model(x) semantics_loss = semantics_criterion(semantics_pred, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) loss = multitask_criterion(semantics_loss, offsets_loss, box_loss, conf_loss) loss.backward() optimizer.step() return { 'loss': loss.item(), 'loss_semantics': semantics_loss.item(), 'loss_offsets': offsets_loss.item(), 'loss_ssdbox': box_loss.item(), 'loss_ssdclass': conf_loss.item() } trainer = Engine(_update) checkpoint_handler = ModelCheckpoint(args.output_dir, 'checkpoint', save_interval=1, n_saved=10, require_empty=False, create_dir=True, save_as_state_dict=False) timer = Timer(average=True) # attach running average metrics train_metrics = [ 'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox', 'loss_ssdclass' ] for m in train_metrics: transform = partial(lambda x, metric: x[metric], metric=m) RunningAverage(output_transform=transform).attach(trainer, m) # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=train_metrics) checkpoint = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict(), 'multitask': multitask_criterion.state_dict() } trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'checkpoint': checkpoint}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) def _inference(engine, batch): model.eval() with torch.no_grad(): x, instance, boxes, labels = _prepare_batch(batch) loc_preds, conf_preds, semantics, offsets_pred = model(x) boxes_preds, labels_preds, scores_preds = box_coder.decode( loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01) semantics_loss = semantics_criterion(semantics, instance) offsets_loss = offsets_criterion(offsets_pred, instance) box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds, labels) semantics_pred = semantics.argmax(dim=1) instances = helper.assign_pix2box(semantics_pred, offsets_pred, boxes_preds, labels_preds) return { 'loss': (semantics_loss, offsets_loss, { 'box_loss': box_loss, 'conf_loss': conf_loss }), 'objects': (boxes_preds, labels_preds, scores_preds, boxes, labels), 'semantics': semantics_pred, 'instances': instances } train_evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( train_evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( train_evaluator, 'semantics') evaluator = Engine(_inference) Loss(multitask_criterion, output_transform=lambda x: x['loss']).attach(evaluator, 'loss') MeanAveragePrecision(num_classes, output_transform=lambda x: x['objects']).attach( evaluator, 'objects') IntersectionOverUnion(num_classes, output_transform=lambda x: x['semantics']).attach( evaluator, 'semantics') @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format( engine.state.epoch, engine.state.max_epochs, timer.value())) timer.reset() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): iteration = (engine.state.iteration - 1) % len(train_loader) + 1 if iteration % args.log_interval == 0: writer.add_scalar("training/loss", engine.state.output['loss'], engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): train_evaluator.run(train_loader) metrics = train_evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("train-val/loss", loss, engine.state.epoch) writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch) writer.add_scalar("train-val/IoU", iou, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] mean_ap = metrics['objects'] iou = metrics['semantics'] pbar.log_message( 'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}' .format(loss, evaluator.state.epochs, evaluator.state.max_epochs, mean_ap, iou * 100.0)) writer.add_scalar("validation/loss", loss, engine.state.epoch) writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch) writer.add_scalar("validation/IoU", iou, engine.state.epoch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") checkpoint_handler(engine, {'model_exception': model}) else: raise e @trainer.on(Events.COMPLETED) def save_final_model(engine): checkpoint_handler(engine, {'final': model}) trainer.run(train_loader, max_epochs=args.epochs) writer.close()
def train_gan(logger: Logger, experiment_dir: Path, data_dir: Path, batch_size: int, z_dim: int, g_filters: int, d_filters: int, learning_rate: float, beta_1: float, epochs: int, saved_g: bool = False, saved_d: bool = False, seed: Optional[int] = None, g_extra_layers: int = 0, d_extra_layers: int = 0, scheduler: bool = False) -> None: seed = fix_random_seed(seed) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Train started with seed: {seed}") dataset = HDF5ImageDataset(image_dir=data_dir) desired_minkowski = pickle.load( (data_dir / 'minkowski.pkl').open(mode='rb')) loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True) iterations = epochs * len(loader) img_size = dataset.shape[-1] num_channels = dataset.shape[0] # networks net_g = Generator(img_size=img_size, z_dim=z_dim, num_channels=num_channels, num_filters=g_filters, num_extra_layers=g_extra_layers).to(device) net_d = Discriminator(img_size=img_size, num_channels=num_channels, num_filters=d_filters, num_extra_layers=d_extra_layers).to(device) summary(net_g, (z_dim, 1, 1, 1)) summary(net_d, (num_channels, img_size, img_size, img_size)) if saved_g: net_g.load_state_dict(torch.load(experiment_dir / G_CHECKPOINT_NAME)) logger.info("Loaded generator checkpoint") if saved_d: net_d.load_state_dict(torch.load(experiment_dir / D_CHECKPOINT_NAME)) logger.info("Loaded discriminator checkpoint") # criterion criterion = nn.BCELoss() optimizer_g = optim.Adam(net_g.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizer_d = optim.Adam(net_d.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) patience = int(3000 / len(loader)) scheduler_g = optim.lr_scheduler.ReduceLROnPlateau(optimizer_g, min_lr=1e-6, verbose=True, patience=patience) scheduler_d = optim.lr_scheduler.ReduceLROnPlateau(optimizer_d, min_lr=1e-6, verbose=True, patience=patience) # labels smoothing real_labels = torch.full((batch_size, ), fill_value=0.9, device=device) fake_labels = torch.zeros((batch_size, ), device=device) fixed_noise = torch.randn(1, z_dim, 1, 1, 1, device=device) def step(engine: Engine, batch: torch.Tensor) -> Dict[str, float]: """ Train step function :param engine: pytorch ignite train engine :param batch: batch to process :return batch metrics """ # get batch of fake images from generator fake_batch = net_g( torch.randn(batch_size, z_dim, 1, 1, 1, device=device)) # 1. Update D network: maximize log(D(x)) + log(1 - D(G(z))) batch = batch.to(device) optimizer_d.zero_grad() # train D with real and fake batches d_out_real = net_d(batch) d_out_fake = net_d(fake_batch.detach()) loss_d_real = criterion(d_out_real, real_labels) loss_d_fake = criterion(d_out_fake, fake_labels) # mean probabilities p_real = d_out_real.mean().item() p_fake = d_out_fake.mean().item() loss_d = (loss_d_real + loss_d_fake) / 2 loss_d.backward() optimizer_d.step() # 2. Update G network: maximize log(D(G(z))) loss_g = None p_gen = None for _ in range(1): fake_batch = net_g( torch.randn(batch_size, z_dim, 1, 1, 1, device=device)) optimizer_g.zero_grad() d_out_fake = net_d(fake_batch) loss_g = criterion(d_out_fake, real_labels) # mean fake generator probability p_gen = d_out_fake.mean().item() loss_g.backward() optimizer_g.step() # minkowski functional measures cube = net_g(fixed_noise).detach().squeeze().cpu() cube = cube.mul(0.5).add(0.5).numpy() cube = postprocess_cube(cube) cube = np.pad(cube, ((1, 1), (1, 1), (1, 1)), mode='constant', constant_values=0) v, s, b, xi = compute_minkowski(cube) return { 'loss_d': loss_d.item(), 'loss_g': loss_g.item(), 'p_real': p_real, 'p_fake': p_fake, 'p_gen': p_gen, 'V': v, 'S': s, 'B': b, 'Xi': xi } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(dirname=str(experiment_dir), filename_prefix=CKPT_PREFIX, save_interval=5, n_saved=50, require_empty=False) # attach running average metrics monitoring_metrics = [ 'loss_d', 'loss_g', 'p_real', 'p_fake', 'p_gen', 'V', 'S', 'B', 'Xi' ] RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_d']).attach( trainer, 'loss_d') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['loss_g']).attach( trainer, 'loss_g') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_real']).attach( trainer, 'p_real') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_fake']).attach( trainer, 'p_fake') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['p_gen']).attach( trainer, 'p_gen') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['V']).attach(trainer, 'V') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['S']).attach(trainer, 'S') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['B']).attach(trainer, 'B') RunningAverage(alpha=ALPHA, output_transform=lambda x: x['Xi']).attach(trainer, 'Xi') # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = experiment_dir / LOGS_FNAME columns = ['iter'] + list(engine.state.metrics.keys()) values = [str(engine.state.iteration)] + [ str(round(value, 7)) for value in engine.state.metrics.values() ] with fname.open(mode='a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration:04d}/{iterations}]" for name, value in zip(engine.state.metrics.keys(), engine.state.metrics.values()): message += f" | {name}: {value:0.5f}" pbar.log_message(message) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net_g': net_g, 'net_d': net_d }) @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): df = pd.read_csv(experiment_dir / LOGS_FNAME, delimiter='\t') fig_1 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['loss_d'], label='loss_d', linestyle='dashed') plt.plot(df['iter'], df['loss_g'], label='loss_g') plt.xlabel('Iteration number') plt.legend() fig_1.savefig(experiment_dir / ('loss_' + PLOT_FNAME)) plt.close(fig_1) fig_2 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['p_real'], label='p_real', linestyle='dashed') plt.plot(df['iter'], df['p_fake'], label='p_fake', linestyle='dashdot') plt.plot(df['iter'], df['p_gen'], label='p_gen') plt.xlabel('Iteration number') plt.legend() fig_2.savefig(experiment_dir / PLOT_FNAME) plt.close(fig_2) desired_v = [desired_minkowski[0]] * len(df['iter']) desired_s = [desired_minkowski[1]] * len(df['iter']) desired_b = [desired_minkowski[2]] * len(df['iter']) desired_xi = [desired_minkowski[3]] * len(df['iter']) fig_3 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['V'], label='V', color='b') plt.plot(df['iter'], desired_v, color='b', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional V') plt.legend() fig_3.savefig(experiment_dir / ('minkowski_V_' + PLOT_FNAME)) plt.close(fig_3) fig_4 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['S'], label='S', color='r') plt.plot(df['iter'], desired_s, color='r', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional S') plt.legend() fig_4.savefig(experiment_dir / ('minkowski_S_' + PLOT_FNAME)) plt.close(fig_4) fig_5 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['B'], label='B', color='g') plt.plot(df['iter'], desired_b, color='g', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional B') plt.legend() fig_5.savefig(experiment_dir / ('minkowski_B_' + PLOT_FNAME)) plt.close(fig_5) fig_6 = plt.figure(figsize=(18, 12)) plt.plot(df['iter'], df['Xi'], label='Xi', color='y') plt.plot(df['iter'], desired_xi, color='y', linestyle='dashed') plt.xlabel('Iteration number') plt.ylabel('Minkowski functional Xi') plt.legend() fig_6.savefig(experiment_dir / ('minkowski_Xi_' + PLOT_FNAME)) plt.close(fig_6) if scheduler: @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler(engine): desired_b = desired_minkowski[2] desired_xi = desired_minkowski[3] current_b = engine.state.metrics['B'] current_xi = engine.state.metrics['Xi'] delta = abs(desired_b - current_b) + abs(desired_xi - current_xi) scheduler_d.step(delta) scheduler_g.step(delta) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') create_plots(engine) checkpoint_handler(engine, { 'net_g_exception': net_g, 'net_d_exception': net_d }) else: raise e trainer.run(loader, epochs)
def train(self, config, **kwargs): """Trains a model on the given configurations. :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE :param **kwargs: parameters to overwrite yaml config """ from pycocoevalcap.cider.cider import Cider config_parameters = train_util.parse_config_or_kwargs(config, **kwargs) config_parameters["seed"] = self.seed outputdir = os.path.join( config_parameters["outputpath"], config_parameters["model"], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) # Early init because of creating dir checkpoint_handler = ModelCheckpoint( outputdir, "run", n_saved=1, require_empty=False, create_dir=True, score_function=lambda engine: engine.state.metrics["score"], score_name="score") logger = train_util.genlogger(os.path.join(outputdir, "train.log")) # print passed config parameters logger.info("Storing files in: {}".format(outputdir)) train_util.pprint_dict(config_parameters, logger.info) zh = config_parameters["zh"] vocabulary = torch.load(config_parameters["vocab_file"]) train_loader, cv_loader, info = self._get_dataloaders( config_parameters, vocabulary) config_parameters["inputdim"] = info["inputdim"] cv_key2refs = info["cv_key2refs"] logger.info("<== Estimating Scaler ({}) ==>".format( info["scaler"].__class__.__name__)) logger.info("Feature: {} Input dimension: {} Vocab Size: {}".format( config_parameters["feature_file"], info["inputdim"], len(vocabulary))) model = self._get_model(config_parameters, len(vocabulary)) if "pretrained_word_embedding" in config_parameters: embeddings = np.load( config_parameters["pretrained_word_embedding"]) model.load_word_embeddings( embeddings, tune=config_parameters["tune_word_embedding"], projection=True) model = model.to(self.device) train_util.pprint_dict(model, logger.info, formatter="pretty") optimizer = getattr(torch.optim, config_parameters["optimizer"])( model.parameters(), **config_parameters["optimizer_args"]) train_util.pprint_dict(optimizer, logger.info, formatter="pretty") criterion = torch.nn.CrossEntropyLoss().to(self.device) crtrn_imprvd = train_util.criterion_improver( config_parameters['improvecriterion']) def _train_batch(engine, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() output = self._forward(model, batch, "train") loss = criterion(output["packed_logits"], output["targets"]).to(self.device) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() output["loss"] = loss.item() return output trainer = Engine(_train_batch) RunningAverage(output_transform=lambda x: x["loss"]).attach( trainer, "running_loss") pbar = ProgressBar(persist=False, ascii=True, ncols=100) pbar.attach(trainer, ["running_loss"]) key2pred = {} def _inference(engine, batch): model.eval() keys = batch[2] with torch.no_grad(): output = self._forward(model, batch, "validation") seqs = output["seqs"].cpu().numpy() for (idx, seq) in enumerate(seqs): if keys[idx] in key2pred: continue candidate = self._convert_idx2sentence(seq, vocabulary, zh) key2pred[keys[idx]] = [ candidate, ] return output metrics = { "loss": Loss(criterion, output_transform=lambda x: (x["packed_logits"], x["targets"])) } evaluator = Engine(_inference) def eval_cv(engine, key2pred, key2refs): scorer = Cider(zh=zh) score, scores = scorer.compute_score(key2refs, key2pred) engine.state.metrics["score"] = score key2pred.clear() evaluator.add_event_handler(Events.EPOCH_COMPLETED, eval_cv, key2pred, cv_key2refs) for name, metric in metrics.items(): metric.attach(evaluator, name) trainer.add_event_handler(Events.EPOCH_COMPLETED, train_util.log_results, evaluator, cv_loader, logger.info, ["loss", "score"]) evaluator.add_event_handler( Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd, "score", { "model": model.state_dict(), "config": config_parameters, "scaler": info["scaler"] }, os.path.join(outputdir, "saved.pth")) scheduler = getattr(torch.optim.lr_scheduler, config_parameters["scheduler"])( optimizer, **config_parameters["scheduler_args"]) evaluator.add_event_handler(Events.EPOCH_COMPLETED, train_util.update_lr, scheduler, "score") evaluator.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, }) trainer.run(train_loader, max_epochs=config_parameters["epochs"]) return outputdir
def main(): # region Setup conf = parse_args() setup_seeds(conf.session.seed) tb_logger, tb_img_logger, json_logger = setup_all_loggers(conf) logger.info("Parsed configuration:\n" + pyaml.dump(OmegaConf.to_container(conf), safe=True, sort_dicts=False, force_embed=True)) # region Predicate classification engines datasets, dataset_metadata = build_datasets(conf.dataset) dataloaders = build_dataloaders(conf, datasets) model = build_model(conf.model, dataset_metadata["train"]).to(conf.session.device) criterion = PredicateClassificationCriterion(conf.losses) pred_class_trainer = Trainer(pred_class_training_step, conf) pred_class_trainer.model = model pred_class_trainer.criterion = criterion pred_class_trainer.optimizer, scheduler = build_optimizer_and_scheduler( conf.optimizer, pred_class_trainer.model) pred_class_validator = Validator(pred_class_validation_step, conf) pred_class_validator.model = model pred_class_validator.criterion = criterion pred_class_tester = Validator(pred_class_validation_step, conf) pred_class_tester.model = model pred_class_tester.criterion = criterion # endregion if "resume" in conf: checkpoint = Path(conf.resume.checkpoint).expanduser().resolve() logger.debug(f"Resuming checkpoint from {checkpoint}") Checkpoint.load_objects( { "model": pred_class_trainer.model, "optimizer": pred_class_trainer.optimizer, "scheduler": scheduler, "trainer": pred_class_trainer, }, checkpoint=torch.load(checkpoint, map_location=conf.session.device), ) logger.info(f"Resumed from {checkpoint}, " f"epoch {pred_class_trainer.state.epoch}, " f"samples {pred_class_trainer.global_step()}") # endregion # region Predicate classification training callbacks def increment_samples(trainer: Trainer): images = trainer.state.batch[0] trainer.state.samples += len(images) pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED, increment_samples) ProgressBar(persist=True, desc="Pred class train").attach( pred_class_trainer, output_transform=itemgetter("losses")) tb_logger.attach( pred_class_trainer, OptimizerParamsHandler( pred_class_trainer.optimizer, param_name="lr", tag="z", global_step_transform=pred_class_trainer.global_step, ), Events.EPOCH_STARTED, ) pred_class_trainer.add_event_handler( Events.ITERATION_COMPLETED, PredicateClassificationMeanAveragePrecisionBatch()) pred_class_trainer.add_event_handler(Events.ITERATION_COMPLETED, RecallAtBatch(sizes=(5, 10))) tb_logger.attach( pred_class_trainer, OutputHandler( "train", output_transform=lambda o: { **o["losses"], "pc/mAP": o["pc/mAP"].mean().item(), **{k: r.mean().item() for k, r in o["recalls"].items()}, }, global_step_transform=pred_class_trainer.global_step, ), Events.ITERATION_COMPLETED, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification training", "train", json_logger=None, tb_logger=tb_logger, global_step_fn=pred_class_trainer.global_step, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="train", logger=tb_img_logger.writer, metadata=dataset_metadata["train"], global_step_fn=pred_class_trainer.global_step, ), ) tb_logger.attach( pred_class_trainer, EpochHandler( pred_class_trainer, tag="z", global_step_transform=pred_class_trainer.global_step, ), Events.EPOCH_COMPLETED, ) pred_class_trainer.add_event_handler( Events.EPOCH_COMPLETED, lambda _: pred_class_validator.run(dataloaders["val"])) # endregion # region Predicate classification validation callbacks ProgressBar(persist=True, desc="Pred class val").attach(pred_class_validator) if conf.losses["bce"]["weight"] > 0: Average(output_transform=lambda o: o["losses"]["loss/bce"]).attach( pred_class_validator, "loss/bce") if conf.losses["rank"]["weight"] > 0: Average(output_transform=lambda o: o["losses"]["loss/rank"]).attach( pred_class_validator, "loss/rank") Average(output_transform=lambda o: o["losses"]["loss/total"]).attach( pred_class_validator, "loss/total") PredicateClassificationMeanAveragePrecisionEpoch( itemgetter("target", "output")).attach(pred_class_validator, "pc/mAP") RecallAtEpoch((5, 10), itemgetter("target", "output")).attach(pred_class_validator, "pc/recall_at") pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, lambda val_engine: scheduler.step(val_engine.state.metrics["loss/total" ]), ) pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification validation", "val", json_logger, tb_logger, pred_class_trainer.global_step, ) pred_class_validator.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="val", logger=tb_img_logger.writer, metadata=dataset_metadata["val"], global_step_fn=pred_class_trainer.global_step, ), ) pred_class_validator.add_event_handler( Events.COMPLETED, EarlyStopping( patience=conf.session.early_stopping.patience, score_function=lambda val_engine: -val_engine.state.metrics[ "loss/total"], trainer=pred_class_trainer, ), ) pred_class_validator.add_event_handler( Events.COMPLETED, Checkpoint( { "model": pred_class_trainer.model, "optimizer": pred_class_trainer.optimizer, "scheduler": scheduler, "trainer": pred_class_trainer, }, DiskSaver( Path(conf.checkpoint.folder).expanduser().resolve() / conf.fullname), score_function=lambda val_engine: val_engine.state.metrics[ "pc/recall_at_5"], score_name="pc_recall_at_5", n_saved=conf.checkpoint.keep, global_step_transform=pred_class_trainer.global_step, ), ) # endregion if "test" in conf.dataset: # region Predicate classification testing callbacks if conf.losses["bce"]["weight"] > 0: Average( output_transform=lambda o: o["losses"]["loss/bce"], device=conf.session.device, ).attach(pred_class_tester, "loss/bce") if conf.losses["rank"]["weight"] > 0: Average( output_transform=lambda o: o["losses"]["loss/rank"], device=conf.session.device, ).attach(pred_class_tester, "loss/rank") Average( output_transform=lambda o: o["losses"]["loss/total"], device=conf.session.device, ).attach(pred_class_tester, "loss/total") PredicateClassificationMeanAveragePrecisionEpoch( itemgetter("target", "output")).attach(pred_class_tester, "pc/mAP") RecallAtEpoch((5, 10), itemgetter("target", "output")).attach(pred_class_tester, "pc/recall_at") ProgressBar(persist=True, desc="Pred class test").attach(pred_class_tester) pred_class_tester.add_event_handler( Events.EPOCH_COMPLETED, log_metrics, "Predicate classification test", "test", json_logger, tb_logger, pred_class_trainer.global_step, ) pred_class_tester.add_event_handler( Events.EPOCH_COMPLETED, PredicateClassificationLogger( grid=(2, 3), tag="test", logger=tb_img_logger.writer, metadata=dataset_metadata["test"], global_step_fn=pred_class_trainer.global_step, ), ) # endregion # region Run log_effective_config(conf, pred_class_trainer, tb_logger) if not ("resume" in conf and conf.resume.test_only): max_epochs = conf.session.max_epochs if "resume" in conf: max_epochs += pred_class_trainer.state.epoch pred_class_trainer.run( dataloaders["train"], max_epochs=max_epochs, seed=conf.session.seed, epoch_length=len(dataloaders["train"]), ) if "test" in conf.dataset: pred_class_tester.run(dataloaders["test"]) add_session_end(tb_logger.writer, "SUCCESS") tb_logger.close() tb_img_logger.close()
def run(config, logger): plx_logger = PolyaxonLogger() set_seed(config.seed) plx_logger.log_params(**{ "seed": config.seed, "batch_size": config.batch_size, "pytorch version": torch.__version__, "ignite version": ignite.__version__, "cuda version": torch.version.cuda }) device = config.device non_blocking = config.non_blocking prepare_batch = config.prepare_batch def stats_collect_function(engine, batch): x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_ohe = to_onehot(y.reshape(-1), config.num_classes) class_distrib = y_ohe.mean(dim=0).cpu() class_presence = (class_distrib > 1e-3).cpu().float() num_classes = (class_distrib > 1e-3).sum().item() engine.state.class_presence += class_presence engine.state.class_presence -= (1 - class_presence) return { "class_distrib": class_distrib, "class_presence": engine.state.class_presence, "num_classes": num_classes } stats_collector = Engine(stats_collect_function) ProgressBar(persist=True).attach(stats_collector) @stats_collector.on(Events.STARTED) def init_vars(engine): engine.state.class_presence = torch.zeros(config.num_classes) log_dir = get_outputs_path() if log_dir is None: log_dir = "output" tb_logger = TensorboardLogger(log_dir=log_dir) tb_handler = tb_output_handler(tag="training", output_transform=lambda x: x) tb_logger.attach(stats_collector, log_handler=tb_handler, event_name=Events.ITERATION_COMPLETED) stats_collector.run(config.train_loader, max_epochs=1) remove_handler(stats_collector, tb_handler, Events.ITERATION_COMPLETED) tb_logger.attach(stats_collector, log_handler=tb_output_handler(tag="validation", output_transform=lambda x: x), event_name=Events.ITERATION_COMPLETED) stats_collector.run(config.val_loader, max_epochs=1)
def _setup_common_training_handlers( trainer: Engine, to_save: Optional[Mapping] = None, save_every_iters: int = 1000, output_path: Optional[str] = None, lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None, with_gpu_stats: bool = False, output_names: Optional[Iterable[str]] = None, with_pbars: bool = True, with_pbar_on_iters: bool = True, log_every_iters: int = 100, stop_on_nan: bool = True, clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any, ) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" ) if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None and save_handler is None: raise ValueError( "If to_save argument is provided then output_path or save_handler arguments should be also defined" ) if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) checkpoint_handler = Checkpoint(to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED( every=log_every_iters) # type: ignore[arg-type] ) if output_names is not None: def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise TypeError( "Unhandled type of update_function's output. " f"It should either mapping or sequence, but given {type(x)}" ) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(trainer, n) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
def run(output_path, config): device = "cuda" local_rank = config['local_rank'] distributed = backend is not None if distributed: torch.cuda.set_device(local_rank) device = "cuda" rank = dist.get_rank() if distributed else 0 # Rescale batch_size and num_workers ngpus_per_node = torch.cuda.device_count() ngpus = dist.get_world_size() if distributed else 1 batch_size = config['batch_size'] // ngpus num_workers = int( (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node) train_labelled_loader, test_loader = \ get_train_test_loaders(path=config['data_path'], batch_size=batch_size, distributed=distributed, num_workers=num_workers) model = get_model(config['model']) model = model.to(device) if distributed: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[ local_rank, ], output_device=local_rank) optimizer = optim.SGD(model.parameters(), lr=config['learning_rate'], momentum=config['momentum'], weight_decay=config['weight_decay'], nesterov=True) criterion = nn.CrossEntropyLoss().to(device) le = len(train_labelled_loader) milestones_values = [(0, 0.0), (le * config['num_warmup_epochs'], config['learning_rate']), (le * config['num_epochs'], 0.0)] lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values) def _prepare_batch(batch, device, non_blocking): x, y = batch return (convert_tensor(x, device=device, non_blocking=non_blocking), convert_tensor(y, device=device, non_blocking=non_blocking)) def process_function(engine, labelled_batch): x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { 'batch loss': loss.item(), } trainer = Engine(process_function) if not hasattr(lr_scheduler, "step"): trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) metric_names = [ 'batch loss', ] def output_transform(x, name): return x[name] for n in metric_names: # We compute running average values on the output (batch loss) across all devices RunningAverage(output_transform=partial(output_transform, name=n), epoch_bound=False, device=device).attach(trainer, n) if rank == 0: checkpoint_handler = ModelCheckpoint(dirname=output_path, filename_prefix="checkpoint") trainer.add_event_handler(Events.ITERATION_COMPLETED(every=1000), checkpoint_handler, { 'model': model, 'optimizer': optimizer }) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED) if config['display_iters']: ProgressBar(persist=False, bar_format="").attach(trainer, metric_names=metric_names) tb_logger = TensorboardLogger(log_dir=output_path) tb_logger.attach(trainer, log_handler=tbOutputHandler( tag="train", metric_names=metric_names), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=tbOptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED) metrics = { "accuracy": Accuracy(device=device if distributed else None), "loss": Loss(criterion, device=device if distributed else None) } evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=device, non_blocking=True) def run_validation(engine): torch.cuda.synchronize() train_evaluator.run(train_labelled_loader) evaluator.run(test_loader) trainer.add_event_handler(Events.EPOCH_STARTED(every=3), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) if rank == 0: if config['display_iters']: ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Test evaluation").attach(evaluator) tb_logger.attach(train_evaluator, log_handler=tbOutputHandler(tag="train", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) tb_logger.attach(evaluator, log_handler=tbOutputHandler(tag="test", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.COMPLETED) # Store the best model def default_score_fn(engine): score = engine.state.metrics['accuracy'] return score score_function = default_score_fn if not hasattr( config, "score_function") else config.score_function best_model_handler = ModelCheckpoint( dirname=output_path, filename_prefix="best", n_saved=3, global_step_transform=global_step_from_engine(trainer), score_name="val_accuracy", score_function=score_function) evaluator.add_event_handler(Events.COMPLETED, best_model_handler, { 'model': model, }) trainer.run(train_labelled_loader, max_epochs=config['num_epochs']) if rank == 0: tb_logger.close()
def main(batch_size, epochs, length_scale, centroid_size, model_output_size, learning_rate, l_gradient_penalty, gamma, weight_decay, final_model, input_dep_ls, use_grad_norm): # Dataset prep ds = all_datasets["CIFAR10"]() input_size, num_classes, dataset, test_dataset = ds # Split up training set idx = list(range(len(dataset))) random.shuffle(idx) if final_model: train_dataset = dataset val_dataset = test_dataset else: val_size = int(len(dataset) * 0.8) train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) val_dataset.transform = (test_dataset.transform) kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) # Model global model model = ResNet_DUQ(input_size, num_classes, centroid_size, model_output_size, length_scale, gamma) model = model.cuda() #model.load_state_dict(torch.load("DUQ_CIFAR_75.pt")) # Optimiser optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 50, 75], gamma=0.2) def bce_loss_fn(y_pred, y): bce = F.binary_cross_entropy(y_pred, y, reduction="sum").div( num_classes * y_pred.shape[0]) return bce def output_transform_bce(output): y_pred, y, x = output y = F.one_hot(y, num_classes).float() return y_pred, y def output_transform_acc(output): y_pred, y, x = output return y_pred, y def output_transform_gp(output): y_pred, y, x = output return x, y_pred def calc_gradients_input(x, y_pred): gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) return gradients def calc_gradient_penalty(x, y_pred): gradients = calc_gradients_input(x, y_pred) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1)**2).mean() return gradient_penalty def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() if l_gradient_penalty > 0: x.requires_grad_(True) z, y_pred = model(x) y = F.one_hot(y, num_classes).float() loss = bce_loss_fn(y_pred, y) # Avoid calc of computing if l_gradient_penalty > 0: loss += l_gradient_penalty * calc_gradient_penalty(x, y_pred) if use_grad_norm: #gradient normalization loss /= (1 + l_gradient_penalty) loss.backward() optimizer.step() x.requires_grad_(False) with torch.no_grad(): model.eval() model.update_embeddings(x, y) return loss.item() trainer = Engine(step) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): # logging every 10 epoch or last epochs if trainer.state.epoch % 10 == 0 or trainer.state.epoch > epochs - 5: #acc on cifar test set and auroc on cifar+svhn testsets testacc, auroc_cifsv = get_cifar_svhn_ood(model) #acc on cifar val set and self auroc on cifar valset val_acc, self_auroc = get_auroc_classification(val_dataset, model) print(f"Test Accuracy: {testacc}, AUROC: {auroc_cifsv}") print( f"AUROC - uncertainty: {self_auroc}, Val Accuracy : {val_acc}") scheduler.step() # save if trainer.state.epoch == epochs - 1: torch.save(model.state_dict(), f"model_{trainer.state.epoch}.pt") pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) trainer.run(train_loader, max_epochs=epochs) testacc, auroc_cifsv = get_cifar_svhn_ood(model) val_acc, self_auroc = get_auroc_classification(val_dataset, model) return testacc, auroc_cifsv, val_acc, self_auroc
epoch = engine.state.epoch evaluator.run(val_ld) val_wra_vle = round(evaluator.state.metrics['WRA'], 3) print(f"EPOCH:[{epoch}] VAL WRA:{val_wra_vle}") @trainer.on(Events.COMPLETED) def test(engine): print("TEST EVAL") evaluator.run(test_ld) test_wra_vle = round(evaluator.state.metrics["WRA"], 3) report = f"{RUN_NAME};{test_wra_vle}\n" with EVALUATION_RESULTS_FILE_PATH.open(mode='a') as f: f.writelines(report) print(f"TRAINING IS DONE FOR {RUN_NAME} RUN.") pbar = ProgressBar() checkpointer = ModelCheckpoint( CHECKPOINTS_RUN_DIR_PATH, filename_prefix=RUN_NAME.lower(), n_saved=None, score_function=lambda engine: round(engine.state.metrics['WRA'], 3), score_name='WRA', atomic=True, require_empty=True, create_dir=True, archived=False, global_step_transform=global_step_from_engine(trainer)) nan_handler = TerminateOnNan() coslr = CosineAnnealingScheduler(opt, "lr",