예제 #1
0
    def test_replase_eltwise_n_3(self):
        graph = build_graph(
            {
                'node_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op',
                    'op': 'Placeholder'
                },
                'node_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'node_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'node_4': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'node_5': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'add_n': {
                    'value': None,
                    'operation': 'sum',
                    'type': None,
                    'kind': 'op',
                    'op': 'EltwiseN'
                },
                'node_6': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            },
            [
                ('node_1', 'node_2'),
                ('node_2', 'add_n'),
                ('node_3', 'add_n'),
                ('node_4', 'add_n'),
                ('node_5', 'add_n'),
                ('add_n', 'node_6'),
            ],
        )

        add_n_node = Node(graph, 'add_n')
        rep_op = EltwiseNReplacement()
        rep_op.replace_op(graph, add_n_node)
        eltwise_nodes = [
            node for node, attrs in list(graph.nodes(data=True))
            if attrs['type'] == 'Eltwise'
        ]
        self.assertEqual(len(eltwise_nodes), 3)
예제 #2
0
    def test_eltwise_test_2(self):
        # EltwiseN test with N = 3 from 3 placeholders and operation = mul

        graph = build_graph(nodes_attributes, [('placeholder_1', 'EltwiseN_1'),
                                               ('placeholder_2', 'EltwiseN_1'),
                                               ('placeholder_3', 'EltwiseN_1'),
                                               ('EltwiseN_1', 'last')],
                            {
                                'placeholder_1': {
                                    'shape': np.array([1, 227, 227, 3])
                                },
                                'placeholder_2': {
                                    'shape': np.array([1, 227, 227, 3])
                                },
                                'placeholder_3': {
                                    'shape': np.array([1, 227, 227, 3])
                                },
                                'EltwiseN_1': {
                                    'operation': 'mul'
                                },
                            },
                            nodes_with_edges_only=True)

        graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'mul_1'),
                                                   ('placeholder_2', 'mul_1'),
                                                   ('mul_1', 'mul_2'),
                                                   ('placeholder_3', 'mul_2'),
                                                   ('mul_2', 'last')],
                                {
                                    'placeholder_1': {
                                        'shape': np.array([1, 227, 227, 3])
                                    },
                                    'placeholder_2': {
                                        'shape': np.array([1, 227, 227, 3])
                                    },
                                    'placeholder_3': {
                                        'shape': np.array([1, 227, 227, 3])
                                    },
                                    'mul_1': {
                                        'type': 'Multiply'
                                    },
                                    'mul_2': {
                                        'type': 'Multiply'
                                    },
                                },
                                nodes_with_edges_only=True)

        graph.stage = 'front'

        replacer = EltwiseNReplacement()
        replacer.find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'last',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)