Example #1
0
    def create_transfo_xl_lm_head_trainer_incompatible_tuple(
            self, config, input_ids_1, input_ids_2, lm_labels):
        config.trainer_compatible = False
        model = TransfoXLLMHeadModel(config)
        model.to(torch_device)
        model.eval()

        lm_logits_1 = model(input_ids_1, return_dict=False)[0]
        outputs1 = model(input_ids_1, labels=lm_labels, return_dict=False)
        losses_1, _, mems_1 = outputs1[:3]
        loss_1 = outputs1[-1]
        lm_logits_2 = model(input_ids_2, mems=mems_1, return_dict=False)[0]
        outputs2 = model(input_ids_2, labels=lm_labels, mems=mems_1)
        losses_2, _, mems_2 = outputs2[:3]
        loss_2 = outputs2[-1]

        outputs = {
            "losses_1": losses_1,
            "mems_1": mems_1,
            "lm_logits_1": lm_logits_1,
            "loss_1": loss_1,
            "losses_2": losses_2,
            "mems_2": mems_2,
            "lm_logits_2": lm_logits_2,
            "loss_2": loss_2,
        }

        config.trainer_compatible = None
        return outputs
Example #2
0
def convert_transfo_xl_checkpoint_to_pytorch(tf_checkpoint_path,
                                             transfo_xl_config_file,
                                             pytorch_dump_folder_path,
                                             transfo_xl_dataset_file):
    if transfo_xl_dataset_file:
        # Convert a pre-processed corpus (see original TensorFlow repo)
        with open(transfo_xl_dataset_file, "rb") as fp:
            corpus = pickle.load(fp, encoding="latin1")
        # Save vocabulary and dataset cache as Dictionaries (should be better than pickles for the long-term)
        pytorch_vocab_dump_path = pytorch_dump_folder_path + "/" + VOCAB_FILES_NAMES[
            "pretrained_vocab_file"]
        print(f"Save vocabulary to {pytorch_vocab_dump_path}")
        corpus_vocab_dict = corpus.vocab.__dict__
        torch.save(corpus_vocab_dict, pytorch_vocab_dump_path)

        corpus_dict_no_vocab = corpus.__dict__
        corpus_dict_no_vocab.pop("vocab", None)
        pytorch_dataset_dump_path = pytorch_dump_folder_path + "/" + CORPUS_NAME
        print(f"Save dataset to {pytorch_dataset_dump_path}")
        torch.save(corpus_dict_no_vocab, pytorch_dataset_dump_path)

    if tf_checkpoint_path:
        # Convert a pre-trained TensorFlow model
        config_path = os.path.abspath(transfo_xl_config_file)
        tf_path = os.path.abspath(tf_checkpoint_path)

        print(
            f"Converting Transformer XL checkpoint from {tf_path} with config at {config_path}."
        )
        # Initialise PyTorch model
        if transfo_xl_config_file == "":
            config = TransfoXLConfig()
        else:
            config = TransfoXLConfig.from_json_file(transfo_xl_config_file)
        print(f"Building PyTorch model from configuration: {config}")
        model = TransfoXLLMHeadModel(config)

        model = load_tf_weights_in_transfo_xl(model, config, tf_path)
        # Save pytorch-model
        pytorch_weights_dump_path = os.path.join(pytorch_dump_folder_path,
                                                 WEIGHTS_NAME)
        pytorch_config_dump_path = os.path.join(pytorch_dump_folder_path,
                                                CONFIG_NAME)
        print(
            f"Save PyTorch model to {os.path.abspath(pytorch_weights_dump_path)}"
        )
        torch.save(model.state_dict(), pytorch_weights_dump_path)
        print(
            f"Save configuration file to {os.path.abspath(pytorch_config_dump_path)}"
        )
        with open(pytorch_config_dump_path, "w", encoding="utf-8") as f:
            f.write(config.to_json_string())
Example #3
0
        def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
            model = TransfoXLLMHeadModel(config)
            model.eval()

            lm_logits_1, mems_1 = model(input_ids_1)
            loss_1, _, mems_1 = model(input_ids_1, labels=lm_labels)
            lm_logits_2, mems_2 = model(input_ids_2, mems=mems_1)
            loss_2, _, mems_2 = model(input_ids_2, labels=lm_labels, mems=mems_1)

            outputs = {
                "loss_1": loss_1,
                "mems_1": mems_1,
                "lm_logits_1": lm_logits_1,
                "loss_2": loss_2,
                "mems_2": mems_2,
                "lm_logits_2": lm_logits_2,
            }
            return outputs
    def create_transfo_xl_lm_head(self, config, input_ids_1, input_ids_2, lm_labels):
        model = TransfoXLLMHeadModel(config)
        model.to(torch_device)
        model.eval()

        lm_logits_1 = model(input_ids_1)["prediction_scores"]
        outputs1 = model(input_ids_1, labels=lm_labels)
        lm_logits_2 = model(input_ids_2, mems=outputs1["mems"])["prediction_scores"]
        outputs2 = model(input_ids_2, labels=lm_labels, mems=outputs1["mems"])

        outputs = {
            "loss_1": outputs1["losses"],
            "mems_1": outputs1["mems"],
            "lm_logits_1": lm_logits_1,
            "loss_2": outputs2["losses"],
            "mems_2": outputs2["mems"],
            "lm_logits_2": lm_logits_2,
        }
        return outputs