Ejemplo n.º 1
0
def grouped_convolutions_fusing(graph: Graph):
    while True:
        is_fused = False
        graph_clean_up(graph, ['TFCustomSubgraphCall', 'ShapeOf', 'Shape'])
        for node in graph.pseudo_topological_sort():
            if node.kind == 'op' and len(node.out_nodes()) > 1:
                if node.soft_get('can_be_fused') == False:
                    continue

                is_valid_convolutions = True
                last_layer = None

                next_nodes = get_next_operation(node)
                # Check that all operation after this one are Convolutions
                # and all convolutions has same output
                if len(next_nodes) > 1 and all(_node.soft_get('type') in ['Convolution', 'Deconvolution'] for _node in next_nodes):
                    for conv in next_nodes:
                        conv_outputs = get_next_operation(conv)
                        if conv.soft_get('can_be_fused') == False:
                            is_valid_convolutions = False
                        if len(conv_outputs) != 1:
                            is_valid_convolutions = False
                        if last_layer is None:
                            last_layer = conv_outputs[0].id
                        elif conv_outputs[0].id != last_layer:
                            is_valid_convolutions = False

                    if is_valid_convolutions:
                        is_fused = concat_convolutions(graph, node, Node(graph, last_layer))
                        if is_fused:
                            break

        if not is_fused:
            break
Ejemplo n.º 2
0
def _simple_stride_prop(graph: Graph,
                        node: Node,
                        spatial_dims,
                        supported=True):
    """
    This function handles stride propagation for op nodes. If node is in supported ops dict so this is supported operation and we
    can propagate stride directly via this op (stride_prop will be set by using bottom stride_prop), otherwise we can't and
    stride_prop attr will be set as 1,1,1,1
    """
    next_ops = get_next_operation(node)
    stride_props, all_ops_are_valid = _check_next_ops(next_ops)

    if not supported or not all_ops_are_valid:
        # We have to insert pooling layers
        for op in next_ops:
            if op.has_valid('stride_prop') and not np.array_equal(op.stride_prop[spatial_dims], np.array([1, 1])) and \
                    (op.has_valid('has_stride') == False or op.soft_get('has_stride') == False):
                _insert_pooling(graph, node.out_node(), op, spatial_dims)
        # If Convolution is valid then set `stride_prop` to Convolution stride
        node['stride_prop'] = np.array([1, 1, 1, 1])
        return

    for op in next_ops:
        if op.soft_get('has_stride') == True:
            op.stride = np.array([1, 1, 1, 1])
            log.debug(
                "STRIDE PROP: {} {} strides was moved upper via {}".format(
                    op.type, op.name, node.name))

    node['stride_prop'] = np.array(
        stride_props[0]) if len(stride_props) > 0 else np.array([1, 1, 1, 1])
    node['is_partial_inferred'] = False
    _clean_fw_tensor_attrs(node.out_node())
Ejemplo n.º 3
0
def _conv_stride_prop(graph: Graph, node: Node, spatial_dims, supported=True):
    """
    This function handles convolution stride propagation. There is two cases: conv->(op) and conv->conv. In first case
    we propagate stride from op, and in second case we also change stride for second conv
    """
    next_ops = get_next_operation(node)
    stride_props, all_ops_are_valid = _check_next_ops(next_ops)

    def _check_convolution(node: Node):
        return node.has_valid('kernel_spatial') and np.array_equal(
            node.kernel_spatial, np.array([1, 1]))

    # Check that all ops are valid and have same values
    if not all_ops_are_valid:
        # We have to insert pooling layers
        for op in next_ops:
            if op.has_valid('stride_prop') and not np.array_equal(
                    op.stride_prop[spatial_dims], np.array([1, 1])):
                # Insert pooling
                _insert_pooling(graph, node.out_node(), op, spatial_dims)
    elif len(stride_props) > 0:
        node.stride *= stride_props[0]
        log.debug('STRIDE PROP: {} got new strides {}'.format(
            node.name, node.stride))
        for op in next_ops:
            if op.soft_get('has_stride') == True:
                op.stride = np.array([1, 1, 1, 1])
        node['is_partial_inferred'] = False
        node['output_spatial_shape'] = False
        _clean_fw_tensor_attrs(node.out_node())

    # If Convolution is valid then set `stride_prop` to Convolution stride
    node['stride_prop'] = np.array(
        node.stride) if _check_convolution(node) else np.array([1, 1, 1, 1])
    def find_and_replace_pattern(self, graph: Graph):
        from tensorflow.core.framework import types_pb2 as tf_types  # pylint: disable=no-name-in-module
        for node_name, node_attrs in list(graph.nodes(data=True)):
            node = Node(graph, node_name)
            pb = node_attrs.get('pb')
            if pb is not None and pb.op == 'Parameter' and pb.attr['dtype'].type != tf_types.DT_FLOAT:
                log.info('Placeholder "{}" has type that is different from DT_FLOAT'.format(node_name))
                next_ops = get_next_operation(node)
                # check that all output nodes are nodes of type 'ToFloat'
                if all([ChangePlaceholderTypes.is_node_casts_to_float(op) and
                        len(op.in_nodes()) == 1 for op in next_ops]):
                    ChangePlaceholderTypes.change_node_type(node, tf_types.DT_FLOAT)
                    ChangePlaceholderTypes.remove_node_preserving_edges(node, next_ops)  # remove 'Cast' nodes

                elif all([ChangePlaceholderTypes.is_node_gather(op) for op in next_ops] for op in next_ops):
                    ChangePlaceholderTypes.change_node_type(node, tf_types.DT_FLOAT)

                else:
                    raise Error(
                        ('Cannot convert type of placeholder "{}" because not all of its outputs are "Cast" to float '
                         'operations: {}. ' +
                         refer_to_faq_msg(49)),
                        node.soft_get('name'),
                        [op.soft_get('name') for op in next_ops]
                    )
