Пример #1
0
def get_warmup_multistep_scheduler(optimizer, num_iterations_per_epoch,
                                   config):

    lr_max_value = config['lr_max_value']
    warmup_duration = config['warmup_duration'] * num_iterations_per_epoch
    num_iterations = config['num_epochs'] * num_iterations_per_epoch
    cooldown_duration = config['cooldown_duration'] * num_iterations_per_epoch

    optimizer.param_groups[0]['initial_lr'] = lr_max_value

    lr_scheduler = optim.lr_scheduler.MultiStepLR(
        optimizer=optimizer,
        milestones=[
            int(num_iterations - warmup_duration - 2 * cooldown_duration),
            int(num_iterations - warmup_duration - 1 * cooldown_duration),
            int(num_iterations - warmup_duration - 0.5 * cooldown_duration),
        ],
        gamma=0.2)

    return create_lr_scheduler_with_warmup(
        lr_scheduler,
        warmup_start_value=0.0,
        warmup_end_value=lr_max_value,
        warmup_duration=warmup_duration,
        save_history=True,
    )
Пример #2
0
def get_scheduler(optimizer, epochs, learning_rate, train_loader_size):
    scheduler = CosineAnnealingScheduler(optimizer, 'lr', learning_rate,
                                         learning_rate / 1000,
                                         epochs * train_loader_size)
    scheduler = create_lr_scheduler_with_warmup(scheduler, 0, 1000,
                                                learning_rate)
    return scheduler
Пример #3
0
def get_lr_scheduler(
    config: ConfigSchema, optimizer: Optimizer, trainer: Engine, evaluator: Engine
):
    if config.num_warmup_epochs:
        length = config.num_epochs - config.num_warmup_epochs
    else:
        length = config.num_epochs

    if config.lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingScheduler(
            optimizer,
            "lr",
            config.learning_rate,
            0.001 * config.learning_rate,
            cycle_size=length + 1,
        )
        if config.num_warmup_epochs:
            lr_scheduler = create_lr_scheduler_with_warmup(
                lr_scheduler, 0.0, config.num_warmup_epochs
            )
    elif config.lr_scheduler == "reduce_at_plateau":
        lr_scheduler = LRReductionEarlyStopping(
            optimizer,
            trainer=trainer,
            reduction_rate=0.1,
            num_reduction=2,
            patience=config.patience,
            score_function=lambda _: evaluator.state.metrics["accuracy"],
            num_warmup_epochs=config.num_warmup_epochs,
            warmup_start_value=0.001 * config.learning_rate,
        )
    else:
        raise ValueError(f"unknown lr scheduler {config.lr_scheduler}")

    return lr_scheduler
Пример #4
0
def get_lr_scheduler(optimizer, num_iterations_per_epoch, config):
    lr_max_value = config['lr_max_value']
    warmup_duration = config['warmup_duration'] * num_iterations_per_epoch
    num_iterations = config['num_epochs'] * num_iterations_per_epoch
    cooldown_duration = config['cooldown_duration'] * num_iterations_per_epoch

    scheduler_1 = LinearCyclicalScheduler(
        optimizer,
        "lr",
        start_value=lr_max_value,
        end_value=lr_max_value * 0.4,
        cycle_size=(num_iterations - warmup_duration - cooldown_duration) * 2)

    scheduler_2 = LinearCyclicalScheduler(optimizer,
                                          "lr",
                                          start_value=lr_max_value * 0.2,
                                          end_value=lr_max_value * 0.01,
                                          cycle_size=cooldown_duration * 2)

    lr_scheduler = ConcatScheduler(schedulers=[
        scheduler_1,
        scheduler_2,
    ],
                                   durations=[
                                       num_iterations - warmup_duration -
                                       cooldown_duration,
                                   ])

    return create_lr_scheduler_with_warmup(
        lr_scheduler,
        warmup_start_value=0.0,
        warmup_end_value=lr_max_value,
        warmup_duration=warmup_duration,
        save_history=True,
    )
Пример #5
0
 def schedule_lr(self, optimizer, name, params, warmup_start=None,
                 warmup_end=None, warmup_duration=None):
     if name is None:
         return None
     lr_scheduler = self._get_lr_scheduler(name)(optimizer, **params)
     if warmup_start and warmup_end and warmup_duration:
         scheduler = \
             create_lr_scheduler_with_warmup(lr_scheduler,
                                             warmup_start_value=warmup_start,
                                             warmup_end_value=warmup_end,
                                             warmup_duration=warmup_duration)
     else:
         scheduler = LRScheduler(lr_scheduler)
     self.trainer.add_event_handler(Events.EPOCH_COMPLETED, scheduler)
def get_lr_scheduler(optimizer, config):
    lr = config["learning_rate"]
    warmup_factor = config["warmup_factor"]
    num_warmup_iterations = config["num_warmup_iterations"]
    learning_rate_milestone_iterations = config[
        "learning_rate_milestone_iterations"]
    gamma = config["gamma"]

    learning_rate_milestone_iterations = [
        x - num_warmup_iterations for x in learning_rate_milestone_iterations
    ]
    lr_scheduler = MultiStepLR(optimizer=optimizer,
                               gamma=gamma,
                               milestones=learning_rate_milestone_iterations)

    lr_scheduler = create_lr_scheduler_with_warmup(
        lr_scheduler,
        warmup_start_value=lr * warmup_factor,
        warmup_end_value=lr,
        warmup_duration=num_warmup_iterations,
    )
    return lr_scheduler
