def s_find_nearest(instance, data, k, dist_meas, instance_tie_breaker=min, exclude_self=True): data = [(d, dist_meas(instance, d)) for d in data if not (d is instance and exclude_self)] if len(data) <= k: return [d[0] for d in data] cmp_key = functools.cmp_to_key(KNN._comparer) if len(data) == k+1: maxes = max_multiple(data, key=cmp_key) while len(maxes) > 1: best = instance_tie_breaker((m[0] for m in maxes)) maxes = [m for m in maxes if m[0] is not best] m = maxes[0][0] return [d[0] for d in data if d[0] is not m] current_nearest_sorted = [] for (other, other_dist) in data: if other is instance and exclude_self: continue other_dist = dist_meas(instance, other) if len(current_nearest_sorted) < k \ or other_dist <= current_nearest_sorted[-1][1]: current_nearest_sorted.append((other, other_dist)) current_nearest_sorted.sort(key=cmp_key) if len(current_nearest_sorted) <= k: continue maxes = max_multiple(current_nearest_sorted, key=cmp_key) current_nearest_sorted = current_nearest_sorted[:-(len(maxes))] while len(current_nearest_sorted) < k: max_instances = [m[0] for m in maxes] best = instance_tie_breaker(max_instances) best_index = max_instances.index(best) current_nearest_sorted.append(maxes[best_index]) del(maxes[best_index]) return [nnd[0] for nnd in current_nearest_sorted]
def classify_from_neighbours(self, instance, neighbours): class_probabilities = self.get_probabilities(instance, neighbours) max_class_probabilities = max_multiple(class_probabilities, key=lambda cp: cp[1]) or self.possible_classes return self.classification_tie_breaker((m[0] for m in max_class_probabilities))
def __suppose_nn(self, case): ''' Determine the changes that would occur to all NN rNN sets if case were added. @param case: The case to suppose the addition of. ''' assert(not self.case_info_lookup.has_key(case)) # Just don't want to deal with the hassle right now. add_removals_dict = SuppositionResults(case) def get_or_create(_case): if not add_removals_dict.has_key(_case): add_removals_dict[_case] = AddRemovalStore() return add_removals_dict[_case] case_changes = get_or_create(case) case_nns = self.nns_getter(self.case_base, case) case_changes.added.nearest_neighbours.update(case_nns) for nn in case_nns: nn_changes = get_or_create(nn) nn_changes.added.reverse_nearest_neighbours.add(case) for (other_case, other_case_profile) in self.case_info_lookup.items(): if len(other_case_profile.nearest_neighbours) < self.__k: # Less than, as I'm going to be adding to (in which case if it was k-1, it's going to become k) other_case_new_nns = chain(other_case_profile.nearest_neighbours, (case,)) shunted = None else: # TODO: There is the issue of comparability here. Change to pass distance_measurer to __init__ instead of distance_constructor dist_meas = self.__distance_constructor(other_case_profile.nearest_neighbours) maxes = max_multiple(((other_case_nn, dist_meas(other_case, other_case_nn)) for other_case_nn in other_case_profile.nearest_neighbours), key=lambda el: el[1]) max_ex, max_dist = maxes[0] case_dist = dist_meas(other_case, case) if case_dist > max_dist: continue elif case_dist == max_dist or len(maxes) > 1: # Unsure - deferreing decision to nn finder for tie breaking other_case_new_nns = self.nns_getter(chain(other_case_profile.nearest_neighbours, (case,)), other_case) if case not in other_case_new_nns: continue difference = set(other_case_profile.nearest_neighbours).difference(other_case_new_nns) assert(len(difference) == 1) shunted = difference.pop() else: other_case_new_nns = [nn for nn in other_case_profile.nearest_neighbours if nn is not max_ex] other_case_new_nns.append(case) shunted = max_ex other_case_changes = get_or_create(other_case) # Heuston - we have a reversed nearest neighbour - case_changes.added.reverse_nearest_neighbours.add(other_case) other_case_changes.added.nearest_neighbours.add(case) # now, did it shunt out another one for the position if shunted is None: continue shunted_changes = get_or_create(shunted) other_case_changes.removed.nearest_neighbours.add(shunted) shunted_changes.removed.reverse_nearest_neighbours.add(other_case) assert(all(ca not in add_removed.added.nearest_neighbours for (ca, add_removed) in add_removals_dict.items())) return add_removals_dict