Beispiel #1
0
    def test6(self):
        #   Original graph
        #   data(1,64,1)-->Reduce(axis=-2,keep_dims=True, reduce_type=Sum)-->data(1,1,1)
        #
        #   Reference graph
        #   data(1,61,1)->Reshape(1,1,8,8)->Pool(1,1,1,1)->Reshape(1,1,1)->Power(scale=64)
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'reduce_1'),
                             ('const', 'const_data'),
                             ('const_data', 'reduce_1', {'in': 1}),
                             ('reduce_1', 'reduce_1_data'),
                             ('reduce_1_data', 'concat'),
                             ],
                            {'placeholder_1': {'shape': int64_array([1, 64, 1])},
                             'placeholder_1_data': {'shape': int64_array([1, 64, 1])},
                             'reduce_1': {'keep_dims': True, 'type': 'ReduceSum'},
                             'const_data': {'value': int64_array([-2])},
                             'reduce_1_data': {'shape': int64_array([1, 1, 1])},
                             }, nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'reshape_1'),
                                 ('reshape_1_const', 'reshape_1_const_data'),
                                 ('reshape_1_const_data', 'reshape_1'),
                                 ('reshape_1', 'reshape_1_data'),
                                 ('reshape_1_data', 'pooling'),
                                 ('pooling', 'pooling_data'),
                                 ('pooling_data', 'reshape_2'),
                                 ('reshape_2_const', 'reshape_2_const_data'),
                                 ('reshape_2_const_data', 'reshape_2'),
                                 ('reshape_2', 'reshape_2_data'),
                                 ('reshape_2_data', 'power'),
                                 ('power', 'power_data'),
                                 ('power_data', 'concat'),
                                 ],
                                {'placeholder_1': {'shape': int64_array([1, 64, 1])},
                                 'placeholder_1_data': {'shape': int64_array([1, 64, 1])},
                                 'reshape_1_const': {'value': int64_array([0, 1, 8, 8]), 'shape': int64_array([4])},
                                 'reshape_1_const_data': {'value': int64_array([0, 1, 8, 8]),
                                                          'shape': int64_array([4])},
                                 'reshape_1_data': {'shape': int64_array([1, 1, 8, 8])},
                                 'pooling': {'window': int64_array([1, 1, 8, 8])},
                                 'pooling_data': {'shape': int64_array([1, 1, 1, 1])},
                                 'reshape_2_const': {'value': int64_array([0, 1, 1]), 'shape': int64_array([3])},
                                 'reshape_2_const_data': {'value': int64_array([0, 1, 1]), 'shape': int64_array([3])},
                                 'reshape_2_data': {'shape': int64_array([1, 1, 1])},
                                 'power': {'scale': 64.0},
                                 'power_data': {'shape': int64_array([1, 1, 1])},
                                 }, nodes_with_edges_only=True)

        ReduceReplacer().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_consecutive_stride_slices_removal(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'strided_slice'),
            ('strided_slice_input_1_data', 'strided_slice'),
            ('strided_slice_input_2_data', 'strided_slice'),
            ('strided_slice_input_3_data', 'strided_slice'),
            ('strided_slice', 'strided_slice_data'),
            ('strided_slice_data', 'strided_slice_2'),
            ('strided_slice_input_1_data', 'strided_slice_2'),
            ('strided_slice_input_2_data', 'strided_slice_2'),
            ('strided_slice_input_3_data', 'strided_slice_2'),
            ('strided_slice_2', 'strided_slice_2_data'),
            ('strided_slice_2_data', 'output_op'),
        ], {},
                            nodes_with_edges_only=True)

        UselessStridedSliceEraser().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'output_op'),
        ], {'placeholder_data': {
            'shape': int64_array([4, 1, 6])
        }})
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #3
0
    def test5(self):
        #   Original graph
        #   data(1, 16, 64, 64, 64, 4)-->Reduce(axis=[5],keep_dims=False)-->data(1, 16, 64, 64, 64)
        #
        #   Reference graph
        #   data(1, 16, 64, 64, 64, 4)->Reshape(1*16*64*64, 64, 2, 2)->Pool(1, 1, 2, 2)->Reshape(1, 16, 64, 64, 64)
        #
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'reduce_1'),
                             ('const', 'const_data'),
                             ('const_data', 'reduce_1', {'in': 1}),
                             ('reduce_1', 'reduce_1_data'),
                             ('reduce_1_data', 'concat'),
                             ],
                            {'placeholder_1': {'shape': int64_array([1, 16, 64, 64, 64, 4])},
                             'placeholder_1_data': {'shape': int64_array([1, 16, 64, 64, 64, 4])},
                             'reduce_1': {'keep_dims': False, 'type': 'ReduceMax'},
                             'const_data': {'value': int64_array([5])},
                             'reduce_1_data': {'shape': int64_array([1, 16, 64, 64, 64])},
                             }, nodes_with_edges_only=True)

        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder_1', 'placeholder_1_data'),
                                 ('placeholder_1_data', 'reshape_1'),
                                 ('reshape_1_const', 'reshape_1_const_data'),
                                 ('reshape_1_const_data', 'reshape_1'),
                                 ('reshape_1', 'reshape_1_data'),
                                 ('reshape_1_data', 'pooling'),
                                 ('pooling', 'pooling_data'),
                                 ('pooling_data', 'reshape_2'),
                                 ('reshape_2_const', 'reshape_2_const_data'),
                                 ('reshape_2_const_data', 'reshape_2'),
                                 ('reshape_2', 'reshape_2_data'),
                                 ('reshape_2_data', 'concat'),
                                 ],
                                {'placeholder_1': {'shape': int64_array([1, 16, 64, 64, 64, 4])},
                                 'placeholder_1_data': {'shape': int64_array([1, 16, 64, 64, 64, 4])},
                                 'reshape_1_const': {'value': int64_array([0, 4194304, 2, 2]),
                                                     'shape': int64_array([4])},
                                 'reshape_1_const_data': {'value': int64_array([0, 4194304, 2, 2]),
                                                          'shape': int64_array([4])},
                                 'reshape_1_data': {'shape': int64_array([1, 4194304, 2, 2])},
                                 'pooling': {'window': int64_array([1, 1, 2, 2])},
                                 'pooling_data': {'shape': int64_array([1, 4194304, 1, 1])},
                                 'reshape_2_const': {'value': int64_array([0, 16, 64, 64, 64]),
                                                     'shape': int64_array([5])},
                                 'reshape_2_const_data': {'value': int64_array([0, 16, 64, 64, 64]),
                                                          'shape': int64_array([5])},
                                 'reshape_2_data': {'shape': int64_array([1, 16, 64, 64, 64])},
                                 }, nodes_with_edges_only=True)

        ReduceReplacer().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #4
