def test_model2model_from_pretrained(self):
     logging.basicConfig(level=logging.INFO)
     for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = Model2Model.from_pretrained(model_name)
         self.assertIsInstance(model.encoder, BertModel)
         self.assertIsInstance(model.decoder, BertForMaskedLM)
         self.assertEqual(model.decoder.config.is_decoder, True)
         self.assertEqual(model.encoder.config.is_decoder, False)
Пример #2
0
    def test_question_answering_model_from_pretrained(self):
        logging.basicConfig(level=logging.INFO)
        for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = AutoModelForQuestionAnswering.from_pretrained(model_name)
            model, loading_info = AutoModelForQuestionAnswering.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, BertForQuestionAnswering)
Пример #3
0
    def test_model_from_pretrained(self):
        logging.basicConfig(level=logging.INFO)
        for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = AutoModel.from_pretrained(model_name)
            model, loading_info = AutoModel.from_pretrained(model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, BertModel)
            for value in loading_info.values():
                self.assertEqual(len(value), 0)
    def test_model_for_pretraining_from_pretrained(self):
        logging.basicConfig(level=logging.INFO)
        for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
            config = AutoConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, BertConfig)

            model = AutoModelForPreTraining.from_pretrained(model_name)
            model, loading_info = AutoModelForPreTraining.from_pretrained(
                model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, BertForPreTraining)
            for key, value in loading_info.items():
                # Only one value should not be initialized and in the missing keys.
                self.assertEqual(len(value), 1 if key == "missing_keys" else 0)
Пример #5
0
    def test_model_from_pretrained(self):
        logging.basicConfig(level=logging.INFO)
        for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
            config = BertConfig.from_pretrained(model_name)
            self.assertIsNotNone(config)
            self.assertIsInstance(config, PretrainedConfig)

            model = BertModel.from_pretrained(model_name)
            model, loading_info = BertModel.from_pretrained(
                model_name, output_loading_info=True)
            self.assertIsNotNone(model)
            self.assertIsInstance(model, PreTrainedModel)
            for value in loading_info.values():
                self.assertEqual(len(value), 0)

            config = BertConfig.from_pretrained(model_name,
                                                output_attentions=True,
                                                output_hidden_states=True)
            model = BertModel.from_pretrained(model_name,
                                              output_attentions=True,
                                              output_hidden_states=True)
            self.assertEqual(model.config.output_attentions, True)
            self.assertEqual(model.config.output_hidden_states, True)
            self.assertEqual(model.config, config)
Пример #6
0
 def test_model_from_pretrained(self):
     cache_dir = "/tmp/transformers_test/"
     for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = BertModel.from_pretrained(model_name, cache_dir=cache_dir)
         shutil.rmtree(cache_dir)
         self.assertIsNotNone(model)
Пример #7
0
 def test_model_from_pretrained(self):
     for model_name in list(BERT_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = BertModel.from_pretrained(model_name, cache_dir=CACHE_DIR)
         self.assertIsNotNone(model)
Пример #8
0
    vocab = Vocab(
        list_of_tokens=idx_to_token,
        unknown_token="[UNK]",
        padding_token="[PAD]",
        bos_token=None,
        eos_token=None,
        reserved_tokens=["[CLS]", "[SEP]", "[MASK]"],
        token_to_idx=token_to_idx
    )
    vocab_filename = "{}-vocab.pkl".format(args.type)
    vocab_filepath = ptr_dir / vocab_filename

    if not vocab_filepath.exists():
        with open(vocab_filepath, mode="wb") as io:
            pickle.dump(vocab, io)
    else:
        print("Already you have {}".format(vocab_filename))

    print("Saving the vocab of {} is done".format(args.type))

    # Saving weights of pretrained model
    weight_url = BERT_PRETRAINED_MODEL_ARCHIVE_MAP.get(args.type)
    weight_filename = weight_url.split("/")[-1]
    weight_filepath = ptr_dir / weight_filename

    if not weight_filepath.exists():
        urlretrieve(weight_url, weight_filepath)
    else:
        print("Already you have {}".format(weight_filename))

    print("Saving weights of {} is done".format(args.type))