def test_lm_generate_ctrl(self):
        model = CTRLLMHeadModel.from_pretrained("ctrl")
        model.to(torch_device)
        input_ids = torch.tensor([[11859, 0, 1611, 8]],
                                 dtype=torch.long,
                                 device=torch_device)  # Legal the president is
        expected_output_ids = [
            11859,
            0,
            1611,
            8,
            5,
            150,
            26449,
            2,
            19,
            348,
            469,
            3,
            2595,
            48,
            20740,
            246533,
            246533,
            19,
            30,
            5,
        ]  # Legal the president is a good guy and I don't want to lose my job. \n \n I have a

        output_ids = model.generate(input_ids, do_sample=False)
        self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
示例#2
0
    def test_lm_generate_ctrl(self):
        model = CTRLLMHeadModel.from_pretrained("ctrl")
        input_ids = torch.Tensor([[11859, 586, 20984,
                                   8]]).long()  # Legal My neighbor is
        expected_output_ids = [
            11859,
            586,
            20984,
            8,
            13391,
            3,
            980,
            8258,
            72,
            327,
            148,
            2,
            53,
            29,
            226,
            3,
            780,
            49,
            3,
            980,
        ]  # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay
        torch.manual_seed(0)

        output_ids = model.generate(input_ids)
        self.assertListEqual(output_ids[0].tolist(), expected_output_ids)