def conv2d(block: Block,
           inputs,
           out_channels,
           kernel=(1, 1),
           stride=(1, 1),
           padding=(0, 0),
           groups=1,
           act="relu",
           is_exit=False):
    """
    Add a convolution operator to the end of given block.

    :param block: ios.Block
        The block to add the operator

    :param inputs: Sequence[Sequence[Value]]
        The inputs of the convolution. 'inputs' contains a list of terms. A term contains a list of values. The values
        in a term are added up. The terms are concatenated along with the channel dimension.

    :param out_channels:
        The number of output channels.

    :param kernel: Tuple[int, int], default (1, 1)
        The kernel size.

    :param stride: Tuple[int, int], default (1, 1)
        The stride size.

    :param padding: Tuple[int, int], default (0, 0)

    :param groups: int, default 1
        The number of groups. It must be a common factor of the input channels and output channels.

    :param act: str, default 'relu'
        The activation applied to the output of convolution.

    :param is_exit: boolean, default False
        Whether this operator is the exit operator of the block.

    :return: Value
        A value represents the output of the operator.
    """
    name = new_name()
    conv = Conv(name, name, inputs, out_channels, kernel, stride, padding,
                groups, act, None)
    conv.infer_shape()
    for ti, term in enumerate(inputs):
        for vi, value in enumerate(term):
            value.node.uses.append((conv, ti, vi))
    if is_exit:
        block.exit_node = conv
    else:
        block.inner_nodes.append(conv)
    return Value(conv, 0, out_channels)
def relu_conv(block: Block,
              inputs,
              out_channels,
              kernel=(1, 1),
              stride=(1, 1),
              padding=(0, 0),
              groups=1,
              is_exit=False):
    """
    Add a compound operator that contains a Relu operator and a convolution operator at the end of given block.

    :param block: ios.Block
        The block to add the operator

    :param inputs: Sequence[Sequence[Value]]
        The inputs of the convolution. 'inputs' contains a list of terms. A term contains a list of values. The values
        in a term are added up. The terms are concatenated along with the channel dimension.

    :param out_channels: int
        The number of output channels of the separate convolution.

    :param kernel: Tuple[int, int], default (1, 1)
        The kernel size of the convolution.

    :param stride: Tuple[int, int], default (1, 1)
        The stride of the convolution.

    :param padding: Tuple[int, int], default (0, 0)
        The padding size of the convolution.

    :param is_exit: boolean, default False
        Whether this operator is the exit operator of the block.

    :return: Value
        A value represents the output of the operator.
    """
    names = [new_name(), new_name()]
    nodes = [
        Relu(names[0], names[0], inputs, None),
        Conv(names[1],
             names[1],
             None,
             out_channels=out_channels,
             kernel=kernel,
             stride=stride,
             padding=padding,
             groups=groups,
             act="identity",
             output_shape=None),
    ]
    return sequential(block,
                      hint_name='ReluConv',
                      nodes=nodes,
                      is_exit=is_exit)
def latency(stage: Tuple[List[List[int]], str], block, merge_latency, parallel_latency, cost_model, idn, nid,
            batch_size, warmup, number, repeat) -> float:
    """
    Measure the latency of a stage.
    """
    stage_seqs, qtype = stage
    ss = sum(1 << u for u in itertools.chain(*stage_seqs))
    if qtype == 'merge':
        if ss in merge_latency:
            return merge_latency[ss]
        snodes = state2nset(ss, idn)
        if len(stage_seqs) == 1:
            assert len(snodes) == 1
            merge_latency[ss] = float(
                np.mean(cost_model.get_stage_latency([[snodes[0]]], batch_size, warmup, number, repeat)))
        else:
            convs = [nd for nd in snodes if isinstance(nd, Conv)]
            assert len(convs) == len(snodes)
            terms = get_input(convs, block, nid, idn)
            out_channels = sum(nd.out_channels for nd in convs)
            kernel = (max(nd.kernel[0] for nd in convs), max(nd.kernel[1] for nd in convs))
            stride = (convs[0].stride[0], convs[0].stride[1])
            padding = (max(nd.padding[0] for nd in convs), max(nd.padding[1] for nd in convs))
            groups = convs[0].groups
            act = convs[0].act
            conv = Conv('c', '', inputs=terms, out_channels=out_channels, kernel=kernel, stride=stride, padding=padding,
                        groups=groups, act=act, output_shape=None)
            merge_latency[ss] = float(
                np.mean(cost_model.get_stage_latency([[conv]], batch_size, warmup, number, repeat)))
        return merge_latency[ss]
    elif qtype == 'parallel':
        if ss in parallel_latency:
            return parallel_latency[ss]
        stage_seqs_nodes = []
        for seq in stage_seqs:
            seq_nodes = []
            for uid in seq:
                seq_nodes.append(idn[uid])
            stage_seqs_nodes.append(seq_nodes)
        parallel_latency[ss] = float(
            np.mean(cost_model.get_stage_latency(stage_seqs_nodes, batch_size, warmup, number, repeat)))
        return parallel_latency[ss]
    else:
        raise ValueError
