def training(config, local_rank=None, with_mlflow_logging=False, with_plx_logging=False): if not getattr(config, "use_fp16", True): raise RuntimeError("This training script uses by default fp16 AMP") set_seed(config.seed + local_rank) torch.cuda.set_device(local_rank) device = 'cuda' torch.backends.cudnn.benchmark = True train_loader = config.train_loader train_sampler = getattr(train_loader, "sampler", None) assert train_sampler is not None, "Train loader of type '{}' " \ "should have attribute 'sampler'".format(type(train_loader)) assert hasattr(train_sampler, 'set_epoch') and callable(train_sampler.set_epoch), \ "Train sampler should have a callable method `set_epoch`" train_eval_loader = config.train_eval_loader val_loader = config.val_loader model = config.model.to(device) optimizer = config.optimizer model, optimizer = amp.initialize(model, optimizer, opt_level=getattr( config, "fp16_opt_level", "O2"), num_losses=1) model = DDP(model, delay_allreduce=True) criterion = config.criterion.to(device) prepare_batch = getattr(config, "prepare_batch", _prepare_batch) non_blocking = getattr(config, "non_blocking", True) # Setup trainer accumulation_steps = getattr(config, "accumulation_steps", 1) model_output_transform = getattr(config, "model_output_transform", lambda x: x) def train_update_function(engine, batch): model.train() x, y = prepare_batch(batch, device=device, non_blocking=non_blocking) y_pred = model(x) y_pred = model_output_transform(y_pred) loss = criterion(y_pred, y) if isinstance(loss, Mapping): assert 'supervised batch loss' in loss loss_dict = loss output = {k: v.item() for k, v in loss_dict.items()} loss = loss_dict['supervised batch loss'] / accumulation_steps else: output = {'supervised batch loss': loss.item()} with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss: scaled_loss.backward() if engine.state.iteration % accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return output output_names = getattr(config, "output_names", [ 'supervised batch loss', ]) trainer = Engine(train_update_function) common.setup_common_distrib_training_handlers( trainer, train_sampler, to_save={ 'model': model, 'optimizer': optimizer }, save_every_iters=1000, output_path=config.output_path.as_posix(), lr_scheduler=config.lr_scheduler, with_gpu_stats=True, output_names=output_names, with_pbars=True, with_pbar_on_iters=with_mlflow_logging, log_every_iters=1) # Setup evaluators num_classes = config.num_classes cm_metric = ConfusionMatrix(num_classes=num_classes) val_metrics = { "IoU": IoU(cm_metric), "mIoU_bg": mIoU(cm_metric), } if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict): val_metrics.update(config.val_metrics) model_output_transform = getattr(config, "model_output_transform", lambda x: x) evaluator_args = dict(model=model, metrics=val_metrics, device=device, non_blocking=non_blocking, prepare_batch=prepare_batch, output_transform=lambda x, y, y_pred: ( model_output_transform(y_pred), y, )) train_evaluator = create_supervised_evaluator(**evaluator_args) evaluator = create_supervised_evaluator(**evaluator_args) if dist.get_rank() == 0 and with_mlflow_logging: ProgressBar(persist=False, desc="Train Evaluation").attach(train_evaluator) ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator) def run_validation(_): train_evaluator.run(train_eval_loader) evaluator.run(val_loader) if getattr(config, "start_by_validation", False): trainer.add_event_handler(Events.STARTED, run_validation) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)), run_validation) trainer.add_event_handler(Events.COMPLETED, run_validation) score_metric_name = "mIoU_bg" if hasattr(config, "es_patience"): common.add_early_stopping_by_val_score(config.es_patience, evaluator, trainer, metric_name=score_metric_name) if dist.get_rank() == 0: tb_logger = common.setup_tb_logging(config.output_path.as_posix(), trainer, optimizer, evaluators={ "training": train_evaluator, "validation": evaluator }) if with_mlflow_logging: common.setup_mlflow_logging(trainer, optimizer, evaluators={ "training": train_evaluator, "validation": evaluator }) if with_plx_logging: common.setup_plx_logging(trainer, optimizer, evaluators={ "training": train_evaluator, "validation": evaluator }) common.save_best_model_by_val_score(config.output_path.as_posix(), evaluator, model, metric_name=score_metric_name, trainer=trainer) # Log train/val predictions: tb_logger.attach(evaluator, log_handler=predictions_gt_images_handler( img_denormalize_fn=config.img_denormalize, n_images=15, another_engine=trainer, prefix_tag="validation"), event_name=Events.EPOCH_COMPLETED) log_train_predictions = getattr(config, "log_train_predictions", False) if log_train_predictions: tb_logger.attach(train_evaluator, log_handler=predictions_gt_images_handler( img_denormalize_fn=config.img_denormalize, n_images=15, another_engine=trainer, prefix_tag="validation"), event_name=Events.EPOCH_COMPLETED) trainer.run(train_loader, max_epochs=config.num_epochs)
def attach_decorators(trainer, SR, feature_extractor, domain_classifier, resolution_classifier, sr_classif_critic, optim, loader): timer = Timer(average=True) checkpoint_handler = ModelCheckpoint( args.output_dir + '/checkpoints/domain_adaptation_training/', 'training', save_interval=1, n_saved=300, require_empty=False, iteration=args.epoch_c) monitoring_metrics = [ 'tgt_loss', 'src_loss', 'sr_loss', 'loss', 'GP', 'res_down_loss', 'res_up_loss', 'tv_loss', 'vgg_loss' ] RunningAverage(alpha=0.98, output_transform=lambda x: x['tgt_loss']).attach( trainer, 'tgt_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['src_loss']).attach( trainer, 'src_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['sr_loss']).attach( trainer, 'sr_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['loss']).attach( trainer, 'loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['GP']).attach(trainer, 'GP') # RunningAverage(alpha=0.98, output_transform=lambda x: x['g_loss']).attach(trainer, 'g_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['res_down_loss']).attach( trainer, 'res_down_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['res_up_loss']).attach( trainer, 'res_up_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['tv_loss']).attach( trainer, 'tv_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['vgg_loss']).attach( trainer, 'vgg_loss') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) trainer.add_event_handler( event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'feature_extractor': feature_extractor, 'SR': SR, # 'optim_feature': optim_feature, # 'optim_domain_classif': optim_domain_classif, # 'optim_res_classif': optim_res_classif, 'optim': optim, # 'optim_sr_critic': optim_sr_critic, 'domain_D': domain_classifier, 'res_D': resolution_classifier, 'sr_D': sr_classif_critic }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = os.path.join(args.output_dir, LOGS_FNAME) columns = engine.state.metrics.keys() values = [ str(round(value, 5)) for value in engine.state.metrics.values() ] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) i = (engine.state.iteration % len(loader)) message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=args.epochs, i=i, max_i=len(loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.ITERATION_COMPLETED) def save_real_example(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: if (engine.state.iteration - 1) % PRINT_FREQ == 0: if not os.path.exists(args.output_dir + '/imgs/domain_adaptation_training/'): os.makedirs(args.output_dir + '/imgs/domain_adaptation_training/') px, py, px2, py2, px_up, _, px2_up, _ = engine.state.batch img = SR(feature_extractor(px2.cuda())) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', predtgt_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(img, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', targetY_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(py2, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', targetX_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(px2, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', sourceX_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(px, path) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', sourceY_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(py, path) img = SR(feature_extractor(px.cuda())) path = os.path.join( args.output_dir + '/imgs/domain_adaptation_training/', predsrc_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(img, path) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler( engine, { 'feature_extractor_{}'.format(engine.state.iteration): feature_extractor, 'SR_{}'.format(engine.state.iteration): SR, 'DOMAIN_D_{}'.format(engine.state.iteration): domain_classifier, 'RES_D_{}'.format(engine.state.iteration): resolution_classifier, 'SR_D_{}'.format(engine.state.iteration): sr_classif_critic, 'OPTIM_{}'.format(engine.state.iteration): optim }) else: raise e @trainer.on(Events.STARTED) def loaded(engine): if args.epoch_c != 0: engine.state.epoch = args.epoch_c engine.state.iteration = args.epoch_c * len(loader)
engine.add_event_handler( ignite.engine.Events.ITERATION_STARTED, handlers.calculate_gdl_lambda ) engine.add_event_handler( ignite.engine.Events.ITERATION_COMPLETED(every=args.log_every), handlers.log_summaries, torch.utils.tensorboard.SummaryWriter(logs_directory), ) engine.add_event_handler( ignite.engine.Events.ITERATION_COMPLETED(every=args.checkpoint_every), handlers.save_checkpoint, model, optimizer, lr_scheduler, amp, args.checkpoint_last, checkpoint_directory, ) else: engine.add_event_handler( ignite.engine.Events.ITERATION_COMPLETED, handlers.save_output, outputs_directory, ) pbar = ProgressBar() pbar.attach(engine, output_transform=lambda output: {"loss": output[("loss")]}) e = engine.run(data=data_loader, max_epochs=args.epochs if args.mode == utils.TRAINING else 1)
device = torch.device(args.device) tfms = albu.Compose([ albu.Resize(256, 256), albu.CenterCrop(224, 224), albu.Normalize(), ToTensor(), ]) dataset_dir = os.path.join(os.environ.get('DATASET_DIR'), 'imagenet') dataset = Imagenet(root_dir=dataset_dir, split='val', transforms=tfms) train_loader = DataLoader(dataset, batch_size=args.batch_size, shuffle=False, num_workers=8) model = (getattr(vgg, args.model))(3, 1000) state_dict = torch.load(args.state_dict) model.load_state_dict(state_dict) model = model.cuda() evaluator = create_classification_evaluator(model, device=device) ProgressBar(persist=True).attach(evaluator) state = evaluator.run(train_loader) print(state.metrics)
def train(self, config, **kwargs): """Trains a model on the given configurations. :param config: A training configuration. Note that all parameters in the config can also be manually adjusted with --ARG=VALUE :param **kwargs: parameters to overwrite yaml config """ from pycocoevalcap.cider.cider import Cider conf = train_util.parse_config_or_kwargs(config, **kwargs) conf["seed"] = self.seed outputdir = os.path.join( conf["outputpath"], conf["model"], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) # Early init because of creating dir checkpoint_handler = ModelCheckpoint( outputdir, "run", n_saved=1, require_empty=False, create_dir=True, score_function=lambda engine: engine.state.metrics["score"], score_name="loss") logger = train_util.genlogger(os.path.join(outputdir, "train.log")) # print passed config parameters logger.info("Storing files in: {}".format(outputdir)) train_util.pprint_dict(conf, logger.info) zh = conf["zh"] vocabulary = torch.load(conf["vocab_file"]) train_loader, val_loader, info = self._get_dataloaders(conf, vocabulary) conf["inputdim"] = info["inputdim"] val_key2refs = info["val_key2refs"] logger.info("<== Estimating Scaler ({}) ==>".format(info["scaler"].__class__.__name__)) logger.info( "Feature: {} Input dimension: {} Vocab Size: {}".format( conf["feature_file"], info["inputdim"], len(vocabulary))) model = self._get_model(conf, len(vocabulary)) model = model.to(self.device) train_util.pprint_dict(model, logger.info, formatter="pretty") optimizer = getattr( torch.optim, conf["optimizer"] )(model.parameters(), **conf["optimizer_args"]) train_util.pprint_dict(optimizer, logger.info, formatter="pretty") XE_criterion = torch.nn.CrossEntropyLoss().to(self.device) seq_criterion = torch.nn.CosineEmbeddingLoss().to(self.device) crtrn_imprvd = train_util.criterion_improver(conf['improvecriterion']) def _train_batch(engine, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() output = self._forward( model, batch, "train", ss_ratio=conf["ss_args"]["ss_ratio"]) XE_loss = XE_criterion(output["packed_logits"], output["word_targets"]) seq_loss = seq_criterion(output["seq_outputs"], output["sentence_targets"], torch.ones(batch[0].shape[0]).to(self.device)) loss = XE_loss + seq_loss * conf["seq_loss_ratio"] loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5) optimizer.step() output["XE_loss"] = XE_loss.item() output["seq_loss"] = seq_loss.item() output["loss"] = loss.item() return output trainer = Engine(_train_batch) RunningAverage(output_transform=lambda x: x["loss"]).attach(trainer, "running_loss") pbar = ProgressBar(persist=False, ascii=True, ncols=100) pbar.attach(trainer, ["running_loss"]) key2pred = {} def _inference(engine, batch): model.eval() keys = batch[3] with torch.no_grad(): output = self._forward(model, batch, "validation") output["seq_loss"] = seq_criterion(output["seq_outputs"], output["sentence_targets"], torch.ones(len(keys)).to(self.device)) seqs = output["seqs"].cpu().numpy() for (idx, seq) in enumerate(seqs): if keys[idx] in key2pred: continue candidate = self._convert_idx2sentence(seq, vocabulary, zh) key2pred[keys[idx]] = [candidate,] return output metrics = { "loss": Average(output_transform=lambda x: x["loss"]), "XE_loss": Average(output_transform=lambda x: x["XE_loss"]), "seq_loss": Average(output_transform=lambda x: x["seq_loss"]), } evaluator = Engine(_inference) def eval_val(engine, key2pred, key2refs): scorer = Cider(zh=zh) score, scores = scorer.compute_score(key2refs, key2pred) engine.state.metrics["score"] = score key2pred.clear() evaluator.add_event_handler( Events.EPOCH_COMPLETED, eval_val, key2pred, val_key2refs) for name, metric in metrics.items(): metric.attach(trainer, name) metrics["seq_loss"].attach(evaluator, "seq_loss") trainer.add_event_handler( Events.EPOCH_COMPLETED, train_util.log_results, evaluator, val_loader, logger.info, metrics.keys(), ["seq_loss", "score"]) if conf["ss"]: trainer.add_event_handler( Events.GET_BATCH_COMPLETED, train_util.update_ss_ratio, conf, len(train_loader)) evaluator.add_event_handler( Events.EPOCH_COMPLETED, train_util.save_model_on_improved, crtrn_imprvd, "score", { "model": model.state_dict(), "config": conf, "scaler": info["scaler"] }, os.path.join(outputdir, "saved.pth")) scheduler = getattr(torch.optim.lr_scheduler, conf["scheduler"])( optimizer, **conf["scheduler_args"]) evaluator.add_event_handler( Events.EPOCH_COMPLETED, train_util.update_lr, scheduler, "score") evaluator.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, } ) trainer.run(train_loader, max_epochs=conf["epochs"]) return outputdir
def main( dataset, dataroot, z_dim, g_filters, d_filters, batch_size, epochs, learning_rate, beta_1, saved_G, saved_D, seed, n_workers, device, alpha, output_dir, ): # seed check_manual_seed(seed) # data dataset, num_channels = check_dataset(dataset, dataroot) loader = data.DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) # netowrks netG = Generator(z_dim, g_filters, num_channels).to(device) netD = Discriminator(num_channels, d_filters).to(device) # criterion bce = nn.BCELoss() # optimizers optimizerG = optim.Adam(netG.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=learning_rate, betas=(beta_1, 0.999)) # load pre-trained models if saved_G: netG.load_state_dict(torch.load(saved_G)) if saved_D: netD.load_state_dict(torch.load(saved_D)) # misc real_labels = torch.ones(batch_size, device=device) fake_labels = torch.zeros(batch_size, device=device) fixed_noise = torch.randn(batch_size, z_dim, 1, 1, device=device) def get_noise(): return torch.randn(batch_size, z_dim, 1, 1, device=device) # The main function, processing a batch of examples def step(engine, batch): # unpack the batch. It comes from a dataset, so we have <images, labels> pairs. Discard labels. real, _ = batch real = real.to(device) # ----------------------------------------------------------- # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z))) netD.zero_grad() # train with real output = netD(real) errD_real = bce(output, real_labels) D_x = output.mean().item() errD_real.backward() # get fake image from generator noise = get_noise() fake = netG(noise) # train with fake output = netD(fake.detach()) errD_fake = bce(output, fake_labels) D_G_z1 = output.mean().item() errD_fake.backward() # gradient update errD = errD_real + errD_fake optimizerD.step() # ----------------------------------------------------------- # (2) Update G network: maximize log(D(G(z))) netG.zero_grad() # Update generator. We want to make a step that will make it more likely that discriminator outputs "real" output = netD(fake) errG = bce(output, real_labels) D_G_z2 = output.mean().item() errG.backward() # gradient update optimizerG.step() return { "errD": errD.item(), "errG": errG.item(), "D_x": D_x, "D_G_z1": D_G_z1, "D_G_z2": D_G_z2 } # ignite objects trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, CKPT_PREFIX, n_saved=10, require_empty=False) timer = Timer(average=True) # attach running average metrics monitoring_metrics = ["errD", "errG", "D_x", "D_G_z1", "D_G_z2"] RunningAverage(alpha=alpha, output_transform=lambda x: x["errD"]).attach( trainer, "errD") RunningAverage(alpha=alpha, output_transform=lambda x: x["errG"]).attach( trainer, "errG") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_x"]).attach(trainer, "D_x") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z1"]).attach( trainer, "D_G_z1") RunningAverage(alpha=alpha, output_transform=lambda x: x["D_G_z2"]).attach( trainer, "D_G_z2") # attach progress bar pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) @trainer.on(Events.ITERATION_COMPLETED(every=PRINT_FREQ)) def print_logs(engine): fname = os.path.join(output_dir, LOGS_FNAME) columns = [ "iteration", ] + list(engine.state.metrics.keys()) values = [ str(engine.state.iteration), ] + [str(round(value, 5)) for value in engine.state.metrics.values()] with open(fname, "a") as f: if f.tell() == 0: print("\t".join(columns), file=f) print("\t".join(values), file=f) message = f"[{engine.state.epoch}/{epochs}][{engine.state.iteration % len(loader)}/{len(loader)}]" for name, value in zip(columns, values): message += f" | {name}: {value}" pbar.log_message(message) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_fake_example(engine): fake = netG(fixed_noise) path = os.path.join(output_dir, FAKE_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(fake.detach(), path, normalize=True) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def save_real_example(engine): img, y = engine.state.batch path = os.path.join(output_dir, REAL_IMG_FNAME.format(engine.state.epoch)) vutils.save_image(img, path, normalize=True) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ "netG": netG, "netD": netD }) # automatically adding handlers via a special `attach` method of `Timer` handler timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() # adding handlers using `trainer.on` decorator API @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use("agg") import matplotlib.pyplot as plt import numpy as np import pandas as pd except ImportError: warnings.warn( "Loss plots will not be generated -- pandas or matplotlib not found" ) else: df = pd.read_csv(os.path.join(output_dir, LOGS_FNAME), delimiter="\t", index_col="iteration") _ = df.plot(subplots=True, figsize=(20, 20)) _ = plt.xlabel("Iteration number") fig = plt.gcf() path = os.path.join(output_dir, PLOT_FNAME) fig.savefig(path) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") create_plots(engine) checkpoint_handler(engine, { "netG_exception": netG, "netD_exception": netD }) else: raise e # Setup is done. Now let's run the training trainer.run(loader, epochs)
def run(conf: DictConfig): epochs = conf.train.epochs epoch_length = conf.train.epoch_length torch.manual_seed(conf.general.seed) dist_conf = conf.distributed local_rank = dist_conf.local_rank backend = dist_conf.backend distributed = backend is not None use_tpu = conf.tpu.enabled if use_tpu: rank = xm.get_ordinal() num_replicas = xm.xrt_world_size() device = xm.xla_device() else: if distributed: rank = dist.get_rank() num_replicas = dist.get_world_size() torch.cuda.set_device(local_rank) else: rank = 0 num_replicas = 1 torch.cuda.set_device(conf.general.gpu) device = torch.device('cuda') if rank == 0: print(conf.pretty()) if num_replicas > 1: epoch_length = epoch_length // num_replicas loader_args = dict(rank=rank, num_replicas=num_replicas) else: loader_args = dict() train_dl = create_train_loader(conf.data.train, epoch_length=epoch_length, **loader_args) valid_dl = create_val_loader(conf.data.val, **loader_args) train_sampler = train_dl.sampler if epoch_length < 1: epoch_length = len(train_dl) if use_tpu: train_dl = pl.ParallelLoader(train_dl, [device]) valid_dl = pl.ParallelLoader(valid_dl, [device]) model = instantiate(conf.model).to(device) if distributed: model = DistributedDataParallel(model, device_ids=[ local_rank, ], output_device=local_rank) model.to_y = model.module.to_y if rank == 0 and conf.logging.model: print(model) loss = instantiate(conf.loss) optim = instantiate(conf.optimizer, filter(lambda x: x.requires_grad, model.parameters())) metrics = create_metrics(loss.keys(), device if distributed else None) build_trainer_fn = create_tpu_trainer if use_tpu else create_trainer trainer = build_trainer_fn(model, loss, optim, device, conf, metrics) evaluator = create_evaluator(model, loss, device, metrics) every_iteration = Events.ITERATION_COMPLETED if 'lr_scheduler' in conf.keys(): # TODO: total_steps is wrong, it works only for one-cycle lr_scheduler = instantiate(conf.lr_scheduler, optim, total_steps=epoch_length) trainer.add_event_handler(every_iteration, lambda _: lr_scheduler.step()) if isinstance(lr_scheduler, torch.optim.lr_scheduler.OneCycleLR): initial_state = lr_scheduler.state_dict() trainer.add_event_handler( Events.ITERATION_COMPLETED(every=epoch_length), lambda _: lr_scheduler.load_state_dict(initial_state)) else: lr_scheduler = None trainer.add_event_handler(every_iteration, TerminateOnNan()) cp = conf.train.checkpoints to_save = { 'trainer': trainer, 'model': model.module if distributed else model, 'optimizer': optim, 'lr_scheduler': lr_scheduler } save_path = cp.get('base_dir', os.getcwd()) if rank == 0: log_freq = conf.logging.iter_freq log_event = Events.ITERATION_COMPLETED(every=log_freq) pbar = ProgressBar(persist=False) for engine, name in zip([trainer, evaluator], ['train', 'val']): engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start) engine.add_event_handler(log_event, log_iter, trainer, pbar, name, log_freq) engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch, trainer, name) pbar.attach(engine, metric_names=loss.keys()) if 'load' in cp.keys() and cp.load: logging.info("Resume from a checkpoint: {}".format(cp.load)) trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp, pbar) logging.info("Saving checkpoints to {}".format(save_path)) if rank == 0 or use_tpu: max_cp = max(int(cp.get('max_checkpoints', 1)), 1) Saver = TpuDiskSaver if use_tpu else DiskSaver save = Saver(save_path, create_dir=True, require_empty=True) make_checkpoint = Checkpoint(to_save, save, n_saved=max_cp) cp_iter = cp.interval_iteration cp_epoch = cp.interval_epoch if cp_iter > 0: save_event = Events.ITERATION_COMPLETED(every=cp_iter) trainer.add_event_handler(save_event, make_checkpoint) if cp_epoch > 0: if cp_iter < 1 or epoch_length % cp_iter: save_event = Events.EPOCH_COMPLETED(every=cp_epoch) trainer.add_event_handler(save_event, make_checkpoint) if 'load' in cp.keys() and cp.load: Checkpoint.load_objects(to_load=to_save, checkpoint=torch.load(cp.load, map_location=device)) assert train_sampler is not None trainer.add_event_handler( Events.EPOCH_STARTED, lambda e: train_sampler.set_epoch(e.state.epoch - 1)) def run_validation(e: Engine): if distributed: torch.cuda.synchronize(device) if use_tpu: xm.rendezvous('validate_{}'.format(e.state.iteration)) valid_it = valid_dl.per_device_loader(device) evaluator.run(valid_it, epoch_length=len(valid_dl)) else: evaluator.run(valid_dl) eval_event = Events.EPOCH_COMPLETED(every=conf.validate.interval) trainer.add_event_handler(eval_event, run_validation) try: if conf.train.skip: evaluator.run(valid_dl) else: loader = train_dl if use_tpu: # need to catch StopIteration before ignite, otherwise it will crash loader = iter(_regenerate(train_dl, device)) trainer.run(loader, max_epochs=epochs, epoch_length=epoch_length) except Exception as e: import traceback print(traceback.format_exc()) if rank == 0: pbar.close()
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--model_checkpoint", type=str, default="t5-small", help="Path, url or short name of the model") parser.add_argument("--max_history", type=int, default=7, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=10, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=10, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=12, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6e-4, help="Learning rate") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--save_name", type=str, default="") parser.add_argument("--mask_ratio", type=float, default=0.15) parser.add_argument("--objective", type=str, default="span_denosing", help="response_generation, span_denosing, both") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") tokenizer = T5Tokenizer.from_pretrained(args.model_checkpoint) model = T5ForConditionalGeneration.from_pretrained(args.model_checkpoint) model.to(args.device) # Add special tokens if they are not already added add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) def collate_fn(data): batch = { "corrupted_context": [], "context": [], "target": [], "response": [] } padded_dataset = {} batch_size = len(data) resp_sos, context_sos = tokenizer.convert_tokens_to_ids([ "<go_r>", "<go_b>", ]) for x in data: corrupted_context = ["fill : "] target = [] length = len(x["context_words"]) mask_bool = random_spans_noise_mask(length=length, noise_density=args.mask_ratio, mean_noise_span_length=3.0) mask_id = 0 #print(mask_bool) for i in range(length): if mask_bool[i]: if i > 0 and mask_bool[i - 1]: target.append(x["context_words"][i]) else: target.append(f"<extra_id_{mask_id}>") target.append(x["context_words"][i]) corrupted_context.append(f"<extra_id_{mask_id}>") mask_id += 1 else: corrupted_context.append(x["context_words"][i]) target.append("<eos_b>") batch["context"].append( tokenizer.encode("response : " + " ".join(x["context_words"]))) batch["corrupted_context"].append( tokenizer.encode(" ".join(corrupted_context))) batch["target"].append(tokenizer.encode(" ".join(target))) batch["response"].append(tokenizer.encode(x["response"])) # print(" ".join(x["context_words"])) # print(" ".join(corrupted_context)) # print(" ".join(target)) # print("") # print(tokenizer.decode(batch["corrupted_context"][-1])) # print(tokenizer.decode(batch["target"][-1])) # print(tokenizer.decode(batch["response"][-1])) # print("") context_ids, context_masks = padInput(batch["context"]) input_ids, masks = padInput(batch["corrupted_context"]) target_ids, target_inputs = padOutput(batch["target"]) response_ids, response_inputs = padOutput(batch["response"]) #inputs padded_dataset["input_ids"] = torch.tensor(input_ids, dtype=torch.long) padded_dataset["masks"] = torch.tensor(masks, dtype=torch.long) padded_dataset["context_ids"] = torch.tensor(context_ids, dtype=torch.long) padded_dataset["context_masks"] = torch.tensor(context_masks, dtype=torch.long) padded_dataset["target_ids"] = torch.tensor(target_ids, dtype=torch.long) padded_dataset["response_ids"] = torch.tensor(response_ids, dtype=torch.long) padded_dataset["target_inputs"] = torch.tensor(np.concatenate((np.ones( (batch_size, 1)) * context_sos, target_inputs[:, :-1]), axis=1), dtype=torch.long) padded_dataset["response_inputs"] = torch.tensor(np.concatenate( (np.ones((batch_size, 1)) * resp_sos, response_inputs[:, :-1]), axis=1), dtype=torch.long) return padded_dataset logger.info("Prepare datasets") train_dataset, valid_dataset, train_sampler, valid_sampler = get_data( args, tokenizer) train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, shuffle=(not args.distributed), collate_fn=collate_fn, num_workers=4) val_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, shuffle=False, collate_fn=collate_fn, num_workers=4) logger.info("Train dataset length: {}".format(len(train_dataset))) logger.info("Valid dataset length: {}".format(len(valid_dataset))) # for batch in train_loader: # #print(batch) # exit(0) # Training function and trainer def update(engine, batch): model.train() batch = tuple(batch[input_name].to(args.device) for input_name in MODEL_INPUTS) input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch # print("input") # print(tokenizer.decode(input_ids[0, :].tolist())) # print("context_ids") # print(tokenizer.decode(context_ids[0, :].tolist())) # print("target") # print(tokenizer.decode(target_ids[0, :].tolist())) # print("target In") # print(tokenizer.decode(target_inputs[0, :].tolist())) # print("response_ids") # print(tokenizer.decode(response_ids[0, :].tolist())) # print("response_inputs") # print(tokenizer.decode(response_inputs[0, :].tolist())) #exit(0) outputs = model(input_ids, attention_mask=masks, decoder_input_ids=target_inputs, lm_labels=target_ids) context_loss = outputs[0] outputs = model(context_ids, attention_mask=context_masks, decoder_input_ids=response_inputs, lm_labels=response_ids) resp_loss = outputs[0] loss = (context_loss + resp_loss) / args.gradient_accumulation_steps loss = (context_loss) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(batch[input_name].to(args.device) for input_name in MODEL_INPUTS) input_ids, masks, context_ids, context_masks, target_ids, target_inputs, response_ids, response_inputs = batch outputs = model( input_ids, attention_mask=masks, decoder_input_ids=target_inputs #, lm_labels=target_ids ) context_logits = outputs[0] outputs = model( context_ids, attention_mask=context_masks, decoder_input_ids=response_inputs, #lm_labels=response_ids ) resp_logits = outputs[0] context_logits_flat_shifted = context_logits.view( -1, context_logits.size(-1)) context_labels_flat_shifted = target_ids.view(-1) resp_logits_flat_shifted = resp_logits.view( -1, resp_logits.size(-1)) resp_labels_flat_shifted = response_ids.view(-1) return (context_logits_flat_shifted, resp_logits_flat_shifted), (context_labels_flat_shifted, resp_labels_flat_shifted) #return (context_logits_flat_shifted, context_logits_flat_shifted), (context_labels_flat_shifted, context_labels_flat_shifted) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) # if args.eval_before_start: # trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "span": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])), "response": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_span": MetricsLambda(average_distributed_scalar, metrics["span"], args), "average_response": MetricsLambda(average_distributed_scalar, metrics["response"], args) }) metrics["average_response"] = MetricsLambda(math.exp, metrics["average_response"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) if not os.path.exists(f"pretrained_model/{args.save_name}"): os.makedirs(f"pretrained_model/{args.save_name}") log_dir = f"pretrained_model/{args.save_name}" tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def run_once(self): log_dir = self.log_dir misc.check_manual_seed(self.seed) train_pairs, valid_pairs = dataset.prepare_data_CANCER() print(len(train_pairs)) # --------------------------- Dataloader train_augmentors = self.train_augmentors() train_dataset = dataset.DatasetSerial( train_pairs[:], shape_augs=iaa.Sequential(train_augmentors[0]), input_augs=iaa.Sequential(train_augmentors[1])) infer_augmentors = self.infer_augmentors() infer_dataset = dataset.DatasetSerial( valid_pairs[:], shape_augs=iaa.Sequential(infer_augmentors)) train_loader = data.DataLoader(train_dataset, num_workers=self.nr_procs_train, batch_size=self.train_batch_size, shuffle=True, drop_last=True) valid_loader = data.DataLoader(infer_dataset, num_workers=self.nr_procs_valid, batch_size=self.infer_batch_size, shuffle=True, drop_last=False) # --------------------------- Training Sequence if self.logging: misc.check_log_dir(log_dir) device = 'cuda' # networks input_chs = 3 net = DenseNet(input_chs, self.nr_classes) net = torch.nn.DataParallel(net).to(device) # print(net) # optimizers optimizer = optim.Adam(net.parameters(), lr=self.init_lr) scheduler = optim.lr_scheduler.StepLR(optimizer, self.lr_steps) # load pre-trained models if self.load_network: saved_state = torch.load(self.save_net_path) net.load_state_dict(saved_state) # trainer = Engine(lambda engine, batch: self.train_step( net, batch, optimizer, 'cuda')) inferer = Engine( lambda engine, batch: self.infer_step(net, batch, 'cuda')) train_output = ['loss', 'acc'] infer_output = ['prob', 'true'] ## if self.logging: checkpoint_handler = ModelCheckpoint(log_dir, self.chkpts_prefix, save_interval=1, n_saved=120, require_empty=False) # adding handlers using `trainer.add_event_handler` method API trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={'net': net}) timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) timer.attach(inferer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) # attach running average metrics computation # decay of EMA to 0.95 to match tensorpack default RunningAverage(alpha=0.95, output_transform=lambda x: x['loss']).attach( trainer, 'loss') RunningAverage(alpha=0.95, output_transform=lambda x: x['acc']).attach( trainer, 'acc') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) pbar.attach(inferer) # adding handlers using `trainer.on` decorator API @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') checkpoint_handler(engine, {'net_exception': net}) else: raise e # writer for tensorboard logging if self.logging: writer = SummaryWriter(log_dir=log_dir) json_log_file = log_dir + '/stats.json' with open(json_log_file, 'w') as json_file: json.dump({}, json_file) # create empty file @trainer.on(Events.EPOCH_STARTED) def log_lrs(engine): if self.logging: lr = float(optimizer.param_groups[0]['lr']) writer.add_scalar("lr", lr, engine.state.epoch) # advance scheduler clock scheduler.step() #### def update_logs(output, epoch, prefix, color): # print values and convert max_length = len(max(output.keys(), key=len)) for metric in output: key = colored(prefix + '-' + metric.ljust(max_length), color) print('------%s : ' % key, end='') print('%0.7f' % output[metric]) if 'train' in prefix: lr = float(optimizer.param_groups[0]['lr']) key = colored(prefix + '-' + 'lr'.ljust(max_length), color) print('------%s : %0.7f' % (key, lr)) if not self.logging: return # create stat dicts stat_dict = {} for metric in output: metric_value = output[metric] stat_dict['%s-%s' % (prefix, metric)] = metric_value # json stat log file, update and overwrite with open(json_log_file) as json_file: json_data = json.load(json_file) current_epoch = str(epoch) if current_epoch in json_data: old_stat_dict = json_data[current_epoch] stat_dict.update(old_stat_dict) current_epoch_dict = {current_epoch: stat_dict} json_data.update(current_epoch_dict) with open(json_log_file, 'w') as json_file: json.dump(json_data, json_file) # log values to tensorboard for metric in output: writer.add_scalar(prefix + '-' + metric, output[metric], current_epoch) @trainer.on(Events.EPOCH_COMPLETED) def log_train_running_results(engine): """ running training measurement """ training_ema_output = engine.state.metrics # update_logs(training_ema_output, engine.state.epoch, prefix='train-ema', color='green') #### def get_init_accumulator(output_names): return {metric: [] for metric in output_names} import cv2 def process_accumulated_output(output): def uneven_seq_to_np(seq, batch_size=self.infer_batch_size): if self.infer_batch_size == 1: return np.squeeze(seq) item_count = batch_size * (len(seq) - 1) + len(seq[-1]) cat_array = np.zeros((item_count, ) + seq[0][0].shape, seq[0].dtype) for idx in range(0, len(seq) - 1): cat_array[idx * batch_size:(idx + 1) * batch_size] = seq[idx] cat_array[(idx + 1) * batch_size:] = seq[-1] return cat_array # prob = uneven_seq_to_np(output['prob']) true = uneven_seq_to_np(output['true']) # cmap = plt.get_cmap('jet') # epi = prob[...,1] # epi = (cmap(epi) * 255.0).astype('uint8') # cv2.imwrite('sample.png', cv2.cvtColor(epi, cv2.COLOR_RGB2BGR)) pred = np.argmax(prob, axis=-1) true = np.squeeze(true) # deal with ignore index pred = pred.flatten() true = true.flatten() pred = pred[true != 0] - 1 true = true[true != 0] - 1 acc = np.mean(pred == true) inter = (pred * true).sum() total = (pred + true).sum() dice = 2 * inter / total # proc_output = dict(acc=acc, dice=dice) return proc_output # @trainer.on(Events.EPOCH_COMPLETED) # def infer_valid(engine): # """ # inference measurement # """ # inferer.accumulator = get_init_accumulator(infer_output) # inferer.run(valid_loader) # output_stat = process_accumulated_output(inferer.accumulator) # update_logs(output_stat, engine.state.epoch, prefix='valid', color='red') @inferer.on(Events.ITERATION_COMPLETED) def accumulate_outputs(engine): batch_output = engine.state.output for key, item in batch_output.items(): engine.accumulator[key].extend([item]) ### #Setup is done. Now let's run the training trainer.run(train_loader, self.nr_epochs) return
def finetune_model(args, model, loader): optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) def update(engine, batch): model.train() batch = tuple(batch[input_name].to(args.device) for input_name in MODEL_INPUTS) input_ids, lm_labels, token_type_ids, nodes_ids, attention_mask = batch if (not args.graph and not args.edge_list): nodes_ids = None if (not args.unilm): attention_mask = None (lm_loss), *_ = model(input_ids=input_ids, token_type_ids=token_type_ids, labels=lm_labels, nodes=nodes_ids, attention_mask=attention_mask) loss = lm_loss / args.gradient_accumulation_steps loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(batch[input_name].to(args.device) for input_name in MODEL_INPUTS) input_ids, lm_labels, token_type_ids, nodes_ids, attention_mask = batch if (not args.graph and not args.edge_list): nodes_ids = None if (not args.unilm): attention_mask = None # if we dont send labels to model, it doesnt return losses lm_logits, *_ = model(input_ids=input_ids, token_type_ids=token_type_ids, nodes=nodes_ids, attention_mask=attention_mask) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, ) trainer = Engine(update) evaluator = Engine(inference) trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(loader)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(), output_transform=lambda x: (x[0][0], x[1][0])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) trainer.run(loader, max_epochs=args.n_epochs) return model
#evaluator.add_event_handler(Events.COMPLETED, es_handler) #setup_logger(es_handler._logger) # Clear cuda cache between training/testing def empty_cuda_cache(engine): torch.cuda.empty_cache() import gc gc.collect() trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache) #train_evaluator.add_event_handler(Events.COMPLETED, empty_cuda_cache) num_epochs = 80 ProgressBar(persist=True).attach(trainer) trainer.run(train_loader, max_epochs=num_epochs) print('The results are') print(train_evaluator.state.metrics) print(evaluator.state.metrics) # Dill routine model_copy = dill.dumps(model) torch.save(model_copy, 'complete_model_final.pt') torch.save(train_evaluator.state.metrics, 'metrics_final.pt')
def setup_training(self, base_model, classifier, setops_model): # # Create the train and test dataset. # train_loader, train_subset_loader, val_loader = self.setup_datasets() logging.info("Setup logging and controls.") # # Setup metrics plotters. # mlflow_logger = MlflowLogger() # # Setup the optimizer. # logging.info("Setup optimizers and losses.") parameters = list(base_model.parameters()) parameters += list(setops_model.parameters()) if self.train_classifier: parameters += list(classifier.parameters()) if self.optimizer_cls == "SGD": optimizer = torch.optim.SGD(parameters, lr=self.lr1, momentum=0.9, weight_decay=self.weight_decay) else: optimizer = torch.optim.Adam(parameters, lr=self.lr1, weight_decay=self.weight_decay) if self.focal_loss: attr_loss = FocalLoss().cuda() else: attr_loss = torch.nn.MultiLabelSoftMarginLoss().cuda() recon_loss = torch.nn.MSELoss( ) if self.recon_loss == "mse" else torch.nn.L1Loss() # # Setup the trainer object and its logging. # logging.info("Setup trainer") trainer = create_setops_trainer(base_model, classifier, setops_model, optimizer, criterion1=attr_loss, criterion2=recon_loss.cuda(), params_object=self, device=self.device) ProgressBar(bar_format=None).attach(trainer) mlflow_logger.attach(engine=trainer, prefix="Train ", plot_event=Events.ITERATION_COMPLETED, update_period=LOG_INTERVAL, output_transform=lambda x: x) # # Define the evaluation metrics. # logging.info("Setup evaluator") evaluation_losses = { 'real class loss': Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class a"], o["targets"]["class a"])) + \ Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["real class b"], o["targets"]["class b"])), 'fake class loss': Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class a"], o["targets"]["class a"])) + \ Loss(torch.nn.MultiLabelSoftMarginLoss().cuda(), lambda o: (o["outputs"]["fake class b"], o["targets"]["class b"])), '{} fake loss'.format(self.recon_loss): (Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed a"], o["targets"]["embed a"])) + Loss(recon_loss.cuda(), lambda o: (o["outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2, } labels_list = train_loader.dataset.labels_list mask = labels_list_to_1hot(labels_list, labels_list).astype(np.bool) evaluation_accuracies = { 'real class acc': (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "real class a"], o["targets"]["class a"])) + MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "real class b"], o["targets"]["class b"]))) / 2, 'fake class acc': (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "fake class a"], o["targets"]["class a"])) + MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "fake class b"], o["targets"]["class b"]))) / 2, 'S class acc': (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "a_S_b class"], o["targets"]["a_S_b class"])) + MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "b_S_a class"], o["targets"]["b_S_a class"]))) / 2, 'I class acc': (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "a_I_b class"], o["targets"]["a_I_b class"])) + MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "b_I_a class"], o["targets"]["a_I_b class"]))) / 2, 'U class acc': (MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "a_U_b class"], o["targets"]["a_U_b class"])) + MultiLabelSoftMarginIOUaccuracy(lambda o: (o["outputs"][ "b_U_a class"], o["targets"]["a_U_b class"]))) / 2, 'MSE fake acc': (EWMeanSquaredError(lambda o: (o["outputs"]["fake embed a"], o[ "targets"]["embed a"])) + EWMeanSquaredError(lambda o: (o[ "outputs"]["fake embed b"], o["targets"]["embed b"]))) / 2, 'real mAP': mAP(mask=mask, output_transform=lambda o: (o["outputs"]["real class a"], o["targets"]["class a"])), 'fake mAP': mAP(mask=mask, output_transform=lambda o: (o["outputs"]["fake class a"], o["targets"]["class a"])), 'S mAP': mAP(mask=mask, output_transform=lambda o: (o["outputs"]["a_S_b class"], o["targets"]["a_S_b class"])), 'I mAP': mAP(mask=mask, output_transform=lambda o: (o["outputs"]["a_I_b class"], o["targets"]["a_I_b class"])), 'U mAP': mAP(mask=mask, output_transform=lambda o: (o["outputs"]["a_U_b class"], o["targets"]["a_U_b class"])), } # # Setup the training evaluator object and its logging. # train_evaluator = create_setops_evaluator( base_model, classifier, setops_model, metrics=evaluation_accuracies.copy(), device=self.device) mlflow_logger.attach(engine=train_evaluator, prefix="Train Eval ", plot_event=Events.EPOCH_COMPLETED, metric_names=list(evaluation_accuracies.keys())) ProgressBar(bar_format=None).attach(train_evaluator) # # Setup the evaluator object and its logging. # evaluator = create_setops_evaluator(base_model, classifier, setops_model, metrics={ **evaluation_losses, **evaluation_accuracies }, device=self.device) mlflow_logger.attach(engine=evaluator, prefix="Eval ", plot_event=Events.EPOCH_COMPLETED, metric_names=list({ **evaluation_losses, **evaluation_accuracies }.keys())) ProgressBar(bar_format=None).attach(evaluator) # # Checkpoint of the model # self.setup_checkpoint(base_model, classifier, setops_model, evaluator) logging.info("Setup schedulers.") # # Update learning rate manually using the Visdom interface. # one_cycle_size = len(train_loader) * self.warmup_epochs * 2 scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=self.lr1, end_value=self.lr2, cycle_size=one_cycle_size) scheduler_2 = ReduceLROnPlateau(optimizer, factor=0.5, patience=4 * len(train_loader), cooldown=len(train_loader), output_transform=lambda x: x["main"]) lr_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=[one_cycle_size // 2], save_history=True) trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) # # Evaluation # @trainer.on(Events.EPOCH_COMPLETED) def epoch_completed(engine): # # Re-randomize the indices of the training dataset. # train_loader.dataset.calc_indices() # # Run the evaluator on a subset of the training dataset. # logging.info("Evaluation on a subset of the training data.") train_evaluator.run(train_subset_loader) # # Run the evaluator on the validation set. # logging.info("Evaluation on the eval data.") evaluator.run(val_loader) return trainer, train_loader
def main( dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, output_dir, saved_optimizer, warmup, ): device = "cpu" if (not torch.cuda.is_available() or not cuda) else "cuda:0" check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True, ) test_loader = data.DataLoader( test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False, ) model = Glow( image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, ) model = model.to(device) optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) lr_lambda = lambda epoch: min(1.0, (epoch + 1) / warmup) # noqa scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses["total_loss"].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction="none") else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction="none") return losses trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, "glow", n_saved=2, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, { "model": model, "optimizer": optimizer }, ) monitoring_metrics = ["total_loss"] RunningAverage(output_transform=lambda x: x["total_loss"]).attach( trainer, "total_loss") evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: ( x["total_loss"], torch.empty(x["total_loss"].shape[0]), ), ).attach(evaluator, "total_loss") if y_condition: monitoring_metrics.extend(["nll"]) RunningAverage(output_transform=lambda x: x["nll"]).attach( trainer, "nll") # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss( lambda x, y: torch.mean(x), output_transform=lambda x: (x["nll"], torch.empty(x["nll"].shape[0])), ).attach(evaluator, "nll") pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split("_")[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) @trainer.on(Events.STARTED) def init(engine): model.train() init_batches = [] init_targets = [] with torch.no_grad(): for batch, target in islice(train_loader, None, n_init_batches): init_batches.append(batch) init_targets.append(target) init_batches = torch.cat(init_batches).to(device) assert init_batches.shape[0] == n_init_batches * batch_size if y_condition: init_targets = torch.cat(init_targets).to(device) else: init_targets = None model(init_batches, init_targets) @trainer.on(Events.EPOCH_COMPLETED) def evaluate(engine): evaluator.run(test_loader) scheduler.step() metrics = evaluator.state.metrics losses = ", ".join( [f"{key}: {value:.2f}" for key, value in metrics.items()]) print(f"Validation Results - Epoch: {engine.state.epoch} {losses}") timer = Timer(average=True) timer.attach( trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED, ) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f"Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]" ) timer.reset() trainer.run(train_loader, epochs)
def train(dataset_path, dataset_cache='./dataset_cache', model_checkpoint='gpt2', num_candidates=2, max_history=2, train_batch_size=4, valid_batch_size=4, gradient_accumulation_steps=8, lr=6.25e-5, lm_coef=1.0, mc_coef=1.0, max_norm=1.0, n_epochs=3, personality_permutations=1, eval_before_start=False, device="cuda" if torch.cuda.is_available() else "cpu", fp16='', path_prefix='', log_dir='', local_rank=-1): args = {**locals()} # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if local_rank in [-1, 0] else logging.WARN) # This is a logger.warning: it will be printed by all distributed processes logger.warning("Running process %d", local_rank) logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed distributed = (local_rank != -1) args['distributed'] = distributed if distributed: torch.cuda.set_device(local_rank) device = torch.device("cuda", local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") # cant use Autotokenizer because checkpoint could be a Path tokenizer_class = GPT2Tokenizer if "gpt2" in model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(model_checkpoint) model.to(device) # Add special tokens if they are not already added add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=fp16) if distributed: model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( dataset_path, dataset_cache, num_candidates, personality_permutations, max_history, train_batch_size, valid_batch_size, distributed, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch (lm_loss), (mc_loss), *_ = model(input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, mc_labels=mc_labels, lm_labels=lm_labels) loss = (lm_loss * lm_coef + mc_loss * mc_coef) / \ gradient_accumulation_steps if fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm) if engine.state.iteration % gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) # if we dont send labels to model, it doesnt return losses lm_logits, mc_logits, *_ = model( input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, ) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, lr), (n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], local_rank, device), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], local_rank, device) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) log_dir = log_dir if log_dir else make_logdir(model_checkpoint, path=path_prefix) tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if local_rank in [-1, 0] and n_epochs > 0: # TODO: PR in ignite to have better access to saved file paths (cleaner) os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(log_dir, WEIGHTS_NAME)) tb_logger.close()
def train(): args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2LMHeadModel if "gpt2" in args.model_checkpoint else OpenAIGPTLMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss = model(*batch) loss = lm_loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, lm_labels, token_type_ids = batch # logger.info(tokenizer.decode(input_ids[0, :].tolist())) model_outputs = model(input_ids, token_type_ids=token_type_ids) lm_logits = model_outputs[0] lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return lm_logits_flat_shifted, lm_labels_flat_shifted evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))} metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=args.output_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def train(name, load, lrate, weight_decay, workers, smooth, device, validation, ground_truth): if not name: name = '{}_{}'.format(lrate, weight_decay) click.echo('model output name: {}'.format(name)) torch.set_num_threads(1) train_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(ground_truth), recursive=True), smooth=smooth) train_data_loader = DataLoader(dataset=train_set, num_workers=workers, batch_size=1, shuffle=True, pin_memory=True) val_set = BaselineSet(glob.glob('{}/**/*.seeds.png'.format(validation), recursive=True), smooth=smooth) val_data_loader = DataLoader(dataset=val_set, num_workers=workers, batch_size=1, pin_memory=True) click.echo('loading network') model = ResUNet(refine_encoder=False).to(device) if load: click.echo('loading weights') model = torch.load(load, map_location=device) criterion = nn.BCEWithLogitsLoss() opti = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lrate, weight_decay=weight_decay) def score_function(engine): val_loss = engine.state.metrics['loss'] return -val_loss def output_preprocess(output): o, target = output o = torch.sigmoid(o) o = denoising_hysteresis_thresh(o.detach().squeeze().cpu().numpy(), 0.8, 0.9, 2.5) return torch.from_numpy(o.astype('f')).unsqueeze(0).unsqueeze(0).to( device), target.double().to(device) trainer = create_supervised_trainer(model, opti, criterion, device=device, non_blocking=True) accuracy = Accuracy(output_transform=output_preprocess) precision = Precision(output_transform=output_preprocess) recall = Recall(output_transform=output_preprocess) loss = Loss(criterion) precision = Precision(average=False) recall = Recall(average=False) f1 = (precision * recall * 2 / (precision + recall)).mean() evaluator = create_supervised_evaluator(model, device=device, non_blocking=True) accuracy.attach(evaluator, 'accuracy') precision.attach(evaluator, 'precision') recall.attach(evaluator, 'recall') loss.attach(evaluator, 'loss') f1.attach(evaluator, 'f1') ckpt_handler = ModelCheckpoint('.', name, save_interval=1, n_saved=10, require_empty=False) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') progress_bar = ProgressBar(persist=True) progress_bar.attach(trainer, ['loss']) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=ckpt_handler, to_save={'net': model}) trainer.add_event_handler(event_name=Events.ITERATION_COMPLETED, handler=TerminateOnNan()) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): evaluator.run(val_data_loader) metrics = evaluator.state.metrics progress_bar.log_message( 'eval results - epoch {} loss: {:.4f} f1: {:.4f}, accuracy: {:.4f} recall: {:.4f} precision {:.4f}' .format(engine.state.epoch, metrics['loss'], metrics['f1'], metrics['accuracy'], metrics['recall'], metrics['precision'])) trainer.run(train_data_loader, max_epochs=1000)
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(args.device) optimizer = OpenAIAdam(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) lm_loss, mc_loss = model(*batch) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) # tb_logger = TensorboardLogger(log_dir=None) # tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) # tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) # checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) # trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation # torch.save(args, tb_logger.writer.log_dir + '/model_training_args.bin') # getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) # tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs)
def attach_decorators(trainer, SR, D, vgg, loader, schedulerD, schedulerG, optimizerD, optimizerG, resume_epoch, resume_iter): timer = Timer(average=True) checkpoint_handler = ModelCheckpoint(args.output_dir, 'training', save_interval=1, n_saved=10, require_empty=False) monitoring_metrics = [ 'dloss_real', 'dloss_fake', 'd_loss', 'GP', 'WD', 'VGG', 'gloss' ] RunningAverage(alpha=0.98, output_transform=lambda x: x['dloss_real']).attach( trainer, 'dloss_real') RunningAverage(alpha=0.98, output_transform=lambda x: x['dloss_fake']).attach( trainer, 'dloss_fake') RunningAverage(alpha=0.98, output_transform=lambda x: x['GP']).attach(trainer, 'GP') RunningAverage(alpha=0.98, output_transform=lambda x: x['d_loss']).attach( trainer, 'd_loss') RunningAverage(alpha=0.98, output_transform=lambda x: x['WD']).attach(trainer, 'WD') RunningAverage(alpha=0.98, output_transform=lambda x: x['VGG']).attach(trainer, 'VGG') RunningAverage(alpha=0.98, output_transform=lambda x: x['gloss']).attach( trainer, 'gloss') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=checkpoint_handler, to_save={ 'SR': SR, 'D': D, 'VGG': vgg, 'optim_D': optimizerD, 'optim_G': optimizerG, 'sched_D': schedulerD, 'sched_G': schedulerG }) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.ITERATION_COMPLETED) def print_logs(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: fname = os.path.join(args.output_dir, LOGS_FNAME) columns = engine.state.metrics.keys() values = [ str(round(value, 5)) for value in engine.state.metrics.values() ] with open(fname, 'a') as f: if f.tell() == 0: print('\t'.join(columns), file=f) print('\t'.join(values), file=f) i = (engine.state.iteration % len(loader)) message = '[{epoch}/{max_epoch}][{i}/{max_i}]'.format( epoch=engine.state.epoch, max_epoch=args.epochs, i=i, max_i=len(loader)) for name, value in zip(columns, values): message += ' | {name}: {value}'.format(name=name, value=value) pbar.log_message(message) @trainer.on(Events.ITERATION_COMPLETED) def save_real_example(engine): if (engine.state.iteration - 1) % PRINT_FREQ == 0: px, y = engine.state.batch img = SR(px.cuda()) path = os.path.join( args.output_dir, FAKE_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(img, path) path = os.path.join( args.output_dir, REAL_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(y, path) path = os.path.join( args.output_dir, TRAIN_IMG_FNAME.format(engine.state.epoch, engine.state.iteration)) vutils.save_image(px, path) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() @trainer.on(Events.EPOCH_COMPLETED) def LRstep(engine): schedulerD.step() schedulerG.step() @trainer.on(Events.EPOCH_COMPLETED) def create_plots(engine): try: import matplotlib as mpl mpl.use('agg') import numpy as np import pandas as pd import matplotlib.pyplot as plt df = pd.read_csv(os.path.join(args.output_dir, LOGS_FNAME), delimiter='\t') x = np.arange(1, engine.state.iteration + 1, PRINT_FREQ) _ = df.plot(subplots=True, figsize=(20, 20), grid=True, xticks=x) _ = plt.xlabel('Iteration number') fig = plt.gcf() path = os.path.join(args.output_dir, PLOT_FNAME) fig.savefig(path) except ImportError: warnings.warn( 'Loss plots will not be generated -- pandas or matplotlib not found' ) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn('KeyboardInterrupt caught. Exiting gracefully.') create_plots(engine) checkpoint_handler( engine, { 'netG_{}'.format(engine.state.iteration): SR, 'netD_{}'.format(engine.state.iteration): D, 'optim_D_{}'.format(engine.state.iteration): optimizerD, 'optim_G_{}'.format(engine.state.iteration): optimizerG, 'sched_D_{}'.format(engine.state.iteration): schedulerD, 'sched_G_{}'.format(engine.state.iteration): schedulerG }) else: raise e @trainer.on(Events.STARTED) def resume_training(engine): engine.state.iteration = resume_iter engine.state.epoch = resume_epoch def OAR_shutdown(): raise KeyboardInterrupt signal.signal(signal.SIGUSR2, OAR_shutdown)
def _upd_pbar_iter_from_cp(engine: Engine, pbar: ProgressBar) -> None: pbar.n = engine.state.iteration
def run(args): if args.seed is not None: random.seed(args.seed) torch.manual_seed(args.seed) num_classes = CityscapesDataset.num_classes() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = GoogLeNetFCN(num_classes) model.init_from_googlenet() device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size args.val_batch_size = device_count * args.val_batch_size model = model.to(device) train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size, args.num_workers, args.include_coarse) criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='sum') optimizer = optim.SGD([{ 'params': [ param for name, param in model.named_parameters() if name.endswith('weight') ] }, { 'params': [ param for name, param in model.named_parameters() if name.endswith('bias') ], 'lr': args.lr * 2, 'weight_decay': 0 }], lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] best_iou = checkpoint.get('bestIoU', 0.0) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) sys.exit() if args.freeze_bn: print("Freezing batch norm") model = freeze_batchnorm(model) trainer = create_supervised_trainer(model, optimizer, criterion, device, non_blocking=True) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) cm = ConfusionMatrix(num_classes) evaluator = create_supervised_evaluator(model, metrics={ 'loss': Loss(criterion), 'IoU': IoU(cm), 'accuracy': cmAccuracy(cm) }, device=device, non_blocking=True) pbar2 = ProgressBar(persist=True, desc='Eval Epoch') pbar2.attach(evaluator) def _global_step_transform(engine, event_name): return trainer.state.iteration tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator, log_handler=OutputHandler( tag='validation', metric_names=['loss', 'IoU', 'accuracy'], global_step_transform=_global_step_transform), event_name=Events.EPOCH_COMPLETED) @evaluator.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): iou = engine.state.metrics['IoU'] * 100.0 mean_iou = iou.mean() is_best = mean_iou.item() > trainer.state.best_iou trainer.state.best_iou = max(mean_iou.item(), trainer.state.best_iou) name = 'epoch{}_mIoU={:.1f}.pth'.format(trainer.state.epoch, mean_iou) file = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'iteration': engine.state.iteration, 'optimizer': optimizer.state_dict(), 'args': args, 'bestIoU': trainer.state.best_iou } save(file, args.output_dir, 'checkpoint_{}'.format(name)) if is_best: save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) @trainer.on(Events.STARTED) def initialize(engine): if args.resume: engine.state.epoch = args.start_epoch engine.state.iteration = args.start_epoch * len( engine.state.dataloader) engine.state.best_iou = best_iou else: engine.state.best_iou = 0.0 @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(engine): pbar.log_message("Start Validation - Epoch: [{}/{}]".format( engine.state.epoch, engine.state.max_epochs)) evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] iou = metrics['IoU'] acc = metrics['accuracy'] mean_iou = iou.mean() pbar.log_message( "Validation results - Epoch: [{}/{}]: Loss: {:.2e}, Accuracy: {:.1f}, mIoU: {:.1f}" .format(engine.state.epoch, engine.state.max_epochs, loss, acc * 100.0, mean_iou * 100.0)) print("Start training") trainer.run(train_loader, max_epochs=args.epochs) tb_logger.close()
def run(args): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = IQAModel(arch=args.arch, pool=args.pool, use_bn_end=args.use_bn_end, P6=args.P6, P7=args.P7).to(device) # print(model) if args.ft_lr_ratio == .0: for param in model.features.parameters(): param.requires_grad = False train_loader, val_loader, test_loader = get_data_loaders(args) optimizer = Adam( [ { 'params': model.regression.parameters() }, # The most important parameters. Maybe we need three levels of lrs { 'params': model.dr6.parameters() }, { 'params': model.dr7.parameters() }, { 'params': model.regr6.parameters() }, { 'params': model.regr7.parameters() }, { 'params': model.features.parameters(), 'lr': args.lr * args.ft_lr_ratio } ], lr=args.lr, weight_decay=args.weight_decay ) # Adam can be changed to other optimizers, such as SGD, Adadelta. # Initialization model, optimizer = amp.initialize(model, optimizer, opt_level=args.opt_level) mapping = True # args.loss_type != 'l1' and args.loss_type != 'mse' if args.evaluate: checkpoint = torch.load(args.trained_model_file) model.load_state_dict(checkpoint['model']) k = checkpoint['k'] b = checkpoint['b'] evaluator = create_supervised_evaluator(model, metrics={ 'IQA_performance': IQAPerformance( status='test', k=k, b=b, mapping=mapping) }, device=device) evaluator.run(test_loader) performance = evaluator.state.metrics for metric_print in metrics_printed: print('{}, {}: {:.3f}'.format(args.dataset, metric_print, performance[metric_print].item())) for metric_print in metrics_printed: print('{:.3f}'.format(performance[metric_print].item())) np.save(args.save_result_file, performance) return scheduler = lr_scheduler.StepLR(optimizer, step_size=args.lr_decay_step, gamma=args.lr_decay) loss_func = IQALoss( loss_type=args.loss_type, alpha=args.alpha, beta=args.beta, p=args.p, q=args.q, monotonicity_regularization=args.monotonicity_regularization, gamma=args.gamma, detach=args.detach) trainer = create_supervised_trainer( model, optimizer, loss_func, device=device, accumulation_steps=args.accumulation_steps) if args.pbar: from ignite.contrib.handlers import ProgressBar ProgressBar().attach(trainer) evaluator_for_train = create_supervised_evaluator(model, metrics={ 'IQA_performance': IQAPerformance( status='train', mapping=mapping) }, device=device) current_time = datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y") writer = SummaryWriter( log_dir='{}/{}-{}'.format(args.log_dir, args.format_str, current_time)) global best_val_criterion, best_epoch best_val_criterion, best_epoch = -100, -1 # larger, better, e.g., SROCC or PLCC. If RMSE is used, best_val_criterion <- 10000 @trainer.on(Events.ITERATION_COMPLETED) def iter_event_function(engine): writer.add_scalar("train/loss", engine.state.output, engine.state.iteration) @trainer.on(Events.EPOCH_COMPLETED) def epoch_event_function(engine): if args.test_during_training: evaluator_for_train.run( train_loader ) # It is better to re-make a train_loader_for_evaluation so as not to disturb the random number generator. performance = evaluator_for_train.state.metrics writer_add_scalar(writer, 'train', args.dataset, performance, engine.state.epoch) k = performance['k'] b = performance['b'] else: k = [1, 1, 1] b = [0, 0, 0] evaluator = create_supervised_evaluator(model, metrics={ 'IQA_performance': IQAPerformance( status='test', k=k, b=b, mapping=mapping) }, device=device) evaluator.run(val_loader) performance = evaluator.state.metrics writer_add_scalar(writer, 'val', args.dataset, performance, engine.state.epoch) val_criterion = abs( performance[args.val_criterion] ) # when alpha=[0,1],loss_type='linearity', test_during_training=False, SROCC/PLCC can be negative during training. if args.test_during_training: evaluator.run(test_loader) performance = evaluator.state.metrics writer_add_scalar(writer, 'test', args.dataset, performance, engine.state.epoch) global best_val_criterion, best_epoch if val_criterion > best_val_criterion: # If RMSE is used, then change ">" to "<". checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict(), 'k': k, 'b': b } torch.save(checkpoint, args.trained_model_file) best_val_criterion = val_criterion best_epoch = engine.state.epoch print( 'Save current best model @best_val_criterion ({}): {:.3f} @epoch: {}' .format(args.val_criterion, best_val_criterion, best_epoch)) else: print( 'Model is not updated @val_criterion ({}): {:.3f} @epoch: {}'. format(args.val_criterion, val_criterion, engine.state.epoch)) scheduler.step(engine.state.epoch) @trainer.on(Events.COMPLETED) def final_testing_results(engine): writer.close() # close the Tensorboard writer print('best epoch: {}'.format(best_epoch)) checkpoint = torch.load(args.trained_model_file) model.load_state_dict(checkpoint['model']) if args.test_during_training: k = checkpoint['k'] b = checkpoint['b'] else: evaluator_for_train.run(train_loader) performance = evaluator_for_train.state.metrics k = performance['k'] b = performance['b'] checkpoint = { 'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'amp': amp.state_dict(), 'k': k, 'b': b } torch.save(checkpoint, args.trained_model_file) evaluator = create_supervised_evaluator(model, metrics={ 'IQA_performance': IQAPerformance( status='test', k=k, b=b, mapping=mapping) }, device=device) evaluator.run(test_loader) performance = evaluator.state.metrics for metric_print in metrics_printed: print('{}, {}: {:.3f}'.format(args.dataset, metric_print, performance[metric_print].item())) for metric_print in metrics_printed: print('{:.3f}'.format(performance[metric_print].item())) np.save(args.save_result_file, performance) trainer.run(train_loader, max_epochs=args.epochs)
def _setup_common_training_handlers( trainer: Engine, to_save: Optional[Mapping] = None, save_every_iters: int = 1000, output_path: Optional[str] = None, lr_scheduler: Optional[Union[ParamScheduler, _LRScheduler]] = None, with_gpu_stats: bool = False, output_names: Optional[Iterable[str]] = None, with_pbars: bool = True, with_pbar_on_iters: bool = True, log_every_iters: int = 100, stop_on_nan: bool = True, clear_cuda_cache: bool = True, save_handler: Optional[Union[Callable, BaseSaveHandler]] = None, **kwargs: Any, ) -> None: if output_path is not None and save_handler is not None: raise ValueError( "Arguments output_path and save_handler are mutually exclusive. Please, define only one of them" ) if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler( Events.ITERATION_COMPLETED, lambda engine: cast(_LRScheduler, lr_scheduler).step()) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None and save_handler is None: raise ValueError( "If to_save argument is provided then output_path or save_handler arguments should be also defined" ) if output_path is not None: save_handler = DiskSaver(dirname=output_path, require_empty=False) checkpoint_handler = Checkpoint(to_save, cast(Union[Callable, BaseSaveHandler], save_handler), filename_prefix="training", **kwargs) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED( every=log_every_iters) # type: ignore[arg-type] ) if output_names is not None: def output_transform(x: Any, index: int, name: str) -> Any: if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise TypeError( "Unhandled type of update_function's output. " f"It should either mapping or sequence, but given {type(x)}" ) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(trainer, n) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
**kwargs) # Load model if args.state_dict is not None: state_dict = torch.load(args.state_dict, map_location='cpu') model.load_state_dict(state_dict, strict=True) model = model.to(device) evaluator = create_segmentation_evaluator( model, device=device, num_classes=19, ) ProgressBar().attach(evaluator) state = evaluator.run(val_loader) classes = CLASSES[TRAIN_MAPPING != 255] metrics = { 'accuracy': state.metrics['accuracy'], 'miou': state.metrics['miou'], 'iou': {name: state.metrics['iou'][id].item() for id, name in enumerate(classes)}, } pprint(metrics)
def main(dataset, dataroot, download, augment, batch_size, eval_batch_size, epochs, saved_model, seed, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, learn_top, y_condition, y_weight, max_grad_clip, max_grad_norm, lr, n_workers, cuda, n_init_batches, warmup_steps, output_dir, saved_optimizer, warmup, fresh, logittransform, gan, disc_lr): device = 'cpu' if (not torch.cuda.is_available() or not cuda) else 'cuda:0' check_manual_seed(seed) ds = check_dataset(dataset, dataroot, augment, download) image_shape, num_classes, train_dataset, test_dataset = ds # Note: unsupported for now multi_class = False train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=n_workers, drop_last=True) test_loader = data.DataLoader(test_dataset, batch_size=eval_batch_size, shuffle=False, num_workers=n_workers, drop_last=False) model = Glow(image_shape, hidden_channels, K, L, actnorm_scale, flow_permutation, flow_coupling, LU_decomposed, num_classes, learn_top, y_condition, logittransform) model = model.to(device) if gan: # Debug model = mine.Generator(32, 1).to(device) optimizer = optim.Adam(model.parameters(), lr=lr, betas=(.5, .99), weight_decay=0) discriminator = mine.Discriminator(image_shape[-1]) discriminator = discriminator.to(device) D_optimizer = optim.Adam(filter(lambda p: p.requires_grad, discriminator.parameters()), lr=disc_lr, betas=(.5, .99), weight_decay=0) else: optimizer = optim.Adamax(model.parameters(), lr=lr, weight_decay=5e-5) # lr_lambda = lambda epoch: lr * min(1., epoch+1 / warmup) # scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda) i = 0 def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x = x.to(device) if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class) else: z, nll, y_logits = model(x, None) losses = compute_loss(nll) losses['total_loss'].backward() if max_grad_clip > 0: torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) if max_grad_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) optimizer.step() return losses def gan_step(engine, batch): assert not y_condition if 'iter_ind' in dir(engine): engine.iter_ind += 1 else: engine.iter_ind = -1 losses = {} model.train() discriminator.train() x, y = batch x = x.to(device) # def generate_from_noise(batch_size): # _, c2, h, w = model.prior_h.shape # c = c2 // 2 # zshape = (batch_size, c, h, w) # randz = torch.autograd.Variable(torch.randn(zshape), requires_grad=True).to(device) # images = model(z= randz, y_onehot=None, temperature=1, reverse=True,batch_size=batch_size) # return images def generate_from_noise(batch_size): zshape = (batch_size, 32, 1, 1) randz = torch.randn(zshape).to(device) images = model(randz) return images / 2 def run_noised_disc(discriminator, x): x = uniform_binning_correction(x)[0] return discriminator(x) # Train Disc fake = generate_from_noise(x.size(0)) D_real_scores = run_noised_disc(discriminator, x.detach()) D_fake_scores = run_noised_disc(discriminator, fake.detach()) ones_target = torch.ones((x.size(0), 1), device=x.device) zeros_target = torch.zeros((x.size(0), 1), device=x.device) # D_real_accuracy = torch.sum(torch.round(F.sigmoid(D_real_scores)) == ones_target).float() / ones_target.size(0) # D_fake_accuracy = torch.sum(torch.round(F.sigmoid(D_fake_scores)) == zeros_target).float() / zeros_target.size(0) D_real_loss = F.binary_cross_entropy_with_logits( D_real_scores, ones_target) D_fake_loss = F.binary_cross_entropy_with_logits( D_fake_scores, zeros_target) D_loss = (D_real_loss + D_fake_loss) / 2 gp = gradient_penalty(x.detach(), fake.detach(), lambda _x: run_noised_disc(discriminator, _x)) D_loss_plus_gp = D_loss + 10 * gp D_optimizer.zero_grad() D_loss_plus_gp.backward() D_optimizer.step() # Train generator fake = generate_from_noise(x.size(0)) G_loss = F.binary_cross_entropy_with_logits( run_noised_disc(discriminator, fake), torch.ones((x.size(0), 1), device=x.device)) losses['total_loss'] = G_loss # G-step optimizer.zero_grad() losses['total_loss'].backward() params = list(model.parameters()) gnorm = [p.grad.norm() for p in params] optimizer.step() # if max_grad_clip > 0: # torch.nn.utils.clip_grad_value_(model.parameters(), max_grad_clip) # if max_grad_norm > 0: # torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) if engine.iter_ind % 50 == 0: grid = make_grid((postprocess(fake.detach().cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig( os.path.join(output_dir, f'sample_{engine.iter_ind}.png')) grid = make_grid( (postprocess(uniform_binning_correction(x)[0].cpu())[:30]), nrow=6).permute(1, 2, 0) plt.figure(figsize=(10, 10)) plt.imshow(grid) plt.axis('off') plt.savefig(os.path.join(output_dir, f'data_{engine.iter_ind}.png')) return losses def eval_step(engine, batch): model.eval() x, y = batch x = x.to(device) with torch.no_grad(): if y_condition: y = y.to(device) z, nll, y_logits = model(x, y) losses = compute_loss_y(nll, y_logits, y_weight, y, multi_class, reduction='none') else: z, nll, y_logits = model(x, None) losses = compute_loss(nll, reduction='none') return losses if gan: trainer = Engine(gan_step) else: trainer = Engine(step) checkpoint_handler = ModelCheckpoint(output_dir, 'glow', save_interval=1, n_saved=2, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, 'optimizer': optimizer }) monitoring_metrics = ['total_loss'] RunningAverage(output_transform=lambda x: x['total_loss']).attach( trainer, 'total_loss') evaluator = Engine(eval_step) # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['total_loss'], torch.empty(x['total_loss'].shape[0]))).attach( evaluator, 'total_loss') if y_condition: monitoring_metrics.extend(['nll']) RunningAverage(output_transform=lambda x: x['nll']).attach( trainer, 'nll') # Note: replace by https://github.com/pytorch/ignite/pull/524 when released Loss(lambda x, y: torch.mean(x), output_transform=lambda x: (x['nll'], torch.empty(x['nll'].shape[0]))).attach( evaluator, 'nll') pbar = ProgressBar() pbar.attach(trainer, metric_names=monitoring_metrics) # load pre-trained model if given if saved_model: model.load_state_dict(torch.load(saved_model)) model.set_actnorm_init() if saved_optimizer: optimizer.load_state_dict(torch.load(saved_optimizer)) file_name, ext = os.path.splitext(saved_model) resume_epoch = int(file_name.split('_')[-1]) @trainer.on(Events.STARTED) def resume_training(engine): engine.state.epoch = resume_epoch engine.state.iteration = resume_epoch * len( engine.state.dataloader) # @trainer.on(Events.STARTED) # def init(engine): # model.train() # init_batches = [] # init_targets = [] # with torch.no_grad(): # for batch, target in islice(train_loader, None, # n_init_batches): # init_batches.append(batch) # init_targets.append(target) # init_batches = torch.cat(init_batches).to(device) # assert init_batches.shape[0] == n_init_batches * batch_size # if y_condition: # init_targets = torch.cat(init_targets).to(device) # else: # init_targets = None # model(init_batches, init_targets) # @trainer.on(Events.EPOCH_COMPLETED) # def evaluate(engine): # evaluator.run(test_loader) # # scheduler.step() # metrics = evaluator.state.metrics # losses = ', '.join([f"{key}: {value:.2f}" for key, value in metrics.items()]) # myprint(f'Validation Results - Epoch: {engine.state.epoch} {losses}') timer = Timer(average=True) timer.attach(trainer, start=Events.EPOCH_STARTED, resume=Events.ITERATION_STARTED, pause=Events.ITERATION_COMPLETED, step=Events.ITERATION_COMPLETED) @trainer.on(Events.EPOCH_COMPLETED) def print_times(engine): pbar.log_message( f'Epoch {engine.state.epoch} done. Time per batch: {timer.value():.3f}[s]' ) timer.reset() trainer.run(train_loader, epochs)
def attach_handlers(run, model, optimizer, trainer, train_evaluator, evaluator, train_loader, val_loader, params): # Tqdm logger pbar = ProgressBar(persist=True, bar_format=config.IGNITE_BAR_FORMAT) pbar.attach(trainer.engine, metric_names='all') tqdm_logger = TqdmLogger(pbar=pbar) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tqdm_logger.attach_output_handler( train_evaluator.engine, event_name=Events.COMPLETED, tag="train", global_step_transform=global_step_from_engine(trainer.engine), ) # Evaluators train_evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, train_loader) evaluator.attach(trainer.engine, Events.EPOCH_COMPLETED, data=val_loader) # Learning rate scheduling lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', verbose=True, patience=5, factor=0.5) evaluator.engine.add_event_handler( Events.COMPLETED, lambda engine: lr_scheduler.step(engine.state.metrics['accuracy'])) # Early stopping es_handler = EarlyStopping( patience=15, score_function=lambda engine: engine.state.metrics['accuracy'], trainer=trainer.engine, cumulative_delta=True, min_delta=0.0001) if 'train_all' in params and params['train_all']: train_evaluator.engine.add_event_handler(Events.COMPLETED, es_handler) else: evaluator.engine.add_event_handler(Events.COMPLETED, es_handler) es_handler.logger.setLevel(logging.DEBUG) # Model checkpoints name = run.replace('/', '-') mc_handler = ModelCheckpoint( config.MODELS_DIR, name, n_saved=1, create_dir=True, require_empty=False, score_name='acc', score_function=lambda engine: engine.state.metrics['accuracy'], global_step_transform=global_step_from_engine(trainer.engine)) evaluator.engine.add_event_handler(Events.EPOCH_COMPLETED, mc_handler, {'m': model}) # TensorBoard logger tb_logger = TensorboardLogger( log_dir=os.path.join(config.TENSORBOARD_DIR, run)) images, labels = next(iter(train_loader)) tb_logger.writer.add_graph(copy.deepcopy(model).cpu(), images) tb_logger.writer.add_hparams(params, {'hparam/dummy': 0}) # noinspection PyTypeChecker tb_logger.attach_output_handler( train_evaluator.engine, event_name=Events.COMPLETED, tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) # noinspection PyTypeChecker tb_logger.attach_output_handler( evaluator.engine, event_name=Events.COMPLETED, tag="validation", metric_names="all", global_step_transform=global_step_from_engine(trainer.engine), ) input_shape = tuple(next(iter(train_loader))[0].shape[1:]) tb_logger.attach(trainer.engine, log_handler=WeightsImageHandler(model, input_shape), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer.engine, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.EPOCH_STARTED) # tb_logger.attach(trainer.engine, log_handler=WeightsScalarHandler(model), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsHistHandler(model, layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=NumActivationsScalarHandler(model, layer_names=['linear1', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.mean, # layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) # tb_logger.attach(trainer.engine, # log_handler=ActivationsScalarHandler(model, reduction=torch.std, # layer_names=['linear1', 'batch_norm', 'repu']), # event_name=Events.ITERATION_COMPLETED) return es_handler, tb_logger
def train(device, net, dataloader, val_loader, args, logger, experiment): def update(engine, data): input_left, input_right, label = data['left_image'], data['right_image'], data['winner'] input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device) # zero the parameter gradients optimizer.zero_grad() label = label.float() start = timer() output_rank_left, output_rank_right = net(input_left,input_right) end = timer() logger.info(f'FORWARD,{end-start:.4f}') #compute ranking loss start = timer() loss = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit) end = timer() logger.info(f'LOSS,{end-start:.4f}') #compute ranking accuracy start = timer() rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label) end = timer() logger.info(f'RANK-ACC,{end-start:.4f}') # backward step start = timer() loss.backward() optimizer.step() end = timer() logger.info(f'BACKWARD,{end-start:.4f}') scheduler.step() return { 'loss':loss.item(), 'rank_acc': rank_acc } def inference(engine,data): with torch.no_grad(): start = timer() input_left, input_right, label = data['left_image'], data['right_image'], data['winner'] input_left, input_right, label = input_left.to(device), input_right.to(device), label.to(device) label = label.float() output_rank_left, output_rank_right = net(input_left,input_right) loss = compute_ranking_loss(output_rank_left, output_rank_right, label, rank_crit) rank_acc = compute_ranking_accuracy(output_rank_left, output_rank_right, label) end = timer() logger.info(f'INFERENCE,{end-start:.4f}') return { 'loss':loss.item(), 'rank_acc': rank_acc } net = net.to(device) if args.equal: rank_crit = RankingLoss(margin=1, tie_margin=0) print("using new loss") else: rank_crit = nn.MarginRankingLoss(reduction='mean', margin=1) #optimizer = optim.SGD(net.parameters(), lr=args.lr, weight_decay=args.wd, momentum=0.9) optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=args.wd, betas=(0.9, 0.98), eps=1e-09) scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.995, last_epoch=-1) trainer = Engine(update) evaluator = Engine(inference) RunningAverage(output_transform=lambda x: x['loss']).attach(trainer, 'loss') RunningAverage(output_transform=lambda x: x['rank_acc']).attach(trainer, 'rank_acc') RunningAverage(output_transform=lambda x: x['loss']).attach(evaluator, 'loss') RunningAverage(output_transform=lambda x: x['rank_acc']).attach(evaluator, 'rank_acc') if args.pbar: pbar = ProgressBar(persist=False) pbar.attach(trainer,['loss', 'rank_acc']) pbar = ProgressBar(persist=False) pbar.attach(evaluator,['loss','rank_acc']) @trainer.on(Events.EPOCH_COMPLETED) def log_validation_results(trainer): net.eval() evaluator.run(val_loader) trainer.state.metrics['val_acc'] = evaluator.state.metrics['rank_acc'] net.train() if hasattr(net,'partial_eval'): net.partial_eval() metrics = { 'train_rank_accuracy':trainer.state.metrics['rank_acc'], 'train_loss':trainer.state.metrics['loss'], 'val_rank_accuracy': evaluator.state.metrics['rank_acc'], 'val_loss':evaluator.state.metrics['loss'] } comet_log( metrics, experiment, epoch=trainer.state.epoch, step=trainer.state.epoch, ) console_log(metrics,{},trainer.state.epoch) @trainer.on(Events.ITERATION_COMPLETED) def log_training_results(trainer): if trainer.state.iteration %100 == 0: metrics = { 'train_rank_accuracy':trainer.state.metrics['rank_acc'], 'train_loss':trainer.state.metrics['loss'], 'lr': scheduler.get_lr() } comet_log( metrics, experiment, step=trainer.state.iteration, epoch=trainer.state.epoch ) console_log( metrics, {}, trainer.state.epoch, step=trainer.state.iteration, ) model_name = '{}_{}_{}'.format(args.model, args.premodel, args.attribute) if args.tag: model_name += f'_{args.tag}' handler = ModelCheckpoint(args.model_dir, model_name, n_saved=1, create_dir=True, save_as_state_dict=True, require_empty=False, score_function=lambda engine: engine.state.metrics['val_acc']) trainer.add_event_handler(Events.EPOCH_COMPLETED, handler, { 'model': net }) if (args.resume): def start_epoch(engine): engine.state.epoch = args.epoch trainer.add_event_handler(Events.STARTED, start_epoch) evaluator.add_event_handler(Events.STARTED, start_epoch) trainer.run(dataloader,max_epochs=args.max_epochs)
def _setup_common_training_handlers( trainer, to_save=None, save_every_iters=1000, output_path=None, lr_scheduler=None, with_gpu_stats=False, output_names=None, with_pbars=True, with_pbar_on_iters=True, log_every_iters=100, stop_on_nan=True, clear_cuda_cache=True, ): if stop_on_nan: trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) if lr_scheduler is not None: if isinstance(lr_scheduler, torch.optim.lr_scheduler._LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lambda engine: lr_scheduler.step()) elif isinstance(lr_scheduler, LRScheduler): trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) else: trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler) if torch.cuda.is_available() and clear_cuda_cache: trainer.add_event_handler(Events.EPOCH_COMPLETED, empty_cuda_cache) if to_save is not None: if output_path is None: raise ValueError( "If to_save argument is provided then output_path argument should be also defined" ) checkpoint_handler = Checkpoint( to_save, DiskSaver(dirname=output_path, require_empty=False), filename_prefix="training", ) trainer.add_event_handler( Events.ITERATION_COMPLETED(every=save_every_iters), checkpoint_handler) if with_gpu_stats: GpuInfo().attach( trainer, name="gpu", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) if output_names is not None: def output_transform(x, index, name): if isinstance(x, Mapping): return x[name] elif isinstance(x, Sequence): return x[index] elif isinstance(x, (torch.Tensor, numbers.Number)): return x else: raise ValueError( "Unhandled type of update_function's output. " "It should either mapping or sequence, but given {}". format(type(x))) for i, n in enumerate(output_names): RunningAverage(output_transform=partial(output_transform, index=i, name=n), epoch_bound=False).attach(trainer, n) if with_pbars: if with_pbar_on_iters: ProgressBar(persist=False).attach( trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED(every=log_every_iters)) ProgressBar(persist=True, bar_format="").attach(trainer, event_name=Events.EPOCH_STARTED, closing_event_name=Events.COMPLETED)
if args.mixed_precision: (model, loss_fn), optimizer = amp.initialize([model, loss_fn], optimizer) if args.distributed: model = DistributedDataParallel(model, device_ids=[local_rank]) trainer = create_sr_trainer( model, loss_fn, optimizer, device=device, mixed_precision=args.mixed_precision, ) ProgressBar(persist=False).attach(trainer, ['loss']) trainer.add_event_handler(Events.ITERATION_COMPLETED, TerminateOnNan()) trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _engine: scheduler.step()) evaluator = create_sr_evaluator( model, device=device, mean=MEAN, ) if local_rank == 0: checkpointer = ModelCheckpoint( dirname='checkpoints', filename_prefix='model', score_name='pnsr',
def train(self, config, **kwargs): """Trains a given model specified in the config file or passed as the --model parameter. All options in the config file can be overwritten as needed by passing --PARAM Options with variable lengths ( e.g., kwargs can be passed by --PARAM '{"PARAM1":VAR1, "PARAM2":VAR2}' :param config: yaml config file :param **kwargs: parameters to overwrite yaml config """ config_parameters = utils.parse_config_or_kwargs(config, **kwargs) outputdir = os.path.join( config_parameters['outputpath'], config_parameters['model'], "{}_{}".format( datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%m'), uuid.uuid1().hex)) # Create base dir Path(outputdir).mkdir(exist_ok=True, parents=True) logger = utils.getfile_outlogger(os.path.join(outputdir, 'train.log')) logger.info("Storing files in {}".format(outputdir)) # utils.pprint_dict utils.pprint_dict(config_parameters, logger.info) logger.info("Running on device {}".format(DEVICE)) labels_df = pd.read_csv(config_parameters['label'], sep='\s+').convert_dtypes() # In case of ave dataset where index is int, we change the # absolute name to relname if not np.all(labels_df['filename'].str.isnumeric()): labels_df.loc[:, 'filename'] = labels_df['filename'].apply( os.path.basename) encoder = utils.train_labelencoder(labels=labels_df['event_labels']) # These labels are useless, only for mode == stratified label_array, _ = utils.encode_labels(labels_df['event_labels'], encoder) if 'cv_label' in config_parameters: cv_df = pd.read_csv(config_parameters['cv_label'], sep='\s+').convert_dtypes() if not np.all(cv_df['filename'].str.isnumeric()): cv_df.loc[:, 'filename'] = cv_df['filename'].apply( os.path.basename) train_df = labels_df logger.info( f"Using CV labels from {config_parameters['cv_label']}") else: train_df, cv_df = utils.split_train_cv( labels_df, y=label_array, **config_parameters['data_args']) if 'cv_data' in config_parameters: cv_data = config_parameters['cv_data'] logger.info(f"Using CV data {config_parameters['cv_data']}") else: cv_data = config_parameters['data'] train_label_array, _ = utils.encode_labels(train_df['event_labels'], encoder) cv_label_array, _ = utils.encode_labels(cv_df['event_labels'], encoder) transform = utils.parse_transforms(config_parameters['transforms']) utils.pprint_dict({'Classes': encoder.classes_}, logger.info, formatter='pretty') torch.save(encoder, os.path.join(outputdir, 'run_encoder.pth')) torch.save(config_parameters, os.path.join(outputdir, 'run_config.pth')) logger.info("Transforms:") utils.pprint_dict(transform, logger.info, formatter='pretty') # For Unbalanced Audioset, this is true if 'sampler' in config_parameters and config_parameters[ 'sampler'] == 'MultiBalancedSampler': # Training sampler that oversamples the dataset to be roughly equally sized # Calcualtes mean over multiple instances, rather useful when number of classes # is large train_sampler = dataset.MultiBalancedSampler( train_label_array, num_samples=1 * train_label_array.shape[0], replacement=True) sampling_kwargs = {"shuffle": False, "sampler": train_sampler} elif 'sampler' in config_parameters and config_parameters[ 'sampler'] == 'MinimumOccupancySampler': # Asserts that each "batch" contains at least one instance train_sampler = dataset.MinimumOccupancySampler( train_label_array, sampling_mode='same') sampling_kwargs = {"shuffle": False, "sampler": train_sampler} else: sampling_kwargs = {"shuffle": True} logger.info("Using Sampler {}".format(sampling_kwargs)) trainloader = dataset.getdataloader( { 'filename': train_df['filename'].values, 'encoded': train_label_array }, config_parameters['data'], transform=transform, batch_size=config_parameters['batch_size'], colname=config_parameters['colname'], num_workers=config_parameters['num_workers'], **sampling_kwargs) cvdataloader = dataset.getdataloader( { 'filename': cv_df['filename'].values, 'encoded': cv_label_array }, cv_data, transform=None, shuffle=False, colname=config_parameters['colname'], batch_size=config_parameters['batch_size'], num_workers=config_parameters['num_workers']) model = getattr(models, config_parameters['model'], 'CRNN')(inputdim=trainloader.dataset.datadim, outputdim=len(encoder.classes_), **config_parameters['model_args']) if 'pretrained' in config_parameters and config_parameters[ 'pretrained'] is not None: models.load_pretrained(model, config_parameters['pretrained'], outputdim=len(encoder.classes_)) logger.info("Loading pretrained model {}".format( config_parameters['pretrained'])) model = model.to(DEVICE) if config_parameters['optimizer'] == 'AdaBound': try: import adabound optimizer = adabound.AdaBound( model.parameters(), **config_parameters['optimizer_args']) except ImportError: config_parameters['optimizer'] = 'Adam' config_parameters['optimizer_args'] = {} else: optimizer = getattr( torch.optim, config_parameters['optimizer'], )(model.parameters(), **config_parameters['optimizer_args']) utils.pprint_dict(optimizer, logger.info, formatter='pretty') utils.pprint_dict(model, logger.info, formatter='pretty') if DEVICE.type != 'cpu' and torch.cuda.device_count() > 1: logger.info("Using {} GPUs!".format(torch.cuda.device_count())) model = torch.nn.DataParallel(model) criterion = getattr(losses, config_parameters['loss'])().to(DEVICE) def _train_batch(_, batch): model.train() with torch.enable_grad(): optimizer.zero_grad() output = self._forward( model, batch) # output is tuple (clip, frame, target) loss = criterion(*output) loss.backward() # Single loss optimizer.step() return loss.item() def _inference(_, batch): model.eval() with torch.no_grad(): return self._forward(model, batch) def thresholded_output_transform(output): # Output is (clip, frame, target) y_pred, _, y = output y_pred = torch.round(y_pred) return y_pred, y precision = Precision(thresholded_output_transform, average=False) recall = Recall(thresholded_output_transform, average=False) f1_score = (precision * recall * 2 / (precision + recall)).mean() metrics = { 'Loss': losses.Loss( criterion), #reimplementation of Loss, supports 3 way loss 'Precision': Precision(thresholded_output_transform), 'Recall': Recall(thresholded_output_transform), 'Accuracy': Accuracy(thresholded_output_transform), 'F1': f1_score, } train_engine = Engine(_train_batch) inference_engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(inference_engine, name) def compute_metrics(engine): inference_engine.run(cvdataloader) results = inference_engine.state.metrics output_str_list = [ "Validation Results - Epoch : {:<5}".format(engine.state.epoch) ] for metric in metrics: output_str_list.append("{} {:<5.2f}".format( metric, results[metric])) logger.info(" ".join(output_str_list)) pbar = ProgressBar(persist=False) pbar.attach(train_engine) if 'itercv' in config_parameters and config_parameters[ 'itercv'] is not None: train_engine.add_event_handler( Events.ITERATION_COMPLETED(every=config_parameters['itercv']), compute_metrics) train_engine.add_event_handler(Events.EPOCH_COMPLETED, compute_metrics) # Default scheduler is using patience=3, factor=0.1 scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, **config_parameters['scheduler_args']) @inference_engine.on(Events.EPOCH_COMPLETED) def update_reduce_on_plateau(engine): logger.info(f"Scheduling epoch {engine.state.epoch}") val_loss = engine.state.metrics['Loss'] if 'ReduceLROnPlateau' == scheduler.__class__.__name__: scheduler.step(val_loss) else: scheduler.step() early_stop_handler = EarlyStopping( patience=config_parameters['early_stop'], score_function=self._negative_loss, trainer=train_engine) inference_engine.add_event_handler(Events.EPOCH_COMPLETED, early_stop_handler) if config_parameters['save'] == 'everyepoch': checkpoint_handler = ModelCheckpoint(outputdir, 'run', n_saved=5, require_empty=False) train_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, }) train_engine.add_event_handler( Events.ITERATION_COMPLETED(every=config_parameters['itercv']), checkpoint_handler, { 'model': model, }) else: checkpoint_handler = ModelCheckpoint( outputdir, 'run', n_saved=1, require_empty=False, score_function=self._negative_loss, global_step_transform=global_step_from_engine( train_engine), # Just so that model is saved with epoch... score_name='loss') inference_engine.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, { 'model': model, }) train_engine.run(trainloader, max_epochs=config_parameters['epochs']) return outputdir
def main(args): # Load a pre-defined tokenizer (GPT-2), create config and model logger.info("Prepare tokenizer, pretrained model and optimizer - \ add special tokens for fine-tuning") gpt_tokenizer = GPT2Tokenizer.from_pretrained(args.qgen_model_path, cache_dir=args.dataset_cache) gpt_tokenizer.sep_token = '<sep>' gpt_tokenizer.add_tokens(SPECIAL_TOKENS) gpt_tokenizer.add_tokens(AMR_SPECIAL_TOKENS) if 'amr' in args.dataset_type: qgen = GPT2LMHeadModel.from_pretrained(args.qgen_model_path, cache_dir=args.dataset_cache) else: qgen = GPT2ConditionalLMHeadModel.from_pretrained( args.qgen_model_path, cache_dir=args.dataset_cache) logger.info("Adjust model size to new tokens") qgen.resize_token_embeddings(len(gpt_tokenizer)) logger.info("Set model to GPU usage") qgen.to(args.device) logger.info("Set up optimizer") qgen_optimizer = AdamW(qgen.parameters(), lr=args.learning_rate, eps=args.adam_epsilon) bos, eos, ctx, ans, que, pad, gen = \ gpt_tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS) # if args.n_gpu > 1: if False: logger.info("More then 1 GPU for training") qgen = torch.nn.DataParallel(qgen) logger.info("Prepare datasets") if args.use_silver_data: data_type = 'Silver' else: data_type = 'Train' dataloader = get_data_loaders(args, gpt_tokenizer, qgen, dataset_name=data_type) # Define training function def update(engine, batch): # remove extra pad from batches batch = trim_batch(batch, pad) qgen.train() loss = torch.tensor([0.0]) ################################### # MLE training with teacher forcing ################################### if 'sl' in args.learning: input_ids, lm_labels, token_type_ids, attention_mask, _, _, _, _ =\ tuple(input_tensor.to(args.device) for input_tensor in batch) loss_ce = qgen(input_ids=input_ids, labels=lm_labels, token_type_ids=token_type_ids)[0] loss = apply_loss(engine.state.iteration, qgen_optimizer, loss_ce, args) return loss.item() trainer = Engine(update) # Add progressbar with loss RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") ProgressBar(persist=True).attach(trainer, metric_names=['loss']) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(qgen_optimizer, "lr", [(0, args.learning_rate), (args.n_epochs * len(dataloader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Save checkpoints checkpoint_handler = ModelCheckpoint(args.checkpoint, 'checkpoint', save_interval=1, n_saved=20, require_empty=False) # "getattr" take care of distributed encapsulation trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(qgen, 'module', qgen)}) # save training config torch.save(dict(args), os.path.join(args.checkpoint, 'training_args.bin')) getattr(qgen, 'module', qgen).config.to_json_file( os.path.join(args.checkpoint, CONFIG_NAME)) gpt_tokenizer.save_vocabulary(args.checkpoint) trainer.run(dataloader, max_epochs=args.n_epochs)