def prune_train_loop(model, params, ds, dset, min_y, base_data, model_id, prune_type, device, batch_size, tpa, max_epochs=2): assert prune_type in ['global_unstructured', 'structured'] total_prune_amount = tpa ds_train, ds_valid = ds train_set, valid_set = dset min_y_train, min_y_val = min_y model_id = f'{model_id}_{prune_type}_pruning_{tpa}' valid_freq = 200 * 500 // batch_size // 3 conv_layers = [model.conv1] def prune_model(model): # remove_amount = total_prune_amount // (max_epochs) remove_amount = total_prune_amount print(f'pruned model by {remove_amount}') worst = select_filters(model, ds_valid, valid_set, remove_amount, device) worst = [ k for k in Counter(torch.stack(worst).view(-1).cpu().numpy()).keys() ] worst.sort(reverse=True) print(worst) for layer in conv_layers: for d in worst: TuckerStructured(layer, name='weight', amount=0, dim=0, filt=d) return worst bad = prune_model(model) zeros = [] wrong = [] for i in range(len(model.conv1.weight_mask)): if torch.sum(model.conv1.weight_mask[i]) == 0.0: zeros.append(i) zeros.sort(reverse=True) if zeros == bad: print("correctly zero'd filters") else: if len(zeros) == len(bad): for i in range(len(zeros)): if zeros[i] != bad[i]: wrong.append((bad[i], zeros[i])) print(wrong) else: print("diff number filters zero'd", zeros) with create_summary_writer(model, ds_train, base_data, model_id, device=device) as writer: lr = params['lr'] mom = params['momentum'] wd = params['l2_wd'] optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=mom, weight_decay=wd) sched = ReduceLROnPlateau(optimizer, factor=0.5, patience=5) funcs = {'accuracy': Accuracy(), 'loss': Loss(F.cross_entropy)} loss = funcs['loss']._loss_fn acc_metric = Accuracy(device=device) loss_metric = Loss(F.cross_entropy, device=device) acc_val_metric = Accuracy(device=device) loss_val_metric = Loss(F.cross_entropy, device=device) def train_step(engine, batch): model.train() x, y = batch x = x.to(device) y = y.to(device) - min_y_train optimizer.zero_grad() ans = model.forward(x) l = loss(ans, y) optimizer.zero_grad() l.backward() optimizer.step() with torch.no_grad(): for layer in conv_layers: layer.weight *= layer.weight_mask # make sure pruned weights stay 0 return l.item() trainer = Engine(train_step) def train_eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_train with torch.no_grad(): ans = model.forward(x) return ans, y train_evaluator = Engine(train_eval_step) acc_metric.attach(train_evaluator, "accuracy") loss_metric.attach(train_evaluator, 'loss') def validation_step(engine, batch): model.eval() x, y = batch x = x.to(device) y = y.to(device) - min_y_val with torch.no_grad(): ans = model.forward(x) return ans, y valid_evaluator = Engine(validation_step) acc_val_metric.attach(valid_evaluator, "accuracy") loss_val_metric.attach(valid_evaluator, 'loss') @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq)) # @trainer.on(Events.ITERATION_COMPLETED) def log_validation_results(engine): valid_evaluator.run(ds_valid) metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] avg_nll = metrics['loss'] print( "Validation Results - Epoch: {} Avg accuracy: {:.2f} Avg loss: {:.2f}" .format(engine.state.epoch, valid_avg_accuracy, avg_nll)) writer.add_scalar("validation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("validation/avg_accuracy", valid_avg_accuracy, engine.state.epoch) writer.add_scalar("validation/avg_error", 1. - valid_avg_accuracy, engine.state.epoch) # prune_model(model) @trainer.on(Events.EPOCH_COMPLETED) def lr_scheduler(engine): metrics = valid_evaluator.state.metrics avg_nll = metrics['accuracy'] sched.step(avg_nll) @trainer.on(Events.ITERATION_COMPLETED(every=100)) def log_training_loss(engine): batch = engine.state.batch ds = DataLoader(TensorDataset(*batch), batch_size=batch_size) train_evaluator.run(ds) metrics = train_evaluator.state.metrics accuracy = metrics['accuracy'] nll = metrics['loss'] iter = (engine.state.iteration - 1) % len(ds_train) + 1 if (iter % 100) == 0: print("Epoch[{}] Iter[{}/{}] Accuracy: {:.2f} Loss: {:.2f}". format(engine.state.epoch, iter, len(ds_train), accuracy, nll)) writer.add_scalar("batchtraining/detloss", nll, engine.state.epoch) writer.add_scalar("batchtraining/accuracy", accuracy, engine.state.iteration) writer.add_scalar("batchtraining/error", 1. - accuracy, engine.state.iteration) writer.add_scalar("batchtraining/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_lr(engine): writer.add_scalar("lr", optimizer.param_groups[0]['lr'], engine.state.epoch) @trainer.on(Events.ITERATION_COMPLETED(every=valid_freq)) def validation_value(engine): metrics = valid_evaluator.state.metrics valid_avg_accuracy = metrics['accuracy'] return valid_avg_accuracy to_save = {'model': model} handler = Checkpoint( to_save, DiskSaver(os.path.join(base_data, model_id), create_dir=True), score_function=validation_value, score_name="val_acc", global_step_transform=global_step_from_engine(trainer), n_saved=None) # kick everything off trainer.add_event_handler(Events.ITERATION_COMPLETED(every=valid_freq), handler) trainer.run(ds_train, max_epochs=max_epochs)
loss_v.backward() optimizer.step() epsilon_tracker.frame(engine.state.iteration) if engine.state.iteration % EVAL_EVERY_FRAME == 0: eval_states = getattr(engine.state, "eval_states", None) if eval_states is None: eval_states = buffer.sample(STATES_TO_EVALUATE) eval_states = [ np.array(transition.state, copy=False) for transition in eval_states ] eval_states = np.array(eval_states, copy=False) engine.state.eval_state = eval_states evaluate_states(eval_states, net, device, engine) return {"loss": loss_v.item(), "epsilon": selector.epsilon} engine = Engine(process_batch) common.setup_ignite(engine, params, exp_source, NAME, extra_metrics=("adv", "val")) engine.run( common.batch_generator(buffer, params.replay_initial, params.batch_size))
def test_attach_fail_with_string(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, "a")
def train(self, model, crit, optimizer, train_loader, valid_loader, src_vocab, tgt_vocab, n_epochs, lr_scheduler=None): trainer = Engine(self.step) trainer.config = self.config trainer.model, trainer.crit = model, crit trainer.optimizer, trainer.lr_scheduler = optimizer, lr_scheduler trainer.epoch_idx = 0 evaluator = Engine(self.validate) evaluator.config = self.config evaluator.model, evaluator.crit = model, crit evaluator.best_loss = np.inf self.attach(trainer, evaluator, verbose=self.config.verbose) def run_validation(engine, evaluator, valid_loader): evaluator.run(valid_loader, max_epochs=1) if engine.lr_scheduler is not None and not engine.config.use_noam_decay: engine.lr_scheduler.step() trainer.add_event_handler(Events.EPOCH_COMPLETED, run_validation, evaluator, valid_loader) evaluator.add_event_handler(Events.EPOCH_COMPLETED, self.check_best) evaluator.add_event_handler( Events.EPOCH_COMPLETED, self.save_model, trainer, self.config, src_vocab, tgt_vocab, ) trainer.run(train_loader, max_epochs=n_epochs) return model
def train(): parser = ArgumentParser() parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss, mc_loss = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=None) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.logdir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.logdir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.logdir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.logdir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.logdir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
eps_tracker.frame(engine.state.iteration) if getattr(engine.state, "eval_states", None) is None: eval_states = buffer.sample(STATES_TO_EVALUATE) eval_states = [ np.array(transition.state, copy=False) for transition in eval_states ] engine.state.eval_states = np.array(eval_states, copy=False) return { "loss": loss_v.item(), "epsilon": selector.epsilon, } engine = Engine(process_batch) tb = common.setup_ignite(engine, exp_source, f"conv-{args.run}", extra_metrics=("values_mean", )) @engine.on(ptan.ignite.PeriodEvents.ITERS_1000_COMPLETED) def sync_eval(engine: Engine): tgt_net.sync() mean_val = common.calc_values_of_states(engine.state.eval_states, net, device=device) engine.state.metrics["values_mean"] = mean_val if getattr(engine.state, "best_mean_val", None) is None: engine.state.best_mean_val = mean_val
def main(args): os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu) cfg = load_config(args.config) path, config_name = os.path.split(args.config) copyfile(args.config, os.path.join(cfg.workdir, config_name)) copyfile(os.path.join(path, "model.py"), os.path.join(cfg.workdir, "model.py")) pbar = ProgressBar() tb_logger = TensorboardLogger(log_dir=os.path.join(cfg.workdir, "tb_logs")) checkpointer = ModelCheckpoint(os.path.join(cfg.workdir, "checkpoints"), '', save_interval=1, n_saved=cfg.n_epochs, create_dir=True, atomic=True) def _update(engine, batch): cfg.model.train() cfg.optimizer.zero_grad() x, y = cfg.prepare_train_batch(batch) y_pred = cfg.model(**x) loss = cfg.loss_fn(y_pred, y) loss['loss'].backward() cfg.optimizer.step() for k in loss: loss[k] = loss[k].item() return loss trainer = Engine(_update) pbar.attach(trainer, output_transform=lambda x: {k: "{:.5f}".format(v) for k, v in x.items()}) trainer.add_event_handler(Events.ITERATION_STARTED, cfg.scheduler) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpointer, {'model': cfg.model, 'optimizer': cfg.optimizer}) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda x: x), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(cfg.optimizer), event_name=Events.ITERATION_STARTED) # tb_logger.attach(trainer, # log_handler=WeightsScalarHandler(cfg.model), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, # log_handler=WeightsHistHandler(cfg.model), # event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer, # log_handler=GradsScalarHandler(cfg.model), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, # log_handler=GradsHistHandler(cfg.model), # event_name=Events.EPOCH_COMPLETED) def _evaluate(engine, batch): cfg.model.eval() x, y = cfg.prepare_train_batch(batch) batch_size = len(batch[list(batch.keys())[0]]) with torch.no_grad(): y_pred = cfg.model(**x) loss = cfg.loss_fn(y_pred, y) for k in loss: loss[k] = loss[k].item() if k not in engine.state.metrics: engine.state.metrics[k] = 0.0 engine.state.metrics[k] += loss[k] * batch_size / len(cfg.valid_ds) return loss evaluator = Engine(_evaluate) pbar.attach(evaluator, output_transform=lambda x: {k: "{:.5f}".format(v) for k, v in x.items()}) @trainer.on(Events.EPOCH_COMPLETED) def evaluate_on_valid_dl(engine): evaluator.run(cfg.valid_dl) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=['loss', 'rot_loss_cos', 'rot_loss_l1', 'trans_loss', 'true_distance', 'cls_loss'], global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) trainer.run(cfg.train_dl, cfg.n_epochs) tb_logger.close()
res["loss_distill"] = dist_loss_t.item() res.update( { "loss": loss_t.item(), "loss_value": loss_value_t.item(), "loss_policy": loss_policy_t.item(), "adv": adv_t.mean().item(), "loss_entropy": loss_entropy_t.item(), "time_batch": time.time() - start_ts, } ) return res engine = Engine(process_batch) common.setup_ignite( engine, params, exp_source, NAME + "_" + args.name, extra_metrics=("test_reward", "avg_test_reward", "test_steps"), ) @engine.on(ptan_ignite.PeriodEvents.ITERS_10000_COMPLETED) def test_network(engine): net.actor.train(False) obs = test_env.reset() reward = 0.0 steps = 0
def test_add_event_handler_raises_with_invalid_signature(): engine = Engine(MagicMock()) def handler(engine): pass engine.add_event_handler(Events.STARTED, handler) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler, 1) def handler_with_args(engine, a): pass engine.add_event_handler(Events.STARTED, handler_with_args, 1) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler_with_args) def handler_with_kwargs(engine, b=42): pass engine.add_event_handler(Events.STARTED, handler_with_kwargs, b=2) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler_with_kwargs, c=3) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler_with_kwargs, 1, b=2) def handler_with_args_and_kwargs(engine, a, b=42): pass engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, 2, b=2) with pytest.raises(ValueError): engine.add_event_handler(Events.STARTED, handler_with_args_and_kwargs, 1, b=2, c=3)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: lr * min(1., epoch / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], output_path=config["output_path"], lr_scheduler=lr_scheduler, output_names=metric_names if config["log_every_iters"] > 0 else None, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format( checkpoint_fp.as_posix()) logger.info("Resume from a checkpoint: {}".format( checkpoint_fp.as_posix())) checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def _setup_common_training_handlers( trainer: Engine, to_save: Optional[Mapping] = None, save_every_iters: int = 1000, output_path: Optional[str] = None, lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None, with_gpu_stats: bool = False, output_names: Optional[Iterable[str]] = None, with_pbars: bool = True, with_pbar_on_iters: bool = True, log_every_iters: int = 100, stop_on_nan: bool = True, clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any, ) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" ) if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step() ) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None and save_handler is None: raise ValueError( "If to_save argument is provided then output_path or save_handler arguments should be also defined" ) if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) checkpoint_handler = Checkpoint( to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs ) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) # type: ignore[arg-type] ) if output_names is not None: def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise TypeError( "Unhandled type of update_function's output. " f"It should either mapping or sequence, but given {type(x)}" ) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach( trainer, n ) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) ) ProgressBar(persist=True, bar_format="").attach( trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED )
def run_trainer(data_loader: dict, model: models, optimizer: optim, lr_scheduler: optim.lr_scheduler, criterion: nn, train_epochs: int, log_training_progress_every: int, log_val_progress_every: int, checkpoint_every: int, tb_summaries_dir: str, chkpt_dir: str, resume_from: str, to_device: object, to_cpu: object, attackers: object = None, train_adv_periodic_ops: int = None, *args, **kwargs): def mk_lr_step(loss): lr_scheduler.step(loss) def train_step(engine, batch): model.train() optimizer.zero_grad() x, y = map(lambda _: to_device(_), batch) if (train_adv_periodic_ops is not None) and ( engine.state.iteration % train_adv_periodic_ops == 0): random_attacker = random.choice(list(attackers)) x = attackers[random_attacker].perturb(x, y) y_pred = model(x) loss = criterion(y_pred, y) loss.backward() optimizer.step() return loss.item() def eval_step(engine, batch): model.eval() with torch.no_grad(): x, y = map(lambda _: to_device(_), batch) if random.choice(range(2)) % 2 == 0: random_attacker = random.choice(list(attackers)) x = attackers[random_attacker].perturb(x, y) y_pred = model(x) return y_pred, y def chkpt_score_func(engine): val_eval.run(data_loader['val']) y_pred, y = val_eval.state.output loss = criterion(y_pred, y) return np.mean(to_cpu(loss, convert_to_np=True)) # set up ignite engines trainer = Engine(train_step) train_eval = Engine(eval_step) val_eval = Engine(eval_step) @trainer.on(Events.ITERATION_COMPLETED(every=log_training_progress_every)) def log_training_results(engine): step = True run_type = 'train' train_eval.run(data_loader['train']) y_pred, y = train_eval.state.output loss = criterion(y_pred, y) log_results(to_cpu(y_pred, convert_to_np=True), to_cpu(y, convert_to_np=True), to_cpu(loss, convert_to_np=True), run_type, step, engine.state.iteration, total_train_steps, writer) @trainer.on(Events.ITERATION_COMPLETED(every=log_val_progress_every)) def log_val_results(engine): step = True run_type = 'val' val_eval.run(data_loader['val']) y_pred, y = val_eval.state.output loss = criterion(y_pred, y) mk_lr_step(loss) log_results(to_cpu(y_pred, convert_to_np=True), to_cpu(y, convert_to_np=True), to_cpu(loss, convert_to_np=True), run_type, step, engine.state.iteration, total_train_steps, writer) # set up vars total_train_steps = len(data_loader['train']) * train_epochs # reporter to identify memory usage # bottlenecks throughout network reporter = MemReporter() print_model(model, reporter) # set up tensorboard summary writer writer = create_summary_writer(model, data_loader['train'], tb_summaries_dir) # move model to device model = to_device(model) # set up progress bar RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') pbar = ProgressBar(persist=True, bar_format="") pbar.attach(trainer, ['loss']) # set up checkpoint objects_to_checkpoint = { 'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler } training_checkpoint = Checkpoint(to_save=objects_to_checkpoint, save_handler=DiskSaver( chkpt_dir, require_empty=False), n_saved=3, filename_prefix='best', score_function=chkpt_score_func, score_name='val_loss') # register events trainer.add_event_handler( Events.ITERATION_COMPLETED(every=checkpoint_every), training_checkpoint) # if resuming if resume_from and os.path.exists(resume_from): print(f'resume model from: {resume_from}') checkpoint = torch.load(resume_from) Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) # fire training engine trainer.run(data_loader['train'], max_epochs=train_epochs)
def train(args): device = torch.device("cuda" if args.cuda else "cpu") train_loader = check_dataset(args) transformer = TransformerNet().to(device) optimizer = Adam(transformer.parameters(), args.lr) mse_loss = torch.nn.MSELoss() vgg = Vgg16(requires_grad=False).to(device) style_transform = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: x.mul(255))]) style = utils.load_image(args.style_image, size=args.style_size) style = style_transform(style) style = style.repeat(args.batch_size, 1, 1, 1).to(device) features_style = vgg(utils.normalize_batch(style)) gram_style = [utils.gram_matrix(y) for y in features_style] running_avgs = OrderedDict() def step(engine, batch): x, _ = batch x = x.to(device) n_batch = len(x) optimizer.zero_grad() y = transformer(x) x = utils.normalize_batch(x) y = utils.normalize_batch(y) features_x = vgg(x) features_y = vgg(y) content_loss = args.content_weight * mse_loss(features_y.relu2_2, features_x.relu2_2) style_loss = 0.0 for ft_y, gm_s in zip(features_y, gram_style): gm_y = utils.gram_matrix(ft_y) style_loss += mse_loss(gm_y, gm_s[:n_batch, :, :]) style_loss *= args.style_weight total_loss = content_loss + style_loss total_loss.backward() optimizer.step() return {"content_loss": content_loss.item(), "style_loss": style_loss.item(), "total_loss": total_loss.item()} trainer = Engine(step) checkpoint_handler = ModelCheckpoint( args.checkpoint_model_dir, "checkpoint", n_saved=10, require_empty=False, create_dir=True ) progress_bar = Progbar(loader=train_loader, metrics=running_avgs) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED(every=args.checkpoint_interval), handler=checkpoint_handler, to_save={"net": transformer}, ) trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=progress_bar) trainer.run(train_loader, max_epochs=args.epochs)
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default='wikitext-2', help="One of ('wikitext-103', 'wikitext-2') or a dict of splits paths." ) parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--embed_dim", type=int, default=410, help="Embeddings dim") parser.add_argument("--hidden_dim", type=int, default=2100, help="Hidden dimension") parser.add_argument("--num_max_positions", type=int, default=256, help="Max input length") parser.add_argument("--num_heads", type=int, default=10, help="Number of heads") parser.add_argument("--num_layers", type=int, default=16, help="NUmber of layers") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--initializer_range", type=float, default=0.02, help="Normal initialization standard deviation") parser.add_argument("--sinusoidal_embeddings", action="store_true", help="Use sinusoidal embeddings") parser.add_argument( "--mlm", action="store_true", help= "Train with masked-language modeling loss instead of language modeling" ) parser.add_argument( "--mlm_probability", type=float, default=0.15, help="Ratio of tokens to mask for masked language modeling loss") parser.add_argument("--train_batch_size", type=int, default=8, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=8, help="Batch size for validation") parser.add_argument("--lr", type=float, default=2.5e-4, help="Learning rate") parser.add_argument("--max_norm", type=float, default=0.25, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay") parser.add_argument("--n_epochs", type=int, default=200, help="Number of training epochs") parser.add_argument("--n_warmup", type=int, default=1000, help="Number of warmup iterations") parser.add_argument("--eval_every", type=int, default=-1, help="Evaluate every X steps (-1 => end of epoch)") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradient") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log on main process only, logger.warning => log on all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat( args)) # This is a logger.info: only printed on the first process # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, model and optimizer") tokenizer = BertTokenizer.from_pretrained( 'bert-base-cased', do_lower_case=False) # Let's use a pre-defined tokenizer args.num_embeddings = len( tokenizer.vocab ) # We need this to create the model at next line (number of embeddings to use) model = TransformerWithLMHead(args) model.to(args.device) optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) logger.info("Model has %s parameters", sum(p.numel() for p in model.parameters() if p.requires_grad)) # Prepare model for distributed training if needed if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders( args, tokenizer) # Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original def mask_tokens(inputs): labels = inputs.clone() masked_indices = torch.bernoulli( torch.full(labels.shape, args.mlm_probability)).byte() labels[~masked_indices] = -1 # We only compute loss on masked tokens indices_replaced = torch.bernoulli(torch.full( labels.shape, 0.8)).byte() & masked_indices inputs[indices_replaced] = tokenizer.vocab[ "[MASK]"] # 80% of the time, replace masked input tokens with [MASK] indices_random = torch.bernoulli(torch.full( labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced random_words = torch.randint(args.num_embeddings, labels.shape, dtype=torch.long, device=args.device) inputs[indices_random] = random_words[ indices_random] # 10% of the time, replace masked input tokens with random word return inputs, labels # Training function and trainer def update(engine, batch): model.train() inputs = batch.transpose(0, 1).contiguous().to( args.device) # to shape [seq length, batch] inputs, labels = mask_tokens(inputs) if args.mlm else ( inputs, inputs) # Prepare masked input/labels if we use masked LM logits, loss = model(inputs, labels=labels) loss = loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): inputs = batch.transpose(0, 1).contiguous().to( args.device) # to shape [seq length, batch] inputs, labels = mask_tokens(inputs) if args.mlm else ( inputs, inputs) # Prepare masked input/labels if we use masked LM logits = model(inputs) shift_logits = logits[:-1] if not args.mlm else logits shift_labels = labels[1:] if not args.mlm else labels return shift_logits.view(-1, logits.size(-1)), shift_labels.view(-1) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_every > 0: trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: evaluator.run(val_loader) if engine.state.iteration % args.eval_every == 0 else None) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, len(train_loader) * args.n_epochs) scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr, args.n_warmup) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we average distributed metrics using average_distributed_scalar metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))} metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) # Let's convert sub-word perplexities in word perplexities. If you need details: http://sjmielke.com/comparing-perplexities.htm metrics["average_word_ppl"] = MetricsLambda( lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words), metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train if args.local_rank in [-1, 0]: checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving( trainer, evaluator, metrics, model, optimizer, args) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading if args.local_rank in [-1, 0] and args.n_epochs > 0: tb_logger.close()
def test_default_exception_handler(): update_function = MagicMock(side_effect=ValueError()) engine = Engine(update_function) with raises(ValueError): engine.run([1])
def main(): """ Basic UNet as implemented in MONAI for Fetal Brain Segmentation, but using ignite to manage training and validation loop and checkpointing :return: """ """ Read input and configuration parameters """ parser = argparse.ArgumentParser( description='Run basic UNet with MONAI - Ignite version.') parser.add_argument('--config', dest='config', metavar='config', type=str, help='config file') args = parser.parse_args() with open(args.config) as f: config_info = yaml.load(f, Loader=yaml.FullLoader) # print to log the parameter setups print(yaml.dump(config_info)) # GPU params cuda_device = config_info['device']['cuda_device'] num_workers = config_info['device']['num_workers'] # training and validation params loss_type = config_info['training']['loss_type'] batch_size_train = config_info['training']['batch_size_train'] batch_size_valid = config_info['training']['batch_size_valid'] lr = float(config_info['training']['lr']) lr_decay = config_info['training']['lr_decay'] if lr_decay is not None: lr_decay = float(lr_decay) nr_train_epochs = config_info['training']['nr_train_epochs'] validation_every_n_epochs = config_info['training'][ 'validation_every_n_epochs'] sliding_window_validation = config_info['training'][ 'sliding_window_validation'] if 'model_to_load' in config_info['training'].keys(): model_to_load = config_info['training']['model_to_load'] if not os.path.exists(model_to_load): raise BlockingIOError( "cannot find model: {}".format(model_to_load)) else: model_to_load = None if 'manual_seed' in config_info['training'].keys(): seed = config_info['training']['manual_seed'] else: seed = None # data params data_root = config_info['data']['data_root'] training_list = config_info['data']['training_list'] validation_list = config_info['data']['validation_list'] # model saving out_model_dir = os.path.join( config_info['output']['out_model_dir'], datetime.now().strftime('%Y-%m-%d_%H-%M-%S') + '_' + config_info['output']['output_subfix']) print("Saving to directory ", out_model_dir) if 'cache_dir' in config_info['output'].keys(): out_cache_dir = config_info['output']['cache_dir'] else: out_cache_dir = os.path.join(out_model_dir, 'persistent_cache') max_nr_models_saved = config_info['output']['max_nr_models_saved'] val_image_to_tensorboad = config_info['output']['val_image_to_tensorboad'] monai.config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) torch.cuda.set_device(cuda_device) if seed is not None: # set manual seed if required (both numpy and torch) set_determinism(seed=seed) # # set torch only seed # torch.manual_seed(seed) # torch.backends.cudnn.deterministic = True # torch.backends.cudnn.benchmark = False """ Data Preparation """ # create cache directory to store results for Persistent Dataset persistent_cache: Path = Path(out_cache_dir) persistent_cache.mkdir(parents=True, exist_ok=True) # create training and validation data lists train_files = create_data_list(data_folder_list=data_root, subject_list=training_list, img_postfix='_Image', label_postfix='_Label') print(len(train_files)) print(train_files[0]) print(train_files[-1]) val_files = create_data_list(data_folder_list=data_root, subject_list=validation_list, img_postfix='_Image', label_postfix='_Label') print(len(val_files)) print(val_files[0]) print(val_files[-1]) # data preprocessing for training: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) # - define 2D patches to be extracted # - add data augmentation (random rotation and random flip) # - squeeze to 2D train_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), RandRotated(keys=['img', 'seg'], degrees=90, prob=0.2, spatial_axes=[0, 1], interp_order=[1, 0], reshape=False), RandFlipd(keys=['img', 'seg'], spatial_axis=[0, 1]), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) # create a training data loader # train_ds = monai.data.Dataset(data=train_files, transform=train_transforms) # train_ds = monai.data.CacheDataset(data=train_files, transform=train_transforms, cache_rate=1.0, # num_workers=num_workers) train_ds = monai.data.PersistentDataset(data=train_files, transform=train_transforms, cache_dir=persistent_cache) train_loader = DataLoader(train_ds, batch_size=batch_size_train, shuffle=True, num_workers=num_workers, collate_fn=list_data_collate, pin_memory=torch.cuda.is_available()) # check_train_data = monai.utils.misc.first(train_loader) # print("Training data tensor shapes") # print(check_train_data['img'].shape, check_train_data['seg'].shape) # data preprocessing for validation: # - convert data to right format [batch, channel, dim, dim, dim] # - apply whitening # - resize to (96, 96) in-plane (preserve z-direction) if sliding_window_validation: val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), ToTensord(keys=['img', 'seg']) ]) do_shuffle = False collate_fn_to_use = None else: # - add extraction of 2D slices from validation set to emulate how loss is computed at training val_transforms = Compose([ LoadNiftid(keys=['img', 'seg']), AddChanneld(keys=['img', 'seg']), NormalizeIntensityd(keys=['img']), Resized(keys=['img', 'seg'], spatial_size=[96, 96], interp_order=[1, 0], anti_aliasing=[True, False]), RandSpatialCropd(keys=['img', 'seg'], roi_size=[96, 96, 1], random_size=False), SqueezeDimd(keys=['img', 'seg'], dim=-1), ToTensord(keys=['img', 'seg']) ]) do_shuffle = True collate_fn_to_use = list_data_collate # create a validation data loader # val_ds = monai.data.Dataset(data=val_files, transform=val_transforms) # val_ds = monai.data.CacheDataset(data=val_files, transform=val_transforms, cache_rate=1.0, # num_workers=num_workers) val_ds = monai.data.PersistentDataset(data=val_files, transform=val_transforms, cache_dir=persistent_cache) val_loader = DataLoader(val_ds, batch_size=batch_size_valid, shuffle=do_shuffle, collate_fn=collate_fn_to_use, num_workers=num_workers) # check_valid_data = monai.utils.misc.first(val_loader) # print("Validation data tensor shapes") # print(check_valid_data['img'].shape, check_valid_data['seg'].shape) """ Network preparation """ # Create UNet, DiceLoss and Adam optimizer. net = monai.networks.nets.UNet( dimensions=2, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) loss_function = monai.losses.DiceLoss(do_sigmoid=True) opt = torch.optim.Adam(net.parameters(), lr) device = torch.cuda.current_device() if lr_decay is not None: lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=opt, gamma=lr_decay, last_epoch=-1) """ Set ignite trainer """ # function to manage batch at training def prepare_batch(batch, device=None, non_blocking=False): return _prepare_batch((batch['img'], batch['seg']), device, non_blocking) trainer = create_supervised_trainer(model=net, optimizer=opt, loss_fn=loss_function, device=device, non_blocking=False, prepare_batch=prepare_batch) # adding checkpoint handler to save models (network params and optimizer stats) during training if model_to_load is not None: checkpoint_handler = CheckpointLoader(load_path=model_to_load, load_dict={ 'net': net, 'opt': opt, }) checkpoint_handler.attach(trainer) state = trainer.state_dict() else: checkpoint_handler = ModelCheckpoint(out_model_dir, 'net', n_saved=max_nr_models_saved, require_empty=False) # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=save_params) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'net': net, 'opt': opt }) # StatsHandler prints loss at every iteration and print metrics at every epoch train_stats_handler = StatsHandler(name='trainer') train_stats_handler.attach(trainer) # TensorBoardStatsHandler plots loss at every iteration and plots metrics at every epoch, same as StatsHandler writer_train = SummaryWriter(log_dir=os.path.join(out_model_dir, "train")) train_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_train) train_tensorboard_stats_handler.attach(trainer) if lr_decay is not None: print("Using Exponential LR decay") lr_schedule_handler = LrScheduleHandler(lr_scheduler, print_lr=True, name="lr_scheduler", writer=writer_train) lr_schedule_handler.attach(trainer) """ Set ignite evaluator to perform validation at training """ # set parameters for validation metric_name = 'Mean_Dice' # add evaluation metric to the evaluator engine val_metrics = { "Loss": 1.0 - MeanDice(add_sigmoid=True, to_onehot_y=False), "Mean_Dice": MeanDice(add_sigmoid=True, to_onehot_y=False) } def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch['img'].to(device), batch['seg'].to( device) roi_size = (96, 96, 1) seg_probs = sliding_window_inference(val_images, roi_size, batch_size_valid, net) return seg_probs, val_labels if sliding_window_validation: # use sliding window inference at validation print("3D evaluator is used") net.to(device) evaluator = Engine(_sliding_window_processor) for name, metric in val_metrics.items(): metric.attach(evaluator, name) else: # ignite evaluator expects batch=(img, seg) and returns output=(y_pred, y) at every iteration, # user can add output_transform to return other values print("2D evaluator is used") evaluator = create_supervised_evaluator(model=net, metrics=val_metrics, device=device, non_blocking=True, prepare_batch=prepare_batch) epoch_len = len(train_ds) // train_loader.batch_size validation_every_n_iters = validation_every_n_epochs * epoch_len @trainer.on(Events.ITERATION_COMPLETED(every=validation_every_n_iters)) def run_validation(engine): evaluator.run(val_loader) # add early stopping handler to evaluator # early_stopper = EarlyStopping(patience=4, # score_function=stopping_fn_from_metric(metric_name), # trainer=trainer) # evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) # add stats event handler to print validation stats via evaluator val_stats_handler = StatsHandler( name='evaluator', output_transform=lambda x: None, # no need to print loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.epoch ) # fetch global epoch number from trainer val_stats_handler.attach(evaluator) # add handler to record metrics to TensorBoard at every validation epoch writer_valid = SummaryWriter(log_dir=os.path.join(out_model_dir, "valid")) val_tensorboard_stats_handler = TensorBoardStatsHandler( summary_writer=writer_valid, output_transform=lambda x: None, # no need to plot loss value, so disable per iteration output global_epoch_transform=lambda x: trainer.state.iteration ) # fetch global iteration number from trainer val_tensorboard_stats_handler.attach(evaluator) # add handler to draw the first image and the corresponding label and model output in the last batch # here we draw the 3D output as GIF format along the depth axis, every 2 validation iterations. if val_image_to_tensorboad: val_tensorboard_image_handler = TensorBoardImageHandler( summary_writer=writer_valid, batch_transform=lambda batch: (batch['img'], batch['seg']), output_transform=lambda output: predict_segmentation(output[0]), global_iter_transform=lambda x: trainer.state.epoch) evaluator.add_event_handler( event_name=Events.ITERATION_COMPLETED(every=1), handler=val_tensorboard_image_handler) """ Run training """ state = trainer.run(train_loader, nr_train_epochs) print("Done!")
def attach(self, engine: Engine): engine.add_event_handler(Events.ITERATION_COMPLETED, self) engine.register_events(*PeriodEvents) for e in PeriodEvents: State.event_to_attr[e] = "iteration"
def train(self, config, **kwargs): """Trains a given model specified in the config file or passed as the --model parameter. All options in the config file can be overwritten as needed by passing --PARAM Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}' :param config: yaml config file :param **kwargs: parameters to overwrite yaml config """ config_parameters = utils.parse_config_or_kwargs(config, **kwargs) outputdir = os.path.join( config_parameters['outputpath'], config_parameters['model'], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) # Early init because of creating dir checkpoint_handler = ModelCheckpoint( outputdir, 'run', n_saved=3, require_empty=False, create_dir=True, score_function=self._negative_loss, score_name='loss') logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log')) logger.info("Storing files in {}".format(outputdir)) # utils.pprint_dict utils.pprint_dict(config_parameters, logger.info) logger.info("Running on device {}".format(DEVICE)) label_df = pd.read_csv(config_parameters['label'], sep='\s+') data_df = pd.read_csv(config_parameters['data'], sep='\s+') # In case that both are not matching merged = data_df.merge(label_df, on='filename') common_idxs = merged['filename'] data_df = data_df[data_df['filename'].isin(common_idxs)] label_df = label_df[label_df['filename'].isin(common_idxs)] train_df, cv_df = utils.split_train_cv( label_df, **config_parameters['data_args']) train_label = utils.df_to_dict(train_df) cv_label = utils.df_to_dict(cv_df) data = utils.df_to_dict(data_df) transform = utils.parse_transforms(config_parameters['transforms']) torch.save(config_parameters, os.path.join(outputdir, 'run_config.pth')) logger.info("Transforms:") utils.pprint_dict(transform, logger.info, formatter='pretty') assert len(cv_df) > 0, "Fraction a bit too large?" trainloader = dataset.gettraindataloader( h5files=data, h5labels=train_label, transform=transform, label_type=config_parameters['label_type'], batch_size=config_parameters['batch_size'], num_workers=config_parameters['num_workers'], shuffle=True, ) cvdataloader = dataset.gettraindataloader( h5files=data, h5labels=cv_label, label_type=config_parameters['label_type'], transform=None, shuffle=False, batch_size=config_parameters['batch_size'], num_workers=config_parameters['num_workers'], ) model = getattr(models, config_parameters['model'], 'CRNN')(inputdim=trainloader.dataset.datadim, outputdim=2, **config_parameters['model_args']) if 'pretrained' in config_parameters and config_parameters[ 'pretrained'] is not None: model_dump = torch.load(config_parameters['pretrained'], map_location='cpu') model_state = model.state_dict() pretrained_state = { k: v for k, v in model_dump.items() if k in model_state and v.size() == model_state[k].size() } model_state.update(pretrained_state) model.load_state_dict(model_state) logger.info("Loading pretrained model {}".format( config_parameters['pretrained'])) model = model.to(DEVICE) optimizer = getattr( torch.optim, config_parameters['optimizer'], )(model.parameters(), **config_parameters['optimizer_args']) utils.pprint_dict(optimizer, logger.info, formatter='pretty') utils.pprint_dict(model, logger.info, formatter='pretty') if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1: logger.info("Using {} GPUs!".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) criterion = getattr(losses, config_parameters['loss'])().to(DEVICE) def _train_batch(_, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() output = self._forward( model, batch) # output is tuple (clip, frame, target) loss = criterion(*output) loss.backward() # Single loss optimizer.step() return loss.item() def _inference(_, batch): model.eval() with torch.no_grad(): return self._forward(model, batch) def thresholded_output_transform(output): # Output is (clip, frame, target, lengths) _, y_pred, y, y_clip, length = output batchsize, timesteps, ndim = y.shape idxs = torch.arange(timesteps, device='cpu').repeat(batchsize).view( batchsize, timesteps) mask = (idxs < length.view(-1, 1)).to(y.device) y = y * mask.unsqueeze(-1) y_pred = torch.round(y_pred) y = torch.round(y) return y_pred, y metrics = { 'Loss': losses.Loss( criterion), #reimplementation of Loss, supports 3 way loss 'Precision': Precision(thresholded_output_transform), 'Recall': Recall(thresholded_output_transform), 'Accuracy': Accuracy(thresholded_output_transform), } train_engine = Engine(_train_batch) inference_engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(inference_engine, name) def compute_metrics(engine): inference_engine.run(cvdataloader) results = inference_engine.state.metrics output_str_list = [ "Validation Results - Epoch : {:<5}".format(engine.state.epoch) ] for metric in metrics: output_str_list.append("{} {:<5.2f}".format( metric, results[metric])) logger.info(" ".join(output_str_list)) pbar.n = pbar.last_print_n = 0 pbar = ProgressBar(persist=False) pbar.attach(train_engine) train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=5000), compute_metrics) train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics) early_stop_handler = EarlyStopping( patience=config_parameters['early_stop'], score_function=self._negative_loss, trainer=train_engine) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, early_stop_handler) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, }) train_engine.run(trainloader, max_epochs=config_parameters['epochs']) return outputdir
def __call__(self, engine: Engine): for period, event in self.INTERVAL_TO_EVENT.items(): if engine.state.iteration % period == 0: engine.fire_event(event)
def attach(self, engine: Engine): """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ return engine.add_event_handler(Events.ITERATION_COMPLETED, self)
def attach(self, engine: Engine): engine.add_event_handler(Events.ITERATION_COMPLETED, self) engine.register_events(*EpisodeEvents) State.event_to_attr[EpisodeEvents.EPISODE_COMPLETED] = "episode" State.event_to_attr[EpisodeEvents.BOUND_REWARD_REACHED] = "episode" State.event_to_attr[EpisodeEvents.BEST_REWARD_REACHED] = "episode"
def main(): config.print_config() logging.basicConfig(stream=sys.stdout, level=logging.INFO) tempdir = tempfile.mkdtemp() print(f"generating synthetic data to {tempdir} (this may take a while)") for i in range(5): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(tempdir, f"im{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(tempdir, f"seg{i:d}.nii.gz")) images = sorted(glob(os.path.join(tempdir, "im*.nii.gz"))) segs = sorted(glob(os.path.join(tempdir, "seg*.nii.gz"))) # define transforms for image and segmentation imtrans = Compose([ScaleIntensity(), AddChannel(), ToTensor()]) segtrans = Compose([AddChannel(), ToTensor()]) ds = NiftiDataset(images, segs, transform=imtrans, seg_transform=segtrans, image_only=False) device = torch.device("cuda:0") net = UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ) net.to(device) # define sliding window size and batch size for windows inference roi_size = (96, 96, 96) sw_batch_size = 4 def _sliding_window_processor(engine, batch): net.eval() with torch.no_grad(): val_images, val_labels = batch[0].to(device), batch[1].to(device) seg_probs = sliding_window_inference(val_images, roi_size, sw_batch_size, net) return seg_probs, val_labels evaluator = Engine(_sliding_window_processor) # add evaluation metric to the evaluator engine MeanDice(add_sigmoid=True, to_onehot_y=False).attach(evaluator, "Mean_Dice") # StatsHandler prints loss at every iteration and print metrics at every epoch, # we don't need to print loss for evaluator, so just print metrics, user can also customize print functions val_stats_handler = StatsHandler( name="evaluator", output_transform=lambda x: None, # no need to print loss value, so disable per iteration output ) val_stats_handler.attach(evaluator) # for the array data format, assume the 3rd item of batch data is the meta_data file_saver = SegmentationSaver( output_dir="tempdir", output_ext=".nii.gz", output_postfix="seg", name="evaluator", batch_transform=lambda x: x[2], output_transform=lambda output: predict_segmentation(output[0]), ) file_saver.attach(evaluator) # the model was trained by "unet_training_array" example ckpt_saver = CheckpointLoader(load_path="./runs/net_checkpoint_100.pth", load_dict={"net": net}) ckpt_saver.attach(evaluator) # sliding window inference for one image at every iteration loader = DataLoader(ds, batch_size=1, num_workers=1, pin_memory=torch.cuda.is_available()) state = evaluator.run(loader) shutil.rmtree(tempdir)
def attach(self, engine: Engine, manual_step: bool = False): self._timer.attach( engine, step=None if manual_step else Events.ITERATION_COMPLETED) engine.add_event_handler(EpisodeEvents.EPISODE_COMPLETED, self)
with torch.no_grad(): # Anneal learning rate mu = next(mu_scheme) i = engine.state.iteration for group in optimizer.param_groups: group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i) return { "elbo": elbo.item(), "kl": kl_divergence.item(), "sigma": sigma, "mu": mu } # Trainer and metrics trainer = Engine(step) metric_names = ["elbo", "kl", "sigma", "mu"] RunningAverage(output_transform=lambda x: x["elbo"]).attach( trainer, "elbo") RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl") RunningAverage(output_transform=lambda x: x["sigma"]).attach( trainer, "sigma") RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu") ProgressBar().attach(trainer, metric_names=metric_names) # Model checkpointing checkpoint_handler = ModelCheckpoint("./", "checkpoint", save_interval=1, n_saved=3, require_empty=False)
def main(): # Init state params params = init_parms() device = params.get('device') # Loading the model, optimizer & criterion model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) model = torch.nn.DataParallel(model) logger.info( f'Model initialized with {get_model_size(model):.3f}M parameters') optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) load_checkpoint(model, optimizer, params) start_epoch = params['start_epoch'] sup_criterion = CustomCTCLoss() # Validation progress bars defined here. pbar = ProgressBar(persist=True, desc="Loss") pbar_valid = ProgressBar(persist=True, desc="Validate") # load timer and best meter to keep track of state params timer = Timer(average=True) # load all the train data logger.info('Begining to load Datasets') trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en') # form data loaders train = lmdbMultiDatasetTester(roots=[trainAirtelPaymentsPath], transform=image_val_transform) logger.info(f'loaded train & test dataset = {len(train)}') def train_update_function(engine, _): optimizer.zero_grad() imgs_sup, labels_sup, label_lengths, input_lengths = next( engine.state.train_loader_labbeled) imgs_sup = imgs_sup.to(device) labels_sup = labels_sup probs_sup = model(imgs_sup) sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths, input_lengths) sup_loss.backward() optimizer.step() return sup_loss.item() @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths, image_lengths = batch y_pred = model(img.to(device)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = sequence_to_string(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) train_loader = torch.utils.data.DataLoader(train, batch_size=train_batch_size, shuffle=True, num_workers=config.workers, pin_memory=True, collate_fn=allign_collate) trainer = Engine(train_update_function) evaluator = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} iteration_log_step = int(0.33 * len(train_loader)) for name, metric in metrics.items(): metric.attach(evaluator, name) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config.lr_gamma, patience=int(config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5) pbar.attach(trainer, output_transform=lambda x: {'loss': x}) pbar_valid.attach(evaluator, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) timer.attach(trainer) @trainer.on(Events.STARTED) def set_init_epoch(engine): engine.state.epoch = params['start_epoch'] logger.info(f'Initial epoch for trainer set to {engine.state.epoch}') @trainer.on(Events.EPOCH_STARTED) def set_model_train(engine): if hasattr(engine.state, 'train_loader_labbeled'): del engine.state.train_loader_labbeled engine.state.train_loader_labbeled = iter(train_loader) @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(engine): if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0): engine.state.epoch += 1 train.set_epochs(engine.state.epoch) model.eval() logger.info('Model set to eval mode') evaluator.run(train_loader) model.train() logger.info('Model set back to train mode') @trainer.on(Events.EPOCH_COMPLETED) def after_complete(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() trainer.run(train_loader, max_epochs=epochs) tb_logger.close()
def test_pbar_fail_with_non_callable_transform(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, output_transform=1)
def main( dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, classifier_weight ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" wandb.init(project=args.dataset) check_manual_seed(seed) image_shape = (64,64,3) # if args.dataset == "task1": num_classes = 24 # else : num_classes = 40 num_classes = 40 # Note: unsupported for now multi_class = True #It's True but this variable doesn't be used now # if args.dataset == "task1": # dataset_train = CLEVRDataset(root_folder=args.dataroot,img_folder=args.dataroot+'images/') # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) # else : # dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' # train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) dataset_train = CelebALoader(root_folder=args.dataroot) #'/home/arg/courses/machine_learning/homework/deep_learning_and_practice/Lab7/dataset/task_2/' train_loader = DataLoader(dataset_train,batch_size=args.batch_size,shuffle=True,drop_last=True) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) wandb.watch(model) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) ### x: torch.Size([batchsize, 3, 64, 64]); y: torch.Size([batchsize, 24]); z: torch.Size([batchsize, 48, 8, 8]) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint( output_dir, "glow", n_saved=None, require_empty=False ) ### n_saved (Optional[int]) – Number of objects that should be kept on disk. Older files will be removed. If set to None, all objects are kept. trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {"model": model, "optimizer": optimizer}, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss" ) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) if saved_model: model.load_state_dict(torch.load(saved_model, map_location="cpu")['model']) model.set_actnorm_init() @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) # evaluator = evaluation_model(args.classifier_weight) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # if args.dataset == "task1": # model.eval() # with torch.no_grad(): # test_conditions = get_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # score = evaluator.eval(predict_x, test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_score{score:.3f}.png", normalize=True) # test_conditions = get_new_test_conditions(args.dataroot).cuda() # predict_x = postprocess(model(y_onehot=test_conditions, temperature=1, reverse=True)).float() # newscore = evaluator.eval(predict_x.float(), test_conditions) # save_image(predict_x.float(), args.output_dir+f"/Epoch{engine.state.epoch}_newscore{newscore:.3f}.png", normalize=True) # print(f"Iter: {engine.state.iteration} score:{score:.3f} newscore:{newscore:.3f} ") # wandb.log({"score": score, "new_score": newscore}) trainer.run(train_loader, epochs)
def test_add_event_handler_raises_with_invalid_event(): engine = Engine(lambda e, b: 1) with pytest.raises(ValueError, match=r"is not a valid event for this Engine"): engine.add_event_handler("incorrect", lambda engine: None)
def train(): ################################ Model Config ################################### if args.lbl_method == "BIO": num_labels_emo = 4 # O POS NEG NORM num_labels_ent = 3 # O B I else: num_labels_emo = 4 # O POS NEG NORM num_labels_ent = 5 # O B I E S if not args.freeze_step > 0: if args.net == "3": model = NetY3.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) elif args.net == "2": model = NetY2.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) elif args.net == "1": model = NetY1.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) elif args.net == "4": model = NetY4.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) else: print("in net 5.................") if args.net == "5": model = NetY5_fz.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) else: model = NetY3_fz.from_pretrained(args.bert_model, cache_dir="", num_labels_ent=num_labels_ent, num_labels_emo=num_labels_emo, dp=args.dp) ########################### Freeze First ################################### model.freeze() print("freezed model") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model = torch.nn.DataParallel(model) ################################# hyper parameters ########################### # alpha = 0.5 # 0.44 # alpha = 0.6 # 0.42 # alpha = 0.7 # alpha = 1.2 # alpha = 0.8 # alpha = 0.7 # alphas = [1, 0.9, 0.8, 0.8, 0.8, 0.8, 0.8] if not args.multi: alpha = args.alpha else: alphas = [0.1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] print("alphas: ", " ".join(map(str, alphas))) # alphas = [2,1,1,0.8,0.8,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5,0.5] # ------------------------------ load model from file ------------------------- model_file = os.path.join(args.checkpoint_model_dir, args.ckp) if os.path.exists(model_file): model.load_state_dict(torch.load(model_file)) print("load checkpoint: {} successfully!".format(model_file)) # ----------------------------------------------------------------------------- trn_dataloader, val_dataloader, trn_size = get_data_loader() ############################## Optimizer ################################### param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [ p for n, p in param_optimizer if (not any(nd in n for nd in no_decay)) and p.requires_grad ], 'weight_decay': args.wd }, { 'params': [ p for n, p in param_optimizer if any(nd in n for nd in no_decay) and p.requires_grad ], 'weight_decay': 0.0 }] # num_train_optimization_steps = int( trn_size / args.batch_size / args.gradient_accumulation_steps) * args.epochs num_train_optimization_steps = int( trn_size / args.batch_size) * args.epochs + 5 optimizer = BertAdam(optimizer_grouped_parameters, lr=args.lr, warmup=args.warmup_proportion, t_total=num_train_optimization_steps) # optimizer = Adam(filter(lambda p:p.requires_grad, model.parameters()), args.lr, weight_decay=5e-3) ###################################################################### if not args.focal: if not args.wc: criterion = torch.nn.CrossEntropyLoss() else: # O B I E S wc_ent = torch.tensor([0.5, 2, 1, 2, 2]).to(device) # O POS NEG NORM wc_emo = torch.tensor([0.5, 0.8, 1, 1]).to(device) criterion_ent = torch.nn.CrossEntropyLoss(weight=wc_ent) criterion_emo = torch.nn.CrossEntropyLoss(weight=wc_emo) print("weight classes over!") else: criterion = FocalLoss(args.gamma) if args.lf: print("lr finding.........") import math init_value = 1e-8 final_value = 10 beta = 0.98 num = len(trn_dataloader) - 1 mult = (final_value / init_value)**(1 / num) lr = init_value optimizer.param_groups[0]['lr'] = lr optimizer.param_groups[1]['lr'] = lr optimizer.param_groups[0]['lr'] = lr optimizer.param_groups[1]['lr'] = lr def lr_find(engine, batch): batch_num = engine.state.iteration if engine.state.metrics.get("avg_loss") is None: engine.state.metrics["avg_loss"] = 0. engine.state.metrics["best_loss"] = 0. engine.state.metrics["lr"] = lr engine.state.metrics["losses"] = [] engine.state.metrics["log_lrs"] = [] model.train() batch = tuple(t.to(device) for t in batch) input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch optimizer.zero_grad() act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model( input_ids, myinput_ids, segment_ids, input_mask, label_ent_ids, label_emo_ids) # Only keep active parts of the loss if not args.wc: loss_ent = criterion(act_logits_ent, act_y_ent) loss_emo = criterion(act_logits_emo, act_y_emo) else: loss_ent = criterion_ent(act_logits_ent, act_y_ent) loss_emo = criterion_emo(act_logits_emo, act_y_emo) # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo if not args.multi: loss = alpha * loss_ent + loss_emo else: loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo # Compute the smoothed loss avg_loss = beta * engine.state.metrics.get("avg_loss") + ( 1 - beta) * loss.item() smoothed_loss = avg_loss / (1 - beta**batch_num) engine.state.metrics["avg_loss"] = avg_loss # Stop if the loss is exploding if batch_num > 1 and smoothed_loss > 4 * engine.state.metrics.get( "best_loss") or torch.isnan(torch.tensor(smoothed_loss)): # engine.terminate() engine.terminate_epoch() # Record the best loss if smoothed_loss < engine.state.metrics.get( "best_loss") or batch_num == 1: best_loss = smoothed_loss engine.state.metrics["best_loss"] = best_loss # print("cur batch:{} smothed_loss: {} best_loss: {} ".format(batch_num, smoothed_loss, engine.state.metrics["best_loss"])) # Store the values engine.state.metrics["losses"].append(smoothed_loss) engine.state.metrics["log_lrs"].append( math.log10(engine.state.metrics["lr"])) # Do the SGD step loss.backward() optimizer.step() # Update the lr for the next step engine.state.metrics["lr"] *= mult optimizer.param_groups[0]['lr'] = engine.state.metrics["lr"] optimizer.param_groups[1]['lr'] = engine.state.metrics["lr"] trn_lr_fineder = Engine(lr_find) @trn_lr_fineder.on(Events.EPOCH_COMPLETED) def draw_lr(engine): import matplotlib.pyplot as plt plt.plot( engine.state.metrics.get("log_lrs")[10:-5], engine.state.metrics.get("losses")[10:-5]) plt.savefig("lr_finder.jpg") print("lr find end!") pbar = ProgressBar(persist=True) pbar.attach(trn_lr_fineder) trn_lr_fineder.run(trn_dataloader, max_epochs=1) else: iterations = None def step(engine, batch): if args.freeze_step > 0: if args.net != "5": if engine.state.epoch - 1 == args.freeze_step: freeze_paras = model.module.unfreeze() # opt unfreeze optimizer.add_param_group({'params': freeze_paras}) # run only once args.freeze_step = -1 else: # print(f'in net 5, freeze_step:{args.freeze_step}, epos:{engine.state.epoch}') if engine.state.epoch == args.freeze_step: # model.module.freeze() print("freeze net 5...") for param in model.module.bert.parameters(): param.requires_grad = False for param in model.module.classifier_ent.parameters(): param.requires_grad = False print("freeze net 5 over") args.freeze_step = -1 if args.eval_radio > 0: global iterations iterations = len(trn_dataloader) // len(batch) model.train() batch = tuple(t.to(device) for t in batch) input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch optimizer.zero_grad() act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model( input_ids, myinput_ids, segment_ids, input_mask, label_ent_ids, label_emo_ids) # Only keep active parts of the loss if args.wc: loss_ent = criterion_ent(act_logits_ent, act_y_ent) loss_emo = criterion_emo(act_logits_emo, act_y_emo) # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo if not args.smooth_beta > 0: if not args.multi: if not (args.net == "5" and args.freeze_step < 0): loss_ent = criterion(act_logits_ent, act_y_ent) loss_emo = criterion(act_logits_emo, act_y_emo) loss = alpha * loss_ent + loss_emo else: loss_ent = torch.tensor([0]) loss_emo = criterion(act_logits_emo, act_y_emo) # print("freeze net5 over") loss = loss_emo else: loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo else: # smooth if not args.multi: loss = alpha * loss_ent + loss_emo else: loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo if engine.state.metrics.get("smooth_loss"): loss = loss * ( 1 - args.smooth_beta ) + args.smooth_beta * engine.state.metrics["smooth_loss"] if engine.state.metrics.get("total_loss") is None: engine.state.metrics["total_loss"] = loss.item() engine.state.metrics["ent_loss"] = loss_ent.item() engine.state.metrics["emo_loss"] = loss_emo.item() else: engine.state.metrics["total_loss"] += loss.item() engine.state.metrics["ent_loss"] += loss_ent.item() engine.state.metrics["emo_loss"] += loss_emo.item() engine.state.metrics["smooth_loss"] = loss engine.state.metrics["batchloss"] = loss.item() engine.state.metrics["batchloss_ent"] = loss_ent.item() engine.state.metrics["batchloss_emo"] = loss_emo.item() loss.backward() optimizer.step() return loss.item( ), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] def infer(engine, batch): model.eval() batch = tuple(t.to(device) for t in batch) input_ids, myinput_ids, input_mask, segment_ids, label_ent_ids, label_emo_ids = batch with torch.no_grad(): act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids = model( input_ids, myinput_ids, segment_ids, input_mask, label_ent_ids, label_emo_ids) # Only keep active parts of the loss if not args.wc: loss_ent = criterion(act_logits_ent, act_y_ent) loss_emo = criterion(act_logits_emo, act_y_emo) else: loss_ent = criterion_ent(act_logits_ent, act_y_ent) loss_emo = criterion_emo(act_logits_emo, act_y_emo) # loss = alphas[engine.state.epoch-1] * loss_ent + loss_emo if not args.smooth_beta > 0: if not args.multi: loss = alpha * loss_ent + loss_emo else: loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo else: if not args.multi: loss = alpha * loss_ent + loss_emo else: loss = loss_ent + alphas[engine.state.epoch - 1] * loss_emo if engine.state.metrics.get("smooth_loss"): loss = loss * ( 1 - args.smooth_beta ) + args.smooth_beta * engine.state.metrics[ "smooth_loss"] if engine.state.metrics.get("total_loss") is None: engine.state.metrics["total_loss"] = loss.item() engine.state.metrics["ent_loss"] = loss_ent.item() engine.state.metrics["emo_loss"] = loss_emo.item() else: engine.state.metrics["total_loss"] += loss.item() engine.state.metrics["ent_loss"] += loss_ent.item() engine.state.metrics["emo_loss"] += loss_emo.item() engine.state.metrics["smooth_loss"] = loss engine.state.metrics["batchloss"] = loss.item() engine.state.metrics["batchloss_ent"] = loss_ent.item() engine.state.metrics["batchloss_emo"] = loss_emo.item() # act_logits = torch.argmax(torch.softmax(act_logits, dim=-1), dim=-1) # [-1, 1] # loss = loss.mean() return loss.item( ), act_logits_ent, act_y_ent, act_logits_emo, act_y_emo, act_myinput_ids # [-1, 11] trainer = Engine(step) trn_evaluator = Engine(infer) val_evaluator = Engine(infer) val_evaluator_iteration = Engine(infer) ############################## Custom Period Event ################################### ''' cpe1 = CustomPeriodicEvent(n_epochs=1) cpe1.attach(trainer) cpe2 = CustomPeriodicEvent(n_epochs=2) cpe2.attach(trainer) cpe3 = CustomPeriodicEvent(n_epochs=3) cpe3.attach(trainer) cpe5 = CustomPeriodicEvent(n_epochs=5) cpe5.attach(trainer) ''' if args.eval_step > 0: cpe = CustomPeriodicEvent(n_iterations=args.eval_step) cpe.attach(trainer) if args.eval_radio > 0: cpe2 = CustomPeriodicEvent( n_iterations=iterations * args.eval_radio + 1) cpe2.attach(trainer) ############################## My F1 ################################### F1 = FScore(output_transform=lambda x: [x[1], x[2], x[3], x[4], x[-1]], lbl_method=args.lbl_method) F1.attach(val_evaluator, "F1") F1.attach(val_evaluator_iteration, "F1") ##################################### progress bar ######################### RunningAverage(output_transform=lambda x: x[0]).attach( trainer, 'batch_loss') pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["batch_loss"]) ##################################### Evaluate ######################### @trainer.on(Events.EPOCH_COMPLETED) def compute_val_metric(engine): # trainer engine engine.state.metrics["total_loss"] /= engine.state.iteration engine.state.metrics["ent_loss"] /= engine.state.iteration engine.state.metrics["emo_loss"] /= engine.state.iteration pbar.log_message( "Training - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}" .format(engine.state.metrics["total_loss"], engine.state.metrics["ent_loss"], engine.state.metrics["emo_loss"])) val_evaluator.run(val_dataloader) metrics = val_evaluator.state.metrics ent_loss = metrics["ent_loss"] emo_loss = metrics["emo_loss"] f1 = metrics['F1'] pbar.log_message( "Validation Results - Epoch: {} Ent_loss: {:.4f}, Emo_loss: {:.4f}, F1: {:.4f}" .format(engine.state.epoch, ent_loss, emo_loss, f1)) pbar.n = pbar.last_print_n = 0 # trainer.add_event_handler(Events.EPOCH_COMPLETED, compute_val_metric) def compute_val_metric_iteration(engine): # trainer engine val_evaluator_iteration.run(val_dataloader) metrics = val_evaluator_iteration.state.metrics f1 = metrics['F1'] pbar.log_message( "Validation Results - Iteration: {} F1: {:.4f}".format( engine.state.iteration, f1)) pbar.n = pbar.last_print_n = 0 if args.eval_step > 0: trainer.add_event_handler( eval(f"cpe.Events.ITERATIONS_{args.eval_step}_COMPLETED"), compute_val_metric_iteration) if args.eval_radio > 0: trainer.add_event_handler( eval( f"cpe2.Events.ITERATIONS_{iterations * args.eval_radio + 1}_COMPLETED" ), compute_val_metric_iteration) @val_evaluator.on(Events.EPOCH_COMPLETED) def reduct_step(engine): engine.state.metrics["total_loss"] /= engine.state.iteration engine.state.metrics["ent_loss"] /= engine.state.iteration engine.state.metrics["emo_loss"] /= engine.state.iteration pbar.log_message( "Validation - total_loss: {:.4f} ent_loss: {:.4f} emo_loss: {:.4f}" .format(engine.state.metrics["total_loss"], engine.state.metrics["ent_loss"], engine.state.metrics["emo_loss"])) ###################################################################### ############################## checkpoint ################################### def best_f1(engine): f1 = engine.state.metrics["F1"] # loss = engine.state.metrics["loss"] return f1 if not args.lite: ckp_dir = os.path.join(args.checkpoint_model_dir, "full", args.hyper_cfg) else: ckp_dir = os.path.join(args.checkpoint_model_dir, "lite", args.hyper_cfg) checkpoint_handler = ModelCheckpoint( ckp_dir, 'ckp', # save_interval=args.checkpoint_interval, score_function=best_f1, score_name="F1", n_saved=5, require_empty=False, create_dir=True) # trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, # to_save={'model_3FC': model}) val_evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'model_title': model}) if args.eval_step > 0: val_evaluator_iteration.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={"model_iter": model}) ###################################################################### ############################## earlystopping ################################### stopping_handler = EarlyStopping(patience=2, score_function=best_f1, trainer=trainer) val_evaluator.add_event_handler(Events.COMPLETED, stopping_handler) ###################################################################### #################################### tb logger ################################## # 在已经在对应基础上计算了 metric 的值 (compute_metric) 后 取值 log if not args.lite: tb_logger = TensorboardLogger( log_dir=os.path.join(args.log_dir, "full", args.hyper_cfg)) else: tb_logger = TensorboardLogger( log_dir=os.path.join(args.log_dir, "lite", args.hyper_cfg)) ''' tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", output_transform=lambda x: {'batchloss': x[0]}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", output_transform=lambda x: {'batchloss': x[0]}), event_name=Events.ITERATION_COMPLETED) ''' tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=[ "batchloss", "batchloss_ent", "batchloss_emo" ]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", metric_names=[ "batchloss", "batchloss_ent", "batchloss_emo" ]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=[ "total_loss", "ent_loss", "emo_loss" ]), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer, # log_handler=OutputHandler(tag="training", output_transform=lambda x: {'loss': x[0]}), # event_name=Events.EPOCH_COMPLETED) ''' tb_logger.attach(trn_evaluator, log_handler=OutputHandler(tag="training", metric_names=["F1"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) ''' tb_logger.attach(val_evaluator, log_handler=OutputHandler(tag="validation", metric_names=[ "total_loss", "ent_loss", "emo_loss", "F1" ], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) if args.eval_step > 0: tb_logger.attach(val_evaluator_iteration, log_handler=OutputHandler( tag="validation_iteration", metric_names=["F1"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer, "lr"), event_name=Events.EPOCH_COMPLETED) ''' tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer, # log_handler=GradsScalarHandler(model), # event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) ''' trainer.run(trn_dataloader, max_epochs=args.epochs) tb_logger.close()