def replace_pattern(graph: Graph, match: Dict[str, Node]): node = match['op'] name = node.soft_get('name', node.id) input_shape = node.in_port(0).data.get_shape() second_input_shape = node.in_port(1).data.get_shape() begin_mask = np.zeros(len(input_shape), dtype=np.int64) end_mask = np.zeros(len(input_shape), dtype=np.int64) for i in node.axes: end_mask[i] = np.int64(1) new_axis_mask = np.zeros(len(input_shape), dtype=np.int64) shrink_axis_mask = np.zeros(len(input_shape), dtype=np.int64) ellipsis_mask = np.zeros(len(input_shape), dtype=np.int64) ss = create_op_with_const_inputs(graph, StridedSlice, port_value_dict={1: np.zeros(len(input_shape), dtype=np.int64)}, op_attrs={'name': 'StridedSlice', 'begin_mask': begin_mask, 'end_mask': end_mask, 'new_axis_mask': new_axis_mask, 'shrink_axis_mask': shrink_axis_mask, 'ellipsis_mask': ellipsis_mask}) if input_shape.size == second_input_shape.size: end = Shape(graph, dict(name=name + '/End')).create_node() end.in_port(0).connect(node.in_port(1).get_source()) ss.in_port(2).connect(end.out_port(0)) else: shape_like, rank_like = get_shape_and_rank_nodes_by_port(node.in_port(1).get_source()) end_first_part = get_shape_values_by_range_idxs(shape_like, rank_like, 0, node.axes[-1], include_end=True) if input_shape.size - 1 == node.axes[-1]: ss.in_port(2).connect(end_first_part.out_port(0)) else: shape, rank = get_shape_and_rank_nodes_by_port(node.in_port(0).get_source()) end_second_part = get_shape_values_by_range_idxs(shape, rank, node.axes[-1], -1, include_begin=False, include_end=True) end = new_shape_node_from_shape_nodes([end_first_part, end_second_part]) ss.in_port(2).connect(end.out_port(0)) node.in_port(0).get_connection().set_destination(ss.in_port(0)) node.in_port(1).disconnect() node.out_port(0).get_connection().set_source(ss.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (ss, name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name) assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node() middle_dim = Const(graph, {'value': int64_array([-1])}).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def resolve_minus2(self, shape_node, input_index, reshape_index, dims): rank_node = Shape( shape_node.graph, dict(name=shape_node.id + '/RankShapeMXReshapeMinus2')).create_node() rank_node.in_port(0).connect(shape_node.out_port(0)) shape_values_node = get_shape_values_by_range_idxs(shape=shape_node, rank=rank_node, begin=input_index, end=-1, include_begin=True, include_end=True) input_index = None reshape_index = reshape_index + 1 return input_index, reshape_index, dims, shape_values_node
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))