Ejemplo n.º 5
0
def change_placeholders_types_to_FP32(graph: nx.MultiDiGraph):
    for node_name, node_attrs in list(graph.nodes(data=True)):
        node = Node(graph, node_name)
        pb = node_attrs.get('pb')
        if pb is not None and pb.op == 'Placeholder' and pb.attr[
                'dtype'].type != tf_types.DT_FLOAT:
            log.info(
                'Placeholder "{}" has type that is different from DT_FLOAT'.
                format(node_name))
            next_ops = get_next_operation(node)
            # check that all output nodes are nodes of type 'ToFloat'
            if all([
                    is_node_casts_to_float(op) and len(op.in_nodes()) == 1
                    for op in next_ops
            ]):
                change_node_type(node, tf_types.DT_FLOAT)
                remove_node_preserving_edges(node,
                                             next_ops)  # remove 'Cast' nodes
            elif all([is_node_gather(op) for op in next_ops]
                     for op in next_ops):
                change_node_type(node, tf_types.DT_FLOAT)
            else:
                raise Error((
                    'Cannot convert type of placeholder "{}" because not all of its outputs are "Cast" to float '
                    'operations: {}. ' + refer_to_faq_msg(49)),
                            node.soft_get('name'),
                            [op.soft_get('name') for op in next_ops])
    return graph
Ejemplo n.º 6
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            if node not in graph.nodes():
                continue
            permute_node = Node(graph, node)
            if permute_node.has_valid(
                    'type') and permute_node.type == 'Permute':
                list_of_permutes = [permute_node]
                # Get sequence of permutations
                node = permute_node
                while True:
                    next_ops = get_next_operation(node)
                    if len(next_ops) != 1:
                        break

                    next_op = next_ops[0]
                    if next_op.has_valid('type') and next_op.type == 'Permute':
                        list_of_permutes.append(next_op)
                        node = next_op
                    else:
                        break

                final_permutation = np.array(
                    [x for x in range(len(list_of_permutes[0].order))],
                    dtype=np.int64)
                for permute in list_of_permutes:
                    if not permute.has_valid('order'):
                        raise Error(
                            "Permute node {} has wrong attribute order = None".
                            format(permute.name))
                    final_permutation = final_permutation[np.array(
                        permute.order, dtype=np.int64)]

                if np.array_equal(
                        final_permutation,
                    [x for x in range(len(list_of_permutes[0].order))]):
                    first_data_node, last_data_node = list_of_permutes[
                        0].in_node(), list_of_permutes[-1].out_node()
                    graph.remove_edge(first_data_node.id,
                                      list_of_permutes[0].id)
                else:
                    if len(list_of_permutes) < 2:
                        continue
                    first_data_node, last_data_node = list_of_permutes[
                        0].out_node(), list_of_permutes[-1].out_node()
                    list_of_permutes[0].order = final_permutation
                    graph.remove_edge(first_data_node.id,
                                      first_data_node.out_node().id)

                graph.remove_edge(last_data_node.in_node().id,
                                  last_data_node.id)

                merge_data_nodes(graph, first_data_node, last_data_node)
                graph.remove_node(last_data_node.id)
                graph_clean_up_tf(graph)
