def create_and_check_xlnet_base_model_with_att_output( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetModel(config) model.to(torch_device) model.eval() attentions = model(input_ids_1, target_mapping=target_mapping, output_attentions=True)["attentions"] self.parent.assertEqual(len(attentions), config.n_layer) self.parent.assertIsInstance(attentions[0], tuple) self.parent.assertEqual(len(attentions[0]), 2) self.parent.assertTrue(attentions[0][0].shape, attentions[0][0].shape)
def create_and_check_xlnet_base_model( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels): model = XLNetModel(config) model.to(torch_device) model.eval() _, _ = model(input_ids_1, input_mask=input_mask) _, _ = model(input_ids_1, attention_mask=input_mask) _, _ = model(input_ids_1, token_type_ids=segment_ids) outputs, mems_1 = model(input_ids_1) result = { "mems_1": mems_1, "outputs": outputs, } config.mem_len = 0 model = XLNetModel(config) model.to(torch_device) model.eval() no_mems_outputs = model(input_ids_1) self.parent.assertEqual(len(no_mems_outputs), 1) self.parent.assertListEqual( list(result["outputs"].size()), [self.batch_size, self.seq_length, self.hidden_size]) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems_1"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers)
def get_model(model_dir, saved_dir, n_labels, task, model_type='xlnet-base-cased'): config = XLNetConfig.from_pretrained( model_type, num_labels=n_labels, finetuning_task=task, ) model = XLNetModel(config=config) model.load_state_dict(torch.load(saved_dir+'/pytorch_model.bin')) model.eval() # set dropout and batch normalisation layers to evaluation tokeniser = XLNetTokenizer.from_pretrained( model_dir, config=config, from_tf=True, ) return model, tokeniser
def create_and_check_xlnet_base_model( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetModel(config) model.to(torch_device) model.eval() result = model(input_ids_1, input_mask=input_mask) result = model(input_ids_1, attention_mask=input_mask) result = model(input_ids_1, token_type_ids=segment_ids) result = model(input_ids_1) config.mem_len = 0 model = XLNetModel(config) model.to(torch_device) model.eval() base_model_output = model(input_ids_1) self.parent.assertEqual(len(base_model_output), 2) self.parent.assertListEqual( list(result["last_hidden_state"].size()), [self.batch_size, self.seq_length, self.hidden_size], ) self.parent.assertListEqual( list(list(mem.size()) for mem in result["mems"]), [[self.seq_length, self.batch_size, self.hidden_size]] * self.num_hidden_layers, )
def create_and_check_xlnet_model_use_cache( self, config, input_ids_1, input_ids_2, input_ids_q, perm_mask, input_mask, target_mapping, segment_ids, lm_labels, sequence_labels, is_impossible_labels, token_labels, ): model = XLNetModel(config=config) model.to(torch_device) model.eval() # first forward pass causal_mask = torch.ones( input_ids_1.shape[0], input_ids_1.shape[1], input_ids_1.shape[1], dtype=torch.float, device=torch_device, ) causal_mask = torch.triu(causal_mask, diagonal=0) outputs_cache = model(input_ids_1, use_cache=True, perm_mask=causal_mask) outputs_no_cache = model(input_ids_1, use_cache=False, perm_mask=causal_mask) outputs_conf = model(input_ids_1) self.parent.assertTrue(len(outputs_cache) == len(outputs_conf)) self.parent.assertTrue(len(outputs_cache) == len(outputs_no_cache) + 1) output, mems = outputs_cache.to_tuple() # create hypothetical next token and extent to next_input_ids next_tokens = ids_tensor((self.batch_size, 1), config.vocab_size) # append to next input_ids and token_type_ids next_input_ids = torch.cat([input_ids_1, next_tokens], dim=-1) # causal mask causal_mask = torch.ones( input_ids_1.shape[0], input_ids_1.shape[1] + 1, input_ids_1.shape[1] + 1, dtype=torch.float, device=torch_device, ) causal_mask = torch.triu(causal_mask, diagonal=0) single_mask = torch.ones(input_ids_1.shape[0], 1, 1, dtype=torch.float, device=torch_device) # second forward pass output_from_no_past = model(next_input_ids, perm_mask=causal_mask)["last_hidden_state"] output_from_past = model(next_tokens, mems=mems, perm_mask=single_mask)["last_hidden_state"] # select random slice random_slice_idx = ids_tensor((1, ), output_from_past.shape[-1]).item() output_from_no_past_slice = output_from_no_past[:, -1, random_slice_idx].detach( ) output_from_past_slice = output_from_past[:, 0, random_slice_idx].detach() # test that outputs are equal for slice self.parent.assertTrue( torch.allclose(output_from_past_slice, output_from_no_past_slice, atol=1e-3))