def create_and_check_xlnet_base_model(self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels): model = TFXLNetModel(config) inputs = {'input_ids': input_ids_1, 'input_mask': input_mask, 'token_type_ids': segment_ids} _, _ = model(inputs) inputs = [input_ids_1, input_mask] outputs, mems_1 = model(inputs) result = { "mems_1": [mem.numpy() for mem in mems_1], "outputs": outputs.numpy(), } config.mem_len = 0 model = TFXLNetModel(config) no_mems_outputs = model(inputs) self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertListEqual( list(result["outputs"].shape), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual( list(list(mem.shape) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def test_model_from_pretrained(self): cache_dir = "/tmp/transformers_test/" for model_name in list( TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = TFXLNetModel.from_pretrained(model_name, cache_dir=cache_dir) shutil.rmtree(cache_dir) self.assertIsNotNone(model)
def create_and_check_xlnet_base_model( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, ): model = TFXLNetModel(config) inputs = { "input_ids": input_ids_1, "input_mask": input_mask, "token_type_ids": segment_ids } result = model(inputs) inputs = [input_ids_1, input_mask] result = model(inputs) config.mem_len = 0 model = TFXLNetModel(config) no_mems_outputs = model(inputs) self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertListEqual( [mem.shape for mem in result.mems], [(self.seq_length, self.batch_size, self.hidden_size)] * self.num_hidden_layers, )
def test_model_from_pretrained(self): for model_name in TF_XLNET_PRETRAINED_MODEL_ARCHIVE_LIST[:1]: model = TFXLNetModel.from_pretrained(model_name) self.assertIsNotNone(model)
def test_model_from_pretrained(self): for model_name in list( TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP.keys())[:1]: model = TFXLNetModel.from_pretrained(model_name, cache_dir=CACHE_DIR) self.assertIsNotNone(model)