Ejemplo n.º 7
0
    def find_and_replace_pattern(self, graph: Graph):
        for permute_node in graph.get_op_nodes(type='Transpose'):
            if permute_node.id not in graph.nodes():
                continue

            list_of_permutes = [permute_node]
            # Get sequence of permutations
            node = permute_node
            while True:
                next_ops = get_next_operation(node)
                if len(next_ops) != 1:
                    break

                next_op = next_ops[0]
                if next_op.soft_get('type') == 'Transpose':
                    list_of_permutes.append(next_op)
                    node = next_op
                else:
                    break

            final_permutation = int64_array([
                x for x in range(
                    len(list_of_permutes[0].in_port(1).data.get_value()))
            ])
            for permute in list_of_permutes:
                order = permute.in_port(1).data.get_value()
                if order is None:
                    raise Error(
                        "Transpose node {} has wrong order for permute = None".
                        format(permute.name))
                final_permutation = final_permutation[int64_array(order)]

            if np.array_equal(final_permutation, [
                    x for x in range(
                        len(list_of_permutes[0].in_port(1).data.get_value()))
            ]):
                first_data_node, last_data_node = list_of_permutes[0].in_node(
                ), list_of_permutes[-1].out_node()
                graph.remove_edge(first_data_node.id, list_of_permutes[0].id)
            else:
                if len(list_of_permutes) < 2:
                    continue
                first_data_node, last_data_node = list_of_permutes[0].out_node(
                ), list_of_permutes[-1].out_node()
                list_of_permutes[0].in_port(1).data.set_value(
                    final_permutation)
                graph.remove_edge(first_data_node.id,
                                  first_data_node.out_node().id)

            graph.remove_edge(last_data_node.in_node().id, last_data_node.id)

            merge_data_nodes(graph, first_data_node, last_data_node)
            graph.remove_node(last_data_node.id)
            graph.clean_up()
Ejemplo n.º 8
0
    def test_get_next_operation_3(self):
        # Placeholder-+--->ScaleShift
        #             +-----^
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1', 'placeholder_2_data'),
                             ('placeholder_1_data', 'mul_1'),
                             ('placeholder_2_data', 'mul_1'),
                             ('mul_1', 'mul_1_data'),
                             ('mul_1_data', 'op_output')
                             ])

        res = get_next_operation(Node(graph, 'placeholder_1'))
        self.assertTrue(len(res) == 1 and res[0].id == 'mul_1', 'get_nex_operation returned wrong op')
Ejemplo n.º 9
0
    def test_get_next_operation_2(self):
        # Placeholder->ScaleShift->Mul->Add
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'mul_1'),
                             ('placeholder_1_data', 'add_1'),
                             ('mul_1', 'mul_1_data'), ('mul_1_data', 'add_1'),
                             ('add_1', 'add_1_data'),
                             ('add_1_data', 'op_output')])

        res = get_next_operation(Node(graph, 'placeholder_1'))
        self.assertTrue(
            len(res) == 2 and all([x.id in ['add_1', 'mul_1'] for x in res]),
            'get_nex_operation returned wrong op')
Ejemplo n.º 10
0
    def test_get_next_operation_1(self):
        # Placeholder->ScaleShift->Mul->Add
        graph = build_graph(nodes_attributes,
                            [('placeholder_1', 'placeholder_1_data'),
                             ('placeholder_1_data', 'scaleshift_1'),
                             ('scaleshift_1_w', 'scaleshift_1'),
                             ('scaleshift_1', 'scaleshift_1_data'),
                             ('scaleshift_1_data', 'mul_1'),
                             ('mul_1', 'mul_1_data'),
                             ('mul_1_data', 'add_1'),
                             ('add_1', 'add_1_data'),
                             ('add_1_data', 'op_output')
                             ])

        res = get_next_operation(Node(graph, 'mul_1'))
        self.assertTrue(len(res) == 1 and res[0].id == 'add_1', 'get_nex_operation returned wrong op')
    def find_and_replace_pattern(self, graph: Graph):
        for start_node in graph.pseudo_topological_sort():
            matched_nodes = []
            if self.is_node_match_for_optimization(start_node):
                next_node = start_node
                while self.is_node_match_for_optimization(next_node):
                    matched_nodes.append(next_node)
                    next_node[self.OPTIMIZED_NODE_FLAG] = True
                    next_nodes = get_next_operation(next_node)
                    if len(next_nodes) > 1:
                        log.debug('There are two consumers of the node {}. Stop matching sequence.'.format(
                            next_node.soft_get('name')))
                        break
                    next_node = next_nodes[0]
            # optimize sequence of three or more Transpose-Reshape nodes
            if len(matched_nodes) >= 3:
                self.optimize_permute_reshape_sequence(graph, matched_nodes)

        # run the RemoveRedundantReshapes to remove dummy (NOP) reshapes. After that we can run Transposes fusing
        FuseReshapesSequence().find_and_replace_pattern(graph)
        RemoveRedundantReshapes().find_and_replace_pattern(graph)
        FuseTransposesSequence().find_and_replace_pattern(graph)