def construct(stage_list: List[Tuple[List[List[int]],
                                     str]], block, constructed_blocks,
              graph_enter, idn, nid, compute_weight) -> Block:
    """
    Construct the optimized computation graph.
    """
    inner_nodes = []
    stages = []
    if len(constructed_blocks) == 0:
        new_enter_node = graph_enter
    else:
        new_enter_node = constructed_blocks[-1].exit_node
    out_dict = {
        block.enter_node: (new_enter_node, 0, new_enter_node.output_shape[0])
    }

    def merge_inputs(inputs: List[List[Value]]):
        merge_inputs_flag = True
        if merge_inputs_flag:
            while True:  # merge input
                merged = False
                for i in range(1, len(inputs)):
                    if len(inputs[i - 1]) > 1 or len(inputs[i]) > 1:
                        continue
                    va, vb = inputs[i - 1][0], inputs[i][0]
                    if va.node == vb.node and va.end == vb.begin:
                        vc = Value(va.node, va.begin, vb.end)
                        inputs = inputs[:i - 1] + [[vc]] + inputs[i + 1:]
                        merged = True
                        break
                if not merged:
                    break
        return inputs

    def get_new_terms(terms, new_node, do_sort=True):
        nterms = []
        for ti, term in enumerate(terms):
            nterm = []
            for vi, value in enumerate(term):
                nv = out_dict[value.node]
                nterm.append(
                    Value(nv[0], nv[1] + value.begin, nv[1] + value.end))
            nterms.append(nterm)
        if do_sort:
            nterms = sorted(nterms,
                            key=lambda nterm: (len(nterm), nterm[0].node.name)
                            )  # git rid of duplicates terms
        for ti, term in enumerate(nterms):
            for vi, value in enumerate(term):
                value.node.uses.append((new_node, ti, vi))
        return nterms

    def copy_weights(dst_node, src_node):
        if isinstance(dst_node, Conv):
            assert isinstance(src_node, Conv)
            dst_node.weight = src_node.weight.copy()
            dst_node.bias = src_node.bias.copy()

    for stage_seqs, qtype in stage_list:
        if qtype == 'merge' and len(stage_seqs) > 1:  # only merge convolutions
            inodes = list(itertools.chain(*stage_seqs))
            snodes = [
                nd for nd in [idn[i] for i in inodes] if isinstance(nd, Conv)
            ]
            assert len(snodes) == len(inodes)
            #  get the parameters of merged conv
            out_channels = sum(nd.out_channels for nd in snodes)
            kernel = (max(nd.kernel[0] for nd in snodes),
                      max(nd.kernel[1] for nd in snodes))
            stride = snodes[0].stride[0], snodes[0].stride[1]
            padding = (max(nd.padding[0] for nd in snodes),
                       max(nd.padding[1] for nd in snodes))
            groups = 1
            #  construct merged conv
            terms = get_input(snodes, block, nid, idn)
            new_node = Conv(snodes[0].name, " ".join(nd.hint_name
                                                     for nd in snodes), None,
                            out_channels, kernel, stride, padding, groups,
                            snodes[0].act, None)
            new_node.inputs = get_new_terms(terms, new_node)
            new_node.infer_shape()
            if compute_weight:
                new_node.weight = np.zeros(shape=new_node.weight_shape,
                                           dtype=np.float32)
                new_node.bias = np.zeros(shape=new_node.bias_shape,
                                         dtype=np.float32)
            #  set weights and out_dict
            out_begin = 0
            for node in snodes:
                for term in node.inputs:
                    in_begin = 0
                    for ti, t in enumerate(terms):
                        found = True
                        if len(term) != len(t):
                            found = False
                        else:
                            for va, vb in zip(term, t):
                                if not (va.node == vb.node and va.begin
                                        == vb.begin and va.end == vb.end):
                                    found = False
                                    break
                        if found:
                            break
                        in_begin += t[0].length
                    in_end = in_begin + term[0].length
                    out_end = out_begin + node.out_channels
                    kernel_begin = (kernel[0] - node.kernel[0]) // 2, (
                        kernel[1] - node.kernel[1]) // 2
                    kernel_end = kernel_begin[0] + node.kernel[
                        0], kernel_begin[1] + node.kernel[1]
                    if compute_weight:
                        new_node.weight[
                            out_begin:out_end, in_begin:in_end,
                            kernel_begin[0]:kernel_end[0],
                            kernel_begin[1]:kernel_end[1]] = node.weight.copy(
                            )
                        new_node.bias[out_begin:out_end] = node.bias.copy()
                out_dict[node] = (new_node, out_begin,
                                  out_begin + node.out_channels)
                out_begin += node.out_channels
            new_node.inputs = merge_inputs(new_node.inputs)
            new_node.infer_shape()
            inner_nodes.append(new_node)
            stages.append([[new_node.name]])
        else:
            seq_in_stage = []
            for seq in stage_seqs:
                inodes = seq
                snodes = [idn[i] for i in inodes]
                new_nodes = []
                for snode in snodes:
                    snode_config = snode.export_config()
                    if isinstance(snode, Sequential):
                        snode_config["nodes"][0]["inputs"] = []
                        new_node = Node.from_config(snode_config, {})
                        new_node.nodes[0].inputs = merge_inputs(
                            get_new_terms(snode.nodes[0].inputs,
                                          new_node,
                                          do_sort=False))
                        new_node.inputs = new_node.nodes[0].inputs
                        if compute_weight:
                            for dst_nd, src_nd in zip(new_node.nodes,
                                                      snode.nodes):
                                copy_weights(dst_nd, src_nd)
                        new_node.infer_shape()
                        new_nodes.append(new_node)
                        out_dict[snode] = (new_node, 0,
                                           new_node.output_shape[0])
                    else:
                        snode_config["inputs"] = []
                        new_node = Node.from_config(snode_config, {})
                        new_node.inputs = merge_inputs(
                            get_new_terms(snode.inputs,
                                          new_node,
                                          do_sort=False))
                        if compute_weight:
                            copy_weights(new_node, snode)
                        new_node.infer_shape()
                        new_nodes.append(new_node)
                        out_dict[snode] = (new_node, 0,
                                           new_node.output_shape[0])
                inner_nodes.extend(new_nodes)
                seq_in_stage.append([new_node.name for new_node in new_nodes])
            stages.append(seq_in_stage)
    new_exit_node = inner_nodes.pop()
    return Block(new_enter_node, new_exit_node, inner_nodes, stages)
