def test_conversion(self, input_shape, scales, axes):
     graph = build_graph(graph_node_attrs,
                         graph_edges,
                         {
                             'placeholder_data': {'shape': int64_array(input_shape)},
                             'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}
                         })
     graph.graph['layout'] = 'NCHW'
     ref_graph = build_graph(new_ref_graph_node_attr,
                             new_ref_graph_edges,
                             {
                                 'placeholder_data': {'shape': int64_array(input_shape)},
                                 'ss_begin': {'value': int64_array([axes[0]])},
                                 'ss_end': {'value': int64_array([axes[-1] + 1])},
                                 'ss_begin_data': {'value': int64_array([axes[0]])},
                                 'ss_end_data': {'value': int64_array([axes[-1] + 1])},
                                 'factor': {'value': int64_array(scales)[2:],
                                            'shape': int64_array(scales[2:]).shape},
                                 'factor_data': {'value': int64_array(scales)[2:],
                                                 'shape': int64_array(scales[2:]).shape},
                                 'axes_const': {'value': int64_array(axes), 'shape': int64_array(axes).shape},
                                 'interpolate_data': {'shape': int64_array(input_shape) * int64_array(scales)},
                             })
     UpsampleToResample().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
    def test_pattern_does_not_satisfy(self, input_shape, scales):
        graph = build_graph(graph_node_attrs, graph_edges,
                            {'placeholder_data': {'shape': int64_array(input_shape)},
                             'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}})
        graph.graph['layout'] = 'NCHW'

        ref_graph = build_graph(graph_node_attrs, graph_edges,
                            {'placeholder_data': {'shape': int64_array(input_shape)},
                             'scales': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'scales_data': {'value': int64_array(scales), 'shape': int64_array(scales).shape},
                             'upsample_data': {'shape': int64_array(input_shape) * int64_array(scales)}})

        UpsampleToResample().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
        self.assertTrue(flag, resp)
Example #3
0
    def test_conversion(self, input_shape, scales):
        graph = build_graph(
            graph_node_attrs, graph_edges, {
                'placeholder_data': {
                    'shape': int64_array(input_shape)
                },
                'scales': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'scales_data': {
                    'value': int64_array(scales),
                    'shape': int64_array(scales).shape
                },
                'upsample_data': {
                    'shape': int64_array(input_shape) * int64_array(scales)
                }
            })
        graph.graph['layout'] = 'NCHW'

        ref_graph = build_graph(
            ref_graph_node_attrs, ref_graph_edges, {
                'placeholder_data': {
                    'shape': int64_array(input_shape)
                },
                'factor': {
                    'value': int64_array(scales)[2:],
                    'shape': int64_array(scales[2:]).shape
                },
                'interpolate_data': {
                    'shape': int64_array(input_shape) * int64_array(scales)
                },
                'interpolate': {
                    'axes': list(range(2, len(input_shape)))
                }
            })

        UpsampleToResample().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph, ref_graph, 'output')
        self.assertTrue(flag, resp)