예제 #1
0
def fuse_linear_ops(graph: nx.MultiDiGraph):
    """
    This function makes fusing of linear operations (Mul,Add) to Convolution/FC.
    """
    fuse_count = 0

    # Fusion in backward direction
    nodes = pseudo_topological_sort(graph)
    for idx in nodes:
        node = Node(graph, idx)
        is_fused = False

        # Fuse Mul to Convolution/FC
        if node.soft_get('op') == 'Mul' and get_value_id(
                node) is not None and node.soft_get('can_be_fused') == True:
            fuse_nodes = backward_bfs(
                node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
            is_fused = _fuse_mul(graph, node, fuse_nodes)

        # Fuse Add to Convolution/FC
        if node.soft_get('op') == 'Add' and get_value_id(
                node) is not None and node.soft_get('can_be_fused') == True:
            fuse_nodes = backward_bfs(
                node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
            is_fused = _fuse_add(graph, node, fuse_nodes)

        fuse_count += is_fused

    # Fusion in forward direction
    nodes = pseudo_topological_sort(graph, reverse=True)
    for idx in nodes:
        node = Node(graph, idx)
        is_fused = False

        # Fuse Mul to Convolution/FC
        if node.soft_get('op') == 'Mul' and get_value_id(
                node) is not None and node.soft_get('can_be_fused') == True:
            fuse_nodes = forward_bfs(
                node, [], ['Convolution', 'Deconvolution', 'FullyConnected'])
            is_fused = _fuse_mul(graph, node, fuse_nodes, False)

        # Fuse Add to Convolution/FC
        if node.soft_get('op') == 'Add' and get_value_id(
                node) is not None and node.soft_get('can_be_fused') == True:
            fuse_nodes = forward_bfs(node, [], ['FullyConnected'])
            is_fused = _fuse_add(graph, node, fuse_nodes, False)

        fuse_count += is_fused

    log.debug("Fused {} nodes".format(fuse_count))
예제 #2
0
    def find_and_replace_pattern(self, graph: nx.MultiDiGraph):
        intervals = {}
        for n in pseudo_topological_sort(graph):
            node = Node(graph, n)
            if not node.has('op') or (node.op != 'FakeQuantWithMinMaxVars'
                                      and node.op != 'Quantize'):
                continue
            if node.op == 'Quantize':
                # check if input range matches output range
                low_match = np.all(
                    node.in_node(1).value == node.in_node(3).value)
                high_match = np.all(
                    node.in_node(2).value == node.in_node(4).value)
                if not low_match or not high_match:
                    continue

            prev_node = node.in_node().in_node()
            prev_node_id = prev_node.id
            prev_node_out_shape = prev_node.out_node()['shape']
            C = prev_node_out_shape[1]
            assert node.in_node(1).value.size == 1
            assert node.in_node(2).value.size == 1
            min = ', '.join([str(node.in_node(1).value.flatten()[0])] * C)
            max = ', '.join([str(node.in_node(2).value.flatten()[0])] * C)
            intervals[prev_node_id] = {'min': min, 'max': max}
        if intervals:
            if 'statistics' not in graph.graph:
                graph.graph['statistics'] = intervals
            else:
                graph.graph['statistics'].update(intervals)
            remove_op_nodes(graph, {'op': 'FakeQuantWithMinMaxVars'})
            remove_op_nodes(graph, {'op': 'Quantize'})
예제 #3
0
 def find_and_replace_pattern(self, graph: Graph):
     for n in pseudo_topological_sort(graph):
         if graph.node[n]['kind'] == 'data' or graph.node[n]['op'] != 'Switch':
             continue
         switch_op_node = Node(graph, n)
         pred_id_data_node = switch_op_node.in_node(1)
         graph.remove_edge(pred_id_data_node.id, switch_op_node.id)
         remove_op_node_with_data_node(graph, switch_op_node)
예제 #4
0
def fuse_mul_add_sequence(graph: Graph):
    """
    This function finds first valid Mul/Add node and pass it to fuse_linear_sequence where full sequence will be found
    """
    while True:
        is_fused = False
        for idx in list(pseudo_topological_sort(graph)):
            if idx in graph:
                node = Node(graph, idx)
                if node.soft_get('op') in ['Mul','Add'] and get_value_id(node) is not None and node.soft_get('can_be_fused') == True:
                    is_fused |= _fuse_linear_sequence(graph, node)
        if not is_fused:
            break
예제 #5
0
def _stride_propagation(graph: Graph, spatial_dims):
    """
    This function do stride propagation for all op nodes
    """
    nodes = [Node(graph, x) for x in pseudo_topological_sort(graph, reverse=True) if
             Node(graph, x).kind == 'op' and Node(graph, x).soft_get('type') != 'Const']

    for node in nodes:
        if node.soft_get('type') in supported_ops:
            op = supported_ops[node.type]
            # Add node attrs
            for key in op['attrs'].keys():
                node[key] = op['attrs'][key]
            op['stride_prop'](graph, node, spatial_dims, True)
        else:
            _simple_stride_prop(graph, node, spatial_dims, False)
예제 #6
0
def mark_const_producer_nodes(graph: nx.MultiDiGraph):
    """
    Mark nodes that produce constant values.
    :param graph: graph to operate on.
    :return: .
    """
    nx.set_node_attributes(G=graph, name='is_const_producer', values=True)

    for n in pseudo_topological_sort(graph):
        node = Node(graph, n)
        for input, output, attrs in graph.in_edges(n, data=True):
            if 'control_flow_edge' in attrs and attrs['control_flow_edge']:
                graph.node[input]['is_const_producer'] = False
                graph.node[output]['is_const_producer'] = False

        if not node.has('value') or node.value is None:
            for input, _ in graph.in_edges(n):
                graph.node[input]['is_const_producer'] = False
예제 #7
0
def shape_inference(graph: Graph):
    nodes = pseudo_topological_sort(graph)
    for node in nodes:
        node = Node(graph, node)
        if node.has_and_set('need_shape_inference'):
            old_out_shapes = [
                port.data.get_shape() for port in node.out_ports().values()
            ]
            node.infer(node)
            new_out_shapes = [
                port.data.get_shape() for port in node.out_ports().values()
            ]
            for shape1, shape2 in zip(old_out_shapes, new_out_shapes):
                if shape1 is not None and not np.array_equal(shape1, shape2):
                    raise Error(
                        "After partial shape inference were found shape collision for node {} (old shape: {}, new shape: {})"
                        .format(node.name, shape1, shape2))
            node.need_shape_inference = False
예제 #8
0
def stride_optimization(graph: Graph):
    """
    This is main function for stride optimization pass
    """
    layout = graph.graph['layout']
    if layout == 'NCHW':
        spatial_dims = np.array([2, 3])
    elif layout == 'NHWC':
        spatial_dims = np.array([1, 2])
    else:
        log.warning('STRIDE PROP: layout {} is not supported'.format(layout))
        return
    _stride_propagation(graph, spatial_dims)

    nodes = [Node(graph, x) for x in pseudo_topological_sort(graph) if
             Node(graph, x).soft_get('is_partial_inferred') == False]
    for node in nodes:
        node.infer(node)
예제 #9
0
def grouped_convolutions_fusing(graph: Graph):
    while True:
        is_fused = False
        graph_clean_up(graph, ['TFCustomSubgraphCall', 'Shape'])
        nodes = pseudo_topological_sort(graph)
        for idx in nodes:
            node = Node(graph, idx)
            if node.kind == 'op' and len(node.out_nodes()) > 1:
                if node.soft_get('can_be_fused') == False:
                    continue

                is_valid_convolutions = True
                last_layer = None

                next_nodes = get_next_operation(node)
                # Check that all operation after this one are Convolutions
                # and all convolutions has same output
                if len(next_nodes) > 1 and all(
                        _node.soft_get('type') in
                    ['Convolution', 'Deconvolution'] for _node in next_nodes):
                    for conv in next_nodes:
                        conv_outputs = get_next_operation(conv)
                        if conv.soft_get('can_be_fused') == False:
                            is_valid_convolutions = False
                        if len(conv_outputs) != 1:
                            is_valid_convolutions = False
                        if last_layer is None:
                            last_layer = conv_outputs[0].id
                        elif conv_outputs[0].id != last_layer:
                            is_valid_convolutions = False

                    if is_valid_convolutions:
                        is_fused = concat_convolutions(graph, node,
                                                       Node(graph, last_layer))
                        if is_fused:
                            break

        if not is_fused:
            break