Ejemplo n.º 1
0
    def test_caffe_bn_decomposition_2(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'bn_op'),
                               ('bn_mean', 'bn_op'), ('bn_var', 'bn_op'),
                               ('bn_op', 'bn_data'), ('concat', 'concat_data'),
                               ('bn_data', 'concat')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'bn_op': {
                                       'epsilon': 1.2,
                                       'op': 'BatchNormalization',
                                       'can_be_fused': False
                                   },
                                   'bn_mean': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_var': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'concat_data': {
                                       'is_output': True
                                   }
                               })

        del graph['placeholder_1']['placeholder_1_data'][0]['in']
        del graph['bn_op']['bn_data'][0]['in']

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'mul_1'),
                               ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'),
                               ('mul_1_data', 'add_1'), ('add_1_w', 'add_1'),
                               ('add_1', 'add_1_data'),
                               ('concat', 'concat_data'),
                               ('add_1_data', 'concat')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'mul_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([0.67419986, 0.55901699, 0.48795004])
                },
                'add_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([-0.67419986, -1.11803399, -1.46385011])
                },
                'add_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'mul_1': {
                    'can_be_fused': False
                },
                'add_1': {
                    'can_be_fused': False
                },
                'concat_data': {
                    'is_output': True
                }
            })

        graph.graph['layout'] = 'NHWC'
        convert_bn_to_mul_add(graph)
        graph_clean_up(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
        self.assertTrue(flag, resp)
Ejemplo n.º 2
0
 def _create_node(attrs: dict):
     pb = onnx.helper.make_node("ConvTranspose", ["X", "W"], ["Y"], **attrs)
     graph = build_graph({'node_0': {'pb': pb}}, [])
     return Node(graph, 'node_0')
    def test(self):
        nodes_attributes = {
            'switch_2_input': {
                'shape': int64_array([1, 3]),
                'type': 'Parameter',
                'kind': 'op',
                'op': 'Parameter'
            },
            'switches_input': {
                'shape': int64_array([1, 3]),
                'type': 'Parameter',
                'kind': 'op',
                'op': 'Parameter'
            },
            'switch_input_0': {
                'kind': 'op',
                'op': 'SomeOp'
            },
            'switch_1_input_0': {
                'kind': 'op',
                'op': 'SomeOp'
            },
            'switch': {
                'kind': 'op',
                'op': 'Switch'
            },
            'switch_1': {
                'kind': 'op',
                'op': 'Switch'
            },
            'switch_2': {
                'kind': 'op',
                'op': 'Switch'
            },
            'some_op': {
                'kind': 'op',
                'op': 'Max'
            },
            'identity': {
                'kind': 'op',
                'op': 'Identity'
            },
            'merge': {
                'kind': 'op',
                'op': 'Merge'
            },
            'select': {
                'kind': 'op',
                'op': 'Select'
            },
            'last': {
                'type': None,
                'value': None,
                'kind': 'op',
                'op': 'Result'
            },
        }

        # check two cases when switch_2 goes to 0-th and 1-st input port of the Merge
        for merge_input_port in range(2):
            graph = build_graph(nodes_attributes, [
                ('switch_2_input', 'switch_2', {
                    'in': 0
                }),
                ('switch_input_0', 'switch', {
                    'in': 0
                }),
                ('switch_1_input_0', 'switch_1', {
                    'in': 0
                }),
                ('switches_input', 'switch', {
                    'in': 1,
                    'out': 0
                }),
                ('switches_input', 'switch_1', {
                    'in': 1,
                    'out': 0
                }),
                ('switches_input', 'switch_2', {
                    'in': 1,
                    'out': 0
                }),
                ('switch', 'some_op', {
                    'in': 0
                }),
                ('switch_1', 'some_op', {
                    'in': 1
                }),
                ('some_op', 'identity', {
                    'in': 0
                }),
                ('switch_2', 'merge', {
                    'in': merge_input_port
                }),
                ('identity', 'merge', {
                    'in': 1 - merge_input_port
                }),
                ('merge', 'last', {
                    'in': 0
                }),
            ],
                                nodes_with_edges_only=True)
            graph.stage = 'front'
            SwitchMergeOptimization().find_and_replace_pattern(graph)

            graph_ref = build_graph(nodes_attributes, [
                ('switches_input', 'select', {
                    'in': 0
                }),
                ('switch_2_input', 'select', {
                    'in': 1
                }),
                ('switch_input_0', 'some_op', {
                    'in': 0
                }),
                ('switch_1_input_0', 'some_op', {
                    'in': 1
                }),
                ('some_op', 'identity', {
                    'in': 0
                }),
                ('identity', 'select', {
                    'in': 2
                }),
                ('select', 'last', {
                    'in': 0
                }),
            ],
                                    nodes_with_edges_only=True)

            (flag, resp) = compare_graphs(graph,
                                          graph_ref,
                                          'last',
                                          check_op_attrs=True)
            self.assertTrue(flag, resp)
Ejemplo n.º 4
0
    def test_replace_node_several_consumers(self):
        graph = build_graph(
            {
                'input_1': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'input_2': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'old': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [
                ('input_1', 'old'),
                ('input_2', 'old'),
                ('old', 'output_3'),
                ('old', 'output_2'),
                ('old', 'output_1'),
            ])

        new_node = Const(graph, {
            'name': 'new'
        }).create_node([Node(graph, 'input_1'),
                        Node(graph, 'input_2')])
        replace_node(Node(graph, 'old'), new_node)

        self.assertEqual(len(graph.nodes()), 6)
        self.assertEqual(len(graph.edges()), 5)
        self.assertListEqual(sorted(graph.out_edges('new')),
                             [('new', 'output_1'), ('new', 'output_2'),
                              ('new', 'output_3')])
        expected_result = [('new', 'output_1', {
            'in': 0,
            'out': 2,
            'name': 'old'
        }), ('new', 'output_2', {
            'in': 0,
            'out': 1,
            'name': 'old'
        }), ('new', 'output_3', {
            'in': 0,
            'out': 0,
            'name': 'old'
        })]
        self.assertListEqual(sorted(graph.out_edges('new', data=True)),
                             expected_result)
Ejemplo n.º 5
0
    def test(self):
        nodes = {
            **regular_op('input', {'type': 'Parameter'}),
            **regular_op('shape', {
                'type': 'ShapeOf',
                'kind': 'op',
                'op': 'ShapeOf'
            }),
            **regular_op(
                'random_uniform', {
                    'type': 'RandomUniform',
                    'kind': 'op',
                    'op': 'RandomUniform',
                    'name': 'dropout/RU'
                }),
            **regular_op('mul', {
                'type': 'Mul',
                'kind': 'op',
                'op': 'Mul'
            }),
            **regular_op('add', {
                'type': 'Add',
                'kind': 'op',
                'op': 'Add'
            }),
            **regular_op('add2', {
                'type': 'Add',
                'kind': 'op',
                'op': 'Add'
            }),
            **regular_op('floor', {
                'type': 'Floor',
                'kind': 'op',
                'op': 'Floor'
            }),
            'add_const': {
                'kind': 'op',
                'op': 'Const',
                'value': np.array(0.0),
                'data_type': np.float32
            },
            **result('result'),

            # new nodes to be added
            'broadcast_const': {
                'kind': 'op',
                'op': 'Const',
                'value': np.array(0.5),
                'data_type': np.float32
            },
            **regular_op('broadcast', {
                'type': 'Broadcast',
                'kind': 'op',
                'op': 'Broadcast'
            }),
        }
        edges = [('input', 'shape'), ('shape', 'random_uniform'),
                 ('random_uniform', 'mul'), ('mul', 'add'),
                 ('add_const', 'add'), ('add', 'add2'), ('add2', 'floor'),
                 ('floor', 'result')]
        graph = build_graph(nodes, edges, nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'
        graph.stage = 'front'

        DropoutWithRandomUniformReplacer().find_and_replace_pattern(graph)

        edges_ref = [('input', 'shape'), ('broadcast_const', 'broadcast'),
                     ('shape', 'broadcast'), ('broadcast', 'mul'),
                     ('mul', 'add'), ('add_const', 'add'), ('add', 'add2'),
                     ('add2', 'floor'), ('floor', 'result')]
        graph_ref = build_graph(nodes, edges_ref, nodes_with_edges_only=True)

        # check graph structure after the transformation and output name
        (flag, resp) = compare_graphs(graph, graph_ref, 'result')
        self.assertTrue(flag, resp)
        self.assertTrue(graph.node[graph.get_nodes_with_attributes(
            op='Broadcast')[0]]['name'] == 'dropout/RU')
Ejemplo n.º 6
0
 def test_2d(self):
     graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable,
         update_attributes={
             'placeholder_data': {
                 'shape': int64_array([5, 8])
             },
             'dim': {
                 'value': int64_array([1])
             },
             'dim_data': {
                 'value': int64_array([1])
             },
             'unsqueeze_data': {
                 'shape': int64_array([5, 1, 8])
             },
             'multipliers': {
                 'value': int64_array([1, 10, 1])
             },
             'multipliers_data': {
                 'value': int64_array([1, 10, 1]),
                 'shape': int64_array([3])
             },
             'tile_data': {
                 'shape': int64_array([5, 10, 8])
             },
             'reshape_data': {
                 'shape': int64_array([50, 8])
             },
             'shape': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'shape_data': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'abs_data': {
                 'shape': int64_array([50, 8])
             },
         })
     ref_graph = build_graph(
         nodes_attrs=graph_node_attrs_when_transformation_is_not_applicable,
         edges=graph_edges_when_transformation_is_not_applicable,
         update_attributes={
             'placeholder_data': {
                 'shape': int64_array([5, 8])
             },
             'dim': {
                 'value': int64_array([1])
             },
             'dim_data': {
                 'value': int64_array([1])
             },
             'unsqueeze_data': {
                 'shape': int64_array([5, 1, 8])
             },
             'multipliers': {
                 'value': int64_array([1, 10, 1])
             },
             'multipliers_data': {
                 'value': int64_array([1, 10, 1]),
                 'shape': int64_array([3])
             },
             'tile_data': {
                 'shape': int64_array([5, 10, 8])
             },
             'reshape_data': {
                 'shape': int64_array([50, 8])
             },
             'shape': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'shape_data': {
                 'value': int64_array([50, 8]),
                 'shape': int64_array([2])
             },
             'abs_data': {
                 'shape': int64_array([50, 8])
             },
         })
     UnsqueezeTileReshapeBlockToInterpolate().find_and_replace_pattern(
         graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
    def test_several_memory_concat(self):
        graph = build_graph(
            {
                'in': {
                    'kind': 'op',
                    'op': None
                },
                'memory_2': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': 2
                },
                'memory_1': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': 1
                },
                'memory__3': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': -3
                },
                'concat': {
                    'kind': 'op',
                    'op': 'Concat'
                }
            }, [('in', 'memory_2', {
                'out': 0
            }), ('in', 'memory_1', {
                'out': 1
            }), ('in', 'memory__3', {
                'out': 3
            }), ('memory_2', 'concat', {
                'in': 0
            }), ('memory_1', 'concat', {
                'in': 1
            }), ('in', 'concat', {
                'in': 2,
                'out': 2
            }), ('memory__3', 'concat', {
                'in': 3
            })],
            nodes_with_edges_only=True)
        graph.stage = 'front'

        ref_graph = build_graph(
            {
                'in': {
                    'kind': 'op',
                    'op': None
                },
                'memory__5': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': -5
                },
                'memory__1': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': -1
                },
                'memory__2': {
                    'kind': 'op',
                    'op': 'MemoryOffset',
                    't': -2
                },
                'concat': {
                    'kind': 'op',
                    'op': 'Concat'
                }
            }, [('in', 'memory__5', {
                'out': 3
            }), ('in', 'memory__1', {
                'out': 1
            }), ('in', 'memory__2', {
                'out': 2
            }), ('in', 'concat', {
                'in': 0,
                'out': 0
            }), ('memory__2', 'concat', {
                'in': 2
            }), ('memory__1', 'concat', {
                'in': 1
            }), ('memory__5', 'concat', {
                'in': 3
            })],
            nodes_with_edges_only=True)

        MemoryOffsetAdjustment().find_and_replace_pattern(graph)

        (flag, resp) = compare_graphs(graph,
                                      ref_graph,
                                      'concat',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 8
0
 def test_there_are_four_inputs_and_first_and_third_input_have_zero_in_their_shapes(self):
     graph = build_graph(
         nodes_attrs={
             'const0': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([5, 0]),
                 'value': np.zeros((5, 0))
             },
             'const0_data': {'kind': 'data', 'shape': int64_array([5, 0]), 'value': None},
             'placeholder': {'kind': 'op', 'type': 'Parameter', 'op': 'Parameter'},
             'placeholder_data': {
                 'kind': 'data',
                 'value': None,
                 'shape': int64_array([5, 17]),
                 'data_type': None
             },
             'const2': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([5, 0]),
                 'value': np.zeros((5, 0))
             },
             'const2_data': {'kind': 'data', 'shape': int64_array([5, 0]), 'value': None},
             'const3': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([5, 23]),
                 'value': np.zeros((5, 23))
             },
             'const3_data': {'kind': 'data', 'shape': int64_array([5, 23]), 'value': None},
             'concat': {'kind': 'op', 'type': 'Concat', 'op': 'Concat', 'axis': 1},
             'concat_data': {'kind': 'data', 'shape': int64_array([5, 40]), 'value': None},
             'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
         },
         edges=[
             ('const0', 'const0_data'),
             ('placeholder', 'placeholder_data'),
             ('const2', 'const2_data'),
             ('const3', 'const3_data'),
             ('const0_data', 'concat', {'in': 0}),
             ('placeholder_data', 'concat', {'in': 1}),
             ('const2_data', 'concat', {'in': 2}),
             ('const3_data', 'concat', {'in': 3}),
             ('concat', 'concat_data'),
             ('concat_data', 'output')
         ]
     )
     ref_graph = build_graph(
         nodes_attrs={
             'placeholder': {'kind': 'op', 'type': 'Parameter', 'op': 'Parameter'},
             'placeholder_data': {
                 'kind': 'data',
                 'value': None,
                 'shape': int64_array([5, 17]),
                 'data_type': None
             },
             'const3': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([5, 23]),
                 'value': np.zeros((5, 23))
             },
             'const3_data': {'kind': 'data', 'shape': int64_array([5, 23]), 'value': None},
             'concat': {'kind': 'op', 'type': 'Concat', 'op': 'Concat', 'axis': 1},
             'concat_data': {'kind': 'data', 'shape': int64_array([5, 40]), 'value': None},
             'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
         },
         edges=[
             ('placeholder', 'placeholder_data'),
             ('const3', 'const3_data'),
             ('placeholder_data', 'concat', {'in': 0}),
             ('const3_data', 'concat', {'in': 1}),
             ('concat', 'concat_data'),
             ('concat_data', 'output')
         ]
     )
     CutInputHavingZeroDimFromConcat().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
