def test6(self): # Original graph # data(1,64,1)-->Reduce(axis=-2,keep_dims=True, reduce_type=Sum)-->data(1,1,1) # # Reference graph # data(1,61,1)->Reshape(1,1,8,8)->Pool(1,1,1,1)->Reshape(1,1,1)->Power(scale=64) # graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reduce_1'), ('const', 'const_data'), ('const_data', 'reduce_1', {'in': 1}), ('reduce_1', 'reduce_1_data'), ('reduce_1_data', 'concat'), ], {'placeholder_1': {'shape': int64_array([1, 64, 1])}, 'placeholder_1_data': {'shape': int64_array([1, 64, 1])}, 'reduce_1': {'keep_dims': True, 'type': 'ReduceSum'}, 'const_data': {'value': int64_array([-2])}, 'reduce_1_data': {'shape': int64_array([1, 1, 1])}, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reshape_1'), ('reshape_1_const', 'reshape_1_const_data'), ('reshape_1_const_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'pooling'), ('pooling', 'pooling_data'), ('pooling_data', 'reshape_2'), ('reshape_2_const', 'reshape_2_const_data'), ('reshape_2_const_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ('reshape_2_data', 'power'), ('power', 'power_data'), ('power_data', 'concat'), ], {'placeholder_1': {'shape': int64_array([1, 64, 1])}, 'placeholder_1_data': {'shape': int64_array([1, 64, 1])}, 'reshape_1_const': {'value': int64_array([0, 1, 8, 8]), 'shape': int64_array([4])}, 'reshape_1_const_data': {'value': int64_array([0, 1, 8, 8]), 'shape': int64_array([4])}, 'reshape_1_data': {'shape': int64_array([1, 1, 8, 8])}, 'pooling': {'window': int64_array([1, 1, 8, 8])}, 'pooling_data': {'shape': int64_array([1, 1, 1, 1])}, 'reshape_2_const': {'value': int64_array([0, 1, 1]), 'shape': int64_array([3])}, 'reshape_2_const_data': {'value': int64_array([0, 1, 1]), 'shape': int64_array([3])}, 'reshape_2_data': {'shape': int64_array([1, 1, 1])}, 'power': {'scale': 64.0}, 'power_data': {'shape': int64_array([1, 1, 1])}, }, nodes_with_edges_only=True) ReduceReplacer().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def test_consecutive_stride_slices_removal(self): graph = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data'), ('placeholder_data', 'strided_slice'), ('strided_slice_input_1_data', 'strided_slice'), ('strided_slice_input_2_data', 'strided_slice'), ('strided_slice_input_3_data', 'strided_slice'), ('strided_slice', 'strided_slice_data'), ('strided_slice_data', 'strided_slice_2'), ('strided_slice_input_1_data', 'strided_slice_2'), ('strided_slice_input_2_data', 'strided_slice_2'), ('strided_slice_input_3_data', 'strided_slice_2'), ('strided_slice_2', 'strided_slice_2_data'), ('strided_slice_2_data', 'output_op'), ], {}, nodes_with_edges_only=True) UselessStridedSliceEraser().find_and_replace_pattern(graph) shape_inference(graph) graph_ref = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data'), ('placeholder_data', 'output_op'), ], {'placeholder_data': { 'shape': int64_array([4, 1, 6]) }}) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test5(self): # Original graph # data(1, 16, 64, 64, 64, 4)-->Reduce(axis=[5],keep_dims=False)-->data(1, 16, 64, 64, 64) # # Reference graph # data(1, 16, 64, 64, 64, 4)->Reshape(1*16*64*64, 64, 2, 2)->Pool(1, 1, 2, 2)->Reshape(1, 16, 64, 64, 64) # graph = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reduce_1'), ('const', 'const_data'), ('const_data', 'reduce_1', {'in': 1}), ('reduce_1', 'reduce_1_data'), ('reduce_1_data', 'concat'), ], {'placeholder_1': {'shape': int64_array([1, 16, 64, 64, 64, 4])}, 'placeholder_1_data': {'shape': int64_array([1, 16, 64, 64, 64, 4])}, 'reduce_1': {'keep_dims': False, 'type': 'ReduceMax'}, 'const_data': {'value': int64_array([5])}, 'reduce_1_data': {'shape': int64_array([1, 16, 64, 64, 64])}, }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' graph_ref = build_graph(nodes_attributes, [('placeholder_1', 'placeholder_1_data'), ('placeholder_1_data', 'reshape_1'), ('reshape_1_const', 'reshape_1_const_data'), ('reshape_1_const_data', 'reshape_1'), ('reshape_1', 'reshape_1_data'), ('reshape_1_data', 'pooling'), ('pooling', 'pooling_data'), ('pooling_data', 'reshape_2'), ('reshape_2_const', 'reshape_2_const_data'), ('reshape_2_const_data', 'reshape_2'), ('reshape_2', 'reshape_2_data'), ('reshape_2_data', 'concat'), ], {'placeholder_1': {'shape': int64_array([1, 16, 64, 64, 64, 4])}, 'placeholder_1_data': {'shape': int64_array([1, 16, 64, 64, 64, 4])}, 'reshape_1_const': {'value': int64_array([0, 4194304, 2, 2]), 'shape': int64_array([4])}, 'reshape_1_const_data': {'value': int64_array([0, 4194304, 2, 2]), 'shape': int64_array([4])}, 'reshape_1_data': {'shape': int64_array([1, 4194304, 2, 2])}, 'pooling': {'window': int64_array([1, 1, 2, 2])}, 'pooling_data': {'shape': int64_array([1, 4194304, 1, 1])}, 'reshape_2_const': {'value': int64_array([0, 16, 64, 64, 64]), 'shape': int64_array([5])}, 'reshape_2_const_data': {'value': int64_array([0, 16, 64, 64, 64]), 'shape': int64_array([5])}, 'reshape_2_data': {'shape': int64_array([1, 16, 64, 64, 64])}, }, nodes_with_edges_only=True) ReduceReplacer().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'concat', check_op_attrs=True) self.assertTrue(flag, resp)
def apply_transform(graph: Graph, replacer_cls, **kwargs): """ Safely executes transform if it should be and validates graph after transform execution """ replacer = replacer_cls() replacement_id = 'REPLACEMENT_ID' if hasattr(replacer, 'replacement_id'): replacement_id = replacer.replacement_id if hasattr(replacer, 'enabled') and not replacer.enabled: log.info("Skip replacer {} (enabled = False)".format(replacer_cls)) return if hasattr(replacer, 'graph_condition') and \ not all([condition(graph) for condition in replacer.graph_condition]): log.info("Skip replacer {} (graph_condition not satisfied)".format( replacer_cls)) return log.debug("Run replacer {}".format(replacer_cls)) try: if hasattr(replacer, 'run_not_recursively'): replacer.find_and_replace_pattern(graph) else: for_graph_and_each_sub_graph_recursively( graph, replacer.find_and_replace_pattern) if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up: for_graph_and_each_sub_graph_recursively(graph, lambda G: G.clean_up()) if hasattr(replacer, 'force_shape_inference') and replacer.force_shape_inference: shape_inference(graph) for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_empty_graph(replacer_cls)) for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_shapes_consistency()) except Error as err: raise Error( 'Exception occurred during running replacer "{}" ({}): {}'.format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err except Exception as err: raise Exception( 'Exception occurred during running replacer "{} ({})": {}'.format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err
def test_single_stride_slice_with_shrink_and_new_removal(self): graph = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data'), ('placeholder_data', 'strided_slice'), ('strided_slice_input_1_data', 'strided_slice'), ('strided_slice_input_2_data', 'strided_slice'), ('strided_slice_input_3_data', 'strided_slice'), ('strided_slice', 'strided_slice_data'), ('strided_slice_data', 'output_op'), ], { 'strided_slice': { 'shrink_axis_mask': int64_array([0, 1, 0, 0]), 'new_axis_mask': int64_array([0, 0, 1, 0]) }, 'strided_slice_data': { 'shape': int64_array([4, 1, 6]) } }, nodes_with_edges_only=True) graph.graph['layout'] = 'NCHW' UselessStridedSliceEraser().find_and_replace_pattern(graph) shape_inference(graph) graph_ref = build_graph(nodes_attributes, [('placeholder', 'placeholder_data'), ('placeholder_data', 'unsqueeze'), ('unsqueeze_const', 'unsqueeze_const_data'), ('unsqueeze_const_data', 'unsqueeze'), ('unsqueeze', 'unsqueeze_data'), ('unsqueeze_data', 'squeeze'), ('squeeze_const', 'squeeze_const_data'), ('squeeze_const_data', 'squeeze'), ('squeeze', 'strided_slice_data'), ('strided_slice_data', 'output_op')], { 'placeholder_data': { 'shape': int64_array([4, 1, 6]) }, 'unsqueeze_data': { 'shape': int64_array([4, 1, 1, 6]) }, 'strided_slice_data': { 'shape': int64_array([4, 1, 6]) }, 'unsqueeze_const': { 'value': int64_array([2]) }, }, nodes_with_edges_only=True) (flag, resp) = compare_graphs(graph, graph_ref, 'output_op', check_op_attrs=True) self.assertTrue(flag, resp)
def test_tile_reshaper(self): graph = build_graph(nodes_attributes, edge_attributes) graph_ref = build_graph(nodes_attributes_ref, edge_attributes_ref) TileReshaper().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'next_op', check_op_attrs=True) self.assertTrue(flag, resp)
def graph_clean_up(graph: Graph, undead_node_types: list = None): if undead_node_types is None: undead_node_types = [] if 'Shape' in undead_node_types: undead_node_types.remove('Shape') mark_output_reachable_nodes(graph) mark_undead_nodes(graph, undead_node_types) mark_const_producer_nodes(graph) eliminate_dead_nodes(graph) # Add Const op for constant data nodes add_constant_operations(graph) shape_inference(graph)
def test_gather_tree_normalizer(self): nodes = { **regular_op_with_shaped_data('data_0', [100, 1, 10], { 'type': 'Parameter' }), **regular_op_with_shaped_data('data_1', [100, 1, 10], { 'type': 'Parameter' }), **regular_op_with_shaped_data('data_2', [1], { 'type': 'Parameter' }), **regular_op_with_shaped_data('gather_tree', [1], { 'type': 'GatherTree' }), **valued_const_with_data('const', np.array([2])), **result('result'), } edges = [ *connect('data_0', '0:gather_tree'), *connect('data_1', '1:gather_tree'), *connect('data_2', '2:gather_tree'), *connect('const', '3:gather_tree'), *connect('gather_tree', 'result'), ] ref_edges = [ *connect('data_0', '0:gather_tree'), *connect('data_1', '1:gather_tree'), *connect('data_2', '2:gather_tree'), *connect('const', '0:squeeze'), *connect('squeeze_axis', '1:squeeze'), *connect('squeeze', '3:gather_tree'), *connect('gather_tree', 'result'), ] ref_nodes = nodes.copy() ref_nodes.update({ **valued_const_with_data('squeeze_axis', int64_array([0])), **regular_op_with_shaped_data('squeeze', [], {'type': 'Squeeze'}) }) graph = build_graph(nodes, edges) GatherTreeNormalizer().find_and_replace_pattern(graph) # run shape inference to make sure that shape overriding happened shape_inference(graph) ref_graph = build_graph(ref_nodes, ref_edges) (flag, resp) = compare_graphs(graph, ref_graph, 'result') self.assertTrue(flag, resp)
def test_non_training(self): graph = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data', {}), ('scale', 'scale_data'), ('offset', 'offset_data'), ('mean', 'mean_data'), ('variance', 'variance_data'), ('placeholder_data', 'batchnorm', { 'in': 0 }), ('scale_data', 'batchnorm', { 'in': 1 }), ('offset_data', 'batchnorm', { 'in': 2 }), ('mean_data', 'batchnorm', { 'in': 3 }), ('variance_data', 'batchnorm', { 'in': 4 }), ('batchnorm', 'batchnorm_data'), ('batchnorm_data', 'result'), ], {'batchnorm': { 'is_training': False }}, nodes_with_edges_only=True) graph_ref = graph.copy() FusedBatchNormTraining().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)
def test_1(self): graph = build_graph( { 'placeholder': { 'kind': 'op', 'op': 'Placeholder', 'type': 'Placeholder', 'shape': [1, 10, 128, 128] }, 'data': { 'shape': [1, 10, 128, 128], 'kind': 'data' }, 'shuffle': { 'type': 'ShuffleChannel', 'kind': 'op', 'op': 'ShuffleChannel', 'group': 2 }, 'out_data': { 'shape': [1, 10, 128, 128], 'kind': 'data' }, 'output': { 'kind': 'op', 'op': 'OpOutput' } }, [('placeholder', 'data'), ('data', 'shuffle'), ('shuffle', 'out_data'), ('out_data', 'output')], {}) graph.graph['layout'] = 'NCHW' graph_ref = build_graph( { 'placeholder': { 'kind': 'op', 'op': 'Placeholder', 'type': 'Placeholder', 'shape': [1, 10, 128, 128] }, 'data': { 'shape': [1, 10, 128, 128], 'kind': 'data' }, 'reshape': { 'type': 'Reshape', 'kind': 'op', 'op': 'Reshape' }, 'reshape_const': { 'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': int64_array([1, 2, 5, -1]), 'shape': [4] }, 'reshape_const_data': { 'kind': 'data', 'value': [1, 2, 5, -1], 'shape': [4] }, 'reshape_data': { 'shape': [1, 2, 5, 128 * 128], 'kind': 'data' }, 'order_const': { 'kind': 'op', 'op': 'Const', 'value': np.array([0, 2, 1, 3]) }, 'order_data': { 'kind': 'data', 'value': np.array([0, 2, 1, 3]), 'shape': np.array([4]) }, 'transpose': { 'type': 'Transpose', 'kind': 'op', 'op': 'Transpose' }, 'transpose_data': { 'shape': [1, 5, 2, 128 * 128], 'kind': 'data' }, 'reshape_back': { 'type': 'Reshape', 'kind': 'op', 'op': 'Reshape' }, 'reshape_back_const': { 'type': 'Const', 'kind': 'op', 'op': 'Const', 'value': int64_array([1, 10, 128, 128]), 'shape': [4] }, 'reshape_back_const_data': { 'kind': 'data', 'value': [1, 10, 128, 128], 'shape': [4] }, 'out_data': { 'shape': [1, 10, 128, 128], 'kind': 'data' }, 'output': { 'kind': 'op', 'op': 'OpOutput' }, }, [('placeholder', 'data'), ('data', 'reshape'), ('reshape_const', 'reshape_const_data'), ('reshape_const_data', 'reshape'), ('reshape', 'reshape_data'), ('order_const', 'order_data'), ('order_data', 'transpose', { 'in': 1 }), ('reshape_data', 'transpose', { 'in': 0 }), ('transpose', 'transpose_data'), ('transpose_data', 'reshape_back'), ('reshape_back_const', 'reshape_back_const_data'), ('reshape_back_const_data', 'reshape_back'), ('reshape_back', 'out_data'), ('out_data', 'output')], {}) ShuffleChannel().find_and_replace_pattern(graph) shape_inference(graph) (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) self.assertTrue(flag, resp)
def apply_replacements(graph: Graph, replacements_type): """ Apply all patterns that do not have 'op' first, then apply patterns from registered_ops. If two or more classes replaces the same op (both have op class attribute and values match), such pattern is not applied (while registration it will warn user that we have a conflict). """ dependency_graph = DependencyGraph(name=ClassType(replacements_type).name) for class_type, classes_set in _registered_classes_dict.items(): if class_type == replacements_type: replacers = [] for cls in classes_set: cur_cls_replacers = [c for c in cls.registered_cls if not hasattr(c, 'op')] + \ [c for op, c in cls.registered_ops.items() if c] replacers.extend([ replacer for replacer in cur_cls_replacers if replacer not in cls.excluded_replacers ]) for replacer_cls in replacers: dependency_graph.add_node(replacer_cls) for replacer_cls in replacers: for cls_after in replacer_cls().run_before(): log.debug("Replacer {} will be run before {}".format( replacer_cls, cls_after)) dependency_graph.add_edge(replacer_cls, cls_after) for cls_before in replacer_cls().run_after(): log.debug("Replacer {} will be run after {}".format( replacer_cls, cls_before)) dependency_graph.add_edge(cls_before, replacer_cls) replacers_order = dependency_graph.determined_sort() for replacer_cls in replacers_order: replacer = replacer_cls() replacement_id = 'REPLACEMENT_ID' if hasattr(replacer, 'replacement_id'): replacement_id = replacer.replacement_id if hasattr(replacer, 'enabled') and not replacer.enabled: log.info("Skip replacer {} (enabled = False)".format(replacer_cls)) continue if hasattr(replacer, 'graph_condition') and \ not all([condition(graph) for condition in replacer.graph_condition]): log.info("Skip replacer {} (graph_condition not satisfied)".format( replacer_cls)) continue log.debug("Run replacer {}".format(replacer_cls)) try: if hasattr(replacer, 'run_not_recursively'): replacer.find_and_replace_pattern(graph) else: for_graph_and_each_sub_graph_recursively( graph, replacer.find_and_replace_pattern) if hasattr(replacer, 'force_clean_up') and replacer.force_clean_up: for_graph_and_each_sub_graph_recursively( graph, graph_clean_up_tf if graph.graph['fw'] == 'tf' else graph_clean_up_onnx if graph.graph['fw'] == 'onnx' else graph_clean_up) if hasattr(replacer, 'force_shape_inference' ) and replacer.force_shape_inference: shape_inference(graph) for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_empty_graph(replacer_cls)) for_graph_and_each_sub_graph_recursively( graph, lambda _: graph.check_shapes_consistency()) except Error as err: raise Error( 'Exception occurred during running replacer "{}" ({}): {}'. format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err except Exception as err: raise Exception( 'Exception occurred during running replacer "{} ({})": {}'. format( replacement_id, replacer_cls, str(err).replace('[REPLACEMENT_ID]', replacement_id), )) from err
def test_transformation(self, op: str): graph = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data', {}), ('scale', 'scale_data'), ('offset', 'offset_data'), ('mean', 'mean_data'), ('variance', 'variance_data'), ('placeholder_data', 'batchnorm', { 'in': 0 }), ('scale_data', 'batchnorm', { 'in': 1 }), ('offset_data', 'batchnorm', { 'in': 2 }), ('mean_data', 'batchnorm', { 'in': 3 }), ('variance_data', 'batchnorm', { 'in': 4 }), ('batchnorm', 'batchnorm_data'), ('batchnorm_data', 'result'), ], {}, nodes_with_edges_only=True) graph.nodes['batchnorm']['op'] = op graph_ref = build_graph(nodes_attributes, [ ('placeholder', 'placeholder_data', {}), ('scale', 'scale_data'), ('offset', 'offset_data'), ('bn_mean', 'bn_mean_data'), ('bn_variance', 'bn_variance_data'), ('scale_data', 'batchnorm', { 'in': 1 }), ('offset_data', 'batchnorm', { 'in': 2 }), ('bn_mean_data', 'batchnorm', { 'in': 3 }), ('bn_variance_data', 'batchnorm', { 'in': 4 }), ('placeholder_data', 'reshape_1', { 'in': 0 }), ('reshape_1_const', 'reshape_1_const_data'), ('reshape_1_const_data', 'reshape_1', { 'in': 1 }), ('reshape_1', 'reshape_1_data', {}), ('reshape_1_data', 'mvn', { 'in': 0 }), ('mvn', 'mvn_data'), ('mvn_data', 'reshape_to_orig', { 'in': 0 }), ('placeholder_data', 'shapeof', { 'in': 0 }), ('shapeof', 'shapeof_data'), ('shapeof_data', 'reshape_to_orig', { 'in': 1 }), ('reshape_to_orig', 'reshape_to_orig_data'), ('reshape_to_orig_data', 'batchnorm', { 'in': 0 }), ('batchnorm', 'batchnorm_data'), ('batchnorm_data', 'result'), ], { 'batchnorm': { 'is_training': False }, }, nodes_with_edges_only=True) FusedBatchNormTraining().find_and_replace_pattern(graph) shape_inference(graph) graph_ref.nodes['batchnorm']['op'] = op (flag, resp) = compare_graphs(graph, graph_ref, 'result', check_op_attrs=True) self.assertTrue(flag, resp)