Пример #7
0
# Define training function
def update(engine, batch):
    model.train()
    batch = batch.transpose(0, 1).contiguous().to(args.device)  # to shape [seq length, batch]
    logits, loss = model(batch, labels=batch)
    loss = loss / args.gradient_accumulation_steps
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
    if engine.state.iteration % args.gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
    return loss.item()
trainer = Engine(update)

# Add progressbar with loss
RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
ProgressBar(persist=True).attach(trainer, metric_names=['loss'])

# Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine
cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, len(dataloader) * args.n_epochs)
scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr, args.n_warmup)
trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

# Save checkpoints and training config
checkpoint_handler = ModelCheckpoint(args.log_dir, 'checkpoint', save_interval=1, n_saved=5)
trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': model})
torch.save(args, os.path.join(args.log_dir, 'training_args.bin'))

trainer.run(dataloader, max_epochs=args.n_epochs)
Пример #8
0
        },
    ],
    lr=args.learning_rate,
)

loss_fn = OHEMLoss(ignore_index=255)
loss_fn = loss_fn.cuda()

scheduler = CosineAnnealingScheduler(
    optimizer,
    'lr',
    args.learning_rate,
    args.learning_rate / 1000,
    args.epochs * len(train_loader) - 1000,
)
scheduler = create_lr_scheduler_with_warmup(scheduler, 0, args.learning_rate,
                                            1000)

model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
if args.distributed:
    model = convert_syncbn_model(model)
    model = DistributedDataParallel(model)

trainer = create_segmentation_trainer(
    model,
    optimizer,
    loss_fn,
    device=device,
    use_f16=True,
)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
Пример #9
0
    loss = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.lr,
                                 weight_decay=cfg.weight_decay)
    trainer = create_supervised_trainer(model, optimizer, loss, device)
    RunningAverage(output_transform=lambda x: x).attach(trainer, 'loss')

    trainer.add_event_handler(
        Events.ITERATION_COMPLETED,
        create_lr_scheduler_with_warmup(CosineAnnealingScheduler(
            optimizer,
            param_name='lr',
            start_value=cfg.lr,
            end_value=0,
            cycle_size=len(train_loader) * cfg.n_epochs,
            start_value_mult=0,
            end_value_mult=0),
                                        warmup_start_value=0.0,
                                        warmup_end_value=cfg.lr,
                                        warmup_duration=len(train_loader)))

    evaluator = create_supervised_evaluator(
        model,
        metrics={
            'loss': Loss(loss),
            'acc_smpl': Accuracy(threshold_output, is_multilabel=True),
            'p': Precision(threshold_output, average=True),
            'r': Recall(threshold_output, average=True),
            'f1': Fbeta(1.0, output_transform=threshold_output),
            'ap': AveragePrecision(output_transform=activate_output)
Пример #10
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default='wikitext-2',
        help="One of ('wikitext-103', 'wikitext-2') or a dict of splits paths."
    )
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")

    parser.add_argument("--embed_dim",
                        type=int,
                        default=410,
                        help="Embeddings dim")
    parser.add_argument("--hidden_dim",
                        type=int,
                        default=2100,
                        help="Hidden dimension")
    parser.add_argument("--num_max_positions",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=16,
                        help="NUmber of layers")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--initializer_range",
                        type=float,
                        default=0.02,
                        help="Dropout")

    parser.add_argument("--train_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=8,
                        help="Batch size for validation")
    parser.add_argument("--lr",
                        type=float,
                        default=2.5e-4,
                        help="Learning rate")
    parser.add_argument("--max_norm",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=200,
                        help="Number of training epochs")
    parser.add_argument("--n_warmup",
                        type=int,
                        default=1000,
                        help="Number of warmup iterations")
    parser.add_argument("--eval_every",
                        type=int,
                        default=-1,
                        help="Evaluate every X steps (-1 => end of epoch)")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=1,
                        help="Accumulate gradient")

    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log on main process only, logger.warning => log on all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(
        args))  # This is a logger.info: only printed on the first process

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    logger.info("Prepare tokenizer, model and optimizer")
    tokenizer = BertTokenizer.from_pretrained(
        'bert-base-cased',
        do_lower_case=False)  # Let's use a pre-defined tokenizer
    args.num_embeddings = len(
        tokenizer.vocab
    )  # We need this to create the model at next line (number of embeddings to use)
    model = TransformerWithLMHead(args)
    model.to(args.device)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = batch.transpose(0, 1).contiguous().to(
            args.device)  # to shape [seq length, batch]
        logits, loss = model(batch, labels=batch)
        loss = loss / args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = batch.transpose(0, 1).contiguous().to(
                args.device)  # to shape [seq length, batch]
            logits = model(batch)
            shift_logits = logits[:-1]
            shift_labels = batch[1:]
            return shift_logits.view(-1,
                                     logits.size(-1)), shift_labels.view(-1)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(val_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule
    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.n_epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.n_warmup)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we average distributed metrics using average_distributed_scalar
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    metrics["average_word_ppl"] = MetricsLambda(
        lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words),
        metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train
    if args.local_rank in [-1, 0]:
        checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving(
            trainer, evaluator, metrics, model, optimizer, args)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint for easy re-loading
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
Пример #11
0
)

loss_fn = OHEMLoss(ignore_index=255).cuda()

lr = args.learning_rate
lrs = [lr / 10, lr/10, lr, lr]

schedulers = [
    CosineAnnealingScheduler(
        optimizer, 'lr',
        lr, lr * 1e-4,
        args.epochs * len(train_loader),
        param_group_index=0)
    for index, lr in enumerate(lrs)]
schedulers = [
    create_lr_scheduler_with_warmup(scheduler, 0, lr, 1000)
    for scheduler, lr in zip(schedulers, lrs)
]


