def _test_ensemble_encoder_export(self, test_args): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(models.build_model(test_args, src_dict, tgt_dict)) encoder_ensemble = EncoderEnsemble(model_list) tmp_dir = tempfile.mkdtemp() encoder_pb_path = os.path.join(tmp_dir, 'encoder.pb') encoder_ensemble.onnx_export(encoder_pb_path) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample['net_input']['src_tokens'][0:1].t() src_lengths = sample['net_input']['src_lengths'][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) with open(encoder_pb_path, 'r+b') as f: onnx_model = onnx.load(f) onnx_encoder = caffe2_backend.prepare(onnx_model) caffe2_encoder_outputs = onnx_encoder.run( ( src_tokens.numpy(), src_lengths.numpy(), ), ) for i in range(len(pytorch_encoder_outputs)): caffe2_out_value = caffe2_encoder_outputs[i] pytorch_out_value = pytorch_encoder_outputs[i].data.numpy() np.testing.assert_allclose( caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6, ) encoder_ensemble.save_to_db( os.path.join(tmp_dir, 'encoder.predictor_export'), )
def _test_ensemble_encoder_export(self, test_args): samples, src_dict, tgt_dict = test_utils.prepare_inputs(test_args) task = tasks.DictionaryHolderTask(src_dict, tgt_dict) num_models = 3 model_list = [] for _ in range(num_models): model_list.append(task.build_model(test_args)) encoder_ensemble = EncoderEnsemble(model_list) tmp_dir = tempfile.mkdtemp() encoder_pb_path = os.path.join(tmp_dir, "encoder.pb") encoder_ensemble.onnx_export(encoder_pb_path) # test equivalence # The discrepancy in types here is a temporary expedient. # PyTorch indexing requires int64 while support for tracing # pack_padded_sequence() requires int32. sample = next(samples) src_tokens = sample["net_input"]["src_tokens"][0:1].t() src_lengths = sample["net_input"]["src_lengths"][0:1].int() pytorch_encoder_outputs = encoder_ensemble(src_tokens, src_lengths) onnx_encoder = caffe2_backend.prepare_zip_archive(encoder_pb_path) caffe2_encoder_outputs = onnx_encoder.run( (src_tokens.numpy(), src_lengths.numpy())) for i in range(len(pytorch_encoder_outputs)): caffe2_out_value = caffe2_encoder_outputs[i] pytorch_out_value = pytorch_encoder_outputs[i].detach().numpy() np.testing.assert_allclose(caffe2_out_value, pytorch_out_value, rtol=1e-4, atol=1e-6) encoder_ensemble.save_to_db( os.path.join(tmp_dir, "encoder.predictor_export"))