Ejemplo n.º 9
0
 def test_when_there_are_three_inputs_and_middle_constant_has_zero_in_shape(self):
     graph = build_graph(
         nodes_attrs={
             'const0': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([1, 2, 5]),
                 'value': np.zeros((1, 2, 5))
             },
             'const0_data': {'kind': 'data', 'shape': int64_array([1, 2, 5]), 'value': None},
             'const1': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([1, 2, 0]),
                 'value': np.zeros((1, 2, 0))
             },
             'const1_data': {'kind': 'data', 'shape': int64_array([1, 2, 0]), 'value': None},
             'placeholder': {'kind': 'op', 'type': 'Parameter', 'op': 'Parameter'},
             'placeholder_data': {
                 'kind': 'data',
                 'value': None,
                 'shape': int64_array([1, 2, 17]),
                 'data_type': None
             },
             'concat': {'kind': 'op', 'type': 'Concat', 'op': 'Concat', 'axis': 2},
             'concat_data': {'kind': 'data', 'shape': int64_array([1, 2, 22]), 'value': None},
             'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
         },
         edges=[
             ('const0', 'const0_data'),
             ('const1', 'const1_data'),
             ('placeholder', 'placeholder_data'),
             ('const0_data', 'concat', {'in': 0}),
             ('const1_data', 'concat', {'in': 1}),
             ('placeholder_data', 'concat', {'in': 2}),
             ('concat', 'concat_data'),
             ('concat_data', 'output')
         ]
     )
     ref_graph = build_graph(
         nodes_attrs={
             'const0': {
                 'kind': 'op',
                 'type': 'Const',
                 'op': 'Const',
                 'shape': int64_array([1, 2, 5]),
                 'value': np.zeros((1, 2, 5))
             },
             'const0_data': {'kind': 'data', 'shape': int64_array([1, 2, 5]), 'value': None},
             'placeholder': {'kind': 'op', 'type': 'Parameter', 'op': 'Parameter'},
             'placeholder_data': {
                 'kind': 'data',
                 'value': None,
                 'shape': int64_array([1, 2, 17]),
                 'data_type': None
             },
             'concat': {'kind': 'op', 'type': 'Concat', 'op': 'Concat', 'axis': 2},
             'concat_data': {'kind': 'data', 'shape': int64_array([1, 2, 22]), 'value': None},
             'output': {'kind': 'op', 'op': 'Result', 'type': 'Result'},
         },
         edges=[
             ('const0', 'const0_data'),
             ('placeholder', 'placeholder_data'),
             ('const0_data', 'concat', {'in': 0}),
             ('placeholder_data', 'concat', {'in': 1}),
             ('concat', 'concat_data'),
             ('concat_data', 'output')
         ]
     )
     CutInputHavingZeroDimFromConcat().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
