def convert_to_tflite(self, graph_def, feed_dict, outputs): if not feed_dict: return None # Can't make TFlite model with no inputs tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') sess_inputs = [ sess.graph.get_tensor_by_name(k) for k in feed_dict.keys() ] sess_outputs = [sess.graph.get_tensor_by_name(n) for n in outputs] converter = tf_lite.TFLiteConverter.from_session( sess, sess_inputs, sess_outputs) #converter.optimizations = [tf.lite.Optimize.DEFAULT] from tensorflow.lite.python.convert import ConverterError try: tflite_model = converter.convert() tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite") dir_name = os.path.dirname(tflite_path) if dir_name: os.makedirs(dir_name, exist_ok=True) with open(tflite_path, 'wb') as f: f.write(tflite_model) return tflite_path except ConverterError: return None
def test_tensor_data(self): tensors = { "empty_tensor": np.array([], dtype=np.float32), "multi_dim_empty_tensor": np.array([[], []], dtype=np.float32), "scalar": np.array(1., dtype=np.float32), "one_item_array": np.array([1.], dtype=np.float32), "normal_array": np.array([[1., 2.], [2., 3.]], dtype=np.float32) } tf_reset_default_graph() with tf_session() as sess: for n, data in tensors.items(): tf.constant(data, dtype=tf.float32, name=n) for tf_node in sess.graph.get_operations(): name = tf_node.name self.assertTrue(name in tensors.keys()) self.assertTrue("value" in tf_node.node_def.attr) # convert to onnx tensor value tensor_value = tf_utils.tf_to_onnx_tensor( tf_utils.get_tf_node_attr(tf_node, "value"), name=utils.port_name(tf_node.name)) attr = helper.make_attribute("value", tensor_value) # same as node.get_tensor_value(is_list=False) actual = numpy_helper.to_array(helper.get_attribute_value(attr)) expected = tensors[name] self.assertTrue(np.array_equal(expected, actual))
def _run_test_case(self, input_names_with_port, output_names_with_port): try: tf.compat.v1.disable_eager_execution() except: # pylint: disable=bare-except pass graph_def = None with tf_session() as sess: # freeze graph origin_graph = sess.graph variables_lib.global_variables_initializer().run() output_name_without_port = [ n.split(':')[0] for n in output_names_with_port ] graph_def = tf.graph_util.convert_variables_to_constants( sess, sess.graph_def, output_name_without_port) tf_reset_default_graph() tf.import_graph_def(graph_def, name='') # optimize graph graph_def = tf_optimize(input_names_with_port, output_names_with_port, sess.graph_def, True) with tf_session() as sess: if self.config.is_debug_mode: if not os.path.exists(self.test_data_directory): os.makedirs(self.test_data_directory) model_path = os.path.join( self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(model_path, graph_def) self.logger.debug("created file %s", model_path) tf_reset_default_graph() tf.import_graph_def(graph_def, name='') with tf_session() as sess: inferred_graph = infer_shape_for_graph(sess.graph) # compare each operation for op in origin_graph.get_operations(): inferred_op = None try: inferred_op = inferred_graph.get_operation_by_name(op.name) except KeyError: continue self._compare_shape_for_op(op, inferred_op)
def from_graph_def(graph_def, name=None, input_names=None, output_names=None, opset=None, custom_ops=None, custom_op_handlers=None, custom_rewriter=None, inputs_as_nchw=None, extra_opset=None, shape_override=None, target=None, large_model=False, tensors_to_rename=None, output_path=None): """Returns a ONNX model_proto for a tensorflow graphdef. Args: graph_def: the graphdef we want to convert input_names: list of input names output_names: list of output names name: A name for the graph opset: the opset to be used for the ONNX model, default is the latest target: list of workarounds applied to help certain platforms custom_op_handlers: dictionary of custom ops handlers custom_rewriter: list of custom graph rewriters extra_opset: list of extra opset's, for example the opset's used by custom ops shape_override: dict with inputs that override the shapes given by tensorflow inputs_as_nchw: transpose inputs in list from nchw to nhwc large_model: use the ONNX external tensor storage format output_path: save model to output_path Returns: An ONNX model_proto and an external_tensor_storage dict. """ if not input_names: raise ValueError("input_names needs to be provided") if not output_names: raise ValueError("output_names needs to be provided") if not name: name = "unknown" initialized_tables = None with tf.device("/cpu:0"): with tf.Graph().as_default() as tf_graph: with tf_loader.tf_session(graph=tf_graph) as sess: tf.import_graph_def(graph_def, name='') frozen_graph = tf_loader.freeze_session(sess, input_names=input_names, output_names=output_names) input_names = tf_loader.inputs_without_resource(sess, input_names) frozen_graph = tf_loader.tf_optimize(input_names, output_names, graph_def) model_proto, external_tensor_storage = _convert_common( frozen_graph, name=name, continue_on_error=True, target=None, opset=opset, custom_op_handlers=custom_ops, extra_opset=extra_opset, shape_override=shape_override, input_names=input_names, output_names=output_names, inputs_as_nchw=inputs_as_nchw, large_model=large_model, tensors_to_rename=tensors_to_rename, initialized_tables=initialized_tables, output_path=output_path) return model_proto, external_tensor_storage
def test_parse_tflite_graph(self): def func(a, b, c): alpha = tf.constant(1.1, dtype=tf.float32) beta = tf.constant(2.3, dtype=tf.float32) mul1 = tf.multiply(alpha, tf.matmul(a, b)) mul2 = tf.multiply(beta, c) x_ = mul1 + mul2 return tf.identity(x_, name="output") inp_shapes = [[2, 3], [3, 1], [2, 1]] inp_dtypes = [tf.float32, tf.float32, tf.float32] names = ['a', 'b', 'c'] names_with_port = ['a:0', 'b:0', 'c:0'] output_names = ['output'] output_names_with_port = ['output:0'] input_tensors = [tf.TensorSpec(shape=s, dtype=d, name=n) for s, d, n in zip(inp_shapes, inp_dtypes, names)] concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=names_with_port, output_names=output_names_with_port) with tf_session() as sess: tf.import_graph_def(graph_def, name='') sess_inputs = [sess.graph.get_tensor_by_name(k) for k in names_with_port] sess_outputs = [sess.graph.get_tensor_by_name(n) for n in output_names_with_port] converter = tf.compat.v1.lite.TFLiteConverter.from_session(sess, sess_inputs, sess_outputs) tflite_model = converter.convert() tflite_path = os.path.join(self.test_data_directory, self._testMethodName + ".tflite") dir_name = os.path.dirname(tflite_path) tflite_model = converter.convert() os.makedirs(dir_name, exist_ok=True) with open(tflite_path, 'wb') as f: f.write(tflite_model) tflite_graphs, opcodes_map, model, tensor_shapes = read_tflite_model(tflite_path) self.assertEqual(1, len(tflite_graphs)) onnx_nodes, op_cnt, attr_cnt, output_shapes, dtypes, inputs, outputs, _ = \ parse_tflite_graph(tflite_graphs[0], opcodes_map, model, tensor_shapes_override=tensor_shapes) self.assertEqual(2, op_cnt['MUL']) self.assertEqual(1, op_cnt['ADD']) self.assertEqual(1, op_cnt['FULLY_CONNECTED']) self.assertEqual(1, attr_cnt['WeightsFormat']) self.assertEqual(names, inputs) self.assertEqual(output_names, outputs) for name, shape, dtype in zip(names, inp_shapes, inp_dtypes): self.assertEqual(shape, output_shapes[name]) self.assertEqual(dtype, dtypes[name]) self.assertTrue(len(onnx_nodes) >= 4)
def compute_const_folding_using_tf(g, const_node_values, graph_outputs): """Find nodes with constant inputs and compute their values using TF""" if const_node_values is None: const_node_values = {} graph_outputs = set(graph_outputs) from tf2onnxnightly.tf_loader import tf_session, tf_placeholder # pylint: disable=import-outside-toplevel ops = g.get_operations() outputs_to_values = {} outputs_to_dtypes = {} outputs_to_shapes = {} shape_node_outputs = {} def is_small_shape(x): return np.product(x) <= 1000 def is_huge_shape(x): return np.product(x) >= 1000000 for node in ops: # Load values of constants. Use const_node_values if possible if node.type in ["Const", "ConstV2"]: tensor = node.node_def.attr["value"].tensor if node.name in const_node_values: tensor.tensor_content = const_node_values[node.name] outputs_to_values[node.outputs[0].name] = get_tf_tensor_data( tensor) outputs_to_dtypes[node.outputs[0].name] = node.outputs[0].dtype for out in node.outputs: outputs_to_shapes[out.name] = get_tf_tensor_shape(out) for node in ops: if node.type == "Shape": shape = outputs_to_shapes.get(node.inputs[0].name) if shape is not None: shape_node_outputs[node.outputs[0].name] = shape unneeded_outputs = set() progress = True while progress: progress = False for node in ops: # Find ops with constant inputs and compute their values input_names = [i.name for i in node.inputs] output_names = [i.name for i in node.outputs] if node.type == 'StridedSlice' and input_names[0] in shape_node_outputs \ and output_names[0] not in outputs_to_values: shape = shape_node_outputs[input_names[0]] i = get_index_from_strided_slice_of_shape( node, outputs_to_values) if i is not None and 0 <= i < len( shape) and shape[i] is not None: np_dtype = map_onnx_to_numpy_type( map_tf_dtype(node.outputs[0].dtype)) outputs_to_values[output_names[0]] = np.array( shape[i], dtype=np_dtype) outputs_to_dtypes[ node.outputs[0].name] = node.outputs[0].dtype progress = True can_fold = node.type not in [ 'Enter', 'Placeholder', 'PlaceholderWithDefault' ] can_fold = can_fold and not node.type.startswith('Random') can_fold = can_fold and len(input_names) > 0 and all( inp in outputs_to_values for inp in input_names) # We can only fold nodes with a single output can_fold = can_fold and len( output_names) == 1 and output_names[0] not in outputs_to_values # Skip if value already computed, used, and discarded can_fold = can_fold and output_names[ 0] not in unneeded_outputs and output_names[ 0] not in graph_outputs if can_fold: # Make a mini graph containing just the node to fold g2 = tf.Graph() with g2.as_default(): for inp in input_names: tf_placeholder(outputs_to_dtypes[inp], name=inp.split(':')[0]) mini_graph_def = g2.as_graph_def() mini_graph_def.node.append(node.node_def) g3 = tf.Graph() with g3.as_default(): feed_dict = {} inp_shapes = [] for inp in input_names: inp_np = outputs_to_values[inp] feed_dict[inp] = inp_np inp_shapes.append(inp_np.shape) try: with tf_session() as sess: tf.import_graph_def(mini_graph_def, name='') results = sess.run(output_names, feed_dict=feed_dict) if is_huge_shape(results[0].shape) and all( is_small_shape(inp) for inp in inp_shapes): logger.debug( "Skipping folding of node %s since result shape %s is much larger " "than input shapes %s", node.name, results[0].shape, inp_shapes) else: outputs_to_values[output_names[0]] = results[0] outputs_to_dtypes[ output_names[0]] = node.outputs[0].dtype progress = True except Exception: # pylint: disable=broad-except logger.debug("Could not fold node %s", node.name) unneeded_outputs.update(outputs_to_values.keys()) for node in ops: # Mark values we need to keep input_names = [i.name for i in node.inputs] output_names = [i.name for i in node.outputs] if len(output_names) == 1 and output_names[0] in outputs_to_values: continue for i in input_names: if i in unneeded_outputs: unneeded_outputs.remove(i) for node in unneeded_outputs: # Remove unneeded values to prevent memory usage explosion if node in outputs_to_values: del outputs_to_values[node] del outputs_to_dtypes[node] for node in ops: # We don't need the constants any more if node.type in ["Const", "ConstV2" ] and node.outputs[0].name in outputs_to_values: del outputs_to_values[node.outputs[0].name] del outputs_to_dtypes[node.outputs[0].name] logger.info("Computed %d values for constant folding", len(outputs_to_values)) return outputs_to_values, outputs_to_dtypes
def run_test_case(self, func, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5, convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=True, check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None, as_session=False, large_model=False, premade_placeholders=False): test_tf = not self.config.skip_tf_tests test_tflite = not self.config.skip_tflite_tests run_tfl_consistency_test = test_tf and test_tflite and self.config.run_tfl_consistency_test # optional - passed to process_tf_graph if process_args is None: process_args = {} # optional - pass distinct feed_dict to onnx runtime if onnx_feed_dict is None: onnx_feed_dict = feed_dict input_names_with_port = list(feed_dict) tf_reset_default_graph() if tf_lite is None: test_tflite = False g = None expected, graph_def, initialized_tables = \ self.freeze_and_run_tf(func, feed_dict, output_names_with_port, as_session, premade_placeholders, large_model, constant_fold) if test_tflite: tflite_path = self.convert_to_tflite(graph_def, feed_dict, output_names_with_port) test_tflite = tflite_path is not None if test_tf: tf_reset_default_graph() with tf_session() as sess: const_node_values = None if large_model: const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') g = process_tf_graph(sess.graph, opset=self.config.opset, input_names=list(feed_dict.keys()), output_names=output_names_with_port, target=self.config.target, const_node_values=const_node_values, initialized_tables=initialized_tables, **process_args) g = optimizer.optimize_graph(g, catch_errors=False) actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model) self.assert_results_equal(expected, actual, rtol, atol, check_value, check_shape, check_dtype) if graph_validator: self.assertTrue(graph_validator(g)) if test_tflite: tfl_results, tfl_outputs = self.run_tflite(tflite_path, feed_dict) test_tflite = tfl_results is not None if test_tflite: if run_tfl_consistency_test: self.assert_results_equal(expected, tfl_results, rtol, atol, check_value, check_shape, check_dtype) tfl_process_args = process_args.copy() if 'inputs_as_nchw' in tfl_process_args: nchw_inps_with_port = tfl_process_args['inputs_as_nchw'] tfl_process_args['inputs_as_nchw'] = [ i.split(':')[0] for i in nchw_inps_with_port ] input_names_without_port = [ inp.split(':')[0] for inp in feed_dict.keys() ] g = process_tf_graph(None, opset=self.config.opset, input_names=input_names_without_port, output_names=tfl_outputs, target=self.config.target, tflite_path=tflite_path, **tfl_process_args) g = optimizer.optimize_graph(g) onnx_feed_dict_without_port = { k.split(':')[0]: v for k, v in onnx_feed_dict.items() } onnx_from_tfl_res = self.run_backend(g, tfl_outputs, onnx_feed_dict_without_port, postfix="_from_tflite") self.assert_results_equal(tfl_results, onnx_from_tfl_res, rtol, atol, check_value, check_shape, check_dtype) if graph_validator: self.assertTrue(graph_validator(g)) if g is None: raise unittest.SkipTest("Both tf and tflite marked to skip") return g
def freeze_and_run_tf(self, func, feed_dict, outputs, as_session, premade_placeholders, large_model, constant_fold): np.random.seed(1) # Make it reproducible. clean_feed_dict = {utils.node_name(k): v for k, v in feed_dict.items()} if is_tf2() and not as_session: # # use eager to execute the tensorflow func # # numpy doesn't work for all ops, make it tf.Tensor() input_tensors = [ tf.TensorSpec(shape=v.shape, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] input_list = [ tf.convert_to_tensor(v, dtype=tf.as_dtype(v.dtype), name=utils.node_name(k)) for k, v in feed_dict.items() ] tf.random.set_seed(1) result = func(*input_list) if isinstance(result, (list, tuple)): # list or tuple result = [x.numpy() for x in result] else: # single result result = [result.numpy()] # now make the eager functions a graph concrete_func = tf.function(func, input_signature=tuple(input_tensors)) concrete_func = concrete_func.get_concrete_function() graph_def = from_function(concrete_func, input_names=list(feed_dict.keys()), output_names=outputs, large_model=large_model) initialized_tables = None else: # # use graph to execute the tensorflow func # with tf_session() as sess: tf_set_random_seed(1) input_list = [] if not premade_placeholders: for k, v in clean_feed_dict.items(): input_list.append( tf_placeholder(name=k, shape=v.shape, dtype=tf.as_dtype(v.dtype))) func(*input_list) variables_lib.global_variables_initializer().run() tf_tables_initializer().run() output_dict = [] for out_name in outputs: output_dict.append(sess.graph.get_tensor_by_name(out_name)) result = sess.run(output_dict, feed_dict=feed_dict) graph_def = freeze_session(sess, input_names=list(feed_dict.keys()), output_names=outputs) table_names, key_dtypes, value_dtypes = get_hash_table_info( graph_def) initialized_tables = {} for n, k_dtype, val_dtype in zip(table_names, key_dtypes, value_dtypes): h = lookup_ops.hash_table_v2(k_dtype, val_dtype, shared_name=n) k, v = lookup_ops.lookup_table_export_v2( h, k_dtype, val_dtype) initialized_tables[n] = (sess.run(k), sess.run(v)) tf_reset_default_graph() with tf_session() as sess: tf.import_graph_def(graph_def, name='') graph_def = tf_optimize(list(feed_dict.keys()), outputs, graph_def, fold_constant=constant_fold) model_path = os.path.join( self.test_data_directory, self._testMethodName + "_after_tf_optimize.pb") utils.save_protobuf(model_path, graph_def) self.logger.debug("created file %s", model_path) return result, graph_def, initialized_tables
def run_test(self, name, backend="caffe2", onnx_file=None, opset=None, extra_opset=None, perf=None, fold_const=None): """Run complete test against backend.""" self.perf = perf # get the model if self.url: _, dir_name = self.download_model() logger.info("Downloaded to %s", dir_name) model_path = os.path.join( dir_name, self.local) if self.local != "." else dir_name else: model_path = self.local logger.info("Load model from %s", model_path) input_names = list(self.input_names.keys()) initialized_tables = {} outputs = self.output_names tflite_path = None to_rename = None if self.model_type in ["checkpoint"]: graph_def, input_names, outputs = tf_loader.from_checkpoint( model_path, input_names, outputs) elif self.model_type in ["saved_model"]: loaded = tf_loader.from_saved_model( model_path, None, None, self.tag, self.signatures, self.concrete_function, self.large_model, return_concrete_func=not self.run_tf_frozen, return_initialized_tables=True, return_tensors_to_rename=True) if not self.run_tf_frozen: # Must maintain ref to imported since concrete_func uses weak refs # pylint: disable=unused-variable graph_def, input_names, outputs, concrete_func, imported, initialized_tables, to_rename = loaded else: graph_def, input_names, outputs, initialized_tables, to_rename = loaded elif self.model_type in ["keras"]: graph_def, input_names, outputs = tf_loader.from_keras( model_path, input_names, outputs) elif self.model_type in ["tflite"]: tflite_path = model_path graph_def = None else: graph_def, input_names, outputs = tf_loader.from_graphdef( model_path, input_names, outputs) if utils.is_debug_mode(): utils.save_protobuf( os.path.join(TEMP_DIR, name + "_after_tf_optimize.pb"), graph_def) if tflite_path is not None: inputs = {} for k in input_names: v = self.input_names[k] inputs[k] = self.make_input(v) interpreter = tf.lite.Interpreter(tflite_path) input_details = interpreter.get_input_details() output_details = interpreter.get_output_details() input_name_to_index = { n['name'].split(':')[0]: n['index'] for n in input_details } for k, v in inputs.items(): interpreter.resize_tensor_input(input_name_to_index[k], v.shape) interpreter.allocate_tensors() def run_tflite(): for k, v in inputs.items(): interpreter.set_tensor(input_name_to_index[k], v) interpreter.invoke() result = [ interpreter.get_tensor(output['index']) for output in output_details ] return result tf_results = run_tflite() if self.perf: logger.info("Running TFLite perf") start = time.time() for _ in range(PERFITER): _ = run_tflite() self.tf_runtime = time.time() - start logger.info("TFLite OK") if not self.run_tf_frozen: inputs = {} for k in input_names: v = self.input_names[k] inputs[k.split(":")[0]] = tf.constant(self.make_input(v)) tf_func = tf.function(concrete_func) logger.info("Running TF") tf_results_d = tf_func(**inputs) # If there is only a single output a dict might not be returned if isinstance(tf_results_d, tf.Tensor): tf_results = [tf_results_d] else: tf_results = [ tf_results_d[k] for k in sorted(tf_results_d.keys()) ] tf_results = [tf_res.numpy() for tf_res in tf_results] if self.perf: logger.info("Running TF perf") start = time.time() for _ in range(PERFITER): _ = concrete_func(**inputs) self.tf_runtime = time.time() - start logger.info("TensorFlow OK") shape_override = {} const_node_values = None tf_graph = None if graph_def is not None: inputs = {} tf_reset_default_graph() with tf.Graph().as_default() as tf_graph: from tf2onnxnightly.tf_utils import compress_graph_def if self.large_model: const_node_values = compress_graph_def(graph_def) tf.import_graph_def(graph_def, name='') with tf_session(graph=tf_graph) as sess: # create the input data for k in input_names: v = self.input_names[k] t = sess.graph.get_tensor_by_name(k) expected_dtype = tf.as_dtype(t.dtype).name if isinstance(v, six.text_type) and v.startswith("np."): np_value = eval(v) # pylint: disable=eval-used if expected_dtype != np_value.dtype: logger.warning( "dtype mismatch for input %s: expected=%s, actual=%s", k, expected_dtype, np_value.dtype) inputs[k] = np_value.astype(expected_dtype) else: if expected_dtype == "string": inputs[k] = self.make_input(v).astype( np.str).astype(np.object) else: inputs[k] = self.make_input(v).astype( expected_dtype) if self.force_input_shape: for k, v in inputs.items(): shape_override[k] = list(v.shape) # run the model with tensorflow if self.skip_tensorflow: logger.info("TensorFlow SKIPPED") elif self.run_tf_frozen: tf_results = self.run_tensorflow(sess, inputs) logger.info("TensorFlow OK") tf_graph = sess.graph model_proto = None if self.skip_conversion: if self.large_model: external_tensor_storage = ExternalTensorStorage() model_proto = utils.model_proto_from_zip( self.converted_model, external_tensor_storage) else: external_tensor_storage = None model_proto = utils.model_proto_from_file(self.converted_model) logger.info("ONNX loaded from file") else: try: # convert model to onnx onnx_graph = self.to_onnx( tf_graph, opset=opset, extra_opset=extra_opset, shape_override=shape_override, input_names=inputs.keys(), const_node_values=const_node_values, initialized_tables=initialized_tables, tflite_path=tflite_path, tensors_to_rename=to_rename) onnx_graph = optimizer.optimize_graph(onnx_graph) print("ONNX", onnx_graph.dump_node_statistics()) external_tensor_storage = ExternalTensorStorage( ) if self.large_model else None model_proto = onnx_graph.make_model( "converted from tf2onnx", external_tensor_storage=external_tensor_storage) logger.info("To_ONNX, OK") if onnx_file: self.create_onnx_file(name, model_proto, inputs, onnx_file, external_tensor_storage) if self.converted_model: if self.large_model: utils.save_onnx_zip(self.converted_model, model_proto, external_tensor_storage) else: utils.save_protobuf(self.converted_model, model_proto) logger.info("Created %s", self.converted_model) except Exception: logger.error("To_ONNX FAIL", exc_info=1) return False try: onnx_results = None if backend == "caffe2": onnx_results = self.run_caffe2(name, model_proto, inputs) elif backend == "onnxruntime": if to_rename is None: struc_outputs = self.output_names else: struc_outputs = [ to_rename.get(k, k) for k in self.output_names ] onnx_results = self.run_onnxruntime(name, model_proto, inputs, struc_outputs, external_tensor_storage) else: raise ValueError("unknown backend") logger.info("Run_ONNX OK") try: if self.skip_tensorflow: logger.info("Results: skipped tensorflow") else: if self.check_only_shape: for tf_res, onnx_res in zip(tf_results, onnx_results): np.testing.assert_array_equal( tf_res.shape, onnx_res.shape) else: for tf_res, onnx_res in zip(tf_results, onnx_results): good_cnt = np.count_nonzero( np.isclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol)) bad_cnt = tf_res.size - good_cnt if bad_cnt > self.ptol / 100 * tf_res.size: # Prints a nice error message with stats np.testing.assert_allclose(tf_res, onnx_res, rtol=self.rtol, atol=self.atol) logger.info("Results: OK") return True except Exception: logger.error("Results", exc_info=1) except Exception: logger.error("Run_ONNX FAIL", exc_info=1) return False