def test_lm_forward(self):
     config, input_ids, batch_size = self._get_config_and_data(output_past=False)
     decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]], self.vocab_size)
     lm_model = BartForMaskedLM(config)
     lm_model.to(torch_device)
     loss, logits, enc_features = lm_model.forward(
         input_ids=input_ids, lm_labels=decoder_lm_labels, decoder_input_ids=input_ids
     )
     expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
     self.assertEqual(logits.shape, expected_shape)
     self.assertIsInstance(loss.item(), float)
    def test_lm_forward(self):
        input_ids = torch.tensor(
            [
                [71, 82, 18, 33, 46, 91, 2],
                [68, 34, 26, 58, 30, 82, 2],
                [5, 97, 17, 39, 94, 40, 2],
                [76, 83, 94, 25, 70, 78, 2],
                [87, 59, 41, 35, 48, 66, 2],
                [55, 13, 16, 58, 5, 2, 1],  # note padding
                [64, 27, 31, 51, 12, 75, 2],
                [52, 64, 86, 17, 83, 39, 2],
                [48, 61, 9, 24, 71, 82, 2],
                [26, 1, 60, 48, 22, 13, 2],
                [21, 5, 62, 28, 14, 76, 2],
                [45, 98, 37, 86, 59, 48, 2],
                [70, 70, 50, 9, 28, 0, 2],
            ],
            dtype=torch.long,
            device=torch_device,
        )
        batch_size = input_ids.shape[0]
        decoder_lm_labels = ids_tensor([batch_size, input_ids.shape[1]],
                                       self.vocab_size)

        config = BartConfig(
            vocab_size=self.vocab_size,
            d_model=24,
            encoder_layers=2,
            decoder_layers=2,
            encoder_attention_heads=2,
            decoder_attention_heads=2,
            encoder_ffn_dim=32,
            decoder_ffn_dim=32,
            max_position_embeddings=48,
        )
        model = BartForSequenceClassification(config)
        model.to(torch_device)
        outputs = model.forward(input_ids=input_ids,
                                decoder_input_ids=input_ids)
        logits = outputs[0]
        expected_shape = torch.Size((batch_size, config.num_labels))
        self.assertEqual(logits.shape, expected_shape)

        lm_model = BartForMaskedLM(config)
        lm_model.to(torch_device)
        loss, logits, enc_features = lm_model.forward(
            input_ids=input_ids,
            lm_labels=decoder_lm_labels,
            decoder_input_ids=input_ids)
        expected_shape = (batch_size, input_ids.shape[1], config.vocab_size)
        self.assertEqual(logits.shape, expected_shape)
        self.assertIsInstance(loss.item(), float)