Ejemplo n.º 10
0
 def test_conversion(self, input_shape, scales, axes):
     input_shape_as_array = int64_array(input_shape)
     scales_as_array = float32_array(scales)
     graph = build_graph(
         graph_node_attrs, graph_edges, {
             'placeholder_data': {
                 'shape': input_shape_as_array
             },
             'scales': {
                 'value': scales_as_array,
                 'shape': scales_as_array.shape
             },
             'scales_data': {
                 'value': scales_as_array,
                 'shape': scales_as_array.shape
             },
             'upsample_data': {
                 'shape':
                 ((input_shape_as_array + 1.e-5) * scales_as_array).astype(
                     np.int64)
             }
         })
     graph.graph['layout'] = 'NCHW'
     ref_graph = build_graph(
         new_ref_graph_node_attr, new_ref_graph_edges, {
             'placeholder_data': {
                 'shape': int64_array(input_shape)
             },
             'ss_begin': {
                 'value': int64_array([axes[0]])
             },
             'ss_end': {
                 'value': int64_array([axes[-1] + 1])
             },
             'ss_begin_data': {
                 'value': int64_array([axes[0]])
             },
             'ss_end_data': {
                 'value': int64_array([axes[-1] + 1])
             },
             'factor': {
                 'value': scales_as_array[2:],
                 'shape': scales_as_array[2:].shape
             },
             'factor_data': {
                 'value': scales_as_array[2:],
                 'shape': scales_as_array[2:].shape
             },
             'axes_const': {
                 'value': int64_array(axes),
                 'shape': int64_array(axes).shape
             },
             'interpolate_data': {
                 'shape':
                 (input_shape_as_array * scales_as_array + 1e-5).astype(
                     np.int64)
             },
         })
     UpsampleToResample().find_and_replace_pattern(graph)
     (flag, resp) = compare_graphs(graph, ref_graph, 'output')
     self.assertTrue(flag, resp)
