def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask,
                target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels):
            model = TFXLNetModel(config)

            inputs = {'input_ids': input_ids_1,
                      'input_mask': input_mask,
                      'token_type_ids': segment_ids}

            _, _ = model(inputs)

            inputs = [input_ids_1, input_mask]

            outputs, mems_1 = model(inputs)

            result = {
                "mems_1": [mem.numpy() for mem in mems_1],
                "outputs": outputs.numpy(),
            }

            config.mem_len = 0
            model = TFXLNetModel(config)
            no_mems_outputs = model(inputs)
            self.parent.assertEqual(len(no_mems_outputs), 1)

            self.parent.assertListEqual(
                list(result["outputs"].shape),
                [self.batch_size, self.seq_length, self.hidden_size])
            self.parent.assertListEqual(
                list(list(mem.shape) for mem in result["mems_1"]),
                [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
 def test_model_from_pretrained(self):
     cache_dir = "/tmp/transformers_test/"
     for model_name in list(
             TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = TFXLNetModel.from_pretrained(model_name,
                                              cache_dir=cache_dir)
         shutil.rmtree(cache_dir)
         self.assertIsNotNone(model)
    def create_and_check_xlnet_base_model(
        self,
        config,
        input_ids_1,
        input_ids_2,
        input_ids_q,
        perm_mask,
        input_mask,
        target_mapping,
        segment_ids,
        lm_labels,
        sequence_labels,
        is_impossible_labels,
    ):
        model = TFXLNetModel(config)

        inputs = {
            "input_ids": input_ids_1,
            "input_mask": input_mask,
            "token_type_ids": segment_ids
        }
        result = model(inputs)

        inputs = [input_ids_1, input_mask]
        result = model(inputs)

        config.mem_len = 0
        model = TFXLNetModel(config)
        no_mems_outputs = model(inputs)
        self.parent.assertEqual(len(no_mems_outputs), 1)

        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertListEqual(
            [mem.shape for mem in result.mems],
            [(self.seq_length, self.batch_size, self.hidden_size)] *
            self.num_hidden_layers,
        )
示例#4
0
 def test_model_from_pretrained(self):
     for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
         model = TFXLNetModel.from_pretrained(model_name)
         self.assertIsNotNone(model)
 def test_model_from_pretrained(self):
     for model_name in list(
             TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]:
         model = TFXLNetModel.from_pretrained(model_name,
                                              cache_dir=CACHE_DIR)
         self.assertIsNotNone(model)