class Mixture(object): def __init__(self): self.clustering = PitmanYor.Mixture() self.feature_x = nich.Mixture() self.feature_y = nich.Mixture() self.id_tracker = MixtureIdTracker() def __len__(self): return len(self.clustering) def init(self, model, empty_group_count=EMPTY_GROUP_COUNT): assert empty_group_count >= 1 counts = [0] * empty_group_count self.clustering.init(model.clustering, counts) assert len(self.clustering) == len(counts) self.id_tracker.init(len(counts)) self.feature_x.clear() self.feature_y.clear() for _ in xrange(empty_group_count): self.feature_x.add_group(model.feature) self.feature_y.add_group(model.feature) self.feature_x.init(model.feature) self.feature_y.init(model.feature) def score_value(self, model, xy, scores): x, y = xy self.clustering.score_value(model.clustering, scores) self.feature_x.score_value(model.feature, x, scores) self.feature_y.score_value(model.feature, y, scores) def add_value(self, model, groupid, xy): x, y = xy group_added = self.clustering.add_value(model.clustering, groupid) self.feature_x.add_value(model.feature, groupid, x) self.feature_y.add_value(model.feature, groupid, y) if group_added: self.feature_x.add_group(model.feature) self.feature_y.add_group(model.feature) self.id_tracker.add_group() def remove_value(self, model, groupid, xy): x, y = xy group_removeed = self.clustering.remove_value( model.clustering, groupid) self.feature_x.remove_value(model.feature, groupid, x) self.feature_y.remove_value(model.feature, groupid, y) if group_removeed: self.feature_x.remove_group(model.feature, groupid) self.feature_y.remove_group(model.feature, groupid) self.id_tracker.remove_group(groupid)
def test_mixture_score_matches_score_add_value(Model, EXAMPLE, *unused): sample_count = 200 model = Model() model.load(EXAMPLE) if Model.__name__ == 'LowEntropy' and sample_count > model.dataset_size: raise SkipTest('skipping trivial example') assignment_vector = model.sample_assignments(sample_count) assignments = dict(enumerate(assignment_vector)) nonempty_counts = count_assignments(assignments) nonempty_group_count = len(nonempty_counts) assert_greater(nonempty_group_count, 1, "test is inaccurate") def check_counts(mixture, counts, empty_group_count): # print 'counts =', counts empty_groupids = frozenset(mixture.empty_groupids) assert_equal(len(empty_groupids), empty_group_count) for groupid in empty_groupids: assert_equal(counts[groupid], 0) def check_scores(mixture, counts, empty_group_count): sample_count = sum(counts) nonempty_group_count = len(counts) - empty_group_count expected = [ model.score_add_value( group_size, nonempty_group_count, sample_count, empty_group_count) for group_size in counts ] noise = numpy.random.randn(len(counts)) actual = numpy.zeros(len(counts), dtype=numpy.float32) actual[:] = noise mixture.score_value(model, actual) assert_close(actual, expected) return actual for empty_group_count in [1, 10]: print 'empty_group_count =', empty_group_count counts = nonempty_counts + [0] * empty_group_count numpy.random.shuffle(counts) mixture = Model.Mixture() id_tracker = MixtureIdTracker() print 'init' mixture.init(model, counts) id_tracker.init(len(counts)) check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count) print 'adding' groupids = [] for _ in xrange(sample_count): check_counts(mixture, counts, empty_group_count) scores = check_scores(mixture, counts, empty_group_count) probs = scores_to_probs(scores) groupid = sample_discrete(probs) expected_group_added = (counts[groupid] == 0) counts[groupid] += 1 actual_group_added = mixture.add_value(model, groupid) assert_equal(actual_group_added, expected_group_added) groupids.append(groupid) if actual_group_added: id_tracker.add_group() counts.append(0) check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count) print 'removing' for global_groupid in groupids: groupid = id_tracker.global_to_packed(global_groupid) counts[groupid] -= 1 expected_group_removed = (counts[groupid] == 0) actual_group_removed = mixture.remove_value(model, groupid) assert_equal(actual_group_removed, expected_group_removed) if expected_group_removed: id_tracker.remove_group(groupid) back = counts.pop() if groupid < len(counts): counts[groupid] = back check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count)
def __init__(self): self.clustering = PitmanYor.Mixture() self.feature_x = nich.Mixture() self.feature_y = nich.Mixture() self.id_tracker = MixtureIdTracker()
def test_mixture_score_matches_score_add_value(Model, EXAMPLE, *unused): sample_count = 200 model = Model() model.load(EXAMPLE) if Model.__name__ == 'LowEntropy' and sample_count > model.dataset_size: raise SkipTest('skipping trivial example') assignment_vector = model.sample_assignments(sample_count) assignments = dict(enumerate(assignment_vector)) nonempty_counts = count_assignments(assignments) nonempty_group_count = len(nonempty_counts) assert_greater(nonempty_group_count, 1, "test is inaccurate") def check_counts(mixture, counts, empty_group_count): # print 'counts =', counts empty_groupids = frozenset(mixture.empty_groupids) assert_equal(len(empty_groupids), empty_group_count) for groupid in empty_groupids: assert_equal(counts[groupid], 0) def check_scores(mixture, counts, empty_group_count): sample_count = sum(counts) nonempty_group_count = len(counts) - empty_group_count expected = [ model.score_add_value(group_size, nonempty_group_count, sample_count, empty_group_count) for group_size in counts ] noise = numpy.random.randn(len(counts)) actual = numpy.zeros(len(counts), dtype=numpy.float32) actual[:] = noise mixture.score_value(model, actual) assert_close(actual, expected) return actual for empty_group_count in [1, 10]: print 'empty_group_count =', empty_group_count counts = nonempty_counts + [0] * empty_group_count numpy.random.shuffle(counts) mixture = Model.Mixture() id_tracker = MixtureIdTracker() print 'init' mixture.init(model, counts) id_tracker.init(len(counts)) check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count) print 'adding' groupids = [] for _ in xrange(sample_count): check_counts(mixture, counts, empty_group_count) scores = check_scores(mixture, counts, empty_group_count) probs = scores_to_probs(scores) groupid = sample_discrete(probs) expected_group_added = (counts[groupid] == 0) counts[groupid] += 1 actual_group_added = mixture.add_value(model, groupid) assert_equal(actual_group_added, expected_group_added) groupids.append(groupid) if actual_group_added: id_tracker.add_group() counts.append(0) check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count) print 'removing' for global_groupid in groupids: groupid = id_tracker.global_to_packed(global_groupid) counts[groupid] -= 1 expected_group_removed = (counts[groupid] == 0) actual_group_removed = mixture.remove_value(model, groupid) assert_equal(actual_group_removed, expected_group_removed) if expected_group_removed: id_tracker.remove_group(groupid) back = counts.pop() if groupid < len(counts): counts[groupid] = back check_counts(mixture, counts, empty_group_count) check_scores(mixture, counts, empty_group_count)