def _remove_node(tt: TourneyTree, pos: Position) -> int: children = tt.get_children(pos) val = tt[pos] if len(children) == 0: tt[pos] = -1 return -1 if len(children) == 1: new_val = _remove_node(tt, children[0]) tt[pos] = new_val return new_val same_child = tt.get_same_child(pos) other_child = tt.get_sibling(same_child) other_val = tt[other_child] new_same_val = _remove_node(tt, same_child) if new_same_val == -1: tt[pos] = other_val return other_val tt[pos] = new_same_val if COMPARE(new_same_val, other_val) else other_val return tt[pos]
def _guess_k_smallest_2(tt: TourneyTree, k: int) -> [int]: # item: (row on, row of partner) row_scores = dict() row_scores[tt[tt.top()]] = (tt.height - 1, tt.height - 1) for row in range(tt.height - 2, -1, -1): for pos in tt.iter_row(row): if tt[pos] not in row_scores: row_scores[tt[pos]] = (row, row_scores[tt[tt.get_sibling(pos)]][0]) if len(row_scores) >= k: return sorted(row_scores, key=row_scores.__getitem__, reverse=True)[:k] return sorted(row_scores, key=row_scores.__getitem__, reverse=True)[:k]
def _populate_row(tt: TourneyTree, row: int): for pos in tt.iter_row(row): children = tt.get_children(pos) if len(children) == 1: tt[pos] = tt[children[0]] continue if len(children) == 0: tt[pos] = -1 continue if len(children) != 2: raise Exception("should only be one or two children") tt[pos] = tt[children[0]] if COMPARE( tt[children[0]], tt[children[1]]) else tt[children[1]]
def _guess_k_smallest_3(tt: TourneyTree, k: int) -> [int]: # item: (row on, row of partner) row_scores = dict() row_scores[tt[tt.top()]] = (tt.height - 1, tt.height - 1) sort_by = dict() sort_by[tt[tt.top()]] = 2 * (tt.height - 1) for row in range(tt.height - 2, -1, -1): for pos in tt.iter_row(row): if tt[pos] not in row_scores: row_scores[tt[pos]] = (row, row_scores[tt[tt.get_sibling(pos)]][0]) # you can weight this 0.001 constant differently # doesn't seem fruitful though sort_by[tt[pos]] = row + 0.001 * row_scores[tt[tt.get_sibling( pos)]][0] return sorted(row_scores, key=sort_by.__getitem__, reverse=True)[:k]
def _guess_k_smallest(tt: TourneyTree, k: int) -> [int]: ksmallest_guess = list() for row in range(tt.height - 1, -1, -1): for pos in tt.iter_row(row): if tt[pos] not in ksmallest_guess: ksmallest_guess.append(tt[pos]) if len(ksmallest_guess) == k: return ksmallest_guess return ksmallest_guess
def _run(tt: TourneyTree, k: int) -> [int]: _populate_table(tt) ksmallest = list() for i in range(k): ksmallest.append(tt[tt.top()]) if i < k - 1: _remove_top(tt) return GET_LIST(ksmallest)
def _get_element_positions(tt: TourneyTree, elements: [int]) -> { int: Position }: elements_left = set(elements) element_positions = dict() for row in range(tt.height - 1, -1, -1): for pos in tt.iter_row(row): parent = tt.get_parent(pos) if tt[pos] not in elements_left: continue element_positions[tt[pos]] = pos elements_left.remove(tt[pos]) if len(elements_left) == 0: return element_positions return element_positions
def _create_smart_compare(tt: TourneyTree, elements: [int]) -> SmartCompare: sc = SmartCompare() for element, pos in _get_element_positions(tt, elements).items(): parent = tt.get_parent(pos) if not parent: continue sc.set_greater_than(tt[pos], tt[parent]) return sc
def _evaluate_contender(tt: TourneyTree, sc: SmartCompare, val: int, contenders: [int], k: int): pos = _get_element_positions(tt, [val])[val] for child in tt.iter_same_child(pos): sibling = tt.get_sibling(child) sibling_val = tt[sibling] if sibling_val < 0: continue if sibling_val in contenders: continue sc.set_greater_than(sibling_val, val) sc.compare(sibling_val, contenders[-1]) index = smart_binary_search(contenders, sc, sibling_val, contenders.index(val) + 1, len(contenders)) contenders.insert(index, sibling_val) if len(contenders) > k: contenders.pop()
def _test_second_method(): total_num_hits = total_num_comparisons = 0 total_trials = 1000 tt = TourneyTree(NUM_ELEMENTS) for i in range(total_trials): reset() ksmallest = _run_method_2(tt, K) total_num_hits += len([i for i in ksmallest if i < K]) total_num_comparisons += get_num_comparisons() print( f'{i}: {total_num_hits / (i + 1):.3f} {total_num_comparisons / (i + 1):.3f}' )
def _test_accuracy(): count_misses = 0 total_trials = 1000 total_num_comparisons = 0 tt = TourneyTree(NUM_ELEMENTS) for i in range(total_trials): reset() ksmallest = _run(tt, K) if ksmallest != list(range(K)): count_misses += 1 total_num_comparisons += get_num_comparisons() print( f'{i}: {count_misses} {count_misses / (i + 1) * 100:.2f} {total_num_comparisons / (i + 1):.3f}' )
def _remove_top(tt: TourneyTree): _remove_node(tt, tt.top())