def do_test_expected(self): tf_op = test_data[1] output_name = test_data[2] inputs = test_data[3] attrs = test_data[4] # Now construct input feed dict # keyed by input name onnx_feed_dict = {} # keyed by placeholder op tf_feed_dict = {} tf_param_list = [] for idx, input_tensor in enumerate(inputs): if type(input_tensor) is np.ndarray: placeholder = tf.placeholder(input_tensor.dtype, shape=input_tensor.shape, name="in_" + str(idx)) onnx_feed_dict["in_" + str(idx)] = input_tensor tf_feed_dict[placeholder] = input_tensor tf_param_list.append(placeholder) else: tf_param_list.append(input_tensor) test_op = tf_op(*tf_param_list, **attrs) tf_graph = test_op.graph.as_graph_def(add_shapes=True) # Construct onnx graph, run with backend. output_node = get_node_by_name(tf_graph.node, output_name) onnx_graph = convert_graph(tf_graph, output_node) backend_rep = prepare(helper.make_model(onnx_graph)) backend_output = backend_rep.run(onnx_feed_dict)[output_name] with tf.Session() as sess: tf_output = sess.run(test_op, tf_feed_dict) tf.reset_default_graph() np.testing.assert_allclose(backend_output, tf_output)
def do_test_expected(self): tf_op = test_data[1] output_name = test_data[2] inputs = test_data[3] attrs = test_data[4] device = test_data[5] channel_last = test_data[6] # Now construct input feed dict # keyed by input name onnx_feed_dict = {} # keyed by placeholder op tf_feed_dict = {} tf_param_list = [] for idx, input_tensor in enumerate(inputs): if type(input_tensor) is np.ndarray: placeholder = tf.placeholder(input_tensor.dtype, shape=input_tensor.shape, name="in_" + str(idx)) onnx_feed_dict["in_" + str(idx)] = input_tensor # TF have to get input in format : NHWC tf_feed_dict[placeholder] = input_tensor tf_param_list.append(placeholder) else: tf_param_list.append(input_tensor) test_op = tf_op(*tf_param_list, **attrs) tf_graph = tf.get_default_graph().as_graph_def(add_shapes=True) # Construct onnx graph, run with backend. output_node = get_node_by_name(tf_graph.node, output_name) onnx_graph = convert_graph(tf_graph, output_node, device=device, channel_last=channel_last) onnx_model = helper.make_model(onnx_graph) backend_rep = prepare(onnx_model, device=device, channel_last=channel_last) backend_output = backend_rep.run(onnx_feed_dict)[output_name] with tf.Session() as sess: tf_output = sess.run(test_op, tf_feed_dict) tf.reset_default_graph() # skip comparison if test_option specifies that # the test is call only. if (test_option.get("call_only", False)): return np.testing.assert_allclose(backend_output, tf_output)
def do_test_expected(self): tf.reset_default_graph() tf_op = test_data[1] output_name = test_data[2] inputs = test_data[3] attrs = test_data[4] # Now construct input feed dict # keyed by input name onnx_feed_dict = {} # keyed by placeholder op tf_feed_dict = {} tf_param_list = [] for idx, input_tensor in enumerate(inputs): if type(input_tensor) is np.ndarray: placeholder = tf.placeholder(input_tensor.dtype, shape=input_tensor.shape, name="in_" + str(idx)) onnx_feed_dict["in_" + str(idx)] = input_tensor tf_feed_dict[placeholder] = input_tensor tf_param_list.append(placeholder) else: tf_param_list.append(input_tensor) test_op = tf_op(*tf_param_list, **attrs) tf_graph = tf.get_default_graph().as_graph_def(add_shapes=True) # Construct onnx graph, run with backend. output_node = get_node_by_name(tf_graph.node, output_name) onnx_graph = convert_graph(tf_graph, output_node) onnx_model = helper.make_model(onnx_graph) backend_rep = prepare(onnx_model) backend_output = [] backend_rep_outputs = backend_rep.run(onnx_feed_dict) for ext_output in backend_rep.predict_net.external_output: backend_output.append(backend_rep_outputs[ext_output]) backend_output = np.asarray(backend_output) backend_output = np.squeeze( backend_output, 0) if backend_output.shape[0] == 1 else backend_output with tf.Session() as sess: tf_output = sess.run(test_op, tf_feed_dict) # skip comparison if test_option specifies that # the test is call only. if (test_option.get("call_only", False)): return for backend_o, tf_o in zip(backend_output, tf_output): np.testing.assert_allclose(backend_o, tf_o)