Exemplo n.º 1
0
def select_model_mode_for_export(model, mode):
    if not isinstance(model, torch.jit.ScriptFunction):
        is_originally_training = model.training

        if mode is None:
            mode = TrainingMode.EVAL
            # if the model is in training mode but the user did not specify
            # to export the model in training mode, export the model in inference
            # mode (default) and warn them
            if is_originally_training:
                warnings.warn("You are exporting the model to ONNX while in training mode with "
                              "'train' parameter not specified. The model will default to inference mode export. "
                              "If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or "
                              "training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().")

        # if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False
        is_export_training = False
        # ONNX opset 12 has better support for training amenable models, with updated
        # versions of the dropout and batch_norm operators
        if mode == TrainingMode.TRAINING or (mode == TrainingMode.PRESERVE and is_originally_training):
            from torch.onnx.symbolic_helper import _export_onnx_opset_version
            if _export_onnx_opset_version < 12:
                warnings.warn("You are exporting the model in training mode with onnx opset version {}. "
                              "Opset versions lower than opset 12 will not be able to export nodes such as"
                              "Dropout and BatchNorm correctly.".format(_export_onnx_opset_version))
            is_export_training = True

        from torch.onnx.symbolic_helper import _set_training_mode
        _set_training_mode(is_export_training)
        model.train(is_export_training)
    try:
        yield
    finally:
        if not isinstance(model, torch.jit.ScriptFunction):
            model.train(is_originally_training)
Exemplo n.º 2
0
def run_model_test(self,
                   model,
                   batch_size=2,
                   state_dict=None,
                   input=None,
                   use_gpu=True,
                   rtol=0.001,
                   atol=1e-7,
                   example_outputs=None,
                   do_constant_folding=True,
                   dynamic_axes=None,
                   test_with_inputs=None,
                   input_names=None,
                   output_names=None,
                   fixed_batch_size=False):
    model.eval()

    if input is None:
        input = torch.randn(batch_size, 3, 224, 224, requires_grad=True)

    with torch.no_grad():
        if isinstance(input, torch.Tensor):
            input = (input, )
        # In-place operators will update input tensor data as well.
        # Thus inputs are replicated before every forward call.
        input_copy = copy.deepcopy(input)
        output = model(*input_copy)
        if isinstance(output, torch.Tensor):
            output = (output, )

        _set_opset_version(self.opset_version)
        _set_operator_export_type(OperatorExportTypes.ONNX)
        _set_onnx_shape_inference(True)
        _set_training_mode(False)
        if dynamic_axes is None:
            dynamic_axes = {}
        _validate_dynamic_axes(dynamic_axes, model, input_names, output_names)

        input_copy = copy.deepcopy(input)
        graph, _, _ = utils._model_to_graph(
            model,
            input_copy,
            input_names=input_names,
            output_names=output_names,
            operator_export_type=OperatorExportTypes.ONNX,
            example_outputs=output,
            do_constant_folding=do_constant_folding,
            training=TrainingMode.EVAL,
            use_new_jit_passes=self.use_new_jit_passes,
            dynamic_axes=dynamic_axes)
        verify_inferred_shape(graph)