Ejemplo n.º 11
0
    def test_1(self):
        graph = build_graph(nodes_attributes,
                            [('placeholder', 'shuffle_channel'),
                             ('shuffle_channel', 'result')],
                            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'
        graph.stage = 'front'

        ref_graph = build_graph(nodes_attributes, [('placeholder', 'shape', {
            'in': 0,
            'out': 0
        }), ('shape', 'batch_gather', {
            'in': 0,
            'out': 0
        }), ('batch_gather_idx', 'batch_gather', {
            'in': 1,
            'out': 0
        }), ('batch_gather_axis', 'batch_gather', {
            'in': 2,
            'out': 0
        }), ('shape', 'channel_gather', {
            'in': 0,
            'out': 0
        }), ('channel_gather_idx', 'channel_gather', {
            'in': 1,
            'out': 0
        }), ('channel_gather_axis', 'channel_gather', {
            'in': 2,
            'out': 0
        }), ('channel_gather', 'output_channels', {
            'in': 0,
            'out': 0
        }), ('div_group', 'output_channels', {
            'in': 1,
            'out': 0
        }), ('output_channels', 'convert', {
            'in': 0,
            'out': 0
        }), ('batch_gather', 'concat', {
            'in': 0,
            'out': 0
        }), ('group', 'concat', {
            'in': 1,
            'out': 0
        }), ('convert', 'concat', {
            'in': 2,
            'out': 0
        }), ('const', 'concat', {
            'in': 3,
            'out': 0
        }), ('placeholder', 'reshape_split', {
            'in': 0,
            'out': 0
        }), ('concat', 'reshape_split', {
            'in': 1,
            'out': 0
        }), ('reshape_split', 'transpose', {
            'in': 0,
            'out': 0
        }), ('transpose_const', 'transpose', {
            'in': 1,
            'out': 0
        }), ('transpose', 'reshape_concat', {
            'in': 0,
            'out': 0
        }), ('shape', 'reshape_concat', {
            'in': 1,
            'out': 0
        }), ('reshape_concat', 'result')],
                                nodes_with_edges_only=True)

        ShuffleChannel().find_and_replace_pattern(graph)
        (flag, resp) = compare_graphs(graph,
                                      ref_graph,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
        self.assertTrue(
            Node(graph, 'result').in_port(0).get_source().node.name ==
            'scname')
Ejemplo n.º 12
0
 def create_graph(cls):
     cls.graph = build_graph(cls.nodes_attributes,
                             [('input_data_node', 'test_node'),
                              ('test_node', 'output_data_node')],
                             nodes_with_edges_only=True)
    def test_splice(self):
        graph = build_graph(self.nodes_attributes,
                            [('placeholder', 'in_node'), ('in_node', 'splice'),
                             ('splice', 'splice_data'),
                             ('splice_data', 'out_placeholder')])
        ReplaceSpliceNodePattern().find_and_replace_pattern(graph)

        ref_graph = build_graph(
            {
                'in_placeholder': {
                    'kind': 'op',
                    'op': None
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'shape': {
                    'kind': 'op',
                    'op': 'ShapeOf'
                },
                'shape_data': {
                    'kind': 'data'
                },
                'crop_batch': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': int64_array([0])
                },
                'crop_batch_data': {
                    'kind': 'data'
                },
                'crop_batch_dim': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([1])
                },
                'crop_batch_dim_data': {
                    'kind': 'data'
                },
                'second_dim': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([143])
                },
                'second_dim_data': {
                    'kind': 'data'
                },
                'gather_shape': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'gather_shape_data': {
                    'kind': 'data'
                },
                'fill_value': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([0])
                },
                'fill_value_data': {
                    'kind': 'data'
                },
                'broadcast': {
                    'kind': 'op',
                    'op': 'Broadcast'
                },
                'broadcast_data': {
                    'kind': 'data'
                },
                'memory_in': {
                    'kind': 'op',
                    'op': 'ReadValue'
                },
                'memory_in_data': {
                    'kind': 'data'
                },
                'crop_mem': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 13,
                    'dim': 130
                },
                'crop_mem_data': {
                    'kind': 'data'
                },
                'concat': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'concat_data': {
                    'kind': 'data',
                    'shape': [1, 143]
                },
                'memory_out': {
                    'kind': 'op',
                    'op': 'Assign'
                },
                'memory_out_data': {
                    'kind': 'data'
                },
                'result': {
                    'kind': 'op',
                    'op': 'Result'
                },
                'out_placeholder': {
                    'kind': 'op',
                    'op': 'placeholder'
                },
            }, [
                ('in_placeholder', 'in_node'),
                ('in_node', 'shape'),
                ('shape', 'shape_data'),
                ('shape_data', 'crop_batch'),
                ('crop_batch', 'crop_batch_data'),
                ('crop_batch_dim', 'crop_batch_dim_data'),
                ('crop_batch_dim_data', 'crop_batch', {
                    'in': 1
                }),
                ('second_dim', 'second_dim_data'),
                ('second_dim_data', 'gather_shape', {
                    'in': 1
                }),
                ('crop_batch_data', 'gather_shape', {
                    'in': 0
                }),
                ('gather_shape', 'gather_shape_data'),
                ('fill_value', 'fill_value_data'),
                ('fill_value_data', 'broadcast', {
                    'in': 0
                }),
                ('gather_shape_data', 'broadcast', {
                    'in': 1
                }),
                ('broadcast', 'broadcast_data'),
                ('broadcast_data', 'memory_in'),
                ('memory_in', 'memory_in_data'),
                ('memory_in_data', 'crop_mem'),
                ('crop_mem', 'crop_mem_data'),
                ('crop_mem_data', 'concat', {
                    'in': 0
                }),
                ('in_node', 'concat', {
                    'in': 1
                }),
                ('concat', 'concat_data'),
                ('concat_data', 'memory_out'),
                ('memory_out', 'memory_out_data'),
                ('memory_out_data', 'result'),
                ('concat_data', 'out_placeholder'),
            ])

        (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
        self.assertTrue(flag, resp)
    def test_splice_with_constdim(self):
        graph = build_graph(self.nodes_attributes,
                            [('placeholder', 'in_node'), ('in_node', 'splice'),
                             ('splice', 'splice_data'),
                             ('splice_data', 'out_placeholder')])
        Node(graph, 'splice')['const_dim'] = 10
        Node(graph, 'splice_data')['shape'] = [1, 43]
        ReplaceSpliceNodePattern().find_and_replace_pattern(graph)

        ref_graph = build_graph(
            {
                'in_placeholder': {
                    'kind': 'op',
                    'op': None
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'split': {
                    'kind': 'op',
                    'op': 'Split'
                },
                'split_data_0': {
                    'kind': 'data'
                },
                'split_data_1': {
                    'kind': 'data'
                },
                'shape': {
                    'kind': 'op',
                    'op': 'ShapeOf'
                },
                'shape_data': {
                    'kind': 'data'
                },
                'crop_batch': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': int64_array([0])
                },
                'crop_batch_data': {
                    'kind': 'data'
                },
                'crop_batch_dim': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([1])
                },
                'crop_batch_dim_data': {
                    'kind': 'data'
                },
                'second_dim': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([33])
                },
                'second_dim_data': {
                    'kind': 'data'
                },
                'gather_shape': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'gather_shape_data': {
                    'kind': 'data'
                },
                'fill_value': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([0])
                },
                'fill_value_data': {
                    'kind': 'data'
                },
                'broadcast': {
                    'kind': 'op',
                    'op': 'Broadcast'
                },
                'broadcast_data': {
                    'kind': 'data'
                },
                'memory_in': {
                    'kind': 'op',
                    'op': 'ReadValue'
                },
                'memory_in_data': {
                    'kind': 'data'
                },
                'crop_mem': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 3,
                    'dim': 30
                },
                'crop_mem_data': {
                    'kind': 'data'
                },
                'concat': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'concat_data': {
                    'kind': 'data'
                },
                'memory_out': {
                    'kind': 'op',
                    'op': 'Assign'
                },
                'memory_out_data': {
                    'kind': 'data'
                },
                'result': {
                    'kind': 'op',
                    'op': 'Result'
                },
                'shape_2': {
                    'kind': 'op',
                    'op': 'ShapeOf'
                },
                'shape_2_data': {
                    'kind': 'data'
                },
                'crop_batch_2': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': int64_array([0])
                },
                'crop_batch_2_data': {
                    'kind': 'data'
                },
                'crop_batch_dim_2': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([1])
                },
                'crop_batch_dim_2_data': {
                    'kind': 'data'
                },
                'second_dim_2': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([33])
                },
                'second_dim_2_data': {
                    'kind': 'data'
                },
                'gather_shape_2': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'gather_shape_2_data': {
                    'kind': 'data'
                },
                'fill_value_2': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([0])
                },
                'fill_value_2_data': {
                    'kind': 'data'
                },
                'broadcast_2': {
                    'kind': 'op',
                    'op': 'Broadcast'
                },
                'broadcast_2_data': {
                    'kind': 'data'
                },
                'memory_in_constdims': {
                    'kind': 'op',
                    'op': 'ReadValue'
                },
                'memory_in_constdims_data': {
                    'kind': 'data'
                },
                'crop_mem_constdims': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 10,
                    'dim': 100
                },
                'crop_mem_constdims_data': {
                    'kind': 'data'
                },
                'concat_constdims': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'concat_constdims_data': {
                    'kind': 'data'
                },
                'memory_out_constdims': {
                    'kind': 'op',
                    'op': 'Assign'
                },
                'memory_out_constdims_data': {
                    'kind': 'data'
                },
                'result_constdims': {
                    'kind': 'op',
                    'op': 'Result'
                },
                'crop_first_constdims': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 0,
                    'dim': 10
                },
                'crop_first_constdims_data': {
                    'kind': 'data'
                },
                'concat_all': {
                    'kind': 'op',
                    'op': 'Concat'
                },
                'concat_all_data': {
                    'kind': 'data',
                    'shape': [1, 43]
                },
                'out_placeholder': {
                    'kind': 'op',
                    'op': 'placeholder'
                },
                'axis_const': {
                    'kind': 'op'
                },
                'axis_const_data': {
                    'value': None,
                    'shape': None,
                    'kind': 'data'
                },
                'split_dim_const': {
                    'kind': 'op'
                },
                'split_dim_const_data': {
                    'value': None,
                    'shape': None,
                    'kind': 'data'
                },
            }, [
                ('in_placeholder', 'in_node'),
                ('in_node', 'split', {
                    'in': 0
                }),
                ('split', 'split_data_0', {
                    'out': 0
                }),
                ('split', 'split_data_1', {
                    'out': 1
                }),
                ('split_data_0', 'shape'),
                ('shape', 'shape_data'),
                ('shape_data', 'crop_batch'),
                ('crop_batch', 'crop_batch_data'),
                ('crop_batch_dim', 'crop_batch_dim_data'),
                ('crop_batch_dim_data', 'crop_batch', {
                    'in': 1
                }),
                ('second_dim', 'second_dim_data'),
                ('second_dim_data', 'gather_shape', {
                    'in': 1
                }),
                ('crop_batch_data', 'gather_shape', {
                    'in': 0
                }),
                ('gather_shape', 'gather_shape_data'),
                ('fill_value', 'fill_value_data'),
                ('fill_value_data', 'broadcast', {
                    'in': 0
                }),
                ('gather_shape_data', 'broadcast', {
                    'in': 1
                }),
                ('broadcast', 'broadcast_data'),
                ('broadcast_data', 'memory_in'),
                ('memory_in', 'memory_in_data'),
                ('memory_in_data', 'crop_mem'),
                ('crop_mem', 'crop_mem_data'),
                ('crop_mem_data', 'concat', {
                    'in': 0
                }),
                ('split_data_0', 'concat', {
                    'in': 1
                }),
                ('concat', 'concat_data'),
                ('concat_data', 'memory_out'),
                ('memory_out', 'memory_out_data'),
                ('memory_out_data', 'result'),
                ('split_data_1', 'shape_2'),
                ('shape_2', 'shape_2_data'),
                ('shape_2_data', 'crop_batch_2'),
                ('crop_batch_2', 'crop_batch_2_data'),
                ('crop_batch_dim_2', 'crop_batch_dim_2_data'),
                ('crop_batch_dim_2_data', 'crop_batch_2', {
                    'in': 1
                }),
                ('second_dim_2', 'second_dim_2_data'),
                ('second_dim_2_data', 'gather_shape_2', {
                    'in': 1
                }),
                ('crop_batch_2_data', 'gather_shape_2', {
                    'in': 0
                }),
                ('gather_shape_2', 'gather_shape_2_data'),
                ('fill_value_2', 'fill_value_2_data'),
                ('fill_value_2_data', 'broadcast_2', {
                    'in': 0
                }),
                ('gather_shape_2_data', 'broadcast_2', {
                    'in': 1
                }),
                ('broadcast_2', 'broadcast_2_data'),
                ('broadcast_2_data', 'memory_in_constdims'),
                ('memory_in_constdims', 'memory_in_constdims_data'),
                ('memory_in_constdims_data', 'crop_mem_constdims'),
                ('crop_mem_constdims', 'crop_mem_constdims_data'),
                ('crop_mem_constdims_data', 'concat_constdims', {
                    'in': 0
                }),
                ('split_data_1', 'concat_constdims', {
                    'in': 1
                }),
                ('concat_constdims', 'concat_constdims_data'),
                ('concat_constdims_data', 'memory_out_constdims'),
                ('memory_out_constdims', 'memory_out_constdims_data'),
                ('memory_out_constdims_data', 'result_constdims'),
                ('concat_constdims_data', 'crop_first_constdims'),
                ('crop_first_constdims', 'crop_first_constdims_data'),
                ('crop_first_constdims_data', 'concat_all', {
                    'in': 1
                }),
                ('concat_data', 'concat_all', {
                    'in': 0
                }),
                ('concat_all', 'concat_all_data'),
                ('concat_all_data', 'out_placeholder'),
                ('axis_const', 'axis_const_data'),
                ('split_dim_const', 'split_dim_const_data'),
                ('axis_const_data', 'split', {
                    'in': 1
                }),
                ('split_dim_const_data', 'split', {
                    'in': 2
                }),
            ])

        (flag, resp) = compare_graphs(graph, ref_graph, 'out_placeholder')
        self.assertTrue(flag, resp)
