def select_model_mode_for_export(model, mode): if not isinstance(model, torch.jit.ScriptFunction): is_originally_training = model.training if mode is None: mode = TrainingMode.EVAL # if the model is in training mode but the user did not specify # to export the model in training mode, export the model in inference # mode (default) and warn them if is_originally_training: warnings.warn("You are exporting the model to ONNX while in training mode with " "'train' parameter not specified. The model will default to inference mode export. " "If you wish to export a training amenable ONNX model, specify training=TrainingMode.TRAINING or " "training=TrainingMode.PRESERVE (to preserve the original model state) in torch.onnx.export().") # if mode == TrainingMode.EVAL or (mode == TrainingMode.PRESERVE and not is_originally_training) => is_training = False is_export_training = False # ONNX opset 12 has better support for training amenable models, with updated # versions of the dropout and batch_norm operators if mode == TrainingMode.TRAINING or (mode == TrainingMode.PRESERVE and is_originally_training): from torch.onnx.symbolic_helper import _export_onnx_opset_version if _export_onnx_opset_version < 12: warnings.warn("You are exporting the model in training mode with onnx opset version {}. " "Opset versions lower than opset 12 will not be able to export nodes such as" "Dropout and BatchNorm correctly.".format(_export_onnx_opset_version)) is_export_training = True from torch.onnx.symbolic_helper import _set_training_mode _set_training_mode(is_export_training) model.train(is_export_training) try: yield finally: if not isinstance(model, torch.jit.ScriptFunction): model.train(is_originally_training)
def run_model_test(self, model, batch_size=2, state_dict=None, input=None, use_gpu=True, rtol=0.001, atol=1e-7, example_outputs=None, do_constant_folding=True, dynamic_axes=None, test_with_inputs=None, input_names=None, output_names=None, fixed_batch_size=False): model.eval() if input is None: input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) with torch.no_grad(): if isinstance(input, torch.Tensor): input = (input, ) # In-place operators will update input tensor data as well. # Thus inputs are replicated before every forward call. input_copy = copy.deepcopy(input) output = model(*input_copy) if isinstance(output, torch.Tensor): output = (output, ) _set_opset_version(self.opset_version) _set_operator_export_type(OperatorExportTypes.ONNX) _set_onnx_shape_inference(True) _set_training_mode(False) if dynamic_axes is None: dynamic_axes = {} _validate_dynamic_axes(dynamic_axes, model, input_names, output_names) input_copy = copy.deepcopy(input) graph, _, _ = utils._model_to_graph( model, input_copy, input_names=input_names, output_names=output_names, operator_export_type=OperatorExportTypes.ONNX, example_outputs=output, do_constant_folding=do_constant_folding, training=TrainingMode.EVAL, use_new_jit_passes=self.use_new_jit_passes, dynamic_axes=dynamic_axes) verify_inferred_shape(graph)