def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]): """ Duplicates shared constants that are consumed by more than one node. If constant is consumed by several ports of one node - no duplication gets done """ graph = node.graph for port_id in port_ids_to_duplicate: dst_port_map = defaultdict(list) for dst in node.in_port( port_id).get_source().get_connection().get_destinations(): dst_port_map[dst.node].append(dst.idx) del dst_port_map[node] value = node.in_port(port_id).data.get_value() if value is None: log.debug( 'Can not duplicate due no data for in_port {} of node {}'. format(port_id, node.name)) for node, idxs in dst_port_map.items(): const = Const( graph, { 'value': np.array(value), 'name': node.soft_get('name', node.id) + '/duplicated_' }).create_node() for idx in idxs: node.in_port(idx).disconnect() const.out_port(0).connect(node.in_port(idx)) const.infer(const)
def find_and_replace_pattern(self, graph: Graph): if graph.graph['layout'] != 'NHWC': # we check it here because this transformation is called explicitly from the pipeline return # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) input_shape = reinterp_shape_node.in_node(0).shape if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4: order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm }).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose' }).create_node() reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) order_const.infer(order_const) # do not infer the Transpose node because it should have input data node in NCHW layout (but currently # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout # (which is true at this moment) permute_node['need_shape_inference'] = False # mark the Transpose output data node having correct layout so it's shape will not be permuted mark_output_as_in_correct_layout(permute_node, 0) # keep the reinterp_shape_node in NHWC layout mark_input_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True): reinterp_shape_node = Node(graph, reinterp_shape_node_id) assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format( reinterp_shape_node_id, graph.dump_graph_for_graphviz()) output_shape = reinterp_shape_node.out_node(0).shape if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4: order_const = Const(graph, { 'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node() permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node() reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node) order_const.out_port(0).connect(permute_node.in_port(1)) # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose # will convert it to the NCHW mark_input_as_in_correct_layout(permute_node, 0) mark_input_as_in_correct_layout(permute_node, 1) # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be # permuted # keep the reinterp_shape_node in NHWC layout mark_output_as_in_correct_layout(reinterp_shape_node, 0) mark_input_as_in_correct_layout(reinterp_shape_node, 1) # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the # rest of the graph consistent permute_node['need_shape_inference'] = False
def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] weights = match['weights'] input_shape = conv.in_port(0).data.get_shape() new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]]) new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node() weights.out_port(0).get_connection().set_source(new_weights.out_port(0)) new_weights.infer(new_weights)
def create_op_node_with_second_input(graph: Graph, op: callable, second_input_value: np.array, op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) second_input_node = Const(graph, {'name': node.name + '/value', 'value': second_input_value}).create_node() second_input_node.out_port(0).connect(node.in_port(1)) if graph.stage != 'front': second_input_node.infer(second_input_node) return node
def create_fake_quantize_node(graph: Graph, name): fq = FakeQuantize(graph, { 'name': name, 'levels': 0, 'stop_value_propagation': True }).create_node() input_low = Const(graph, { 'value': np.array(0.0).astype(np.float32) }).create_node() input_height = Const(graph, { 'value': np.array(0.0).astype(np.float32) }).create_node() output_low = Const(graph, { 'value': np.array(0.0).astype(np.float32) }).create_node() output_height = Const(graph, { 'value': np.array(0.0).astype(np.float32) }).create_node() input_low.out_port(0).connect(fq.in_port(1)) input_height.out_port(0).connect(fq.in_port(2)) output_low.out_port(0).connect(fq.in_port(3)) output_height.out_port(0).connect(fq.in_port(4)) input_low.infer(input_low) input_height.infer(input_height) output_low.infer(output_low) output_height.infer(output_height) return fq
def replace_pattern(graph: Graph, match: dict): node = match['op'] input_shape = node.in_port(0).data.get_shape() if len(input_shape) > 2: new_shape = Const(graph, { 'value': np.array([0, -1], dtype=np.int64) }).create_node() reshape = Reshape(graph, {}).create_node() source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) source.connect(reshape.in_port(0)) new_shape.out_port(0).connect(reshape.in_port(1)) new_shape.infer(new_shape) reshape.infer(reshape)
def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dict[int, np.array], op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) for idx, value in port_value_dict.items(): node.add_input_port(idx, skip_if_exist=True) value_input_node = Const(graph, {'name': node.name + '_input_port_' + str(idx) + '/value', 'value': value}).create_node() value_input_node.out_port(0).connect(node.in_port(idx)) if graph.stage != 'front': value_input_node.infer(value_input_node) return node
def replace_pattern(self, graph: Graph, match: dict): conv = match['conv'] assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id) out_data = conv.out_node() assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format( conv.id) inp_data = conv.in_node() assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id) inp_shape = inp_data.shape new_inp_shape = np.insert(inp_shape, 2, 1) # setting to None to be overwritten by infer function conv.kernel_spatial_idx = None conv.spatial_dims = None # inserting fake H dimension conv.dilation = np.insert(conv.dilation, 2, 1) conv.kernel_spatial = np.append([1], conv.kernel_spatial) conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0) conv.stride = np.insert(conv.stride, 2, 1) weights_node = conv.in_node(1) weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1)) weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64) reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node() reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node() conv.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node() reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node() conv.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) conv.infer(conv)
def replace_pattern(self, graph: Graph, match: dict): node = match['pad'] pb = node.pads[:, 0] pe = node.pads[:, 1] pm = node.mode pads_begin = Const(graph, {'value': np.array(pb)}).create_node() node.add_input_port(1, skip_if_exist=True) node.in_port(1).connect(pads_begin.out_port(0)) pads_begin.infer(pads_begin) pads_end = Const(graph, {'value': np.array(pe)}).create_node() node.add_input_port(2, skip_if_exist=True) node.in_port(2).connect(pads_end.out_port(0)) pads_end.infer(pads_end) del node['pads'] if node.has_valid('fill_value') and pm == 'constant': pv = node.fill_value pad_value = Const(graph, {'value': np.array(pv)}).create_node() node.add_input_port(3, skip_if_exist=True) node.in_port(3).connect(pad_value.out_port(0)) pad_value.infer(pad_value) del node['fill_value']
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) # If the value fails to reshape to the provided shape, skip fusing. # This can happen in case of group != 1 of the convolution. try: value = np.reshape(value, shape) except ValueError: log.error( "Cannot fuse const from {} to {}. Reshape failed. Skipping.". format(node.soft_get('name', node.id), fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True}) return False # Weights multiplication mul_name = node.name + '_copy' mul_const = Const(graph, { 'value': value, 'name': mul_name + '/const' }).create_node() w_mul = node.copy_node({ 'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) w_const.connect(w_mul.in_port(tensor_port.idx)) fuse_node_in_data = fuse_node.in_node(weights_port.idx) w_const_out_data = w_const.node.out_node(w_const.idx) # During this reconnection new data node name is copied from the data node # outgoing from w_const port. Duplicate names of data nodes lead to appearing # of duplicate op node names after constant folding. So we should manually # set a unique name for the new data node. if fuse_node_in_data.soft_get('name') == w_const_out_data.soft_get('name') and \ fuse_node_in_data.soft_get('name', None) is not None: fuse_node.in_node( weights_port.idx)['name'] = graph.unique_id(mul_name) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() # as Mul node is added before convolution, output tensor from Convolution node # corresponds to original Mul node node.out_port(0).get_connection().set_source(producer_port, "dest") return is_fused
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) value = np.reshape(value, shape) # Weights multiplication mul_const = Const(graph, {'value': value}).create_node() w_mul = node.copy_node({ 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) w_const.connect(w_mul.in_port(tensor_port.idx)) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() node.out_port(0).get_connection().set_source(producer_port) return is_fused
def replace_pattern(graph: Graph, match: dict): node = match['matmul'] name = node.soft_get('name', node.id) A_shape = node.in_port(0).data.get_shape() B_shape = node.in_port(1).data.get_shape() out_shape = node.out_port(0).data.get_shape() assert A_shape is not None and B_shape is not None and out_shape is not None B_value = node.in_port(1).data.get_value() if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[ B_shape != 1].size <= 2: # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O] # to FullyConnected representation: [I, K] * [O, K] = [I, O] B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node) # weights normalization if not node.transpose_b: # FullyConnected weights layout is OI # MatMul second input layout is (B)IO transpose_order = list(range(B_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(weights_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if node.in_port(1).data.get_shape().size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node() weights_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(weights_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \ "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O) node.in_port(1).bin = 'weights' del node['transpose_b'] # input normalization if node.transpose_a: transpose_order = list(range(A_shape.size)) transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1] order = Const(graph, {'value': int64_array(transpose_order)}).create_node() transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(transpose.out_port(0)) transpose.in_port(0).connect(input_source) transpose.in_port(1).connect(order.out_port(0)) order.infer(order) transpose.infer(transpose) if A_shape.size != 2: const = Const(graph, {'value': int64_array([-1, K])}).create_node() reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node() input_source = node.in_port(0).get_source() node.in_port(0).get_connection().set_source(reshape.out_port(0)) reshape.in_port(0).connect(input_source) reshape.in_port(1).connect(const.out_port(0)) const.infer(const) reshape.infer(reshape) assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \ "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \ "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O) del node['transpose_a'] FullyConnected.update_node_stat(node, {'out-size': O}) # output normalization if out_shape.size != 2: const = Const(graph, {'value': int64_array([*B, I, O])}).create_node() reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node() dst = node.out_port(0).get_destination() node.out_port(0).get_connection().set_destination(reshape.in_port(0)) const.out_port(0).connect(reshape.in_port(1)) reshape.out_port(0).connect(dst) node.infer(node) const.infer(const) reshape.infer(reshape) else: assert A_shape.size == out_shape.size assert B_shape.size <= out_shape.size if B_shape.size != out_shape.size: unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size))) }).create_node() unsqueeze = Unsqueeze(graph, {}).create_node() B_source = node.in_port(1).get_source() node.in_port(1).get_connection().set_source(unsqueeze.out_port(0)) unsqueeze.in_port(0).connect(B_source) unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0)) unsqueeze_dim.infer(unsqueeze_dim) unsqueeze.infer(unsqueeze) Gemm.update_node_stat(node, { 'transpose_a': node.has_and_set('transpose_a'), 'transpose_b': node.has_and_set('transpose_b'), })
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = np.array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item()) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = np.array(value) # If the value fails to reshape to the provided shape, skip fusing. # This can happen in case of group != 1 of the convolution. try: value = np.reshape(value, shape) except ValueError: log.error( "Cannot fuse const from {} to {}. Reshape failed. Skipping.". format(node.soft_get('name', node.id), fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True}) return False # Weights multiplication mul_name = node.name + '_copy' mul_const = Const(graph, { 'value': value, 'name': mul_name + '/const' }).create_node() w_mul = node.copy_node({ 'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) r""" In this transformation we remove Mul or Div node (node) that goes after fuse_node and create new Mul node (w_mul), connect it with the corrected const value (mul_const) and insert w_mul before the fuse_node. So the input data of fuse_node becomes different. For this reason we need to use set_destination from previous operation to w_mul which guaranties that data node will be reused on previous_op -> w_mul connection and its attributes won't be copied to the data node of w_mul -> fuse_node connection. BEFORE AFTER previous_op mul_const \ / previous_op w_mul | | fuse_node const fuse_node \ / | node next_op | next_op """ weights_port.get_connection().set_destination( w_mul.in_port(tensor_port.idx)) w_mul.out_port(0).connect(weights_port) # As fusing is applied to convolutions it is important to keep 'permutation' and 'input_permutation' attributes # which were obtained from original model. These attributes are stored on the incoming edge to the operation # node and during the reconnection they are moved to the new connection. But during reconnection in this # transformation these attributes are moved to the previous node. So we need manually set them at the # incoming edge to fuse_node. in_edge = w_mul.in_edge(tensor_port.idx) if 'permutation' in in_edge: fuse_node.in_edge( weights_port.idx)['permutation'] = in_edge['permutation'] if 'input_permutation' in in_edge: fuse_node.in_edge( weights_port.idx )['input_permutation'] = in_edge['input_permutation'] # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() # as Mul node is added before convolution, output tensor from Convolution node # corresponds to original Mul node node.out_port(0).get_connection().set_source(producer_port, "dest") return is_fused
def _fuse_add(graph: Graph, node: Node, fuse_nodes: List[Node], backward: bool = True): """ This function takes Add node and Convolution/FC nodes for further fusion and then deletes Add node In case if Convolution/FC Bias absence it will be created """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_add for node {} because this node has wrong inputs' .format(node.id)) return False # if len(node.in_node(const_id).shape) > 2 or any([x == 0 for x in node.in_node(const_id).shape]): # log.warning('Cannot do fuse_add for node {} because this node has wrong shape'.format(node.id)) # return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing due to user specified attr can_be_fused = False' .format(fuse_node.name)) return False if not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = np.array(const_port.data.get_value()) # If forward, broadcast value if not backward: cnt = weights_port.data.get_shape( )[-1] / const_port.data.get_shape()[0] if fuse_node.layout == 'NCHW': tmp = [] for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = np.array(tmp) else: value = np.tile(value, int(cnt)) value = np.squeeze(value) # Create BIAS data node if not exists if len(fuse_node.in_ports()) <= 2: fuse_node.add_input_port(idx=2) if fuse_node.in_port(2).disconnected( ) or fuse_node.in_port(2).data.get_value() is None: # Broadcast if scalar if value.size == 1: id = weights_port.data.get_attr( 'output_channel_dim' ) if backward else weights_port.data.get_attr( 'input_channel_dim') vshape = weights_port.data.get_shape()[id] value = np.full(vshape, value.item()) if not backward: value = np.dot(weights_port.data.get_value(), value) const_bias_node = Const( graph, dict(name="bias_data", value=np.array(value))).create_node() fuse_node.in_port(2).connect(const_bias_node.out_port(0)) fuse_node.in_port(2).bin = 'biases' const_bias_node.infer(const_bias_node) fuse_node['bias_term'] = True else: bias_value = fuse_node.in_port(2).data.get_value() if not backward: fuse_node.in_port(2).data.set_value( bias_value + np.dot(fuse_node.in_port(1).data.get_value(), value)) else: fuse_node.in_port(2).data.set_value(bias_value + value) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Add node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() node.out_port(0).get_connection().set_source(producer_port) return is_fused
def replace_pattern(graph: Graph, match: dict): """ Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors): Searches for Tiles with 3D shapes and covers it with Reshapes. Example: Tile (axis=1, tiles=16): in_shape: [1,1,101] out_shape: [1,16,101] Old behaviour: Tile -> [1,16,101] New behaviour: Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101] """ tile = match['tile'] assert len(tile.out_nodes( )) == 1, "Tile operation {} should have 1 output data node".format( tile.id) out_data = tile.out_node() assert out_data.has_valid( 'shape'), 'Output shape is undefined for {} in back phase'.format( tile.id) out_shape = out_data.shape if out_shape.size != 3: return assert len(tile.in_nodes( )) == 1, "Tile operation {} should have 1 input data node".format( tile.id) inp_data = tile.in_node() assert inp_data.has_valid( 'shape'), 'Input shape is undefined for {} in back phase'.format( tile.id) inp_shape = inp_data.shape new_inp_shape = np.append(inp_shape, [1]) reshape = Reshape(graph, { 'name': tile.name + '/reshape' }).create_node() reshape_dim = Const(graph, { 'value': new_inp_shape, 'name': reshape.id + '/Dim' }).create_node() tile.in_port(0).get_connection().insert_node(reshape) reshape.in_port(1).connect(reshape_dim.out_port(0)) reshape_back = Reshape(graph, { 'name': tile.name + '/reshape_back' }).create_node() reshape_back_dim = Const(graph, { 'value': out_shape, 'name': reshape.id + '/Dim' }).create_node() tile.out_port(0).get_connection().insert_node(reshape_back) reshape_back.in_port(1).connect(reshape_back_dim.out_port(0)) # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour reshape_dim.infer(reshape_dim) reshape.infer(reshape) tile.infer(tile)