示例#1
0
def get_outside_index(length, level, offset_cache=None, cuda=False):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)
    index = OutsideIndex()
    pairs = index.get_all_pairs(level, length)

    par_lvl, par_pos = [], []
    sis_lvl, sis_pos = [], []

    for pair in pairs:
        par, sis = pair
        par_lvl.append(par[0])
        par_pos.append(par[1] - par[0])
        sis_lvl.append(sis[0])
        sis_pos.append(sis[1] - sis[0])

    device = torch.cuda.current_device() if cuda else None

    # Parent
    index = []
    for lvl, pos in zip(par_lvl, par_pos):
        offset = offset_cache[lvl]
        idx = offset + pos
        index.append(idx)
    par_index = torch.tensor(index, dtype=torch.long, device=device)

    # Sibling
    index = []
    for lvl, pos in zip(sis_lvl, sis_pos):
        offset = offset_cache[lvl]
        idx = offset + pos
        index.append(idx)
    sis_index = torch.tensor(index, dtype=torch.long, device=device)

    return par_index, sis_index
示例#2
0
def get_inside_index_unique(length, level, offset_cache=None, cuda=False):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)
    index = InsideIndex()
    pairs = index.get_all_pairs(level, length)

    L = length - level
    n_constituents = len(pairs) // L
    idx_set = set()

    for i in range(n_constituents):
        lvl_l = i
        lvl_r = level - i - 1
        lstart, lend = 0, L
        rstart, rend = length - L - lvl_r, length - lvl_r

        if lvl_l < 0:
            lvl_l = length + lvl_l
        if lvl_r < 0:
            lvl_r = length + lvl_r

        for pos in range(lstart, lend):
            offset = offset_cache[lvl_l]
            idx = offset + pos
            idx_set.add(idx)

        for pos in range(rstart, rend):
            offset = offset_cache[lvl_r]
            idx = offset + pos
            idx_set.add(idx)

    device = torch.cuda.current_device() if cuda else None
    idx_lst = torch.tensor(list(idx_set), dtype=torch.int64,
                           device=device).flatten()
    return idx_lst
示例#3
0
def get_outside_target(length, level, offset_cache=None, cuda=False):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)

    L = length - level
    offset = offset_cache[level]

    target = []
    for i in range(L - 1):
        target.extend(range(offset, offset + L))

    device = torch.cuda.current_device() if cuda else None
    target = torch.tensor(target, dtype=torch.long, device=device)
    return target
示例#4
0
def get_inside_components(length, level, offset_cache=None):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)
    index = InsideIndex()
    pairs = index.get_all_pairs(level, length)

    L = length - level
    n_constituents = len(pairs) // L
    output = []

    for i in range(n_constituents):
        index_l, index_r = [], []
        span_x, span_l, span_r = [], [], []

        l_level = i
        r_level = level - l_level - 1

        l_start = 0
        l_end = L

        r_start = length - L - r_level
        r_end = length - r_level

        if l_level < 0:
            l_level = length + l_level
        if r_level < 0:
            r_level = length + r_level

        # The span being targeted.
        for pos in range(l_start, l_end):
            span_x.append((level, pos))

        # The left child.
        for pos in range(l_start, l_end):
            offset = offset_cache[l_level]
            idx = offset + pos
            index_l.append(idx)
            span_l.append((l_level, pos))

        # The right child.
        for pos in range(r_start, r_end):
            offset = offset_cache[r_level]
            idx = offset + pos
            index_r.append(idx)
            span_r.append((r_level, pos))

        output.append((index_l, index_r, span_x, span_l, span_r))

    return output
示例#5
0
def get_outside_encoded_index(length, offset_cache=None, cuda=False):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)
    index = OutsideIndex()

    outside_leaf = {}
    node2leaf = index.get_leaf(length)

    device = torch.cuda.current_device() if cuda else None

    for (lvl, pos), leaf in node2leaf.items():
        offset = offset_cache[lvl]
        idx = offset + pos
        leaf = torch.tensor(leaf, dtype=torch.long, device=device)
        outside_leaf[idx] = leaf

    return outside_leaf
示例#6
0
def get_outside_components(length, level, offset_cache=None):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)
    index = OutsideIndex()
    pairs = index.get_all_pairs(level, length)
    output = []

    for pair in pairs:
        par, sis = pair
        par_lvl = par[0]
        par_pos = par[1] - par[0]
        par_span = (par_lvl, par_pos)
        sis_lvl = sis[0]
        sis_pos = sis[1] - sis[0]
        sis_span = (sis_lvl, sis_pos)

        output.append((par_span, sis_span))

    return output
示例#7
0
def get_topk_outside_index(length, level, K, offset_cache=None, cuda=False):
    if offset_cache is None:
        offset_cache = get_offset_cache(length)

    L = length - level
    N = length - level - 1

    components = get_outside_components(length, level, offset_cache)

    p_info, s_info = [], []
    for i, (p_span, s_span) in enumerate(components):
        p_level, p_pos = p_span
        s_level, s_pos = s_span
        n_idx = i // L
        x_pos = i % L
        p_idx = offset_cache[p_level] + p_pos
        s_idx = offset_cache[s_level] + s_pos

        p_info.append((x_pos, n_idx, p_level, p_pos, p_idx))
        s_info.append((x_pos, n_idx, s_level, s_pos, s_idx))

    def sort_key(x):
        x_pos, n_idx, inp_level, inp_pos, inp_idx = x
        return (x_pos, n_idx)

    def get_val(x):
        x_pos, n_idx, inp_level, inp_pos, inp_idx = x
        return inp_idx

    p_info = sorted(p_info, key=sort_key)
    s_info = sorted(s_info, key=sort_key)

    device = torch.cuda.current_device() if cuda else None

    p_index = torch.tensor([get_val(x) for x in p_info],
                           dtype=torch.long,
                           device=device)
    s_index = torch.tensor([get_val(x) for x in s_info],
                           dtype=torch.long,
                           device=device)

    return p_index, p_info, s_index, s_info
示例#8
0
 def func():
     return get_offset_cache(length)