コード例 #1
0
def create_model(embeddings, d_model, d_ff, num_heads, num_layers, rpr_k, rpr_value_on, d_k, checkpoint_name, activation):
    rpr_k = listify(rpr_k)

    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = None
    elif len(rpr_k) == 1:
        rpr_k = rpr_k[0]

    logger.info("Creating tied encoder decoder model")
    model = TransformerLanguageModel.create({'x': embeddings},
                                            hsz=d_model,
                                            d_ff=d_ff,
                                            tie_weights=True,
                                            dropout=0,
                                            gpu=False,
                                            num_heads=num_heads,
                                            layers=num_layers,
                                            rpr_k=rpr_k,
                                            rpr_value_on=rpr_value_on,
                                            d_k=d_k,
                                            activation=activation,
                                            src_keys=['x'], tgt_key='x')
    if checkpoint_name.endswith('npz'):
        load_tlm_npz(model, checkpoint_name)
    else:
        tlm_load_state_dict(model, checkpoint_name)
    model.eval()
    print(model)
    return model
コード例 #2
0
def _round_trip(embed_type, rpr_k=None):
    test_file = os.path.join(file_loc, "test_data", "blah.npz")
    d_model = 40
    vocab_x = {
        'a': 1,
        'aardvark': 100,
        'beandip': 42,
        'cheerio': 86,
        'dumdum': 129,
        'eoyre': 3
    }
    embeddings = {}
    vocabs = {'x': vocab_x}
    src_x_embedding = baseline.embeddings.load_embeddings(
        'x', dsz=d_model, known_vocab=vocab_x, embed_type=embed_type)
    embeddings['x'] = src_x_embedding['embeddings']

    src_model = TransformerLanguageModel.create(embeddings,
                                                hsz=d_model,
                                                dropout=0.1,
                                                gpu=False,
                                                num_heads=4,
                                                layers=2,
                                                rpr_k=rpr_k,
                                                src_keys=['x'],
                                                tgt_key='x')

    save_tlm_npz(src_model, test_file)

    dst_x_embedding = baseline.embeddings.load_embeddings(
        'x', dsz=d_model, known_vocab=vocab_x, embed_type=embed_type)
    embeddings['x'] = dst_x_embedding['embeddings']
    dst_model = TransformerLanguageModel.create(embeddings,
                                                hsz=d_model,
                                                dropout=0.1,
                                                gpu=False,
                                                num_heads=4,
                                                layers=2,
                                                rpr_k=rpr_k,
                                                src_keys=['x'],
                                                tgt_key='x')
    load_tlm_npz(dst_model, test_file)

    B = 4
    T = 7
    a_batch = torch.randint(0, 9, (B, T)).long()
    a_lengths = torch.randint(0, T, (B, )).long()
    out_pyt1 = _call_model(src_model, {
        'x': a_batch,
        'lengths': a_lengths
    }).detach().numpy()
    out_pyt2 = _call_model(dst_model, {
        'x': a_batch,
        'lengths': a_lengths
    }).detach().numpy()
    return np.allclose(out_pyt1, out_pyt2, atol=1e-6)
コード例 #3
0
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(
                c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings) == 'application/zip':
            load_tlm_npz(c, embeddings)
        else:
            tlm_load_state_dict(c, embeddings)
        return c
コード例 #4
0
ファイル: model.py プロジェクト: dpressel/mead-baseline
    def create(cls, embeddings, **kwargs):

        lm = cls()
        lm.gpu = kwargs.get('gpu', True)
        lm.tgt_key = kwargs.get('tgt_key')
        if lm.tgt_key is None:
            raise Exception('Need a `tgt_key` to know which source vocabulary should be used for destination ')
        lm.src_keys = kwargs.get('src_keys', embeddings.keys())
        lm.create_layers(embeddings, **kwargs)
        checkpoint_name = kwargs.get('checkpoint')
        if checkpoint_name is not None:
            if checkpoint_name.endswith('npz'):
                load_tlm_npz(lm, checkpoint_name)
            else:
                lm.load_state_dict(torch.load(checkpoint_name))
        return lm
