def resolve_shared_inputs(node: Node, port_ids_to_duplicate: List[int]):
    """
    Duplicates shared constants that are consumed by more than one node. 
    If constant is consumed by several ports of one node - no duplication gets done
    """
    graph = node.graph

    for port_id in port_ids_to_duplicate:
        dst_port_map = defaultdict(list)
        for dst in node.in_port(
                port_id).get_source().get_connection().get_destinations():
            dst_port_map[dst.node].append(dst.idx)
        del dst_port_map[node]
        value = node.in_port(port_id).data.get_value()
        if value is None:
            log.debug(
                'Can not duplicate due no data for in_port {} of node {}'.
                format(port_id, node.name))
        for node, idxs in dst_port_map.items():
            const = Const(
                graph, {
                    'value': np.array(value),
                    'name': node.soft_get('name', node.id) + '/duplicated_'
                }).create_node()
            for idx in idxs:
                node.in_port(idx).disconnect()
                const.out_port(0).connect(node.in_port(idx))
            const.infer(const)
Esempio n. 2
0
    def find_and_replace_pattern(self, graph: Graph):
        if graph.graph['layout'] != 'NHWC':
            # we check it here because this transformation is called explicitly from the pipeline
            return

        # reshape from 4D-5D -> ND. Insert Transpose(NC(D)HW->N(D)HWC) before Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.in_nodes(), 'Node {} does not have 0 input. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            input_shape = reinterp_shape_node.in_node(0).shape
            if not is_input_data_in_correct_layout(reinterp_shape_node, 0) and len(input_shape) >= 4:
                order_const = Const(graph, {'value': PermuteAttrs().get_nchw_to_nhwc_permutation(len(input_shape)).perm
                                            }).create_node()
                permute_node = Transpose(graph,
                                         {'name': reinterp_shape_node.in_port(0).get_source().node.name + '/Transpose'
                                          }).create_node()
                reinterp_shape_node.in_port(0).get_connection().insert_node(permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))
                order_const.infer(order_const)

                # do not infer the Transpose node because it should have input data node in NCHW layout (but currently
                # it is NHWC because data node attributes has not been permuted yet) and produce output in NHWC layout
                # (which is true at this moment)
                permute_node['need_shape_inference'] = False
                # mark the Transpose output data node having correct layout so it's shape will not be permuted
                mark_output_as_in_correct_layout(permute_node, 0)

                # keep the reinterp_shape_node in NHWC layout
                mark_input_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

        # reshape from ND -> 4D-5D. Insert Transpose(N(D)HWC->NC(D)HW) after Reshape
        for reinterp_shape_node_id in graph.get_nodes_with_attributes(reinterp_shape=True):
            reinterp_shape_node = Node(graph, reinterp_shape_node_id)
            assert 0 in reinterp_shape_node.out_nodes(), 'Node {} does not have 0 output. \n{}'.format(
                reinterp_shape_node_id, graph.dump_graph_for_graphviz())
            output_shape = reinterp_shape_node.out_node(0).shape
            if not is_output_data_in_correct_layout(reinterp_shape_node, 0) and len(output_shape) >= 4:
                order_const = Const(graph, {
                    'value': PermuteAttrs().get_nhwc_to_nchw_permutation(len(output_shape)).perm}).create_node()
                permute_node = Transpose(graph, {'name': reinterp_shape_node.id + '/Transpose'}).create_node()
                reinterp_shape_node.out_port(0).get_connection().insert_node(permute_node)
                order_const.out_port(0).connect(permute_node.in_port(1))

                # the Reshape and Transpose operations should work in original (NHWC layout) so the Transpose
                # will convert it to the NCHW
                mark_input_as_in_correct_layout(permute_node, 0)
                mark_input_as_in_correct_layout(permute_node, 1)
                # do not set Transpose output data node 'correct_data_layout' attribute so the data node shape will be
                # permuted

                # keep the reinterp_shape_node in NHWC layout
                mark_output_as_in_correct_layout(reinterp_shape_node, 0)
                mark_input_as_in_correct_layout(reinterp_shape_node, 1)

                # do not re-infer the Transpose node because it output data node should be in NHWC layout to make the
                # rest of the graph consistent
                permute_node['need_shape_inference'] = False
 def replace_pattern(self, graph: Graph, match: dict):
     conv = match['conv']
     weights = match['weights']
     input_shape = conv.in_port(0).data.get_shape()
     new_weights_shape = int64_array([(weights.value.shape[0] * weights.value.shape[1]) / (input_shape[1] / conv.group), input_shape[1] / conv.group, *weights.value.shape[2:]])
     new_weights = Const(graph, {'value': np.reshape(weights.value, new_weights_shape)}).create_node()
     weights.out_port(0).get_connection().set_source(new_weights.out_port(0))
     new_weights.infer(new_weights)
