Example #1
0
    def replace_pattern(self, graph: Graph, match: dict):
        sparse_reshape = match['sparse_reshape']

        input_shape_value = sparse_reshape.in_port(1).data.get_value()
        output_shape_value = sparse_reshape.out_port(1).data.get_value()

        if input_shape_value is None or output_shape_value is None:
            raise Error(
                "Input shape and output shape values must be defined for node {}"
                .format(sparse_reshape.id))
        if not np.array_equal(input_shape_value, output_shape_value):
            raise Error(
                "Input shape and output shape values must be equal for node {}"
                .format(sparse_reshape.id))

        input_data_node1 = sparse_reshape.in_node(0)
        input_data_node2 = sparse_reshape.in_node(1)
        output_data_node1 = sparse_reshape.out_node(0)
        output_data_node2 = sparse_reshape.out_node(1)
        graph.remove_edge(input_data_node1.id, sparse_reshape.id)
        graph.remove_edge(sparse_reshape.id, output_data_node1.id)
        graph.remove_edge(input_data_node2.id, sparse_reshape.id)
        graph.remove_edge(sparse_reshape.id, output_data_node2.id)
        merge_data_nodes(graph, output_data_node1, input_data_node1)
        merge_data_nodes(graph, output_data_node2, input_data_node2)
        graph.remove_nodes_from(
            [sparse_reshape.id, input_data_node1.id, input_data_node2.id])
Example #2
0
    def find_and_replace_pattern(self, graph: Graph):
        for node in list(graph.nodes()):
            if node not in graph.nodes():
                continue
            permute_node = Node(graph, node)
            if permute_node.has_valid(
                    'type') and permute_node.type == 'Permute':
                list_of_permutes = [permute_node]
                # Get sequence of permutations
                node = permute_node
                while True:
                    next_ops = get_next_operation(node)
                    if len(next_ops) != 1:
                        break

                    next_op = next_ops[0]
                    if next_op.has_valid('type') and next_op.type == 'Permute':
                        list_of_permutes.append(next_op)
                        node = next_op
                    else:
                        break

                final_permutation = np.array(
                    [x for x in range(len(list_of_permutes[0].order))],
                    dtype=np.int64)
                for permute in list_of_permutes:
                    if not permute.has_valid('order'):
                        raise Error(
                            "Permute node {} has wrong attribute order = None".
                            format(permute.name))
                    final_permutation = final_permutation[np.array(
                        permute.order, dtype=np.int64)]

                if np.array_equal(
                        final_permutation,
                    [x for x in range(len(list_of_permutes[0].order))]):
                    first_data_node, last_data_node = list_of_permutes[
                        0].in_node(), list_of_permutes[-1].out_node()
                    graph.remove_edge(first_data_node.id,
                                      list_of_permutes[0].id)
                else:
                    if len(list_of_permutes) < 2:
                        continue
                    first_data_node, last_data_node = list_of_permutes[
                        0].out_node(), list_of_permutes[-1].out_node()
                    list_of_permutes[0].order = final_permutation
                    graph.remove_edge(first_data_node.id,
                                      first_data_node.out_node().id)

                graph.remove_edge(last_data_node.in_node().id,
                                  last_data_node.id)

                merge_data_nodes(graph, first_data_node, last_data_node)
                graph.remove_node(last_data_node.id)
                graph_clean_up_tf(graph)