def sep_conv(block: Block,
             inputs: List[List[Value]],
             out_channels,
             kernel,
             stride,
             padding,
             is_exit=False):
    """
    Add a separate convolution at the end of given block. This operator is a compound operator consists of a depth-wise
    convolution and a point-wise convolution.

    :param block: ios.Block
        The block to add the operator

    :param inputs: Sequence[Sequence[Value]]
        The inputs of the convolution. 'inputs' contains a list of terms. A term contains a list of values. The values
        in a term are added up. The terms are concatenated along with the channel dimension.

    :param out_channels:
        The number of output channels of the point-wise convolution.

    :param kernel:
        The kernel size of the depth-wise convolution.

    :param stride:
        The stride size of the depth-wise convolution.

    :param padding:
        The padding size of the depth-wise convolution.

    :param is_exit: boolean, default False
        Whether this operator is the exit operator of the block.

    :return: Value
        A value represents the output of the operator.
    """
    names = [new_name(), new_name()]
    in_channels = sum(term[0].node.output_shape[0] for term in inputs)
    nodes = [
        Conv(names[0],
             names[0],
             inputs,
             out_channels=in_channels,
             kernel=kernel,
             stride=stride,
             padding=padding,
             groups=in_channels,
             act="identity",
             output_shape=None),
        Conv(names[1],
             names[1],
             None,
             out_channels=out_channels,
             kernel=(1, 1),
             stride=(1, 1),
             padding=(0, 0),
             groups=1,
             act="identity",
             output_shape=None)
    ]
    return sequential(block, hint_name='SepConv', nodes=nodes, is_exit=is_exit)