with torch.no_grad(): # Anneal learning rate mu = next(mu_scheme) i = engine.state.iteration for group in optimizer.param_groups: group["lr"] = mu * math.sqrt(1 - 0.999**i) / (1 - 0.9**i) return { "elbo": elbo.item(), "kl": kl_divergence.item(), "sigma": sigma, "mu": mu } # Trainer and metrics trainer = Engine(step) metric_names = ["elbo", "kl", "sigma", "mu"] RunningAverage(output_transform=lambda x: x["elbo"]).attach( trainer, "elbo") RunningAverage(output_transform=lambda x: x["kl"]).attach(trainer, "kl") RunningAverage(output_transform=lambda x: x["sigma"]).attach( trainer, "sigma") RunningAverage(output_transform=lambda x: x["mu"]).attach(trainer, "mu") ProgressBar().attach(trainer, metric_names=metric_names) # Model checkpointing checkpoint_handler = ModelCheckpoint("./", "checkpoint", save_interval=1, n_saved=3, require_empty=False)
def test_invert(self): set_determinism(seed=0) im_fname, seg_fname = [ make_nifti_image(i) for i in create_test_image_3d(101, 100, 107, noise_max=100) ] transform = Compose([ LoadImaged(KEYS), AddChanneld(KEYS), Orientationd(KEYS, "RPS"), Spacingd(KEYS, pixdim=(1.2, 1.01, 0.9), mode=["bilinear", "nearest"], dtype=np.float32), ScaleIntensityd("image", minv=1, maxv=10), RandFlipd(KEYS, prob=0.5, spatial_axis=[1, 2]), RandAxisFlipd(KEYS, prob=0.5), RandRotate90d(KEYS, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True), RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), ResizeWithPadOrCropd(KEYS, 100), ToTensord(KEYS), CastToTyped(KEYS, dtype=torch.uint8), ]) data = [{"image": im_fname, "label": seg_fname} for _ in range(12)] # num workers = 0 for mac or gpu transforms num_workers = 0 if sys.platform == "darwin" or torch.cuda.is_available( ) else 2 dataset = CacheDataset(data, transform=transform, progress=False) loader = DataLoader(dataset, num_workers=num_workers, batch_size=5) # set up engine def _train_func(engine, batch): self.assertTupleEqual(batch["image"].shape[1:], (1, 100, 100, 100)) engine.state.output = batch engine.fire_event(IterationEvents.MODEL_COMPLETED) return engine.state.output engine = Engine(_train_func) engine.register_events(*IterationEvents) # set up testing handler TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="label", nearest_interp=True, postfix="inverted1", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) # test different nearest interpolation values TransformInverter( transform=transform, loader=loader, output_keys=["image", "label"], batch_keys="image", nearest_interp=[True, False], postfix="inverted2", num_workers=0 if sys.platform == "darwin" or torch.cuda.is_available() else 2, ).attach(engine) engine.run(loader, max_epochs=1) set_determinism(seed=None) self.assertTupleEqual(engine.state.output["image"].shape, (2, 1, 100, 100, 100)) self.assertTupleEqual(engine.state.output["label"].shape, (2, 1, 100, 100, 100)) # check the nearest inerpolation mode for i in engine.state.output["image_inverted1"] + engine.state.output[ "label_inverted1"]: torch.testing.assert_allclose( i.to(torch.uint8).to(torch.float), i.to(torch.float)) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) # check labels match reverted = engine.state.output["label_inverted1"][-1].detach().cpu( ).numpy()[0].astype(np.int32) original = LoadImaged(KEYS)(data[-1])["label"] n_good = np.sum(np.isclose(reverted, original, atol=1e-3)) reverted_name = engine.state.output["label_meta_dict"][ "filename_or_obj"][-1] original_name = data[-1]["label"] self.assertEqual(reverted_name, original_name) print("invert diff", reverted.size - n_good) # 25300: 2 workers (cpu, non-macos) # 1812: 0 workers (gpu or macos) # 1824: torch 1.5.1 self.assertTrue((reverted.size - n_good) in (25300, 1812, 1824), "diff. in 3 possible values") # check the case that different items use different interpolation mode to invert transforms for i in engine.state.output["image_inverted2"]: # if the interpolation mode is nearest, accumulated diff should be smaller than 1 self.assertLess( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 1.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107)) for i in engine.state.output["label_inverted2"]: # if the interpolation mode is not nearest, accumulated diff should be greater than 10000 self.assertGreater( torch.sum( i.to(torch.float) - i.to(torch.uint8).to(torch.float)).item(), 10000.0) self.assertTupleEqual(i.shape, (1, 100, 101, 107))
]).to(device) y_pred = descriminator(x_gan) loss = loss_fn(y_pred, y_gan) if args.mixed_precision: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() optimizer.step() return loss trainer = Engine(_update_model) RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') ProgressBar(persist=False).attach(trainer, ['loss']) if local_rank == 0: checkpointer = ModelCheckpoint( dirname='checkpoints', filename_prefix='model', score_name='loss', score_function=lambda engine: engine.state.metrics['loss'], n_saved=5, global_step_transform=global_step_from_engine(trainer), ) trainer.add_event_handler( Events.COMPLETED, checkpointer, to_save={'descriminator': descriminator if not args.distributed else descriminator.module})
def create_setops_evaluator(base_model, classifier, setops_model, metrics={}, device=None): """ Factory function for creating an evaluator for supervised models Args: model (`torch.nn.Module`): the model to train optimizer (`torch.optim.Optimizer`): the optimizer to use loss_fn (torch.nn loss function): the loss function to use device (str, optional): device type specification (default: None). Applies to both model and batches. Returns: Engine: a trainer engine with supervised update function """ if device: base_model.to(device) classifier.to(device) setops_model.to(device) def _inference(engine, batch): base_model.eval() classifier.eval() setops_model.eval() with torch.no_grad(): input_a, input_b, target_a, target_b = _prepare_batch( batch, device=device) # # Apply the classification model # embed_a = base_model(input_a) output_a = classifier(embed_a) embed_b = base_model(input_b) output_b = classifier(embed_b) # # Apply the setops model. # outputs_setopt = setops_model(embed_a, embed_b) fake_a, fake_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, \ a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, \ a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a = \ [classifier(o) for o in outputs_setopt] fake_a_em, fake_b_em = outputs_setopt[:2] # # Calculate the target setops operations # target_a_bt = target_a.type(torch.cuda.ByteTensor) target_b_bt = target_b.type(torch.cuda.ByteTensor) target_a_I_b = target_a_bt & target_b_bt target_a_U_b = target_a_bt | target_b_bt target_a_S_b = target_a_bt & ~target_a_I_b target_b_S_a = target_b_bt & ~target_a_I_b target_a_I_b = target_a_I_b.type(torch.cuda.FloatTensor) target_a_U_b = target_a_U_b.type(torch.cuda.FloatTensor) target_a_S_b = target_a_S_b.type(torch.cuda.FloatTensor) target_b_S_a = target_b_S_a.type(torch.cuda.FloatTensor) return dict(outputs={ "real class a": output_a, "real class b": output_b, "fake class a": fake_a, "fake class b": fake_b, "a_S_b class": a_S_b, "b_S_a class": b_S_a, "a_U_b class": a_U_b, "b_U_a class": b_U_a, "a_I_b class": a_I_b, "b_I_a class": b_I_a, "fake embed a": fake_a_em, "fake embed b": fake_b_em, }, targets={ "class a": target_a, "class b": target_b, "a_S_b class": target_a_S_b, "b_S_a class": target_b_S_a, "a_U_b class": target_a_U_b, "a_I_b class": target_a_I_b, "embed a": embed_a, "embed b": embed_b, }) engine = Engine(_inference) for name, metric in metrics.items(): metric.attach(engine, name) return engine
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() # Supervised part y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() loss.backward() optimizer.step() # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration if config["log_every_iters"] > 0 and ( engine.state.iteration - 1) % config["log_every_iters"] == 0: batch_loss = loss.item() engine.state.saved_batch_loss = batch_loss else: batch_loss = engine.state.saved_batch_loss return { "batch loss": batch_loss, } trainer = Engine(train_step) trainer.state.saved_batch_loss = -1.0 trainer.state_dict_user_keys.append("saved_batch_loss") trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names if config["log_every_iters"] > 0 else None, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def test_pbar_fail_with_non_callable_transform(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, output_transform=1)
def _test_distrib_on_metric(device): import torch.distributed as dist rank = dist.get_rank() n_iters = 10 n_epochs = 3 batch_size = 10 n_classes = 10 data = list(range(n_iters)) np.random.seed(12) all_y_true_batch_values = np.random.randint(0, n_classes, size=(dist.get_world_size(), n_epochs * n_iters, batch_size)) all_y_pred_batch_values = np.random.rand(dist.get_world_size(), n_epochs * n_iters, batch_size, n_classes) y_true_batch_values = iter(all_y_true_batch_values[rank, ...]) y_pred_batch_values = iter(all_y_pred_batch_values[rank, ...]) def update_fn(engine, batch): y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage( Accuracy(output_transform=lambda x: [x[0], x[1]], device=device), alpha=alpha, epoch_bound=False, ) acc_metric.attach(trainer, "running_avg_accuracy") running_avg_acc = [ None, ] true_acc_metric = Accuracy(device=device) @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): i = engine.state.iteration - 1 true_acc_metric.reset() for j in range(dist.get_world_size()): output = ( torch.from_numpy(all_y_pred_batch_values[j, i, :, :]), torch.from_numpy(all_y_true_batch_values[j, i, :]), ) true_acc_metric.update(output) batch_acc = true_acc_metric._num_correct * 1.0 / true_acc_metric._num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert (engine.state.running_avg_acc == engine.state. metrics["running_avg_accuracy"]), "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) trainer.run(data, max_epochs=3)
def run(args): train_loader, val_loader = get_data_loaders(args.dataset_dir, args.batch_size, args.val_batch_size, args.num_workers) if args.seed is not None: torch.manual_seed(args.seed) device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_classes = KITTI.num_classes() model = LiLaNet(num_classes) device_count = torch.cuda.device_count() if device_count > 1: print("Using %d GPU(s)" % device_count) model = nn.DataParallel(model) args.batch_size = device_count * args.batch_size args.val_batch_size = device_count * args.val_batch_size model = model.to(device) criterion = nn.CrossEntropyLoss(weight=KITTI.class_weights()).to(device) optimizer = optim.Adam(model.parameters(), lr=args.lr) if args.resume: if os.path.isfile(args.resume): print("Loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded checkpoint '{}' (Epoch {})".format( args.resume, checkpoint['epoch'])) else: print("No checkpoint found at '{}'".format(args.resume)) def _prepare_batch(batch, non_blocking=True): distance, reflectivity, target = batch return (convert_tensor(distance, device=device, non_blocking=non_blocking), convert_tensor(reflectivity, device=device, non_blocking=non_blocking), convert_tensor(target, device=device, non_blocking=non_blocking)) def _update(engine, batch): model.train() optimizer.zero_grad() distance, reflectivity, target = _prepare_batch(batch) pred = model(distance, reflectivity) loss = criterion(pred, target) loss.backward() optimizer.step() return loss.item() trainer = Engine(_update) # attach running average metrics RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss') # attach progress bar pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=['loss']) def _inference(engine, batch): model.eval() with torch.no_grad(): distance, reflectivity, target = _prepare_batch(batch) pred = model(distance, reflectivity) return pred, target evaluator = Engine(_inference) cm = ConfusionMatrix(num_classes) IoU(cm, ignore_index=0).attach(evaluator, 'IoU') Loss(criterion).attach(evaluator, 'loss') pbar2 = ProgressBar(persist=True, desc='Eval Epoch') pbar2.attach(evaluator) def _global_step_transform(engine, event_name): if trainer.state is not None: return trainer.state.iteration else: return 1 tb_logger = TensorboardLogger(args.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag='training', metric_names=['loss']), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(evaluator, log_handler=OutputHandler( tag='validation', metric_names=['loss', 'IoU'], global_step_transform=_global_step_transform), event_name=Events.EPOCH_COMPLETED) @trainer.on(Events.STARTED) def initialize(engine): engine.state.exception_raised = False if args.resume: engine.state.epoch = args.start_epoch @evaluator.on(Events.EPOCH_COMPLETED) def save_checkpoint(engine): epoch = trainer.state.epoch if trainer.state is not None else 1 iou = engine.state.metrics['IoU'] * 100.0 mean_iou = iou.mean() name = 'epoch{}_mIoU={:.1f}.pth'.format(epoch, mean_iou) file = { 'model': model.state_dict(), 'epoch': epoch, 'optimizer': optimizer.state_dict(), 'args': args } save(file, args.output_dir, 'checkpoint_{}'.format(name)) save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) @trainer.on(Events.EPOCH_COMPLETED) def run_validation(engine): pbar.log_message("Start Validation - Epoch: [{}/{}]".format( engine.state.epoch, engine.state.max_epochs)) evaluator.run(val_loader) metrics = evaluator.state.metrics loss = metrics['loss'] iou = metrics['IoU'] * 100.0 mean_iou = iou.mean() iou_text = ', '.join([ '{}: {:.1f}'.format(KITTI.classes[i + 1].name, v) for i, v in enumerate(iou.tolist()) ]) pbar.log_message( "Validation results - Epoch: [{}/{}]: Loss: {:.2e}\n IoU: {}\n mIoU: {:.1f}" .format(engine.state.epoch, engine.state.max_epochs, loss, iou_text, mean_iou)) @trainer.on(Events.EXCEPTION_RAISED) def handle_exception(engine, e): engine.state.exception_raised = True if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1): engine.terminate() warnings.warn("KeyboardInterrupt caught. Exiting gracefully.") name = 'epoch{}_exception.pth'.format(trainer.state.epoch) file = { 'model': model.state_dict(), 'epoch': trainer.state.epoch, 'optimizer': optimizer.state_dict() } save(file, args.output_dir, 'checkpoint_{}'.format(name)) save(model.state_dict(), args.output_dir, 'model_{}'.format(name)) else: raise e if args.eval_on_start: print("Start validation") evaluator.run(val_loader, max_epochs=1) print("Start training") trainer.run(train_loader, max_epochs=args.epochs) tb_logger.close()
def train_model(learning_rate, scale, bins, la): model = Model() model2 = Model2() sf = torch.nn.Softmax(dim=1) device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") model.to(device) model2.to(device) optimizer = torch.optim.Adam( model.parameters(), lr=learning_rate, weight_decay=1e-4 ) optimizer2 = torch.optim.Adam( model2.parameters(), lr=learning_rate, weight_decay=1e-4 ) def get_dataloader(): dl_train = torch.utils.data.DataLoader( train_dataset, batch_size=128, shuffle=True, num_workers=0, drop_last=True ) dl_val = torch.utils.data.DataLoader( val_dataset, batch_size=400, shuffle=False, num_workers=0 ) dl_test = torch.utils.data.DataLoader( test_dataset, batch_size=400, shuffle=False, num_workers=0 ) return dl_train, dl_test, dl_val # get the pred from multi-views def get_pred_max(y_pred, y_pred2): pred_max = torch.max(y_pred, y_pred2) return pred_max def get_acc(y_pred, y): acc_1 = 0 a_count =0 for i in range(len(y)): if torch.argmax(y[i]) == torch.argmax(y_pred[i]): acc_1 += 1 return acc_1/len(y) def step(engine, batch): model.train() optimizer.zero_grad() x, x2, y = batch y = F.one_hot(y, num_classes=10).float() device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") x, x2, y = x.to(device), x2.to(device), y.to(device) x.requires_grad_(True) y_pred = sf(model(x)) loss = F.binary_cross_entropy(y_pred, y) x.requires_grad_(False) loss.backward() optimizer.step() return loss.item() def step2(engine, batch): model2.train() optimizer2.zero_grad() x, x2, y = batch y = F.one_hot(y, num_classes=10).float() device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") x, x2, y = x.to(device), x2.to(device), y.to(device) x2.requires_grad_(True) y_pred = sf(model2(x2)) loss2 = F.binary_cross_entropy(y_pred, y) x2.requires_grad_(False) loss2.backward() optimizer2.step() return loss2.item() def val_step(): with torch.no_grad(): for batch in dl_val: x, x2, y = batch y = F.one_hot(y, num_classes=10).float() device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") x, x2, y = x.to(device), x2.to(device), y.to(device) x.requires_grad_(True) x2.requires_grad_(True) y_pred = sf(model(x)) y_pred2 = sf(model2(x2)) return y_pred, y_pred2, y def eval_step(): with torch.no_grad(): for batch in dl_test: x, x2, y = batch y = F.one_hot(y, num_classes=10).float() device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu") x, x2, y = x.to(device), x2.to(device), y.to(device) global v1count x.requires_grad_(True) x2.requires_grad_(True) y_pred = sf(model(x)) y_pred2 = sf(model2(x2)) acc1 = get_acc(y_pred, y) acc2 = get_acc(y_pred2, y) val_pred1, val_pred2, y_val = val_step() y_pred_max = get_pred_max(y_pred, y_pred2) y_pred_note = y_pred*get_acc(val_pred1,y_val)+y_pred2*get_acc(val_pred2,y_val) acc_m = get_acc(y_pred_max, y) acc_m_note = get_acc(y_pred_note, y) y_pred_c = calibrate(y_pred, y_pred2, y, val_pred1, val_pred2, y_val, scale, bins, la) y_pred_max_c = get_pred_max(y_pred_c, y_pred2) y_pred_note2 = y_pred_c + y_pred2 acc_m_note2 =get_acc(y_pred_note2, y) acc_m_c = get_acc(y_pred_max_c, y) print(acc_m_note, acc_m_note2) return acc1, acc2, acc_m, acc_m_c trainer = Engine(step) trainer2 = Engine(step2) dl_train, dl_test, dl_val = get_dataloader() trainer.run(dl_train, max_epochs=epoch) trainer2.run(dl_train, max_epochs=epoch) acc1, acc2, acc_m, acc_m_c = eval_step() return model, acc1, acc2, acc_m, acc_m_c
def process_batch(engine, batch): optimizer.zero_grad() loss_v = common.calc_loss_dqn( batch, net, tgt_net.target_model, gamma=params.gamma, device=device ) loss_v.backward() optimizer.step() if engine.state.iteration % params.target_net_sync == 0: tgt_net.sync() return { "loss": loss_v.item(), "epsilon": batch_generator.epsilon, } engine = Engine(process_batch) ptan_ignite.EndOfEpisodeHandler(batch_generator, bound_avg_reward=17.0).attach(engine) fps_handler.attach(engine, manual_step=True) @engine.on(ptan_ignite.EpisodeEvents.EPISODE_COMPLETED) def episode_completed(trainer: Engine): print( "Episode %d: reward=%s, steps=%s, speed=%.3f frames/s, elapsed=%s" % ( trainer.state.episode, trainer.state.episode_reward, trainer.state.episode_steps, trainer.state.metrics.get("avg_fps", 0), timedelta(seconds=trainer.state.metrics.get("time_passed", 0)), ) )
def inference(config, local_rank, with_pbar_on_iters=True): set_seed(config.seed + local_rank) torch.cuda.set_device(local_rank) device = 'cuda' torch.backends.cudnn.benchmark = True # Load model and weights model_weights_filepath = Path( get_artifact_path(config.run_uuid, config.weights_filename)) assert model_weights_filepath.exists(), \ "Model weights file '{}' is not found".format(model_weights_filepath.as_posix()) model = config.model.to(device) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank) if hasattr(config, "custom_weights_loading"): config.custom_weights_loading(model, model_weights_filepath) else: state_dict = torch.load(model_weights_filepath) if not all([k.startswith("module.") for k in state_dict]): state_dict = {f"module.{k}": v for k, v in state_dict.items()} model.load_state_dict(state_dict) model.eval() prepare_batch = config.prepare_batch non_blocking = getattr(config, "non_blocking", True) model_output_transform = getattr(config, "model_output_transform", lambda x: x) tta_transforms = getattr(config, "tta_transforms", None) def eval_update_function(engine, batch): with torch.no_grad(): x, y, meta = prepare_batch(batch, device=device, non_blocking=non_blocking) if tta_transforms is not None: y_preds = [] for t in tta_transforms: t_x = t.augment_image(x) t_y_pred = model(t_x) t_y_pred = model_output_transform(t_y_pred) y_pred = t.deaugment_mask(t_y_pred) y_preds.append(y_pred) y_preds = torch.stack(y_preds, dim=0) y_pred = torch.mean(y_preds, dim=0) else: y_pred = model(x) y_pred = model_output_transform(y_pred) return {"y_pred": y_pred, "y": y, "meta": meta} evaluator = Engine(eval_update_function) has_targets = getattr(config, "has_targets", False) if has_targets: def output_transform(output): return output['y_pred'], output['y'] num_classes = config.num_classes cm_metric = ConfusionMatrix(num_classes=num_classes, output_transform=output_transform) pr = cmPrecision(cm_metric, average=False) re = cmRecall(cm_metric, average=False) val_metrics = { "IoU": IoU(cm_metric), "mIoU_bg": mIoU(cm_metric), "Accuracy": cmAccuracy(cm_metric), "Precision": pr, "Recall": re, "F1": Fbeta(beta=1.0, output_transform=output_transform) } if hasattr(config, "metrics") and isinstance(config.metrics, dict): val_metrics.update(config.metrics) for name, metric in val_metrics.items(): metric.attach(evaluator, name) if dist.get_rank() == 0: # Log val metrics: mlflow_logger = MLflowLogger() mlflow_logger.attach(evaluator, log_handler=OutputHandler( tag="validation", metric_names=list(val_metrics.keys())), event_name=Events.EPOCH_COMPLETED) if dist.get_rank() == 0 and with_pbar_on_iters: ProgressBar(persist=True, desc="Inference").attach(evaluator) if dist.get_rank() == 0: do_save_raw_predictions = getattr(config, "do_save_raw_predictions", True) do_save_overlayed_predictions = getattr( config, "do_save_overlayed_predictions", True) if not has_targets: assert do_save_raw_predictions or do_save_overlayed_predictions, \ "If no targets, either do_save_overlayed_predictions or do_save_raw_predictions should be " \ "defined in the config and has value equal True" # Save predictions if do_save_raw_predictions: raw_preds_path = config.output_path / "raw" raw_preds_path.mkdir(parents=True) evaluator.add_event_handler(Events.ITERATION_COMPLETED, save_raw_predictions_with_geoinfo, raw_preds_path) if do_save_overlayed_predictions: overlayed_preds_path = config.output_path / "overlay" overlayed_preds_path.mkdir(parents=True) evaluator.add_event_handler( Events.ITERATION_COMPLETED, save_overlayed_predictions, overlayed_preds_path, img_denormalize_fn=config.img_denormalize, palette=default_palette) evaluator.add_event_handler(Events.EXCEPTION_RAISED, report_exception) # Run evaluation evaluator.run(config.data_loader)
def train(): parser = ArgumentParser() parser.add_argument("--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=16, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") # parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Accumulate gradients on several steps") # parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lr", type=float, default=1e-4, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--fp16", type=str, default="", help="Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument("--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument( "--init_model", default="model/pytorch_kogpt2_676e9bcfa7.params", type=str, help="The model checkpoint for weights initialization. Leave None if you want to train a model from scratch.", ) args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig(level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", args.local_rank) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") config = GPT2Config(vocab_size=50000) model = GPT2DoubleHeadsModel(config) if args.init_model: print("Load model from ", args.init_model) model.load_state_dict(torch.load(args.init_model), strict=False) model.to(args.device) add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch (lm_loss), (mc_loss), *_ = model( input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, mc_labels=mc_labels, lm_labels=lm_labels ) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch # logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) # if we dont send labels to model, it doesnt return losses lm_logits, mc_logits, *_ = model( input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, ) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args)}) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) log_dir = make_logdir(args.init_model) tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) # tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename(os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def train(): config_file = "configs/train_daily_dialog_emotion_action_topic_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig(level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning("Running process %d", config.local_rank) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning") tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(config, tokenizer) # Training function and trainer def update(engine, batch): model.train() input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple(input_tensor.to(config.device) for input_tensor in batch) lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids) loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps if config.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) if engine.state.iteration % config.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple(input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids, token_action_ids=token_action_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(-1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if config.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if config.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if config.distributed: trainer.add_event_handler(Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler(Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1]))} metrics.update({"average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config)}) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if config.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=config.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list(metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=config.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if config.local_rank in [-1, 0] and config.n_epochs > 0: os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def main(local_rank): params = init_parms(local_rank) device = params.get('device') model = ASRModel(input_features=config.num_mel_banks, num_classes=config.vocab_size).to(device) logger.info( f'Model initialized with {get_model_size(model):.3f}M parameters') optimizer = Ranger(model.parameters(), lr=config.lr, eps=1e-5) model = DistributedDataParallel(model, device_ids=[local_rank], output_device=local_rank, check_reduction=True) load_checkpoint(model, optimizer, params) print(f"Loaded model on {local_rank}") start_epoch = params['start_epoch'] sup_criterion = CustomCTCLoss() unsup_criterion = UDALoss() if args.local_rank == 0: tb_logger = TensorboardLogger(log_dir=log_path) pbar = ProgressBar(persist=True, desc="Training") pbar_valid = ProgressBar(persist=True, desc="Validation Clean") pbar_valid_other = ProgressBar(persist=True, desc="Validation Other") pbar_valid_airtel = ProgressBar(persist=True, desc="Validation Airtel") pbar_valid_airtel_payments = ProgressBar( persist=True, desc="Validation Airtel Payments") timer = Timer(average=True) best_meter = params.get('best_stats', BestMeter()) trainCleanPath = os.path.join(lmdb_root_path, 'train-labelled') trainOtherPath = os.path.join(lmdb_root_path, 'train-unlabelled') trainCommonVoicePath = os.path.join(lmdb_commonvoice_root_path, 'train-labelled-en') trainAirtelPath = os.path.join(lmdb_airtel_root_path, 'train-labelled-en') trainAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'train-labelled-en') testCleanPath = os.path.join(lmdb_root_path, 'test-clean') testOtherPath = os.path.join(lmdb_root_path, 'test-other') testAirtelPath = os.path.join(lmdb_airtel_root_path, 'test-labelled-en') testAirtelPaymentsPath = os.path.join(lmdb_airtel_payments_root_path, 'test-labelled-en') devOtherPath = os.path.join(lmdb_root_path, 'dev-other') train_clean = lmdbMultiDataset(roots=[ trainCleanPath, trainOtherPath, trainCommonVoicePath, trainAirtelPath, trainAirtelPaymentsPath ], transform=image_train_transform) train_other = lmdbMultiDataset(roots=[devOtherPath], transform=image_train_transform) test_clean = lmdbMultiDataset(roots=[testCleanPath], transform=image_val_transform) test_other = lmdbMultiDataset(roots=[testOtherPath], transform=image_val_transform) test_airtel = lmdbMultiDataset(roots=[testAirtelPath], transform=image_val_transform) test_payments_airtel = lmdbMultiDataset(roots=[testAirtelPaymentsPath], transform=image_val_transform) logger.info( f'Loaded Train & Test Datasets, train_labbeled={len(train_clean)}, train_unlabbeled={len(train_other)}, test_clean={len(test_clean)}, test_other={len(test_other)}, test_airtel={len(test_airtel)}, test_payments_airtel={len(test_payments_airtel)} examples' ) def train_update_function(engine, _): optimizer.zero_grad() # Supervised gt, pred imgs_sup, labels_sup, label_lengths = next( engine.state.train_loader_labbeled) imgs_sup = imgs_sup.cuda(local_rank, non_blocking=True) labels_sup = labels_sup probs_sup = model(imgs_sup) # Unsupervised gt, pred # imgs_unsup, augmented_imgs_unsup = next(engine.state.train_loader_unlabbeled) # with torch.no_grad(): # probs_unsup = model(imgs_unsup.to(device)) # probs_aug_unsup = model(augmented_imgs_unsup.to(device)) sup_loss = sup_criterion(probs_sup, labels_sup, label_lengths) # unsup_loss = unsup_criterion(probs_unsup, probs_aug_unsup) # Blend supervised and unsupervised losses till unsupervision_warmup_epoch # alpha = get_alpha(engine.state.epoch) # final_loss = ((1 - alpha) * sup_loss) + (alpha * unsup_loss) # final_loss = sup_loss sup_loss.backward() optimizer.step() return sup_loss.item() @torch.no_grad() def validate_update_function(engine, batch): img, labels, label_lengths = batch y_pred = model(img.cuda(local_rank, non_blocking=True)) if np.random.rand() > 0.99: pred_sentences = get_most_probable(y_pred) labels_list = labels.tolist() idx = 0 for i, length in enumerate(label_lengths.cpu().tolist()): pred_sentence = pred_sentences[i] gt_sentence = get_sentence(labels_list[idx:idx + length]) idx += length print(f"Pred sentence: {pred_sentence}, GT: {gt_sentence}") return (y_pred, labels, label_lengths) train_sampler_labbeled = torch.utils.data.distributed.DistributedSampler( train_clean, num_replicas=3, rank=args.local_rank) train_sampler_unlabbeled = torch.utils.data.distributed.DistributedSampler( train_other, num_replicas=3, rank=args.local_rank) test_sampler_clean = torch.utils.data.distributed.DistributedSampler( test_clean, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_other = torch.utils.data.distributed.DistributedSampler( test_other, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_airtel = torch.utils.data.distributed.DistributedSampler( test_airtel, num_replicas=3, rank=args.local_rank, shuffle=False) test_sampler_airtel_payments = torch.utils.data.distributed.DistributedSampler( test_payments_airtel, num_replicas=3, rank=args.local_rank, shuffle=False) train_loader_labbeled_loader = torch.utils.data.DataLoader( train_clean, batch_size=train_batch_size // 3, sampler=train_sampler_labbeled, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) train_loader_unlabbeled_loader = torch.utils.data.DataLoader( train_other, batch_size=train_batch_size * 4, sampler=train_sampler_unlabbeled, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_clean = torch.utils.data.DataLoader( test_clean, batch_size=1, sampler=test_sampler_clean, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_other = torch.utils.data.DataLoader( test_other, batch_size=1, sampler=test_sampler_other, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_airtel = torch.utils.data.DataLoader( test_airtel, batch_size=1, sampler=test_sampler_airtel, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) test_loader_airtel_payments = torch.utils.data.DataLoader( test_payments_airtel, batch_size=1, sampler=test_sampler_airtel_payments, num_workers=config.workers // 3, pin_memory=True, collate_fn=allign_collate) trainer = Engine(train_update_function) iteration_log_step = int(0.33 * len(train_loader_labbeled_loader)) evaluator_clean = Engine(validate_update_function) evaluator_other = Engine(validate_update_function) evaluator_airtel = Engine(validate_update_function) evaluator_airtel_payments = Engine(validate_update_function) metrics = {'wer': WordErrorRate(), 'cer': CharacterErrorRate()} for name, metric in metrics.items(): metric.attach(evaluator_clean, name) metric.attach(evaluator_other, name) metric.attach(evaluator_airtel, name) metric.attach(evaluator_airtel_payments, name) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='min', factor=config.lr_gamma, patience=int(config.epochs * 0.05), verbose=True, threshold_mode="abs", cooldown=int(config.epochs * 0.025), min_lr=1e-5) if args.local_rank == 0: tb_logger.attach(trainer, log_handler=OutputHandler( tag="training", output_transform=lambda loss: {'loss': loss}), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(trainer, log_handler=WeightsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(trainer, log_handler=WeightsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsScalarHandler(model), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=GradsHistHandler(model), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_clean, log_handler=OutputHandler(tag="validation_clean", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_other, log_handler=OutputHandler(tag="validation_other", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel, log_handler=OutputHandler(tag="validation_airtel", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) tb_logger.attach(evaluator_airtel_payments, log_handler=OutputHandler( tag="validation_airtel_payments", metric_names=["wer", "cer"], another_engine=trainer), event_name=Events.EPOCH_COMPLETED) pbar.attach(trainer, output_transform=lambda x: {'loss': x}) pbar_valid.attach(evaluator_clean, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_other.attach(evaluator_other, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel.attach(evaluator_airtel, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) pbar_valid_airtel_payments.attach(evaluator_airtel_payments, ['wer', 'cer'], event_name=Events.EPOCH_COMPLETED, closing_event_name=Events.COMPLETED) timer.attach(trainer) @trainer.on(Events.STARTED) def set_init_epoch(engine): engine.state.epoch = params['start_epoch'] logger.info(f'Initial epoch for trainer set to {engine.state.epoch}') @trainer.on(Events.EPOCH_STARTED) def set_model_train(engine): if hasattr(engine.state, 'train_loader_labbeled'): del engine.state.train_loader_labbeled engine.state.train_loader_labbeled = iter(train_loader_labbeled_loader) # engine.state.train_loader_unlabbeled = iter(train_loader_unlabbeled_loader) @trainer.on(Events.ITERATION_COMPLETED) def iteration_completed(engine): if (engine.state.iteration % iteration_log_step == 0) and (engine.state.iteration > 0): engine.state.epoch += 1 train_clean.set_epochs(engine.state.epoch) train_other.set_epochs(engine.state.epoch) model.eval() logger.info('Model set to eval mode') evaluator_clean.run(test_loader_clean) evaluator_other.run(test_loader_other) evaluator_airtel.run(test_loader_airtel) evaluator_airtel_payments.run(test_loader_airtel_payments) model.train() logger.info('Model set back to train mode') if args.local_rank == 0: @evaluator_other.on(Events.EPOCH_COMPLETED) def save_checkpoints(engine): metrics = engine.state.metrics wer = metrics['wer'] cer = metrics['cer'] epoch = trainer.state.epoch scheduler.step(wer) save_checkpoint(model, optimizer, best_meter, wer, cer, epoch) best_meter.update(wer, cer, epoch) @trainer.on(Events.EPOCH_COMPLETED) def after_complete(engine): logger.info('Epoch {} done. Time per batch: {:.3f}[s]'.format( engine.state.epoch, timer.value())) timer.reset() trainer.run(train_loader_labbeled_loader, max_epochs=epochs) if args.local_rank == 0: tb_logger.close()
def test_run_with_max_iters_greater_than_epoch_length(): max_iters = 73 engine = Engine(lambda e, b: 1) engine.run([0] * 20, max_iters=max_iters) assert engine.state.iteration == max_iters
def test_linear_scheduler(): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0) scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1, 0, 10) lrs = [] def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) trainer.run([0] * 10, max_epochs=2) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 2 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, ])) optimizer = torch.optim.SGD([tensor], lr=0) scheduler = LinearCyclicalScheduler(optimizer, 'lr', 1, 0, 10, cycle_mult=2) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) lrs = [] trainer.run([0] * 10, max_epochs=3) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 2 1.0, 0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, ])) # With float cycle_size optimizer = torch.optim.SGD([tensor], lr=0) scheduler = LinearCyclicalScheduler(optimizer, 'lr', start_value=1.2, end_value=0.2, cycle_size=10.00000012, cycle_mult=1.0) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) lrs = [] trainer.run([0] * 10, max_epochs=2) assert lrs == list( map( pytest.approx, [ # Cycle 1 1.2, 1.0, 0.8, 0.6, 0.4, 0.2, 0.4, 0.6, 0.8, 1.0, # Cycle 2 1.2, 1.0, 0.8, 0.6, 0.4, 0.2, 0.4, 0.6, 0.8, 1.0, ]))
def main( architecture, batch_size, length_scale, centroid_size, learning_rate, l_gradient_penalty, gamma, weight_decay, final_model, output_dir, ): writer = SummaryWriter(log_dir=f"runs/{output_dir}") ds = all_datasets["CIFAR10"]() input_size, num_classes, dataset, test_dataset = ds # Split up training set idx = list(range(len(dataset))) random.shuffle(idx) if final_model: train_dataset = dataset val_dataset = test_dataset else: val_size = int(len(dataset) * 0.8) train_dataset = torch.utils.data.Subset(dataset, idx[:val_size]) val_dataset = torch.utils.data.Subset(dataset, idx[val_size:]) val_dataset.transform = (test_dataset.transform ) # Test time preprocessing for validation if architecture == "WRN": model_output_size = 640 epochs = 200 milestones = [60, 120, 160] feature_extractor = WideResNet() elif architecture == "ResNet18": model_output_size = 512 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet18() elif architecture == "ResNet50": model_output_size = 2048 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet50() elif architecture == "ResNet110": model_output_size = 2048 epochs = 200 milestones = [60, 120, 160] feature_extractor = resnet110() elif architecture == "DenseNet121": model_output_size = 1024 epochs = 200 milestones = [60, 120, 160] feature_extractor = densenet121() # Adapted resnet from: # https://github.com/kuangliu/pytorch-cifar/blob/master/models/resnet.py feature_extractor.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) feature_extractor.maxpool = torch.nn.Identity() feature_extractor.fc = torch.nn.Identity() if centroid_size is None: centroid_size = model_output_size model = ResNet_DUQ( feature_extractor, num_classes, centroid_size, model_output_size, length_scale, gamma, ) model = model.cuda() optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=weight_decay) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.2) def calc_gradients_input(x, y_pred): gradients = torch.autograd.grad( outputs=y_pred, inputs=x, grad_outputs=torch.ones_like(y_pred), create_graph=True, )[0] gradients = gradients.flatten(start_dim=1) return gradients def calc_gradient_penalty(x, y_pred): gradients = calc_gradients_input(x, y_pred) # L2 norm grad_norm = gradients.norm(2, dim=1) # Two sided penalty gradient_penalty = ((grad_norm - 1)**2).mean() return gradient_penalty def step(engine, batch): model.train() optimizer.zero_grad() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) y_pred = model(x) y = F.one_hot(y, num_classes).float() loss = F.binary_cross_entropy(y_pred, y, reduction="mean") if l_gradient_penalty > 0: gp = calc_gradient_penalty(x, y_pred) loss += l_gradient_penalty * gp loss.backward() optimizer.step() x.requires_grad_(False) with torch.no_grad(): model.eval() model.update_embeddings(x, y) return loss.item() def eval_step(engine, batch): model.eval() x, y = batch x, y = x.cuda(), y.cuda() x.requires_grad_(True) y_pred = model(x) return {"x": x, "y": y, "y_pred": y_pred} trainer = Engine(step) evaluator = Engine(eval_step) metric = Average() metric.attach(trainer, "loss") metric = Accuracy(output_transform=lambda out: (out["y_pred"], out["y"])) metric.attach(evaluator, "accuracy") def bce_output_transform(out): return (out["y_pred"], F.one_hot(out["y"], num_classes).float()) metric = Loss(F.binary_cross_entropy, output_transform=bce_output_transform) metric.attach(evaluator, "bce") metric = Loss(calc_gradient_penalty, output_transform=lambda out: (out["x"], out["y_pred"])) metric.attach(evaluator, "gradient_penalty") pbar = ProgressBar(dynamic_ncols=True) pbar.attach(trainer) kwargs = {"num_workers": 4, "pin_memory": True} train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True, **kwargs) val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, **kwargs) test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, **kwargs) @trainer.on(Events.EPOCH_COMPLETED) def log_results(trainer): metrics = trainer.state.metrics loss = metrics["loss"] print(f"Train - Epoch: {trainer.state.epoch} Loss: {loss:.2f}") writer.add_scalar("Loss/train", loss, trainer.state.epoch) if trainer.state.epoch > (epochs - 5): accuracy, auroc = get_cifar_svhn_ood(model) print(f"Test Accuracy: {accuracy}, AUROC: {auroc}") writer.add_scalar("OoD/test_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc", auroc, trainer.state.epoch) accuracy, auroc = get_auroc_classification(val_dataset, model) print(f"AUROC - uncertainty: {auroc}") writer.add_scalar("OoD/val_accuracy", accuracy, trainer.state.epoch) writer.add_scalar("OoD/roc_auc_classification", auroc, trainer.state.epoch) evaluator.run(val_loader) metrics = evaluator.state.metrics acc = metrics["accuracy"] bce = metrics["bce"] GP = metrics["gradient_penalty"] loss = bce + l_gradient_penalty * GP print((f"Valid - Epoch: {trainer.state.epoch} " f"Acc: {acc:.4f} " f"Loss: {loss:.2f} " f"BCE: {bce:.2f} " f"GP: {GP:.2f} ")) writer.add_scalar("Loss/valid", loss, trainer.state.epoch) writer.add_scalar("BCE/valid", bce, trainer.state.epoch) writer.add_scalar("GP/valid", GP, trainer.state.epoch) writer.add_scalar("Accuracy/valid", acc, trainer.state.epoch) scheduler.step() trainer.run(train_loader, max_epochs=epochs) evaluator.run(test_loader) acc = evaluator.state.metrics["accuracy"] print(f"Test - Accuracy {acc:.4f}") torch.save(model.state_dict(), f"runs/{output_dir}/model.pt") writer.close()
def test_concat_scheduler(): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0) scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.0, cycle_size=10) scheduler_2 = CosineAnnealingScheduler(optimizer, "lr", start_value=0.0, end_value=1.0, cycle_size=10) durations = [ 10, ] concat_scheduler = ConcatScheduler(schedulers=[scheduler_1, scheduler_2], durations=durations, save_history=True) data = [0] * 10 max_epochs = 2 simulated_values = ConcatScheduler.simulate_values( num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2], durations=durations) lrs = [] def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) trainer.run(data, max_epochs=max_epochs) assert lrs == list( map( pytest.approx, [ # Cycle 1 of the LinearCyclicalScheduler 1.0, 0.8, 0.6, 0.4, 0.2, 0.0, 0.2, 0.4, 0.6, 0.8, # Cycle 1 of the CosineAnnealingScheduler 0.0, 0.02447174185242318, 0.09549150281252627, 0.20610737385376332, 0.3454915028125263, 0.5, 0.6545084971874737, 0.7938926261462365, 0.9045084971874737, 0.9755282581475768, ])) state_lrs = trainer.state.param_history['lr'] assert len(state_lrs) == len(lrs) # Unpack singleton lists assert [group[0] for group in state_lrs] == lrs assert lrs == pytest.approx([v for i, v in simulated_values])
def test_attach_fail_with_string(): engine = Engine(update_fn) pbar = ProgressBar() with pytest.raises(TypeError): pbar.attach(engine, 'a')
def test_concat_scheduler_3_schedulers(): tensor = torch.zeros([1], requires_grad=True) optimizer = torch.optim.SGD([tensor], lr=0) scheduler_1 = LinearCyclicalScheduler(optimizer, "lr", start_value=1.0, end_value=0.5, cycle_size=20) scheduler_2 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.45, cycle_size=10) scheduler_3 = LinearCyclicalScheduler(optimizer, "lr", start_value=0.5, end_value=0.0, cycle_size=20) durations = [10, 5] concat_scheduler = ConcatScheduler( schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations, save_history=True) data = [0] * 10 max_epochs = 2 simulated_values = ConcatScheduler.simulate_values( num_events=len(data) * max_epochs, schedulers=[scheduler_1, scheduler_2, scheduler_3], durations=durations) lrs = [] def save_lr(engine): lrs.append(optimizer.param_groups[0]['lr']) trainer = Engine(lambda engine, batch: None) trainer.add_event_handler(Events.ITERATION_STARTED, concat_scheduler) trainer.add_event_handler(Events.ITERATION_COMPLETED, save_lr) trainer.run(data, max_epochs=max_epochs) assert lrs == list( map( pytest.approx, [ # Cycle 1 of the first LinearCyclicalScheduler 1.0, 0.95, 0.9, 0.85, 0.8, 0.75, 0.7, 0.65, 0.6, 0.55, # Cycle 1 of the second LinearCyclicalScheduler 0.5, 0.49, 0.48, 0.47, 0.46, # Cycle 1 of the third LinearCyclicalScheduler 0.5, 0.45, 0.4, 0.35, 0.3, ])) state_lrs = trainer.state.param_history['lr'] assert len(state_lrs) == len(lrs) # Unpack singleton lists assert [group[0] for group in state_lrs] == lrs assert lrs == pytest.approx([v for i, v in simulated_values])
def test_integration(): n_iters = 100 batch_size = 10 n_classes = 10 y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) loss_values = iter(range(n_iters)) def update_fn(engine, batch): loss_value = next(loss_values) y_true_batch = next(y_true_batch_values) y_pred_batch = next(y_pred_batch_values) return ( loss_value, torch.from_numpy(y_pred_batch), torch.from_numpy(y_true_batch), ) trainer = Engine(update_fn) alpha = 0.98 acc_metric = RunningAverage( Accuracy(output_transform=lambda x: [x[1], x[2]]), alpha=alpha) acc_metric.attach(trainer, "running_avg_accuracy") avg_output = RunningAverage(output_transform=lambda x: x[0], alpha=alpha) avg_output.attach(trainer, "running_avg_output") running_avg_acc = [ None, ] @trainer.on(Events.ITERATION_COMPLETED) def manual_running_avg_acc(engine): _, y_pred, y = engine.state.output indices = torch.max(y_pred, 1)[1] correct = torch.eq(indices, y).view(-1) num_correct = torch.sum(correct).item() num_examples = correct.shape[0] batch_acc = num_correct * 1.0 / num_examples if running_avg_acc[0] is None: running_avg_acc[0] = batch_acc else: running_avg_acc[0] = running_avg_acc[0] * alpha + ( 1.0 - alpha) * batch_acc engine.state.running_avg_acc = running_avg_acc[0] @trainer.on(Events.EPOCH_STARTED) def running_avg_output_init(engine): engine.state.running_avg_output = None @trainer.on(Events.ITERATION_COMPLETED) def running_avg_output_update(engine): if engine.state.running_avg_output is None: engine.state.running_avg_output = engine.state.output[0] else: engine.state.running_avg_output = ( engine.state.running_avg_output * alpha + (1.0 - alpha) * engine.state.output[0]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_acc_values(engine): assert (engine.state.running_avg_acc == engine.state. metrics["running_avg_accuracy"]), "{} vs {}".format( engine.state.running_avg_acc, engine.state.metrics["running_avg_accuracy"]) @trainer.on(Events.ITERATION_COMPLETED) def assert_equal_running_avg_output_values(engine): assert (engine.state.running_avg_output == engine.state.metrics["running_avg_output"]), "{} vs {}".format( engine.state.running_avg_output, engine.state.metrics["running_avg_output"]) np.random.seed(10) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1) running_avg_acc = [ None, ] n_iters = 10 batch_size = 10 n_classes = 10 data = list(range(n_iters)) loss_values = iter(range(n_iters)) y_true_batch_values = iter( np.random.randint(0, n_classes, size=(n_iters, batch_size))) y_pred_batch_values = iter(np.random.rand(n_iters, batch_size, n_classes)) trainer.run(data, max_epochs=1)
def test_terminate(): engine = Engine(lambda e, b: 1) assert not engine.should_terminate engine.terminate() assert engine.should_terminate
def create_setops_trainer(base_model, classifier, setops_model, optimizer, criterion1, criterion2, params_object, metrics={}, device=None): """ Factory function for creating a trainer for supervised models Args: model (`torch.nn.Module`): the model to train optimizer (`torch.optim.Optimizer`): the optimizer to use loss_fn (torch.nn loss function): the loss function to use device (str, optional): device type specification (default: None). Applies to both model and batches. Returns: Engine: a trainer engine with supervised update function """ if device: base_model.to(device) classifier.to(device) setops_model.to(device) def _update(engine, batch): if params_object.train_base: base_model.train() else: base_model.eval() classifier.train() setops_model.train() optimizer.zero_grad() input_a, input_b, target_a, target_b = _prepare_batch(batch, device=device) # # Apply the classification model # with conditional(not params_object.train_base, torch.no_grad()): embed_a = base_model(input_a) embed_b = base_model(input_b) output_a = classifier(embed_a) output_b = classifier(embed_b) # # Apply the setopt model. # outputs_setopt = setops_model(embed_a, embed_b) fake_a, fake_b, a_S_b, b_S_a, a_U_b, b_U_a, a_I_b, b_I_a, \ a_S_b_b, b_S_a_a, a_I_b_b, b_I_a_a, a_U_b_b, b_U_a_a, \ a_S_b_I_a, b_S_a_I_b, a_S_a_I_b, b_S_b_I_a = \ [classifier(o) for o in outputs_setopt] fake_a_em, fake_b_em, a_S_b_em, b_S_a_em, a_U_b_em, b_U_a_em, a_I_b_em, b_I_a_em, \ a_S_b_b_em, b_S_a_a_em, a_I_b_b_em, b_I_a_a_em, a_U_b_b_em, b_U_a_a_em, \ a_S_b_I_a_em, b_S_a_I_b_em, a_S_a_I_b_em, b_S_b_I_a_em = outputs_setopt loss_class = criterion1(output_a, target_a) + criterion1( output_b, target_b) loss_class_out = criterion1(fake_a, target_a) + criterion1( fake_b, target_b) if params_object.mc_toggle: loss_recon = criterion2(embed_a, fake_a_em) + criterion2( embed_b, fake_b_em) return_loss_recon = loss_recon.item() else: loss_recon = 0 return_loss_recon = 0 # # Calculate the target setopt operations # target_a = target_a.type(torch.cuda.ByteTensor) target_b = target_b.type(torch.cuda.ByteTensor) target_a_I_b = target_a & target_b target_a_U_b = target_a | target_b target_a_S_b = target_a & ~target_a_I_b target_b_S_a = target_b & ~target_a_I_b target_a_I_b = target_a_I_b.type(torch.cuda.FloatTensor) target_a_U_b = target_a_U_b.type(torch.cuda.FloatTensor) target_a_S_b = target_a_S_b.type(torch.cuda.FloatTensor) target_b_S_a = target_b_S_a.type(torch.cuda.FloatTensor) loss_class_S = criterion1(a_S_b, target_a_S_b) + criterion1( b_S_a, target_b_S_a) loss_class_U = criterion1(a_U_b, target_a_U_b) loss_class_I = criterion1(a_I_b, target_a_I_b) if params_object.tautology_class_toggle: loss_class_S += criterion1(a_S_b_b, target_a_S_b) + criterion1( b_S_a_a, target_b_S_a) loss_class_S += criterion1(a_S_a_I_b, target_a_S_b) + criterion1(b_S_a_I_b, target_b_S_a) +\ criterion1(b_S_b_I_a, target_b_S_a) + criterion1(a_S_b_I_a, target_a_S_b) loss_class_U += criterion1(a_U_b_b, target_a_U_b) + criterion1( b_U_a_a, target_a_U_b) loss_class_I += criterion1(a_I_b_b, target_a_I_b) + criterion1( b_I_a_a, target_a_I_b) if params_object.tautology_recon_toggle: loss_recon_S = criterion2(a_S_b_em, a_S_b_b_em) + criterion2(a_S_b_em, a_S_a_I_b_em) + \ criterion2(a_S_b_em, a_S_b_I_a_em) loss_recon_S += criterion2(b_S_a_em, b_S_a_a_em) + criterion2(b_S_a_em, b_S_a_I_b_em) + \ criterion2(b_S_a_em, b_S_b_I_a_em) return_recon_S = loss_recon_S.item() else: loss_recon_S = 0 return_recon_S = 0 if params_object.sym_class_toggle: loss_class_U += criterion1(b_U_a, target_a_U_b) loss_class_I += criterion1(b_I_a, target_a_I_b) if params_object.sym_recon_toggle: loss_recon_U = criterion2(a_U_b_em, b_U_a_em) loss_recon_I = criterion2(a_I_b_em, b_I_a_em) return_recon_U = loss_recon_U.item() return_recon_I = loss_recon_I.item() else: loss_recon_U = 0 loss_recon_I = 0 return_recon_U = 0 return_recon_I = 0 loss = loss_class loss += 0 if params_object.class_fake_loss_weight == 0 else params_object.class_fake_loss_weight * loss_class_out loss += 0 if (params_object.recon_loss_weight == 0) or ( not loss_recon) else params_object.recon_loss_weight * loss_recon loss += 0 if params_object.class_S_loss_weight == 0 else params_object.class_S_loss_weight * loss_class_S loss += 0 if (params_object.recon_loss_weight == 0) or ( not loss_recon_I ) else params_object.recon_loss_weight * loss_recon_S loss += 0 if params_object.class_U_loss_weight == 0 else params_object.class_U_loss_weight * loss_class_U loss += 0 if (params_object.recon_loss_weight == 0) or ( not loss_recon_U ) else params_object.recon_loss_weight * loss_recon_U loss += 0 if params_object.class_I_loss_weight == 0 else params_object.class_I_loss_weight * loss_class_I loss += 0 if (params_object.recon_loss_weight == 0) or ( not loss_recon_I ) else params_object.recon_loss_weight * loss_recon_I loss.backward() optimizer.step() return { "main": loss.item(), "real class": loss_class.item(), "fake class": loss_class_out.item(), "fake MSE": return_loss_recon, "S MSE": return_recon_S, "U MSE": return_recon_U, "I MSE": return_recon_I, "S class": loss_class_S.item(), "U class": loss_class_U.item(), "I class": loss_class_I.item() } engine = Engine(_update) for name, metric in metrics.items(): metric.attach(engine, name) return engine
def get_engine(): engine = Engine(sum_data) average = Average() average.attach(engine, "average") return engine
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations cutmix_beta = config["cutmix_beta"] cutmix_prob = config["cutmix_prob"] with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() with autocast(enabled=with_amp): r = torch.rand(1).item() if cutmix_beta > 0 and r < cutmix_prob: output, loss = utils.cutmix_forward(model, x, criterion, y, cutmix_beta) else: output = model(x) loss = criterion(output, y) optimizer.zero_grad() scaler.scale(loss).backward() if idist.backend() == "horovod": optimizer.synchronize() with optimizer.skip_synchronize(): scaler.step(optimizer) scaler.update() else: scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger if config["with_pbar"] and idist.get_rank() == 0: ProgressBar().attach(trainer) to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def test_stopping_criterion_is_max_epochs(): engine = Engine(MagicMock(return_value=1)) max_epochs = 5 state = engine.run([1], max_epochs=max_epochs) assert state.epoch == max_epochs
def train(): parser = ArgumentParser() parser.add_argument("--train_path", type=str, default="data/train_set4DSTC7-AVSD.json", help="Path of the trainset") parser.add_argument("--fea_path", type=str, default="data/", help="Path of the trainset") parser.add_argument("--valid_path", type=str, default="data/valid_set4DSTC7-AVSD.json", help="Path of the validset") parser.add_argument("--model_checkpoint", type=str, default="gpt2", help="Path, url or short name of the model") parser.add_argument("--max_history", type=int, default=3, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--drop_rate", type=float, default=0.5, help="drop rate for caption") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=8, help="Number of training epochs") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument("--log_path", type=str, default="log/", help="Log path") args = parser.parse_args() if not os.path.exists(args.log_path): os.makedirs(args.log_path) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = VideoGPT2LMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) model.resize_token_embeddings(len(tokenizer)) model.to(args.device) optimizer = AdamW(model.parameters(), lr=args.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader = get_data_loaders_new(args, tokenizer) # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) video_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[video_mask, input_mask], mode="video")[0] reply_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[reply_mask, input_mask], mode="reply")[0] loss = (video_loss + reply_loss) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) model_outputs = model(input_embs, token_type_ids=token_type_ids, attention_mask=[reply_mask, input_mask])[0] lm_logits = model_outputs # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return lm_logits_flat_shifted, lm_labels_flat_shifted evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0], x[1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir="./tb_logs") tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(args.log_path, 'checkpoint', save_interval=1, n_saved=8, require_empty=False) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(args, args.log_path + 'model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(args.log_path, CONFIG_NAME)) tokenizer.save_vocabulary(args.log_path) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(args.log_path, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def test_run_with_max_iters(): max_iters = 8 engine = Engine(lambda e, b: 1) engine.run([0] * 20, max_iters=max_iters) assert engine.state.iteration == max_iters assert engine.state.max_iters == max_iters
def _test_setup_logging( setup_logging_fn, kwargs_dict, output_handler_cls, opt_params_handler_cls, with_eval=True, with_optim=True, as_class=False, log_every_iters=1, ): trainer = Engine(lambda e, b: b) evaluators = None optimizers = None if with_eval: evaluator = Engine(lambda e, b: None) acc_scores = [0.1, 0.2, 0.3, 0.4, 0.3, 0.3, 0.2, 0.1, 0.1, 0.0] @trainer.on(Events.EPOCH_COMPLETED) def validate(engine): evaluator.run([0, 1]) @evaluator.on(Events.EPOCH_COMPLETED) def set_eval_metric(engine): engine.state.metrics = {"acc": acc_scores[trainer.state.epoch - 1]} evaluators = {"validation": evaluator} if as_class: evaluators = evaluators["validation"] if with_optim: t = torch.tensor([ 0, ]) optimizers = { "optimizer": torch.optim.SGD([ t, ], lr=0.01) } if as_class: optimizers = optimizers["optimizer"] kwargs_dict["trainer"] = trainer kwargs_dict["optimizers"] = optimizers kwargs_dict["evaluators"] = evaluators kwargs_dict["log_every_iters"] = log_every_iters x_logger = setup_logging_fn(**kwargs_dict) handlers = trainer._event_handlers[Events.ITERATION_COMPLETED] for cls in [ output_handler_cls, ]: assert any([isinstance(h[0], cls) for h in handlers]), "{}".format(handlers) if with_optim: handlers = trainer._event_handlers[Events.ITERATION_STARTED] for cls in [ opt_params_handler_cls, ]: assert any([isinstance(h[0], cls) for h in handlers]), "{}".format(handlers) if with_eval: handlers = evaluator._event_handlers[Events.COMPLETED] for cls in [ output_handler_cls, ]: assert any([isinstance(h[0], cls) for h in handlers]), "{}".format(handlers) data = [0, 1, 2] trainer.run(data, max_epochs=10) if "output_path" in kwargs_dict: tb_files = list(os.listdir(kwargs_dict["output_path"])) assert len(tb_files) == 1 for v in [ "events", ]: assert any([v in c for c in tb_files]), "{}".format(tb_files) return x_logger
def test_default_exception_handler(): update_function = MagicMock(side_effect=ValueError()) engine = Engine(update_function) with raises(ValueError): engine.run([1])