def test_lm_generate_ctrl(self): model = CTRLLMHeadModel.from_pretrained("ctrl") model.to(torch_device) input_ids = torch.tensor([[11859, 0, 1611, 8]], dtype=torch.long, device=torch_device) # Legal the president is expected_output_ids = [ 11859, 0, 1611, 8, 5, 150, 26449, 2, 19, 348, 469, 3, 2595, 48, 20740, 246533, 246533, 19, 30, 5, ] # Legal the president is a good guy and I don't want to lose my job. \n \n I have a output_ids = model.generate(input_ids, do_sample=False) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)
def test_lm_generate_ctrl(self): model = CTRLLMHeadModel.from_pretrained("ctrl") input_ids = torch.Tensor([[11859, 586, 20984, 8]]).long() # Legal My neighbor is expected_output_ids = [ 11859, 586, 20984, 8, 13391, 3, 980, 8258, 72, 327, 148, 2, 53, 29, 226, 3, 780, 49, 3, 980, ] # Legal My neighbor is refusing to pay rent after 2 years and we are having to force him to pay torch.manual_seed(0) output_ids = model.generate(input_ids) self.assertListEqual(output_ids[0].tolist(), expected_output_ids)