Ejemplo n.º 15
0
 def test_two_inputs_two_shapes_positive_1(self):
     shape_1 = [1, 2, 3, 4]
     shape_2 = [4, 3, 2, 1]
     inputs = {
         'node_1': [{
             'shape': shape_1
         }],
         'node_4': [{
             'shape': shape_2
         }]
     }
     nodes = {
         'input_1': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'Placeholder'
         },
         'input_2': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'Placeholder'
         },
         'node_1': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'node_2': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'node_3': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'node_4': {
             'type': 'Identity',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'output': {
             'kind': 'op',
             'op': 'OpOutput'
         }
     }
     edges = [('input_1', 'node_1'), ('node_1', 'node_2'),
              ('node_3', 'output'), ('input_2', 'node_4'),
              ('node_4', 'output')]
     graph = build_graph(nodes, edges)
     add_input_ops(graph=graph,
                   user_defined_inputs=inputs,
                   before_infer=True)
     new_input_1 = list(graph.in_edges('node_1'))[0][0]
     new_input_2 = list(graph.in_edges('node_4'))[0][0]
     self.assertFalse(graph.node['input_1']['is_input'])
     self.assertTrue(graph.node[new_input_1]['is_input'])
     self.assertTrue(graph.node[new_input_2]['is_input'])
     self.assertTrue((new_input_1, 'node_1') in graph.edges())
     self.assertTrue((new_input_2, 'node_4') in graph.edges())
     self.assertListEqual(shape_1, graph.node[new_input_1]['shape'])
     self.assertListEqual(shape_2, graph.node[new_input_2]['shape'])
Ejemplo n.º 16
0
    def test_scaleshift2_axis1_to_mul(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_2', 'placeholder_2_data'),
                               ('placeholder_1_data', 'scaleshift_1'),
                               ('placeholder_2_data', 'scaleshift_1'),
                               ('scaleshift_1', 'scaleshift_1_data'),
                               ('scaleshift_1_data', 'op_output')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'placeholder_2_data': {
                                       'shape': np.array([227])
                                   },
                                   'scaleshift_1': {
                                       'axis': 1
                                   },
                                   'scaleshift_1_data': {}
                               })

        graph_ref = build_graph(
            nodes_attributes,
            [('placeholder_1', 'placeholder_1_data'),
             ('placeholder_2', 'placeholder_2_data'),
             ('placeholder_2_data', 'placeholder_2/Reshape_'),
             ('placeholder_2/Reshape_const',
              'placeholder_2/Reshape_const_data'),
             ('placeholder_2/Reshape_const_data', 'placeholder_2/Reshape_'),
             ('placeholder_2/Reshape_', 'placeholder_2/Reshape_data'),
             ('placeholder_1_data', 'mul_1'),
             ('placeholder_2/Reshape_data', 'mul_1'),
             ('mul_1', 'scaleshift_1_data'),
             ('scaleshift_1_data', 'op_output')], {
                 'placeholder_1_data': {
                     'shape': np.array([1, 227, 227, 3])
                 },
                 'placeholder_2_data': {
                     'shape': np.array([227])
                 },
                 'placeholder_2/Reshape_const': {
                     'value': np.array([1, 227, 1, 1]),
                     'shape': [4]
                 },
                 'placeholder_2/Reshape_const_data': {
                     'value': np.array([1, 227, 1, 1]),
                     'shape': [4]
                 },
                 'placeholder_2/Reshape_data': {
                     'shape': np.array([1, 227, 1, 1])
                 },
                 'mul_1': {
                     'can_be_fused': True
                 },
                 'scaleshift_1_data': {}
             })

        graph.graph['layout'] = 'NHWC'
        convert_scale_shift_to_mul_add(graph)
        graph.clean_up()
        (flag, resp) = compare_graphs(graph, graph_ref, 'placeholder_1')
        self.assertTrue(flag, resp)
Ejemplo n.º 17
0
    def test_add_input_with_output_port_after_infer(self):
        shape = np.array([1, 2, 3, 4])
        inputs = {'conv_1': [{'shape': shape, 'out': 0}]}
        nodes = {
            'old_input': {
                'type': 'Identity',
                'kind': 'op',
                'op': 'Placeholder'
            },
            'inp_data': {
                'kind': 'data',
                'shape': shape + 1
            },
            'conv_1': {
                'type': 'Convolution',
                'kind': 'op',
                'op': 'NotPlaceholder'
            },
            'conv_data': {
                'kind': 'data',
                'shape': shape
            },
            'relu_1': {
                'type': 'ReLU',
                'kind': 'op',
                'op': 'NotPlaceholder'
            },
        }
        edges = [
            ('old_input', 'inp_data'),
            ('inp_data', 'conv_1'),
            ('conv_1', 'conv_data'),
            ('conv_data', 'relu_1'),
        ]
        graph = build_graph(nodes, edges)
        add_input_ops(graph=graph,
                      user_defined_inputs=inputs,
                      before_infer=False)

        graph_ref = build_graph(
            nodes_attrs={
                'new_input': {
                    'kind': 'op',
                    'op': 'Placeholder',
                    'shape': shape
                },
                **nodes
            },
            edges=[
                ('old_input', 'inp_data'),
                ('inp_data', 'conv_1'),
                ('new_input', 'conv_data'),
                ('conv_data', 'relu_1'),
            ],
        )
        # Check that new input is added right (with right ports !)
        (flag, resp) = compare_graphs(graph, graph_ref, last_node='relu_1')
        self.assertTrue(flag, resp)

        # Check that other graph is not damaged
        (flag, resp) = compare_graphs(graph, graph_ref, last_node='conv_1')
        self.assertTrue(flag, resp)

        # Checks for new input and edges
        self.assertTrue('conv_1/placeholder_out_port_0' in graph.nodes())
        new_input = 'conv_1/placeholder_out_port_0'

        self.assertTrue(graph.node[new_input]['is_input'])
        self.assertTrue((new_input, 'conv_data') in graph.edges())
        self.assertTrue(('conv_1', 'conv_data') not in graph.edges())
