def test_replase_eltwise_n_3(self): graph = build_graph( { 'node_1': { 'type': 'Identity', 'value': None, 'kind': 'op', 'op': 'Placeholder' }, 'node_2': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'node_3': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'node_4': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'node_5': { 'type': 'Identity', 'value': None, 'kind': 'op' }, 'add_n': { 'value': None, 'operation': 'sum', 'type': None, 'kind': 'op', 'op': 'EltwiseN' }, 'node_6': { 'type': 'Identity', 'value': None, 'kind': 'op' }, }, [ ('node_1', 'node_2'), ('node_2', 'add_n'), ('node_3', 'add_n'), ('node_4', 'add_n'), ('node_5', 'add_n'), ('add_n', 'node_6'), ], ) add_n_node = Node(graph, 'add_n') rep_op = EltwiseNReplacement() rep_op.replace_op(graph, add_n_node) eltwise_nodes = [ node for node, attrs in list(graph.nodes(data=True)) if attrs['type'] == 'Eltwise' ] self.assertEqual(len(eltwise_nodes), 3)
def test_eltwise_test_2(self): # EltwiseN test with N = 3 from 3 placeholders and operation = mul graph = build_graph(nodes_attributes, [('placeholder_1', 'EltwiseN_1'), ('placeholder_2', 'EltwiseN_1'), ('placeholder_3', 'EltwiseN_1'), ('EltwiseN_1', 'last')], { 'placeholder_1': { 'shape': np.array([1, 227, 227, 3]) }, 'placeholder_2': { 'shape': np.array([1, 227, 227, 3]) }, 'placeholder_3': { 'shape': np.array([1, 227, 227, 3]) }, 'EltwiseN_1': { 'operation': 'mul' }, }, nodes_with_edges_only=True) graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'mul_1'), ('placeholder_2', 'mul_1'), ('mul_1', 'mul_2'), ('placeholder_3', 'mul_2'), ('mul_2', 'last')], { 'placeholder_1': { 'shape': np.array([1, 227, 227, 3]) }, 'placeholder_2': { 'shape': np.array([1, 227, 227, 3]) }, 'placeholder_3': { 'shape': np.array([1, 227, 227, 3]) }, 'mul_1': { 'type': 'Multiply' }, 'mul_2': { 'type': 'Multiply' }, }, nodes_with_edges_only=True) graph.stage = 'front' replacer = EltwiseNReplacement() replacer.find_and_replace_pattern(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'last', check_op_attrs=True) self.assertTrue(flag, resp)