Ejemplo n.º 12
0
    def find_and_replace_pattern(self, graph: Graph):
        reshape_nodes = graph.get_op_nodes(type='Reshape')
        for node in reshape_nodes:
            if not graph.has_node(node.id):
                # the Reshape node has been removed in the previous iteration
                continue
            if len(node.out_port(0).get_destinations()) == 1:
                log.debug('First phase for Reshape: {}'.format(
                    node.soft_get('name')))

                next_op = get_next_operation(node)[0]
                log.debug('second node: id={}, type={}'.format(
                    next_op.soft_get('id'), next_op.soft_get('type')))
                if next_op.has_valid('type') and next_op.type == 'Reshape':
                    dim_value = next_op.in_port(1).data.get_value()
                    if dim_value is None or 0 in dim_value or -1 in dim_value:
                        # we do not fuse reshape sequences with special symbols: 0, -1
                        continue

                    # Detected Reshape1 --> data --> Reshape2 pattern without side edges. Remove Reshape1
                    log.debug('Second phase for Reshape: {}'.format(
                        node.soft_get('name')))
                    remove_op_node_with_data_node(graph, node)
Ejemplo n.º 13
0
def concat_convolutions(graph: Graph, start_node: Node, last_node: Node):
    """
    This function converts group of convolutions into one
    """

    # Check that concatenation makes in the same order
    conv_nodes = get_next_operation(start_node)
    assert len(conv_nodes) == len(last_node.in_nodes())
    gconv = conv_nodes[0]

    for id in range(len(conv_nodes)):
        conv = conv_nodes[id]
        if conv.out_node().id != last_node.in_node(id).id:
            return False
        # Check that all convolutions have same weights shapes
        if not np.array_equal(conv.in_node(1).shape, gconv.in_node(1).shape):
            log.debug(
                'Grouped convolutions fusion : convolutions have different weights shape'
            )
            return False

    # Check that split and concat dims are valid
    channel_dim = gconv.channel_dims[0]
    if channel_dim != start_node.axis or channel_dim != last_node.axis:
        log.debug(
            'Grouped convolutions fusion : split or concat has wierd axis!')
        return False

    # Check that all convolutions has the same parameters
    conv_attrs = ['pad', 'stride']
    for attr in conv_attrs:
        for id in range(len(conv_nodes)):
            conv = conv_nodes[id]
            if not np.array_equal(gconv[attr], conv[attr]):
                log.debug(
                    'Grouped convolutions fusion : attrs {} doesn\'t match'.
                    format(attr))
                return False

    # Check that all Convolutions has biases (if exists)
    has_biases = False
    for id in range(len(conv_nodes)):
        conv = conv_nodes[id]
        if len(conv.in_nodes()) == 3:
            if not has_biases:
                has_biases = True
        elif has_biases:
            return False  # All convolution mast have biases

    # Check that all biases have same shape
    if has_biases:
        for id in range(len(conv_nodes)):
            conv = conv_nodes[id]
            if conv.in_node(2).shape != gconv.in_node(2).shape:
                log.debug(
                    'Group convolutions fusion : convolutions have different biases shape {} and {}'
                    .format(conv.in_node(2).shape,
                            gconv.in_node(2).shape))
                return False

    graph.remove_edge(gconv.in_node(0).id, gconv.id)
    graph.remove_edge(gconv.id, gconv.out_node().id)

    input = start_node.in_node(start_node.input_port)
    output = last_node.out_node()

    # Removing edges from data nodes to Split and Concat
    graph.remove_edge(input.id, start_node.id)
    graph.remove_edge(last_node.id, output.id)

    # Add edges to grouped convolution
    graph.add_edges_from([(input.id, gconv.id, {
        'in': 0
    }), (gconv.id, output.id, {
        'out': 0
    })])

    # Concatenation of convolutions
    weights_node = gconv.in_node(1)
    bias_node = gconv.in_node(2) if has_biases else None

    weights_value = np.array(weights_node.value)
    bias_value = np.array(bias_node.value) if has_biases else None

    feature_dim = 3 if graph.graph['layout'] == 'NHWC' else 1

    for conv in conv_nodes[1:]:
        weights_value = np.concatenate((weights_value, conv.in_node(1).value),
                                       axis=feature_dim)
        if has_biases:
            bias_value = np.concatenate((bias_value, conv.in_node(2).value),
                                        axis=-1)  # Not validated

    weights_node.value = np.array(weights_value)
    weights_node.shape = np.array(weights_value.shape)

    if has_biases:
        bias_node.value = np.array(bias_value)
        bias_node.shape = np.array(bias_value.shape)

    log.debug('Start node : {} Last node : {}  Nodes inside : {}'.format(
        start_node.id, last_node.id, len(start_node.out_nodes())))
    log.debug('Output shape : {}'.format(weights_value.shape))

    gconv.group = len(conv_nodes)
    gconv.output = weights_node.shape[feature_dim]
    gconv.output_shape[feature_dim] = weights_node.shape[feature_dim]

    return True