Esempio n. 4
0
def create_op_node_with_second_input(graph: Graph, op: callable, second_input_value: np.array, op_attrs=None,
                                     input_node=None):
    operation = op(graph, op_attrs)
    node = operation.create_node()
    if input_node is not None:
        input_node.out_port(0).connect(node.in_port(0))
    second_input_node = Const(graph, {'name': node.name + '/value', 'value': second_input_value}).create_node()
    second_input_node.out_port(0).connect(node.in_port(1))
    if graph.stage != 'front':
        second_input_node.infer(second_input_node)
    return node
Esempio n. 5
0
def create_fake_quantize_node(graph: Graph, name):
    fq = FakeQuantize(graph, {
        'name': name,
        'levels': 0,
        'stop_value_propagation': True
    }).create_node()

    input_low = Const(graph, {
        'value': np.array(0.0).astype(np.float32)
    }).create_node()
    input_height = Const(graph, {
        'value': np.array(0.0).astype(np.float32)
    }).create_node()
    output_low = Const(graph, {
        'value': np.array(0.0).astype(np.float32)
    }).create_node()
    output_height = Const(graph, {
        'value': np.array(0.0).astype(np.float32)
    }).create_node()

    input_low.out_port(0).connect(fq.in_port(1))
    input_height.out_port(0).connect(fq.in_port(2))
    output_low.out_port(0).connect(fq.in_port(3))
    output_height.out_port(0).connect(fq.in_port(4))

    input_low.infer(input_low)
    input_height.infer(input_height)
    output_low.infer(output_low)
    output_height.infer(output_height)

    return fq
 def replace_pattern(graph: Graph, match: dict):
     node = match['op']
     input_shape = node.in_port(0).data.get_shape()
     if len(input_shape) > 2:
         new_shape = Const(graph, {
             'value': np.array([0, -1], dtype=np.int64)
         }).create_node()
         reshape = Reshape(graph, {}).create_node()
         source = node.in_port(0).get_source()
         node.in_port(0).get_connection().set_source(reshape.out_port(0))
         source.connect(reshape.in_port(0))
         new_shape.out_port(0).connect(reshape.in_port(1))
         new_shape.infer(new_shape)
         reshape.infer(reshape)
Esempio n. 7
0
def create_op_with_const_inputs(graph: Graph, op: callable, port_value_dict: Dict[int, np.array],
                                op_attrs=None, input_node=None):
    operation = op(graph, op_attrs)
    node = operation.create_node()
    if input_node is not None:
        input_node.out_port(0).connect(node.in_port(0))

    for idx, value in port_value_dict.items():
        node.add_input_port(idx, skip_if_exist=True)
        value_input_node = Const(graph, {'name': node.name + '_input_port_' + str(idx) + '/value',
                                         'value': value}).create_node()
        value_input_node.out_port(0).connect(node.in_port(idx))
        if graph.stage != 'front':
            value_input_node.infer(value_input_node)
    return node
