def test_model2model_from_pretrained_not_bert(self):
        logging.basicConfig(level=logging.INFO)
        with self.assertRaises(ValueError):
            _ = Model2Model.from_pretrained('roberta')

        with self.assertRaises(ValueError):
            _ = Model2Model.from_pretrained('distilbert')

        with self.assertRaises(ValueError):
            _ = Model2Model.from_pretrained('does-not-exist')
 def test_model2model_from_pretrained(self):
     logging.basicConfig(level=logging.INFO)
     for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = Model2Model.from_pretrained(model_name)
         self.assertIsInstance(model.encoder, BertModel)
         self.assertIsInstance(model.decoder, BertForMaskedLM)
         self.assertEqual(model.decoder.config.is_decoder, True)
         self.assertEqual(model.encoder.config.is_decoder, False)
Exemple #3
0
def get_BertAbs_model(args):
    """ Initializes the BertAbs model for finetuning.
    """
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path,do_lower_case=False)
    decoder_config = BertConfig(
        hidden_size=768,
        num_hidden_layers=6,
        num_attention_heads=8,
        intermediate_size=2048,
        hidden_dropout_prob=0.2,
        attention_probs_dropout_prob=0.2,
    )
    decoder_model = BertForMaskedLM(decoder_config)

    model = Model2Model.from_pretrained(args.model_name_or_path,decoder_model=decoder_model)
    return tokenizer,model
