def test_interpolate_concat_reshape_graph_comparison(self): graph = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', '0:concat'), *connect('placeholder_1', '1:concat'), *connect('concat', 'output'), ], nodes_with_edges_only=True) InterpolateWithConcat().find_and_replace_pattern(graph) graph.graph['cmd_params'] = Namespace(keep_shape_ops=True) graph.clean_up() graph_ref = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('placeholder_1', 'shape'), *connect('shape', '0:gather'), *connect('indices', '1:gather'), *connect('axis', '2:gather'), *connect('gather', '1:interpolate'), *connect('interpolate', '0:concat'), *connect_data('placeholder_1', '1:concat'), *connect('concat', 'output'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_interpolate_concat_negate(self): graph = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', 'identity_00'), *connect('interpolate', 'identity_01'), *connect('identity_00', 'output'), *connect('identity_01', 'output_1'), ], nodes_with_edges_only=True) InterpolateWithConcat().find_and_replace_pattern(graph) graph.clean_up() graph_ref = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', 'identity_00'), *connect('interpolate', 'identity_01'), *connect('identity_00', 'output'), *connect('identity_01', 'output_1'), ], nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_negative_axes_conditions(self, update_attrs): graph = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', '0:concat'), *connect('placeholder_1', '1:concat'), *connect('concat', 'output'), ], update_attributes=update_attrs, nodes_with_edges_only=True) InterpolateWithConcat().find_and_replace_pattern(graph) graph_ref = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', '0:concat'), *connect('placeholder_1', '1:concat'), *connect('concat', 'output'), ], update_attributes=update_attrs, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def test_interpolate_tf_style_concat(self): graph = build_graph(nodes, [ *connect('placeholder', '0:interpolate'), *connect('out_shape', '1:interpolate'), *connect('interpolate', '0:concat'), *connect('N', '1:concat'), *connect('concat', 'output'), ], update_attributes={'concat': {'N': 1}}, nodes_with_edges_only=True) graph_ref = graph.copy() InterpolateWithConcat().find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)