Esempio n. 8
0
    def replace_pattern(self, graph: Graph, match: dict):
        conv = match['conv']

        assert len(conv.out_nodes()) == 1, "Convolution operation {} should have 1 output data node".format(conv.id)
        out_data = conv.out_node()

        assert out_data.has_valid('shape'), 'Output shape is undefined for {} in back phase'.format(conv.id)
        out_shape = out_data.shape

        if out_shape.size != 3:
            return

        assert len(conv.in_nodes()) >= 1, "Convolution operation {} should have more than 1 input data node".format(
            conv.id)
        inp_data = conv.in_node()

        assert inp_data.has_valid('shape'), 'Input shape is undefined for {} in back phase'.format(conv.id)
        inp_shape = inp_data.shape
        new_inp_shape = np.insert(inp_shape, 2, 1)

        # setting to None to be overwritten by infer function
        conv.kernel_spatial_idx = None
        conv.spatial_dims = None

        # inserting fake H dimension
        conv.dilation = np.insert(conv.dilation, 2, 1)
        conv.kernel_spatial = np.append([1], conv.kernel_spatial)
        conv.pad = np.insert(conv.pad, 2, [0, 0], axis=0)
        conv.stride = np.insert(conv.stride, 2, 1)

        weights_node = conv.in_node(1)
        weights_node.value = np.reshape(weights_node.value, np.insert(weights_node.value.shape, 2, 1))
        weights_node.shape = np.array(weights_node.value.shape, dtype=np.int64)

        reshape = Reshape(graph, {'name': conv.name + '/reshape'}).create_node()
        reshape_dim = Const(graph, {'value': new_inp_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.in_port(0).get_connection().insert_node(reshape)
        reshape.in_port(1).connect(reshape_dim.out_port(0))

        reshape_back = Reshape(graph, {'name': conv.name + '/reshape_back'}).create_node()
        reshape_back_dim = Const(graph, {'value': out_shape, 'name': reshape.id + '/Dim'}).create_node()
        conv.out_port(0).get_connection().insert_node(reshape_back)
        reshape_back.in_port(1).connect(reshape_back_dim.out_port(0))

        # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour
        reshape_dim.infer(reshape_dim)
        reshape.infer(reshape)
        conv.infer(conv)
    def replace_pattern(self, graph: Graph, match: dict):
        node = match['pad']

        pb = node.pads[:, 0]
        pe = node.pads[:, 1]
        pm = node.mode

        pads_begin = Const(graph, {'value': np.array(pb)}).create_node()
        node.add_input_port(1, skip_if_exist=True)
        node.in_port(1).connect(pads_begin.out_port(0))
        pads_begin.infer(pads_begin)

        pads_end = Const(graph, {'value': np.array(pe)}).create_node()
        node.add_input_port(2, skip_if_exist=True)
        node.in_port(2).connect(pads_end.out_port(0))
        pads_end.infer(pads_end)

        del node['pads']

        if node.has_valid('fill_value') and pm == 'constant':
            pv = node.fill_value
            pad_value = Const(graph, {'value': np.array(pv)}).create_node()
            node.add_input_port(3, skip_if_exist=True)
            node.in_port(3).connect(pad_value.out_port(0))
            pad_value.infer(pad_value)

        del node['fill_value']
Esempio n. 10
0
def _fuse_mul(graph: Graph,
              node: Node,
              fuse_nodes: list,
              backward: bool = True):
    """
    This function takes Mul node and array of convolution/fc nodes for further fusion
    Parameters
    ----------
    x : bool
        If backward is False, that means that Convolution/FC goes after Mul node
        else means that Mul goes after Convolutions/FC
        :param backward:
        :param fuse_nodes:
        :param node:
        :param graph:
    """
    is_fused = False
    const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node)

    if const_port is None or tensor_port is None:
        log.warning(
            'Cannot do fuse_mul for node {} because this node has wrong inputs'
            .format(node.id))
        return False

    for fuse_node in fuse_nodes:
        if fuse_node.soft_get('can_be_fused') is False:
            log.warning(
                'Node {} can\'t be used in fusing because attr can_be_fused = False'
                .format(fuse_node.name))
            return False

        if len(fuse_node.in_ports()) < 2:
            log.warning('Node {} has no weights node'.format(fuse_node.name))
            return False

        if not backward and not fuse_node.has_valid('layout'):
            log.warning('Node {} has no layout attr'.format(fuse_node.name))
            return False

        weights_port = fuse_node.in_port(1)
        if not weights_port.data.has_valid('output_channel_dim') or \
                not weights_port.data.has_valid('input_channel_dim'):
            log.warning(
                'Cannot do fuse_mul for node {} because there is no field ' +
                'output_channel_dim and/or input_channel_dim in weights.'.
                format(fuse_node.soft_get('name')))
            return False

        inp_ch = weights_port.data.get_attr('input_channel_dim')
        out_ch = weights_port.data.get_attr('output_channel_dim')
        if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()):
            log.warning('Node {} has wrong weights shape'.format(
                fuse_node.name))
            return False

    for fuse_node in fuse_nodes:
        weights_port = fuse_node.in_port(1)
        value = np.array(const_port.data.get_value())

        value = np.squeeze(value)

        # TODO : ch_dim should be equal to node.in_node(1).value.shape
        # We will multiply weights according output/input channel dimension
        ch_dim = weights_port.data.get_attr(
            'output_channel_dim' if backward else 'input_channel_dim')
        shape = np.array([weights_port.data.get_shape()[ch_dim]])

        # Scalar broadcast
        if value.size == 1:
            value = np.full(shape, value.item())

        # Common broadcast for forward fusion
        if not backward:
            cnt = shape[-1] / value.shape[0]
            if fuse_node.layout == 'NCHW':
                tmp = []
                for val in value:
                    tmp = np.concatenate((tmp, np.repeat(val, cnt)))
                value = np.array(tmp)
            else:
                value = np.tile(value, int(cnt))

        # Expand dims for multiplication (ex. [38] to [38, 1, 1])
        wdims_number = weights_port.data.get_attr('dims_number')
        for x in range(wdims_number - ch_dim - 1):
            shape = np.append(shape, 1)

        mul_val = np.array(value)
        # If the value fails to reshape to the provided shape, skip fusing.
        # This can happen in case of group != 1 of the convolution.
        try:
            value = np.reshape(value, shape)
        except ValueError:
            log.error(
                "Cannot fuse const from {} to {}. Reshape failed. Skipping.".
                format(node.soft_get('name', node.id),
                       fuse_node.soft_get('name', fuse_node.id)),
                extra={'is_warning': True})
            return False

        # Weights multiplication
        mul_name = node.name + '_copy'
        mul_const = Const(graph, {
            'value': value,
            'name': mul_name + '/const'
        }).create_node()
        w_mul = node.copy_node({
            'name': mul_name,
            'in_ports_count': len(node.in_ports()),
            'out_ports_count': len(node.out_ports()),
            'can_be_fused': False
        })
        w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
        w_const = weights_port.get_source()
        weights_port.get_connection().set_source(w_mul.out_port(0))
        w_const.connect(w_mul.in_port(tensor_port.idx))

        fuse_node_in_data = fuse_node.in_node(weights_port.idx)
        w_const_out_data = w_const.node.out_node(w_const.idx)

        # During this reconnection new data node name is copied from the data node
        # outgoing from w_const port. Duplicate names of data nodes lead to appearing
        # of duplicate op node names after constant folding. So we should manually
        # set a unique name for the new data node.
        if fuse_node_in_data.soft_get('name') == w_const_out_data.soft_get('name') and \
                fuse_node_in_data.soft_get('name', None) is not None:
            fuse_node.in_node(
                weights_port.idx)['name'] = graph.unique_id(mul_name)

        # If we fuse in backward direction we should multiply biases if they exists
        if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \
                not fuse_node.has_and_set('shape_input'):
            conv_bias = fuse_node.in_port(2)
            conv_bias.data.set_value(conv_bias.data.get_value() *
                                     np.squeeze(mul_val))

        mul_const.infer(mul_const)
        w_mul.infer(w_mul)

        log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
        is_fused = True

    if is_fused:
        # Delete Mul node
        producer_port = tensor_port.get_source()
        tensor_port.disconnect()
        const_port.disconnect()
        # as Mul node is added before convolution, output tensor from Convolution node
        # corresponds to original Mul node
        node.out_port(0).get_connection().set_source(producer_port, "dest")

    return is_fused