Ejemplo n.º 18
0
    def test_scaleshift_can_be_fused(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'scaleshift_1'),
                               ('const_scaleshift_1_w', 'scaleshift_1_w'),
                               ('const_scaleshift_1_b', 'scaleshift_1_b'),
                               ('scaleshift_1_w', 'scaleshift_1'),
                               ('scaleshift_1_b', 'scaleshift_1'),
                               ('scaleshift_1', 'scaleshift_1_data'),
                               ('scaleshift_1_data', 'op_output')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'scaleshift_1_w': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 1, 1])
                                   },
                                   'scaleshift_1_b': {
                                       'shape': np.array([3]),
                                       'value': np.array([0, 0, 0])
                                   },
                                   'scaleshift_1': {
                                       'can_be_fused': False
                                   },
                                   'scaleshift_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   }
                               })

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'scaleshift_1'),
                               ('const_scaleshift_1_w', 'scaleshift_1_w'),
                               ('const_scaleshift_1_b', 'scaleshift_1_b'),
                               ('scaleshift_1_w', 'scaleshift_1'),
                               ('scaleshift_1_b', 'scaleshift_1'),
                               ('scaleshift_1', 'scaleshift_1_data'),
                               ('scaleshift_1_data', 'op_output')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'const_scaleshift_1_w': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 1, 1])
                                   },
                                   'scaleshift_1_w': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 1, 1])
                                   },
                                   'const_scaleshift_1_b': {
                                       'shape': np.array([3]),
                                       'value': np.array([0, 0, 0])
                                   },
                                   'scaleshift_1_b': {
                                       'shape': np.array([3]),
                                       'value': np.array([0, 0, 0])
                                   },
                                   'scaleshift_1': {
                                       'can_be_fused': False
                                   },
                                   'scaleshift_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   }
                               })

        convert_scale_shift_to_mul_add(graph)
        graph.clean_up()

        (flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
        self.assertTrue(flag, resp)
Ejemplo n.º 19
0
    def test_replacer(self):
        graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=[
                ('slice_like', 'model_reshape0', {'in': 0}),
                ('model_reshape0_const', 'model_reshape0', {'in': 1}),
                ('model_reshape0', 'model_reshape1', {'in': 0}),
                ('model_reshape1_const', 'model_reshape1', {'in': 1}),
                ('model_reshape1', 'model_reshape2', {'in': 0}),
                ('model_reshape2_const', 'model_reshape2', {'in': 1}),
                ('model_reshape2', 'reshape0', {'in': 0}),
                ('reshape0_const', 'reshape0', {'in': 1}),
                ('reshape0', 'concat'),
                ('concat', 'detection_output', {'in': 2})
            ],
            nodes_with_edges_only=True
        )

        ref_graph = build_graph(
            nodes_attrs=nodes_attributes,
            edges=[
                ('slice_like', 'model_reshape0', {'in': 0}),
                ('model_reshape0_const', 'model_reshape0', {'in': 1}),
                ('model_reshape0', 'model_reshape1', {'in': 0}),
                ('model_reshape1_const', 'model_reshape1', {'in': 1}),
                ('model_reshape1', 'model_reshape2', {'in': 0}),
                ('model_reshape2_const', 'model_reshape2', {'in': 1}),
                ('model_reshape2', 'reshape0', {'in': 0}),
                ('reshape0_const', 'reshape0', {'in': 1}),
                ('reshape0', 'concat'),
                ('concat', 'reshape1', {'in': 0}),
                ('reshape1_const', 'reshape1', {'in': 1}),
                ('reshape1', 'split', {'in': 0}),
                ('split_const', 'split', {'in': 1}),
                ('split', 'reshape2', {'out': 0, 'in': 0}),
                ('reshape2_const', 'reshape2', {'in': 1}),
                ('reshape2', 'value', {'in': 0}),
                ('value_const', 'value', {'in': 1}),
                ('value', 'xmin', {'out': 0, 'in': 0}),
                ('value', 'ymin', {'out': 1, 'in': 0}),
                ('value', 'xmax', {'out': 0, 'in': 1}),
                ('value', 'ymax', {'out': 1, 'in': 1}),
                ('value', 'div_1', {'out': 2, 'in': 0}),
                ('value', 'div_2', {'out': 3, 'in': 0}),
                ('div_1_const', 'div_1', {'in': 1}),
                ('div_2_const', 'div_2', {'in': 1}),
                ('div_1', 'xmin', {'in': 1, 'out': 0}),
                ('div_1', 'xmax', {'in': 0, 'out': 0}),
                ('div_2', 'ymin', {'in': 1, 'out': 0}),
                ('div_2', 'ymax', {'in': 0, 'out': 0}),
                ('xmin', 'concat_value', {'in': 0}),
                ('ymin', 'concat_value', {'in': 1}),
                ('xmax', 'concat_value', {'in': 2}),
                ('ymax', 'concat_value', {'in': 3}),
                ('concat_value', 'reshape3', {'in': 0}),
                ('reshape3_const', 'reshape3', {'in': 1}),
                ('reshape3', 'end_concat', {'in': 0}),
                ('split', 'end_concat', {'in': 1}),
                ('end_concat', 'detection_output', {'in': 2})
            ],
            update_attributes={
                'concat': {'axis': 1}
            },
            nodes_with_edges_only=True
        )
        graph.stage = 'front'
        graph.graph['cmd_params'].data_type = 'FP32'
        SsdAnchorsReplacer().find_and_replace_pattern(graph)
        flag, resp = compare_graphs(graph, ref_graph, 'detection_output', check_op_attrs=True)
        self.assertTrue(flag, resp)
Ejemplo n.º 20
0
    def test_bn_decomposition_2(self):
        graph = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'bn_op'),
                               ('const_bn_const', 'bn_const'),
                               ('const_bn_beta', 'bn_beta'),
                               ('const_bn_mean', 'bn_mean'),
                               ('const_bn_var', 'bn_var'),
                               ('bn_const', 'bn_op'), ('bn_beta', 'bn_op'),
                               ('bn_mean', 'bn_op'), ('bn_var', 'bn_op'),
                               ('bn_op', 'bn_data'), ('concat', 'concat_data'),
                               ('bn_data', 'concat'),
                               ('concat_data', 'op_output')], {
                                   'placeholder_1_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'bn_op': {
                                       'eps': 1.2,
                                       'can_be_fused': False
                                   },
                                   'bn_const': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_beta': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_mean': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_var': {
                                       'shape': np.array([3]),
                                       'value': np.array([1, 2, 3])
                                   },
                                   'bn_data': {
                                       'shape': np.array([1, 227, 227, 3])
                                   },
                                   'concat_data': {}
                               })

        graph_ref = build_graph(
            nodes_attributes, [('placeholder_1', 'placeholder_1_data'),
                               ('placeholder_1_data', 'mul_1'),
                               ('const_mul_1_w', 'mul_1_w'),
                               ('mul_1_w', 'mul_1'), ('mul_1', 'mul_1_data'),
                               ('mul_1_data', 'add_1'),
                               ('const_add_1_w', 'add_1_w'),
                               ('add_1_w', 'add_1'), ('add_1', 'add_1_data'),
                               ('add_1_data', 'mul_2'),
                               ('const_mul_2_w', 'mul_2_w'),
                               ('mul_2_w', 'mul_2'), ('mul_2', 'mul_2_data'),
                               ('mul_2_data', 'add_2'),
                               ('const_add_2_w', 'add_2_w'),
                               ('add_2_w', 'add_2'), ('add_2', 'add_2_data'),
                               ('concat', 'concat_data'),
                               ('add_2_data', 'concat'),
                               ('concat_data', 'op_output')],
            {
                'placeholder_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'const_mul_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([0.67419986, 0.55901699, 0.48795004])
                },
                'mul_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([0.67419986, 0.55901699, 0.48795004])
                },
                'const_mul_2_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'mul_2_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'const_add_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([-0.67419986, -1.11803399, -1.46385011])
                },
                'add_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([-0.67419986, -1.11803399, -1.46385011])
                },
                'const_add_2_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'add_2_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'add_2_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'mul_1': {
                    'can_be_fused': False
                },
                'mul_2': {
                    'can_be_fused': False
                },
                'add_1': {
                    'can_be_fused': False
                },
                'add_2': {
                    'can_be_fused': False
                },
                'concat_data': {}
            })

        graph.graph['layout'] = 'NHWC'
        convert_batch_norm(graph)
        graph.clean_up()

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat_data')
        self.assertTrue(flag, resp)
