Пример #1
0
    def test_4D_multiple_consumers(self):
        input_shape = int64_array([1, 300, 300, 3])
        axes = int64_array([1, 2, 3])
        weights_value = np.ones(shape=int64_array([input_shape[-1]]), dtype=np.float32)

        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
            ('result_2', dict(kind='op', op='Result'))
        ], edges + [('input_data', 'result_2')], nodes_with_edges_only=True)
        graph.stage = 'middle'

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('weights_node_data', dict(kind='data', value=axes.sort())),
            ('result_2', dict(kind='op', op='Result'))
        ], edges_after_replacement + [('input_data', 'result_2')], nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
        self.assertTrue(flag, resp)
Пример #2
0
    def test_positive(self, input_shape, axes, layout):
        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph.graph['layout'] = layout

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('weights_node_data', dict(kind='data', value=axes.sort())),
        ], edges_after_replacement, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(type='NormalizeL2')[0]]['name'] == 'l2_norm_name')
        self.assertTrue(flag, resp)
Пример #3
0
    def test_negative(self, input_shape, axes, layout):
        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)
        graph.stage = 'middle'
        graph.graph['layout'] = layout

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)
Пример #4
0
    def test_4D_negative_4(self):
        input_shape = int64_array([1, 300, 300, 3])
        axes = int64_array([2, 0])

        graph = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)
        graph.stage = 'middle'

        L2NormToNorm().find_and_replace_pattern(graph)

        graph_ref = build_graph_with_attrs(nodes + [
            ('input', dict(kind='op', shape=input_shape, op='Parameter', data_type=np.float32)),
            ('input_data', dict(kind='data', shape=input_shape, data_type=np.float32)),
            ('square_data', dict(kind='data', shape=input_shape)),
            ('sum_axes_data', dict(kind='data', value=axes, shape=None)),
        ], edges, nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True)
        self.assertTrue(flag, resp)