0
def apply_transform(graph: Graph, replacer_cls, **kwargs):
    """
    Safely executes transform if it should be and validates graph after transform execution
    """
    replacer = replacer_cls()
    replacement_id = 'REPLACEMENT_ID'
    if hasattr(replacer, 'replacement_id'):
        replacement_id = replacer.replacement_id

    if hasattr(replacer, 'enabled') and not replacer.enabled:
        log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
        return

    if hasattr(replacer, 'graph_condition') and \
            not all([condition(graph) for condition in replacer.graph_condition]):
        log.info("Skip replacer {} (graph_condition not satisfied)".format(
            replacer_cls))
        return

    log.debug("Run replacer {}".format(replacer_cls))

    try:
        if hasattr(replacer, 'run_not_recursively'):
            replacer.find_and_replace_pattern(graph)
        else:
            for_graph_and_each_sub_graph_recursively(
                graph, replacer.find_and_replace_pattern)

        if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        if hasattr(replacer,
                   'force_shape_inference') and replacer.force_shape_inference:
            shape_inference(graph)

        for_graph_and_each_sub_graph_recursively(
            graph, lambda _: graph.check_empty_graph(replacer_cls))
        for_graph_and_each_sub_graph_recursively(
            graph, lambda _: graph.check_shapes_consistency())

    except Error as err:
        raise Error(
            'Exception occurred during running replacer "{}" ({}): {}'.format(
                replacement_id,
                replacer_cls,
                str(err).replace('[REPLACEMENT_ID]', replacement_id),
            )) from err
    except Exception as err:
        raise Exception(
            'Exception occurred during running replacer "{} ({})": {}'.format(
                replacement_id,
                replacer_cls,
                str(err).replace('[REPLACEMENT_ID]', replacement_id),
            )) from err
    def test_single_stride_slice_with_shrink_and_new_removal(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data'),
            ('placeholder_data', 'strided_slice'),
            ('strided_slice_input_1_data', 'strided_slice'),
            ('strided_slice_input_2_data', 'strided_slice'),
            ('strided_slice_input_3_data', 'strided_slice'),
            ('strided_slice', 'strided_slice_data'),
            ('strided_slice_data', 'output_op'),
        ], {
            'strided_slice': {
                'shrink_axis_mask': int64_array([0, 1, 0, 0]),
                'new_axis_mask': int64_array([0, 0, 1, 0])
            },
            'strided_slice_data': {
                'shape': int64_array([4, 1, 6])
            }
        },
                            nodes_with_edges_only=True)
        graph.graph['layout'] = 'NCHW'

        UselessStridedSliceEraser().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref = build_graph(nodes_attributes,
                                [('placeholder', 'placeholder_data'),
                                 ('placeholder_data', 'unsqueeze'),
                                 ('unsqueeze_const', 'unsqueeze_const_data'),
                                 ('unsqueeze_const_data', 'unsqueeze'),
                                 ('unsqueeze', 'unsqueeze_data'),
                                 ('unsqueeze_data', 'squeeze'),
                                 ('squeeze_const', 'squeeze_const_data'),
                                 ('squeeze_const_data', 'squeeze'),
                                 ('squeeze', 'strided_slice_data'),
                                 ('strided_slice_data', 'output_op')], {
                                     'placeholder_data': {
                                         'shape': int64_array([4, 1, 6])
                                     },
                                     'unsqueeze_data': {
                                         'shape': int64_array([4, 1, 1, 6])
                                     },
                                     'strided_slice_data': {
                                         'shape': int64_array([4, 1, 6])
                                     },
                                     'unsqueeze_const': {
                                         'value': int64_array([2])
                                     },
                                 },
                                nodes_with_edges_only=True)
        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #6
