def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name)) if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name)) mvn = MVN(graph, dict( name=fbn.name + '/MVN_', eps=fbn.eps, eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_op(self, graph: Graph, node: Node): # Add new nodes mvn = MVN(graph, { 'eps': node.epsilon, 'name': node.name + '/Ins_Norm/MVN_', }).create_node() mul = Mul(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/mul_' }).create_node() add = Add(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/add_' }).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(mvn.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.out_port(0).connect(mul.in_port(0)) mul.out_port(0).connect(add.in_port(0)) return [add.id]
def extract(cls, node): name = node.soft_get('name', node.id) axes = onnx_attr(node, 'axes', 'ints', default=np.array([0, 2, 3], dtype=np.int64), dst_type=lambda x: np.array(x, dtype=np.int64)) if axes is not None: if 0 in axes: raise Error('Reduction over the batch dimension in node "{}" ' 'is not supported by the backend.'.format(name)) # Dimension 4 (if it's present in the input tensor) # should also be in the list of axes for reduction. # This case will be handled at the MVN Op side, # 'cause input shape is not available at that stage. for i in (2, 3): if i not in axes: raise Error( 'Reduction over spatial dimensions in node "{}" ' 'is obligatory for the backend.'.format(name)) attrs = { 'eps': 1e-9, 'across_channels': 1 if 1 in axes else 0, 'normalize_variance': 1, 'axes': axes } MVN.update_node_stat(node, attrs) return cls.enabled
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='MVNCaffe'): node_name = node.soft_get('name', node.id) start_axis = 2 if node['across_channels'] == 1: start_axis = 1 rank = Rank(graph, {'name': node_name + '/Rank'}).create_node() # create range of axes based on `start_axis` and rank of input rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(start_axis), 2: int64_array(1) }, { 'name': node_name + '/Range', 'output_type': np.int64 }) rng.in_port(1).connect(rank.out_port(0)) new_mvn = MVN( graph, { 'eps': node.soft_get('eps', 1e-9), 'eps_mode': 'inside_sqrt', 'normalize_variance': node.soft_get( 'normalize_variance', 1) }).create_node([node.in_port(0).get_source().node, rng]) new_mvn.in_port(0).get_connection().add_destination( rank.in_port(0)) node.out_port(0).get_connection().set_source(new_mvn.out_port(0)) rename_nodes([(node, node_name + '/tbd'), (new_mvn, node_name)]) graph.remove_node(node.id)
def extract(cls, node): proto_layer = node.pb param = proto_layer.mvn_param attrs = collect_attributes(param) # update the attributes of the node MVN.update_node_stat(node, attrs) return cls.enabled
def replace_op(self, graph: Graph, node: Node): prefix = node.name + '/InstanceNormalization' mvn = MVN(graph, dict(name=prefix + '/MVN', eps=node.epsilon)) mul = Mul(graph, dict(name=prefix + '/Mul', axis=1)) add = Add(graph, dict(name=prefix + '/Add', axis=1)) new_subgraph = add.create_node([ mul.create_node( [mvn.create_node([node.in_node(0)]), node.in_node(1)]), node.in_node(2) ]) return [new_subgraph.id]
def replace_sub_graph(graph: Graph, match: dict): mvn = MVN(graph, dict( name=match['truediv'].name + '/MVN_', eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) pow2 = match['pow'].in_node(1) eps = match['add'].in_node(0 if match['add'].in_node(0).id != match['variance'].id else 1) new_subgraph = mvn.create_node([match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps]) match['truediv'].replace_node(new_subgraph)
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] node.is_training = False shape = node.in_port(1).data.get_shape() assert shape is not None, 'The shape of scale input of the BatchNorm node {} is not defined'.format(node.name) bn_mean = Const(graph, {'name': node.name + '/mean', 'value': np.zeros(shape, dtype=np.float32), 'override_output_shape': True}).create_node() bn_std = Const(graph, {'name': node.name + '/std', 'value': np.ones(shape, dtype=np.float32), 'override_output_shape': True}).create_node() node.in_port(3).get_connection().set_source(bn_mean.out_port(0)) node.in_port(4).get_connection().set_source(bn_std.out_port(0)) # save the original shape original_shape = Shape(graph, {'name': node.in_port(0).get_source().node.soft_get('name')}).create_node() original_shape.in_port(0).connect(node.in_port(0).get_source()) mvn = MVN(graph, {'name': node.name + '/mvn_', 'eps': node.soft_get('eps', 1e-6), 'override_output_shape': True}).create_node() node.in_port(0).get_connection().insert_node(mvn) reshape_4d = create_op_node_with_second_input(graph, Reshape, int64_array([1, -1, 0, 0]), {'override_output_shape': True, 'name': node.soft_get('name') + '/fused_batch_and_channels'}) mvn.in_port(0).get_connection().insert_node(reshape_4d) # restore original shape reshape_back = Reshape(graph, {'name': mvn.soft_get('name') + '/restore_shape', 'override_output_shape': True}).create_node() reshape_back.in_port(1).connect(original_shape.out_port(0)) mvn.out_port(0).get_connection().insert_node(reshape_back)
def extract(cls, node): name = node.soft_get('name', node.id) axes = onnx_attr(node, 'axes', 'ints', default=np.array([0, 2, 3], dtype=np.int64), dst_type=lambda x: np.array(x, dtype=np.int64)) axes = Const(node.graph, {'value': axes, 'name': name + '/Axes'}).create_node() node.add_input_port(1, skip_if_exist=True) node.in_port(1).connect(axes.out_port(0)) attrs = { 'eps': 1e-9, 'normalize_variance': 1, 'eps_mode': 'outside_sqrt' } MVN.update_node_stat(node, attrs) return cls.enabled
def replace_sub_graph(graph: Graph, match: dict): mvn = MVN( graph, dict(name=match['truediv'].name + '/MVN_', required_reduction_indices=[1, 2] if graph.graph['layout'] == 'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) pow2 = match['pow'].in_node(1) eps = match['add'].in_node( 0 if match['add'].in_node(0).id != match['variance'].id else 1) new_subgraph = mvn.create_node([ match['mean'].in_node(0), mean_reduction, variance_reduction, pow2, eps ]) match['truediv'].replace_node(new_subgraph)
def replace_sub_graph(self, graph: Graph, match: dict): inp = match['pool0'] inp_port = inp.in_port(0).get_source() # take/check the values of the add, pow and axes for ReduceMean pow_param = match['pow_param'] add_param = match['add_param'] if add_param.value.size == 1 and pow_param.value.size == 1 and add_param.value.item() <= 1e-05 \ and pow_param.value.item() == 0.5 and match['pool0_param'].value == match['pool1_param'].value: log.debug('Found LayerNorm pattern after {} with name {}'.format( inp_port.node.op, inp_port.node.name)) mvn = MVN( graph, { 'eps': add_param.value.item(), 'axes': match['pool1_param'].value, 'normalize_variance': 1 }).create_node() div_name = match['div'].soft_get('name', match['div'].id) rename_nodes([(match['div'], div_name + '/to_be_removed'), (mvn, div_name)]) inp_port.connect(mvn.in_port(0)) match['div'].out_port(0).get_connection().set_source( mvn.out_port(0))
def replace_op(self, graph: Graph, node: Node): axis = Const(graph, {'value': int64_array([-1])}).create_node() mvn = MVN( graph, dict(name=node.name + '/mvn', eps=node.module.eps, normalize_variance=True, eps_mode='inside_sqrt')).create_node([node.in_node(0), axis]) weight = node.module.weight.detach().numpy() bias = node.module.bias.detach().numpy() w = Const(graph, {'value': weight}).create_node() b = Const(graph, {'value': bias}).create_node() mul = Mul(graph, dict(name=node.name + '/mul')).create_node([mvn, w]) add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b]) return [add.id]
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def replace_sub_graph(self, graph: Graph, match: dict): if not check_applicability(match): return reshape = match['reshape'] div_name = match['division'].name input_shape = Shape(graph, dict(name=div_name + '/shape/MVN_T_')).create_node() shape_of_reshape = reshape.in_port( 1).get_connection().get_source().node.value c1, c2 = shape_of_reshape[1], shape_of_reshape[2] c = c1 * c2 new_reshape = create_op_node_with_second_input( graph, Reshape, int64_array([0, 0, 0, c1, c2]), dict(name=div_name + '/first_reshape/MVN_T_')) permute_order = int64_array([0, 1, 2, 4, 3]) first_permute = create_op_node_with_second_input( graph, Transpose, permute_order, dict(name=div_name + '/first_permute/MVN_T_'), new_reshape) add = match['add'] variance = match['variance'] eps_port_num = 0 if add.in_port( 0).get_connection().get_source().node.id != variance.id else 1 eps = add.in_port(eps_port_num).get_connection().get_source().node mvn_node = MVN( graph, dict(name=div_name + '/MVN/MVN_T_', required_reduction_indices=[1, 2, 3], eps=eps.value)).create_node() first_permute.out_port(0).connect(mvn_node.in_port(0)) second_permute = create_op_node_with_second_input( graph, Transpose, permute_order, dict(name=div_name + '/second_permute/MVN_T_'), mvn_node) new_reshape2 = Reshape(graph, dict(name=div_name + '/second_reshape/MVN_T_')).create_node() second_permute.out_port(0).connect(new_reshape2.in_port(0)) gamma_val = np.reshape( match['gamma_identity'].in_port( 0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_mul = create_op_node_with_second_input( graph, Mul, gamma_val, dict(name=match['mul'].name + '/MVN_T_'), new_reshape2) beta_val = np.reshape( match['beta_identity'].in_port( 0).get_connection().get_source().node.value, int64_array([1, 1, 1, c])) new_add2 = create_op_node_with_second_input( graph, Add, beta_val, dict(name=match['add2'].name + '/MVN_T_'), new_mul) transpose_connection = match['transpose'].in_port(0).get_connection() before_transpose = transpose_connection.get_source().node transpose_connection.set_destination(new_reshape.in_port(0)) input_shape.out_port(0).connect(new_reshape2.in_port(1)) before_transpose.out_port(0).connect(input_shape.in_port(0)) match['transpose2'].out_port(0).get_connection().set_source( new_add2.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_batch_dim_node = node_to_get_batch_value(initial_shape_op_node) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node) initial_spatial_dims_node = node_to_get_spatial_dimensions_value( initial_shape_op_node) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'across_channels': 1, 'normalize_variance': 1, 'eps': group_norm_node.eps }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): group_norm_node = match['op'] group_norm_num_input_dims = len( group_norm_node.in_port(0).data.get_shape()) # node computing initial GroupNorm input shape initial_shape_op_node = Shape(graph, { 'name': group_norm_node.name + '/Shape' }).create_node() initial_shape_op_node.in_port(0).connect( group_norm_node.in_port(0).get_source()) initial_shape_op_node_float = Cast( graph, { 'name': initial_shape_op_node.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_shape_op_node.out_port(0).connect( initial_shape_op_node_float.in_port(0)) initial_batch_dim_node = node_to_get_batch_value( initial_shape_op_node_float) initial_features_dim_node = node_to_get_features_dimension_value( initial_shape_op_node_float) initial_spatial_dims_node_int = node_to_get_spatial_dimensions_value( initial_shape_op_node) initial_spatial_dims_node = Cast( graph, { 'name': initial_spatial_dims_node_int.name + '/to_float', 'dst_type': data_type_str_to_np(graph.graph['cmd_params'].data_type) }).create_node() initial_spatial_dims_node_int.out_port(0).connect( initial_spatial_dims_node.in_port(0)) group_size_node = Const( graph, { 'value': int64_array([group_norm_node.num_groups]), 'name': group_norm_node.name + '/GroupSize' }).create_node() # calculate "features // group_size" value reciprocal_group_size_node = Const( graph, { 'value': np.array([1.0 / group_norm_node.num_groups]), 'name': group_norm_node.name + '/ReciprocalGroupSize' }).create_node() c_div_g_node = Mul(graph, {}).create_node() c_div_g_node.in_port(0).connect(initial_features_dim_node.out_port(0)) c_div_g_node.in_port(1).connect(reciprocal_group_size_node.out_port(0)) batch_mul_group_size_node = Mul(graph, {}).create_node() batch_mul_group_size_node.in_port(0).connect( initial_batch_dim_node.out_port(0)) batch_mul_group_size_node.in_port(1).connect( group_size_node.out_port(0)) # create new node which concatenates several dims to one new_shape_node_float = new_shape_node_from_shape_nodes([ batch_mul_group_size_node, c_div_g_node, initial_spatial_dims_node ]) new_shape_node = Cast(graph, { 'name': new_shape_node_float.name + '/to_int64', 'dst_type': np.int64 }).create_node() new_shape_node_float.out_port(0).connect(new_shape_node.in_port(0)) reshape_for_mvn_node = Reshape(graph, {}).create_node() group_norm_node.in_port(0).get_connection().set_destination( reshape_for_mvn_node.in_port(0)) reshape_for_mvn_node.in_port(1).connect(new_shape_node.out_port(0)) # Reshape the gamma and beta constants to correct layout from [C] to [1,C], [1,C,1], [1,C,1,1] etc gamma_beta_shape = np.ones([group_norm_num_input_dims], dtype=np.int64) gamma_beta_shape[1] = -1 gamma_value = group_norm_node.in_port(1).get_source().data.get_value() beta_value = group_norm_node.in_port(2).get_source().data.get_value() assert gamma_value is not None, 'The gamma should be constant' assert beta_value is not None, 'The beta should be constant' gamma_value = np.reshape(gamma_value, gamma_beta_shape) group_norm_node.in_port(1).get_source().data.set_value(gamma_value) beta_value = np.reshape(beta_value, gamma_beta_shape) group_norm_node.in_port(2).get_source().data.set_value(beta_value) # MVN mvn_node = MVN( graph, { 'name': group_norm_node.name + '/MVN', 'normalize_variance': 1, 'eps': group_norm_node.eps, 'eps_mode': 'inside_sqrt' }).create_node() mvn_node.in_port(0).connect(reshape_for_mvn_node.out_port(0)) # MVN axes _, rank = get_shape_and_rank_nodes_by_port( mvn_node.in_port(0).get_connection().get_source(), return_as_a_scalar=True) rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(1), 2: int64_array(1) }, { 'name': group_norm_node.name + '/Range', 'output_type': np.int64 }) mvn_node.in_port(1).connect(rng.out_port(0)) rng.in_port(1).connect(rank.out_port(0)) # reshape to the initial shape before multiplying with gamma and adding beta reshape_to_initial_shape_node = Reshape(graph, {}).create_node() reshape_to_initial_shape_node.in_port(0).connect(mvn_node.out_port(0)) reshape_to_initial_shape_node.in_port(1).connect( initial_shape_op_node.out_port(0)) mul_node = Mul(graph, {'name': mvn_node.name + '/Mul'}).create_node() mul_node.in_port(0).connect(reshape_to_initial_shape_node.out_port(0)) group_norm_node.in_port(1).get_connection().set_destination( mul_node.in_port(1)) add_node = Add(graph, {'name': mul_node.name + '/Add'}).create_node() add_node.in_port(0).connect(mul_node.out_port(0)) group_norm_node.in_port(2).get_connection().set_destination( add_node.in_port(1)) group_norm_node.out_port(0).get_connection().set_source( add_node.out_port(0))