Ejemplo n.º 21
0
    def test_remove_noop_nodes_check_out_port(self):
        graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'noop': {
                    'type': 'NoOp',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input', 'noop'), ('noop', 'output_1', {
                'in': 4,
                'out': 1
            }), ('noop', 'output_2', {
                'in': 2,
                'out': 1
            }), ('noop', 'output_3', {
                'in': 10,
                'out': 1
            })])

        ref_graph = build_graph(
            {
                'input': {
                    'type': 'Placeholder',
                    'value': None,
                    'kind': 'op'
                },
                'output_1': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_2': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
                'output_3': {
                    'type': 'Identity',
                    'value': None,
                    'kind': 'op'
                },
            }, [('input', 'output_1', {
                'in': 4,
                'out': 0
            }), ('input', 'output_2', {
                'in': 2,
                'out': 0
            }), ('input', 'output_3', {
                'in': 10,
                'out': 0
            })],
            nodes_with_edges_only=True)

        erase_node(Node(graph, 'noop'))

        compare_graphs(graph, ref_graph, 'output_1')
 def test_positive(self):
     graph = build_graph(nodes_attributes,
                         [('placeholder', 'placeholder_data'),
                          ('placeholder_data', 'transpose', {
                              'in': 0
                          }), ('transpose_const', 'transpose_const_data'),
                          ('transpose_const_data', 'transpose', {
                              'in': 1
                          }), ('transpose', 'transpose_data'),
                          ('transpose_data', 'reduceMean', {
                              'in': 0
                          }), ('reduceMeanConst', 'reduceMeanConst_data'),
                          ('reduceMeanConst_data', 'reduceMean', {
                              'in': 1
                          }), ('reduceMean', 'reduceMean_data'),
                          ('reduceMean_data', 'convolution')], {
                              'transpose_const': {
                                  'value': int64_array([0, 2, 3, 1])
                              },
                              'transpose_const_data': {
                                  'value': int64_array([0, 2, 3, 1])
                              },
                              'reduceMeanConst': {
                                  'value': int64_array([1, 2])
                              },
                              'reduceMeanConst_data': {
                                  'value': int64_array([1, 2])
                              }
                          },
                         nodes_with_edges_only=True)
     ref_graph = build_graph(nodes_attributes,
                             [('placeholder', 'placeholder_data'),
                              ('placeholder_data', 'reduceMean', {
                                  'in': 0
                              }),
                              ('transpose_const', 'transpose_const_data'),
                              ('transpose_const_data', 'gather', {
                                  'in': 0
                              }),
                              ('reduceMeanConst', 'reduceMeanConst_data'),
                              ('reduceMeanConst_data', 'gather', {
                                  'in': 1
                              }), ('gather_const', 'gather_const_data'),
                              ('gather_const_data', 'gather', {
                                  'in': 2
                              }), ('gather', 'gather_data'),
                              ('gather_data', 'reduceMean', {
                                  'in': 1
                              }), ('reduceMean', 'reduceMean_data'),
                              ('reduceMean_data', 'convolution')], {
                                  'transpose_const_data': {
                                      'value': int64_array([0, 2, 3, 1])
                                  },
                                  'reduceMeanConst_data': {
                                      'value': int64_array([1, 2])
                                  },
                              },
                             nodes_with_edges_only=True)
     TransposeReduce().find_and_replace_pattern(graph)
     flag, resp = compare_graphs(graph,
                                 ref_graph,
                                 'convolution',
                                 check_op_attrs=True)
     self.assertTrue(flag, resp)
Ejemplo n.º 23
0
 def setUp(self):
     self.graph = build_graph(nodes, edges)
Ejemplo n.º 24
0
    def test_remove_duplication_neibor(self):
        graph = build_graph(
            {
                'input': {
                    'kind': 'op',
                    'op': 'Parameter'
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'splice_1': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(-5, 1)
                },
                'splice_data_1': {
                    'kind': 'data',
                    'shape': [1, 78],
                    'value': None
                },
                'placeholder_1': {
                    'kind': 'op',
                    'op': None
                },
                'splice_2': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(0, 2)
                },
                'splice_data_2': {
                    'kind': 'data',
                    'shape': [1, 26],
                    'value': None
                },
                'placeholder_2': {
                    'kind': 'op',
                    'op': None
                },
            }, [
                ('input', 'in_node'),
                ('in_node', 'splice_1'),
                ('splice_1', 'splice_data_1'),
                ('splice_data_1', 'placeholder_1'),
                ('in_node', 'splice_2'),
                ('splice_2', 'splice_data_2'),
                ('splice_data_2', 'placeholder_2'),
            ],
            nodes_with_edges_only=True)
        MergeNeighborSplicePattern().find_and_replace_pattern(graph)
        ref_graph = build_graph(
            {
                'input': {
                    'kind': 'op',
                    'op': 'Parameter'
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'splice_1': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(-5, 2)
                },
                'splice_data_1': {
                    'kind': 'data',
                    'shape': [1, 91],
                    'value': None
                },
                'crop_1': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 0,
                    'dim': 78,
                    'axis': -1
                },
                'crop_1_data': {
                    'kind': 'data',
                    'shape': [1, 78]
                },
                'placeholder_1': {
                    'kind': 'op'
                },
                'crop_2': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 65,
                    'dim': 26,
                    'axis': -1
                },
                'splice_data_2': {
                    'kind': 'data',
                    'shape': [1, 26],
                    'value': None
                },
                'placeholder_2': {
                    'kind': 'op'
                },
            }, [
                ('input', 'in_node'),
                ('in_node', 'splice_1'),
                ('splice_1', 'splice_data_1'),
                ('splice_data_1', 'crop_1'),
                ('crop_1', 'crop_1_data'),
                ('crop_1_data', 'placeholder_1'),
                ('splice_data_1', 'crop_2'),
                ('crop_2', 'splice_data_2'),
                ('splice_data_2', 'placeholder_2'),
            ],
            nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_2')
        self.assertTrue(flag, resp)
Ejemplo n.º 25
0
 def test_infer_invalid4(self):
     graph = build_graph(nodes_attributes, edges1, inputs4_inv)
     ctcgreedydecoder_node = Node(graph, 'ctcgreedydecoder_node')
     self.assertRaises(AssertionError, CTCGreedyDecoderSeqLenOp.infer,
                       ctcgreedydecoder_node)
Ejemplo n.º 26
0
    def test_remove_duplication(self):
        graph = build_graph(
            {
                'input': {
                    'kind': 'op',
                    'op': 'Parameter'
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'splice_1': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(-5, 6)
                },
                'splice_data_1': {
                    'kind': 'data',
                    'shape': [1, 143]
                },
                'placeholder_1': {
                    'kind': 'op',
                    'op': None
                },
                'splice_2': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(-1, 2)
                },
                'splice_data_2': {
                    'kind': 'data',
                    'shape': [1, 39]
                },
                'placeholder_2': {
                    'kind': 'op',
                    'op': None
                },
            }, [
                ('input', 'in_node'),
                ('in_node', 'splice_1'),
                ('splice_1', 'splice_data_1'),
                ('splice_data_1', 'placeholder_1'),
                ('in_node', 'splice_2'),
                ('splice_2', 'splice_data_2'),
                ('splice_data_2', 'placeholder_2'),
            ],
            nodes_with_edges_only=True)
        RemoveMemoryDuplicationPattern().find_and_replace_pattern(graph)
        ref_graph = build_graph(
            {
                'input': {
                    'kind': 'op',
                    'op': 'Parameter'
                },
                'in_node': {
                    'kind': 'data',
                    'shape': [1, 13]
                },
                'splice_1': {
                    'kind': 'op',
                    'op': 'Splice',
                    'context': range(-5, 6)
                },
                'splice_data_1': {
                    'kind': 'data',
                    'shape': [1, 143]
                },
                'placeholder_1': {
                    'kind': 'op'
                },
                'crop_2': {
                    'kind': 'op',
                    'op': 'Crop',
                    'offset': 52,
                    'dim': 39,
                    'axis': -1
                },
                'splice_data_2': {
                    'kind': 'data',
                    'shape': [1, 39]
                },
                'placeholder_2': {
                    'kind': 'op'
                },
            }, [
                ('input', 'in_node'),
                ('in_node', 'splice_1'),
                ('splice_1', 'splice_data_1'),
                ('splice_data_1', 'placeholder_1'),
                ('splice_data_1', 'crop_2'),
                ('crop_2', 'splice_data_2'),
                ('splice_data_2', 'placeholder_2'),
            ],
            nodes_with_edges_only=True)

        (flag, resp) = compare_graphs(graph, ref_graph, 'placeholder_2')
        self.assertTrue(flag, resp)