コード例 #5
0
ファイル: embeddings.py プロジェクト: wenshuoliu/baseline
    def load(cls, embeddings, **kwargs):
        c = cls("tlm-words-embed", **kwargs)

        if embeddings.endswith('.bin'):
            # HuggingFace checkpoint, convert on the fly
            from eight_mile.pytorch.serialize import load_tlm_transformers_bin, BERT_HF_FT_LAYER_MAP
            unmatch = load_tlm_transformers_bin(c, embeddings, replace_layers=BERT_HF_FT_LAYER_MAP)
            if unmatch['missing'] or unmatch['unexpected']:
                raise Exception("Unable to load the HuggingFace checkpoint")
        if mime_type(embeddings) == 'application/zip' and not embeddings.endswith("pth"):
            keys_to_restore = set(list(c.embeddings.keys()))
            filtered_keys = keys_to_restore.difference(c.skippable)
            if not keys_to_restore:
                raise Exception("No keys to restore!")
            if len(filtered_keys) < len(keys_to_restore):
                logger.warning("Restoring only key [%s]", ' '.join(filtered_keys))
            load_tlm_npz(c, embeddings, filtered_keys)
        else:
            tlm_load_state_dict(c, embeddings)
        return c
コード例 #6
0
def train():
    parser = ArgumentParser()
    parser.add_argument("--basedir", type=str)
    parser.add_argument("--train_file",
                        type=str,
                        required=True,
                        help='File path to use for train file')
    parser.add_argument("--valid_file",
                        type=str,
                        required=True,
                        help='File path to use for valid file')
    parser.add_argument("--dataset_key",
                        default="tlm",
                        help="dataset key for basedir")
    parser.add_argument(
        "--embed_type",
        type=str,
        default='default',
        choices=["default", "positional", "learned-positional"],
        help="register label of the embeddings")
    parser.add_argument("--d_model",
                        type=int,
                        default=512,
                        help="Model dimension (and embedding dsz)")
    parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension")
    parser.add_argument(
        "--d_k",
        type=int,
        default=None,
        help="Dimension per head.  Use if num_heads=1 to reduce dims")
    parser.add_argument("--num_heads",
                        type=int,
                        default=8,
                        help="Number of heads")
    parser.add_argument("--num_layers",
                        type=int,
                        default=8,
                        help="Number of layers")
    parser.add_argument("--num_train_workers",
                        type=int,
                        default=4,
                        help="Number train workers")
    parser.add_argument("--nctx",
                        type=int,
                        default=256,
                        help="Max input length")
    parser.add_argument("--file_type",
                        default='json',
                        help="Glob pattern for data")
    parser.add_argument("--batch_size",
                        type=int,
                        default=256,
                        help="Batch Size")
    parser.add_argument("--subword_model_file",
                        type=str,
                        help="The BPE model file",
                        required=True)
    parser.add_argument("--subword_vocab_file",
                        type=str,
                        help="The BPE subword vocab",
                        required=True)
    parser.add_argument("--dropout", type=float, default=0.1, help="Dropout")
    parser.add_argument("--ffn_pdrop",
                        type=float,
                        default=0.0,
                        help="Dropout in the dense stack")
    parser.add_argument("--layer_drop",
                        type=float,
                        default=0.0,
                        help="LayerDrop to apply")
    parser.add_argument("--lr_scheduler",
                        type=str,
                        default='cosine',
                        help="The type of learning rate decay scheduler")
    parser.add_argument("--lr_decay_steps",
                        type=int,
                        help="decay steps of lr scheduler")
    parser.add_argument("--lr_decay_rate",
                        type=float,
                        help="decay rate of lr scheduler")
    parser.add_argument("--lr_alpha",
                        type=float,
                        help="parameter alpha for cosine decay scheduler")
    parser.add_argument("--optim",
                        default="adamw",
                        type=str,
                        help="Optimizer to use (defaults to adamw)")
    parser.add_argument("--lr",
                        type=float,
                        default=4.0e-4,
                        help="Learning rate")
    parser.add_argument("--clip",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--weight_decay",
                        type=float,
                        default=1.0e-2,
                        help="Weight decay")
    parser.add_argument("--epochs",
                        type=int,
                        default=32,
                        help="Num training epochs")
    parser.add_argument(
        "--restart_from",
        type=str,
        help="Option allows you to restart from a previous checkpoint")
    parser.add_argument("--restart_tt",
                        type=str,
                        help="Optional param for legacy checkpoints",
                        choices=['step', 'epoch', 'ignore'])
    parser.add_argument("--warmup_steps",
                        type=int,
                        default=10000,
                        help="Num warmup steps")
    parser.add_argument("--saves_per_epoch",
                        type=int,
                        default=10,
                        help="The number of checkpoints to save per epoch")
    parser.add_argument("--mlm",
                        type=str2bool,
                        default=True,
                        help="Use Masked Language Model (MLM) objective")
    parser.add_argument("--preprocessed",
                        type=str2bool,
                        default=True,
                        help="Has the data already been preprocessed?")
    parser.add_argument(
        '--rpr_k',
        help=
        'Relative attention positional sizes pass 0 if you dont want relative attention',
        type=int,
        default=[8],
        nargs='+')
    parser.add_argument(
        '--rpr_value_on',
        type=str2bool,
        default=True,
        help=
        "In relative attention, whether add positional correction to values in addition to the "
        "correction to attention matrix")
    parser.add_argument("--windowed_ra",
                        type=str2bool,
                        default=False,
                        help="whether prevent attention beyond rpr_k")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument("--distributed",
                        type=str2bool,
                        default=False,
                        help="Are we doing distributed training?")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help=
        "Local rank for distributed training (-1 means use the environment variables to find)"
    )

    args = parser.parse_args()

    if args.basedir is None:
        args.basedir = 'lm-{}-bpe-{}'.format(args.dataset_key, os.getpid())
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    num_gpus = get_num_gpus_multiworker()
    args.distributed = args.distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    do_on_demand_masking = args.mlm and not args.preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")
    if args.distributed:
        args.device, updated_local_rank = init_distributed(args.local_rank)
        args.local_rank = updated_local_rank

    if args.file_type == 'tfrecord':
        reader_type = 'tfrecord'
    elif args.preprocessed:
        reader_type = 'preprocessed'
    else:
        reader_type = 'lang'
    reader = MultiFileDatasetReader(
        src_nctx=args.nctx,
        model_file=args.subword_model_file,
        vocab_file=args.subword_vocab_file,
        file_type=args.file_type,
        reader_type=reader_type,
        record_keys=['x', 'y'] if args.mlm else ['x'])

    # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model
    # However, we do need to get counts from our dataset for validation so we can calculate the perplexity
    vocab = reader.build_vocab([args.valid_file])
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=args.d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=args.embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(args.basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(args.basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(args.train_file, vocabs)
    valid_set = reader.load(args.valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=args.batch_size,
                              num_workers=args.num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=args.batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", args.embed_type)

    if args.mlm:
        mask_from = vocabs
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]")
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return

    if len(args.rpr_k) == 0 or args.rpr_k[0] < 1:
        rpr_k = None
    elif len(args.rpr_k) == 1:
        rpr_k = args.rpr_k[0]
    else:
        rpr_k = args.rpr_k

    TLM = TransformerMaskedLanguageModel if args.mlm else TransformerLanguageModel
    model = TLM.create(embeddings,
                       hsz=args.d_model,
                       d_ff=args.d_ff,
                       tie_weights=True,
                       dropout=args.dropout,
                       gpu=False,
                       num_heads=args.num_heads,
                       layers=args.num_layers,
                       rpr_k=rpr_k,
                       d_k=args.d_k,
                       ffn_pdrop=args.ffn_pdrop,
                       windowed_ra=args.windowed_ra,
                       rpr_value_on=args.rpr_value_on,
                       layer_drop=args.layer_drop,
                       src_keys=['x'],
                       tgt_key='x')

    model.to(args.device)
    loss_function = model.create_loss()
    loss_function.to(args.device)

    logger.info("Loaded model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    valid_steps = len(valid_loader)
    update_on = steps_per_epoch // args.saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(args.lr_scheduler,
                            args.lr,
                            steps_per_epoch,
                            args.epochs,
                            logger,
                            decay_steps=args.lr_decay_steps,
                            decay_rate=args.lr_decay_rate,
                            alpha=args.lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr)

    global_step = 0
    start_epoch = 0
    if args.restart_from:

        if args.restart_from.endswith('npz'):
            load_tlm_npz(model, args.restart_from)
        else:
            model.load_state_dict(torch.load(args.restart_from))
        vec = args.restart_from.split("-")

        if args.restart_tt:
            tick_type = args.restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        elif tick_type == 'step':
            start_epoch = step_num // steps_per_epoch
            global_step = step_num
        else:
            logger.warning(
                f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0"
            )

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            args.restart_from, global_step, start_epoch + 1)

    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=args.optim,
                                 lr=args.lr,
                                 lr_function=lr_sched,
                                 weight_decay=args.weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if args.distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        model = DistributedDataParallel(model,
                                        device_ids=[args.device],
                                        output_device=args.device)
        logger.info("Model located on %s", args.device)

    model_base = os.path.join(args.basedir, 'checkpoint')
    steps = global_step

    timer = Timer()
    for epoch in range(start_epoch, args.epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1
            x, y = batch
            inputs = x.to(args.device)
            labels = y.to(args.device)
            if do_on_demand_masking:
                inputs, labels, _ = on_demand_mlm_masking(
                    inputs, labels, mask_value, vocab_size)
            inputs = {'x': inputs}

            labels = labels.transpose(0, 1).contiguous()
            logits = model(inputs, None)[0].transpose(0, 1).contiguous()
            if args.mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:-1]
                shift_labels = labels[1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)

            if (i + 1) % update_on == 0 and args.local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)
                save_checkpoint(model, model_base, steps, tick_type='step')

        # How much time elapsed in minutes
        elapsed = timer.elapsed(True)
        train_token_loss = avg_loss.avg
        # This is the average training token-level loss across all machines
        # This is the token-level training perplexity
        train_token_ppl = math.exp(train_token_loss)
        metrics['train_elapsed_min'] = elapsed
        metrics['average_train_loss'] = train_token_loss
        metrics['train_ppl'] = train_token_ppl
        if args.local_rank < 1:
            avg_valid_loss = Average('average_valid_loss')
            timer.start()
            model.eval()
            valid_itr = iter(valid_loader)
            for j in range(valid_steps):
                batch = next(valid_itr)
                with torch.no_grad():
                    x, y = batch
                    inputs = x.to(args.device)
                    labels = y.to(args.device)

                    if do_on_demand_masking:
                        inputs, labels, _ = on_demand_mlm_masking(
                            inputs, labels, mask_value, vocab_size)
                    inputs = {'x': inputs}
                    labels = labels.transpose(0, 1).contiguous()
                    logits = model(inputs, None)[0].transpose(0,
                                                              1).contiguous()
                    if args.mlm:
                        loss = loss_function(logits, labels)
                    else:
                        shift_logits = logits[:-1]
                        shift_labels = labels[1:]
                        loss = loss_function(shift_logits, shift_labels)
                    avg_valid_loss.update(loss.item())

            valid_token_loss = avg_valid_loss.avg
            valid_token_ppl = math.exp(valid_token_loss)

            metrics['valid_elapsed_min'] = timer.elapsed(True)
            metrics['average_valid_loss'] = valid_token_loss
            metrics['average_valid_word_ppl'] = valid_token_ppl
            logger.info(metrics)
            save_checkpoint(model, model_base, epoch, save_npz=True)
