def test_target_selector_will_use_novel_examples_preferentially(): selector = TargetSelector(random=Random(0), pool_size=3) seen = set() for i in range(100): selector.add(FakeData()) assert len(selector) == min(i + 1, 3) t = selector.select().global_identifier assert t not in seen seen.add(t)
def test_always_starts_with_rare_tags(rnd, n): selector = TargetSelector(rnd) selector.add(FakeConjectureData(tags=frozenset({0}))) for _ in hrange(n): selector.select() selector.add(FakeConjectureData(tags=frozenset({1}))) _, data = selector.select() assert 1 in data.tags
class TargetSelectorMachine(RuleBasedStateMachine): def __init__(self): super(TargetSelectorMachine, self).__init__() self.target_selector = None self.data = [] self.tags = set() self.tag_intersections = None @precondition(lambda self: self.target_selector is None) @rule(rnd=fake_randoms()) def initialize(self, rnd): self.target_selector = TargetSelector(rnd) @precondition(lambda self: self.target_selector is not None) @rule( data=st.builds(FakeConjectureData, st.frozensets(st.integers(0, 10)))) def add_data(self, data): self.target_selector.add(data) self.data.append(data) self.tags.update(data.tags) if self.tag_intersections is None: self.tag_intersections = data.tags else: self.tag_intersections &= data.tags @precondition(lambda self: self.data) @rule() def select_target(self): tag, data = self.target_selector.select() assert self.target_selector.has_tag(tag, data) if self.tags != self.tag_intersections: assert tag != universal @precondition(lambda self: self.data) @rule() def cycle_through_tags(self): seen = set() for _ in hrange( (2 * len(self.tags) + 1) * (1 + self.target_selector.mutation_counts) ): _, data = self.target_selector.select() seen.update(data.tags) if seen == self.tags: break else: assert False
def test_selects_non_universal_tag(rnd): selector = TargetSelector(rnd) selector.add(FakeConjectureData({0})) selector.add(FakeConjectureData(set())) tag1, x = selector.select() assert tag1 is not universal tag2, y = selector.select() assert tag2 is not universal assert tag1 != tag2 assert x != y
def test_a_negated_tag_is_also_interesting(rnd): selector = TargetSelector(rnd) selector.add(FakeConjectureData(tags=frozenset({0}))) selector.add(FakeConjectureData(tags=frozenset({0}))) selector.add(FakeConjectureData(tags=frozenset())) _, data = selector.select() assert not data.tags
class TargetSelectorMachine(RuleBasedStateMachine): def __init__(self): super(TargetSelectorMachine, self).__init__() self.target_selector = None self.data = [] self.tags = set() self.tag_intersections = None @precondition(lambda self: self.target_selector is None) @rule(rnd=fake_randoms()) def initialize(self, rnd): self.target_selector = TargetSelector(rnd) @precondition(lambda self: self.target_selector is not None) @rule(data=st.builds(FakeConjectureData, st.frozensets(st.integers(0, 10)))) def add_data(self, data): self.target_selector.add(data) self.data.append(data) self.tags.update(data.tags) if self.tag_intersections is None: self.tag_intersections = data.tags else: self.tag_intersections &= data.tags @precondition(lambda self: self.data) @rule() def select_target(self): tag, data = self.target_selector.select() assert self.target_selector.has_tag(tag, data) if self.tags != self.tag_intersections: assert tag != universal @precondition(lambda self: self.data) @rule() def cycle_through_tags(self): seen = set() for _ in hrange((2 * len(self.tags) + 1) * (1 + self.target_selector.mutation_counts)): _, data = self.target_selector.select() seen.update(data.tags) if seen == self.tags: break else: assert False
def test_target_selector_will_maintain_a_bounded_size_with_scores(pool_size): selector = TargetSelector(random=Random(0), pool_size=pool_size) selector.add(FakeData()) for i in range(100): selector.add( FakeData(target_observations={str(i // 3 == 0): float(i % 30)})) assert len(selector) <= pool_size for label, pool in selector.scored_examples.items(): scores = [ex.target_observations[label] for ex in pool] assert scores == sorted(scores, reverse=True) selector.add(FakeData()) assert len(selector) <= pool_size
def test_cycles_through_all_tags_in_bounded_time_mixed(rnd, d1, d2): selector = TargetSelector(rnd) for d in d1: selector.add(d) check_bounded_cycle(selector) for d in d2: selector.add(d) check_bounded_cycle(selector)
def test_target_selector_will_eventually_reuse_examples(): selector = TargetSelector(random=Random(0), pool_size=2) seen = set() selector.add(FakeData()) selector.add(FakeData()) for _ in range(2): x = selector.select() assert x.global_identifier not in seen seen.add(x.global_identifier) for _ in range(2): x = selector.select() assert x.global_identifier in seen
def test_target_selector_will_maintain_a_bounded_pool(): selector = TargetSelector(random=Random(0), pool_size=3) for i in range(100): selector.add(FakeData()) assert len(selector) == min(i + 1, 3)
def initialize(self, rnd): self.target_selector = TargetSelector(rnd)
def test_cycles_through_all_tags_in_bounded_time(rnd, datas): selector = TargetSelector(rnd) for d in datas: selector.add(d) check_bounded_cycle(selector)
def test_target_selector_can_discard_labels(): selector = TargetSelector(random=Random(0), pool_size=2) for i in range(10): selector.add(FakeData(target_observations={str(i): 0.0})) assert len(selector) <= 2