Esempio n. 11
0
def _fuse_mul(graph: Graph,
              node: Node,
              fuse_nodes: list,
              backward: bool = True):
    """
    This function takes Mul node and array of convolution/fc nodes for further fusion
    Parameters
    ----------
    x : bool
        If backward is False, that means that Convolution/FC goes after Mul node
        else means that Mul goes after Convolutions/FC
        :param backward:
        :param fuse_nodes:
        :param node:
        :param graph:
    """
    is_fused = False
    const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node)

    if const_port is None or tensor_port is None:
        log.warning(
            'Cannot do fuse_mul for node {} because this node has wrong inputs'
            .format(node.id))
        return False

    for fuse_node in fuse_nodes:
        if fuse_node.soft_get('can_be_fused') is False:
            log.warning(
                'Node {} can\'t be used in fusing because attr can_be_fused = False'
                .format(fuse_node.name))
            return False

        if len(fuse_node.in_ports()) < 2:
            log.warning('Node {} has no weights node'.format(fuse_node.name))
            return False

        if not backward and not fuse_node.has_valid('layout'):
            log.warning('Node {} has no layout attr'.format(fuse_node.name))
            return False

        weights_port = fuse_node.in_port(1)
        if not weights_port.data.has_valid('output_channel_dim') or \
                not weights_port.data.has_valid('input_channel_dim'):
            log.warning(
                'Cannot do fuse_mul for node {} because there is no field ' +
                'output_channel_dim and/or input_channel_dim in weights.'.
                format(fuse_node.soft_get('name')))
            return False

        inp_ch = weights_port.data.get_attr('input_channel_dim')
        out_ch = weights_port.data.get_attr('output_channel_dim')
        if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()):
            log.warning('Node {} has wrong weights shape'.format(
                fuse_node.name))
            return False

    for fuse_node in fuse_nodes:
        weights_port = fuse_node.in_port(1)
        value = np.array(const_port.data.get_value())

        value = np.squeeze(value)

        # TODO : ch_dim should be equal to node.in_node(1).value.shape
        # We will multiply weights according output/input channel dimension
        ch_dim = weights_port.data.get_attr(
            'output_channel_dim' if backward else 'input_channel_dim')
        shape = np.array([weights_port.data.get_shape()[ch_dim]])

        # Scalar broadcast
        if value.size == 1:
            value = np.full(shape, value.item())

        # Common broadcast for forward fusion
        if not backward:
            cnt = shape[-1] / value.shape[0]
            if fuse_node.layout == 'NCHW':
                tmp = []
                for val in value:
                    tmp = np.concatenate((tmp, np.repeat(val, cnt)))
                value = np.array(tmp)
            else:
                value = np.tile(value, int(cnt))

        # Expand dims for multiplication (ex. [38] to [38, 1, 1])
        wdims_number = weights_port.data.get_attr('dims_number')
        for x in range(wdims_number - ch_dim - 1):
            shape = np.append(shape, 1)

        mul_val = np.array(value)
        value = np.reshape(value, shape)

        # Weights multiplication
        mul_const = Const(graph, {'value': value}).create_node()
        w_mul = node.copy_node({
            'in_ports_count': len(node.in_ports()),
            'out_ports_count': len(node.out_ports()),
            'can_be_fused': False
        })
        w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
        w_const = weights_port.get_source()
        weights_port.get_connection().set_source(w_mul.out_port(0))
        w_const.connect(w_mul.in_port(tensor_port.idx))

        # If we fuse in backward direction we should multiply biases if they exists
        if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \
                not fuse_node.has_and_set('shape_input'):
            conv_bias = fuse_node.in_port(2)
            conv_bias.data.set_value(conv_bias.data.get_value() *
                                     np.squeeze(mul_val))

        mul_const.infer(mul_const)
        w_mul.infer(w_mul)

        log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
        is_fused = True

    if is_fused:
        # Delete Mul node
        producer_port = tensor_port.get_source()
        tensor_port.disconnect()
        const_port.disconnect()
        node.out_port(0).get_connection().set_source(producer_port)

    return is_fused
