コード例 #1
0
    def test_advanced_inputs(self):
        config, inputs_dict = self.model_tester.prepare_config_and_inputs()
        config.use_cache = False
        inputs_dict["input_ids"][:, -2:] = config.pad_token_id
        decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
            config, inputs_dict["input_ids"]
        )
        model = FSMTModel(config).to(torch_device).eval()

        decoder_features_with_created_mask = model(**inputs_dict)[0]
        decoder_features_with_passed_mask = model(
            decoder_attention_mask=invert_mask(decoder_attn_mask), decoder_input_ids=decoder_input_ids, **inputs_dict
        )[0]
        _assert_tensors_equal(decoder_features_with_passed_mask, decoder_features_with_created_mask)
        useless_mask = torch.zeros_like(decoder_attn_mask)
        decoder_features = model(decoder_attention_mask=useless_mask, **inputs_dict)[0]
        self.assertTrue(isinstance(decoder_features, torch.Tensor))  # no hidden states or attentions
        self.assertEqual(
            decoder_features.size(),
            (self.model_tester.batch_size, self.model_tester.seq_length, config.tgt_vocab_size),
        )
        if decoder_attn_mask.min().item() < -1e3:  # some tokens were masked
            self.assertFalse((decoder_features_with_created_mask == decoder_features).all().item())

        # Test different encoder attention masks
        decoder_features_with_long_encoder_mask = model(
            inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"].long()
        )[0]
        _assert_tensors_equal(decoder_features_with_long_encoder_mask, decoder_features_with_created_mask)
コード例 #2
0
 def test_prepare_fsmt_decoder_inputs(self):
     config, *_ = self._get_config_and_data()
     input_ids = _long_tensor(([4, 4, 2]))
     decoder_input_ids = _long_tensor([[26388, 2, config.pad_token_id]])
     ignore = float("-inf")
     decoder_input_ids, decoder_attn_mask, causal_mask = _prepare_fsmt_decoder_inputs(
         config, input_ids, decoder_input_ids)
     expected_causal_mask = torch.tensor(
         [[0, ignore, ignore], [0, 0, ignore],
          [0, 0, 0]]  # never attend to the final token, because its pad
     ).to(input_ids.device)
     self.assertEqual(decoder_attn_mask.size(), decoder_input_ids.size())
     self.assertTrue(torch.eq(expected_causal_mask, causal_mask).all())