Ejemplo n.º 27
0
    def test_conv_infer_3D_convolution(self):
        graph = build_graph(nodes_attributes,
                            [
                                ('conv_input', 'conv_node'),
                                ('conv_weights', 'conv_node'),
                                ('conv_node', 'conv_output'),
                                ('conv_output', 'op_output')
                            ],
                            {
                                'conv_output': {
                                    'shape': None
                                },
                                'conv_input': {
                                    'shape': int64_array([1, 3, 16, 224, 224])
                                },
                                'conv_weights': {
                                    'shape': int64_array([3, 64, 1, 7, 7]),
                                    'dim_attrs': ['spatial_dims', 'channel_dims', 'batch_dims', 'axis']
                                },
                                'conv_node': {
                                    'type': 'Convolution',
                                    'bias_term': None,
                                    'stride': None,
                                    'dilation': None,

                                    'batch_dims': int64_array([0]),
                                    'channel_dims': int64_array([1]),

                                    'output_spatial_shape': None,

                                    'input_feature_channel': 0,
                                    'output_feature_channel': 1,

                                    'group': 1,
                                    'output_shape': None,
                                    'layout': 'NCHW'
                                }
                            })

        conv_node = Node(graph, 'conv_node')
        conv_output = Node(graph, 'conv_output')

        Convolution.infer(conv_node)

        # Check bias_term attribute
        self.assertTrue(conv_node.has_valid('bias_term'))
        self.assertTrue(not conv_node.bias_term)
        # Check kernel_spatial_idx attr detection
        self.assertTrue(conv_node.has_valid('kernel_spatial_idx'))
        self.assertTrue(np.array_equal(int64_array([2, 3, 4]), conv_node.kernel_spatial_idx))
        # Check spatial_dims attr detection
        self.assertTrue(conv_node.has_valid('spatial_dims'))
        self.assertTrue(np.array_equal(int64_array([2, 3, 4]), conv_node.spatial_dims))
        # Check kernel_spatial attr detection
        self.assertTrue(conv_node.has_valid('kernel_spatial'))
        self.assertTrue(np.array_equal(int64_array([1, 7, 7]), conv_node.kernel_spatial))
        # Check output attribute
        self.assertTrue(conv_node.has_valid('output'))
        self.assertEqual(64, conv_node.output)
        # Check dilation value. Should be set to default
        self.assertTrue(conv_node.has_valid('dilation'))
        self.assertTrue(np.array_equal(int64_array([1, 1, 1, 1, 1]), conv_node.dilation))
        # Check stride value. Should be set to default
        self.assertTrue(conv_node.has_valid('stride'))
        self.assertTrue(np.array_equal(int64_array([1, 1, 1, 1, 1]), conv_node.stride))
        # Check pad value. Should be set to default
        self.assertTrue(conv_node.has_valid('pad'))
        self.assertTrue(np.array_equal(int64_array([[0, 0], [0, 0], [0, 0], [0, 0], [0, 0]]), conv_node.pad))
        # Check pad_spatial_shape
        self.assertTrue(conv_node.has_valid('pad_spatial_shape'))
        self.assertTrue(np.array_equal(int64_array([[0, 0], [0, 0], [0, 0]]), conv_node.pad_spatial_shape))
        # Check resulting output shape
        self.assertTrue(np.array_equal(int64_array([1, 64, 16, 218, 218]), conv_output.shape))
Ejemplo n.º 28
0
 def test_one_input_no_shape(self):
     shape = None
     inputs = {'conv_1': [{'shape': shape}]}
     nodes = {
         'old_input': {
             'type': 'Placeholder',
             'kind': 'op',
             'op': 'Placeholder'
         },
         'old_input_data': {
             'kind': 'data',
             'value': None,
             'shape': np.array([-1, 224, 224, 3])
         },
         'conv_1': {
             'type': 'Convolution',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'conv_1_data': {
             'kind': 'data',
             'value': True,
             'shape': np.array([-1, 224, 224, 3])
         },
         'relu_1': {
             'type': 'ReLU',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'relu_1_data': {
             'kind': 'data',
             'value': None,
             'shape': np.array([-1, 112, 112, 64])
         },
         'output': {
             'type': 'SoftMax',
             'kind': 'op',
             'op': 'NotPlaceholder'
         },
         'output_data': {
             'name': 'output_data',
             'kind': 'data',
             'shape': np.array([-1, 112, 112, 64])
         },
         'op_output': {
             'kind': 'op',
             'op': 'OpOutput'
         }
     }
     edges = [('old_input', 'old_input_data'), ('old_input_data', 'conv_1'),
              ('conv_1', 'conv_1_data'), ('conv_1_data', 'relu_1'),
              ('relu_1', 'relu_1_data'), ('relu_1_data', 'output'),
              ('output', 'output_data'), ('output_data', 'op_output')]
     graph = build_graph(nodes, edges)
     add_input_ops(graph=graph,
                   user_defined_inputs=inputs,
                   before_infer=False)
     new_input = list(graph.in_edges(list(
         graph.in_edges('conv_1'))[0][0]))[0][0]
     new_input_data = list(graph.in_edges('conv_1'))[0][0]
     self.assertFalse(graph.node['old_input']['is_input'])
     self.assertTrue(graph.node[new_input]['is_input'])
     self.assertTrue((new_input_data, 'conv_1') in graph.edges())
     self.assertTrue(('old_input_data', 'conv_1') not in graph.edges())
     self.assertIsNotNone(graph.node[new_input_data]['shape'])
Ejemplo n.º 29
0
 def test_infer(self, input_value, exp_value, axis=-1):
     graph = build_graph(generate_nodes(int64_array(input_value), axis), edges)
     onehot_node = Node(graph, 'one_hot')
     OneHot.infer(onehot_node)
     res_value = graph.node['one_hot_d']['value']
     self.assertTrue(np.array_equal(exp_value, int64_array(res_value)))
Ejemplo n.º 30
0
    def test_scaleshift_to_mul_add(self):
        graph = build_graph(
            nodes_attributes, [
                ('placeholder_1', 'placeholder_1_data'),
                ('placeholder_1_data', 'scaleshift_1'),
                ('scaleshift_1_w', 'scaleshift_1'),
                ('scaleshift_1_b', 'scaleshift_1'),
                ('scaleshift_1', 'scaleshift_1_data'),
            ], {
                'placeholder_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'scaleshift_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'scaleshift_1_b': {
                    'shape': np.array([3]),
                    'value': np.array([3, 2, 1])
                },
                'scaleshift_1_data': {
                    'is_output': True
                }
            })

        graph_ref = build_graph(
            nodes_attributes, [
                ('placeholder_1', 'placeholder_1_data'),
                ('placeholder_1_data', 'mul_1'),
                ('mul_1_w', 'mul_1'),
                ('mul_1', 'mul_1_data'),
                ('mul_1_data', 'add_1'),
                ('add_1_w', 'add_1'),
                ('add_1', 'scaleshift_1_data'),
            ], {
                'placeholder_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'mul_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([1, 2, 3])
                },
                'add_1_w': {
                    'shape': np.array([3]),
                    'value': np.array([3, 2, 1])
                },
                'mul_1_data': {
                    'shape': np.array([1, 227, 227, 3])
                },
                'add_1': {
                    'can_be_fused': True
                },
                'mul_1': {
                    'can_be_fused': True
                },
                'scaleshift_1_data': {
                    'is_output': True
                }
            })

        graph.graph['layout'] = 'NHWC'
        convert_scale_shift_to_mul_add(graph)
        graph_clean_up(graph)
        (flag, resp) = compare_graphs(graph, graph_ref, 'scaleshift_1_data')
        self.assertTrue(flag, resp)