model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
if args.distributed:
    model = convert_syncbn_model(model)
    model = DistributedDataParallel(model)


trainer = create_segmentation_trainer(
    model, optimizer, loss_fn,
    device=device,
    use_f16=True,
)
Пример #12
0
def supervised_loss_fn(y_pred, y):
    y_pred, aux_y_pred = y_pred
    return \
        loss_fn(y_pred, y) \
        + 0.4 * sum((aux_loss(y_pred, y) for y_pred in aux_y_pred))


scheduler1 = CosineAnnealingScheduler(
    optimizer,
    param_name='lr',
    start_value=args.learning_rate / 10,
    end_value=args.learning_rate / 10 * 1e-4,
    cycle_size=args.epochs * len(train_loader) - 1000,
    param_group_index=0,
)
scheduler1 = create_lr_scheduler_with_warmup(scheduler1, 0,
                                             args.learning_rate / 10, 1000)
scheduler2 = CosineAnnealingScheduler(
    optimizer,
    param_name='lr',
    start_value=args.learning_rate / 10,
    end_value=args.learning_rate / 10 * 1e-4,
    cycle_size=args.epochs * len(train_loader) - 1000,
    param_group_index=1,
)
scheduler2 = create_lr_scheduler_with_warmup(scheduler2, 0,
                                             args.learning_rate / 10, 1000)
