def create_and_check_model_with_global_attention_mask(
            self, config, input_ids, token_type_ids, input_mask,
            sequence_labels, token_labels, choice_labels):
        model = LongformerModel(config=config)
        model.to(torch_device)
        model.eval()
        global_attention_mask = input_mask.clone()
        global_attention_mask[:, input_mask.shape[-1] // 2] = 0
        global_attention_mask = global_attention_mask.to(torch_device)

        result = model(
            input_ids,
            attention_mask=input_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
        )
        result = model(input_ids,
                       token_type_ids=token_type_ids,
                       global_attention_mask=global_attention_mask)
        result = model(input_ids, global_attention_mask=global_attention_mask)

        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape,
                                (self.batch_size, self.hidden_size))
    def create_and_check_longformer_model_with_global_attention_mask(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = LongformerModel(config=config)
        model.to(torch_device)
        model.eval()
        global_attention_mask = input_mask.clone()
        global_attention_mask[:, input_mask.shape[-1] // 2] = 0
        global_attention_mask = global_attention_mask.to(torch_device)

        sequence_output, pooled_output = model(
            input_ids,
            attention_mask=input_mask,
            global_attention_mask=global_attention_mask,
            token_type_ids=token_type_ids,
        )
        sequence_output, pooled_output = model(
            input_ids, token_type_ids=token_type_ids, global_attention_mask=global_attention_mask
        )
        sequence_output, pooled_output = model(input_ids, global_attention_mask=global_attention_mask)

        result = {
            "sequence_output": sequence_output,
            "pooled_output": pooled_output,
        }
        self.parent.assertListEqual(
            list(result["sequence_output"].size()), [self.batch_size, self.seq_length, self.hidden_size]
        )
        self.parent.assertListEqual(list(result["pooled_output"].size()), [self.batch_size, self.hidden_size])
    def create_and_check_attention_mask_determinism(
        self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
    ):
        model = LongformerModel(config=config)
        model.to(torch_device)
        model.eval()

        attention_mask = torch.ones(input_ids.shape, dtype=torch.long, device=torch_device)
        output_with_mask = model(input_ids, attention_mask=attention_mask)["last_hidden_state"]
        output_without_mask = model(input_ids)["last_hidden_state"]
        self.parent.assertTrue(torch.allclose(output_with_mask[0, 0, :5], output_without_mask[0, 0, :5], atol=1e-4))
Beispiel #4
0
 def create_and_check_longformer_model(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = LongformerModel(config=config)
     model.to(torch_device)
     model.eval()
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
     result = model(input_ids, token_type_ids=token_type_ids)
     result = model(input_ids)
     self.parent.assertEqual(result.last_hidden_state.shape, (self.batch_size, self.seq_length, self.hidden_size))
     self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.hidden_size))