0
    def test_tile_reshaper(self):
        graph = build_graph(nodes_attributes, edge_attributes)

        graph_ref = build_graph(nodes_attributes_ref, edge_attributes_ref)

        TileReshaper().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'next_op',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Beispiel #7
0
    def graph_clean_up(graph: Graph, undead_node_types: list = None):
        if undead_node_types is None:
            undead_node_types = []

        if 'Shape' in undead_node_types:
            undead_node_types.remove('Shape')

        mark_output_reachable_nodes(graph)
        mark_undead_nodes(graph, undead_node_types)
        mark_const_producer_nodes(graph)
        eliminate_dead_nodes(graph)
        # Add Const op for constant data nodes
        add_constant_operations(graph)
        shape_inference(graph)
Beispiel #8
0
    def test_gather_tree_normalizer(self):
        nodes = {
            **regular_op_with_shaped_data('data_0', [100, 1, 10], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('data_1', [100, 1, 10], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('data_2', [1], {
                                              'type': 'Parameter'
                                          }),
            **regular_op_with_shaped_data('gather_tree', [1], {
                                              'type': 'GatherTree'
                                          }),
            **valued_const_with_data('const', np.array([2])),
            **result('result'),
        }
        edges = [
            *connect('data_0', '0:gather_tree'),
            *connect('data_1', '1:gather_tree'),
            *connect('data_2', '2:gather_tree'),
            *connect('const', '3:gather_tree'),
            *connect('gather_tree', 'result'),
        ]
        ref_edges = [
            *connect('data_0', '0:gather_tree'),
            *connect('data_1', '1:gather_tree'),
            *connect('data_2', '2:gather_tree'),
            *connect('const', '0:squeeze'),
            *connect('squeeze_axis', '1:squeeze'),
            *connect('squeeze', '3:gather_tree'),
            *connect('gather_tree', 'result'),
        ]
        ref_nodes = nodes.copy()
        ref_nodes.update({
            **valued_const_with_data('squeeze_axis', int64_array([0])),
            **regular_op_with_shaped_data('squeeze', [], {'type': 'Squeeze'})
        })
        graph = build_graph(nodes, edges)
        GatherTreeNormalizer().find_and_replace_pattern(graph)
        # run shape inference to make sure that shape overriding happened
        shape_inference(graph)

        ref_graph = build_graph(ref_nodes, ref_edges)

        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
        self.assertTrue(flag, resp)
Beispiel #9
0
    def test_non_training(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('mean', 'mean_data'),
            ('variance', 'variance_data'),
            ('placeholder_data', 'batchnorm', {
                'in': 0
            }),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('mean_data', 'batchnorm', {
                'in': 3
            }),
            ('variance_data', 'batchnorm', {
                'in': 4
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {'batchnorm': {
            'is_training': False
        }},
                            nodes_with_edges_only=True)
        graph_ref = graph.copy()

        FusedBatchNormTraining().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
    def test_1(self):
        graph = build_graph(
            {
                'placeholder': {
                    'kind': 'op',
                    'op': 'Placeholder',
                    'type': 'Placeholder',
                    'shape': [1, 10, 128, 128]
                },
                'data': {
                    'shape': [1, 10, 128, 128],
                    'kind': 'data'
                },
                'shuffle': {
                    'type': 'ShuffleChannel',
                    'kind': 'op',
                    'op': 'ShuffleChannel',
                    'group': 2
                },
                'out_data': {
                    'shape': [1, 10, 128, 128],
                    'kind': 'data'
                },
                'output': {
                    'kind': 'op',
                    'op': 'OpOutput'
                }
            }, [('placeholder', 'data'), ('data', 'shuffle'),
                ('shuffle', 'out_data'), ('out_data', 'output')], {})
        graph.graph['layout'] = 'NCHW'

        graph_ref = build_graph(
            {
                'placeholder': {
                    'kind': 'op',
                    'op': 'Placeholder',
                    'type': 'Placeholder',
                    'shape': [1, 10, 128, 128]
                },
                'data': {
                    'shape': [1, 10, 128, 128],
                    'kind': 'data'
                },
                'reshape': {
                    'type': 'Reshape',
                    'kind': 'op',
                    'op': 'Reshape'
                },
                'reshape_const': {
                    'type': 'Const',
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([1, 2, 5, -1]),
                    'shape': [4]
                },
                'reshape_const_data': {
                    'kind': 'data',
                    'value': [1, 2, 5, -1],
                    'shape': [4]
                },
                'reshape_data': {
                    'shape': [1, 2, 5, 128 * 128],
                    'kind': 'data'
                },
                'order_const': {
                    'kind': 'op',
                    'op': 'Const',
                    'value': np.array([0, 2, 1, 3])
                },
                'order_data': {
                    'kind': 'data',
                    'value': np.array([0, 2, 1, 3]),
                    'shape': np.array([4])
                },
                'transpose': {
                    'type': 'Transpose',
                    'kind': 'op',
                    'op': 'Transpose'
                },
                'transpose_data': {
                    'shape': [1, 5, 2, 128 * 128],
                    'kind': 'data'
                },
                'reshape_back': {
                    'type': 'Reshape',
                    'kind': 'op',
                    'op': 'Reshape'
                },
                'reshape_back_const': {
                    'type': 'Const',
                    'kind': 'op',
                    'op': 'Const',
                    'value': int64_array([1, 10, 128, 128]),
                    'shape': [4]
                },
                'reshape_back_const_data': {
                    'kind': 'data',
                    'value': [1, 10, 128, 128],
                    'shape': [4]
                },
                'out_data': {
                    'shape': [1, 10, 128, 128],
                    'kind': 'data'
                },
                'output': {
                    'kind': 'op',
                    'op': 'OpOutput'
                },
            }, [('placeholder', 'data'), ('data', 'reshape'),
                ('reshape_const', 'reshape_const_data'),
                ('reshape_const_data', 'reshape'), ('reshape', 'reshape_data'),
                ('order_const', 'order_data'),
                ('order_data', 'transpose', {
                    'in': 1
                }), ('reshape_data', 'transpose', {
                    'in': 0
                }), ('transpose', 'transpose_data'),
                ('transpose_data', 'reshape_back'),
                ('reshape_back_const', 'reshape_back_const_data'),
                ('reshape_back_const_data', 'reshape_back'),
                ('reshape_back', 'out_data'), ('out_data', 'output')], {})

        ShuffleChannel().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'output',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
def apply_replacements(graph: Graph, replacements_type):
    """
    Apply all patterns that do not have 'op' first, then apply patterns from registered_ops.
    If two or more classes replaces the same op (both have op class attribute and values match), such
    pattern is not applied (while registration it will warn user that we have a conflict).
    """
    dependency_graph = DependencyGraph(name=ClassType(replacements_type).name)
    for class_type, classes_set in _registered_classes_dict.items():
        if class_type == replacements_type:
            replacers = []
            for cls in classes_set:
                cur_cls_replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \
                                    [c for op, c in cls.registered_ops.items() if c]
                replacers.extend([
                    replacer for replacer in cur_cls_replacers
                    if replacer not in cls.excluded_replacers
                ])

            for replacer_cls in replacers:
                dependency_graph.add_node(replacer_cls)

            for replacer_cls in replacers:
                for cls_after in replacer_cls().run_before():
                    log.debug("Replacer {} will be run before {}".format(
                        replacer_cls, cls_after))
                    dependency_graph.add_edge(replacer_cls, cls_after)
                for cls_before in replacer_cls().run_after():
                    log.debug("Replacer {} will be run after {}".format(
                        replacer_cls, cls_before))
                    dependency_graph.add_edge(cls_before, replacer_cls)

    replacers_order = dependency_graph.determined_sort()
    for replacer_cls in replacers_order:
        replacer = replacer_cls()

        replacement_id = 'REPLACEMENT_ID'
        if hasattr(replacer, 'replacement_id'):
            replacement_id = replacer.replacement_id

        if hasattr(replacer, 'enabled') and not replacer.enabled:
            log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
            continue

        if hasattr(replacer, 'graph_condition') and \
                not all([condition(graph) for condition in replacer.graph_condition]):
            log.info("Skip replacer {} (graph_condition not satisfied)".format(
                replacer_cls))
            continue

        log.debug("Run replacer {}".format(replacer_cls))

        try:
            if hasattr(replacer, 'run_not_recursively'):
                replacer.find_and_replace_pattern(graph)
            else:
                for_graph_and_each_sub_graph_recursively(
                    graph, replacer.find_and_replace_pattern)

            if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
                for_graph_and_each_sub_graph_recursively(
                    graph, graph_clean_up_tf
                    if graph.graph['fw'] == 'tf' else graph_clean_up_onnx
                    if graph.graph['fw'] == 'onnx' else graph_clean_up)

            if hasattr(replacer, 'force_shape_inference'
                       ) and replacer.force_shape_inference:
                shape_inference(graph)

            for_graph_and_each_sub_graph_recursively(
                graph, lambda _: graph.check_empty_graph(replacer_cls))
            for_graph_and_each_sub_graph_recursively(
                graph, lambda _: graph.check_shapes_consistency())

        except Error as err:
            raise Error(
                'Exception occurred during running replacer "{}" ({}): {}'.
                format(
                    replacement_id,
                    replacer_cls,
                    str(err).replace('[REPLACEMENT_ID]', replacement_id),
                )) from err
        except Exception as err:
            raise Exception(
                'Exception occurred during running replacer "{} ({})": {}'.
                format(
                    replacement_id,
                    replacer_cls,
                    str(err).replace('[REPLACEMENT_ID]', replacement_id),
                )) from err
Beispiel #12
0
    def test_transformation(self, op: str):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('mean', 'mean_data'),
            ('variance', 'variance_data'),
            ('placeholder_data', 'batchnorm', {
                'in': 0
            }),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('mean_data', 'batchnorm', {
                'in': 3
            }),
            ('variance_data', 'batchnorm', {
                'in': 4
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {},
                            nodes_with_edges_only=True)
        graph.nodes['batchnorm']['op'] = op
        graph_ref = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('bn_mean', 'bn_mean_data'),
            ('bn_variance', 'bn_variance_data'),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('bn_mean_data', 'batchnorm', {
                'in': 3
            }),
            ('bn_variance_data', 'batchnorm', {
                'in': 4
            }),
            ('placeholder_data', 'reshape_1', {
                'in': 0
            }),
            ('reshape_1_const', 'reshape_1_const_data'),
            ('reshape_1_const_data', 'reshape_1', {
                'in': 1
            }),
            ('reshape_1', 'reshape_1_data', {}),
            ('reshape_1_data', 'mvn', {
                'in': 0
            }),
            ('mvn', 'mvn_data'),
            ('mvn_data', 'reshape_to_orig', {
                'in': 0
            }),
            ('placeholder_data', 'shapeof', {
                'in': 0
            }),
            ('shapeof', 'shapeof_data'),
            ('shapeof_data', 'reshape_to_orig', {
                'in': 1
            }),
            ('reshape_to_orig', 'reshape_to_orig_data'),
            ('reshape_to_orig_data', 'batchnorm', {
                'in': 0
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {
            'batchnorm': {
                'is_training': False
            },
        },
                                nodes_with_edges_only=True)
        FusedBatchNormTraining().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref.nodes['batchnorm']['op'] = op

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)