scheduler3 = CosineAnnealingScheduler(
    optimizer,
    param_name='lr',
    start_value=args.learning_rate,
    end_value=args.learning_rate * 1e-4,
Пример #13
0
def run(train_config, logger, **kwargs):

    logger = logging.getLogger('UDA')
    if getattr(train_config, 'debug', False):
        setup_logger(logger, logging.DEBUG)

    # Set Polyaxon environment if needed
    plx_logger = None
    save_dir = None
    output_experiment_path = None
    try:
        plx_logger = PolyaxonLogger()
        experiment = plx_logger.experiment
        save_dir = get_outputs_path()
        output_experiment_path = get_outputs_refs_paths()
        output_experiment_path = output_experiment_path['experiments'][
            0] if output_experiment_path else None
        logger.debug("Experiment info: {}".format(
            experiment.get_experiment_info()))
    except PolyaxonClientException as e:
        logger.warning('Logger Polyaxon : ' + str(e))

    # Path configuration
    saves_dict = getattr(train_config, 'saves', {})

    save_dir = saves_dict.get('save_dir', '') if save_dir is None else save_dir
    log_dir = os.path.join(save_dir, saves_dict.get('log_dir', ''))
    save_model_dir = os.path.join(save_dir, saves_dict.get('model_dir', ''))
    save_prediction_dir = os.path.join(save_dir,
                                       saves_dict.get('prediction_dir', ''))
    save_config_dir = os.path.join(save_dir, saves_dict.get('config_dir', ''))
    load_model_file = saves_dict.get('load_model_file', '')
    load_optimizer_file = saves_dict.get('load_optimizer_file', '')

    # Create folders
    create_save_folders(save_dir, saves_dict)

    if output_experiment_path is not None:
        model_dir = saves_dict.get('model_dir', '')
        load_model_file = os.path.join(
            output_experiment_path, model_dir,
            load_model_file) if load_model_file else None
        load_optimizer_file = os.path.join(
            output_experiment_path, model_dir,
            load_optimizer_file) if load_optimizer_file else None

    num_epochs = getattr(train_config, 'num_epochs')
    num_classes = getattr(train_config, 'num_classes')
    device = getattr(train_config, 'device', 'cpu')

    # Set magical acceleration
    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    else:
        assert device == 'cpu', 'CUDA device selected but none is available'

    # Set half precision if required
    use_fp_16 = getattr(train_config, 'use_fp_16', False)

    train1_sup_loader = getattr(train_config, 'train1_sup_loader')
    train1_unsup_loader = getattr(train_config, 'train1_unsup_loader')
    train2_unsup_loader = getattr(train_config, 'train2_unsup_loader')
    test_loader = getattr(train_config, 'test_loader')

    save_interval = saves_dict.get('save_interval', 0)
    n_saved = saves_dict.get('n_saved', 0)

    val_interval = getattr(train_config, 'val_interval', 1)
    pred_interval = getattr(train_config, 'pred_interval', 0)

    model = getattr(train_config, 'model').to(device)

    optimizer = getattr(train_config, 'optimizer')

    criterion = getattr(train_config, 'criterion').to(device)
    consistency_criterion = getattr(train_config,
                                    'consistency_criterion').to(device)

    cm_metric = getattr(
        train_config, 'cm_metric',
        ConfusionMatrix(num_classes=num_classes,
                        output_transform=lambda x: (x['y_pred'], x['y'])))

    # AMP initialization for half precision
    if use_fp_16:
        assert 'cuda' in device
        assert torch.backends.cudnn.enabled, "NVIDIA/Apex:Amp requires cudnn backend to be enabled."
        try:
            from apex import amp
        except:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to run this example."
            )
        # Initialize amp
        model, optimizer = amp.initialize(model, optimizer, opt_level="O2")

    # Load checkpoint
    load_params(model,
                optimizer=optimizer,
                model_file=load_model_file,
                optimizer_file=load_optimizer_file,
                device_name=device)

    # Add batch norm
    is_bn = getattr(train_config, 'is_bn', False)
    if is_bn:
        batch_norm = nn.BatchNorm2d(3).to(device)
        if use_fp_16:
            batch_norm = amp.initialize(batch_norm)
        batch_norm.reset_parameters()
        model = nn.Sequential(batch_norm, model)

    # Copy the config file
    shutil.copy2(os.path.abspath(train_config.__file__),
                 os.path.join(save_config_dir, 'checkpoint_module.py'))

    le = len(train1_sup_loader)
    num_train_steps = le * num_epochs
    mlflow.log_param("num train steps", num_train_steps)

    lr = getattr(train_config, 'learning_rate')
    num_warmup_steps = getattr(train_config, 'num_warmup_steps', 0)

    lr_scheduler = getattr(train_config, 'lr_scheduler', None)
    if lr_scheduler is not None:
        lr_scheduler = lr_scheduler(optimizer)

    if num_warmup_steps > 0:
        lr_scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr * (1.0 + 1.0 / num_warmup_steps),
            warmup_duration=num_warmup_steps)

    train1_sup_loader_iter = cycle(train1_sup_loader)
    train1_unsup_loader_iter = cycle(train1_unsup_loader)
    train2_unsup_loader_iter = cycle(train2_unsup_loader)

    # Reduce on plateau
    reduce_on_plateau = getattr(train_config, 'reduce_on_plateau', None)

    # Output transform model
    output_transform_model = getattr(train_config, 'output_transform_model',
                                     lambda x: x)

    inference_fn = getattr(train_config, 'inference_fn', inference_standard)

    lam = getattr(train_config, 'consistency_lambda')
    beta = getattr(train_config, 'consistency_beta', lam)

    tsa = TrainingSignalAnnealing(
        num_steps=num_train_steps,
        min_threshold=getattr(train_config, 'TSA_proba_min'),
        max_threshold=getattr(train_config, 'TSA_proba_max'))

    with_tsa = getattr(train_config, 'with_TSA', False)

    cfg = {
        'tsa': tsa,
        'lambda': lam,
        'beta': beta,
        'with_tsa': with_tsa,
        'device': device,
        'consistency_criterion': consistency_criterion,
        'criterion': criterion
    }

    trainer = Engine(
        partial(train_update_function,
                model=model,
                optimizer=optimizer,
                cfg=cfg,
                train1_sup_loader_iter=train1_sup_loader_iter,
                train1_unsup_loader_iter=train1_unsup_loader_iter,
                train2_unsup_loader_iter=train2_unsup_loader_iter,
                output_transform_model=output_transform_model,
                use_fp_16=use_fp_16))

    # Register events
    for e in CustomEvents:
        State.event_to_attr[e] = 'iteration'

    trainer.register_events(*CustomEvents)

    if with_tsa:
        trainer.add_event_handler(Events.ITERATION_COMPLETED, log_tsa, tsa)

    if lr_scheduler is not None:
        if not hasattr(lr_scheduler, "step"):
            trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
        else:
            trainer.add_event_handler(Events.ITERATION_STARTED,
                                      lambda engine: lr_scheduler.step())

    trainer.add_event_handler(Events.ITERATION_COMPLETED, log_learning_rate,
                              optimizer)

    metric_names = [
        'supervised batch loss', 'consistency batch loss', 'final batch loss'
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        RunningAverage(
            output_transform=partial(output_transform, name=n)).attach(
                trainer, n)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)

    # Handlers for Tensorboard logging
    tb_logger = TensorboardLogger(log_dir=log_dir)
    tb_logger.attach(trainer,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=metric_names),
                     event_name=CustomEvents.ITERATION_K_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=tbOptimizerParamsHandler(optimizer,
                                                          param_name="lr"),
                     event_name=CustomEvents.ITERATION_K_STARTED)

    # Handlers for Polyaxon logging
    if plx_logger is not None:
        plx_logger.attach(trainer,
                          log_handler=plxOutputHandler(
                              tag="train", metric_names=metric_names),
                          event_name=CustomEvents.ITERATION_K_COMPLETED)

    metrics = {
        'loss': Loss(criterion,
                     output_transform=lambda x: (x['y_pred'], x['y'])),
        'mAcc': cmAccuracy(cm_metric).mean(),
        'mPr': cmPrecision(cm_metric).mean(),
        'mRe': cmRecall(cm_metric).mean(),
        'mIoU': mIoU(cm_metric),
        'mF1': cmFbeta(cm_metric, 1).mean()
    }
    iou = IoU(cm_metric)
    for i in range(num_classes):
        key_name = 'IoU_{}'.format(str(i))
        metrics[key_name] = iou[i]

    inference_update_fn = partial(
        inference_update_function,
        model=model,
        cfg=cfg,
        output_transform_model=output_transform_model,
        inference_fn=inference_fn)

    evaluator = Engine(inference_update_fn)
    train_evaluator = Engine(inference_update_fn)

    for name, metric in metrics.items():
        metric.attach(train_evaluator, name)
        metric.attach(evaluator, name)

    # Add checkpoint
    if save_model_dir:
        checkpoint = ModelCheckpoint(dirname=save_model_dir,
                                     filename_prefix='checkpoint',
                                     save_interval=save_interval,
                                     n_saved=n_saved,
                                     create_dir=True)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint, {
            'mymodel': model,
            'optimizer': optimizer
        })

    def trigger_k_iteration_started(engine, k):
        if engine.state.iteration % k == 0:
            engine.fire_event(CustomEvents.ITERATION_K_STARTED)

    def trigger_k_iteration_completed(engine, k):
        if engine.state.iteration % k == 0:
            engine.fire_event(CustomEvents.ITERATION_K_COMPLETED)

    def run_validation(engine, validation_interval):
        if (trainer.state.epoch - 1) % validation_interval == 0:
            train_evaluator.run(train1_sup_loader)
            evaluator.run(test_loader)

            if save_prediction_dir:
                train_output = train_evaluator.state.output
                test_output = evaluator.state.output

                iteration = str(trainer.state.iteration)
                epoch = str(trainer.state.epoch)

                save_prediction('train_{}_{}'.format(iteration, epoch),
                                save_prediction_dir,
                                train_output['x'],
                                torch.argmax(
                                    train_output['y_pred'][0, :, :, :], dim=0),
                                y=train_output['y'][0, :, :])

                save_prediction('test_{}_{}'.format(iteration, epoch),
                                save_prediction_dir,
                                test_output['x'],
                                torch.argmax(test_output['y_pred'][0, :, :, :],
                                             dim=0),
                                y=test_output['y'][0, :, :])

            train_evaluator.state.output = None
            evaluator.state.output = None

            if reduce_on_plateau is not None:
                reduce_on_plateau.step(evaluator.state.metrics['mIoU'])

    trainer.add_event_handler(Events.ITERATION_STARTED,
                              trigger_k_iteration_started,
                              k=10)
    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              trigger_k_iteration_completed,
                              k=10)

    trainer.add_event_handler(Events.EPOCH_STARTED,
                              run_validation,
                              validation_interval=val_interval)
    trainer.add_event_handler(Events.COMPLETED,
                              run_validation,
                              validation_interval=1)

    def trainer_prediction_save(engine, prediction_interval):
        if (engine.state.iteration - 1) % prediction_interval == 0:

            if save_prediction_dir:
                trainer_output = trainer.state.output['unsup pred']

                iteration = str(trainer.state.iteration)
                epoch = str(trainer.state.epoch)

                save_prediction('trainer_{}_{}'.format(iteration, epoch),
                                save_prediction_dir, trainer_output['x'],
                                trainer_output['y_pred'])

                logger.debug(
                    'Saved trainer prediction for iteration {}'.format(
                        str(engine.state.iteration)))

            trainer.state.output = None

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              trainer_prediction_save,
                              prediction_interval=pred_interval)

    tb_logger.attach(train_evaluator,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=list(
                                                     metrics.keys())),
                     event_name=Events.EPOCH_COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=tbOutputHandler(tag="test",
                                                 metric_names=list(
                                                     metrics.keys())),
                     event_name=Events.EPOCH_COMPLETED)

    # Handlers for Polyaxon logging
    if plx_logger is not None:
        plx_logger.attach(train_evaluator,
                          log_handler=plxOutputHandler(tag="train",
                                                       metric_names=list(
                                                           metrics.keys())),
                          event_name=Events.EPOCH_COMPLETED)

        plx_logger.attach(evaluator,
                          log_handler=plxOutputHandler(tag="test",
                                                       metric_names=list(
                                                           metrics.keys())),
                          event_name=Events.EPOCH_COMPLETED)

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mlflow_batch_metrics_logging, "train", trainer)
    train_evaluator.add_event_handler(Events.COMPLETED,
                                      mlflow_val_metrics_logging, "train",
                                      trainer)
    evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging,
                                "test", trainer)

    data_steps = list(range(len(train1_sup_loader)))

    logger.debug('Start training')
    trainer.run(data_steps, max_epochs=num_epochs)
    logger.debug('Finished training')