コード例 #7
0
def run(basedir=None,
        train_file=None,
        valid_file=None,
        dataset_key='tlm',
        embed_type='default',
        d_model=512,
        d_ff=2048,
        d_k=None,
        num_heads=8,
        num_layers=8,
        num_train_workers=4,
        nctx=256,
        file_type='json',
        batch_size=256,
        subword_model_file=None,
        subword_vocab_file=None,
        dropout=0.1,
        ffn_pdrop=0.0,
        layer_drop=0.0,
        lr_scheduler='cosine',
        lr_decay_steps=None,
        lr_decay_rate=None,
        lr_alpha=0.0,
        optim='adamw',
        lr=4.0e-4,
        clip=1.0,
        weight_decay=1.0e-2,
        epochs=32,
        restart_from=None,
        restart_tt=None,
        warmup_steps=10000,
        saves_per_epoch=10,
        mlm=True,
        preprocessed=True,
        rpr_k=[8],
        rpr_value_on=False,
        windowed_ra=False,
        device="cuda",
        distributed=False,
        local_rank=-1,
        extra_tokens=["[CLS]", "[MASK]"],
        do_early_stopping=False,
        model_type='transformer-mlm',
        modules=[],
        ra_type=None,
        transformer_type=None,
        **kwargs):
    if basedir is None:
        basedir = 'lm-{}-bpe-{}'.format(dataset_key, os.getpid())
    logging.basicConfig(
        level=logging.INFO if local_rank in [-1, 0] else logging.WARN)

    for module in modules:
        import_user_module(module)
    num_gpus = get_num_gpus_multiworker()
    distributed = distributed or num_gpus > 1
    logger.info(f"Using {num_gpus} GPUs in this job.")

    do_on_demand_masking = mlm and not preprocessed
    if do_on_demand_masking:
        logger.info(f"On-demand masking is turned on")
    if distributed:
        device, updated_local_rank = init_distributed(local_rank)
        local_rank = updated_local_rank

    if file_type == 'tfrecord':
        reader_type = 'tfrecord'
    elif preprocessed:
        reader_type = 'preprocessed'
    else:
        reader_type = 'lang'
    reader = MultiFileDatasetReader(src_nctx=nctx,
                                    model_file=subword_model_file,
                                    vocab_file=subword_vocab_file,
                                    file_type=file_type,
                                    reader_type=reader_type,
                                    record_keys=['x', 'y'] if mlm else ['x'],
                                    extra_tokens=extra_tokens)

    # This looks a bit funny but the streaming reader ignores our vocab and gives us the one from the subword_model
    # However, we do need to get counts from our dataset for validation so we can calculate the perplexity
    vocab = reader.build_vocab([valid_file])
    # If we are not using chars, then use 'x' for both input and output
    preproc_data = baseline.embeddings.load_embeddings(
        'x',
        dsz=d_model,
        known_vocab=vocab['x'],
        preserve_vocab_indices=True,
        embed_type=embed_type)
    vocabs = preproc_data['vocab']

    os.makedirs(basedir, exist_ok=True)
    # We want to make sure to save our input vocab into the basedir for reuse later
    write_json(vocabs, os.path.join(basedir, 'vocabs.json'))
    embeddings = {'x': preproc_data['embeddings']}
    logger.info("Loaded embeddings")

    train_set = reader.load(train_file, vocabs)
    valid_set = reader.load(valid_file,
                            vocabs,
                            distribute=False,
                            shuffle=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=num_train_workers)
    valid_loader = DataLoader(valid_set, batch_size=batch_size)
    logger.info("Loaded datasets")
    logger.info("Using embedding type [%s]", embed_type)

    if 'mlm' in model_type:
        mask_from = vocabs
        vocab_size = len(mask_from)
        mask_value = mask_from.get("[MASK]")
        if mask_value == -1:
            logger.error(
                "We could not find a suitable masking token in the vocab")
            return

    if len(rpr_k) == 0 or rpr_k[0] < 1:
        rpr_k = None
    elif len(rpr_k) == 1:
        rpr_k = None if rpr_k[0] == 0 else rpr_k[0]
    if ra_type != None and ra_type != 'shaw' and rpr_k is not None:
        print(
            f"Relative attention mismatch. You requested {ra_type} with rpr set.  Setting it to 0"
        )
        rpr_k = None

    model = create_lang_model(
        embeddings,
        hsz=d_model,
        nctx=nctx,  # Only for gMLP
        d_ff=d_ff,
        tie_weights=True,
        dropout=dropout,
        gpu=False,
        num_heads=num_heads,
        layers=num_layers,
        rpr_k=rpr_k,
        d_k=d_k,
        ffn_pdrop=ffn_pdrop,
        windowed_ra=windowed_ra,
        rpr_value_on=rpr_value_on,
        layer_drop=layer_drop,
        model_type=model_type,
        ra_type=ra_type,
        transformer_type=transformer_type,
        src_keys=['x'],
        tgt_key='x')
    model.to(device)

    loss_function = model.create_loss()
    loss_function.to(device)

    logger.info("Loaded model and loss")

    steps_per_epoch = len(train_loader) // num_gpus
    update_on = steps_per_epoch // saves_per_epoch
    report_on = max(10, update_on) // 10
    logger.info(
        f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps."
    )
    lr_decay = get_lr_decay(lr_scheduler,
                            lr,
                            steps_per_epoch,
                            epochs,
                            logger,
                            decay_steps=lr_decay_steps,
                            decay_rate=lr_decay_rate,
                            alpha=lr_alpha)
    linear_warmup = WarmupLinearSchedulerPyTorch(warmup_steps, lr=lr)
    lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=lr)

    global_step = 0
    start_epoch = 0
    if restart_from:

        if restart_from.endswith('npz'):
            load_tlm_npz(model, restart_from)
        else:
            model.load_state_dict(torch.load(restart_from))
        vec = restart_from.split("-")

        if restart_tt:
            tick_type = restart_tt
        else:
            tick_type = vec[-2]
        step_num = int(vec[-1].split(".")[0])
        if tick_type == 'epoch':
            start_epoch = step_num
            global_step = start_epoch * steps_per_epoch

        elif tick_type == 'step':
            start_epoch = step_num // steps_per_epoch
            global_step = step_num
        else:
            logger.warning(
                f"The previous tick was {step_num} but command-line specifies to ignore, setting to 0"
            )

        logger.info(
            "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d",
            restart_from, global_step, start_epoch + 1)

    optimizer = OptimizerManager(model,
                                 global_step,
                                 optim=optim,
                                 lr=lr,
                                 lr_function=lr_sched,
                                 weight_decay=weight_decay)
    logger.info("Model has {:,} parameters".format(
        sum(p.numel() for p in model.parameters() if p.requires_grad)))

    # Prepare model for distributed training if needed
    if distributed:
        # This program assume pure data parallelism, each model is on a single gpu
        # If we wanted to support model and data parallelism we would need to update
        # the selection of gpus based on rank, it would need to select multiple ids
        # based on rank, here we select only a single gpu and use it for input and
        # output.
        model = DistributedDataParallel(model,
                                        device_ids=[device],
                                        output_device=device,
                                        find_unused_parameters=True)
        logger.info("Model located on %s", device)

    model_base = os.path.join(basedir, 'checkpoint')
    steps = global_step
    best_valid_loss = np.inf

    timer = Timer()
    for epoch in range(start_epoch, epochs):
        avg_loss = Average('average_train_loss')
        metrics = {}
        optimizer.zero_grad()
        timer.start()
        model.train()
        train_itr = iter(train_loader)
        for i in range(steps_per_epoch):
            batch = next(train_itr)
            steps += 1
            x, y = batch
            inputs = x.to(device)
            labels = y.to(device)
            if do_on_demand_masking:
                inputs, labels, _ = on_demand_mlm_masking(
                    inputs, labels, mask_value, vocab_size)
            inputs = {'x': inputs}

            labels = labels.contiguous()
            logits = model(inputs, None)[0].contiguous()
            if mlm:
                loss = loss_function(logits, labels)
            else:
                shift_logits = logits[:, -1]
                shift_labels = labels[:, 1:]
                loss = loss_function(shift_logits, shift_labels)
            loss.backward()
            avg_loss.update(loss.item())

            torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
            optimizer.step()
            optimizer.zero_grad()
            if (i + 1) % report_on == 0:
                logging.info(avg_loss)

            if (i + 1) % update_on == 0 and local_rank < 1:
                elapsed = timer.elapsed(True)
                logging.info('elapsed time this epoch %d min', elapsed)
                logging.info('elapsed step time %f steps/min', i / elapsed)
                logging.info('LR: %f', optimizer.current_lr)

                if not do_early_stopping:
                    save_checkpoint(model, model_base, steps, tick_type='step')
                else:
                    valid_token_loss = validate(model, loss_function,
                                                valid_loader, avg_loss, timer,
                                                metrics, do_on_demand_masking,
                                                mlm, mask_value, vocab_size,
                                                device)
                    if valid_token_loss < best_valid_loss:
                        best_valid_loss = valid_token_loss
                        logger.info(
                            f"New best valid loss: {best_valid_loss}. Saving checkpoint..."
                        )
                        save_checkpoint(model,
                                        model_base,
                                        steps,
                                        tick_type='step')
                    model.train()

        if not do_early_stopping:
            _ = validate(model, loss_function, valid_loader, avg_loss, timer,
                         metrics, do_on_demand_masking, mlm, mask_value,
                         vocab_size, device)
            save_checkpoint(model, model_base, epoch, tick_type='epoch')