예제 #1
0
    def test_interpolate_concat_reshape_graph_comparison(self):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                            nodes_with_edges_only=True)

        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph.graph['cmd_params'] = Namespace(keep_shape_ops=True)
        graph.clean_up()
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('placeholder_1', 'shape'),
            *connect('shape', '0:gather'),
            *connect('indices', '1:gather'),
            *connect('axis', '2:gather'),
            *connect('gather', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect_data('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_interpolate_concat_negate(self):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', 'identity_00'),
            *connect('interpolate', 'identity_01'),
            *connect('identity_00', 'output'),
            *connect('identity_01', 'output_1'),
        ],
                            nodes_with_edges_only=True)

        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph.clean_up()
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', 'identity_00'),
            *connect('interpolate', 'identity_01'),
            *connect('identity_00', 'output'),
            *connect('identity_01', 'output_1'),
        ],
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
예제 #3
0
    def test_negative_axes_conditions(self, update_attrs):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                            update_attributes=update_attrs,
                            nodes_with_edges_only=True)
        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                                update_attributes=update_attrs,
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
 def test_interpolate_tf_style_concat(self):
     graph = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect('out_shape', '1:interpolate'),
         *connect('interpolate', '0:concat'),
         *connect('N', '1:concat'),
         *connect('concat', 'output'),
     ], update_attributes={'concat': {'N': 1}}, nodes_with_edges_only=True)
     graph_ref = graph.copy()
     InterpolateWithConcat().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
     self.assertTrue(flag, resp)