def main(dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs, learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device, alpha, output_dir): # seed check_manual_seed(seed) # Summarywriter writer = SummaryWriter(output_dir+ "board") # logger logger = set_logger("GAN_model", output_dir,0) logger.info("start training") # data dataset, num_channels = check_dataset(dataset, dataroot) loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) # netowrks netG = Generator(z_dim, g_filters, num_channels).to(device) netD = Discriminator(num_channels, d_filters).to(device) # criterion bce = nn.BCELoss() # optimizers optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G)) if saved_D: netD.load_state_dict(torch.load(saved_D)) # misc real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device) def get_noise(): return torch.randn(batch_size, z_dim, 1, 1, device=device) # The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return { 'errD': errD.item(), 'errG': errG.item(), 'D_x': D_x, 'D_G_z1': D_G_z1, 'D_G_z2': D_G_z2 } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, save_interval=1, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2'] RunningAverage(alpha=alpha, output_transform=lambda x: x['errD']).attach(trainer, 'errD') RunningAverage(alpha=alpha, output_transform=lambda x: x['errG']).attach(trainer, 'errG') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_x']).attach(trainer, 'D_x') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z1']).attach(trainer, 'D_G_z1') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z2']).attach(trainer, 'D_G_z2') # attach progress bar @trainer.on(Events.ITERATION_COMPLETED) def log_training_results(engine): iter = (engine.state.iteration - 1) % len(loader) + 1 if iter % PRINT_FREQ == 0: logger.info("Epoch[{}] Iteration[{}/{}] errD:{:.3f}\ errG:{:.3f} D_x:{:.3f} D_G_z1:{:.3f} D_G_z2:{:.3f}" .format(engine.state.epoch, iter,len(loader), engine.state.metrics["errD"] ,engine.state.metrics["errG"],engine.state.metrics["D_x"],engine.state.metrics["D_G_z1"], engine.state.metrics["D_G_z2"])) writer.add_scalars("train", {"errD": engine.state.metrics["errD"]}, engine.state.iteration) writer.add_scalars("train", {"errG": engine.state.metrics["errG"]}, engine.state.iteration) writer.add_scalars("train", {"D_x": engine.state.metrics["D_x"]}, engine.state.iteration) writer.add_scalars("train", {"D_G_z1": engine.state.metrics["D_G_z1"]}, engine.state.iteration) writer.add_scalars("train", {"D_G_z2": engine.state.metrics["D_G_z2"]}, engine.state.iteration) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'netG': netG, 'netD': netD }) # automatically adding handlers via a special `attach` method of `Timer` handler ''' engine (Engine): Engine that this timer will be attached to. start (Events): Event which should start (reset) the timer. pause (Events): Event which should pause the timer. resume (Events, optional): Event which should resume the timer. step (Events, optional): Event which should call the `step` method of the counter. ''' timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # adding handlers using `trainer.on` decorator API # timer.step_count # timer.value 一个样本运行事件 @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info("Epoch {} done. Time per batch:{:.3f}[s] Speed:{:.1f}[samples/s]" .format(engine.state.epoch, timer.value() * timer.step_count, loader.batch_size / timer.value())) logger.info("timer.step_count:{:.3f}, timer.value:{:.3f}".format(timer.step_count, timer.value())) timer.reset() # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') create_plots(engine) checkpoint_handler(engine, { 'netG_exception': netG, 'netD_exception': netD }) else: raise e # Setup is done. Now let's run the training trainer.run(loader, epochs) writer.close()
def attach_running_average(engine, metric_name): RunningAverage(output_transform=lambda x: x[metric_name]).attach( engine, metric_name, )
def do_train_with_center(cfg, model, center_criterion, train_loader, val_loader, optimizer, optimizer_center, scheduler, loss_fn, num_query, start_epoch): log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer_with_center( model, center_criterion, optimizer, optimizer_center, loss_fn, cfg.SOLVER.CENTER_LOSS_WEIGHT, device=device) evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device) checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=10, require_empty=False) timer = Timer(average=True) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpointer, { 'model': model, 'optimizer': optimizer, 'optimizer_center': optimizer_center }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): global ITER ITER += 1 if ITER % log_period == 0: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) if len(train_loader) == ITER: ITER = 0 # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.1f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): if engine.state.epoch % eval_period == 0: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("mAP: {:.1%}".format(mAP)) for r in [1, 5, 10]: logger.info("CMC curve, Rank-{:<3}:{:.1%}".format( r, cmc[r - 1])) trainer.run(train_loader, max_epochs=epochs)
# Evaluating 시 process_function def evaluate_process(engine, batch): model.float().to(device).eval() with torch.no_grad(): _, font = batch font = font.float().to(device) font_hat, latent_vectors = model(font) return font, font_hat, latent_vectors trainer = Engine(train_process) evaluator = Engine(evaluate_process) RunningAverage(output_transform=lambda x: x).attach(trainer, 'mse') Loss(F.mse_loss, output_transform=lambda x: [x[1], x[0]]).attach(evaluator, 'mse') desc = "ITERATION - loss: {:.5f}" pbar = tqdm(initial=0, leave=False, total=len(train_loader), desc=desc.format(0)) train_history = [] # valid_history = [] @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine):
def __call__(self, model, train_dataset, val_dataset=None, **_): """Train a PyTorch model. Args: model (torch.nn.Module): PyTorch model to train. train_dataset (torch.utils.data.Dataset): Dataset used to train. val_dataset (torch.utils.data.Dataset, optional): Dataset used to validate. Returns: trained_model (torch.nn.Module): Trained PyTorch model. """ assert train_dataset is not None train_params = self.train_params mlflow_logging = self.mlflow_logging if mlflow_logging: try: import mlflow # NOQA except ImportError: log.warning( "Failed to import mlflow. MLflow logging is disabled.") mlflow_logging = False loss_fn = train_params.get("loss_fn") assert loss_fn epochs = train_params.get("epochs") seed = train_params.get("seed") optimizer = train_params.get("optimizer") assert optimizer optimizer_params = train_params.get("optimizer_params", dict()) train_dataset_size_limit = train_params.get("train_dataset_size_limit") if train_dataset_size_limit: train_dataset = PartialDataset(train_dataset, train_dataset_size_limit) log.info("train dataset size is set to {}".format( len(train_dataset))) val_dataset_size_limit = train_params.get("val_dataset_size_limit") if val_dataset_size_limit and (val_dataset is not None): val_dataset = PartialDataset(val_dataset, val_dataset_size_limit) log.info("val dataset size is set to {}".format(len(val_dataset))) train_data_loader_params = train_params.get("train_data_loader_params", dict()) val_data_loader_params = train_params.get("val_data_loader_params", dict()) evaluation_metrics = train_params.get("evaluation_metrics") evaluate_train_data = train_params.get("evaluate_train_data") evaluate_val_data = train_params.get("evaluate_val_data") progress_update = train_params.get("progress_update") scheduler = train_params.get("scheduler") scheduler_params = train_params.get("scheduler_params", dict()) model_checkpoint = train_params.get("model_checkpoint") model_checkpoint_params = train_params.get("model_checkpoint_params") early_stopping_params = train_params.get("early_stopping_params") time_limit = train_params.get("time_limit") cudnn_deterministic = train_params.get("cudnn_deterministic") cudnn_benchmark = train_params.get("cudnn_benchmark") if seed: torch.manual_seed(seed) np.random.seed(seed) if cudnn_deterministic: torch.backends.cudnn.deterministic = cudnn_deterministic if cudnn_benchmark: torch.backends.cudnn.benchmark = cudnn_benchmark device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) optimizer_ = optimizer(model.parameters(), **optimizer_params) trainer = create_supervised_trainer(model, optimizer_, loss_fn=loss_fn, device=device) train_data_loader_params.setdefault("shuffle", True) train_data_loader_params.setdefault("drop_last", True) train_data_loader_params["batch_size"] = _clip_batch_size( train_data_loader_params.get("batch_size", 1), train_dataset, "train") train_loader = DataLoader(train_dataset, **train_data_loader_params) RunningAverage(output_transform=lambda x: x, alpha=0.98).attach(trainer, "ema_loss") RunningAverage(output_transform=lambda x: x, alpha=2**(-1022)).attach(trainer, "batch_loss") if scheduler: class ParamSchedulerSavingAsMetric( ParamSchedulerSavingAsMetricMixIn, scheduler): pass cycle_epochs = scheduler_params.pop("cycle_epochs", 1) scheduler_params.setdefault("cycle_size", int(cycle_epochs * len(train_loader))) scheduler_params.setdefault("param_name", "lr") scheduler_ = ParamSchedulerSavingAsMetric(optimizer_, **scheduler_params) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler_) if evaluate_train_data: evaluator_train = create_supervised_evaluator( model, metrics=evaluation_metrics, device=device) if evaluate_val_data: val_data_loader_params["batch_size"] = _clip_batch_size( val_data_loader_params.get("batch_size", 1), val_dataset, "val") val_loader = DataLoader(val_dataset, **val_data_loader_params) evaluator_val = create_supervised_evaluator( model, metrics=evaluation_metrics, device=device) if model_checkpoint_params: assert isinstance(model_checkpoint_params, dict) minimize = model_checkpoint_params.pop("minimize", True) save_interval = model_checkpoint_params.get("save_interval", None) if not save_interval: model_checkpoint_params.setdefault( "score_function", get_score_function("ema_loss", minimize=minimize)) model_checkpoint_params.setdefault("score_name", "ema_loss") mc = model_checkpoint(**model_checkpoint_params) trainer.add_event_handler(Events.EPOCH_COMPLETED, mc, {"model": model}) if early_stopping_params: assert isinstance(early_stopping_params, dict) metric = early_stopping_params.pop("metric", None) assert (metric is None) or (metric in evaluation_metrics) minimize = early_stopping_params.pop("minimize", False) if metric: assert ( "score_function" not in early_stopping_params ), "Remove either 'metric' or 'score_function' from early_stopping_params: {}".format( early_stopping_params) early_stopping_params["score_function"] = get_score_function( metric, minimize=minimize) es = EarlyStopping(trainer=trainer, **early_stopping_params) if evaluate_val_data: evaluator_val.add_event_handler(Events.COMPLETED, es) elif evaluate_train_data: evaluator_train.add_event_handler(Events.COMPLETED, es) elif early_stopping_params: log.warning( "Early Stopping is disabled because neither " "evaluate_val_data nor evaluate_train_data is set True.") if time_limit: assert isinstance(time_limit, (int, float)) tl = TimeLimit(limit_sec=time_limit) trainer.add_event_handler(Events.ITERATION_COMPLETED, tl) pbar = None if progress_update: if not isinstance(progress_update, dict): progress_update = dict() progress_update.setdefault("persist", True) progress_update.setdefault("desc", "") pbar = ProgressBar(**progress_update) pbar.attach(trainer, ["ema_loss"]) else: def log_train_metrics(engine): log.info("[Epoch: {} | {}]".format(engine.state.epoch, engine.state.metrics)) trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_metrics) if evaluate_train_data: def log_evaluation_train_data(engine): evaluator_train.run(train_loader) train_report = _get_report_str(engine, evaluator_train, "Train Data") if pbar: pbar.log_message(train_report) else: log.info(train_report) eval_train_event = (Events[evaluate_train_data] if isinstance( evaluate_train_data, str) else Events.EPOCH_COMPLETED) trainer.add_event_handler(eval_train_event, log_evaluation_train_data) if evaluate_val_data: def log_evaluation_val_data(engine): evaluator_val.run(val_loader) val_report = _get_report_str(engine, evaluator_val, "Val Data") if pbar: pbar.log_message(val_report) else: log.info(val_report) eval_val_event = (Events[evaluate_val_data] if isinstance( evaluate_val_data, str) else Events.EPOCH_COMPLETED) trainer.add_event_handler(eval_val_event, log_evaluation_val_data) if mlflow_logging: mlflow_logger = MLflowLogger() logging_params = { "train_n_samples": len(train_dataset), "train_n_batches": len(train_loader), "optimizer": _name(optimizer), "loss_fn": _name(loss_fn), "pytorch_version": torch.__version__, "ignite_version": ignite.__version__, } logging_params.update(_loggable_dict(optimizer_params, "optimizer")) logging_params.update( _loggable_dict(train_data_loader_params, "train")) if scheduler: logging_params.update({"scheduler": _name(scheduler)}) logging_params.update( _loggable_dict(scheduler_params, "scheduler")) if evaluate_val_data: logging_params.update({ "val_n_samples": len(val_dataset), "val_n_batches": len(val_loader), }) logging_params.update( _loggable_dict(val_data_loader_params, "val")) mlflow_logger.log_params(logging_params) batch_metric_names = ["batch_loss", "ema_loss"] if scheduler: batch_metric_names.append(scheduler_params.get("param_name")) mlflow_logger.attach( trainer, log_handler=OutputHandler( tag="step", metric_names=batch_metric_names, global_step_transform=global_step_from_engine(trainer), ), event_name=Events.ITERATION_COMPLETED, ) if evaluate_train_data: mlflow_logger.attach( evaluator_train, log_handler=OutputHandler( tag="train", metric_names=list(evaluation_metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) if evaluate_val_data: mlflow_logger.attach( evaluator_val, log_handler=OutputHandler( tag="val", metric_names=list(evaluation_metrics.keys()), global_step_transform=global_step_from_engine(trainer), ), event_name=Events.COMPLETED, ) trainer.run(train_loader, max_epochs=epochs) try: if pbar and pbar.pbar: pbar.pbar.close() except Exception as e: log.error(e, exc_info=True) model = load_latest_model(model_checkpoint_params)(model) return model
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, gpuid): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:' + str(gpuid) check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: lr * min(1., epoch / warmup) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ', '.join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
"Episode %d: reward=%s, steps=%s, speed=%.3f frames/s, elapsed=%s" % (trainer.state.episode, trainer.state.episode_reward, trainer.state.episode_steps, trainer.state.metrics.get('avg_fps', 0), timedelta(seconds=trainer.state.metrics.get('time_passed', 0)))) @engine.on(ptan_ignite.EpisodeEvents.BOUND_REWARD_REACHED) def game_solved(trainer: Engine): print("Game solved in %s, after %d episodes and %d iterations!" % (timedelta(seconds=trainer.state.metrics['time_passed']), trainer.state.episode, trainer.state.iteration)) trainer.should_terminate = True logdir = f"runs/{datetime.now().isoformat(timespec='minutes')}-{params.run_name}-{NAME}" tb = tb_logger.TensorboardLogger(log_dir=logdir) RunningAverage(output_transform=lambda v: v['loss']).attach( engine, "avg_loss") episode_handler = tb_logger.OutputHandler( tag="episodes", metric_names=['reward', 'steps', 'avg_reward']) tb.attach(engine, log_handler=episode_handler, event_name=ptan_ignite.EpisodeEvents.EPISODE_COMPLETED) # write to tensorboard every 100 iterations ptan_ignite.PeriodicEvents().attach(engine) handler = tb_logger.OutputHandler(tag="train", metric_names=['avg_loss', 'avg_fps'], output_transform=lambda a: a) tb.attach(engine, log_handler=handler, event_name=ptan_ignite.PeriodEvents.ITERS_100_COMPLETED)
def test_integration(): n_iters = 100 batch_size = 10 n_classes = 10 y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) loss_values = iter(range(n_iters)) def update_fn(engine, batch): loss_value = next(loss_values) y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return ( loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch), ) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage( Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, "running_avg_accuracy") avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, "running_avg_output") running_avg_acc = [ None, ] @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): _, y_pred, y = engine.state.output indices = torch.max(y_pred, 1)[1] correct = torch.eq(indices, y).view(-1) num_correct = torch.sum(correct).item() num_examples = correct.shape[0] batch_acc = num_correct * 1.0 / num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.EPOCH_STARTED) def running_avg_output_init(engine): engine.state.running_avg_output = None @trainer.on(Events.ITERATION_COMPLETED) def running_avg_output_update(engine): if engine.state.running_avg_output is None: engine.state.running_avg_output = engine.state.output[0] else: engine.state.running_avg_output = ( engine.state.running_avg_output * alpha + (1.0 - alpha) * engine.state.output[0]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert (engine.state.running_avg_acc == engine.state. metrics["running_avg_accuracy"]), "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_output_values(engine): assert (engine.state.running_avg_output == engine.state.metrics["running_avg_output"]), "{} vs {}".format( engine.state.running_avg_output, engine.state.metrics["running_avg_output"]) np.random.seed(10) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1)
def _setup_common_training_handlers( trainer: Engine, to_save: Optional[Mapping] = None, save_every_iters: int = 1000, output_path: Optional[str] = None, lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None, with_gpu_stats: bool = False, output_names: Optional[Iterable[str]] = None, with_pbars: bool = True, with_pbar_on_iters: bool = True, log_every_iters: int = 100, stop_on_nan: bool = True, clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any, ) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" ) if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step() ) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None and save_handler is None: raise ValueError( "If to_save argument is provided then output_path or save_handler arguments should be also defined" ) if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) checkpoint_handler = Checkpoint( to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs ) trainer.add_event_handler(Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) # type: ignore[arg-type] ) if output_names is not None: def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise TypeError( "Unhandled type of update_function's output. " f"It should either mapping or sequence, but given {type(x)}" ) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach( trainer, n ) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters) ) ProgressBar(persist=True, bar_format="").attach( trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED )
def train(): parser = argparse.ArgumentParser() parser.add_argument("--model_checkpoint", type=str, default=PRETRAINED_MODEL_URL, help="Path to the pretrained model checkpoint") parser.add_argument("--dataset_path", type=str, default='../data/sst', help="Directory to dataset.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path to dataset cache") parser.add_argument("--logdir", type=str, default='./transformer_results', help="Path to logs") parser.add_argument("--num_classes", type=int, default=5, help="Number of classes for the target classification task") parser.add_argument("--adapters_dim", type=int, default=-1, help="If >0 add adapters to the model with adapters_dim dimension") parser.add_argument("--dropout", type=float, default=0.1, help="Dropout for transformer module") parser.add_argument("--clf_loss_coef", type=float, default=1, help="If >0 add a classification loss") parser.add_argument("--train_batch_size", type=int, default=32, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=32, help="Batch size for validation") parser.add_argument("--valid_pct", type=float, default=0.1, help="Percentage of training data to use for validation") parser.add_argument("--lr", type=float, default=6.5e-5, help="Learning rate") parser.add_argument("--n_warmup", type=int, default=10, help="Number of warmup iterations") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=0.0, help="Weight decay") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--gradient_acc_steps", type=int, default=2, help="Number of update steps to accumulate before a backward pass.") parser.add_argument("--init_range", type=float, default=0.02, help="Normal initialization standard deviation") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") args = parser.parse_args() # Define pretrained model and optimizer model, state_dict, config = load_pretrained_model(args) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=False) num_model_params = sum(p.numel() for p in model.parameters() if p.requires_grad) print(f"Model has {num_model_params:,} parameters") # Define datasets datasets = read_sst5(args.dataset_path) # Define labels labels = list(set(datasets["train"][LABEL_COL].tolist())) assert len(labels) == args.num_classes # Specified number of classes should be equal to that in the given dataset! label2int = {label: i for i, label in enumerate(labels)} int2label = {i: label for label, i in label2int.items()} # Get BertTokenizer for this pretrained model tokenizer = BertTokenizer.from_pretrained('bert-base-cased', do_lower_case=False) clf_token = tokenizer.vocab['[CLS]'] # classifier token pad_token = tokenizer.vocab['[PAD]'] # pad token processor = TextProcessor(tokenizer, label2int, clf_token, pad_token, max_length=config.num_max_positions) train_dl = processor.create_dataloader(datasets["train"], shuffle=True, batch_size=args.train_batch_size, valid_pct=None) valid_dl = processor.create_dataloader(datasets["dev"], batch_size=args.train_batch_size, valid_pct=None) test_dl = processor.create_dataloader(datasets["test"], batch_size=args.valid_batch_size, valid_pct=None) # Training function and trainer def update(engine, batch): "update function for training" model.train() inputs, labels = (t.to(args.device) for t in batch) inputs = inputs.transpose(0, 1).contiguous() # to shape [seq length, batch] _, loss = model(inputs, clf_tokens_mask=(inputs == clf_token), clf_labels=labels) loss = loss / args.gradient_acc_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_acc_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch, labels = (t.to(args.device) for t in batch) inputs = batch.transpose(0, 1).contiguous() # to shape [seq length, batch] clf_logits = model(inputs, clf_tokens_mask=(inputs == clf_token), padding_mask=(batch == pad_token)) return clf_logits, labels evaluator = Engine(inference) # add metric to evaluator Accuracy().attach(evaluator, "accuracy") # add evaluator to trainer: eval on valid set after each epoch @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(valid_dl) print(f"validation epoch: {engine.state.epoch} acc: {100*evaluator.state.metrics['accuracy']:.3f}%") # Learning rate schedule: linearly warm-up to lr and then to zero scheduler = PiecewiseLinear(optimizer, 'lr', [(0, 0.0), (args.n_warmup, args.lr), (len(train_dl) * args.n_epochs, 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Add progressbar with loss RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") ProgressBar(persist=True).attach(trainer, metric_names=['loss']) # Save checkpoints and finetuning config checkpoint_handler = ModelCheckpoint(args.logdir, 'checkpoint', save_interval=1, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'sst_model': model}) # Save metadata torch.save({ "config": config, "config_ft": args, "int2label": int2label }, os.path.join(args.logdir, "model_training_args.bin")) # Run trainer trainer.run(train_dl, max_epochs=args.n_epochs) # Evaluate evaluator.run(test_dl) print(f"test results - acc: {100*evaluator.state.metrics['accuracy']:.3f}") # Save fine-tuned model weights torch.save(model.state_dict(), os.path.join(args.logdir, "model_weights.pth"))
def _test_distrib_on_metric(device): import torch.distributed as dist rank = dist.get_rank() n_iters = 10 n_epochs = 3 batch_size = 10 n_classes = 10 data = list(range(n_iters)) np.random.seed(12) all_y_true_batch_values = np.random.randint(0, n_classes, size=(dist.get_world_size(), n_epochs * n_iters, batch_size)) all_y_pred_batch_values = np.random.rand(dist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage( Accuracy(output_transform=lambda x: [x[0], x[1]], device=device), alpha=alpha, epoch_bound=False, ) acc_metric.attach(trainer, "running_avg_accuracy") running_avg_acc = [ None, ] true_acc_metric = Accuracy(device=device) @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): i = engine.state.iteration - 1 true_acc_metric.reset() for j in range(dist.get_world_size()): output = ( torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), torch.from_numpy(all_y_true_batch_values[j, i, :]), ) true_acc_metric.update(output) batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert (engine.state.running_avg_acc == engine.state. metrics["running_avg_accuracy"]), "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) trainer.run(data, max_epochs=3)
def run(args): train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = KITTI.num_classes() model = LiLaNet(num_classes) device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size args.val_batch_size = device_count * args.val_batch_size model = model.to(device) criterion = nn.CrossEntropyLoss(weight=KITTI.class_weights()).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) def _prepare_batch(batch, non_blocking=True): distance, reflectivity, target = batch return (convert_tensor(distance, device=device, non_blocking=non_blocking), convert_tensor(reflectivity, device=device, non_blocking=non_blocking), convert_tensor(target, device=device, non_blocking=non_blocking)) def _update(engine, batch): model.train() optimizer.zero_grad() distance, reflectivity, target = _prepare_batch(batch) pred = model(distance, reflectivity) loss = criterion(pred, target) loss.backward() optimizer.step() return loss.item() trainer = Engine(_update) # attach running average metrics RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) def _inference(engine, batch): model.eval() with torch.no_grad(): distance, reflectivity, target = _prepare_batch(batch) pred = model(distance, reflectivity) return pred, target evaluator = Engine(_inference) cm = ConfusionMatrix(num_classes) IoU(cm, ignore_index=0).attach(evaluator, 'IoU') Loss(criterion).attach(evaluator, 'loss') pbar2 = ProgressBar(persist=True, desc='Eval Epoch') pbar2.attach(evaluator) def _global_step_transform(engine, event_name): if trainer.state is not None: return trainer.state.iteration else: return 1 tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(evaluator, log_handler=OutputHandler( tag='validation', metric_names=['loss', 'IoU'], global_step_transform=_global_step_transform), event_name=Events.EPOCH_COMPLETED) @trainer.on(Events.STARTED) def initialize(engine): engine.state.exception_raised = False if args.resume: engine.state.epoch = args.start_epoch @evaluator.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): epoch = trainer.state.epoch if trainer.state is not None else 1 iou = engine.state.metrics['IoU'] * 100.0 mean_iou = iou.mean() name = 'epoch{}_mIoU={:.1f}.pth'.format(epoch, mean_iou) file = { 'model': model.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict(), 'args': args } save(file, args.output_dir, 'checkpoint_{}'.format(name)) save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): pbar.log_message("Start Validation - Epoch: [{}/{}]".format( engine.state.epoch, engine.state.max_epochs)) evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] iou = metrics['IoU'] * 100.0 mean_iou = iou.mean() iou_text = ', '.join([ '{}: {:.1f}'.format(KITTI.classes[i + 1].name, v) for i, v in enumerate(iou.tolist()) ]) pbar.log_message( "Validation results - Epoch: [{}/{}]: Loss: {:.2e}\n IoU: {}\n mIoU: {:.1f}" .format(engine.state.epoch, engine.state.max_epochs, loss, iou_text, mean_iou)) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): engine.state.exception_raised = True if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") name = 'epoch{}_exception.pth'.format(trainer.state.epoch) file = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict() } save(file, args.output_dir, 'checkpoint_{}'.format(name)) save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) else: raise e if args.eval_on_start: print("Start validation") evaluator.run(val_loader, max_epochs=1) print("Start training") trainer.run(train_loader, max_epochs=args.epochs) tb_logger.close()
mu = next(mu_scheme) i = engine.state.iteration for group in optimizer.param_groups: group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i) return { "elbo": elbo.item(), "kl": kl_divergence.item(), "sigma": sigma, "mu": mu } # Trainer and metrics trainer = Engine(step) metric_names = ["elbo", "kl", "sigma", "mu"] RunningAverage(output_transform=lambda x: x["elbo"]).attach( trainer, "elbo") RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl") RunningAverage(output_transform=lambda x: x["sigma"]).attach( trainer, "sigma") RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu") ProgressBar().attach(trainer, metric_names=metric_names) # Model checkpointing checkpoint_handler = ModelCheckpoint("./", "checkpoint", save_interval=1, n_saved=3, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={
def run(train_batch_size, val_batch_size, epochs, lr, momentum, log_interval, log_dir): train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size) model = Net() writer = SummaryWriter(log_dir=log_dir) # Use TPU device device = xm.xla_device() model.to(device) # Move model before creating optimizer optimizer = SGD(model.parameters(), lr=lr, momentum=momentum) criterion = nn.NLLLoss() # Create trainer and evaluator trainer = create_supervised_trainer( model, optimizer, criterion, device=device, output_transform=lambda x, y, y_pred, loss: [loss.item(),] ) val_metrics = {"accuracy": Accuracy(), "nll": Loss(criterion)} evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device) tracker = xm.RateTracker() # Add RateTracker as an output of the training step @trainer.on(Events.ITERATION_COMPLETED) def add_rate_tracker(engine): tracker.add(len(engine.state.batch)) engine.state.output.append(tracker.global_rate()) # Setup output values of the training step as EMA metrics RunningAverage(output_transform=lambda x: x[0]).attach(trainer, "batch_loss") RunningAverage(output_transform=lambda x: x[1]).attach(trainer, "global_rate") # Let's log the EMA metrics every `log_interval` iterations @trainer.on(Events.ITERATION_COMPLETED(every=log_interval)) def log_training_loss(engine): writer.add_scalar("training/batch_loss", engine.state.metrics["batch_loss"], engine.state.iteration) writer.add_scalar("training/global_rate", engine.state.metrics["global_rate"], engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def log_training_results(engine): evaluator.run(train_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] print( f"Training Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) writer.add_scalar("training/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("training/avg_accuracy", avg_accuracy, engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_loader) metrics = evaluator.state.metrics avg_accuracy = metrics["accuracy"] avg_nll = metrics["nll"] print( f"Validation Results - Epoch: {engine.state.epoch} Avg accuracy: {avg_accuracy:.2f} Avg loss: {avg_nll:.2f}" ) writer.add_scalar("valdation/avg_loss", avg_nll, engine.state.epoch) writer.add_scalar("valdation/avg_accuracy", avg_accuracy, engine.state.epoch) # kick everything off trainer.run(train_loader, max_epochs=epochs) writer.close()
def setup_ignite( engine: Engine, params: SimpleNamespace, exp_source, run_name: str, model, optimizer, extra_metrics: Iterable[str] = (), ): warnings.simplefilter("ignore", category=UserWarning) handler = ptan_ignite.EndOfEpisodeHandler( exp_source, bound_avg_reward=params.stop_reward) handler.attach(engine) ptan_ignite.EpisodeFPSHandler().attach(engine) objects_to_checkpoint = { 'model': model, 'optimizer': optimizer, 'trainer': engine } checkpoint_dir = Path("models") saver = DiskSaver(str(checkpoint_dir), create_dir=True, require_empty=False) handler = Checkpoint(objects_to_checkpoint, saver, n_saved=2) engine.add_event_handler(Events.ITERATION_COMPLETED(every=1000), handler) checkpoints_paths = list(checkpoint_dir.iterdir()) if checkpoints_paths: checkpoint = torch.load(checkpoints_paths[-1]) print(f"Loading checkpoint {checkpoints_paths[-1].name}") Checkpoint.load_objects(to_load=objects_to_checkpoint, checkpoint=checkpoint) @engine.on(ptan_ignite.EpisodeEvents.EPISODE_COMPLETED) def episode_completed(trainer: Engine): passed = trainer.state.metrics.get('time_passed', 0) print("Episode %d: reward=%.2f, steps=%s, " "speed=%.1f f/s, elapsed=%s" % (trainer.state.episode, trainer.state.episode_reward, trainer.state.episode_steps, trainer.state.metrics.get('avg_fps', 0), timedelta(seconds=int(passed)))) @engine.on(ptan_ignite.EpisodeEvents.BOUND_REWARD_REACHED) def game_solved(trainer: Engine): passed = trainer.state.metrics['time_passed'] print("Game solved in %s, after %d episodes " "and %d iterations!" % (timedelta(seconds=int(passed)), trainer.state.episode, trainer.state.iteration)) trainer.should_terminate = True now = datetime.now().isoformat(timespec='minutes').replace(":", "-") logdir = f"runs/{now}-{params.run_name}-{run_name}" tb = tb_logger.TensorboardLogger(log_dir=logdir) run_avg = RunningAverage(output_transform=lambda v: v['loss']) run_avg.attach(engine, "avg_loss") metrics = ['reward', 'steps', 'avg_reward'] handler = tb_logger.OutputHandler(tag="episodes", metric_names=metrics) event = ptan_ignite.EpisodeEvents.EPISODE_COMPLETED tb.attach(engine, log_handler=handler, event_name=event) # write to tensorboard every 100 iterations ptan_ignite.PeriodicEvents().attach(engine) metrics = ['avg_loss', 'avg_fps'] metrics.extend(extra_metrics) handler = tb_logger.OutputHandler(tag="train", metric_names=metrics, output_transform=lambda a: a) event = ptan_ignite.PeriodEvents.ITERS_100_COMPLETED tb.attach(engine, log_handler=handler, event_name=event)
def do_train(cfg, model, train_loader, val_loader, optimizer, scheduler, loss_fn, num_query, start_epoch, image_map_label2, num_classes2): # ---------------------- LOSS start----------------------------- print('----------Initialize Loss Start...') criterion = torch.nn.CrossEntropyLoss() criterion_lsr = LSR() criterion_mse = torch.nn.MSELoss() #(size_average=True) criterion_lsr_direction = LSR_direction() criterion_adaptive_lsr = AdaptiveLSR(0.25) criterion_lsr.set_epsilon(0.1) criterion_lsr_direction.set_alpha(0.6) criterion_lsr_direction.set_beta(0.15) print('******\nalpha:', criterion_lsr_direction.alpha, ' beta:', criterion_lsr_direction.beta) same_id_list = get_same_id_list(image_map_label2) criterion_lsr_direction.set_mask(same_id_list, num_classes2) mask_tensor_matrix = torch.zeros(num_classes2, num_classes2) eplsion = [1, 1, 1] for ids_item in same_id_list: if len(ids_item) == 2: mask_tensor_matrix[ids_item[0], ids_item[1]] = eplsion[1] if len(ids_item) == 3: mask_tensor_matrix[ids_item[0], ids_item[1]] = eplsion[2] / 3 mask_tensor_matrix[ids_item[0], ids_item[2]] = eplsion[2] / 3 mask_tensor_matrix[ids_item[1], ids_item[2]] = eplsion[2] / 3 mask_tensor_matrix = mask_tensor_matrix.float() #mask_tensor_matrix = Variable(mask_tensor_matrix.cuda()) print('mask_tensor_matrix.shape:', mask_tensor_matrix.shape, type(mask_tensor_matrix), '\n\n\n') print('----------Initialize Loss End!!!') # --------------------------------------------------------- global mAP_path, model_dir mAP_path = osp.join(cfg.OUTPUT_DIR, 'map_cmc.txt') model_dir = cfg.OUTPUT_DIR map_cmc_txt = open(mAP_path, 'a+') map_cmc_txt.close() log_period = cfg.SOLVER.LOG_PERIOD checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD eval_period = cfg.SOLVER.EVAL_PERIOD output_dir = cfg.OUTPUT_DIR device = cfg.MODEL.DEVICE epochs = cfg.SOLVER.MAX_EPOCHS logger = logging.getLogger("reid_baseline.train") logger.info("Start training") trainer = create_supervised_trainer( model, optimizer, loss_fn, criterion, criterion_mse, criterion_lsr, criterion_adaptive_lsr, criterion_lsr_direction, mask_tensor_matrix, device, cfg.SOLVER.MIXUP, cfg.SOLVER.RICAP, cfg.MODEL.FREEZE_BASE, cfg.MODEL.FREEZE_BASE_EPOCH) #evaluator = create_supervised_evaluator(model, metrics={'r1_mAP': R1_mAP_reranking(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM)}, device=device) evaluator = create_supervised_evaluator( model, metrics={ 'r1_mAP': R1_mAP(num_query, max_rank=50, feat_norm=cfg.TEST.FEAT_NORM) }, device=device) checkpointer = ModelCheckpoint(output_dir, cfg.MODEL.NAME, checkpoint_period, n_saved=3, require_empty=False) timer = Timer(average=True) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpointer, { 'model': model, #.state_dict(), 'optimizer': optimizer }) #.state_dict()}) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # average metric to attach on trainer RunningAverage(output_transform=lambda x: x[0]).attach(trainer, 'avg_loss') RunningAverage(output_transform=lambda x: x[1]).attach(trainer, 'avg_acc') @trainer.on(Events.STARTED) def start_training(engine): engine.state.epoch = start_epoch @trainer.on(Events.EPOCH_STARTED) def adjust_learning_rate(engine): if cfg.SOLVER.MY_WARMUP == 'yes': if engine.state.epoch <= cfg.SOLVER.MY_WARMUP_EPOCH: print('--- warmup') else: scheduler.step() else: scheduler.step() @trainer.on(Events.ITERATION_COMPLETED) def log_training_loss(engine): global ITER ITER += 1 if ITER % log_period == 0: if cfg.SOLVER.MY_SCHEDULER == 'yes': logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}". format(engine.state.epoch, ITER, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'])) else: logger.info( "Epoch[{}] Iteration[{}/{}] Loss: {:.3f}, Acc: {:.3f}, Base Lr: {:.2e}" .format(engine.state.epoch, ITER, len(train_loader), engine.state.metrics['avg_loss'], engine.state.metrics['avg_acc'], scheduler.get_lr()[0])) if len(train_loader) == ITER: ITER = 0 # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): logger.info( 'Epoch {} done. Time per batch: {:.3f}[s] Speed: {:.2f}[samples/s]' .format(engine.state.epoch, timer.value() * timer.step_count, train_loader.batch_size / timer.value())) logger.info('-' * 10) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): global best_mAP, best_epoch, mAP_path, save_flag if engine.state.epoch % eval_period == 0: evaluator.run(val_loader) cmc, mAP = evaluator.state.metrics['r1_mAP'] logger.info("Validation Results - Epoch: {}".format( engine.state.epoch)) logger.info("[Epoch {}] mAP: {:.2%}".format( engine.state.epoch, mAP)) for r in [1, 5, 10, 20]: logger.info("CMC curve, Rank-{:<3}:{:.2%}".format( r, cmc[r - 1])) if float(mAP) > float(best_mAP): print('+++ get best_mAP: ', best_mAP, '-->', mAP) best_mAP = mAP best_epoch = int(engine.state.epoch) save_flag = True print(' set save_flag: True') map_cmc_txt = open(mAP_path, 'a+') map_cmc_txt.write( "Epoch[{}] best_mAP: {:.2f} best_epoch: {} \n".format( engine.state.epoch, best_mAP * 100, best_epoch)) map_cmc_txt.write( " mAP: {:.2f} Rank-1: {:.2f} Rank-5: {:.2f} Rank-10: {:.2f} Rank-20: {:.2f}\n" .format( float(mAP) * 100, cmc[0] * 100, cmc[4] * 100, cmc[9] * 100, cmc[19] * 100)) map_cmc_txt.flush() os.fsync(map_cmc_txt) map_cmc_txt.close() trainer.run(train_loader, max_epochs=epochs)
def train(): parser = ArgumentParser() parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset.") parser.add_argument("--use_adapter", default=False, action='store_true', help="Use adapter or not") parser.add_argument("--keyword_module", type=str, default="", help="add, attention, ") parser.add_argument("--train_batch_size", type=int, default=20, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=20, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=5, help="Number of training epochs") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--gpt2_model_name", type=str, default="gpt2", help="Path, url or short name of the model") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') bert_model = BertModel.from_pretrained('bert-base-uncased') bert_model.to(args.device) bert_model.eval() tokenizer_class = GPT2Tokenizer if "gpt2" in args.gpt2_model_name else OpenAIGPTTokenizer # cant use Autotokenizer because checkpoint could be a Path tokenizer = tokenizer_class.from_pretrained(args.gpt2_model_name) config_class = GPT2Config if "gpt2" in args.gpt2_model_name else OpenAIGPTConfig gpt_config = config_class.from_pretrained(args.gpt2_model_name) gpt_config.adapter = args.use_adapter gpt_config.keyword_module = args.keyword_module model_class = GPT2LMHeadModel if "gpt2" in args.gpt2_model_name else OpenAIGPTLMHeadModel model = model_class.from_pretrained(args.gpt2_model_name, config=gpt_config) model.to(args.device) # Add special tokens if they are not already added add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, bert_tokenizer, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) source_ids, target_ids, lm_labels = batch encoded_layers, _ = bert_model(source_ids) (lm_loss), *_ = model(target_ids, encoded_layers, labels=lm_labels) loss = lm_loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) source_ids, target_ids, lm_labels = batch logger.info(tokenizer.decode(target_ids[0].tolist())) encoded_layers, _ = bert_model(source_ids) lm_logits, *_ = model(target_ids, encoded_layers) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, ) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) log_dir = make_logdir(args.gpt2_model_name, args.dataset_path, args.use_adapter, args.keyword_module) tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=4) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def train(): parser = ArgumentParser() parser.add_argument("--train_path", type=str, default="data/train_set4DSTC7-AVSD.json", help="Path of the trainset") parser.add_argument("--fea_path", type=str, default="data/", help="Path of the trainset") parser.add_argument("--valid_path", type=str, default="data/valid_set4DSTC7-AVSD.json", help="Path of the validset") parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") parser.add_argument("--max_history", type=int, default=3, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--drop_rate", type=float, default=0.5, help="drop rate for caption") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=8, help="Number of training epochs") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--log_path", type=str, default="log/", help="Log path") args = parser.parse_args() if not os.path.exists(args.log_path): os.makedirs(args.log_path) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = VideoGPT2LMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) model.resize_token_embeddings(len(tokenizer)) model.to(args.device) optimizer = AdamW(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader = get_data_loaders_new(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) video_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[video_mask, input_mask], mode="video")[0] reply_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[reply_mask, input_mask], mode="reply")[0] loss = (video_loss + reply_loss) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) model_outputs = model(input_embs, token_type_ids=token_type_ids, attention_mask=[reply_mask, input_mask])[0] lm_logits = model_outputs # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return lm_logits_flat_shifted, lm_labels_flat_shifted evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0], x[1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir="./tb_logs") tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(args.log_path, 'checkpoint', save_interval=1, n_saved=8, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, args.log_path + 'model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(args.log_path, CONFIG_NAME)) tokenizer.save_vocabulary(args.log_path) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(args.log_path, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def run_once(self): # self.log_path = 'log/%s/' % self.dataset # self.model_name = 'efficientnet-b0_MSI_{0}fold_random_tile_patch'.format(self.fold_idx) # self.log_dir = self.log_path + self.model_name log_dir = self.log_dir check_manual_seed(self.seed) train_pairs, valid_pairs = dataset.prepare_PAIP2020_PANDA( self.fold_idx) print(len(train_pairs)) print(len(valid_pairs)) train_augmentors = self.train_augmentors() train_dataset = dataset.DatasetSerial(train_pairs[:], self.tile_size, self.num_tile, train_mode=True) infer_augmentors = self.infer_augmentors() # HACK at has_aux infer_dataset = dataset.DatasetSerial(valid_pairs[:], self.tile_size, self.num_tile, train_mode=False) train_loader = data.DataLoader(train_dataset, num_workers=self.nr_procs_train, batch_size=self.train_batch_size, shuffle=True, drop_last=True) valid_loader = data.DataLoader(infer_dataset, num_workers=self.nr_procs_valid, batch_size=self.infer_batch_size, shuffle=True, drop_last=False) # --------------------------- Training Sequence if self.logging: check_log_dir(log_dir) # device = 'cuda' # networksv input_chs = 3 # TODO: dynamic config # ### VGGNet net = EfficientNet.from_pretrained('efficientnet-b0', num_classes=2) #net =DenseNet(3,2) # load pre-trained models net = torch.nn.DataParallel(net).to(device) if self.load_network: saved_state = torch.load(self.save_net_path) net.load_state_dict(saved_state) # optimizers optimizer = optim.Adam(net.parameters(), lr=self.init_lr) scheduler = StepLR(optimizer, self.lr_steps, gamma=0.1) scheduler = LRScheduler(scheduler) # trainer = Engine(lambda engine, batch: self.train_step( net, batch, optimizer, device)) valider = Engine( lambda engine, batch: self.infer_step(net, batch, device)) infer_output = ['prob', 'true'] ## if self.logging: checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, save_interval=1, n_saved=100, require_empty=False) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) timer.attach(valider, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # attach running average metrics computation # decay of EMA to 0.95 to match tensorpack default # TODO: refactor this RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach( trainer, 'acc') RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach( trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) pbar.attach(valider) # #early Stopping # def score_function(engine): # val_acc=engine.state.metrics["valid-acc"] # return val_acc # early_stopping_handler=EarlyStopping(patience=10,score_function=score_function,trainer=trainer) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'net_exception': net}) else: raise e # writer for tensorboard logging tfwriter = None # HACK temporary if self.logging: tfwriter = SummaryWriter(log_dir) json_log_file = log_dir + '/stats.json' with open(json_log_file, 'w') as json_file: json.dump({}, json_file) # create empty file ### TODO refactor again log_info_dict = { 'logging': self.logging, 'optimizer': optimizer, 'tfwriter': tfwriter, 'json_file': json_log_file, 'nr_classes': self.nr_classes, 'metric_names': infer_output, 'infer_batch_size': self.infer_batch_size # too cumbersome } trainer.add_event_handler(Events.EPOCH_COMPLETED, log_train_ema_results, log_info_dict) trainer.add_event_handler(Events.EPOCH_COMPLETED, inference, valider, valid_loader, log_info_dict) valider.add_event_handler(Events.ITERATION_COMPLETED, accumulate_outputs) # Setup is done. Now let's run the training trainer.run(train_loader, self.nr_epochs) return
def train(run_name, forward_func, sample_func, model, train_set, val_set, n_epochs, batch_size, lr_i, lr_f, lr_n, sig_i, sig_f, sig_n): # Make the run directory save_dir = os.path.join('training/saved_runs', run_name) if run_name == 'debug': shutil.rmtree(save_dir, ignore_errors=True) os.mkdir(save_dir) model = model.to(device) train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=True) val_loader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=True) optimizer = torch.optim.Adam(model.parameters(), lr=lr_i) lr_scheduler = utils.AnnealingStepLR(optimizer, mu_i=lr_i, mu_f=lr_f, n=lr_n) sigma_scheduler = utils.AnnealingStepSigma(sig_i, sig_f, sig_n) # Training step def step(engine, batch): model.train() if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_mu, x_q, kl = forward_func(model, batch) # Log likelihood sigma = sigma_scheduler.sigma lr = lr_scheduler.get_lr()[0] ll = Normal(x_mu, sigma).log_prob(x_q) likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3])) kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3])) # Evidence lower bound elbo = likelihood - kl_divergence loss = -elbo loss.backward() optimizer.step() optimizer.zero_grad() lr_scheduler.step() sigma_scheduler.step() return { 'elbo': elbo.item(), 'likelihood': likelihood.item(), 'kl': kl_divergence.item(), 'lr': lr, 'sigma': sigma } # Trainer and metrics trainer = Engine(step) metric_names = ['elbo', 'likelihood', 'kl', 'lr', 'sigma'] RunningAverage(output_transform=lambda x: x['elbo']).attach( trainer, 'elbo') RunningAverage(output_transform=lambda x: x['likelihood']).attach( trainer, 'likelihood') RunningAverage(output_transform=lambda x: x['kl']).attach(trainer, 'kl') RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr') RunningAverage(output_transform=lambda x: x['sigma']).attach( trainer, 'sigma') ProgressBar().attach(trainer, metric_names=metric_names) Timer(average=True).attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # Model checkpointing checkpoint_handler = ModelCheckpoint(os.path.join(save_dir, 'checkpoints'), type(model).__name__, save_interval=1, n_saved=3, require_empty=False) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'sigma_scheduler': sigma_scheduler }) # Tensorbard writer writer = SummaryWriter(log_dir=os.path.join(save_dir, 'logs')) @trainer.on(Events.ITERATION_COMPLETED) def log_metrics(engine): if engine.state.iteration % 100 == 0: for metric, value in engine.state.metrics.items(): writer.add_scalar('training/{}'.format(metric), value, engine.state.iteration) def save_images(engine, batch): x_mu, x_q, r = sample_func(model, batch) r_dim = r.shape[1] if isinstance(model, VVGQN): r = (r + 1) / 2 r = r.view(-1, 1, int(math.sqrt(r_dim)), int(math.sqrt(r_dim))) x_mu = x_mu.detach().cpu().float() r = r.detach().cpu().float() writer.add_image('representation', make_grid(r), engine.state.epoch) writer.add_image('generation', make_grid(x_mu), engine.state.epoch) writer.add_image('query', make_grid(x_q), engine.state.epoch) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): model.eval() with torch.no_grad(): batch = next(iter(val_loader)) if isinstance(batch, list): batch = [tensor.to(device) for tensor in batch] else: batch = batch.to(device) x_mu, x_q, kl = forward_func(model, batch) # Validate at last sigma ll = Normal(x_mu, sigma_scheduler.sigma).log_prob(x_q) likelihood = torch.mean(torch.sum(ll, dim=[1, 2, 3])) kl_divergence = torch.mean(torch.sum(kl, dim=[1, 2, 3])) # Evidence lower bound elbo = likelihood - kl_divergence writer.add_scalar('validation/elbo', elbo.item(), engine.state.epoch) writer.add_scalar('validation/likelihood', likelihood.item(), engine.state.epoch) writer.add_scalar('validation/kl', kl_divergence.item(), engine.state.epoch) save_images(engine, batch) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): writer.close() engine.terminate() if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): import warnings warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'model_exception': model}) else: raise e start_time = time.time() trainer.run(train_loader, n_epochs) writer.close() end_time = time.time() print('Total training time: {}'.format( timedelta(seconds=end_time - start_time)))
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--train", default="./dataset/icdar2015/train") parser.add_argument("--test") parser.add_argument("--batch_size", default=32, type=int) parser.add_argument("--epochs", default=100, type=int) parser.add_argument("--scale", default=4, type=int) parser.add_argument("--logdir") parser.add_argument("--checkpoint") parser.add_argument("--restore") parser.add_argument("--seed", default=42, type=int) parser.add_argument("--excitation", choices=["cse", "sse", "scse", "none"], default=None) args = parser.parse_args() torch.manual_seed(args.seed) np.random.seed(args.seed) image_size = (512, 512) dataset = ICDAR15Dataset(os.path.join(args.train, "images"), os.path.join(args.train, "labels"), image_size=image_size, scale=args.scale, training=True) dataloader = data.DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8) if args.test is not None: test_dataset = ICDAR15Dataset(os.path.join(args.test, "images"), os.path.join(args.test, "labels"), image_size=image_size, scale=args.scale, training=False) else: n_test = min(1000, (len(dataset) * 0.05)) dataset, test_dataset = torch.utils.data.random_split( dataset, [len(dataset) - n_test, n_test]) # indices = np.arange(len(dataset)) # test_dataset = torch.utils.data.Subset(dataset, indices[:n_test]) # dataset = torch.utils.data.Subset(dataset, indices[n_test:]) # print(len(dataset), len(test_dataset)) test_dataloader = data.DataLoader(test_dataset, batch_size=8, shuffle=False, num_workers=8) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # pixellink = net.PixelLink(args.scale, pretrained=False).to(device) # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer) if args.restore: if torch.cuda.is_available(): map_location = None else: def map_location(storage, loc): return storage pixellink = torch.load(args.restore, map_location=map_location).to(device) else: excitation_cls = { "cse": net.CSE, "sse": net.SSE, "scse": net.SCSE }.get(args.excitation, None) print(excitation_cls) pixellink = net.MobileNetV2PixelLink( args.scale, excitation_cls=excitation_cls).to(device) optimizer = torch.optim.Adam(pixellink.parameters(), lr=1e-3) def step_fn(training): def fn(engine, batch): if training: pixellink.train() else: pixellink.eval() with torch.set_grad_enabled(training): images, pos_pixel_masks, neg_pixel_masks, pixel_weights, link_masks = batch if training: optimizer.zero_grad() images = images.to(device) pos_pixel_masks = pos_pixel_masks.to(device) neg_pixel_masks = neg_pixel_masks.to(device) pixel_weights = pixel_weights.to(device) link_masks = link_masks.to(device) pixel_input, link_input = pixellink(images) loss_object = net.PixelLinkLoss(pixel_input, pos_pixel_masks, neg_pixel_masks, pixel_weights, link_input, link_masks) # loss_object = net.PixelLinkFocalLoss(pixel_input, pos_pixel_masks, neg_pixel_masks, pixel_weights, link_input, link_masks) if training: loss_object.loss.backward() optimizer.step() return { "loss": loss_object.loss.item(), "loss/pixel": loss_object.pixel_loss.item(), "loss/link": loss_object.link_loss.item(), "accuracy/pixel": loss_object.pixel_accuracy, "accuracy/link": np.mean(loss_object.link_accuracy), "accuracy/positive_pixel": loss_object.positive_pixel_accuracy, } return fn dummy = torch.randn(1, 3, image_size[0], image_size[1], dtype=torch.float).to(device) writer = create_summary_writer(pixellink, dummy, os.path.join(args.logdir, "train")) test_writer = create_summary_writer(pixellink, dummy, os.path.join(args.logdir, "test")) trainer = Engine(step_fn(training=True)) evaluator = Engine(step_fn(training=False)) checkpoint_handler = ModelCheckpoint( args.checkpoint, "networks", n_saved=5, require_empty=False, score_function=lambda engine: -engine.state.metrics["loss"], score_name="loss") biggest_checkpoint_handler = ModelCheckpoint( args.checkpoint, "biggest", n_saved=5, score_function=lambda engine: engine.state.metrics["loss"], score_name="loss", require_empty=False) evaluator.add_event_handler(Events.COMPLETED, handler=checkpoint_handler, to_save={"net": pixellink}) evaluator.add_event_handler(Events.COMPLETED, handler=biggest_checkpoint_handler, to_save={"net": pixellink}) timer = Timer(average=True) monitoring_metrics = [ "loss", "loss/pixel", "loss/link", "accuracy/pixel", "accuracy/link", "accuracy/positive_pixel" ] for metric in monitoring_metrics: def output_transform(m): def fn(x): return x[m] return fn RunningAverage(output_transform=output_transform(metric)).attach( trainer, metric) RunningAverage(output_transform=output_transform(metric)).attach( evaluator, metric) pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % LOG_FREQ != 0: return for key, value in engine.state.metrics.items(): writer.add_scalar(key, value, engine.state.iteration) message = "[{epoch}/{max_epoch}][{i}/{max_i}] (train)\t".format( epoch=engine.state.epoch, max_epoch=args.epochs, i=(engine.state.iteration % len(dataloader)), max_i=len(dataloader), ) for key, value in engine.state.metrics.items(): message += ' | {key}: {value}'.format(key=key, value=str(round(value, 5))) pbar.log_message(message) @trainer.on(Events.ITERATION_COMPLETED) def print_validation_results(engine): if (engine.state.iteration - 1) % LOG_FREQ != 0: return evaluator.run(test_dataloader) for key, value in evaluator.state.metrics.items(): test_writer.add_scalar(key, value, engine.state.iteration) message = "[{epoch}/{max_epoch}][{i}/{max_i}] (test) \t".format( epoch=engine.state.epoch, max_epoch=args.epochs, i=(engine.state.iteration % len(dataloader)), max_i=len(dataloader), ) for key, value in evaluator.state.metrics.items(): message += ' | {key}: {value}'.format(key=key, value=str(round(value, 5))) pbar.log_message(message) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_time(engine): pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format( engine.state.epoch, timer.value(), )) timer.reset() @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") checkpoint_handler(engine, {"net": pixellink}) else: raise e trainer.run(dataloader, args.epochs) writer.close() test_writer.close()
def setup(self, training_metrics): def metric_name(n) -> str: if n.endswith('Accuracy'): n = 'acc' else: n = n[:-6] if n.endswith('Metric') else n return n def print_metrics(metrics) -> str: rv = '' metric_keys = sorted(k for k in metrics) for k in metric_keys: if k == 'Accuracy': rv += f'{metric_name(k)}: {metrics[k]:.3}' else: rv += f'{metric_name(k)}: {metrics[k]:.6}' return rv if self.seed: set_seed_everywhere(self.seed, self.cuda) pbar = ProgressBar() names = [] for k, v in training_metrics.items(): name = f'r{k}' names.append(name) RunningAverage(v).attach(self.trainer, name) RunningAverage(None, output_transform=lambda x: x[-1] * self. loss_accumulation_steps).attach(self.trainer, 'rloss') names.append('rloss') pbar.attach(self.trainer, names) pbar = ProgressBar() pbar.attach(self.evaluator) # A few events handler. To add / modify the events handler, you need to extend the __init__ method of RunnerABC # Ignite provides the necessary abstractions and a furnished repository of useful tools @self.trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): self.evaluator.run(self.dataset_splits.val_data_loader()) metrics = self.evaluator.state.metrics logger.info( f"Validation Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}" ) if self.scheduler: self.scheduler.step( metrics[self.loss_metric.__class__.__name__]) @self.trainer.on(Events.COMPLETED) def log_test_results(trainer): self.evaluator.run(self.dataset_splits.test_data_loader()) metrics = self.evaluator.state.metrics logger.info( f"Test Results - Epoch: {trainer.state.epoch} {print_metrics(metrics)}" ) if self.tensorboard_logs: tb_logger = TensorboardLogger(log_dir=self.tensorboard_logs) tb_logger.attach(self.trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.evaluator, log_handler=OutputHandler( tag="validation", metric_names=["LossMetric"], another_engine=self.trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=OptimizerParamsHandler( self.optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(self.trainer, log_handler=WeightsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(self.trainer, log_handler=WeightsHistHandler(self.model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(self.trainer, log_handler=GradsScalarHandler(self.model), event_name=Events.ITERATION_COMPLETED) # This is important to close the tensorboard file logger @self.trainer.on(Events.COMPLETED) def end_tensorboard(trainer): logger.info("Training completed") tb_logger.close() if self.embeddings_name: @self.trainer.on(Events.COMPLETED) def log_embeddings(trainer): if hasattr(self.model, self.embeddings_name) and hasattr( self.dataset_splits, "vectorizer"): logger.info( f"Logging embeddings ({self.embeddings_name}) to Tensorboard!" ) embeddings = getattr(self.model, self.embeddings_name).weight.data metadata = [ str(self.dataset_splits.vectorizer.data_vocab. _id2token[token_index]).encode('utf-8') for token_index in range(embeddings.shape[0]) ] self.writer.add_embedding( mat=embeddings, metadata=metadata, global_step=self.trainer.state.epoch)
def train(): config_file = "configs/train_full_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", config.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) # Training function and trainer def update(engine, batch): model.train() input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = tuple( input_tensor.to(config.device) for input_tensor in batch) lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids) loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps if config.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) if engine.state.iteration % config.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if config.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if config.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if config.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if config.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=config.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=config.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if config.local_rank in [-1, 0] and config.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
'D_x': D_x, 'D_G_z1': D_G_z1, 'D_G_z2': D_G_z2 } trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ['errD', 'errG', 'D_x', 'D_G_z1', 'D_G_z2'] RunningAverage(alpha=alpha, output_transform=lambda x: x['errD']).attach(trainer, 'errD') RunningAverage(alpha=alpha, output_transform=lambda x: x['errG']).attach(trainer, 'errG') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_x']).attach(trainer, 'D_x') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z1']).attach( trainer, 'D_G_z1') RunningAverage(alpha=alpha, output_transform=lambda x: x['D_G_z2']).attach( trainer, 'D_G_z2') # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED)
if trainer.state.iteration % SAVE_IMAGE_EVERY_ITER == 0: fake_img = vutils.make_grid( gen_output_v.data[:64], normalize=True) trainer.tb.writer.add_image( "fake", fake_img, trainer.state.iteration) real_img = vutils.make_grid( batch_v.data[:64], normalize=True) trainer.tb.writer.add_image( "real", real_img, trainer.state.iteration) trainer.tb.writer.flush() return dis_loss.item(), gen_loss.item() engine = Engine(process_batch) tb = tb_logger.TensorboardLogger(log_dir=None) engine.tb = tb RunningAverage(output_transform=lambda out: out[0]).\ attach(engine, "avg_loss_gen") RunningAverage(output_transform=lambda out: out[1]).\ attach(engine, "avg_loss_dis") handler = tb_logger.OutputHandler(tag="train", metric_names=['avg_loss_gen', 'avg_loss_dis']) tb.attach(engine, log_handler=handler, event_name=Events.ITERATION_COMPLETED) @engine.on(Events.ITERATION_COMPLETED) def log_losses(trainer): if trainer.state.iteration % REPORT_EVERY_ITER == 0: log.info("%d: gen_loss=%f, dis_loss=%f", trainer.state.iteration, trainer.state.metrics['avg_loss_gen'], trainer.state.metrics['avg_loss_dis'])
def train(model, train_loader, eval_loaders, optimizer, loss_fn, n_it_max, patience, split_names, select_metric='Val accuracy_0', select_mode='max', viz=None, device='cpu', lr_scheduler=None, name=None, log_steps=None, log_epoch=False, _run=None, prepare_batch=_prepare_batch, single_pass=False, n_ep_max=None): # print(model) if not log_steps and not log_epoch: logger.warning('/!\\ No logging during training /!\\') if log_steps is None: log_steps = [] epoch_steps = len(train_loader) if log_epoch: log_steps.append(epoch_steps) if single_pass: max_epoch = 1 elif n_ep_max is None: assert n_it_max is not None max_epoch = int(n_it_max / epoch_steps) + 1 else: assert n_it_max is None max_epoch = n_ep_max all_metrics = defaultdict(dict) trainer = create_supervised_trainer(model, optimizer, loss_fn, device=device, prepare_batch=prepare_batch) if hasattr(model, 'new_epoch_hook'): trainer.add_event_handler(Events.EPOCH_STARTED, model.new_epoch_hook) if hasattr(model, 'new_iter_hook'): trainer.add_event_handler(Events.ITERATION_STARTED, model.new_iter_hook) trainer.logger.setLevel(logging.WARNING) # trainer output is in the format (x, y, y_pred, loss, optionals) train_loss = RunningAverage(output_transform=lambda out: out[3].item(), epoch_bound=True) train_loss.attach(trainer, 'Trainer loss') if hasattr(model, 's'): met = Average(output_transform=lambda _: float('nan') if model.s is None else model.s) met.attach(trainer, 'cur_s') trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'cur_s') if hasattr(model, 'arch_sampler') and model.arch_sampler.distrib_dim > 0: met = Average(output_transform=lambda _: float('nan') if model.cur_split is None else model.cur_split) met.attach(trainer, 'Trainer split') trainer.add_event_handler(Events.ITERATION_COMPLETED, met.completed, 'Trainer split') # trainer.add_event_handler(Events.EPOCH_STARTED, met.started) all_ent = Average( output_transform=lambda out: out[-1]['arch_entropy_avg'].item()) all_ent.attach(trainer, 'Trainer all entropy') trainer.add_event_handler(Events.ITERATION_COMPLETED, all_ent.completed, 'Trainer all entropy') train_ent = Average( output_transform=lambda out: out[-1]['arch_entropy_sample'].item()) train_ent.attach(trainer, 'Trainer sampling entropy') trainer.add_event_handler(Events.ITERATION_COMPLETED, train_ent.completed, 'Trainer sampling entropy') trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda engine: model.check_arch_freezing( ent=train_ent.compute(), epoch=engine.state.iteration/(epoch_steps*max_epoch)) ) def log_always(engine, name): val = engine.state.output[-1][name] all_metrics[name][engine.state.iteration/epoch_steps] = val.mean().item() def log_always_dict(engine, name): for node, val in engine.state.output[-1][name].items(): all_metrics['node {} {}'.format(node, name)][engine.state.iteration/epoch_steps] = val.mean().item() trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_grads') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='arch_probas') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always_dict, name='node_grads') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='task all_loss') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='arch all_loss') trainer.add_event_handler(Events.ITERATION_COMPLETED, log_always, name='entropy all_loss') if n_it_max is not None: StopAfterIterations([n_it_max]).attach(trainer) # epoch_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name, # persist=True, disable=not (_run or viz)) # epoch_pbar.attach(trainer, metric_names=['Train loss']) # # training_pbar = ProgressBar(bar_format='{l_bar}{bar}{r_bar}', desc=name, # persist=True, disable=not (_run or viz)) # training_pbar.attach(trainer, event_name=Events.EPOCH_COMPLETED, # closing_event_name=Events.COMPLETED) total_time = Timer(average=False) eval_time = Timer(average=False) eval_time.pause() data_time = Timer(average=False) forward_time = Timer(average=False) forward_time.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.ITERATION_COMPLETED, resume=Events.ITERATION_STARTED, step=Events.ITERATION_COMPLETED) epoch_time = Timer(average=False) epoch_time.attach(trainer, start=Events.EPOCH_STARTED, pause=Events.EPOCH_COMPLETED, resume=Events.EPOCH_STARTED, step=Events.EPOCH_COMPLETED) def get_loss(y_pred, y): l = loss_fn(y_pred, y) if not torch.is_tensor(l): l, *l_details = l return l.mean() def get_member(x, n=0): if isinstance(x, (list, tuple)): return x[n] return x eval_metrics = {'loss': Loss(get_loss)} for i in range(model.n_out): out_trans = get_attr_transform(i) def extract_ys(out): x, y, y_pred, loss, _ = out return out_trans((y_pred, y)) train_acc = Accuracy(extract_ys) train_acc.attach(trainer, 'Trainer accuracy_{}'.format(i)) trainer.add_event_handler(Events.ITERATION_COMPLETED, train_acc.completed, 'Trainer accuracy_{}'.format(i)) eval_metrics['accuracy_{}'.format(i)] = \ Accuracy(output_transform=out_trans) # if isinstance(model, SSNWrapper): # model.arch_sampler.entropy().mean() evaluator = create_supervised_evaluator(model, metrics=eval_metrics, device=device, prepare_batch=prepare_batch) last_iteration = 0 patience_counter = 0 best = {'value': float('inf') * 1 if select_mode == 'min' else -1, 'iter': -1, 'state_dict': None } def is_better(new, old): if select_mode == 'min': return new < old else: return new > old def log_results(evaluator, data_loader, iteration, split_name): evaluator.run(data_loader) metrics = evaluator.state.metrics log_metrics = {} for metric_name, metric_val in metrics.items(): log_name = '{} {}'.format(split_name, metric_name) if viz: first = iteration == 0 and split_name == split_names[0] viz.line([metric_val], X=[iteration], win=metric_name, name=log_name, update=None if first else 'append', opts={'title': metric_name, 'showlegend': True, 'width': 500, 'xlabel': 'iterations'}) viz.line([metric_val], X=[iteration/epoch_steps], win='{}epoch'.format(metric_name), name=log_name, update=None if first else 'append', opts={'title': metric_name, 'showlegend': True, 'width': 500, 'xlabel': 'epoch'}) if _run: _run.log_scalar(log_name, metric_val, iteration) log_metrics[log_name] = metric_val all_metrics[log_name][iteration] = metric_val return log_metrics if lr_scheduler is not None: @trainer.on(Events.EPOCH_COMPLETED) def step(_): lr_scheduler.step() # logger.warning('current lr {:.5e}'.format( # optimizer.param_groups[0]['lr'])) @trainer.on(Events.ITERATION_COMPLETED) def log_event(trainer): iteration = trainer.state.iteration if trainer.state else 0 nonlocal last_iteration, patience_counter, best if not log_steps or not \ (iteration in log_steps or iteration % log_steps[-1] == 0): return epoch_time.pause() eval_time.resume() all_metrics['training_epoch'][iteration] = iteration / epoch_steps all_metrics['training_iteration'][iteration] = iteration if hasattr(model, 'arch_sampler'): all_metrics['training_archs'][iteration] = \ model.arch_sampler().squeeze().detach() # if hasattr(model, 'distrib_gen'): # entropy = model.distrib_gen.entropy() # all_metrics['entropy'][iteration] = entropy.mean().item() # if trainer.state and len(trainer.state.metrics) > 1: # raise ValueError(trainer.state.metrics) all_metrics['data time'][iteration] = data_time.value() all_metrics['data time_ps'][iteration] = data_time.value() / max(data_time.step_count, 1.) all_metrics['forward time'][iteration] = forward_time.value() all_metrics['forward time_ps'][iteration] = forward_time.value() / max(forward_time.step_count, 1.) all_metrics['epoch time'][iteration] = epoch_time.value() all_metrics['epoch time_ps'][iteration] = epoch_time.value() / max(epoch_time.step_count, 1.) if trainer.state: # logger.warning(trainer.state.metrics) for metric, value in trainer.state.metrics.items(): all_metrics[metric][iteration] = value if viz: viz.line([value], X=[iteration], win=metric.split()[-1], name=metric, update=None if iteration==0 else 'append', opts={'title': metric, 'showlegend': True, 'width': 500, 'xlabel': 'iterations'}) iter_this_step = iteration - last_iteration for d_loader, name in zip(eval_loaders, split_names): if name == 'Train': if iteration == 0: all_metrics['Trainer loss'][iteration] = float('nan') all_metrics['Trainer accuracy_0'][iteration] = float('nan') if hasattr(model, 'arch_sampler'): all_metrics['Trainer all entropy'][iteration] = float('nan') all_metrics['Trainer sampling entropy'][iteration] = float('nan') # if hasattr(model, 'cur_split'): all_metrics['Trainer split'][iteration] = float('nan') continue split_metrics = log_results(evaluator, d_loader, iteration, name) if select_metric not in split_metrics: continue if is_better(split_metrics[select_metric], best['value']): best['value'] = split_metrics[select_metric] best['iter'] = iteration best['state_dict'] = copy.deepcopy(model.state_dict()) if patience > 0: patience_counter = 0 elif patience > 0: patience_counter += iter_this_step if patience_counter >= patience: logger.info('#####') logger.info('# Early stopping Run') logger.info('#####') trainer.terminate() last_iteration = iteration eval_time.pause() eval_time.step() all_metrics['eval time'][iteration] = eval_time.value() all_metrics['eval time_ps'][iteration] = eval_time.value() / eval_time.step_count all_metrics['total time'][iteration] = total_time.value() epoch_time.resume() log_event(trainer) # # @trainer.on(Events.EPOCH_COMPLETED) # def log_epoch(trainer): # iteration = trainer.state.iteration if trainer.state else 0 # epoch = iteration/epoch_steps # fw_t = forward_time.value() # fw_t_ps = fw_t / forward_time.step_count # d_t = data_time.value() # d_t_ps = d_t / data_time.step_count # e_t = epoch_time.value() # e_t_ps = e_t / epoch_time.step_count # ev_t = eval_time.value() # ev_t_ps = ev_t / eval_time.step_count # logger.warning('<{}> Epoch {}/{} finished (Forward: {:.3f}s({:.3f}), ' # 'data: {:.3f}s({:.3f}), epoch: {:.3f}s({:.3f}),' # ' Eval: {:.3f}s({:.3f}), Total: ' # '{:.3f}s)'.format(type(model).__name__, epoch, # max_epoch, fw_t, fw_t_ps, d_t, d_t_ps, # e_t, e_t_ps, ev_t, ev_t_ps, # total_time.value())) data_time.attach(trainer, start=Events.STARTED, pause=Events.ITERATION_STARTED, resume=Events.ITERATION_COMPLETED, step=Events.ITERATION_STARTED) if hasattr(model, 'iter_per_epoch'): model.iter_per_epoch = len(train_loader) trainer.run(train_loader, max_epochs=max_epoch) return trainer.state.iteration, all_metrics, best
def main(): args = parse_args() logger.info('Num GPU: {}'.format(num_gpus)) logger.info('Load Dataset') data = get_dataset(args.dataset, args.data_root, args.batch_size) data1, _ = data['train'][0] dims = list(data1.shape) param = dict(zdim=args.zdim, hdim=args.hdim, quant=args.quantization) model, optimizer = get_model(args.model, args.learning_rate, param, *dims) model = torch.nn.DataParallel(model) if num_gpus > 1 else model model.to(device) logger.info(model) kwargs = { 'pin_memory': True if use_gpu else False, 'shuffle': True, 'num_workers': num_gpus * 4 } logdir = get_logdir_name(args, param) logger.info('Log Dir: {}'.format(logdir)) writer = SummaryWriter(logdir) os.makedirs(logdir, exist_ok=True) train_loader = DataLoader(data['train'], args.batch_size, **kwargs) kwargs['shuffle'] = False test_loader = DataLoader(data['test'], args.batch_size, **kwargs) if args.quantization: q = Quantization(device=device) else: q = Dummy() def get_recon_error(recon, x, sigma): if x.shape[1] == 1: # Binary image ll = Bernoulli(recon).log_prob(x) elif x.shape[1] == 3: # RGB image ll = Normal(recon, sigma).log_prob(x) else: NotImplementedError('X must be either 1 or 3') return -ll.sum() def step(engine, batch): model.train() x, _ = batch x = x.to(device) x_quant = q.preprocess(x) recon, kl = model(x_quant) nll = get_recon_error(recon, x, sigma(engine.state.epoch, args.sigma_switch)) loss = nll + kl elbo = -loss optimizer.zero_grad() loss.backward() optimizer.step() lr = optimizer.param_groups[0]['lr'] ret = { 'elbo': elbo.item() / len(x), 'nll': nll.item() / len(x), 'kl': kl.item() / len(x), 'lr': lr, 'sigma': sigma(engine.state.epoch, args.sigma_switch) } return ret trainer = Engine(step) metric_names = ['elbo', 'nll', 'kl', 'lr', 'sigma'] RunningAverage(output_transform=lambda x: x['elbo']).attach(trainer, 'elbo') RunningAverage(output_transform=lambda x: x['nll']).attach(trainer, 'nll') RunningAverage(output_transform=lambda x: x['kl']).attach(trainer, 'kl') RunningAverage(output_transform=lambda x: x['lr']).attach(trainer, 'lr') RunningAverage(output_transform=lambda x: x['sigma']).attach( trainer, 'sigma') ProgressBar().attach(trainer, metric_names=metric_names) Timer(average=True).attach(trainer) add_events(trainer, model, writer, logdir, args.log_interval) @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): model.eval() val_elbo = 0 val_kl = 0 val_nll = 0 with torch.no_grad(): for i, (x, _) in enumerate(test_loader): x = x.to(device) x_quant = q.preprocess(x) recon, kl = model(x_quant) nll = get_recon_error( recon, x, sigma(engine.state.epoch, args.sigma_switch)) loss = nll + kl elbo = -loss val_elbo += elbo val_kl += kl val_nll += nll if i == 0: batch, *xdims = x.shape row = 8 n = min(x.shape[0], row) comparison = torch.cat([x[:n], recon[:n]]) grid = make_grid(comparison.detach().cpu().float(), nrow=row) writer.add_image('val/reconstruction', grid, engine.state.iteration) val_elbo /= len(test_loader.dataset) val_kl /= len(test_loader.dataset) val_nll /= len(test_loader.dataset) writer.add_scalar('val/elbo', val_elbo.item(), engine.state.iteration) writer.add_scalar('val/kl', val_kl.item(), engine.state.iteration) writer.add_scalar('val/nll', val_nll.item(), engine.state.iteration) print('{:3d} /{:3d} : ELBO: {:.4f}, KL: {:.4f}, NLL: {:.4f}'.format( engine.state.epoch, engine.state.max_epochs, val_elbo, val_kl, val_nll)) @trainer.on(Events.EXCEPTION_RAISED) def handler_exception(engine, e): writer.close() engine.terminate() if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): logger.warn('KeyboardInterrupt caught. Exiting gracefully.') else: raise e logger.info( 'Start training. Max epoch = {}, Batch = {}, # Trainset = {}'.format( args.epoch, args.batch_size, len(data['train']))) trainer.run(train_loader, args.epoch) logger.info('Done training') writer.close()
def train(): parser = ArgumentParser() parser.add_argument("--local_rank", type=int, default=-1) args = parser.parse_args() device = torch.device("cuda" if torch.cuda.device_count() > 1 else "cpu") model = GPT2DoubleHeadsModel.from_pretrained('gpt2') tokenizer = GPT2Tokenizer.from_pretrained("gpt2") DISTRIBUTED = args.local_rank != -1 if DISTRIBUTED and torch.distributed.is_available(): print("Distributed") torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') #BATCH_SIZE *= 2 def average_distributed_scalar(scalar): if (not DISTRIBUTED): return scalar scalar_t = torch.tensor( scalar, dtype=torch.float, device=device) / torch.distributed.get_world_size() torch.distributed.all_reduce(scalar_t, op=torch.distributed.ReduceOp.SUM) return scalar_t.item() optimizer = AdamW(model.parameters(), lr=6.25e-5) ds = dataloader.Conv_GPT2_DataClass(tokenizer) v_ds = dataloader.Conv_GPT2_DataClass(tokenizer, dev=True) orig_added_tokens = len(tokenizer.encoder) num_added_tokens = tokenizer.add_special_tokens( dataloader.ATTR_SPECIAL_TOKENS) if (num_added_tokens > 0): model.resize_token_embeddings(new_num_tokens=orig_added_tokens + num_added_tokens) model = model.to(device) train_sampler = torch.utils.data.distributed.DistributedSampler( ds) if DISTRIBUTED else None valid_sampler = torch.utils.data.distributed.DistributedSampler( v_ds) if DISTRIBUTED else None dl = DataLoader(ds, sampler=train_sampler, batch_size=BATCH_SIZE, shuffle=not DISTRIBUTED) v_dl = DataLoader(v_ds, sampler=valid_sampler, shuffle=False) metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"]), }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) def update(engine, batch): model.train() batch = tuple(t.to(device) for t in batch) lm_loss, *_ = model(batch[0], token_type_ids=batch[1], lm_labels=batch[2]) loss = lm_loss / ITERATION_STEP loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) if engine.state.iteration % ITERATION_STEP == 0: optimizer.step() optimizer.zero_grad() return loss.item() def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(t.to(device) for t in batch) input_ids, token_type_ids, lm_labels = batch model_outputs = model(input_ids, token_type_ids=token_type_ids) lm_logits = model_outputs[0] lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return lm_logits_flat_shifted, lm_labels_flat_shifted trainer = Engine(update) evaluator = Engine(inference) scheduler = PiecewiseLinear(optimizer, "lr", [(0, 6.25e-5), (EPOCHS * len(ds) // BATCH_SIZE, 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(v_dl)) if DISTRIBUTED: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) #evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") for name, metric in metrics.items(): metric.attach(evaluator, name) if (args.local_rank in [0, -1]): pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) #evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir='./logs') tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) #tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint('./checkpoint', '_checkpoint', n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'gpt2_qg': getattr(model, 'module', model)}) getattr(model, 'module', model).config.to_json_file( os.path.join('./checkpoint', 'config')) tokenizer.save_pretrained('./checkpoint') trainer.run(dl, max_epochs=EPOCHS) if (args.local_rank in [0, -1]): tb_logger.close()
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") #parser.add_argument("--model_checkpoint", type=str, default="/home/rohola/codes/transfer-learning-conv-ai/runs/Jun18_10-40-49_rohola-pc", help="Path, url or short name of the model") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=2, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=1, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=20, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument( "--device", type=str, default="cuda:0" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument( "--log_dir", type=str, default="", help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) #persona_ids, history_ids, reply_ids, mc_token_ids, lm_labels, mc_labels, history_token_type = batch lm_loss, mc_loss = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) #input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch persona_ids, history_ids, reply_ids, mc_token_ids, lm_labels, mc_labels, history_token_type = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) #model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, past=engine.state.past) model_outputs = model(persona_ids, history_ids, reply_ids, mc_token_ids, history_token_type=history_token_type) lm_logits, mc_logits = model_outputs[0], model_outputs[1] #engine.state.presents = presents lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def main( dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs, learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device, alpha, output_dir, ): # seed check_manual_seed(seed) # data dataset, num_channels = check_dataset(dataset, dataroot) loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) # netowrks netG = Generator(z_dim, g_filters, num_channels).to(device) netD = Discriminator(num_channels, d_filters).to(device) # criterion bce = nn.BCELoss() # optimizers optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G)) if saved_D: netD.load_state_dict(torch.load(saved_D)) # misc real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device) def get_noise(): return torch.randn(batch_size, z_dim, 1, 1, device=device) # The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return { "errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2 } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"] RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach( trainer, "errD") RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach( trainer, "errG") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach( trainer, "D_G_z1") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach( trainer, "D_G_z2") # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ)) def print_logs(engine): fname = os.path.join(output_dir, LOGS_FNAME) columns = [ "iteration", ] + list(engine.state.metrics.keys()) values = [ str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()] with open(fname, "a") as f: if f.tell() == 0: print("\t".join(columns), file=f) print("\t".join(values), file=f) message = "[{epoch}/{max_epoch}][{i}/{max_i}]".format( epoch=engine.state.epoch, max_epoch=epochs, i=(engine.state.iteration % len(loader)), max_i=len(loader)) for name, value in zip(columns, values): message += " | {name}: {value}".format(name=name, value=value) pbar.log_message(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "netG": netG, "netD": netD }) # automatically adding handlers via a special `attach` method of `Timer` handler timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message("Epoch {} done. Time per batch: {:.3f}[s]".format( engine.state.epoch, timer.value())) timer.reset() # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use("agg") import numpy as np import pandas as pd import matplotlib.pyplot as plt except ImportError: warnings.warn( "Loss plots will not be generated -- pandas or matplotlib not found" ) else: df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration") _ = df.plot(subplots=True, figsize=(20, 20)) _ = plt.xlabel("Iteration number") fig = plt.gcf() path = os.path.join(output_dir, PLOT_FNAME) fig.savefig(path) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") create_plots(engine) checkpoint_handler(engine, { "netG_exception": netG, "netD_exception": netD }) else: raise e # Setup is done. Now let's run the training trainer.run(loader, epochs)