def replace_pattern(self, graph: Graph, match: Dict[str, Node]): quantize = match['quantize'] preop = match['preop'] for i in [0, 1]: if preop.in_port(i).get_source().node.soft_get('type') in [ 'Convolution', 'Deconvolution', 'MatMul' ]: return tensor_port, value_port = get_tensor_in_port(preop), get_value_in_port( preop) if value_port is None or value_port.data.get_value() is None: log.debug( 'AddQuantizeFuse: cannot fuse because Add op has dynamic inputs' ) return # Direct modifications to quantize 1-st and 2-nd port inputs are performed. # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1, 2, 3, 4 data resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1, 2]) quantize.in_port(1).data.set_value( quantize.in_port(1).data.get_value() - value_port.data.get_value()) if quantize.in_node(1).id != quantize.in_node(2).id: quantize.in_port(2).data.set_value( quantize.in_port(2).data.get_value() - value_port.data.get_value()) in_add_connection = quantize.in_port(0).get_source().node.in_port( 0).get_connection() quantize.in_port(0).disconnect() in_add_connection.add_destination(quantize.in_port(0))
def replace_pattern(graph: Graph, match: dict): node = match['fc'] name = node.soft_get('name', node.id) add = match['add'] if 2 in node.in_ports() and not node.in_port(2).disconnected(): return out_size = node.soft_get('out-size', None) assert out_size is not None, \ "FullyConnected should have `out-size` parameter, but it doesn't for node {}".format(name) tensor_port, value_port = get_tensor_in_port(add), get_value_in_port(add) if value_port is None: return shift_shape = value_port.data.get_shape() if not any([np.array_equal(int64_array(suitable_shape), shift_shape) for suitable_shape in [[1, out_size], [1, 1], [out_size], [1], []]]): return broadcasted_value = np.broadcast_to(value_port.data.get_value(), [1, out_size]) const = Const(graph, {'name': name + '/Bias_', 'value': broadcasted_value}).create_node() node.add_input_port(2, skip_if_exist=True) const.out_port(0).connect(node.in_port(2)) add.out_port(0).get_connection().set_source(tensor_port.get_source()) node.in_port(2).bin = 'biases'
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): quantize = match['quantize'] preop = match['preop'] tensor_port, value_port = get_tensor_in_port(preop), get_value_in_port(preop) if value_port is None or value_port.data.get_value() is None: log.debug('MulQuantizeFuse: cannot fuse because Mul op has dynamic inputs') return mul_val = value_port.data.get_value() if np.any(mul_val <= 0): return # Direct modifications to quantize 1-st and 2-nd port inputs are performed. # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1, 2 data if needed resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1, 2]) # TODO: need some special processing for values that exactly equal to threshold quantize.in_port(1).data.set_value(quantize.in_port(1).data.get_value() / mul_val) if quantize.in_node(1).id != quantize.in_node(2).id: quantize.in_port(2).data.set_value(quantize.in_port(2).data.get_value() / mul_val) # Reconnect Mul as it no longer needed for current FakeQuantize in_mul_connection = quantize.in_port(0).get_source().node.in_port(0).get_connection() quantize.in_port(0).disconnect() in_mul_connection.add_destination(quantize.in_port(0))
def replace_pattern(graph: Graph, match: [str, Node]): op = match['op'] op_type = op.type const_port, tensor_port = get_value_in_port(op), get_tensor_in_port(op) if const_port is None or tensor_port is None: return value = const_port.data.get_value() assert value is not None if value.size != 1: return value = value.item(0) assert op_type in EltwisesWithScalarInputToPower.eltw_types if op_type == 'Add': delete_node = value == 0 Power.update_node_stat(op, {'shift': value}) elif op_type == 'Multiply': delete_node = value == 1 Power.update_node_stat(op, {'scale': value}) elif op_type == 'Pow': delete_node = value == 1 Power.update_node_stat(op, {'power': value}) const_port.disconnect() if tensor_port.idx != 0: tensor_port.get_connection().set_destination(op.in_port(0))
def mark_eltwise_node(self, node, feature_channel=None): tensor_port, value_port = get_tensor_in_port(node), get_value_in_port( node) if tensor_port is None or value_port is None: self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) return connected_in_ports = { idx: port for idx, port in node.in_ports().items() if not port.disconnected() } if len(connected_in_ports) != 2: return tensor_shape = tensor_port.data.get_shape() out_shape = node.out_port(0).data.get_shape() assert tensor_shape is not None and out_shape is not None if not np.array_equal(tensor_shape, out_shape): # ScaleShift operation doesn't support broadcasting self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) return value_shape = value_port.data.get_shape() assert value_shape is not None assert len(value_shape) <= len(tensor_shape), \ "No broadcasting was done for elementwise node {} due to previous checks in EltwiseChecker class. " \ "But constant input rank is larger than tensor input rank, that is inconsistent".format(node.name) # if both tensors are 0D they cannot be converted to scaleshift if len(tensor_shape) == 0 and len(value_shape) == 0: self.set_flags_to_false(node, ['can_be_scaleshift']) return broadcasted_value_shape = np.insert( value_shape, 0, [1] * (len(tensor_shape) - len(value_shape))) feature_dim = min(1, tensor_shape.size - 1) if node.graph.graph['layout'] == 'NCHW' else -1 if feature_channel is not None: feature_dim = feature_channel ones = np.ones(len(tensor_shape)) possible_shape = ones.copy() np.put(possible_shape, feature_dim, tensor_shape.item(feature_dim)) if not np.array_equal(broadcasted_value_shape, ones) and \ not np.array_equal(broadcasted_value_shape, possible_shape): # ScaleShift weights should have [1,C,1,1]-like or [1,1,1,1]-like shape self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) return if len(tensor_shape) not in [2, 4, 5]: # ScaleShift operation is supported for 2D, 4D or 5D tensor inputs self.set_flags_to_false(node, ['can_be_scaleshift']) return
def replace_pattern(self, graph: Graph, match: Dict[str, Node]): quantize = match['quantize'] preop = match['preop'] tensor_port, value_port = get_tensor_in_port(preop), get_value_in_port(preop) if value_port is None or value_port.data.get_value() is None: log.debug('MulQuantizeFuse: cannot fuse because Mul op has dynamic inputs') return mul_val = value_port.data.get_value() # Direct modifications to quantize 1-st and 2-nd port inputs are performed. # So the data nodes at those inputs shouldn't have more than 1 consumer maximum 2 consumers to the same # quantize op (consumed by 1st and 2nd ports). So we duplicate FakeQuantize in_port 1, 2 data if needed resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[1, 2]) # TODO: need some special processing for values that exactly equal to threshold # Need to flip output_low and output_high for those elements that have multiplier < 0 if np.all(mul_val < 0): mi_o_node = quantize.in_port(3).get_source() ma_o_node = quantize.in_port(4).get_source() quantize.in_port(3).disconnect() quantize.in_port(4).disconnect() mi_o_node.connect(quantize.in_port(4)) ma_o_node.connect(quantize.in_port(3)) elif np.any(mul_val < 0): # Flipping values should be done on exclusive inputs of FakeQuantize node, so we duplicate them if needed resolve_shared_inputs(node=quantize, port_ids_to_duplicate=[3, 4]) # Successful flipping will be done on broadcasted arrays mi_o_val = quantize.in_port(3).data.get_value() ma_o_val = quantize.in_port(4).data.get_value() mul_val, mi_o_val, ma_o_val = [np.array(a) for a in np.broadcast_arrays(mul_val, mi_o_val, ma_o_val)] neg_idx = np.where(mul_val < 0) mi_o_val[neg_idx], ma_o_val[neg_idx] = ma_o_val[neg_idx], mi_o_val[neg_idx] # TODO: revert broadcasting where unnecessary quantize.in_port(3).data.set_value(mi_o_val) quantize.in_port(4).data.set_value(ma_o_val) quantize.in_port(1).data.set_value(quantize.in_port(1).data.get_value() / mul_val) if quantize.in_node(1).id != quantize.in_node(2).id: quantize.in_port(2).data.set_value(quantize.in_port(2).data.get_value() / mul_val) # Reconnect Mul as it no longer needed for current FakeQuantize in_mul_connection = quantize.in_port(0).get_source().node.in_port(0).get_connection() quantize.in_port(0).disconnect() in_mul_connection.add_destination(quantize.in_port(0))
def convert_add_or_mul_to_scaleshift(graph: Graph): if graph.graph['cmd_params'].generate_experimental_IR_V10: return graph.strict_mode = False for node in graph.get_op_nodes(): if node.soft_get('op') in ['Add', 'Mul'] and len(node.in_ports()) == 2: tensor_port, value_port = get_tensor_in_port(node), get_value_in_port(node) if tensor_port is not None and not tensor_port.disconnected() and value_port is not None and node.soft_get('can_be_scaleshift') is not False: original_value = value_port.data.get_value() if original_value.size == 1: continue # Remove 1 dims from value array (should be 1D) value_port.data.set_value(np.squeeze(original_value)) # Updated shapes accordingly # Create ScaleShift operation scsh_op = ScaleShiftOp(graph, dict(name='ScaleShift/{}'.format(node.name))).create_node() if node.op == 'Mul': # Create fake biases for scale shift node const_op = Const(graph, dict(name='{}/biases'.format(scsh_op.name), value=np.zeros(value_port.data.get_shape(), dtype=np.float32), shape=np.array(value_port.data.get_shape()), )).create_node() # Reconnect input and weights to scale shift node tensor_port.get_connection().set_destination(scsh_op.in_port(0)) value_port.get_connection().set_destination(scsh_op.in_port(1)) const_op.out_port(0).connect(scsh_op.in_port(2)) else: # Create fake weights for scale shift node const_op = Const(graph, dict(name='{}/weights'.format(scsh_op.name), value=np.ones(value_port.data.get_shape(), dtype=np.float32), shape=np.array(value_port.data.get_shape()), )).create_node() # Reconnect input and biases to scale shift node tensor_port.get_connection().set_destination(scsh_op.in_port(0)) const_op.out_port(0).connect(scsh_op.in_port(1)) value_port.get_connection().set_destination(scsh_op.in_port(2)) node.out_port(0).get_connection().set_source(scsh_op.out_port(0)) # Set bin attribute to ScaleShift input ports scsh_op.in_port(1).bin = 'weights' scsh_op.in_port(2).bin = 'biases' graph.strict_mode = True
def find_and_replace_pattern(self, graph: Graph): for node in graph.get_op_nodes(is_eltwise=True): tensor_port, value_port = get_tensor_in_port( node), get_value_in_port(node) if tensor_port is None or value_port is None: self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) continue tensor_shape = tensor_port.data.get_shape() out_shape = node.out_port(0).data.get_shape() assert tensor_shape is not None and out_shape is not None if not np.array_equal(tensor_shape, out_shape): # ScaleShift operation doesn't support broadcasting self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) continue value_shape = value_port.data.get_shape() assert value_shape is not None assert len(value_shape) <= len(tensor_shape), \ "No broadcasting was done for elementwise node {} due to previous checks in EltwiseChecker class. " \ "But constant input rank is larger than tensor input rank, that is inconsistent".format(node.name) broadcasted_value_shape = np.insert( value_shape, 0, [1] * (len(tensor_shape) - len(value_shape))) feature_dim = min( 1, tensor_shape.size - 1) if node.graph.graph['layout'] == 'NCHW' else -1 ones = np.ones(len(tensor_shape)) possible_shape = ones.copy() np.put(possible_shape, feature_dim, tensor_shape.item(feature_dim)) if not np.array_equal(broadcasted_value_shape, ones) and \ not np.array_equal(broadcasted_value_shape, possible_shape): # ScaleShift weights should have [1,C,1,1]-like or [1,1,1,1]-like shape self.set_flags_to_false(node, ['can_be_fused', 'can_be_scaleshift']) continue if len(tensor_shape) not in [2, 4, 5]: # ScaleShift operation is supported for 2D, 4D or 5D tensor inputs self.set_flags_to_false(node, ['can_be_scaleshift']) continue
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) # If the value fails to reshape to the provided shape, skip fusing. # This can happen in case of group != 1 of the convolution. try: value = np.reshape(value, shape) except ValueError: log.error( "Cannot fuse const from {} to {}. Reshape failed. Skipping.". format(node.soft_get('name', node.id), fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True}) return False # Weights multiplication mul_name = node.name + '_copy' mul_const = Const(graph, { 'value': value, 'name': mul_name + '/const' }).create_node() w_mul = node.copy_node({ 'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) w_const.connect(w_mul.in_port(tensor_port.idx)) fuse_node_in_data = fuse_node.in_node(weights_port.idx) w_const_out_data = w_const.node.out_node(w_const.idx) # During this reconnection new data node name is copied from the data node # outgoing from w_const port. Duplicate names of data nodes lead to appearing # of duplicate op node names after constant folding. So we should manually # set a unique name for the new data node. if fuse_node_in_data.soft_get('name') == w_const_out_data.soft_get('name') and \ fuse_node_in_data.soft_get('name', None) is not None: fuse_node.in_node( weights_port.idx)['name'] = graph.unique_id(mul_name) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() # as Mul node is added before convolution, output tensor from Convolution node # corresponds to original Mul node node.out_port(0).get_connection().set_source(producer_port, "dest") return is_fused
def _fuse_linear_sequence(graph: Graph, start_node: Node): """ This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add). :param graph: :param start_node: The first operation of the sequence """ fnodes = [start_node] while True: node = fnodes[-1] destinations = node.out_port(0).get_destinations() if len(destinations) != 1: break dst_node = destinations[0].node if dst_node.soft_get('op') in [ 'Mul', 'Add' ] and get_value_in_port(dst_node) is not None and dst_node.soft_get( 'can_be_fused') is True: fnodes.append(dst_node) else: break if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul' and fnodes[1].op == 'Add'): return False input_shape = get_tensor_in_port(start_node).data.get_shape() init_dims_cnt = len( input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1 mul = np.ones([1 for x in range(init_dims_cnt)]) add = np.zeros([1 for x in range(init_dims_cnt)]) first_mul_name = None first_add_name = None for node in fnodes: const_port_value = get_value_in_port(node).data.get_value() if node.op == 'Mul': if first_mul_name is None: first_mul_name = node.name mul = mul * const_port_value add = add * const_port_value elif node.op == 'Add': if first_add_name is None: first_add_name = node.name add = add + const_port_value # If mul is scalar we broadcast it to biases shape if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1: mul = np.array([mul[0] for x in range(add.shape[0])]) assert (np.array_equal( get_tensor_in_port(fnodes[0]).data.get_shape(), fnodes[-1].out_port(0).data.get_shape())) mul_op = Mul(graph, dict(name='{}/Fused_Mul_'.format(first_mul_name or ''))) add_op = Add(graph, dict(name='{}/Fused_Add_'.format(first_add_name or ''))) in_port = get_tensor_in_port(fnodes[0]) out_port = fnodes[-1].out_port(0) """ Four cases considered below: 1. Mul and Add have valid values (mul value != 1 and add value != 0) 2. Only Mul has valid values, so we add only Mul node 3. Only Add has valid values, so we add only Add node 4. When Mul and Add has not valid values we just merge two data nodes """ if any([x != 0 for x in np.nditer(add)]) and any([x != 1 for x in np.nditer(mul)]): # Const\ Const\ # ----->Mul------>Add--> mul_const = Const(graph, dict(name="data_mul_", value=np.array(mul))).create_node() add_const = Const(graph, dict(name="data_add_", value=np.array(add))).create_node() mul_node = mul_op.create_node() add_node = add_op.create_node() in_port.get_connection().set_destination(mul_node.in_port(0)) mul_const.out_port(0).connect(mul_node.in_port(1)) mul_node.out_port(0).connect(add_node.in_port(0)) add_const.out_port(0).connect(add_node.in_port(1)) out_port.get_connection().set_source(add_node.out_port(0)) elif any([x != 1 for x in np.nditer(mul)]): # Const\ # ----->Mul--> mul_const = Const(graph, dict(name="data_mul_", value=np.array(mul))).create_node() mul_node = mul_op.create_node() in_port.get_connection().set_destination(mul_node.in_port(0)) mul_const.out_port(0).connect(mul_node.in_port(1)) out_port.get_connection().set_source(mul_node.out_port(0)) elif any([x != 0 for x in np.nditer(add)]): # Const\ # ----->Add--> add_const = Const(graph, dict(name="data_add_", value=np.array(add))).create_node() add_node = add_op.create_node() in_port.get_connection().set_destination(add_node.in_port(0)) add_const.out_port(0).connect(add_node.in_port(1)) out_port.get_connection().set_source(add_node.out_port(0)) else: source_node = in_port.get_source() in_port.disconnect() out_port.get_connection().set_source(source_node) # Remove fused nodes for node in fnodes: graph.remove_node(node.id) log.debug('Fused {} operations'.format(len(fnodes))) return True
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) value = np.reshape(value, shape) # Weights multiplication mul_const = Const(graph, {'value': value}).create_node() w_mul = node.copy_node({ 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) w_const.connect(w_mul.in_port(tensor_port.idx)) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() node.out_port(0).get_connection().set_source(producer_port) return is_fused
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) # If the value fails to reshape to the provided shape, skip fusing. # This can happen in case of group != 1 of the convolution. try: value = np.reshape(value, shape) except ValueError: log.error( "Cannot fuse const from {} to {}. Reshape failed. Skipping.". format(node.soft_get('name', node.id), fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True}) return False # Weights multiplication mul_name = node.name + '_copy' mul_const = Const(graph, { 'value': value, 'name': mul_name + '/const' }).create_node() w_mul = node.copy_node({ 'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) r""" In this transformation we remove Mul or Div node (node) that goes after fuse_node and create new Mul node (w_mul), connect it with the corrected const value (mul_const) and insert w_mul before the fuse_node. So the input data of fuse_node becomes different. For this reason we need to use set_destination from previous operation to w_mul which guaranties that data node will be reused on previous_op -> w_mul connection and its attributes won't be copied to the data node of w_mul -> fuse_node connection. BEFORE AFTER previous_op mul_const \ / previous_op w_mul | | fuse_node const fuse_node \ / | node next_op | next_op """ weights_port.get_connection().set_destination( w_mul.in_port(tensor_port.idx)) w_mul.out_port(0).connect(weights_port) # As fusing is applied to convolutions it is important to keep 'permutation' and 'input_permutation' attributes # which were obtained from original model. These attributes are stored on the incoming edge to the operation # node and during the reconnection they are moved to the new connection. But during reconnection in this # transformation these attributes are moved to the previous node. So we need manually set them at the # incoming edge to fuse_node. in_edge = w_mul.in_edge(tensor_port.idx) if 'permutation' in in_edge: fuse_node.in_edge( weights_port.idx)['permutation'] = in_edge['permutation'] if 'input_permutation' in in_edge: fuse_node.in_edge( weights_port.idx )['input_permutation'] = in_edge['input_permutation'] # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() # as Mul node is added before convolution, output tensor from Convolution node # corresponds to original Mul node node.out_port(0).get_connection().set_source(producer_port, "dest") return is_fused
def _fuse_add(graph: Graph, node: Node, fuse_nodes: List[Node], backward: bool = True): """ This function takes Add node and Convolution/FC nodes for further fusion and then deletes Add node In case if Convolution/FC Bias absence it will be created """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_add for node {} because this node has wrong inputs' .format(node.id)) return False # if len(node.in_node(const_id).shape) > 2 or any([x == 0 for x in node.in_node(const_id).shape]): # log.warning('Cannot do fuse_add for node {} because this node has wrong shape'.format(node.id)) # return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing due to user specified attr can_be_fused = False' .format(fuse_node.name)) return False if not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) # If forward, broadcast value if not backward: cnt = weights_port.data.get_shape( )[-1] / const_port.data.get_shape()[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) value = np.squeeze(value) # Create BIAS data node if not exists if len(fuse_node.in_ports()) <= 2: fuse_node.add_input_port(idx=2) if fuse_node.in_port(2).disconnected( ) or fuse_node.in_port(2).data.get_value() is None: # Broadcast if scalar if value.size == 1: id = weights_port.data.get_attr( 'output_channel_dim' ) if backward else weights_port.data.get_attr( 'input_channel_dim') vshape = weights_port.data.get_shape()[id] value = np.full(vshape, value.item()) if not backward: value = np.dot(weights_port.data.get_value(), value) const_bias_node = Const( graph, dict(name="bias_data", value=np.array(value))).create_node() fuse_node.in_port(2).connect(const_bias_node.out_port(0)) fuse_node.in_port(2).bin = 'biases' const_bias_node.infer(const_bias_node) fuse_node['bias_term'] = True else: bias_value = fuse_node.in_port(2).data.get_value() if not backward: fuse_node.in_port(2).data.set_value( bias_value + np.dot(fuse_node.in_port(1).data.get_value(), value)) else: fuse_node.in_port(2).data.set_value(bias_value + value) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Add node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() node.out_port(0).get_connection().set_source(producer_port) return is_fused