Example #3
0
    def find_and_replace_pattern(self, graph: Graph):
        for permute_node in graph.get_op_nodes(type='Transpose'):
            if permute_node.id not in graph.nodes():
                continue

            list_of_permutes = [permute_node]
            # Get sequence of permutations
            node = permute_node
            while True:
                next_ops = get_next_operation(node)
                if len(next_ops) != 1:
                    break

                next_op = next_ops[0]
                if next_op.soft_get('type') == 'Transpose':
                    list_of_permutes.append(next_op)
                    node = next_op
                else:
                    break

            final_permutation = int64_array([
                x for x in range(
                    len(list_of_permutes[0].in_port(1).data.get_value()))
            ])
            for permute in list_of_permutes:
                order = permute.in_port(1).data.get_value()
                if order is None:
                    raise Error(
                        "Transpose node {} has wrong order for permute = None".
                        format(permute.name))
                final_permutation = final_permutation[int64_array(order)]

            if np.array_equal(final_permutation, [
                    x for x in range(
                        len(list_of_permutes[0].in_port(1).data.get_value()))
            ]):
                first_data_node, last_data_node = list_of_permutes[0].in_node(
                ), list_of_permutes[-1].out_node()
                graph.remove_edge(first_data_node.id, list_of_permutes[0].id)
            else:
                if len(list_of_permutes) < 2:
                    continue
                first_data_node, last_data_node = list_of_permutes[0].out_node(
                ), list_of_permutes[-1].out_node()
                list_of_permutes[0].in_port(1).data.set_value(
                    final_permutation)
                graph.remove_edge(first_data_node.id,
                                  first_data_node.out_node().id)

            graph.remove_edge(last_data_node.in_node().id, last_data_node.id)

            merge_data_nodes(graph, first_data_node, last_data_node)
            graph.remove_node(last_data_node.id)
            graph.clean_up()
