def replace_pattern(graph: Graph, match: dict): fq = match['fq'] if len(fq.out_port(0).get_destinations()) > 1: # FQ should have only one child -- Transpose for optimization return transpose = match['transpose'] name = fq.soft_get('name', fq.id) input_shape = transpose.in_port(0).data.get_shape() # detaching transpose from the graph transpose.out_port(0).get_connection().set_source(transpose.in_port(0).get_connection().get_source()) transpose.in_port(0).disconnect() for idx, port in fq.in_ports().items(): transpose_copy = transpose.copy_node({'override_output_shape': True}) transpose.in_port(1).get_source().connect(transpose_copy.in_port(1)) start_port = transpose_copy.in_port(0) idxs = np.arange(len(input_shape) - len(port.data.get_shape())) if idxs.size != 0: axis = Const(graph, {'name': name + '/in_{}_unsqueeze_axis'.format(idx), 'value': int64_array(idxs)}).create_node() unsqueeze = Unsqueeze(graph, {'name': name + '/in_{}_unsqueeze'.format(idx)}).create_node() axis.out_port(0).connect(unsqueeze.in_port(1)) unsqueeze.out_port(0).connect(transpose_copy.in_port(0)) start_port = unsqueeze.in_port(0) src = port.get_source() port.get_connection().set_source(transpose_copy.out_port(0)) src.connect(start_port)
def add_unsqueeze_for_new(graph: Graph, ss_node: Node): log.info( "StridedSlice op with new axis mask '{}' has been detected".format( ss_node.id)) if len(ss_node.in_nodes()) != 4 or len(ss_node.out_nodes()) != 1: return shape_out = ss_node.out_node().shape dim = np.array(range(len(ss_node['new_axis_mask'])))[np.array( ss_node['new_axis_mask'], dtype=bool)] ss_shape = [] for i in range(0, len(ss_node['new_axis_mask'])): if not ss_node['new_axis_mask'][i]: ss_shape.append(shape_out[i]) else: ss_node['new_axis_mask'][i] = 0 ss_node.out_port(0).data.set_shape(ss_shape) # insert Unsqueeze unsqueeze_node = Unsqueeze(graph, dict(name=ss_node.name + '/Unsqueeze_new')).create_node() ss_node.out_port(0).get_connection().insert_node(unsqueeze_node) unsqueeze_node.out_port(0).data.set_shape(shape_out) dims_node = Const(graph, { 'name': unsqueeze_node.id + '/Indices', 'value': int64_array(dim) }).create_node() dims_node.out_port(0).connect(unsqueeze_node.in_port(1))
def find_and_replace_pattern(self, graph: Graph): for expand_dims_node in graph.get_op_nodes(op='ExpandDims'): if len(expand_dims_node.in_nodes()) == 1: expand_axis = expand_dims_node.expand_axis if not isinstance(expand_axis, np.ndarray): expand_axis = int64_array([expand_axis]).flatten() unsqueeze_node = Unsqueeze(graph, {'name': expand_dims_node.id}).create_node() unsqueeze_dims_node = Const(graph, {'name': expand_dims_node.id + '/Dims', 'value': expand_axis}).create_node() expand_dims_node.in_port(0).get_connection().set_destination(unsqueeze_node.in_port(0)) expand_dims_node.out_port(0).get_connection().set_source(unsqueeze_node.out_port(0)) unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0)) else: log.error('The ExpandDims node {} has more than 1 input'.format(expand_dims_node.soft_get('name')))
def replace_pattern(graph: Graph, match: dict): """ Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors): Searches for Tiles with 3D shapes and covers it with Reshapes. Example: Tile (axis=1, tiles=16): in_shape: [1,1,101] out_shape: [1,16,101] Old behaviour: Tile -> [1,16,101] New behaviour: Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101] """ node = match['tile'] name = node.soft_get('name', node.id) out_shape = node.out_port(0).data.get_shape() assert out_shape is not None, 'Output shape is undefined for {} in back phase'.format(name) if out_shape.size != 3: return inp_shape = node.in_port(0).data.get_shape() assert inp_shape is not None, 'Input shape is undefined for {} in back phase'.format(name) unsqueeze_dim = Const(graph, {'name': name + '/3D_Tile_Unsqueeze_dim', 'value': int64_array([3])}).create_node() unsqueeze = Unsqueeze(graph, {'name': name + '/3D_Tile_Unsqueeze', 'override_output_shape': True}).create_node() unsqueeze_dim.out_port(0).connect(unsqueeze.in_port(1)) const = Const(graph, {'name': name + '/additional_axis', 'value': int64_array([1])}).create_node() new_tiles = new_shape_node_from_shape_nodes([node.in_port(1).get_source().node, const]) node.in_port(1).get_connection().set_source(new_tiles.out_port(0)) squeeze_dim = Const(graph, {'name': name + '/3D_Tile_Squeeze_dim', 'value': int64_array([3])}).create_node() squeeze = Squeeze(graph, {'name': name + '/3D_Tile_Squeeze', 'override_output_shape': True}).create_node() squeeze_dim.out_port(0).connect(squeeze.in_port(1)) source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(source) node.out_port(0).get_connection().set_source(squeeze.out_port(0)) node.out_port(0).connect(squeeze.in_port(0)) node['override_output_shape'] = True new_tiles['override_output_shape'] = True node['need_shape_inference'] = True
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })