def test_lm_generate_ctrl(self):
        model = TFCTRLLMHeadModel.from_pretrained("ctrl")
        input_ids = tf.convert_to_tensor([[11859, 0, 1611, 8]], dtype=tf.int32)  # Legal the president is
        expected_output_ids = [
            11859,
            0,
            1611,
            8,
            5,
            150,
            26449,
            2,
            19,
            348,
            469,
            3,
            2595,
            48,
            20740,
            246533,
            246533,
            19,
            30,
            5,
        ]  # Legal the president is a good guy and I don't want to lose my job. \n \n I have a

        output_ids = model.generate(input_ids, do_sample=False)
        self.assertListEqual(output_ids[0].numpy().tolist(), expected_output_ids)
Ejemplo n.º 2
0
 def create_and_check_ctrl_lm_head(self, config, input_ids, input_mask,
                                   head_mask, token_type_ids, *args):
     model = TFCTRLLMHeadModel(config=config)
     inputs = {
         "input_ids": input_ids,
         "attention_mask": input_mask,
         "token_type_ids": token_type_ids
     }
     result = model(inputs)
     self.parent.assertEqual(
         result.logits.shape,
         (self.batch_size, self.seq_length, self.vocab_size))