Example #4
0
def _fuse_linear_sequence(graph: nx.MultiDiGraph, start_node: Node):
    """
    This function finds the sequence of Mul/Add operations and replaces this sequence with two ops (Mul->Add).
    :param graph:
    :param start_node: The first operation of the sequence
    """
    fnodes = [start_node]
    while True:
        node = fnodes[-1]
        data_node = node.out_node()
        if (len(data_node.out_nodes()) != 1):
            break
        if (data_node.out_node().op in ['Mul', 'Add']) and get_value_id(
                data_node.out_node()) is not None and data_node.out_node(
                ).soft_get('can_be_fused') == True:
            fnodes.append(data_node.out_node())
        else:
            break

    if len(fnodes) == 1 or (len(fnodes) == 2 and fnodes[0].op == 'Mul'
                            and fnodes[1].op == 'Add'):
        return False

    input_shape = start_node.in_node(get_tensor_id(start_node)).shape

    init_dims_cnt = len(
        input_shape) - 2 if graph.graph['layout'] == 'NCHW' else 1

    mul = np.ones([1 for x in range(init_dims_cnt)])
    add = np.zeros([1 for x in range(init_dims_cnt)])

    first_mul_name = None
    first_add_name = None

    for idx in range(len(fnodes)):
        node = fnodes[idx]
        const_node = get_value_id(node)
        if node.op == 'Mul':
            if first_mul_name is None:
                first_mul_name = node.name
            mul = mul * node.in_node(const_node).value
            add = add * node.in_node(const_node).value
        elif node.op == 'Add':
            if first_add_name is None:
                first_add_name = node.name
            add = add + node.in_node(const_node).value

    # If mul is scalar we broadcast it to biases shape
    if mul.shape != add.shape and len(mul.shape) == 1 and mul.shape[0] == 1:
        mul = np.array([mul[0] for x in range(add.shape[0])])

    assert (np.array_equal(fnodes[0].in_node(get_tensor_id(fnodes[0])).shape,
                           fnodes[-1].out_node().shape))

    mul_node = Mul(
        graph,
        dict(name=first_mul_name +
             '/Fused_Mul_' if first_mul_name is not None else ''))
    add_node = Add(
        graph,
        dict(name=first_add_name +
             '/Fused_Add_' if first_add_name is not None else ''))

    in_node = fnodes[0].in_node(get_tensor_id(fnodes[0]))
    out_node = fnodes[-1].out_node()

    graph.remove_edge(in_node.id, fnodes[0].id)
    graph.remove_edge(fnodes[-1].id, out_node.id)

    # Remove deleted subgraph
    for node in fnodes:
        for tmp_node in node.in_nodes().values():
            # Remove node only if it has one consumer (for case with shared weights)
            if len(tmp_node.out_nodes()) == 1:
                graph.remove_node(tmp_node.id)
        for tmp_node in node.out_nodes().values():
            graph.remove_node(tmp_node.id)
        graph.remove_node(node.id)
    """
    Four cases considered below:
        1. Mul and Add have valid values (mul value != 1 and add value != 0)
        2. Only Mul has valid values, so we add only Mul node
        3. Only Add has valid values, so we add only Add node
        4. When Mul and Add has not valid values we just merge two data nodes
    """
    if any([x != 0
            for x in np.nditer(add)]) and any([x != 1
                                               for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[
            mul_node.create_node_with_data([in_node, data_mul]), data_add
        ],
                                       data_nodes=out_node)
    elif any([x != 1 for x in np.nditer(mul)]):
        data_mul = Op.create_input_data_node(graph, "data_mul_", np.array(mul))
        mul_node.create_node_with_data(inputs=[in_node, data_mul],
                                       data_nodes=out_node)
    elif any([x != 0 for x in np.nditer(add)]):
        data_add = Op.create_input_data_node(graph, "data_add_", np.array(add))
        add_node.create_node_with_data(inputs=[in_node, data_add],
                                       data_nodes=out_node)
    else:
        merge_data_nodes(graph, out_node, in_node)
        graph.remove_node(in_node.id)

    log.debug('Fused {} operations'.format(len(fnodes)))
    return True
Example #5
0
def convert_scale_shift_to_mul_add(graph: nx.MultiDiGraph):
    nodes = [
        Node(graph, node) for node in graph.nodes()
        if Node(graph, node).soft_get('op') == 'ScaleShift'
    ]
    for node in nodes:
        if node.soft_get('can_be_fused') is False:
            continue

        has_biases = True
        has_weights = True
        # We don't need zero biases
        if len(node.in_nodes()) < 3 or all(
            [x == 0 for x in node.in_node(2).value]):
            has_biases = False
        input_node = node.in_node(0)
        scale_node = node.in_node(1)
        shift_node = node.in_node(2) if has_biases else None
        output_node = node.out_node()

        if scale_node.has_valid("value") and all(
            [x == 1 for x in scale_node.value]):
            has_weights = False

        mul_node = Mul(graph, dict(name=node.name + "/Mul_"))
        add_node = Add(graph, dict(name=node.name + "/Add_"))

        # Disconnect ScaleShift node
        graph.remove_edge(input_node.id, node.id)
        graph.remove_edge(node.id, output_node.id)

        # Expand dims for current layout
        broadcast_dims_cnt = len(
            input_node.shape) - 2 if graph.graph['layout'] == 'NCHW' else 0
        if scale_node.has_valid("value"):
            Op.expand_node_shape(scale_node, broadcast_dims_cnt)
        else:
            # insert reshape to make shapes similar
            reshape_dims = np.zeros(len(input_node.shape), dtype=np.int64)
            for i in range(0, node.axis):
                reshape_dims[i] = 1
            for i in range(node.axis, node.axis + len(scale_node.shape)):
                reshape_dims[i] = scale_node.shape[i - node.axis]
            for i in range(node.axis + len(scale_node.shape),
                           len(input_node.shape)):
                reshape_dims[i] = 1
            reshape = Reshape(
                graph,
                dict(name=scale_node.name + "/Broadcast_", dim=reshape_dims))
            scale_node = reshape.create_node_with_data(inputs=[scale_node])

        Op.expand_node_shape(shift_node, broadcast_dims_cnt)

        # Connect input->mul->out->add->out
        if has_biases:
            add_node.create_node_with_data(inputs=[
                mul_node.create_node_with_data(
                    inputs=[input_node, scale_node]), shift_node
            ],
                                           data_nodes=output_node)
        elif has_weights:
            mul_node.create_node_with_data(inputs=[input_node, scale_node],
                                           data_nodes=output_node)
        else:
            merge_data_nodes(graph, input_node, output_node)
            graph.remove_node(output_node.id)