예제 #1
0
 def test_assert_jit_vs_nonjit_(self):
     task, parser = get_dummy_task_and_parser()
     LSTMModel.add_args(parser)
     args = parser.parse_args([])
     args.criterion = ""
     model = LSTMModel.build_model(args, task)
     model.eval()
     scripted_model = torch.jit.script(model)
     scripted_model.eval()
     idx = len(task.source_dictionary)
     iter = 100
     # Inject random input and check output
     seq_len_tensor = torch.randint(1, 10, (iter, ))
     num_samples_tensor = torch.randint(1, 10, (iter, ))
     for i in range(iter):
         seq_len = seq_len_tensor[i]
         num_samples = num_samples_tensor[i]
         src_token = torch.randint(0, idx, (num_samples, seq_len)),
         src_lengths = torch.randint(1, seq_len + 1, (num_samples, ))
         src_lengths, _ = torch.sort(src_lengths, descending=True)
         # Force the first sample to have seq_len
         src_lengths[0] = seq_len
         prev_output_token = torch.randint(0, idx, (num_samples, 1)),
         result = model(src_token[0], src_lengths, prev_output_token[0],
                        None)
         scripted_result = scripted_model(src_token[0], src_lengths,
                                          prev_output_token[0], None)
         self.assertTensorEqual(result[0], scripted_result[0])
         self.assertTensorEqual(result[1], scripted_result[1])
예제 #2
0
 def test_jit_and_export_lstm(self):
     task, parser = get_dummy_task_and_parser()
     LSTMModel.add_args(parser)
     args = parser.parse_args([])
     args.criterion = ""
     model = LSTMModel.build_model(args, task)
     scripted_model = torch.jit.script(model)
     self._test_save_and_load(scripted_model)