Ejemplo n.º 1
0
 def _model_to_graph(self,
                     model,
                     input,
                     do_constant_folding=True,
                     training=TrainingMode.EVAL,
                     operator_export_type=OperatorExportTypes.ONNX,
                     input_names=None,
                     dynamic_axes=None):
     if training == torch.onnx.TrainingMode.TRAINING:
         model.train()
     elif training == torch.onnx.TrainingMode.EVAL:
         model.eval()
     # Need disable onnx_shape_inference for this test because it puts const node to initializers.
     _set_onnx_shape_inference(False)
     utils._validate_dynamic_axes(dynamic_axes, model, None, None)
     graph, params_dict, torch_out = utils._model_to_graph(
         model,
         input,
         do_constant_folding=do_constant_folding,
         _disable_torch_constant_prop=True,
         operator_export_type=operator_export_type,
         training=training,
         input_names=input_names,
         dynamic_axes=dynamic_axes)
     _set_onnx_shape_inference(True)
     return graph, params_dict, torch_out
Ejemplo n.º 2
0
 def test_validate_dynamic_axes_invalid_input_output_name(self):
     import warnings
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         utils._validate_dynamic_axes({'input1': {}, 'output': {},
                                      'invalid_name1': {}, 'invalid_name2': {}},
                                      None, ['input1', 'input2'], ['output'])
         messages = [str(warning.message) for warning in w]
     assert "Provided key invalid_name1 for dynamic axes is not a valid input/output name" in messages
     assert "Provided key invalid_name2 for dynamic axes is not a valid input/output name" in messages
     assert len(messages) == 2
Ejemplo n.º 3
0
def run_model_test(self,
                   model,
                   batch_size=2,
                   state_dict=None,
                   input=None,
                   use_gpu=True,
                   rtol=0.001,
                   atol=1e-7,
                   example_outputs=None,
                   do_constant_folding=True,
                   dynamic_axes=None,
                   test_with_inputs=None,
                   input_names=None,
                   output_names=None,
                   fixed_batch_size=False):
    model.eval()

    if input is None:
        input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)

    with torch.no_grad():
        if isinstance(input, torch.Tensor):
            input = (input, )
        # In-place operators will update input tensor data as well.
        # Thus inputs are replicated before every forward call.
        input_copy = copy.deepcopy(input)
        output = model(*input_copy)
        if isinstance(output, torch.Tensor):
            output = (output, )

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        _set_onnx_shape_inference(True)
        _set_training_mode(False)
        if dynamic_axes is None:
            dynamic_axes = {}
        _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        input_copy = copy.deepcopy(input)
        graph, _, _ = utils._model_to_graph(
            model,
            input_copy,
            input_names=input_names,
            output_names=output_names,
            operator_export_type=OperatorExportTypes.ONNX,
            example_outputs=output,
            do_constant_folding=do_constant_folding,
            training=TrainingMode.EVAL,
            use_new_jit_passes=self.use_new_jit_passes,
            dynamic_axes=dynamic_axes)
        verify_inferred_shape(graph)
Ejemplo n.º 4
0
 def test_validate_dynamic_axes_invalid_input_output_name(self):
     import warnings
     with warnings.catch_warnings(record=True) as w:
         warnings.simplefilter("always")
         utils._validate_dynamic_axes({"input1": {}, "output": {},
                                      "invalid_name1": {}, "invalid_name2": {}},
                                      None, ["input1", "input2"], ["output"])
         messages = [str(warning.message) for warning in w]
     self.assertIn(
         "Provided key invalid_name1 for dynamic axes is not a valid input/output name",
         messages)
     self.assertIn(
         "Provided key invalid_name2 for dynamic axes is not a valid input/output name",
         messages)
     self.assertEqual(len(messages), 2)
Ejemplo n.º 5
0
def _onnx_graph_from_model(
    model: Union[torch.nn.Module, torch.jit.ScriptModule],
    args: Tuple[Any, ...],
    kwargs: Mapping[str, Any],
    export_options: _experimental.ExportOptions,
) -> _C.Graph:
    """As part of the ONNX export steps, export an ONNX JIT graph from a PyTorch model.

    Args:
        model: See :func:`check_export_model_diff`.
        args: See :func:`check_export_model_diff`.
        kwargs: See :func:`check_export_model_diff`.
        export_options: See :func:`check_export_model_diff`.

    Returns:
        onnx_graph (_C.Graph): An ONNX JIT graph.
    """
    # TODO: refactor utils.py to remove duplicated code of context setup. See #78834
    opset_version = export_options.opset_version
    operator_export_type = export_options.operator_export_type
    export_modules_as_functions = export_options.export_modules_as_functions
    training = export_options.training
    verbose = export_options.verbose
    dynamic_axes = export_options.dynamic_axes
    input_names = export_options.input_names
    output_names = export_options.output_names

    if opset_version is None:
        opset_version = _constants.onnx_default_opset

    export_modules_as_functions = utils._setup_trace_module_map(
        model, export_modules_as_functions
    )

    if not operator_export_type:
        if _C_onnx._CAFFE2_ATEN_FALLBACK:
            operator_export_type = _C_onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
        else:
            operator_export_type = _C_onnx.OperatorExportTypes.ONNX

    GLOBALS.export_onnx_opset_version = opset_version
    GLOBALS.operator_export_type = operator_export_type

    with utils.exporter_context(model, training, verbose):
        do_constant_folding = utils._decide_constant_folding(
            export_options.do_constant_folding, operator_export_type, training
        )

        if dynamic_axes is None:
            dynamic_axes = {}
        utils._validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        export_inputs = _prepare_input_for_export(args, kwargs)
        export_inputs = utils._decide_input_format(model, export_inputs)
        onnx_graph, _, _ = utils._model_to_graph(
            model,
            export_inputs,
            verbose,
            input_names,
            output_names,
            operator_export_type,
            do_constant_folding,
            training=training,
            dynamic_axes=dynamic_axes,
        )

        return onnx_graph