Пример #14
0
def run(output_path, config):

    device = "cuda"
    batch_size = config['batch_size']

    train_labelled_loader, train_unlabelled_loader, test_loader = \
        get_train_test_loaders(dataset_name=config['dataset'],
                               num_labelled_samples=config['num_labelled_samples'],
                               path=config['data_path'],
                               batch_size=batch_size,
                               unlabelled_batch_size=config.get('unlabelled_batch_size', None),
                               num_workers=config['num_workers'])

    model = get_model(config['model'])
    model = model.to(device)

    optimizer = optim.SGD(model.parameters(),
                          lr=config['learning_rate'],
                          momentum=config['momentum'],
                          weight_decay=config['weight_decay'],
                          nesterov=True)

    with_SWA = config['with_SWA']
    if with_SWA:
        optimizer = torchcontrib.optim.SWA(optimizer)

    criterion = nn.CrossEntropyLoss().to(device)
    if config['consistency_criterion'] == "MSE":
        consistency_criterion = nn.MSELoss()
    elif config['consistency_criterion'] == "KL":
        consistency_criterion = nn.KLDivLoss(reduction='batchmean')
    else:
        raise RuntimeError("Unknown consistency criterion {}".format(
            config['consistency_criterion']))

    consistency_criterion = consistency_criterion.to(device)

    le = len(train_labelled_loader)
    num_train_steps = le * config['num_epochs']
    mlflow.log_param("num train steps", num_train_steps)

    lr = config['learning_rate']
    eta_min = lr * config['min_lr_ratio']
    num_warmup_steps = config['num_warmup_steps']

    lr_scheduler = CosineAnnealingLR(optimizer,
                                     eta_min=eta_min,
                                     T_max=num_train_steps - num_warmup_steps)

    if num_warmup_steps > 0:
        lr_scheduler = create_lr_scheduler_with_warmup(
            lr_scheduler,
            warmup_start_value=0.0,
            warmup_end_value=lr * (1.0 + 1.0 / num_warmup_steps),
            warmup_duration=num_warmup_steps)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def cycle(iterable):
        while True:
            for i in iterable:
                yield i

    train_unlabelled_loader_iter = cycle(train_unlabelled_loader)

    lam = config['consistency_lambda']

    tsa = TrainingSignalAnnealing(num_steps=num_train_steps,
                                  min_threshold=config['TSA_proba_min'],
                                  max_threshold=config['TSA_proba_max'])

    with_tsa = config['with_TSA']
    with_UDA = not config['no_UDA']

    def uda_process_function(engine, labelled_batch):

        x, y = _prepare_batch(labelled_batch, device=device, non_blocking=True)

        if with_UDA:
            unsup_x, unsup_aug_x = next(train_unlabelled_loader_iter)
            unsup_x = convert_tensor(unsup_x, device=device, non_blocking=True)
            unsup_aug_x = convert_tensor(unsup_aug_x,
                                         device=device,
                                         non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        supervised_loss = loss
        step = engine.state.iteration - 1
        if with_tsa and with_UDA:
            new_y_pred, new_y = tsa(y_pred, y, step=step)
            new_loss = criterion(new_y_pred, new_y)

            engine.state.tsa_log = {
                "new_y_pred": new_y_pred,
                "loss": loss.item(),
                "tsa_loss": new_loss.item()
            }
            supervised_loss = new_loss

        # Unsupervised part
        if with_UDA:
            unsup_orig_y_pred = model(unsup_x).detach()
            unsup_orig_y_probas = torch.softmax(unsup_orig_y_pred, dim=-1)

            unsup_aug_y_pred = model(unsup_aug_x)
            unsup_aug_y_probas = torch.log_softmax(unsup_aug_y_pred, dim=-1)

            consistency_loss = consistency_criterion(unsup_aug_y_probas,
                                                     unsup_orig_y_probas)

        final_loss = supervised_loss

        if with_UDA:
            final_loss += lam * consistency_loss

        optimizer.zero_grad()
        final_loss.backward()
        optimizer.step()

        return {
            'supervised batch loss': supervised_loss.item(),
            'consistency batch loss':
            consistency_loss.item() if with_UDA else 0.0,
            'final batch loss': final_loss.item(),
        }

    trainer = Engine(uda_process_function)

    if with_UDA and with_tsa:

        @trainer.on(Events.ITERATION_COMPLETED)
        def log_tsa(engine):
            step = engine.state.iteration - 1
            if step % 50 == 0:
                mlflow.log_metric("TSA threshold",
                                  tsa.thresholds[step].item(),
                                  step=step)
                mlflow.log_metric("TSA selection",
                                  engine.state.tsa_log['new_y_pred'].shape[0],
                                  step=step)
                mlflow.log_metric("Original X Loss",
                                  engine.state.tsa_log['loss'],
                                  step=step)
                mlflow.log_metric("TSA X Loss",
                                  engine.state.tsa_log['tsa_loss'],
                                  step=step)

    if not hasattr(lr_scheduler, "step"):
        trainer.add_event_handler(Events.ITERATION_STARTED, lr_scheduler)
    else:
        trainer.add_event_handler(Events.ITERATION_STARTED,
                                  lambda engine: lr_scheduler.step())

    @trainer.on(Events.ITERATION_STARTED)
    def log_learning_rate(engine):
        step = engine.state.iteration - 1
        if step % 50 == 0:
            lr = optimizer.param_groups[0]['lr']
            mlflow.log_metric("learning rate", lr, step=step)

    if with_SWA:

        @trainer.on(Events.COMPLETED)
        def swap_swa_sgd(engine):
            optimizer.swap_swa_sgd()
            optimizer.bn_update(train_labelled_loader, model)

        @trainer.on(Events.EPOCH_COMPLETED)
        def update_swa(engine):
            if engine.state.epoch - 1 > int(num_epochs * 0.75):
                optimizer.update_swa()

    metric_names = [
        'supervised batch loss', 'consistency batch loss', 'final batch loss'
    ]

    def output_transform(x, name):
        return x[name]

    for n in metric_names:
        RunningAverage(output_transform=partial(output_transform, name=n),
                       epoch_bound=False).attach(trainer, n)

    ProgressBar(persist=True,
                bar_format="").attach(trainer,
                                      event_name=Events.EPOCH_STARTED,
                                      closing_event_name=Events.COMPLETED)

    tb_logger = TensorboardLogger(log_dir=output_path)
    tb_logger.attach(trainer,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=[
                                                     'final batch loss',
                                                     'consistency batch loss',
                                                     'supervised batch loss'
                                                 ]),
                     event_name=Events.ITERATION_COMPLETED)
    tb_logger.attach(trainer,
                     log_handler=tbOptimizerParamsHandler(optimizer,
                                                          param_name="lr"),
                     event_name=Events.ITERATION_STARTED)

    metrics = {
        "accuracy": Accuracy(),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine, val_interval):
        if (engine.state.epoch - 1) % val_interval == 0:
            train_evaluator.run(train_labelled_loader)
            evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED,
                              run_validation,
                              val_interval=2)
    trainer.add_event_handler(Events.COMPLETED, run_validation, val_interval=1)

    tb_logger.attach(train_evaluator,
                     log_handler=tbOutputHandler(tag="train",
                                                 metric_names=list(
                                                     metrics.keys()),
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    tb_logger.attach(evaluator,
                     log_handler=tbOutputHandler(tag="test",
                                                 metric_names=list(
                                                     metrics.keys()),
                                                 another_engine=trainer),
                     event_name=Events.COMPLETED)

    def mlflow_batch_metrics_logging(engine, tag):
        step = trainer.state.iteration
        for name, value in engine.state.metrics.items():
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    def mlflow_val_metrics_logging(engine, tag):
        step = trainer.state.epoch
        for name in metrics.keys():
            value = engine.state.metrics[name]
            mlflow.log_metric("{} {}".format(tag, name), value, step=step)

    trainer.add_event_handler(Events.ITERATION_COMPLETED,
                              mlflow_batch_metrics_logging, "train")
    train_evaluator.add_event_handler(Events.COMPLETED,
                                      mlflow_val_metrics_logging, "train")
    evaluator.add_event_handler(Events.COMPLETED, mlflow_val_metrics_logging,
                                "test")

    trainer.run(train_labelled_loader, max_epochs=config['num_epochs'])
Пример #15
0
def train(args):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased",
                                              do_lower_case=False)
    args.num_embeddings = len(
        tokenizer.vocab
    )  # We need this to create the model at next line (number of embeddings to use)
    model = TransformerWithLMHead(args)
    model.to(args.device)
    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)

    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler, train_num_words, valid_num_words = get_data_loaders(
        args, tokenizer)

    # Prepare masked tokens inputs/labels for masked language modeling: 80% MASK, 10% random, 10% original
    def mask_tokens(inputs):
        labels = inputs.clone()
        masked_indices = torch.bernoulli(
            torch.full(labels.shape, args.mlm_probability)).byte()
        labels[~masked_indices] = -1  # We only compute loss on masked tokens
        indices_replaced = torch.bernoulli(torch.full(
            labels.shape, 0.8)).byte() & masked_indices
        inputs[indices_replaced] = tokenizer.vocab[
            "[MASK]"]  # 80% of the time, replace masked input tokens with [MASK]
        indices_random = torch.bernoulli(torch.full(
            labels.shape, 0.5)).byte() & masked_indices & ~indices_replaced
        random_words = torch.randint(args.num_embeddings,
                                     labels.shape,
                                     dtype=torch.long,
                                     device=args.device)
        inputs[indices_random] = random_words[
            indices_random]  # 10% of the time, replace masked input tokens with random word
        return inputs, labels

    def update(engine, batch):
        model.train()
        inputs = batch.transpose(0, 1).contiguous().to(args.device)
        inputs, labels = mask_tokens(inputs) if args.mlm else (inputs, inputs)
        logits, loss = model(inputs, labels=labels)
        loss = loss / args.gradient_accumulation_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            inputs = batch.transpose(0, 1).contiguous().to(args.device)
            inputs, labels = mask_tokens(inputs) if args.mlm else (
                inputs,
                inputs)  # Prepare masked input/labels if we use masked LM
            logits = model(inputs)
            shift_logits = logits[:-1] if not args.mlm else logits
            shift_labels = labels[1:] if not args.mlm else labels
            return shift_logits.view(-1,
                                     logits.size(-1)), shift_labels.view(-1)

    evaluator = Engine(inference)

    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(val_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))

    # Learning rate schedule: linearly warm-up to lr and then decrease the learning rate to zero with cosine schedule
    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.n_epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.n_warmup)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we average distributed metrics using average_distributed_scalar
    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    # Let's convert sub-word perplexities in word perplexities. If you need details: http://sjmielke.com/comparing-perplexities.htm
    metrics["average_word_ppl"] = MetricsLambda(
        lambda x: math.exp(x * val_loader.dataset.numel() / valid_num_words),
        metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model and configuration before we start to train
    if args.local_rank in [-1, 0]:
        checkpoint_handler, tb_logger = add_logging_and_checkpoint_saving(
            trainer, evaluator, metrics, model, optimizer, args)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)
