def test_trans_output_as_graph_outputs(self): """ If transpose's output is graph's output, don't optimize it. """ trans = helper.make_node("Transpose", ["X"], ["Y"], name="trans", perm=[0, 2, 3, 1]) graph_proto = helper.make_graph( [trans], "trans-to-graph-output", [helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))], [helper.make_tensor_value_info("Y", TensorProto.FLOAT, (2, 4, 5, 3))], ) graph = GraphUtil.create_graph_from_onnx_graph(graph_proto) # remove identity to graph output identity_op = graph.get_node_by_output(graph.outputs[0]) graph.outputs = [identity_op.input[0]] graph.remove_node(identity_op.name) optimized_graph = GraphUtil.optimize_graph(graph) self.assertTrue(optimized_graph, msg="graph after optimizer should not be None") trans_cnt = len(group_nodes_by_type(optimized_graph)["Transpose"]) self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " + str(trans_cnt) + " left")
def run_test(self, name, backend="caffe2", debug=False, onnx_file=None, opset=None, perf=None, fold_const=None): """Run complete test against backend.""" print(name) self.perf = perf # get the model if self.url: _, dir_name = self.download_file() model_path = os.path.join(dir_name, self.local) else: model_path = self.local dir_name = os.path.dirname(self.local) print("\tdownloaded", model_path) inputs = list(self.input_names.keys()) outputs = self.output_names if self.model_type in ["checkpoint"]: graph_def, inputs, outputs = loader.from_checkpoint( model_path, inputs, outputs) elif self.model_type in ["saved_model"]: graph_def, inputs, outputs = loader.from_saved_model( model_path, inputs, outputs) else: graph_def, inputs, outputs = loader.from_graphdef( model_path, inputs, outputs) # create the input data inputs = {} for k, v in self.input_names.items(): if isinstance(v, six.text_type) and v.startswith("np."): inputs[k] = eval(v) # pylint: disable=eval-used else: inputs[k] = self.make_input(v) if self.more_inputs: for k, v in self.more_inputs.items(): inputs[k] = v graph_def = tf2onnx.tfonnx.tf_optimize(inputs.keys(), self.output_names, graph_def, fold_const) shape_override = {} g = tf.import_graph_def(graph_def, name='') with tf.Session(config=tf.ConfigProto(allow_soft_placement=True), graph=g) as sess: # fix inputs if needed for k in inputs.keys(): # pylint: disable=consider-iterating-dictionary t = sess.graph.get_tensor_by_name(k) dtype = tf.as_dtype(t.dtype).name if type != "float32": v = inputs[k] inputs[k] = v.astype(dtype) if self.force_input_shape: for k, v in inputs.items(): shape_override[k] = list(v.shape) # run the model with tensorflow if self.skip_tensorflow: print("\ttensorflow", "SKIPPED") else: tf_results = self.run_tensorflow(sess, inputs) print("\ttensorflow", "OK") model_proto = None try: # convert model to onnx onnx_graph = self.to_onnx(sess.graph, opset=opset, shape_override=shape_override, input_names=inputs.keys()) model_proto = onnx_graph.make_model("converted from tf2onnx") new_model_proto = GraphUtil.optimize_graph(onnx_graph, "test", debug=debug) if new_model_proto: model_proto = new_model_proto else: print( "\tNON-CRITICAL, optimizers are not applied successfully" ) print("\tto_onnx", "OK") if debug: onnx_graph.dump_graph() if onnx_file: self.create_onnx_file(name, model_proto, inputs, onnx_file) except Exception as ex: tb = traceback.format_exc() print("\tto_onnx", "FAIL", ex, tb) try: onnx_results = None if backend == "caffe2": onnx_results = self.run_caffe2(name, model_proto, inputs) elif backend == "onnxmsrtnext": onnx_results = self.run_onnxmsrtnext(name, model_proto, inputs) elif backend == "onnxruntime": onnx_results = self.run_onnxruntime(name, model_proto, inputs) else: raise ValueError("unknown backend") print("\trun_onnx OK") try: if self.skip_tensorflow: print("\tResults: skipped tensorflow") else: if self.check_only_shape: for tf_res, onnx_res in zip(tf_results, onnx_results): np.testing.assert_array_equal( tf_res.shape, onnx_res.shape) else: for tf_res, onnx_res in zip(tf_results, onnx_results): np.testing.assert_allclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol) print("\tResults: OK") return True except Exception as ex: print("\tResults: ", ex) except Exception as ex: print("\trun_onnx", "FAIL", ex) return False