def replace_sub_graph(self, graph: Graph, match: dict): source_connection = match['split'].in_port(0).get_connection() source_node = source_connection.get_source().node cast_node = match['cast'] range_node = Range(graph, { 'name': source_node.id + '/Range' }).create_node() start_node = Const(graph, { 'name': range_node.id + '/Start', 'value': int64_array(0) }).create_node() step_node = Const(graph, { 'name': range_node.id + '/Step', 'value': int64_array(1) }).create_node() input_shape_node = Shape(graph, { 'name': start_node.id + '/Shape' }).create_node() input_shape_node.in_port(0).connect(source_node.out_port(0)) limit_node_1D = node_to_get_batch_value(input_shape_node) limit_node = create_op_node_with_second_input( graph, Squeeze, int64_array([0]), {'name': source_node.id + '/batch_0D_value'}, limit_node_1D) range_node.in_port(0).connect(start_node.out_port(0)) range_node.in_port(1).connect(limit_node.out_port(0)) range_node.in_port(2).connect(step_node.out_port(0)) cast_node.out_port(0).get_connection().set_source( range_node.out_port(0)) graph.remove_nodes_from([node.id for node in match.values()])
def decompose_shuffle_channel(node: Node): graph = node.graph name = node.soft_get('name', node.id) rename_node(node, name + '/to_be_removed') shape = Shape(graph, dict(name=name + '/InputShape')).create_node() shape.in_port(0).connect(node.in_port(0).get_source()) # Reshape [input_batch, group, input_channels/group, -1] batch = node_to_get_batch_value(shape) group = Const( graph, dict(name=name + '/Rows', value=int64_array([node.group]))).create_node() const = Const(graph, dict(name=name + '/Const', value=int64_array([-1]))).create_node() input_channels = node_to_get_features_dimension_value(shape) output_channels = create_op_node_with_second_input( graph, Div, np.int64(node.group), {'name': name + '/Cols'}, input_node=input_channels) i_output_channels = Cast(graph, { 'name': output_channels.name + '/Convert', 'dst_type': np.int64 }).create_node() output_channels.out_port(0).connect(i_output_channels.in_port(0)) reshape_split_dim = new_shape_node_from_shape_nodes( [batch, group, i_output_channels, const]) reshape_split_node = Reshape( graph, dict(name=name + '/Reshape_split_')).create_node() reshape_split_dim.out_port(0).connect(reshape_split_node.in_port(1)) # Transpose(0, 2, 1, 3) transpose_node = create_op_node_with_second_input( graph, Transpose, int64_array([0, 2, 1, 3]), {'name': name + '/Transpose_'}, input_node=reshape_split_node) # Reshape back to input shape reshape_concat = Reshape(graph, dict(name=name)).create_node() rename_node(reshape_concat, name) shape.out_port(0).connect(reshape_concat.in_port(1)) transpose_node.out_port(0).connect(reshape_concat.in_port(0)) # Final connections node.in_port(0).get_connection().set_destination( reshape_split_node.in_port(0)) node.out_port(0).get_connection().set_source( reshape_concat.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node) initial_spatial_dims_node = node_to_get_spatial_dimensions_value( initial_shape_op_node) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'across_channels': 1, 'normalize_variance': 1, 'eps': group_norm_node.eps }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['interpolate'] assert 1 in node.in_ports() and not node.in_port(1).disconnected() and \ node.in_port(1).data.get_value() is not None, 'Interpolate node {} is corrupted: no 1-port input found' # common mode = node.mode assert mode in ['linear', 'nearest', 'cubic', 'area'] in_shape = node.in_port(0).data.get_shape() assert in_shape is not None and len(in_shape) in [4, 5] out_shape = node.out_port(0).data.get_shape() assert out_shape is not None and len(out_shape) in [4, 5] in_height, in_width = in_shape[2], in_shape[3] out_height, out_width = out_shape[2], out_shape[3] factor = factor_update( None if not node.has_valid('factor') else node.factor, [float(out_height) / in_height, float(out_width) / in_width], [in_height, in_width], [out_height, out_width], node.soft_get('name')) update_attrs = { 'width': out_width, 'height': out_height, 'factor': factor, } if (node.has_valid('shrink_factor') and node.has_valid('zoom_factor')) or factor is None: del update_attrs['factor'] if node.has('factor'): del node['factor'] if ((node.has_valid('shrink_factor') and node.shrink_factor != 1) or (node.has_valid('zoom_factor') and node.zoom_factor != 1) or 'factor' in update_attrs) \ and ((not node.has_valid('width') or node.width == 0) and (not node.has_valid('height') or node.height == 0)): update_attrs['width'] = 0 update_attrs['height'] = 0 # specific if mode in ['nearest', 'cubic', 'area' ] or node.has_and_set('convert_to_resample'): assert not node.align_corners assert node.pads_begin == 0 and node.pads_end == 0 update_attrs[ 'resample_type'] = InterpolateToInterpOrResample.type_map[mode] ResampleOp.update_node_stat(node, update_attrs) if not graph.graph[ 'cmd_params'].keep_shape_ops or graph.graph['fw'] != 'tf': node.in_port(1).disconnect() else: # we avoid making resample non-reshapable for tf version shape = Shape(graph, {}).create_node() node.in_port(0).get_source().connect(shape.in_port(0)) batch = node_to_get_batch_value(shape) features = node_to_get_features_dimension_value(shape) full_shape = new_shape_node_from_shape_nodes( [batch, features, node.in_port(1).get_source().node]) node.in_port(1).get_connection().set_source( full_shape.out_port(0)) full_shape['override_output_shape'] = True elif mode == 'linear': assert len(in_shape) == 4, 'Interp does not support 5D input' update_attrs.update({ 'pad_beg': node.pads_begin, 'pad_end': node.pads_end, 'align_corners': node.align_corners, }) InterpOp.update_node_stat(node, update_attrs) node.in_port(1).disconnect()
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, { 'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_shape_op_node.out_port(0).connect( initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value( initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value( initial_shape_op_node) initial_spatial_dims_node = Cast( graph, { 'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_spatial_dims_node_int.out_port(0).connect( initial_spatial_dims_node.in_port(0)) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) new_shape_node = Cast(graph, { 'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64 }).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt' }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port( mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 2: int64_array(1) }, { 'name': group_norm_node.name + '/Range', 'output_type': np.int64 }) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))