Пример #16
0
class_freq = torch.from_numpy(Cityscapes.CLASS_FREQ).float()
weight = 1 / torch.log(1.02 + class_freq)
loss_fn = torch.nn.CrossEntropyLoss(ignore_index=255, weight=weight)
loss_fn = loss_fn.cuda()

warmup_iterations = 1000

scheduler = CosineAnnealingScheduler(
    optimizer,
    'lr',
    args.learning_rate,
    args.learning_rate * 1e-4,
    cycle_size=args.epochs * len(train_loader) - warmup_iterations,
)
scheduler = create_lr_scheduler_with_warmup(scheduler, 0, args.learning_rate,
                                            warmup_iterations)

model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
if args.distributed:
    model = convert_syncbn_model(model)
    model = DistributedDataParallel(model)

trainer = create_segmentation_trainer(
    model,
    optimizer,
    loss_fn,
    device=device,
    use_f16=True,
)
trainer.add_event_handler(Events.ITERATION_COMPLETED, scheduler)
Пример #17
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--dataset_key",
                        type=str,
                        default='wikitext-2',
                        help="key from DATASETS global")
    parser.add_argument("--train_file",
                        type=str,
                        help='Optional file path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        help='Optional file path to use for valid file')
    parser.add_argument("--dataset_cache",
                        type=str,
                        default=os.path.expanduser('~/.bl-data'),
                        help="Path or url of the dataset cache")
    parser.add_argument("--cache_features", type=str2bool, default=True)
    parser.add_argument("--d_model",
                        type=int,
                        default=410,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2100, help="FFN dimension")
    parser.add_argument("--num_heads",
                        type=int,
                        default=10,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch Size")
    parser.add_argument("--tokens",
                        choices=["words", "chars", "subwords"],
                        default="subwords",
                        help="What tokens to use")
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=0.25,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.0,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=20,
                        help="Num training epochs")
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=1000,
                        help="Num warmup steps")
    parser.add_argument("--eval_every",
                        type=int,
                        default=-1,
                        help="Evaluate every X steps (-1 => end of epoch)")

    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )
    parser.add_argument("--chars_per_word",
                        type=int,
                        default=40,
                        help="How many max characters per word")
    parser.add_argument(
        "--accum_grad_steps",
        type=int,
        default=1,
        help="Create effective batch size by accumulating grads without updates"
    )
    args = parser.parse_args()

    if args.train_file and not args.valid_file:
        logger.error(
            "If you provide a train_file, you must provide a valid_file")
        return

    if not args.train_file and args.valid_file:
        logger.error(
            "If you provide a valid_file, you must also provide a train_file")
        return

    if args.basedir is None:
        args.basedir = 'transformer-{}-{}-{}'.format(args.dataset_key,
                                                     args.tokens, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.info("Cache directory [%s]", args.dataset_cache)

    args.distributed = args.distributed or int(os.environ.get("WORLD_SIZE",
                                                              1)) > 1

    if args.distributed:
        if args.local_rank == -1:
            # https://github.com/kubeflow/pytorch-operator/issues/128
            # https://github.com/pytorch/examples/blob/master/imagenet/main.py
            logger.info("Setting local rank to RANK env variable")
            args.local_rank = int(os.environ['RANK'])
        logger.warning("Local rank (%d)", args.local_rank)
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    if args.train_file:
        dataset = {
            'train_file': args.train_file,
            'valid_file': args.valid_file
        }
    else:
        dataset = DataDownloader(DATASETS[args.dataset_key],
                                 args.dataset_cache).download()
    reader = create_reader(args.tokens, args.nctx, args.chars_per_word)

    preproc_data = load_embed_and_vocab(args.tokens, reader, dataset,
                                        args.dataset_key, args.d_model,
                                        args.cache_features)

    vocabs = preproc_data['vocabs']
    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs['x'], os.path.join(args.basedir, 'vocabs.json'))
    embeddings = preproc_data['embeddings']
    valid_num_words = preproc_data['valid_num_words']
    tgt_key = preproc_data['tgt_key']
    logger.info("Loaded embeddings")

    train_set = load_data(args.tokens, reader, dataset, 'train_file', vocabs,
                          args.cache_features)
    valid_set = load_data(args.tokens, reader, dataset, 'valid_file', vocabs,
                          args.cache_features)
    logger.info("valid. tokens [%s], valid. words [%s]",
                valid_set.tensors[-1].numel(), valid_num_words)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_set) if args.distributed else None
    train_loader = DataLoader(train_set,
                              sampler=train_sampler,
                              batch_size=args.batch_size,
                              shuffle=(not args.distributed))

    valid_sampler = torch.utils.data.distributed.DistributedSampler(
        valid_set) if args.distributed else None
    valid_loader = DataLoader(valid_set,
                              sampler=valid_sampler,
                              batch_size=args.batch_size,
                              shuffle=False)

    logger.info("Loaded datasets")

    model = TransformerLanguageModel.create(
        embeddings,
        hsz=args.d_model,
        d_ff=args.d_ff,
        tie_weights=(args.tokens != 'chars'),
        dropout=args.dropout,
        gpu=False,
        num_heads=args.num_heads,
        layers=args.num_layers,
        src_keys=['x'],
        tgt_key=tgt_key)
    model.to(args.device)
    train_loss = model.create_loss()
    train_loss.to(args.device)

    logger.info("Loaded model and loss")

    optimizer = Adam(model.parameters(),
                     lr=args.lr,
                     weight_decay=args.weight_decay)
    logger.info("Model has %s parameters",
                sum(p.numel() for p in model.parameters() if p.requires_grad))

    # Prepare model for distributed training if needed
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank)
        logger.info("Model located on %d", args.local_rank)

    def update(engine, batch):
        model.train()
        x, y = batch
        inputs = {'x': x.to(args.device)}
        labels = y.to(args.device).transpose(0, 1).contiguous()
        logits = model(inputs, None)[0].transpose(0, 1).contiguous()
        shift_logits = logits[:-1]
        shift_labels = labels[1:]
        loss = train_loss(shift_logits, shift_labels)
        loss = loss / args.accum_grad_steps
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
        if engine.state.iteration % args.accum_grad_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    def inference(_, batch):
        model.eval()
        with torch.no_grad():
            x, y = batch
            inputs = {'x': x.to(args.device)}
            labels = y.to(args.device).transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            shift_logits = logits[:-1]
            shift_labels = labels[1:]
            return shift_logits.view(-1,
                                     logits.size(-1)), shift_labels.view(-1)

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate at the end of each epoch and every 'eval_every' iterations if needed
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(valid_loader))
    if args.eval_every > 0:
        trainer.add_event_handler(
            Events.ITERATION_COMPLETED,
            lambda engine: evaluator.run(valid_loader)
            if engine.state.iteration % args.eval_every == 0 else None)
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    cos_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0,
                                             len(train_loader) * args.epochs)
    scheduler = create_lr_scheduler_with_warmup(cos_scheduler, 0.0, args.lr,
                                                args.warmup_steps)
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    metrics = {"nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1))}
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })

    if args.tokens == 'subwords':
        # If we compute subwords, need to renormalize for num words
        metrics["average_subword_ppl"] = MetricsLambda(math.exp,
                                                       metrics["average_nll"])
        metrics["average_word_ppl"] = MetricsLambda(
            lambda x: math.exp(x * valid_set.tensors[-1].numel() /
                               valid_num_words), metrics["average_nll"])
    else:
        metrics["average_word_ppl"] = MetricsLambda(math.exp,
                                                    metrics["average_nll"])

    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    if args.local_rank < 1:
        RunningAverage(output_transform=lambda x: x).attach(
            trainer, "valid_loss")
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, lambda _: print(
                "Epoch[{}] Training Loss: {:.2f}, Perplexity {:.2f}".format(
                    trainer.state.epoch, trainer.state.output,
                    np.exp(trainer.state.output))))
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: print("Validation: %s" % pformat(
                evaluator.state.metrics)))
        checkpoint_handler = ModelCheckpoint(args.basedir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3,
                                             create_dir=False)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, checkpoint_handler,
                                  {'mymodel': getattr(model, 'module', model)})
    trainer.run(train_loader, max_epochs=args.epochs)