Beispiel #1
0
    def test_interpolate_concat_reshape_graph_comparison(self):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                            nodes_with_edges_only=True)

        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph.clean_up()
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('placeholder_1', 'shape'),
            *connect('shape', '0:gather'),
            *connect('indices', '1:gather'),
            *connect('axis', '2:gather'),
            *connect('gather', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect_data('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #2
0
    def test_negative_axes_conditions(self, update_attrs):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                            update_attributes=update_attrs,
                            nodes_with_edges_only=True)
        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', '0:concat'),
            *connect('placeholder_1', '1:concat'),
            *connect('concat', 'output'),
        ],
                                update_attributes=update_attrs,
                                nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #3
0
    def test_interpolate_concat_negate(self):
        graph = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', 'identity_00'),
            *connect('interpolate', 'identity_01'),
            *connect('identity_00', 'output'),
            *connect('identity_01', 'output_1'),
        ],
                            nodes_with_edges_only=True)

        InterpolateWithConcat().find_and_replace_pattern(graph)
        graph.clean_up()
        graph_ref = build_graph(nodes, [
            *connect('placeholder', '0:interpolate'),
            *connect('out_shape', '1:interpolate'),
            *connect('interpolate', 'identity_00'),
            *connect('interpolate', 'identity_01'),
            *connect('identity_00', 'output'),
            *connect('identity_01', 'output_1'),
        ],
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #4
0
 def test_interpolate_tf_style_concat(self):
     graph = build_graph(nodes, [
         *connect('placeholder', '0:interpolate'),
         *connect('out_shape', '1:interpolate'),
         *connect('interpolate', '0:concat'),
         *connect('N', '1:concat'),
         *connect('concat', 'output'),
     ],
                         update_attributes={'concat': {
                             'N': 1
                         }},
                         nodes_with_edges_only=True)
     graph_ref = graph.copy()
     InterpolateWithConcat().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph,
                                   graph_ref,
                                   'output',
                                   check_op_attrs=True)
     self.assertTrue(flag, resp)