def test_gather_tree_normalizer(self):
        nodes = {
            **regular_op_with_shaped_data('data_0', [100, 1, 10], {'type': 'Parameter'}),
            **regular_op_with_shaped_data('data_1', [100, 1, 10], {'type': 'Parameter'}),
            **regular_op_with_shaped_data('data_2', [1], {'type': 'Parameter'}),
            **regular_op_with_shaped_data('gather_tree', [1], {'type': 'GatherTree'}),
            **valued_const_with_data('const', np.array([2])),
            **result('result'),
        }
        edges = [*connect('data_0', '0:gather_tree'),
                 *connect('data_1', '1:gather_tree'),
                 *connect('data_2', '2:gather_tree'),
                 *connect('const', '3:gather_tree'),
                 *connect('gather_tree', 'result'),
                 ]
        ref_edges = [*connect('data_0', '0:gather_tree'),
                     *connect('data_1', '1:gather_tree'),
                     *connect('data_2', '2:gather_tree'),
                     *connect('const', '0:squeeze'),
                     *connect('squeeze_axis', '1:squeeze'),
                     *connect('squeeze', '3:gather_tree'),
                     *connect('gather_tree', 'result'),]
        ref_nodes = nodes.copy()
        ref_nodes.update({**valued_const_with_data('squeeze_axis', int64_array([0])),
                          **regular_op_with_shaped_data('squeeze', [], {'type': 'Squeeze'})})
        graph = build_graph(nodes, edges)
        GatherTreeNormalizer().find_and_replace_pattern(graph)
        # run shape inference to make sure that shape overriding happened
        shape_inference(graph)

        ref_graph = build_graph(ref_nodes, ref_edges)

        (flag, resp) = compare_graphs(graph, ref_graph, 'result')
        self.assertTrue(flag, resp)
Пример #2
0
    def test_non_training(self):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('mean', 'mean_data'),
            ('variance', 'variance_data'),
            ('placeholder_data', 'batchnorm', {
                'in': 0
            }),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('mean_data', 'batchnorm', {
                'in': 3
            }),
            ('variance_data', 'batchnorm', {
                'in': 4
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {'batchnorm': {
            'is_training': False
        }},
                            nodes_with_edges_only=True)
        graph_ref = graph.copy()

        FusedBatchNormTraining().find_and_replace_pattern(graph)
        shape_inference(graph)

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)
Пример #3
0
def apply_transform(graph: Graph, replacer_cls, **kwargs):
    """
    Safely executes transform if it should be and validates graph after transform execution
    """
    replacer = replacer_cls()
    replacement_id = 'REPLACEMENT_ID'
    if hasattr(replacer, 'replacement_id'):
        replacement_id = replacer.replacement_id

    if hasattr(replacer, 'enabled') and not replacer.enabled:
        log.info("Skip replacer {} (enabled = False)".format(replacer_cls))
        return

    if hasattr(replacer, 'graph_condition') and \
            not all([condition(graph) for condition in replacer.graph_condition]):
        log.info("Skip replacer {} (graph_condition not satisfied)".format(
            replacer_cls))
        return

    log.debug("Run replacer {}".format(replacer_cls))

    try:
        if hasattr(replacer,
                   'run_not_recursively') and replacer.run_not_recursively:
            replacer.find_and_replace_pattern(graph)
        else:
            for_graph_and_each_sub_graph_recursively(
                graph, replacer.find_and_replace_pattern)

        if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up:
            for_graph_and_each_sub_graph_recursively(graph,
                                                     lambda G: G.clean_up())

        if hasattr(replacer,
                   'force_shape_inference') and replacer.force_shape_inference:
            shape_inference(graph)

        if hasattr(replacer,
                   'run_not_recursively') and replacer.run_not_recursively:
            graph.check_empty_graph(replacer_cls)
            graph.check_shapes_consistency()
        else:
            for_graph_and_each_sub_graph_recursively(
                graph, lambda _: graph.check_empty_graph(replacer_cls))
            for_graph_and_each_sub_graph_recursively(
                graph, lambda _: graph.check_shapes_consistency())

    except Error as err:
        raise Error(
            'Exception occurred during running replacer "{}" ({}): {}'.format(
                replacement_id,
                replacer_cls,
                str(err).replace('[REPLACEMENT_ID]', replacement_id),
            )) from err
    except FrameworkError as err:
        raise FrameworkError('{}'.format(str(err))) from err
    except Exception as err:
        raise Exception(
            'Exception occurred during running replacer "{} ({})": {}'.format(
                replacement_id,
                replacer_cls,
                str(err).replace('[REPLACEMENT_ID]', replacement_id),
            )) from err
Пример #4
0
    def test_transformation(self, op: str):
        graph = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('mean', 'mean_data'),
            ('variance', 'variance_data'),
            ('placeholder_data', 'batchnorm', {
                'in': 0
            }),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('mean_data', 'batchnorm', {
                'in': 3
            }),
            ('variance_data', 'batchnorm', {
                'in': 4
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {},
                            nodes_with_edges_only=True)
        graph.nodes['batchnorm']['op'] = op
        graph_ref = build_graph(nodes_attributes, [
            ('placeholder', 'placeholder_data', {}),
            ('scale', 'scale_data'),
            ('offset', 'offset_data'),
            ('bn_mean', 'bn_mean_data'),
            ('bn_variance', 'bn_variance_data'),
            ('scale_data', 'batchnorm', {
                'in': 1
            }),
            ('offset_data', 'batchnorm', {
                'in': 2
            }),
            ('bn_mean_data', 'batchnorm', {
                'in': 3
            }),
            ('bn_variance_data', 'batchnorm', {
                'in': 4
            }),
            ('placeholder_data', 'reshape_1', {
                'in': 0
            }),
            ('reshape_1_const', 'reshape_1_const_data'),
            ('reshape_1_const_data', 'reshape_1', {
                'in': 1
            }),
            ('reshape_1', 'reshape_1_data', {}),
            ('reshape_1_data', 'mvn', {
                'in': 0
            }),
            ('mvn', 'mvn_data'),
            ('mvn_data', 'reshape_to_orig', {
                'in': 0
            }),
            ('start', 'start_data'),
            ('start_data', 'mvn_axes'),
            ('stop', 'stop_data'),
            ('stop_data', 'mvn_axes'),
            ('step', 'step_data'),
            ('step_data', 'mvn_axes'),
            ('mvn_axes', 'mvn_axes_data'),
            ('mvn_axes_data', 'mvn'),
            ('placeholder_data', 'shapeof', {
                'in': 0
            }),
            ('shapeof', 'shapeof_data'),
            ('shapeof_data', 'reshape_to_orig', {
                'in': 1
            }),
            ('reshape_to_orig', 'reshape_to_orig_data'),
            ('reshape_to_orig_data', 'batchnorm', {
                'in': 0
            }),
            ('batchnorm', 'batchnorm_data'),
            ('batchnorm_data', 'result'),
        ], {
            'batchnorm': {
                'is_training': False
            },
        },
                                nodes_with_edges_only=True)
        FusedBatchNormTraining().find_and_replace_pattern(graph)
        shape_inference(graph)

        graph_ref.nodes['batchnorm']['op'] = op

        (flag, resp) = compare_graphs(graph,
                                      graph_ref,
                                      'result',
                                      check_op_attrs=True)
        self.assertTrue(flag, resp)