예제 #1
0
파일: space.py 프로젝트: Wuqiman/FlexTensor
 def __init__(self, dim, total, allow_non_divisible='off'):
     super(SplitSpace, self).__init__()
     self.total = total
     self.allow_non_divisible = allow_non_divisible
     self.dim = dim
     self.static_entities = any_factor_split(
         total, dim, allow_non_divisible=allow_non_divisible)
     self.size = len(self.static_entities)
     self.num_direction = dim * (dim - 1)
     self.directions = []
     for i in range(self.dim):
         for j in range(self.dim):
             if i != j:
                 self.directions.append((i, j))
     self.type_key = "split"
예제 #2
0
    s = tvm.create_schedule(outputs.op)
    schedule_yolo_conv_llvm(s, outputs, inputs, weight, config)

    arg_bufs = [inputs, weight, outputs]
    stmt = tvm.lower(s, arg_bufs, simple_mode=True)
    # print(stmt)
    dev_id = 0
    time_cost = _evaluate(s, arg_bufs, "llvm", dev_id, 10)
    print("Yolo conv17 use", time_cost, "ms\n")
    return time_cost


if __name__ == "__main__":
    import random
    config = Config()
    # k_split_lst = list(filter(lambda x: x[1] == 4 and x[3] == 2, any_factor_split(512, 4)))
    # print(k_split_lst)
    # print(len(k_split_lst))
    rc_split_lst = list(filter(lambda x: x[1] == 1, any_factor_split(1024, 3)))
    flop = 14 * 14 * 512 * (1024 + 1023)
    record = []
    for ele in rc_split_lst:
        # config.k_factors = ele
        config.rc_factors = ele
        time_cost = try_yolo_conv(1, config)
        record.append((ele, flop / (time_cost / 1e3) / 1e9))
    for ele in record:
        print(ele[0][2])
    for ele in record:
        print(ele[1])
예제 #3
0
def gemm_config(M, N, K, logits_dict):
    spatial_split_parts = 4
    reduce_split_parts = 4
    unroll_max_factor = 10

    sy = any_factor_split(M, spatial_split_parts)
    sx = any_factor_split(N, spatial_split_parts)
    sk = any_factor_split(K, reduce_split_parts)
    unroll = []
    for i in range(1):
        for j in range(unroll_max_factor + 1):
            unroll.append([i, 2**j])

    def _rational(lst, max_val):
        return torch.FloatTensor([[y / float(max_val) for y in x]
                                  for x in lst])

    nsy = _rational(sy, M)
    nsx = _rational(sx, N)
    nsk = _rational(sk, K)

    n_unroll = torch.FloatTensor([[x[0] / float(2) + 0.5,
                                   math.log2(x[1]) / 1] for x in unroll])

    # get logits
    spatial_logits = logits_dict["spatial"]
    reduce_logits = logits_dict["reduce"]
    unroll_logits = logits_dict["unroll"]

    # make choice
    feature_size = len(logits_dict["spatial"][0])
    split_classifier = model.MLP(feature_size + spatial_split_parts)
    unroll_classifier = model.MLP(feature_size + 2)
    cy = torch.argmax(
        split_classifier(
            torch.cat([
                nsy,
                torch.zeros([nsy.shape[0], feature_size]) + spatial_logits[0]
            ],
                      dim=1)))
    cx = torch.argmax(
        split_classifier(
            torch.cat([
                nsx,
                torch.zeros([nsx.shape[0], feature_size]) + spatial_logits[1]
            ],
                      dim=1)))
    ck = torch.argmax(
        split_classifier(
            torch.cat([
                nsk,
                torch.zeros([nsk.shape[0], feature_size]) + reduce_logits[0]
            ],
                      dim=1)))
    cu = torch.argmax(
        unroll_classifier(
            torch.cat([
                n_unroll,
                torch.zeros([n_unroll.shape[0], feature_size]) + unroll_logits
            ],
                      dim=1)))

    print(cy, cx, ck, cu)

    # print choice
    print("Print choice")
    print("split y =", sy[cy])
    print("split x =", sx[cx])
    print("split k =", sk[ck])
    print("unroll", unroll[cu])

    # make config
    op_config = [{
        "spatial": [sy[cy], sx[cx]],
        "reduce": [sk[ck]],
        "inline": [],
        "unroll": [unroll[cu]]
    }]
    graph_config = {"spatial": [], "reduce": [], "inline": [[0]], "unroll": []}
    return Config(op_config, graph_config)