def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]): """ Duplicates shared constants that are consumed by more than one node. If constant is consumed by several ports of one node - no duplication gets done """ graph = node.graph for port_id in port_ids_to_duplicate: dst_port_map = defaultdict(list) for dst in node.in_port( port_id).get_source().get_connection().get_destinations(): dst_port_map[dst.node].append(dst.idx) del dst_port_map[node] value = node.in_port(port_id).data.get_value() if value is None: log.debug( 'Can not duplicate due no data for in_port {} of node {}'. format(port_id, node.name)) for node, idxs in dst_port_map.items(): const = Const( graph, { 'value': mo_array(value), 'name': node.soft_get('name', node.id) + '/duplicated_' }).create_node() for idx in idxs: node.in_port(idx).disconnect() const.out_port(0).connect(node.in_port(idx)) const.infer(const)
def create_op_node_with_second_input(graph: Graph, op: callable, second_input_value: np.array, op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) second_input_node = Const(graph, {'name': node.name + '/value', 'value': second_input_value}).create_node() second_input_node.out_port(0).connect(node.in_port(1)) if graph.stage != 'front': second_input_node.infer(second_input_node) return node
def create_fake_quantize_node(graph: Graph, name, data_type=np.float32): fq = FakeQuantize(graph, { 'name': name, 'levels': 0, 'stop_value_propagation': True }).create_node() input_low = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() input_height = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() output_low = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() output_height = Const(graph, { 'value': np.array(0.0, dtype=data_type) }).create_node() input_low.out_port(0).connect(fq.in_port(1)) input_height.out_port(0).connect(fq.in_port(2)) output_low.out_port(0).connect(fq.in_port(3)) output_height.out_port(0).connect(fq.in_port(4)) input_low.infer(input_low) input_height.infer(input_height) output_low.infer(output_low) output_height.infer(output_height) return fq
def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dict[int, np.array], op_attrs=None, input_node=None): operation = op(graph, op_attrs) node = operation.create_node() if input_node is not None: input_node.out_port(0).connect(node.in_port(0)) for idx, value in port_value_dict.items(): node.add_input_port(idx, skip_if_exist=True) value_input_node = Const(graph, {'name': node.name + '_input_port_' + str(idx) + '/value', 'value': value}).create_node() value_input_node.out_port(0).connect(node.in_port(idx)) if graph.stage != 'front': value_input_node.infer(value_input_node) return node
def test_v10_group_convolution_resolver_for_dynamic_weights(self): num_groups = 2 C_OUT = 8 nodes = { **regular_op_with_shaped_data( 'input', shape_array([1, dynamic_dimension_value, 224, 224]), { 'type': 'Parameter' }), **valued_const_with_data('weights', np.ones([num_groups, C_OUT, 7, 7])), **regular_op_with_empty_data('reshape', {'type': 'Reshape'}), **regular_op_with_empty_data( 'ss', { 'type': 'StridedSlice', 'begin_mask': [1], 'end_mask': [0], 'new_axis_mask': [0], 'shrink_axis_mask': [0], 'ellipsis_mask': [0] }), **regular_op_with_empty_data('weights_shape', {'type': 'ShapeOf'}), **regular_op_with_empty_data('input_shape', {'type': 'ShapeOf'}), **regular_op_with_empty_data('gather', {'type': 'Gather'}), **regular_op_with_empty_data('concat', {'type': 'Concat'}), **regular_op_with_empty_data('div', {'type': 'Divide'}), **valued_const_with_data( 'channels_const', int64_array([num_groups, C_OUT / num_groups])), **valued_const_with_data('num_groups', int64_array(num_groups)), **valued_const_with_data('begin', int64_array([2])), **valued_const_with_data('end', int64_array([-1])), **valued_const_with_data('channel_index', int64_array([1])), **valued_const_with_data('axis', int64_array(0)), **regular_op_with_shaped_data('convolution', None, { 'type': 'Convolution', 'group': num_groups, 'output': C_OUT }), **result(), } graph = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) V10ConvolutionWithGroupsResolver().find_and_replace_pattern(graph) nodes['convolution']['type'] = 'GroupConvolution' del nodes['convolution']['group'] graph_ref = build_graph(nodes, [ *connect('input', '0:convolution'), *connect('weights', '0:reshape'), ('input_d', 'input_shape', { 'in': 0, 'out': 0 }), ('weights_d', 'weights_shape', { 'in': 0, 'out': 0 }), *connect('input_shape', '0:gather'), *connect('channel_index', '1:gather'), *connect('axis', '2:gather'), *connect('weights_shape', '0:ss'), *connect('begin', '1:ss'), *connect('end', '2:ss'), *connect('gather', '0:div'), *connect('num_groups', '1:div'), *connect('channels_const', '0:concat'), *connect('div', '1:concat'), *connect('ss', '2:concat'), *connect('concat', '1:reshape'), *connect('reshape', '1:convolution'), *connect('convolution', 'output'), ], nodes_with_edges_only=True) Const.infer(Node(graph, 'convolution/GroupsAndOutputChannelsSize')) Const.infer(Node(graph, 'convolution/Div_input_port_1/value')) (flag, resp) = compare_graphs(graph, graph_ref, last_node='output') self.assertTrue(flag, resp)
def _fuse_mul(graph: Graph, node: Node, fuse_nodes: list, backward: bool = True): """ This function takes Mul node and array of convolution/fc nodes for further fusion Parameters ---------- x : bool If backward is False, that means that Convolution/FC goes after Mul node else means that Mul goes after Convolutions/FC :param backward: :param fuse_nodes: :param node: :param graph: """ is_fused = False const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node) if const_port is None or tensor_port is None: log.warning( 'Cannot do fuse_mul for node {} because this node has wrong inputs' .format(node.id)) return False for fuse_node in fuse_nodes: if fuse_node.soft_get('can_be_fused') is False: log.warning( 'Node {} can\'t be used in fusing because attr can_be_fused = False' .format(fuse_node.name)) return False if len(fuse_node.in_ports()) < 2: log.warning('Node {} has no weights node'.format(fuse_node.name)) return False if not backward and not fuse_node.has_valid('layout'): log.warning('Node {} has no layout attr'.format(fuse_node.name)) return False weights_port = fuse_node.in_port(1) if not weights_port.data.has_valid('output_channel_dim') or \ not weights_port.data.has_valid('input_channel_dim'): log.warning( 'Cannot do fuse_mul for node {} because there is no field ' + 'output_channel_dim and/or input_channel_dim in weights.'. format(fuse_node.soft_get('name'))) return False inp_ch = weights_port.data.get_attr('input_channel_dim') out_ch = weights_port.data.get_attr('output_channel_dim') if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()): log.warning('Node {} has wrong weights shape'.format( fuse_node.name)) return False for fuse_node in fuse_nodes: weights_port = fuse_node.in_port(1) value = mo_array(const_port.data.get_value()) value = np.squeeze(value) # TODO : ch_dim should be equal to node.in_node(1).value.shape # We will multiply weights according output/input channel dimension ch_dim = weights_port.data.get_attr( 'output_channel_dim' if backward else 'input_channel_dim') shape = mo_array([weights_port.data.get_shape()[ch_dim]]) # Scalar broadcast if value.size == 1: value = np.full(shape, value.item(), dtype=value.dtype) # Common broadcast for forward fusion if not backward: cnt = shape[-1] / value.shape[0] if fuse_node.layout == 'NCHW': tmp = mo_array([], dtype=value.dtype) for val in value: tmp = np.concatenate((tmp, np.repeat(val, cnt))) value = mo_array(tmp) else: value = np.tile(value, int(cnt)) # Expand dims for multiplication (ex. [38] to [38, 1, 1]) wdims_number = weights_port.data.get_attr('dims_number') for x in range(wdims_number - ch_dim - 1): shape = np.append(shape, 1) mul_val = mo_array(value) # If the value fails to reshape to the provided shape, skip fusing. # This can happen in case of group != 1 of the convolution. try: value = np.reshape(value, shape) except ValueError: log.error( "Cannot fuse const from {} to {}. Reshape failed. Skipping.". format(node.soft_get('name', node.id), fuse_node.soft_get('name', fuse_node.id)), extra={'is_warning': True}) return False # Weights multiplication mul_name = node.name + '_copy' mul_const = Const(graph, { 'value': value, 'name': mul_name + '/const' }).create_node() w_mul = node.copy_node({ 'name': mul_name, 'in_ports_count': len(node.in_ports()), 'out_ports_count': len(node.out_ports()), 'can_be_fused': False }) w_mul.in_port(const_port.idx).connect(mul_const.out_port(0)) w_const = weights_port.get_source() weights_port.get_connection().set_source(w_mul.out_port(0)) w_const.connect(w_mul.in_port(tensor_port.idx)) fuse_node_in_data = fuse_node.in_node(weights_port.idx) w_const_out_data = w_const.node.out_node(w_const.idx) # During this reconnection new data node name is copied from the data node # outgoing from w_const port. Duplicate names of data nodes lead to appearing # of duplicate op node names after constant folding. So we should manually # set a unique name for the new data node. if fuse_node_in_data.soft_get('name') == w_const_out_data.soft_get('name') and \ fuse_node_in_data.soft_get('name', None) is not None: fuse_node.in_node( weights_port.idx)['name'] = graph.unique_id(mul_name) # If we fuse in backward direction we should multiply biases if they exists if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \ not fuse_node.has_and_set('shape_input'): conv_bias = fuse_node.in_port(2) conv_bias.data.set_value(conv_bias.data.get_value() * np.squeeze(mul_val)) mul_const.infer(mul_const) w_mul.infer(w_mul) log.debug('Fused: {} to {}'.format(node.name, fuse_node.name)) is_fused = True if is_fused: # Delete Mul node producer_port = tensor_port.get_source() tensor_port.disconnect() const_port.disconnect() # as Mul node is added before convolution, output tensor from Convolution node # corresponds to original Mul node node.out_port(0).get_connection().set_source(producer_port, "dest") return is_fused