Exemple #4
0
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help="The input training data file (a text file).",
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model predictions and checkpoints will be written.",
    )

    # Optional parameters
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--do_evaluate",
        type=bool,
        default=False,
        help="Run model evaluation on out-of-sample data.",
    )
    parser.add_argument("--do_train", type=bool, default=False, help="Run training.")
    parser.add_argument(
        "--do_overwrite_output_dir",
        type=bool,
        default=False,
        help="Whether to overwrite the output dir.",
    )
    parser.add_argument(
        "--model_name_or_path",
        default="bert-base-cased",
        type=str,
        help="The model checkpoint to initialize the encoder and decoder's weights with.",
    )
    parser.add_argument(
        "--model_type",
        default="bert",
        type=str,
        help="The decoder architecture to be fine-tuned.",
    )
    parser.add_argument(
        "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
    )
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help="If > 0: set total number of training steps to perform. Override num_train_epochs.",
    )
    parser.add_argument(
        "--to_cpu", default=False, type=bool, help="Whether to force training on CPU."
    )
    parser.add_argument(
        "--num_train_epochs",
        default=10,
        type=int,
        help="Total number of training epochs to perform.",
    )
    parser.add_argument(
        "--per_gpu_eval_batch_size",
        default=4,
        type=int,
        help="Batch size per GPU/CPU for eval.",
    )
    parser.add_argument(
        "--per_gpu_train_batch_size",
        default=4,
        type=int,
        help="Batch size per GPU/CPU for training.",
    )
    parser.add_argument(
        "--input_block_size",
        default=256,
        type=int,
        help="Max seq length for input",
    )
    parser.add_argument(
        "--output_block_size",
        default=64,
        type=int,
        help="Max seq length for output",
    )

    parser.add_argument(
        "--trained_checkpoints",
        default="",
        type=str,
        help="trained_checkpoints",
    )

    parser.add_argument(
        "--decoding_type",
        default="pnt",
        type=str,
        help="",
    )

    parser.add_argument(
        "--encoder_lr",
        default=5e-4,
        type=float,
        help="encoder's learning rate",
    )

    parser.add_argument(
        "--decoder_lr",
        default=5e-4,
        type=float,
        help="encoder's learning rate",
    )

    parser.add_argument(
        "--encoder_warmup",
        default=10,
        type=int,
        help="encoder's learning rate",
    )

    parser.add_argument(
        "--decoder_warmup",
        default=100,
        type=int,
        help="encoder's learning rate",
    )

    parser.add_argument("--seed", default=42, type=int)
    args = parser.parse_args()

    if (
                        os.path.exists(args.output_dir)
                    and os.listdir(args.output_dir)
                and args.do_train
            and not args.do_overwrite_output_dir
    ):
        raise ValueError(
            "Output directory ({}) already exists and is not empty. Use --do_overwrite_output_dir to overwrite.".format(
                args.output_dir
            )
        )

    # Set up training device
    if args.to_cpu or not torch.cuda.is_available():
        args.device = torch.device("cpu")
        args.n_gpu = 0
    else:
        args.device = torch.device("cuda")
        args.n_gpu = torch.cuda.device_count()
        print(args.n_gpu)

    # Load pretrained model and tokenizer. The decoder's weights are randomly initialized.
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)
    #config = BertConfig.from_pretrained(args.model_name_or_path)
    #config.num_hidden_layers=3
    #config.is_decoder=True
    #decoder_model = BertForMaskedLM(config)
    decoder_model = BertForMaskedLM.from_pretrained(r'/data/zhuoyu/semantic_parsing/models')
    model = Model2Model.from_pretrained(
        args.model_name_or_path, decoder_model=decoder_model
    )
    #model = Model2Model.from_pretrained(
    #    args.model_name_or_path, decoder_model=None
    #)

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        0,
        args.device,
        args.n_gpu,
        False,
        False,
    )

    logger.info("Training/evaluation parameters %s", args)

    # Train the model
    model.to(args.device)
    if args.do_train:
        global_step, tr_loss = train(args, model, tokenizer)
        logger.info(" global_step = %s, average loss = %s", global_step, tr_loss)

        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)

        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`
        model_to_save = (
            model.module if hasattr(model, "module") else model
        )  # Take care of distributed/parallel training
        model_to_save.save_pretrained(args.output_dir)
        tokenizer.save_pretrained(args.output_dir)
        torch.save(args, os.path.join(args.output_dir, "training_arguments.bin"))

    # Evaluate the model
    results = {}
    if args.do_evaluate:
        checkpoints = [args.trained_checkpoints]
        logger.info("Evaluate the following checkpoints: %s", checkpoints)
        for checkpoint in checkpoints:
            encoder_checkpoint = os.path.join(checkpoint, "encoder")
            decoder_checkpoint = os.path.join(checkpoint, "decoder")
            #model = PreTrainedEncoderDecoder.from_pretrained(
            #    encoder_checkpoint, decoder_checkpoint
            #)
            #model = Model2Model.from_pretrained(encoder_checkpoint)
            #model.to(args.device)
            results = "placeholder"

            evaluate(args,model,tokenizer,"test")

    return results
Exemple #5
0
# Encode the input to the decoder (the answer)
answer = "Jim Henson was a puppeteer"
encoded_answer = tokenizer.encode(answer)

# Convert inputs to PyTorch tensors
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])

# In order to compute the loss we need to provide language model
# labels (the token ids that the model should have produced) to
# the decoder.
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels])

# Load pre-trained model (weights)
model = Model2Model.from_pretrained('bert-base-uncased')

# Set the model in evaluation mode to deactivate the DropOut modules
# This is IMPORTANT to have reproducible results during evaluation!
model.eval()

# If you have a GPU, put everything on cuda
question_tensor = question_tensor.to('cuda')
answer_tensor = answer_tensor.to('cuda')
labels_tensor = labels_tensor.to('cuda')
model.to('cuda')

# Predict hidden states features for each layer
with torch.no_grad():
    # See the models docstrings for the detail of the inputs
    outputs = model(question_tensor,
def create_model():
    config = BertConfig.from_pretrained(BERT_PATH)
    decoder_model = BertForMaskedLM(config)
    QA_model = Model2Model.from_pretrained(BERT_PATH, decoder_model=decoder_model)

    return QA_model
Exemple #7
0
def run():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument(
        "--model", type=str, default="multi-bert",
        help="Model type")  # anything besides gpt2 will load openai-gpt
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="",
                        help="Path, url or short name of the model")
    parser.add_argument(
        "--max_turns",
        type=int,
        default=3,
        help="Number of previous utterances to keep in history")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")

    parser.add_argument("--no_sample",
                        action='store_true',
                        help="Set to use greedy decoding instead of sampling")
    parser.add_argument("--max_length",
                        type=int,
                        default=20,
                        help="Maximum length of the output utterances")
    parser.add_argument("--min_length",
                        type=int,
                        default=1,
                        help="Minimum length of the output utterances")
    parser.add_argument("--seed", type=int, default=0, help="Seed")
    parser.add_argument("--temperature",
                        type=int,
                        default=0.7,
                        help="Sampling softmax temperature")
    parser.add_argument(
        "--top_k",
        type=int,
        default=0,
        help="Filter top-k tokens before sampling (<=0: no filtering)")
    parser.add_argument(
        "--top_p",
        type=float,
        default=0.9,
        help="Nucleus filtering (top-p) before sampling (<=0.0: no filtering)")
    parser.add_argument("--test_lang",
                        type=str,
                        default="",
                        help="test monolingual model")
    parser.add_argument("--no_lang_id", action='store_true')
    args = parser.parse_args()

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__file__)
    logger.info(pformat(args))

    if args.seed != 0:
        random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

    logger.info("Get pretrained model and tokenizer")
    tokenizer = BertTokenizer.from_pretrained(args.model_checkpoint)
    if args.test_lang == "Jp":
        tokenizer = BertJapaneseTokenizer.from_pretrained(
            args.model_checkpoint)
    bertconfig = BertConfig.from_pretrained(args.model_checkpoint)
    bertconfig.is_decoder = True
    model = Model2Model.from_pretrained(args.model_checkpoint,
                                        **{"decoder_config": bertconfig})

    with open(args.model_checkpoint + "/added_tokens.json",
              encoding="utf-8") as f:
        special_map = json.load(f)

    model.load_state_dict(
        torch.load(args.model_checkpoint + "/pytorch_model.bin"))
    model.to(args.device)
    model.eval()

    personachat = get_dataset(tokenizer, args.dataset_path, args.dataset_cache,
                              args.test_lang)
    persona_text = {}
    context_text = {}
    output_text = {}
    ref_text = {}
    ppl = {}
    BLEU_score = {}
    if args.test_lang in ["En", "Fr", "It", "Id", "Jp", "Ko", "Zh"]:
        logdir = args.test_lang + "_bert2bert/"
    else:
        logdir = "multilingual_bert2bert/"
    for lang, dials in personachat["test"].items():
        if args.test_lang in ["En", "Fr", "It", "Id", "Jp", "Ko", "Zh"]:
            if lang != args.test_lang:
                continue
        persona_text[lang] = []
        context_text[lang] = []
        output_text[lang] = []
        ref_text[lang] = []
        loss_list = []
        for dial in dials:  #dial: {"persona":[], "history":[], "response":str}
            with torch.no_grad():
                out_ids, loss = sample_sequence(
                    dial["persona"],
                    dial["history"][-args.max_turns:],
                    tokenizer,
                    model,
                    args,
                    "<{}>".format(lang.lower()),
                    special_map,
                    ref=dial["response"])
                output_text[lang].append(
                    tokenizer.decode(out_ids, skip_special_tokens=True))
                # print(tokenizer.decode(dial["history"][-1]))
                # print(output_text[lang][-1])
                # print("-")
                ref_text[lang].append(
                    tokenizer.decode(dial["response"],
                                     skip_special_tokens=True))
                context_text[lang].append([
                    tokenizer.decode(turn, skip_special_tokens=True)
                    for turn in dial["history"]
                ])
                persona_text[lang].append([
                    tokenizer.decode(sent, skip_special_tokens=True)
                    for sent in dial["persona"]
                ])
                loss_list.append(loss)
        ppl[lang] = math.exp(np.mean(loss_list))
        print("{} PPL:{}".format(lang, ppl[lang]))
        if not os.path.exists("results/" + logdir):
            os.makedirs("results/" + logdir)
        with open("results/" + logdir + lang + "_output.txt",
                  'w',
                  encoding='utf-8') as f:
            for line in output_text[lang]:
                f.write(line)
                f.write('\n')

        with open("results/" + logdir + lang + "_ref.txt",
                  'w',
                  encoding='utf-8') as f:
            for line in ref_text[lang]:
                f.write(line)
                f.write('\n')
        print("Computing BLEU for {}".format(lang))
        BLEU = moses_multi_bleu(np.array(output_text[lang]),
                                np.array(ref_text[lang]))
        print("BLEU:{}".format(BLEU))
        BLEU_score[lang] = BLEU

        with open("results/" + logdir + lang + "_all.txt",
                  'w',
                  encoding='utf-8') as f:
            for i in range(len(ref_text[lang])):
                f.write(
                    "=====================================================\n")
                f.write("Persona:\n")
                for sent in persona_text[lang][i]:
                    f.write("".join(sent.split()) if lang in
                            ["jp", "zh"] else sent)
                    f.write('\n')
                f.write("History:\n")
                for sent in context_text[lang][i]:
                    f.write("".join(sent.split()) if lang in
                            ["jp", "zh"] else sent)
                    f.write('\n')
                f.write("Response: ")
                f.write("".join(output_text[lang][i].split()) if lang in
                        ["jp", "zh"] else output_text[lang][i])
                f.write('\n')
                f.write("Ref: ")
                f.write("".join(ref_text[lang][i].split()) if lang in
                        ["jp", "zh"] else ref_text[lang][i])
                f.write('\n')

    with open("results/" + logdir + "BLEU_score.txt", "w",
              encoding='utf-8') as f:
        for language, score in BLEU_score.items():
            f.write("{}\t PPL:{}, BLEU:{}\n".format(language, ppl[language],
                                                    score))
# Convert inputs to PyTorch tensors
question_tensor = torch.tensor([encoded_question])
answer_tensor = torch.tensor([encoded_answer])
'''
Use Model2Model to get the value of the loss associated with this (question, answer) pair
'''

# In order to compute the loss we need to provide language model
# labels (the token ids that the model should have produced) to
# the decoder.
lm_labels = encoded_answer
labels_tensor = torch.tensor([lm_labels])

# Load pre-trained model (weights)
model = Model2Model.from_pretrained('bert-base-uncased')

# Set the model in evaluation mode to deactivate the DropOut modules
# This is IMPORTANT to have reproducible results during evaluation!
model.eval()

# Predict hidden states features for each layer
with torch.no_grad():
    # See the models docstrings for the detail of the inputs
    outputs = model(question_tensor,
                    answer_tensor,
                    decoder_lm_labels=labels_tensor)
    # Transformers models always output tuples.
    # See the models docstrings for the detail of all the outputs
    # In our case, the first element is the value of the LM loss
    lm_loss = outputs[0]
Exemple #9
0
def train():
    parser = ArgumentParser()
    parser.add_argument(
        "--dataset_path",
        type=str,
        default="",
        help="Path or url of the dataset. If empty download from S3.")
    parser.add_argument("--dataset_cache",
                        type=str,
                        default='./dataset_cache',
                        help="Path or url of the dataset cache")
    parser.add_argument("--model_checkpoint",
                        type=str,
                        default="multi-bert",
                        help="Path, url or short name of the model")
    parser.add_argument("--num_candidates",
                        type=int,
                        default=2,
                        help="Number of candidates for training")
    parser.add_argument("--max_turns",
                        type=int,
                        default=3,
                        help="Number of previous turns to keep in history")
    parser.add_argument("--train_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for training")
    parser.add_argument("--valid_batch_size",
                        type=int,
                        default=4,
                        help="Batch size for validation")
    parser.add_argument("--gradient_accumulation_steps",
                        type=int,
                        default=8,
                        help="Accumulate gradients on several steps")
    parser.add_argument("--lr",
                        type=float,
                        default=6.25e-5,
                        help="Learning rate")
    parser.add_argument("--lm_coef",
                        type=float,
                        default=1.0,
                        help="LM loss coefficient")
    parser.add_argument("--max_norm",
                        type=float,
                        default=1.0,
                        help="Clipping gradient norm")
    parser.add_argument("--n_epochs",
                        type=int,
                        default=3,
                        help="Number of training epochs")
    parser.add_argument("--personality_permutations",
                        type=int,
                        default=1,
                        help="Number of permutations of personality sentences")
    parser.add_argument(
        "--eval_before_start",
        action='store_true',
        help="If true start with a first evaluation before training")
    parser.add_argument("--device",
                        type=str,
                        default="cuda" if torch.cuda.is_available() else "cpu",
                        help="Device (cuda or cpu)")
    parser.add_argument(
        "--fp16",
        type=str,
        default="",
        help=
        "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=-1,
        help="Local rank for distributed training (-1: not distributed)")
    parser.add_argument("--random_init",
                        action='store_true',
                        help="If true random initailze the model")
    parser.add_argument(
        "--train_lang",
        type=str,
        default="",
        help="train monolingual model, defaul: multilingual model")
    args = parser.parse_args()

    # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes
    logging.basicConfig(
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Running process %d", args.local_rank
    )  # This is a logger.warning: it will be printed by all distributed processes
    logger.info("Arguments: %s", pformat(args))

    # Initialize distributed training if needed
    args.distributed = (args.local_rank != -1)
    if args.distributed:
        torch.cuda.set_device(args.local_rank)
        args.device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    # Model
    logger.info("Prepare tokenizer, pretrained model and optimizer.")
    model_path = 'bert-base-multilingual-cased'
    if args.train_lang in ["En", "It", "Jp",
                           "Zh"]:  # for Fr Ko Id we use MBERT
        model_path = LANG_2_MODEL[args.train_lang]

    tokenizer = BertTokenizer.from_pretrained(model_path)
    if args.train_lang == "Jp":
        tokenizer = BertJapaneseTokenizer.from_pretrained(model_path)
    model = Model2Model.from_pretrained(model_path)

    # tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
    # if args.random_init:
    #     config = BertConfig.from_pretrained('bert-base-multilingual-cased')
    #     config.is_decoder = True
    #     bert_decoder = BertForMaskedLM(config)
    #     model = Model2Model.from_pretrained('bert-base-multilingual-cased', decoder_model=bert_decoder)
    # else:
    #     model = Model2Model.from_pretrained('bert-base-multilingual-cased')
    #     model_dict = model.state_dict()
    #     # initialize crossattention with selfattention
    #     model_dict.update({ name: model_dict[name.replace("crossattention", "attention")] for name in model_dict if "crossattention" in name })
    #     model.load_state_dict(model_dict)
    model.to(args.device)

    # Add special tokens if they are not already added
    add_special_tokens_(model, tokenizer)
    optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True)

    # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last)
    if args.fp16:
        from apex import amp  # Apex is only required if we use fp16 training
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16)
    if args.distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[args.local_rank],
                                        output_device=args.local_rank,
                                        find_unused_parameters=True)

    logger.info("Prepare datasets")
    train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders(
        args, tokenizer)

    # Training function and trainer
    def update(engine, batch):
        model.train()
        batch = tuple(batch[input_name].to(args.device)
                      for input_name in MODEL_INPUTS)

        #batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
        encoder_mask, decoder_mask, encoder_input_ids, decoder_input_ids, lm_labels, token_type_ids, decoder_lang_id = batch
        model_kwargs = {
            "encoder_token_type_ids": token_type_ids,
            "decoder_token_type_ids": decoder_lang_id,
            "encoder_attention_mask": encoder_mask,
            "decoder_attention_mask": decoder_mask,
            "decoder_lm_labels": lm_labels
        }
        lm_loss, prediction_scores, *_ = model(
            encoder_input_ids=encoder_input_ids,
            decoder_input_ids=decoder_input_ids,
            **model_kwargs)

        loss = (lm_loss) / args.gradient_accumulation_steps
        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
            torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                           args.max_norm)
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
        if engine.state.iteration % args.gradient_accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()
        return loss.item()

    trainer = Engine(update)

    # Evaluation function and evaluator (evaluator output is the input of the metrics)
    def inference(engine, batch):
        model.eval()
        with torch.no_grad():
            batch = tuple(batch[input_name].to(args.device)
                          for input_name in MODEL_INPUTS)
            #batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
            encoder_mask, decoder_mask, encoder_input_ids, decoder_input_ids, lm_labels, token_type_ids, decoder_lang_id = batch
            logger.info(tokenizer.decode(encoder_input_ids[0, :].tolist()))
            # if we dont send labels to model, it doesnt return losses
            model_kwargs = {
                "encoder_token_type_ids": token_type_ids,
                "decoder_token_type_ids": decoder_lang_id,
                "encoder_attention_mask": encoder_mask,
                "decoder_attention_mask": decoder_mask
            }

            lm_logits, *_ = model(encoder_input_ids=encoder_input_ids,
                                  decoder_input_ids=decoder_input_ids,
                                  **model_kwargs)
            lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view(
                -1, lm_logits.size(-1))
            lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1)
            return (lm_logits_flat_shifted, ), (lm_labels_flat_shifted, )

    evaluator = Engine(inference)

    # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch
    trainer.add_event_handler(Events.EPOCH_COMPLETED,
                              lambda _: evaluator.run(val_loader))
    if args.n_epochs < 1:
        trainer.add_event_handler(Events.COMPLETED,
                                  lambda _: evaluator.run(val_loader))
    if args.eval_before_start:
        trainer.add_event_handler(Events.STARTED,
                                  lambda _: evaluator.run(val_loader))

    # Make sure distributed data samplers split the dataset nicely between the distributed processes
    if args.distributed:
        trainer.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: train_sampler.set_epoch(engine.state.epoch))
        evaluator.add_event_handler(
            Events.EPOCH_STARTED,
            lambda engine: valid_sampler.set_epoch(engine.state.epoch))

    # Linearly decrease the learning rate from lr to zero
    scheduler = PiecewiseLinear(optimizer, "lr",
                                [(0, args.lr),
                                 (args.n_epochs * len(train_loader), 0.0)])
    trainer.add_event_handler(Events.ITERATION_STARTED, scheduler)

    # Prepare metrics - note how we compute distributed metrics
    RunningAverage(output_transform=lambda x: x).attach(trainer, "loss")
    metrics = {
        "nll":
        Loss(torch.nn.CrossEntropyLoss(ignore_index=-1),
             output_transform=lambda x: (x[0][0], x[1][0]))
    }
    metrics.update({
        "average_nll":
        MetricsLambda(average_distributed_scalar, metrics["nll"], args)
    })
    metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"])
    for name, metric in metrics.items():
        metric.attach(evaluator, name)

    # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
    if args.local_rank in [-1, 0]:
        pbar = ProgressBar(persist=True)
        pbar.attach(trainer, metric_names=["loss"])
        evaluator.add_event_handler(
            Events.COMPLETED, lambda _: pbar.log_message(
                "Validation: %s" % pformat(evaluator.state.metrics)))

        log_dir = make_logdir(args.model_checkpoint)
        log_dir += "_lang_id"
        if args.random_init:
            log_dir = log_dir + "_random_init"
        tb_logger = TensorboardLogger(log_dir)

        tb_logger.attach(trainer,
                         log_handler=OutputHandler(tag="training",
                                                   metric_names=["loss"]),
                         event_name=Events.ITERATION_COMPLETED)
        tb_logger.attach(trainer,
                         log_handler=OptimizerParamsHandler(optimizer),
                         event_name=Events.ITERATION_STARTED)
        tb_logger.attach(evaluator,
                         log_handler=OutputHandler(tag="validation",
                                                   metric_names=list(
                                                       metrics.keys()),
                                                   another_engine=trainer),
                         event_name=Events.EPOCH_COMPLETED)

        checkpoint_handler = ModelCheckpoint(log_dir,
                                             'checkpoint',
                                             save_interval=1,
                                             n_saved=3)
        trainer.add_event_handler(
            Events.EPOCH_COMPLETED, checkpoint_handler,
            {'mymodel': getattr(model, 'module', model)
             })  # "getattr" takes care of distributed encapsulation

        torch.save(args, log_dir + '/model_training_args.bin')
        if args.distributed:
            getattr(model.module, 'encoder', model).config.to_json_file(
                os.path.join(log_dir, CONFIG_NAME)
            )  # the config for encoder and decoder should be the same
        else:
            getattr(model, 'encoder', model).config.to_json_file(
                os.path.join(log_dir, CONFIG_NAME)
            )  # the config for encoder and decoder should be the same
        tokenizer.save_pretrained(log_dir)

    # Run the training
    trainer.run(train_loader, max_epochs=args.n_epochs)

    # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method)
    if args.local_rank in [-1, 0] and args.n_epochs > 0:
        os.rename(
            checkpoint_handler._saved[-1][1][-1],
            os.path.join(log_dir, WEIGHTS_NAME)
        )  # TODO: PR in ignite to have better access to saved file paths (cleaner)
        tb_logger.close()
# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

'''
Assuming that we fine-tuned the model, let us now see how to generate an answer
'''

# Let's re-use the previous question
question = "Who was Jim Henson?"
encoded_question = tokenizer.encode(question)
question_tensor = torch.tensor([encoded_question])

# This time we try to generate the answer, so we start with an empty sequence
answer = "[CLS]"
encoded_answer = tokenizer.encode(answer, add_special_tokens=False)
answer_tensor = torch.tensor([encoded_answer])

# Load pre-trained model (weights)
model = Model2Model.from_pretrained('fine-tuned-weights')
model.eval()

# Predict all tokens
with torch.no_grad():
    outputs = model(question_tensor, answer_tensor)
    predictions = outputs[0]

# confirm we were able to predict 'jim'
predicted_index = torch.argmax(predictions[0, -1]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
assert predicted_token == 'jim'