Esempio n. 12
0
    def replace_pattern(graph: Graph, match: dict):
        node = match['matmul']
        name = node.soft_get('name', node.id)

        A_shape = node.in_port(0).data.get_shape()
        B_shape = node.in_port(1).data.get_shape()
        out_shape = node.out_port(0).data.get_shape()

        assert A_shape is not None and B_shape is not None and out_shape is not None

        B_value = node.in_port(1).data.get_value()
        if (B_value is not None or node.in_port(1).get_source().node.has_and_set('stop_value_propagation')) and B_shape[
            B_shape != 1].size <= 2:
            # transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
            # to FullyConnected representation: [I, K] * [O, K] = [I, O]
            B, I, K, O, aligned_A_shape, aligned_B_shape = MatMulToFullyConnected.get_matmul_BIKO(node)

            # weights normalization
            if not node.transpose_b:
                # FullyConnected weights layout is OI
                # MatMul second input layout is (B)IO
                transpose_order = list(range(B_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/weights_transpose'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(weights_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if node.in_port(1).data.get_shape().size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/weights_reshape'}).create_node()

                weights_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(reshape.out_port(0))

                reshape.in_port(0).connect(weights_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(1).data.get_shape(), int64_array([O, K]))), \
                "MatMul `{}` was not converted to FullyConnected: wrong weights shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(1).data.get_shape(), B, I, K, O)

            node.in_port(1).bin = 'weights'
            del node['transpose_b']

            # input normalization
            if node.transpose_a:
                transpose_order = list(range(A_shape.size))
                transpose_order[-1], transpose_order[-2] = transpose_order[-2], transpose_order[-1]

                order = Const(graph, {'value': int64_array(transpose_order)}).create_node()
                transpose = Transpose(graph, {'name': name + '/input_transpose'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(transpose.out_port(0))
                transpose.in_port(0).connect(input_source)
                transpose.in_port(1).connect(order.out_port(0))

                order.infer(order)
                transpose.infer(transpose)

            if A_shape.size != 2:
                const = Const(graph, {'value': int64_array([-1, K])}).create_node()
                reshape = Reshape(graph, {'name': name + '/input_reshape'}).create_node()

                input_source = node.in_port(0).get_source()
                node.in_port(0).get_connection().set_source(reshape.out_port(0))
                reshape.in_port(0).connect(input_source)
                reshape.in_port(1).connect(const.out_port(0))

                const.infer(const)
                reshape.infer(reshape)

            assert np.all(np.array_equal(node.in_port(0).data.get_shape(), int64_array([np.prod(B) * I, K]))), \
                "MatMul `{}` wasn't converted to FullyConnected: wrong input shape: {}, " \
                "B={}, I={}, K={}, O={}".format(name, node.in_port(0).data.get_shape(), B, I, K, O)

            del node['transpose_a']

            FullyConnected.update_node_stat(node, {'out-size': O})

            # output normalization
            if out_shape.size != 2:
                const = Const(graph, {'value': int64_array([*B, I, O])}).create_node()
                reshape = Reshape(graph, {'name': name + '/output_reshape'}).create_node()

                dst = node.out_port(0).get_destination()
                node.out_port(0).get_connection().set_destination(reshape.in_port(0))
                const.out_port(0).connect(reshape.in_port(1))
                reshape.out_port(0).connect(dst)

                node.infer(node)

                const.infer(const)
                reshape.infer(reshape)

        else:
            assert A_shape.size == out_shape.size
            assert B_shape.size <= out_shape.size
            if B_shape.size != out_shape.size:
                unsqueeze_dim = Const(graph, {'value': int64_array(list(range(out_shape.size - B_shape.size)))
                                              }).create_node()
                unsqueeze = Unsqueeze(graph, {}).create_node()
                B_source = node.in_port(1).get_source()
                node.in_port(1).get_connection().set_source(unsqueeze.out_port(0))
                unsqueeze.in_port(0).connect(B_source)
                unsqueeze.in_port(1).connect(unsqueeze_dim.out_port(0))

                unsqueeze_dim.infer(unsqueeze_dim)
                unsqueeze.infer(unsqueeze)

            Gemm.update_node_stat(node, {
                'transpose_a': node.has_and_set('transpose_a'),
                'transpose_b': node.has_and_set('transpose_b'),
            })
Esempio n. 13
0
def _fuse_mul(graph: Graph,
              node: Node,
              fuse_nodes: list,
              backward: bool = True):
    """
    This function takes Mul node and array of convolution/fc nodes for further fusion
    Parameters
    ----------
    x : bool
        If backward is False, that means that Convolution/FC goes after Mul node
        else means that Mul goes after Convolutions/FC
        :param backward:
        :param fuse_nodes:
        :param node:
        :param graph:
    """
    is_fused = False
    const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node)

    if const_port is None or tensor_port is None:
        log.warning(
            'Cannot do fuse_mul for node {} because this node has wrong inputs'
            .format(node.id))
        return False

    for fuse_node in fuse_nodes:
        if fuse_node.soft_get('can_be_fused') is False:
            log.warning(
                'Node {} can\'t be used in fusing because attr can_be_fused = False'
                .format(fuse_node.name))
            return False

        if len(fuse_node.in_ports()) < 2:
            log.warning('Node {} has no weights node'.format(fuse_node.name))
            return False

        if not backward and not fuse_node.has_valid('layout'):
            log.warning('Node {} has no layout attr'.format(fuse_node.name))
            return False

        weights_port = fuse_node.in_port(1)
        if not weights_port.data.has_valid('output_channel_dim') or \
                not weights_port.data.has_valid('input_channel_dim'):
            log.warning(
                'Cannot do fuse_mul for node {} because there is no field ' +
                'output_channel_dim and/or input_channel_dim in weights.'.
                format(fuse_node.soft_get('name')))
            return False

        inp_ch = weights_port.data.get_attr('input_channel_dim')
        out_ch = weights_port.data.get_attr('output_channel_dim')
        if max(inp_ch, out_ch) >= len(weights_port.data.get_shape()):
            log.warning('Node {} has wrong weights shape'.format(
                fuse_node.name))
            return False

    for fuse_node in fuse_nodes:
        weights_port = fuse_node.in_port(1)
        value = np.array(const_port.data.get_value())

        value = np.squeeze(value)

        # TODO : ch_dim should be equal to node.in_node(1).value.shape
        # We will multiply weights according output/input channel dimension
        ch_dim = weights_port.data.get_attr(
            'output_channel_dim' if backward else 'input_channel_dim')
        shape = np.array([weights_port.data.get_shape()[ch_dim]])

        # Scalar broadcast
        if value.size == 1:
            value = np.full(shape, value.item())

        # Common broadcast for forward fusion
        if not backward:
            cnt = shape[-1] / value.shape[0]
            if fuse_node.layout == 'NCHW':
                tmp = []
                for val in value:
                    tmp = np.concatenate((tmp, np.repeat(val, cnt)))
                value = np.array(tmp)
            else:
                value = np.tile(value, int(cnt))

        # Expand dims for multiplication (ex. [38] to [38, 1, 1])
        wdims_number = weights_port.data.get_attr('dims_number')
        for x in range(wdims_number - ch_dim - 1):
            shape = np.append(shape, 1)

        mul_val = np.array(value)
        # If the value fails to reshape to the provided shape, skip fusing.
        # This can happen in case of group != 1 of the convolution.
        try:
            value = np.reshape(value, shape)
        except ValueError:
            log.error(
                "Cannot fuse const from {} to {}. Reshape failed. Skipping.".
                format(node.soft_get('name', node.id),
                       fuse_node.soft_get('name', fuse_node.id)),
                extra={'is_warning': True})
            return False

        # Weights multiplication
        mul_name = node.name + '_copy'
        mul_const = Const(graph, {
            'value': value,
            'name': mul_name + '/const'
        }).create_node()
        w_mul = node.copy_node({
            'name': mul_name,
            'in_ports_count': len(node.in_ports()),
            'out_ports_count': len(node.out_ports()),
            'can_be_fused': False
        })
        w_mul.in_port(const_port.idx).connect(mul_const.out_port(0))
        r"""
        In this transformation we remove Mul or Div node (node) that goes after fuse_node and
        create new Mul node (w_mul), connect it with the corrected const value (mul_const) and
        insert w_mul before the fuse_node. So the input data of fuse_node becomes different. 
        For this reason we need to use set_destination from previous operation to w_mul which 
        guaranties that data node will be reused on previous_op -> w_mul connection and its 
        attributes won't be copied to the data node of w_mul -> fuse_node connection.   
        
        BEFORE                        AFTER

                                 previous_op      mul_const
                                         \     /
            previous_op                   w_mul
               |                            |
             fuse_node   const          fuse_node     
                 \     /                    |       
                  node                   next_op      
                   |                              
                 next_op                      
        """
        weights_port.get_connection().set_destination(
            w_mul.in_port(tensor_port.idx))
        w_mul.out_port(0).connect(weights_port)

        # As fusing is applied to convolutions it is important to keep 'permutation' and 'input_permutation' attributes
        # which were obtained from original model. These attributes are stored on the incoming edge to the operation
        # node and during the reconnection they are moved to the new connection. But during reconnection in this
        # transformation these attributes are moved to the previous node. So we need manually set them at the
        # incoming edge to fuse_node.
        in_edge = w_mul.in_edge(tensor_port.idx)
        if 'permutation' in in_edge:
            fuse_node.in_edge(
                weights_port.idx)['permutation'] = in_edge['permutation']
        if 'input_permutation' in in_edge:
            fuse_node.in_edge(
                weights_port.idx
            )['input_permutation'] = in_edge['input_permutation']

        # If we fuse in backward direction we should multiply biases if they exists
        if backward and len(fuse_node.in_ports()) == 3 and not fuse_node.in_port(2).disconnected() and \
                not fuse_node.has_and_set('shape_input'):
            conv_bias = fuse_node.in_port(2)
            conv_bias.data.set_value(conv_bias.data.get_value() *
                                     np.squeeze(mul_val))

        mul_const.infer(mul_const)
        w_mul.infer(w_mul)

        log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
        is_fused = True

    if is_fused:
        # Delete Mul node
        producer_port = tensor_port.get_source()
        tensor_port.disconnect()
        const_port.disconnect()
        # as Mul node is added before convolution, output tensor from Convolution node
        # corresponds to original Mul node
        node.out_port(0).get_connection().set_source(producer_port, "dest")

    return is_fused
def _fuse_add(graph: Graph,
              node: Node,
              fuse_nodes: List[Node],
              backward: bool = True):
    """
    This function takes Add node and Convolution/FC nodes for further fusion and then deletes Add node
    In case if Convolution/FC Bias absence it will be created
    """
    is_fused = False
    const_port, tensor_port = get_value_in_port(node), get_tensor_in_port(node)

    if const_port is None or tensor_port is None:
        log.warning(
            'Cannot do fuse_add for node {} because this node has wrong inputs'
            .format(node.id))
        return False

    # if len(node.in_node(const_id).shape) > 2 or any([x == 0 for x in node.in_node(const_id).shape]):
    #     log.warning('Cannot do fuse_add for node {} because this node has wrong shape'.format(node.id))
    #     return False

    for fuse_node in fuse_nodes:
        if fuse_node.soft_get('can_be_fused') is False:
            log.warning(
                'Node {} can\'t be used in fusing due to user specified attr can_be_fused = False'
                .format(fuse_node.name))
            return False
        if not fuse_node.has_valid('layout'):
            log.warning('Node {} has no layout attr'.format(fuse_node.name))
            return False
        if len(fuse_node.in_ports()) < 2:
            log.warning('Node {} has no weights node'.format(fuse_node.name))
            return False

    for fuse_node in fuse_nodes:
        weights_port = fuse_node.in_port(1)
        value = np.array(const_port.data.get_value())

        # If forward, broadcast value
        if not backward:
            cnt = weights_port.data.get_shape(
            )[-1] / const_port.data.get_shape()[0]
            if fuse_node.layout == 'NCHW':
                tmp = []
                for val in value:
                    tmp = np.concatenate((tmp, np.repeat(val, cnt)))
                value = np.array(tmp)
            else:
                value = np.tile(value, int(cnt))

        value = np.squeeze(value)

        # Create BIAS data node if not exists
        if len(fuse_node.in_ports()) <= 2:
            fuse_node.add_input_port(idx=2)
        if fuse_node.in_port(2).disconnected(
        ) or fuse_node.in_port(2).data.get_value() is None:
            # Broadcast if scalar
            if value.size == 1:
                id = weights_port.data.get_attr(
                    'output_channel_dim'
                ) if backward else weights_port.data.get_attr(
                    'input_channel_dim')
                vshape = weights_port.data.get_shape()[id]
                value = np.full(vshape, value.item())

            if not backward:
                value = np.dot(weights_port.data.get_value(), value)

            const_bias_node = Const(
                graph, dict(name="bias_data",
                            value=np.array(value))).create_node()

            fuse_node.in_port(2).connect(const_bias_node.out_port(0))
            fuse_node.in_port(2).bin = 'biases'
            const_bias_node.infer(const_bias_node)

            fuse_node['bias_term'] = True
        else:
            bias_value = fuse_node.in_port(2).data.get_value()
            if not backward:
                fuse_node.in_port(2).data.set_value(
                    bias_value +
                    np.dot(fuse_node.in_port(1).data.get_value(), value))
            else:
                fuse_node.in_port(2).data.set_value(bias_value + value)

        log.debug('Fused: {} to {}'.format(node.name, fuse_node.name))
        is_fused = True

    if is_fused:
        # Delete Add node
        producer_port = tensor_port.get_source()
        tensor_port.disconnect()
        const_port.disconnect()
        node.out_port(0).get_connection().set_source(producer_port)

    return is_fused
Esempio n. 15
0
    def replace_pattern(graph: Graph, match: dict):
        """
        Workarounds not supported type of Tile in Inference Engine (Tiles are supported for 2-D or 4-D tensors):
        Searches for Tiles with 3D shapes and covers it with Reshapes.

        Example: Tile (axis=1, tiles=16):
            in_shape: [1,1,101]
            out_shape: [1,16,101]

        Old behaviour:
            Tile -> [1,16,101]
        New behaviour:
            Reshape [1,1,101,1] -> Tile -> [1,16,101,1] -> Reshape [1,16,101]
        """
        tile = match['tile']

        assert len(tile.out_nodes(
        )) == 1, "Tile operation {} should have 1 output data node".format(
            tile.id)
        out_data = tile.out_node()

        assert out_data.has_valid(
            'shape'), 'Output shape is undefined for {} in back phase'.format(
                tile.id)
        out_shape = out_data.shape

        if out_shape.size != 3:
            return

        assert len(tile.in_nodes(
        )) == 1, "Tile operation {} should have 1 input data node".format(
            tile.id)
        inp_data = tile.in_node()

        assert inp_data.has_valid(
            'shape'), 'Input shape is undefined for {} in back phase'.format(
                tile.id)
        inp_shape = inp_data.shape
        new_inp_shape = np.append(inp_shape, [1])

        reshape = Reshape(graph, {
            'name': tile.name + '/reshape'
        }).create_node()
        reshape_dim = Const(graph, {
            'value': new_inp_shape,
            'name': reshape.id + '/Dim'
        }).create_node()
        tile.in_port(0).get_connection().insert_node(reshape)
        reshape.in_port(1).connect(reshape_dim.out_port(0))

        reshape_back = Reshape(graph, {
            'name': tile.name + '/reshape_back'
        }).create_node()
        reshape_back_dim = Const(graph, {
            'value': out_shape,
            'name': reshape.id + '/Dim'
        }).create_node()
        tile.out_port(0).get_connection().insert_node(reshape_back)
        reshape_back.in_port(1).connect(reshape_back_dim.out_port(0))

        # run shape inference manually for several nodes to override shapes of the model nodes which changed behaviour
        reshape_dim.infer(reshape_dim)
        reshape.infer(reshape)
        tile.infer(tile)