def create_and_check_model_with_global_attention_mask( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels): model = LongformerModel(config=config) model.to(torch_device) model.eval() global_attention_mask = input_mask.clone() global_attention_mask[:, input_mask.shape[-1] // 2] = 0 global_attention_mask = global_attention_mask.to(torch_device) result = model( input_ids, attention_mask=input_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, ) result = model(input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask) result = model(input_ids, global_attention_mask=global_attention_mask) self.parent.assertEqual( result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))
def create_and_check_longformer_model_with_global_attention_mask( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = LongformerModel(config=config) model.to(torch_device) model.eval() global_attention_mask = input_mask.clone() global_attention_mask[:, input_mask.shape[-1] // 2] = 0 global_attention_mask = global_attention_mask.to(torch_device) sequence_output, pooled_output = model( input_ids, attention_mask=input_mask, global_attention_mask=global_attention_mask, token_type_ids=token_type_ids, ) sequence_output, pooled_output = model( input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask ) sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask) result = { "sequence_output": sequence_output, "pooled_output": pooled_output, } self.parent.assertListEqual( list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size] ) self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
def create_and_check_attention_mask_determinism( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = LongformerModel(config=config) model.to(torch_device) model.eval() attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device) output_with_mask = model(input_ids, attention_mask=attention_mask)["last_hidden_state"] output_without_mask = model(input_ids)["last_hidden_state"] self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
def create_and_check_longformer_model( self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels ): model = LongformerModel(config=config) model.to(torch_device) model.eval() result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids) result = model(input_ids, token_type_ids=token_type_ids) result = model(input_ids) self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size)) self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))