def get_outside_index(length, level, offset_cache=None, cuda=False): if offset_cache is None: offset_cache = get_offset_cache(length) index = OutsideIndex() pairs = index.get_all_pairs(level, length) par_lvl, par_pos = [], [] sis_lvl, sis_pos = [], [] for pair in pairs: par, sis = pair par_lvl.append(par[0]) par_pos.append(par[1] - par[0]) sis_lvl.append(sis[0]) sis_pos.append(sis[1] - sis[0]) device = torch.cuda.current_device() if cuda else None # Parent index = [] for lvl, pos in zip(par_lvl, par_pos): offset = offset_cache[lvl] idx = offset + pos index.append(idx) par_index = torch.tensor(index, dtype=torch.long, device=device) # Sibling index = [] for lvl, pos in zip(sis_lvl, sis_pos): offset = offset_cache[lvl] idx = offset + pos index.append(idx) sis_index = torch.tensor(index, dtype=torch.long, device=device) return par_index, sis_index
def get_inside_index_unique(length, level, offset_cache=None, cuda=False): if offset_cache is None: offset_cache = get_offset_cache(length) index = InsideIndex() pairs = index.get_all_pairs(level, length) L = length - level n_constituents = len(pairs) // L idx_set = set() for i in range(n_constituents): lvl_l = i lvl_r = level - i - 1 lstart, lend = 0, L rstart, rend = length - L - lvl_r, length - lvl_r if lvl_l < 0: lvl_l = length + lvl_l if lvl_r < 0: lvl_r = length + lvl_r for pos in range(lstart, lend): offset = offset_cache[lvl_l] idx = offset + pos idx_set.add(idx) for pos in range(rstart, rend): offset = offset_cache[lvl_r] idx = offset + pos idx_set.add(idx) device = torch.cuda.current_device() if cuda else None idx_lst = torch.tensor(list(idx_set), dtype=torch.int64, device=device).flatten() return idx_lst
def get_outside_target(length, level, offset_cache=None, cuda=False): if offset_cache is None: offset_cache = get_offset_cache(length) L = length - level offset = offset_cache[level] target = [] for i in range(L - 1): target.extend(range(offset, offset + L)) device = torch.cuda.current_device() if cuda else None target = torch.tensor(target, dtype=torch.long, device=device) return target
def get_inside_components(length, level, offset_cache=None): if offset_cache is None: offset_cache = get_offset_cache(length) index = InsideIndex() pairs = index.get_all_pairs(level, length) L = length - level n_constituents = len(pairs) // L output = [] for i in range(n_constituents): index_l, index_r = [], [] span_x, span_l, span_r = [], [], [] l_level = i r_level = level - l_level - 1 l_start = 0 l_end = L r_start = length - L - r_level r_end = length - r_level if l_level < 0: l_level = length + l_level if r_level < 0: r_level = length + r_level # The span being targeted. for pos in range(l_start, l_end): span_x.append((level, pos)) # The left child. for pos in range(l_start, l_end): offset = offset_cache[l_level] idx = offset + pos index_l.append(idx) span_l.append((l_level, pos)) # The right child. for pos in range(r_start, r_end): offset = offset_cache[r_level] idx = offset + pos index_r.append(idx) span_r.append((r_level, pos)) output.append((index_l, index_r, span_x, span_l, span_r)) return output
def get_outside_encoded_index(length, offset_cache=None, cuda=False): if offset_cache is None: offset_cache = get_offset_cache(length) index = OutsideIndex() outside_leaf = {} node2leaf = index.get_leaf(length) device = torch.cuda.current_device() if cuda else None for (lvl, pos), leaf in node2leaf.items(): offset = offset_cache[lvl] idx = offset + pos leaf = torch.tensor(leaf, dtype=torch.long, device=device) outside_leaf[idx] = leaf return outside_leaf
def get_outside_components(length, level, offset_cache=None): if offset_cache is None: offset_cache = get_offset_cache(length) index = OutsideIndex() pairs = index.get_all_pairs(level, length) output = [] for pair in pairs: par, sis = pair par_lvl = par[0] par_pos = par[1] - par[0] par_span = (par_lvl, par_pos) sis_lvl = sis[0] sis_pos = sis[1] - sis[0] sis_span = (sis_lvl, sis_pos) output.append((par_span, sis_span)) return output
def get_topk_outside_index(length, level, K, offset_cache=None, cuda=False): if offset_cache is None: offset_cache = get_offset_cache(length) L = length - level N = length - level - 1 components = get_outside_components(length, level, offset_cache) p_info, s_info = [], [] for i, (p_span, s_span) in enumerate(components): p_level, p_pos = p_span s_level, s_pos = s_span n_idx = i // L x_pos = i % L p_idx = offset_cache[p_level] + p_pos s_idx = offset_cache[s_level] + s_pos p_info.append((x_pos, n_idx, p_level, p_pos, p_idx)) s_info.append((x_pos, n_idx, s_level, s_pos, s_idx)) def sort_key(x): x_pos, n_idx, inp_level, inp_pos, inp_idx = x return (x_pos, n_idx) def get_val(x): x_pos, n_idx, inp_level, inp_pos, inp_idx = x return inp_idx p_info = sorted(p_info, key=sort_key) s_info = sorted(s_info, key=sort_key) device = torch.cuda.current_device() if cuda else None p_index = torch.tensor([get_val(x) for x in p_info], dtype=torch.long, device=device) s_index = torch.tensor([get_val(x) for x in s_info], dtype=torch.long, device=device) return p_index, p_info, s_index, s_info
def func(): return get_offset_cache(length)