def replace_sub_graph(self, graph: Graph, match: dict): node = match['mxreshape'] input_index = 0 reshape_index = 0 shape_node = Shape(graph, dict(name=node.id + '/ShapeMXReshape')).create_node() shape_node.in_port(0).connect(node.in_port(0).get_source()) output_dims_nodes = [] for d in node.dim: if reshape_index < len(node.dim): input_index, reshape_index, output_dims_nodes = self.resolve( input_index, reshape_index, node.dim, shape_node, output_dims_nodes) concat_node = Concat( shape_node.graph, dict(name=shape_node.id + '/ConcatMXReshape_', axis=0, in_ports_count=len(output_dims_nodes))).create_node() for in_port_index, dim_node in enumerate(output_dims_nodes): concat_node.in_port(in_port_index).connect(dim_node.out_port(0)) reshape_node = Reshape(graph, dict(name=node.id + '/Reshape_')).create_node() reshape_node.in_port(1).connect(concat_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0))
def append_variances(priors_scale_node: Node, variance: list): graph = priors_scale_node.graph name = priors_scale_node.name sp_shape = Shape(graph, {'name': name + '/shape'}).create_node() priors_scale_node.out_port(0).connect(sp_shape.in_port(0)) begin = Const(graph, {'value': int64_array([-2])}).create_node() end = Const(graph, {'value': int64_array([-1])}).create_node() stride = Const(graph, {'value': int64_array([1])}).create_node() shape_part_for_tiling = StridedSlice(graph, {'name': name + '/get_-2_dim', 'begin_mask': int64_array([1]), 'end_mask': int64_array([1]), 'new_axis_mask': int64_array([0]), 'shrink_axis_mask': int64_array([0]), 'ellipsis_mask': int64_array([0])}).create_node() sp_shape.out_port(0).connect(shape_part_for_tiling.in_port(0)) begin.out_port(0).connect(shape_part_for_tiling.in_port(1)) end.out_port(0).connect(shape_part_for_tiling.in_port(2)) stride.out_port(0).connect(shape_part_for_tiling.in_port(3)) shape_concat = create_op_node_with_second_input(graph, Concat, int64_array([4]), {'name': name + '/shape_for_tiling', 'in_ports_count': 2, 'axis': int64_array(0)}, shape_part_for_tiling) variance = Const(graph, {'name': name + '/variance', 'value': float32_array(variance)}).create_node() tile = Broadcast(graph, {'name': name + '/variance_tile'}).create_node() variance.out_port(0).connect(tile.in_port(0)) shape_concat.out_port(0).connect(tile.in_port(1)) reshape_dim = Const(graph, {'value': int64_array([-1, 4])}).create_node() sp_reshape = Reshape(graph, {'name': name + '/reshape'}).create_node() sp_reshape.in_port(0).connect(priors_scale_node.out_port(0)) sp_reshape.in_port(1).connect(reshape_dim.out_port(0)) concat = Concat(graph, {'name': name + '/priors_concat', 'axis': int64_array(0), 'in_ports_count': 2}).create_node() sp_reshape.out_port(0).connect(concat.in_port(0)) tile.out_port(0).connect(concat.in_port(1)) output_dims = Const(graph, {'value': int64_array([1, 2, -1])}).create_node() output_node = Reshape(graph, {'name': name + '/3D_priors_wth_variances'}).create_node() concat.out_port(0).connect(output_node.in_port(0)) output_dims.out_port(0).connect(output_node.in_port(1)) return output_node
def replace_sub_graph(self, graph: Graph, match: dict): mxreshape = match['op'] if not mxreshape.reverse: return shape_node = Shape(graph, dict(name=mxreshape.id + '/Shape')).create_node() forward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardUnsqueeze')) forward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/ForwardReverse', axis=1)).create_node() forward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/ForwardSqueeze')) reshape_node = Reshape(graph, dict(name=mxreshape.id + '/Reshape')).create_node() shape_node.in_port(0).connect(mxreshape.in_port(0).get_source()) mxreshape.in_port(0).get_connection().set_destination(reshape_node.in_port(0)) forward_reverse_unsqueeze_node.in_port(0).connect(shape_node.out_port(0)) forward_reverse_node.in_port(0).connect(forward_reverse_unsqueeze_node.out_port(0)) forward_reverse_squeeze_node.in_port(0).connect(forward_reverse_node.out_port(0)) reshape_node.in_port(1).connect(forward_reverse_squeeze_node.out_port(0)) reshape_shape_node = create_op_node_with_second_input(graph, Reshape, int64_array(np.flip(mxreshape.dim, 0)), dict(name=str(mxreshape.id) + '/ReshapeShape')) if np.sum(np.in1d([-2, -3, -4], mxreshape.dim), axis=0): reshape_shape_node = MXReshape(graph, dict(name=mxreshape.id + '/Reshape', dim=int64_array(np.flip(mxreshape.dim, 0)))).create_node() reshape_shape_node.in_port(0).connect(reshape_node.out_port(0)) backward_shape_node = Shape(graph, dict(name=mxreshape.id + '/BackwardShape')).create_node() backward_reverse_unsqueeze_node = create_op_node_with_second_input(graph, Unsqueeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardUnsqueeze')) backward_reverse_node = Reverse(graph, dict(name=mxreshape.id + '/BackwardReverse', axis=1)).create_node() backward_reverse_squeeze_node = create_op_node_with_second_input(graph, Squeeze, int64_array([0]), dict(name=str(mxreshape.id) + '/BackwardSqueeze')) backward_reshape_node = Reshape(graph, dict(name=mxreshape.id + '/BackwardReshape')).create_node() backward_shape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reverse_unsqueeze_node.in_port(0).connect(backward_shape_node.out_port(0)) backward_reverse_node.in_port(0).connect(backward_reverse_unsqueeze_node.out_port(0)) backward_reverse_squeeze_node.in_port(0).connect(backward_reverse_node.out_port(0)) backward_reshape_node.in_port(0).connect(reshape_shape_node.out_port(0)) backward_reshape_node.in_port(1).connect(backward_reverse_squeeze_node.out_port(0)) mxreshape.out_port(0).get_connection().set_source(backward_reshape_node.out_port(0))
def find_and_replace_pattern(self, graph: Graph): layout = graph.graph['layout'] for eltwise_op_node in graph.get_op_nodes(is_eltwise=True): out_shape = eltwise_op_node.out_port().data.get_shape() if 4 <= len(out_shape) <= 5: out_features = out_shape[get_features_dim(layout, len(out_shape))] for port, node in eltwise_op_node.in_nodes().items(): if len(node.shape) != len(out_shape) and len(node.shape) == 1 and out_features == node.shape[0]: new_shape = shape_for_layout(layout, batch=1, features=out_features, height=1, width=1, depth=1 if len(out_shape) == 5 else None) dim_const = Const(graph, {'value': new_shape, 'name': node.id + '/Dim'}).create_node() reshape_op = Reshape(graph, attrs={'dim': new_shape, 'name': node.id + '/Broadcast'}).create_node() eltwise_op_node.in_port(port).get_source().node.out_port(0).get_connection().set_destination(reshape_op.in_port(0)) reshape_op.in_port(1).connect(dim_const.out_port(0)) reshape_op.out_port(0).connect(eltwise_op_node.in_port(port))
def convert_fft_to_dft(self, graph: Graph, mx_fft: Node): mx_fft_name = mx_fft.soft_get('name', mx_fft.id) unsqueeze_node = create_op_with_const_inputs( graph, Unsqueeze, {1: int64_array([-1])}, {'name': mx_fft_name + '/Unsqueeze'}) rank_node = Rank(graph, {'name': mx_fft_name + '/Rank'}).create_node() mx_fft_connection = mx_fft.in_port(0).get_connection() mx_fft_connection.set_destination(unsqueeze_node.in_port(0)) mx_fft_connection.get_source().connect(rank_node.in_port(0)) add_node = create_op_with_const_inputs(graph, Add, {1: int64_array(1)}, {'name': mx_fft_name + '/Add'}, rank_node) broadcast_node1 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Pad_broadcast'}) add_node.out_port(0).connect(broadcast_node1.in_port(1)) scatter_node = create_op_with_const_inputs( graph, ScatterUpdate, { 2: int64_array(1), 3: int64_array(0) }, {'name': mx_fft_name + '/ScatterUpdate'}) broadcast_node1.out_port(0).connect(scatter_node.in_port(0)) rank_node.out_port(0).connect(scatter_node.in_port(1)) pad_node = Pad(graph, { 'name': mx_fft_name + '/Pad', 'mode': 'constant' }).create_node([unsqueeze_node, broadcast_node1, scatter_node]) dft_node = create_op_with_const_inputs( graph, DFT, {1: int64_array([-1])}, { 'name': mx_fft_name + '/DFT', 'in_ports_count': 2 }, pad_node) sub_node = create_op_with_const_inputs(graph, Sub, {1: int64_array(1)}, {'name': mx_fft_name + '/Sub'}) rank_node.out_port(0).connect(sub_node.in_port(0)) broadcast_node2 = create_op_with_const_inputs( graph, Broadcast, {0: int64_array(0)}, {'name': mx_fft_name + '/Reshape_broadcast'}) sub_node.out_port(0).connect(broadcast_node2.in_port(1)) concat_node = create_op_with_const_inputs( graph, Concat, {1: int64_array([-1, 2])}, { 'name': mx_fft_name + '/New_shape', 'in_ports_count': 2, 'axis': 0 }, broadcast_node2) reshape_node = Reshape(graph, {}).create_node([dft_node, concat_node]) mx_fft.out_port(0).get_connection().set_source( reshape_node.out_port(0)) rename_nodes([(mx_fft, mx_fft_name + '/to_be_removed'), (reshape_node, mx_fft_name)])
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid( 'axis' ), 'Flatten {} should have `axis` attribute extracted, but it\'s not'.format( name) axis = node.axis reshape_node = Reshape(graph, { 'name': node.id + '/Reshape' }).create_node() if axis == 0: dim = Const( graph, { 'value': int64_array([1, -1]), 'name': reshape_node.name + '/shape' }).create_node() elif axis == 1: dim = Const( graph, { 'value': int64_array([0, -1]), 'name': reshape_node.name + '/shape' }).create_node() else: shape = Shape(graph, {'name': name + '/input_shape'}).create_node() idxs = list(range(axis)) if axis > 0 else list(range(axis, 0)) axis_shape_portion = node_to_get_shape_value_of_indices( shape, idxs) first_dims = create_op_node_with_second_input( graph, ReduceProd, int64_array([0]), { 'name': name + '/first_dims', 'keep_dims': True }) second_dims = Const(graph, { 'value': int64_array([-1]), 'name': name + '/second_dims' }).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) axis_shape_portion.out_port(0).connect(first_dims.in_port(0)) order_of_dims = [first_dims, second_dims ] if axis > 0 else [second_dims, first_dims] dim = new_shape_node_from_shape_nodes(order_of_dims) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination( reshape_node.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['flatten'] name = node.soft_get('name', node.id) assert node.has_valid('axis'), 'Flatten {} has no mandatory `axis` attribute'.format(name) assert node.has_valid('end_axis'), 'Flatten {} has no mandatory `end_axis` attribute'.format(name) axis = node.axis end_axis = node.end_axis if end_axis == -1 and axis >= 0: begin_dims = Const(graph, {'value': int64_array([0] * axis)}).create_node() middle_dim = Const(graph, {'value': int64_array([-1])}).create_node() end_dims = Const(graph, {'value': int64_array([])}).create_node() else: rank = Rank(graph, {'name': name + '/input_rank'}).create_node() node.in_port(0).get_source().connect(rank.in_port(0)) shape = Shape(graph, {'name': name + '/input_shape'}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) begin_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=0, end=axis) middle_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=axis, end=end_axis, include_end=True) end_dims = get_shape_values_by_range_idxs( shape=shape, rank=rank, begin=end_axis, end=-1, include_begin=False, include_end=True) middle_dim = create_op_node_with_second_input(graph, ReduceProd, int64_array([0]), {'keep_dims': True}) middle_dims.out_port(0).connect(middle_dim.in_port(0)) dim = new_shape_node_from_shape_nodes([begin_dims, middle_dim, end_dims]) original_name = node.soft_get('name') abandoned_name = original_name + '/ShouldBeDeleted' reshape_node = Reshape(graph, {}).create_node() # Keep node with the same name to avoid confuse with renaming rename_nodes([(node, abandoned_name), (reshape_node, original_name)]) reshape_node.in_port(1).connect(dim.out_port(0)) node.out_port(0).get_connection().set_source(reshape_node.out_port(0)) node.in_port(0).get_connection().set_destination(reshape_node.in_port(0))
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] name = node.soft_get('name', node.id) axis = node.axis input_shape_node = Shape(graph, { 'name': name + '/ShapeOf' }).create_node() range_node = create_op_with_const_inputs(graph, Range, { 0: mo_array(node.start), 2: mo_array(node.step) }, {'name': name + '/Range'}) node.in_port(0).get_connection().set_destination( input_shape_node.in_port(0)) if axis is not None: ''' Replace arange_like op to subgraph: Shape - Gather - Range ''' gather_node = create_op_with_const_inputs(graph, Gather, { 1: int64_array([axis]), 2: int64_array(0) }, {'name': name + '/Gather'}) input_shape_node.out_port(0).connect(gather_node.in_port(0)) gather_node.out_port(0).connect(range_node.in_port(1)) node.out_port(0).get_connection().set_source( range_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (range_node, name)]) else: r''' Replace arange_like op to subgraph: | ShapeOf ----------- | | | ReduceProd | | | Range | | | Reshape ----------- | | ''' flattened_shape_node = create_op_with_const_inputs( graph, ReduceProd, {1: int64_array([0])}, { 'name': input_shape_node.name + '/ReduceProd', 'keep_dims': True }) reshape_backward_node = Reshape(graph, { 'name': name + '/Reshape_backward' }).create_node() input_shape_node.out_port(0).connect( flattened_shape_node.in_port(0)) flattened_shape_node.out_port(0).connect(range_node.in_port(1)) range_node.out_port(0).connect(reshape_backward_node.in_port(0)) input_shape_node.out_port(0).connect( reshape_backward_node.in_port(1)) node.out_port(0).get_connection().set_source( reshape_backward_node.out_port(0)) rename_nodes([(node, name + '/ShouldBeDeleted'), (reshape_backward_node, name)]) if node.repeat != 1: r""" First, we generate the correct stop value for Range like new_stop_value = stop_value // repeat + 1. Then repeats each value of the interval using Tile. After that we can get a longer interval so we reduce it with Slice. Sub-graph after Range node will be look like Range - Reshape([-1, 1]) - Tile([1, repeat]) - Reshape(-1) - Slice """ if node.repeat < 1: raise Error( "Unexpected value {} of the attribute 'repeat' for the node {}" .format(node.repeat, name)) div_node = create_op_with_const_inputs( graph, Div, {1: int64_array([node.repeat])}, {'name': name + '/Divide'}) add_node = create_op_with_const_inputs( graph, Add, {1: int64_array([1])}, {'name': div_node.name + '/Add'}) cast_node = Cast(graph, { 'name': name + '/ConvertToI64', 'dst_type': np.int64 }).create_node() cast_node.out_port(0).connect(div_node.in_port(0)) div_node.out_port(0).connect(add_node.in_port(0)) range_node.in_port(1).get_connection().set_destination( cast_node.in_port(0)) add_node.out_port(0).connect(range_node.in_port(1)) tile_forward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1, 1])}, {'name': range_node.name + '/ForwardReshape'}) tile = create_op_with_const_inputs( graph, Tile, {1: int64_array([1, node.repeat])}, {'name': tile_forward_reshape.name + '/Tile'}) tile_backward_reshape = create_op_with_const_inputs( graph, Reshape, {1: int64_array([-1])}, {'name': tile.name + '/BackwardReshape'}) slice_node = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': tile_backward_reshape.name + '/Slice'}) tile_forward_reshape.out_port(0).connect(tile.in_port(0)) tile.out_port(0).connect(tile_backward_reshape.in_port(0)) tile_backward_reshape.out_port(0).connect(slice_node.in_port(0)) slice_node.in_port(2).connect(div_node.in_port(0).get_source()) range_node.out_port(0).get_connection().set_source( slice_node.out_port(0)) range_node.out_port(0).connect(tile_forward_reshape.in_port(0)) if axis is not None: rename_nodes([(range_node, name + '/Range'), (slice_node, name)]) # MXNet arange_like op has no stop attribute and the result tensor always matches the input shape, so # we have to correct the stop value for the Range node if step != 1 or start != 0 if node.step != 1: # If step attribute is not integer, we will generate an interval with a larger size and then reduce it # using Slice true_elements_count_port = range_node.in_port(1).get_source() mul_value = np.ceil(node.step) if node.step > 0 else np.floor( node.step) stop_value = create_op_with_const_inputs( graph, Mul, port_value_dict={1: mo_array(np.ceil(mul_value))}, op_attrs={'name': range_node.name + '/Stop'}) range_node.in_port(1).get_connection().insert_node(stop_value) slice_range_values = create_op_with_const_inputs( graph, Slice, { 1: int64_array([0]), 3: int64_array([0]), 4: int64_array([1]) }, {'name': range_node.name + '/Slice'}) slice_range_values.in_port(2).connect(true_elements_count_port) range_node.out_port(0).get_connection().insert_node( slice_range_values) if axis is not None and node.repeat == 1: rename_nodes([(range_node, name + '/Range'), (slice_range_values, name)]) if node.start != 0: correct_stop_value = create_op_with_const_inputs( graph, Add, port_value_dict={1: mo_array(node.start)}, op_attrs={'name': range_node.name + '/Correct_Stop'}) range_node.in_port(1).get_connection().insert_node( correct_stop_value) # Range node supports only scalar inputs squeeze_node = create_op_with_const_inputs( graph, Squeeze, port_value_dict={1: int64_array(0)}, op_attrs={"name": range_node.name + '/Stop/Squeeze'}) range_node.in_port(1).get_connection().insert_node(squeeze_node)
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len(group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, {'name': group_norm_node.name + '/Shape'}).create_node() initial_shape_op_node.in_port(0).connect(group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, {'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_shape_op_node.out_port(0).connect(initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value(initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value(initial_shape_op_node) initial_spatial_dims_node = Cast( graph, {'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type)}).create_node() initial_spatial_dims_node_int.out_port(0).connect(initial_spatial_dims_node.in_port(0)) group_size_node = Const(graph, {'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize'}).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const(graph, {'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize'}).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect(initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect(group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node]) new_shape_node = Cast(graph, {'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64}).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination(reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN(graph, {'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt'}).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port(mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, {0: int64_array(1), 2: int64_array(1)}, {'name': group_norm_node.name + '/Range', 'output_type': np.int64}) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect(initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination(mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination(add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source(add_node.out_port(0))
def mxrepeat_decomposition(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') # Unqueeze input_rank = Rank(graph, {'name': name + '/Rank'}).create_node() node.in_port(0).get_source().connect(input_rank.in_port(0)) axis = get_canonical_axis_index_node(input_rank, node.axis) unsqueeze_axis = create_op_node_with_second_input( graph, Add, int64_array([1]), {'name': name + '/Unsqueeze/Axis'}, input_node=axis) unsqueeze = Unsqueeze(graph, { 'name': name + '/Unsqueeze' }).create_node() unsqueeze.in_port(1).connect(unsqueeze_axis.out_port(0)) # Tile (1, 1, ..., repeats, ..., 1) # we generate tile array according to the following table: # parts: | first | repeats | second | # i: | 0, 1, ..., axis,| axis + 1,| ..., rank+1 | # tile_array: | 1, 1, ..., 1 ,| repeats ,| ..., 1 | one = Const(graph, { 'name': name + '/Broadcast/One', 'value': int64_array([1]) }).create_node() first_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_first_part' }).create_node() first_ones.in_port(0).connect(one.out_port(0)) first_ones.in_port(1).connect(unsqueeze_axis.out_port(0)) repeats = Const(graph, { 'name': name + '/repeats', 'value': int64_array([node.repeats]) }).create_node() second_ones = Broadcast(graph, { 'name': name + '/Broadcast/Ones_second_part' }).create_node() second_part_broadcast_shape = Sub( graph, { 'name': name + '/Broadcast/Shape/second_part' }).create_node() second_part_broadcast_shape.in_port(0).connect(input_rank.out_port(0)) second_part_broadcast_shape.in_port(1).connect( unsqueeze_axis.out_port(0)) second_ones.in_port(0).connect(one.out_port(0)) second_ones.in_port(1).connect(second_part_broadcast_shape.out_port(0)) tile_repeats = new_shape_node_from_shape_nodes( [first_ones, repeats, second_ones]) tile = Tile(graph, {'name': name + '/Tile'}).create_node() tile.in_port(1).connect(tile_repeats.out_port(0)) # Reshape (input_shape[:axis], input_shape[axis] * repeats, input_shape[axis+1:]) # we generate reshape dim array according to the following table: # parts: | first | rep | second | # i: | 0, 1, ... ,| axis, | ..., rank | # dim_array: | inp_sh[i] ,| input_shape[axis] * repeats ,| inp_sh[i] | input_shape = Shape(graph, {'name': name + '/Shape'}).create_node() node.in_port(0).get_source().connect(input_shape.in_port(0)) first_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=0, end=node.axis, include_begin=True, include_end=False) original_axis_dim = create_op_with_const_inputs( graph, Gather, {2: int64_array(0)}, {'name': name + '/OriginalDim'}, input_node=input_shape) original_axis_dim.in_port(1).connect(axis.out_port(0)) repeated_dimention = Mul(graph, { 'name': name + '/RepeatedDim' }).create_node() repeated_dimention.in_port(0).connect(original_axis_dim.out_port(0)) repeated_dimention.in_port(1).connect(repeats.out_port(0)) second_input_shape_part = get_shape_values_by_range_idxs( input_shape, input_rank, begin=node.axis, end=-1, include_begin=False, include_end=True) output_shape = new_shape_node_from_shape_nodes([ first_input_shape_part, repeated_dimention, second_input_shape_part ]) reshape = Reshape(graph, {'name': name}).create_node() rename_node(reshape, name) reshape.in_port(1).connect(output_shape.out_port(0)) # Final connections node.in_port(0).get_connection().set_destination(unsqueeze.in_port(0)) tile.in_port(0).connect(unsqueeze.out_port(0)) reshape.in_port(0).connect(tile.out_port(0)) node.out_port(0).get_connection().set_source(reshape.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['pool'] node_name = node.soft_get('name', node.id) if node.pool_step is None: node.stride = int64_array([1, 1, node.window[-1], node.window[-1]]) # create Reshape before convolution # shape = [in_shape[0], pool_stride, 1, in_shape[1]/pool_stride] i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() dst_dtype = np.float32 # even if data_type=FP16 use float32 for shape values shape = Cast(graph, { 'name': node_name + '/to_float', 'dst_type': dst_dtype }).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) shape.in_port(0).connect(i_shape.out_port(0)) N, H = node_to_get_shape_value_of_indices( shape, [0]), node_to_get_shape_value_of_indices(shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: float32_array([node.pool_stride])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { 1: float32_array([node.pool_stride]), 2: float32_array([1]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(3).connect(div.out_port(0)) reshape_pattern = Cast(graph, { 'name': node_name + '/to_int', 'dst_type': np.int64 }).create_node() concat.out_port(0).connect(reshape_pattern.in_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(reshape_pattern.out_port(0)) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['conv'] node_name = node.soft_get('name', node.id) # create Reshape before convolution # if transpose will be applied (new models) # shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), patch_stride, C=1] # else (for old models to avoid fails on GNA - should be removed as soon as GNA will be changed) # shape = [in_shape[0], t= in_shape[1]/(patch_stride*t), C=1, patch_stride] sp_dim_1 = 1 if node.has_valid('patch_stride'): channel_dim = 2 sp_dim_2 = 3 frame_height = node.patch_stride else: channel_dim = 3 sp_dim_2 = 2 frame_height = node.height_in i_shape = Shape(graph, {'name': node_name + '/Shape'}).create_node() i_shape.in_port(0).connect(node.in_port(0).get_source()) N, H = node_to_get_shape_value_of_indices( i_shape, [0]), node_to_get_shape_value_of_indices(i_shape, [1]) div = create_op_with_const_inputs( graph, Div, {1: int64_array([frame_height * node.kernel[1]])}, {'name': node_name + '/div_stride_h'}) div.in_port(0).connect(H.out_port(0)) concat = create_op_with_const_inputs( graph, Concat, { sp_dim_2: int64_array([frame_height]), channel_dim: int64_array([node.kernel[1]]) }, { 'name': node_name + '/concat_all_dims', 'in_ports_count': 4, 'axis': 0 }) concat.in_port(0).connect(N.out_port(0)) concat.in_port(sp_dim_1).connect(div.out_port(0)) reshape_in = Reshape(graph, { 'name': node_name + '/reshape_in' }).create_node() reshape_in.in_port(1).connect(concat.out_port(0)) # change layout from NHWC to NCHW # should be replaced by common Permute logic in future transpose = None if channel_dim == 3 and node.channel_dims == 1: transpose = create_op_node_with_second_input( graph, Transpose, int64_array([0, 3, 1, 2]), {'name': node.name + '/Transpose'}, reshape_in) # create Reshape after Convolution reshape_out = create_op_node_with_second_input( graph, Reshape, int64_array([0, -1]), {'name': node_name + '/reshape_out'}) # connect input_reshape_node source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source( transpose.out_port(0) if transpose else reshape_in.out_port(0)) reshape_in.in_port(0).connect(source) # connect output_reshape_node node.out_port(0).get_connection().set_source(reshape_out.out_port(0)) node.out_port(0).connect(reshape_out.in_port(0))