def replace_op(self, graph: Graph, node: Node): # Add new nodes mvn = MVN(graph, { 'eps': node.epsilon, 'name': node.name + '/Ins_Norm/MVN_', }).create_node() mul = Mul(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/mul_' }).create_node() add = Add(graph, { 'axis': 1, 'name': node.name + '/Ins_Norm/add_' }).create_node() # Connect nodes node.in_port(0).get_connection().set_destination(mvn.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.out_port(0).connect(mul.in_port(0)) mul.out_port(0).connect(add.in_port(0)) return [add.id]
def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format(input.op, input.name)) if input.id != match['mean'].in_node(0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format(input.op, input.name)) mvn = MVN(graph, dict( name=fbn.name + '/MVN_', eps=fbn.eps, eps_mode='outside_sqrt', normalize_variance=1 )) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_pattern(graph: Graph, match: dict): node = match['op'] if node.has_port('in', 2) and not node.in_port( 2).disconnected() and not node.has_and_set('shape_input'): bias_name = node.name new_node_name = node.name + '/WithoutBiases' add = Add(graph, dict(name=bias_name)).create_node() rename_nodes([(node, new_node_name), (add, bias_name)]) node.out_port(0).get_connection().set_source(add.out_port(0)) node.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) bias = add.in_port(1).get_source().node if bias.has_valid("type") and bias.type == "Const": input_shape = add.in_port(0).data.get_shape() if len(input_shape) > 2: dims_to_add = len(input_shape) - 2 if graph.graph[ 'layout'] == 'NCHW' else 0 if dims_to_add > 0: reshape = create_op_node_with_second_input( graph, Reshape, np.array([input_shape[1]] + [1] * dims_to_add, dtype=np.int64), {'name': node.id + '/Dims'}) add.in_port(1).get_connection().set_destination( reshape.in_port(0)) reshape.out_port(0).connect(add.in_port(1))
def replace_sub_graph(self, graph: Graph, match: dict): op = match['op'] out_port = op.in_port(0).get_source() if op.soft_get('scale', 1) != 1: const = Const(graph, {'value': np.array(op.scale)}).create_node() mul = Mul(graph, {'name': op.name + '/mul_'}).create_node() const.out_port(0).connect(mul.in_port(1)) out_port.connect(mul.in_port(0)) out_port = mul.out_port(0) if op.soft_get('shift', 0) != 0: const = Const(graph, {'value': np.array(op.shift)}).create_node() add = Add(graph, {'name': op.name + '/add_'}).create_node() const.out_port(0).connect(add.in_port(1)) out_port.connect(add.in_port(0)) out_port = add.out_port(0) if op.soft_get('power', 1) != 1: const = Const(graph, {'value': np.array(op.power)}).create_node() pow = Pow(graph, {'name': op.name + '/pow_'}).create_node() const.out_port(0).connect(pow.in_port(1)) out_port.connect(pow.in_port(0)) out_port = pow.out_port(0) op.out_port(0).get_connection().set_source(out_port)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['op'] name = node.soft_get('name', node.id) # biases normalization if 2 in node.in_ports() and not node.in_port(2).disconnected(): bias_node = Add(graph, {'name': name + '/Bias_'}).create_node() if not graph.graph['cmd_params'].generate_deprecated_IR_V7: node_name = node.name + '/WithoutBiases' bias_node_name = node.name rename_nodes([(node, node_name), (bias_node, bias_node_name)]) node.out_port(0).get_connection().set_source(bias_node.out_port(0)) node.in_port(2).get_connection().set_destination(bias_node.in_port(1)) node.out_port(0).connect(bias_node.in_port(0)) # weights normalization assert node.has_valid('out-size') out_size = node['out-size'] reshape_dim = int64_array([-1, out_size]) if node.has_and_set('transpose_weights'): reshape_dim = int64_array([out_size, -1]) node.insert_op_on_input_port(in_port_idx=1, new_op_class=Reshape, new_op_attrs={'name': name + '/weights_reshape'}, value=reshape_dim) if node.has_and_set('transpose_weights'): node.insert_op_on_input_port(in_port_idx=1, new_op_class=Transpose, new_op_attrs={'name': name + '/weights_transpose'}, value=int64_array([1, 0])) # input normalization for 4D Caffe and MxNet FullyConnected if graph.graph['fw'] in ['caffe', 'mxnet']: node.insert_op_on_input_port(in_port_idx=0, new_op_class=Reshape, new_op_attrs={'name': name + '/flatten_fc_input'}, value=int64_array([0, -1])) MatMul.update_node_stat(node, {})
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): cmp = match['complex'] complex_abs = match['abs'] complex_abs_name = complex_abs.soft_get('name', complex_abs.id) power_type = data_type_str_to_np(graph.graph['cmd_params'].data_type) pow0 = create_op_with_const_inputs( graph, Pow, {1: power_type(2.0)}, {'name': complex_abs_name + '/real_part_squared'}) pow1 = create_op_with_const_inputs( graph, Pow, {1: power_type(2.0)}, {'name': complex_abs_name + '/imag_part_squared'}) cmp.in_port(0).get_connection().set_destination(pow0.in_port(0)) cmp.in_port(1).get_connection().set_destination(pow1.in_port(0)) add = Add(graph, { 'name': complex_abs_name + '/squared_abs' }).create_node([pow0, pow1]) sqrt = create_op_with_const_inputs(graph, Pow, {1: power_type(0.5)}, {}) add.out_port(0).connect(sqrt.in_port(0)) complex_abs.out_port(0).get_connection().set_source(sqrt.out_port(0)) rename_nodes([(complex_abs, complex_abs_name + '/to_be_removed'), (sqrt, complex_abs_name)])
def apply_mean_value(graph: Graph, input_node: Node, node_mean_scale_values: dict): if 'mean' in node_mean_scale_values and node_mean_scale_values[ 'mean'] is not None: if all([x == 0 for x in node_mean_scale_values['mean']]): return out_node = input_node.out_node() if not input_node.has_valid('shape'): raise Error("Node {} has not valid shape attribute".format( input_node.id)) input_shape = input_node.shape # Create Add node graph.remove_edge(input_node.id, out_node.id) value = np.array(node_mean_scale_values['mean']) * (-1) add_node = Add(graph, dict(name="Add_")) add_data = Op.create_input_data_node(graph, "data_add_", np.array(value)) Op.expand_node_shape(add_data, (len(input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 0)) add_input = Op.create_data_node(graph, input_node, {'shape': out_node.shape}) add_node.create_node_with_data(inputs=[add_input, add_data], data_nodes=out_node)
def extract(cls, node: Node): attrs = {'axis': node.pb.bias_param.axis} embed_input(attrs, 1, 'bias', node.model_pb.blobs[0].data, 'biases') Add.update_node_stat(node, attrs) return cls.enabled
def replace_sub_graph(self, graph: Graph, match: dict): # This replacer replace ImageScalar operation to Mul->Add sequence # Also it check that weights and biases are good op = match['op'] # Check that weights and biases are not useless has_bias, has_weights = True, True if all([x == 1 for x in np.nditer(op.scale)]): has_weights = False if all([x == 0 for x in np.nditer(op.bias)]): has_bias = False assert len(op.in_ports()) == 1 last_port = op.in_port(0).get_source() # Create Mul & Add nodes if has_weights: mul_weights = Const(graph, dict(value=op.scale, shape=op.scale.shape)).create_node() mul_op = Mul(graph, dict(name=op.id + '/mul_')).create_node() op.in_port(0).get_connection().set_destination(mul_op.in_port(0)) mul_weights.out_port(0).connect(mul_op.in_port(1)) last_port = mul_op.out_port(0) if has_bias: add_bias = Const(graph, dict(value=op.bias, shape=op.bias.shape)).create_node() add_op = Add(graph, dict(name=op.id + '/add_')).create_node() last_port.get_connection().set_destination(add_op.in_port(0)) add_bias.out_port(0).connect(add_op.in_port(1)) last_port = add_op.out_port(0) op.in_port(0).disconnect() op.out_port(0).get_connection().set_source(last_port)
def replace_pattern(self, graph: Graph, match: dict): node = match['op'] if (node.data_format != b'NHWC' or len(node.in_nodes()) != 5 or node.in_node(0).value is not None or # input node.in_node(1).value is None or # scale node.in_node(2).value is None or # offset node.in_node(3).value is not None or # mean node.in_node(4).value is not None or # variance node.in_node(1).value.ndim != 1 or node.in_node(2).value.ndim != 1): return scale_mul = Mul(graph, dict(name=node.name + '/scale_mul_')) shift_add = Add(graph, dict(name=node.name + '/shift_add_')) mean_add = Add(graph, dict(name=node.name + '/mean_add_')) variance_mul = Mul(graph, dict(name=node.name + '/variance_mul_')) neg_const = Const( graph, dict(value=np.array(-1), name=node.name + '/mean_negate_')) mean_negate = Mul(graph, dict(name=node.name + '/mean_negate_')) mean_arg = mean_add.create_node_with_data([ node.in_node(0), mean_negate.create_node_with_data( [node.in_node(3), neg_const.create_node_with_data()]) ]) shift_const = Const( graph, dict(value=node.eps, name=node.name + '/variance_denom_shift_const_')) power_const = Const( graph, dict(value=-0.5, name=node.name + '/variance_denom_power_const_')) variance_denom_shift = Add( graph, dict(name=node.name + '/variance_denom_shift_')) variance_denom_power = Pow( graph, dict(name=node.name + '/variance_denom_power_')) variance_arg = variance_mul.create_node_with_data([ mean_arg, variance_denom_power.create_node_with_data([ variance_denom_shift.create_node_with_data( [node.in_node(4), shift_const.create_node_with_data()]), power_const.create_node_with_data() ]) ]) shift_add.create_node_with_data([ scale_mul.create_node_with_data([variance_arg, node.in_node(1)]), node.in_node(2) ], data_nodes=node.out_node()) node.graph.remove_node(node.id)
def replace_pattern(self, graph: Graph, match: dict): bias_add = match['BiasAdd'] # Replace BiasAdd by Add operation new_add = Add(graph, {'name': bias_add.id + '/Add'}).create_node() bias_add.in_port(0).get_connection().set_destination( new_add.in_port(0)) bias_add.in_port(1).get_connection().set_destination( new_add.in_port(1)) bias_add.out_port(0).get_connection().set_source(new_add.out_port(0)) if bias_add.data_format != 'NCHW': return input_shape = new_add.in_port(0).data.get_shape() bias_shape = new_add.in_port(1).data.get_shape() assert len(bias_shape) == 1 unsqueeze_dims = np.arange(len(input_shape)) channel_dim = get_features_dim('NCHW', len(input_shape)) unsqueeze_dims = np.delete(unsqueeze_dims, channel_dim, 0) unsqueeze_node = Unsqueeze(graph, { 'name': new_add.id + '/BiasUnsqueeze' }).create_node() unsqueeze_dims_node = Const(graph, { 'name': new_add.id + '/Dims', 'value': unsqueeze_dims }).create_node() # Reconnecting nodes unsqueeze_node.in_port(1).connect(unsqueeze_dims_node.out_port(0)) unsqueeze_node['override_output_shape'] = True new_add.in_port(1).get_connection().insert_node(unsqueeze_node)
def replace_sub_graph(self, graph: Graph, match: [dict, SubgraphMatch]): node = match['op'] name = node.soft_get('name', node.id) # biases normalization bias_node = Add(graph, {'name': name + '/Bias_', 'can_be_scaleshift': False}).create_node() if not graph.graph['cmd_params'].generate_deprecated_IR_V7: node_name = node.name + '/WithoutBiases' bias_node_name = node.name rename_nodes([(node, node_name), (bias_node, bias_node_name)]) node.out_port(0).get_connection().set_source(bias_node.out_port(0)) node.in_port(2).get_connection().set_destination(bias_node.in_port(1)) node.out_port(0).connect(bias_node.in_port(0)) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): bias_node.insert_op_on_input_port(in_port_idx=0, new_op_class=Mul, value=np.array(node.alpha), new_op_attrs={'name': name + '/Alpha_', 'can_be_scaleshift': False}) del node['alpha'] if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array(node.beta), new_op_attrs={'name': name + '/Beta_', 'can_be_scaleshift': False}) del node['beta'] MatMul.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def replace_op(self, graph: Graph, node: Node): in_node_0 = node.in_node(0) in_node_1 = node.in_node(1) in_node_2 = node.in_node(2) ss = ScaleShiftOp(graph, {'name': node.id + "/ScaleShift_", 'axis': 0}) scale_shift = ss.create_node(inputs=[in_node_1, in_node_0]) el = Add(graph, {'name': node.id + "/Add_"}) el_node = el.create_node(inputs=[scale_shift, in_node_2]) return [el_node.id]
def replace_op(self, graph: Graph, node: Node): # Add new nodes const = Const(graph, dict(value=np.array(-1, dtype=np.int32))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) const.out_port(0).connect(negate.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) # The "explicit" version of the return value is: [(out_node.id, 0)]) return [add.id]
def replace_pattern(graph: Graph, match: [str, Node]): node = match['sub'] # Add new nodes negate_const = Const( graph, dict(name=node.name + '/negate_const', value=np.array(-1))).create_node() negate = Mul(graph, {'name': node.name + '/negate_'}).create_node() add = Add(graph, {'name': node.name + '/add_'}).create_node() # Connect nodes node.in_port(1).get_connection().set_destination(negate.in_port(0)) negate_const.out_port(0).connect(negate.in_port(1)) node.in_port(0).get_connection().set_destination(add.in_port(1)) negate.out_port(0).connect(add.in_port(0)) node.out_port(0).get_connection().set_source(add.out_port(0))
def replace_op(self, graph: Graph, node: Node): matmul = MatMul(graph, dict(name=node.name, transpose_b=True)).create_node([node.in_node(0), node.in_node(1)]) # Bias if len(node.in_nodes()) > 2: matmul = Add(graph, dict(name=node.name + '/bias')).create_node([matmul, node.in_node(2)]) return [matmul.id]
def create_bias_node(graph: Graph, src_node): logger.debug('Creating new bias for {}'.format(src_node.name)) destination_ports = [] for dest_port in src_node.out_port(0).get_destinations(): destination_ports.append(dest_port) # Create Add and constant with zero bias bias_shape = src_node.out_port(0).data.get_shape() add_bias_shape = [1] * len(bias_shape) add_bias_shape[1] = bias_shape[1] weights = get_weights_for_node(src_node) bias_dtype = np.float32 if weights and weights.out_port(0).is_data_type_defined(): bias_dtype = weights.out_port(0).get_data_type() add_bias = Const(graph, {'value': np.zeros(add_bias_shape, dtype=bias_dtype), 'shape': add_bias_shape, 'need_shape_inference': True }).create_node() add_op = Add(graph, {'name': src_node.name + '/add_', 'need_shape_inference': True}).create_node() # Connect Const to Add node add_op.in_port(1).connect(add_bias.out_port(0)) # Reconnect src_node -> output to src_node -> Add -> output src_node.out_port(0).disconnect() src_node.out_port(0).get_connection().set_destination(add_op.in_port(0)) for destination_port in destination_ports: add_op.out_port(0).connect(destination_port) add_bias.out_node(0)['Insert_Convert_operation_after'] = True
def get_canonical_axis_index_node(rank: Node, axis: int) -> Node: """ Returns positive axis value :param rank: the node of 0D output shape to get rank of tensor from :param axis: integer value from [-rank; rank - 1] :return: node producing positive integer value of axis """ graph = rank.graph name = rank.soft_get('name', rank.id) if axis < 0: axis = Const(graph, {'name': name + '/negative_axis', 'value': int64_array(axis)}).create_node() add = Add(graph, {'name': name + '/positive_axis'}).create_node() rank.out_port(0).connect(add.in_port(0)) axis.out_port(0).connect(add.in_port(1)) return add else: return Const(graph, {'name': name + '/positive_axis', 'value': int64_array(axis)}).create_node()
def get_range_node_of_idxs(rank: Node, begin: int, end: int, include_begin: bool = True, include_end: bool = False) -> Node: """ Returns node that produces 1D output of values of range from begin to end (ex)/(in)cluding begin or end point :param rank: the node of 0D output shape to get rank of tensor from :param begin: integer value from [-rank; rank - 1] :param end: integer value from [-rank; +rank] :param include_begin: boolean flag to include or exclude start point from range output :param include_end: boolean flag to include or exclude end point from range output :return: range node producing 1D output """ graph = rank.graph name = rank.soft_get('name', rank.id) start_idx = get_canonical_axis_index_node(rank, begin) end_idx = get_canonical_axis_index_node(rank, end) if not include_begin: const = Const(graph, { 'value': int64_array(1), 'name': name + '/exclude_begin/value' }).create_node() add = Add(graph, {'name': name + '/exclude_begin'}).create_node() start_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) start_idx = add if include_end: const = Const(graph, { 'value': int64_array(1), 'name': name + '/including_end/value' }).create_node() add = Add(graph, {'name': name + '/including_end'}).create_node() end_idx.out_port(0).connect(add.in_port(0)) const.out_port(0).connect(add.in_port(1)) end_idx = add delta = Const(graph, { 'name': name + '/delta', 'value': int64_array(1) }).create_node() range_node = Range(graph, {'name': name + '/range_idxs'}).create_node() start_idx.out_port(0).connect(range_node.in_port(0)) end_idx.out_port(0).connect(range_node.in_port(1)) delta.out_port(0).connect(range_node.in_port(2)) return range_node
def replace_pattern(self, graph: Graph, match: dict): quantize = match['quantize'] sum_node = Add(graph, dict()).create_node() const = Const(graph, {'value': np.array(0.5)}).create_node() mul_node = Mul(graph, dict()).create_node() mul_node.in_port(0).connect(sum_node.out_port(0)) mul_node.in_port(1).connect(const.out_port(0)) quantize.in_port(1).get_connection().get_source().connect( sum_node.in_port(0)) quantize.in_port(2).get_connection().get_source().connect( sum_node.in_port(1)) quantize.in_port(1).disconnect() quantize.in_port(2).disconnect() mul_node.out_port(0).connect(quantize.in_port(1)) mul_node.out_port(0).connect(quantize.in_port(2))
def replace_op(self, graph: Graph, node: Node): weight = node.module.weight.detach().numpy() bias = node.module.bias.detach().numpy() weight = Const(graph, {'value': weight}).create_node() bias = Const(graph, {'value': bias}).create_node() matmul = MatMul(graph, dict(name=node.name)).create_node( [node.in_node(0), weight]) matmul = Add(graph, dict(name=node.name + '/bias')).create_node( [matmul, bias]) return [matmul.id]
def add_constant_to_negative_values(node: Node, port_idx: int, added_value: np.array): """ This function adds the given values to negative elements of value from the given input port. :param node: node with corrected values in the input port port_idx :param port_idx: input port index for negative values :param added_value: the value to add :return: None """ negative_values_source = node.in_port(port_idx).get_source() negative_values_node = node.in_port(port_idx).get_source().node negative_values_node_name = negative_values_node.soft_get( 'name', negative_values_node.id) graph = node.graph less_node = create_op_with_const_inputs( graph, Less, {1: np.array(0, dtype=added_value.dtype)}, {'name': negative_values_node_name + '/Less'}) mul_node = create_op_with_const_inputs( graph, Mul, {1: added_value}, {'name': negative_values_node_name + '/Mul'}) node.in_port(port_idx).get_connection().set_destination( less_node.in_port(0)) less_node.out_port(0).connect(mul_node.in_port(0)) add_node = Add(graph, {}).create_node() mul_node.out_port(0).connect(add_node.in_port(1)) negative_values_source.connect(add_node.in_port(0)) add_node.out_port(0).connect(node.in_port(port_idx))
def calculate_prior_box_value(value: Node, value_to_div: Port, value_to_add: Port): """ :param value: Node with value. Here is supposed the node with op='Split' :param value_to_div: Output port with values to be divided by 2 :param value_to_add: Output port with values to be added to values from value_to_div port :return: Sub and Add nodes The sub-graph can be described by formulas: min = value[value_to_add] - (value[value_to_div] / 2) max = value[value_to_add] + (value[value_to_div] / 2) """ graph = value.graph dtype = data_type_str_to_np(graph.graph['cmd_params'].data_type) _min = Sub(graph, dict(name=value.name + '/Sub')).create_node() div = create_op_node_with_second_input(graph, Div, np.array([2], dtype=dtype), op_attrs=dict(name=value.name + '/Div')) div.in_port(0).connect(value_to_div) _min.in_port(0).connect(value_to_add) _min.in_port(1).connect(div.out_port(0)) _max = Add(graph, dict(name=value.name + '/Add')).create_node() _max.in_port(0).connect(div.out_port(0)) _max.in_port(1).connect(value_to_add) return _min, _max
def replace_op(self, graph: Graph, node: Node): name = node.soft_get('name', node.id) # create range of axes for MVN based on `start_axis` and rank of input rank = Rank(graph, {'name': name + '/Rank'}).create_node() rng = create_op_with_const_inputs(graph, Range, { 0: int64_array(2), 2: int64_array(1) }, { 'name': name + '/Range', 'output_type': np.int64 }) mvn = MVN( graph, { 'eps': node.epsilon, 'eps_mode': 'inside_sqrt', 'normalize_variance': 1, 'name': name + '/Ins_Norm/MVN_', }).create_node() node.in_port(0).get_connection().set_destination(mvn.in_port(0)) rng.out_port(0).connect(mvn.in_port(1)) mul = Mul(graph, { 'axis': 1, 'name': name + '/Ins_Norm/mul_' }).create_node() mvn.out_port(0).connect(mul.in_port(0)) node.in_port(1).get_connection().set_destination(mul.in_port(1)) add = Add(graph, { 'axis': 1, 'name': name + '/Ins_Norm/add_' }).create_node() mul.out_port(0).connect(add.in_port(0)) node.in_port(2).get_connection().set_destination(add.in_port(1)) mvn.in_port(0).get_connection().add_destination(rank.in_port(0)) rng.in_port(1).connect(rank.out_port(0)) rename_nodes([(node, name + '/TBD'), (add, name)]) return [add.id]
def sub_to_add_replacement(sub: Node): # we execute this transformation for V10 IR later on middle phase despite graph_condition # so we prevent Sub replacement on shape-calculating sub-graphs if sub.in_port(0).data.get_value() is not None and sub.in_port( 1).data.get_value() is not None: return graph = sub.graph name = sub.soft_get('name', sub.id) # keep Add name the same as Sub -- because of mathematical equality of output tensors rename_node(node=sub, name=name + '/to_be_removed') # reconnect Sub in(out)puts to Add add = Add(graph, {'name': name}).create_node() rename_node(add, name) sub.in_port(0).get_connection().set_destination(add.in_port(0)) sub.in_port(1).get_connection().set_destination(add.in_port(1)) sub.out_port(0).get_connection().set_source(add.out_port(0)) # restore mathematical equivalence to Sub operation: Sub(A, B) = Add(A, Mul(B, -1)) const_dtype = sub.soft_get('data_type', np.float32) negate = create_op_with_const_inputs( graph, Mul, {1: np.array(-1, dtype=const_dtype)}, {'name': name + '/neg_'}) add.in_port(1).get_connection().insert_node(negate)
def replace_sub_graph(self, graph: Graph, match: dict): fbn = match['fbn'] input = fbn.in_node(0) log.debug('Found potential MVN pattern after {} with name {}'.format( input.op, input.name)) if input.id != match['mean'].in_node( 0).id or input.id != match['sqdiff'].in_node(0).id: return log.debug('Confirmed MVN pattern after {} with name {}'.format( input.op, input.name)) MVN = Op.get_op_class_by_name('MVN') mvn = MVN( graph, dict(name=fbn.name + '/MVN_', eps=fbn.eps, required_reduction_indices=[1, 2] if fbn.data_format == b'NHWC' else [2, 3])) mvn.attrs['old_infer'] = mvn.attrs['infer'] mvn.attrs['infer'] = __class__.infer mul = Mul(graph, dict(operation='mul', name=fbn.name + '/Mul_')) add = Add(graph, dict(operation='sum', name=fbn.name + '/Add_')) input_gamma = fbn.in_node(1) input_beta = fbn.in_node(2) mean_reduction = match['mean'].in_node(1) variance_reduction = match['variance'].in_node(1) new_subgraph = add.create_node([ mul.create_node([ mvn.create_node([input, mean_reduction, variance_reduction]), input_gamma ]), input_beta ]) fbn.replace_node(new_subgraph)
def replace_op(self, graph: Graph, node: Node): axis = Const(graph, {'value': int64_array([-1])}).create_node() mvn = MVN( graph, dict(name=node.name + '/mvn', eps=node.module.eps, normalize_variance=True, eps_mode='inside_sqrt')).create_node([node.in_node(0), axis]) weight = node.module.weight.detach().numpy() bias = node.module.bias.detach().numpy() w = Const(graph, {'value': weight}).create_node() b = Const(graph, {'value': bias}).create_node() mul = Mul(graph, dict(name=node.name + '/mul')).create_node([mvn, w]) add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b]) return [add.id]
def replace_sub_graph(self, graph: Graph, match: dict): node = match['op'] slice_name = node.soft_get('name', node.id) slice_node = Slice(graph).create_node() rename_nodes([(node, slice_name + '/to_be_removed'), (slice_node, slice_name)]) eq_node = Equal(graph, {'name': slice_name + '/equal'}).create_node() minus_one_node = Const(graph, {'name': slice_name + '/minus_one', 'value': np.array(-1)}).create_node() int32_max_node = Const(graph, {'name': slice_name + '/int32_max', 'value': np.iinfo(np.int32).max}).create_node() select_node = Select(graph, {'name': slice_name + '/select'}).create_node() # node to convert sizes to ends sum_node = Add(graph, {'name': slice_name + '/end_const'}).create_node() # reconnect input from tfslice to slice node.in_port(0).get_source().connect(slice_node.in_port(0)) node.in_port(0).disconnect() # reconnect begin of tfslice to start of slice node.in_port(1).get_source().connect(slice_node.in_port(1)) node.in_port(1).disconnect() # (size -> ends) reconnect begins and sizes to sum to evaluate ends for Slice # connects begins to slice slice_node.in_port(1).get_source().connect(sum_node.in_port(0)) node.in_port(2).get_source().connect(sum_node.in_port(1)) node.in_port(2).disconnect() # if size[i] == -1 when take int32_max as end[i] sum_node.in_port(1).get_source().connect(eq_node.in_port(0)) minus_one_node.out_port(0).connect(eq_node.in_port(1)) # from equal to 0 port of select eq_node.out_port(0).connect(select_node.in_port(0)) # from int32_max to 1 of select int32_max_node.out_port(0).connect(select_node.in_port(1)) # from sum to 2nd of select sum_node.out_port(0).connect(select_node.in_port(2)) # out of select to end (2nd of slice) select_node.out_port(0).connect(slice_node.in_port(2)) cast = Cast(graph, dict(name=sum_node.name + '/CastToI64', dst_type=np.int64)).create_node() select_node.in_port(2).get_connection().insert_node(cast) node.out_port(0).get_connection().set_source(slice_node.out_port(0))
def replace_op(self, graph: Graph, node: Node): mean = node.module.running_mean.detach().numpy() var = node.module.running_var.detach().numpy() weight = node.module.weight.detach().numpy() bias = node.module.bias.detach().numpy() w = weight / np.sqrt(var + node.module.eps) b = bias - w * mean shape = np.ones(node.module.dims, dtype=np.int32) shape[1] = -1 # channels w = Const(graph, {'value': w.reshape(shape)}).create_node() b = Const(graph, {'value': b.reshape(shape)}).create_node() mul = Mul(graph, dict(name=node.name + '/mul')).create_node( [node.in_node(0), w]) add = Add(graph, dict(name=node.name + '/add')).create_node([mul, b]) return [add.id]
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(op='Gemm'): name = node.soft_get('name', node.id) node_output_port = node.out_port(0) if node.has_valid('alpha') and not math.isclose(node.alpha, 1): mul_alpha = create_op_with_const_inputs( graph, Mul, {1: np.array(node.alpha)}, { 'name': name + '/Alpha', 'can_be_scaleshift': False }) node_output_port.get_connection().insert_node(mul_alpha) node_output_port = mul_alpha.out_port(0) del node['alpha'] if node.is_in_port_connected(2): # biases normalization bias_node = Add(graph, { 'name': name + '/Bias_', 'can_be_scaleshift': False }).create_node() without_biases_node_name = name + '/WithoutBiases' rename_nodes([(node, without_biases_node_name), (bias_node, name)]) node_output_port.get_connection().set_source( bias_node.out_port(0)) node.in_port(2).get_connection().set_destination( bias_node.in_port(1)) node_output_port.connect(bias_node.in_port(0)) if node.has_valid('beta') and not math.isclose(node.beta, 1): bias_node.insert_op_on_input_port(in_port_idx=1, new_op_class=Mul, value=np.array( node.beta), new_op_attrs={ 'name': name + '/Beta', 'can_be_scaleshift': False }) del node['beta'] MatMul.update_node_stat( node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })