def _model_to_graph(self, model, input, do_constant_folding=True, training=TrainingMode.EVAL, operator_export_type=OperatorExportTypes.ONNX, input_names=None, dynamic_axes=None): if training == torch.onnx.TrainingMode.TRAINING: model.train() elif training == torch.onnx.TrainingMode.EVAL: model.eval() # Need disable onnx_shape_inference for this test because it puts const node to initializers. _set_onnx_shape_inference(False) utils._validate_dynamic_axes(dynamic_axes, model, None, None) graph, params_dict, torch_out = utils._model_to_graph( model, input, do_constant_folding=do_constant_folding, _disable_torch_constant_prop=True, operator_export_type=operator_export_type, training=training, input_names=input_names, dynamic_axes=dynamic_axes) _set_onnx_shape_inference(True) return graph, params_dict, torch_out
def test_validate_dynamic_axes_invalid_input_output_name(self): import warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") utils._validate_dynamic_axes({'input1': {}, 'output': {}, 'invalid_name1': {}, 'invalid_name2': {}}, None, ['input1', 'input2'], ['output']) messages = [str(warning.message) for warning in w] assert "Provided key invalid_name1 for dynamic axes is not a valid input/output name" in messages assert "Provided key invalid_name2 for dynamic axes is not a valid input/output name" in messages assert len(messages) == 2
def run_model_test(self, model, batch_size=2, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, dynamic_axes=None, test_with_inputs=None, input_names=None, output_names=None, fixed_batch_size=False): model.eval() if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input, ) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. input_copy = copy.deepcopy(input) output = model(*input_copy) if isinstance(output, torch.Tensor): output = (output, ) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) _set_onnx_shape_inference(True) _set_training_mode(False) if dynamic_axes is None: dynamic_axes = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) input_copy = copy.deepcopy(input) graph, _, _ = utils._model_to_graph( model, input_copy, input_names=input_names, output_names=output_names, operator_export_type=OperatorExportTypes.ONNX, example_outputs=output, do_constant_folding=do_constant_folding, training=TrainingMode.EVAL, use_new_jit_passes=self.use_new_jit_passes, dynamic_axes=dynamic_axes) verify_inferred_shape(graph)
def test_validate_dynamic_axes_invalid_input_output_name(self): import warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") utils._validate_dynamic_axes({"input1": {}, "output": {}, "invalid_name1": {}, "invalid_name2": {}}, None, ["input1", "input2"], ["output"]) messages = [str(warning.message) for warning in w] self.assertIn( "Provided key invalid_name1 for dynamic axes is not a valid input/output name", messages) self.assertIn( "Provided key invalid_name2 for dynamic axes is not a valid input/output name", messages) self.assertEqual(len(messages), 2)
def _onnx_graph_from_model( model: Union[torch.nn.Module, torch.jit.ScriptModule], args: Tuple[Any, ...], kwargs: Mapping[str, Any], export_options: _experimental.ExportOptions, ) -> _C.Graph: """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model. Args: model: See :func:`check_export_model_diff`. args: See :func:`check_export_model_diff`. kwargs: See :func:`check_export_model_diff`. export_options: See :func:`check_export_model_diff`. Returns: onnx_graph (_C.Graph): An ONNX JIT graph. """ # TODO: refactor utils.py to remove duplicated code of context setup. See #78834 opset_version = export_options.opset_version operator_export_type = export_options.operator_export_type export_modules_as_functions = export_options.export_modules_as_functions training = export_options.training verbose = export_options.verbose dynamic_axes = export_options.dynamic_axes input_names = export_options.input_names output_names = export_options.output_names if opset_version is None: opset_version = _constants.onnx_default_opset export_modules_as_functions = utils._setup_trace_module_map( model, export_modules_as_functions ) if not operator_export_type: if _C_onnx._CAFFE2_ATEN_FALLBACK: operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK else: operator_export_type = _C_onnx.OperatorExportTypes.ONNX GLOBALS.export_onnx_opset_version = opset_version GLOBALS.operator_export_type = operator_export_type with utils.exporter_context(model, training, verbose): do_constant_folding = utils._decide_constant_folding( export_options.do_constant_folding, operator_export_type, training ) if dynamic_axes is None: dynamic_axes = {} utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names) export_inputs = _prepare_input_for_export(args, kwargs) export_inputs = utils._decide_input_format(model, export_inputs) onnx_graph, _, _ = utils._model_to_graph( model, export_inputs, verbose, input_names, output_names, operator_export_type, do_constant_folding, training=training, dynamic_axes=dynamic_axes, ) return onnx_graph