class OneCharStringStrategy(SearchStrategy): """A strategy which generates single character strings of text type.""" descriptor = text_type ascii_characters = (text_type('0123456789') + text_type(string.ascii_letters) + text_type(' \t\n')) parameter = params.CompositeParameter( ascii_chance=params.UniformFloatParameter(0, 1)) def produce(self, random, pv): if dist.biased_coin(random, pv.ascii_chance): return random.choice(self.ascii_characters) else: while True: result = hunichr(random.randint(0, sys.maxunicode)) if unicodedata.category(result) != 'Cs': return result def simplify(self, x): if x in self.ascii_characters: for i in hrange(self.ascii_characters.index(x), -1, -1): yield self.ascii_characters[i] else: o = ord(x) for c in reversed(self.ascii_characters): yield text_type(c) if o > 0: yield hunichr(o // 2) yield hunichr(o - 1)
class FixedBoundedFloatStrategy(SearchStrategy): """A strategy for floats distributed between two endpoints. The conditional distribution tries to produce values clustered closer to one of the ends. """ descriptor = float parameter = params.CompositeParameter( cut=params.UniformFloatParameter(0, 1), leftwards=params.BiasedCoin(0.5), ) def __init__(self, lower_bound, upper_bound): SearchStrategy.__init__(self) self.lower_bound = float(lower_bound) self.upper_bound = float(upper_bound) def produce(self, random, pv): if pv.leftwards: left = self.lower_bound right = pv.cut else: left = pv.cut right = self.upper_bound return left + random.random() * (right - left) def simplify(self, value): yield self.lower_bound yield self.upper_bound yield (self.lower_bound + self.upper_bound) * 0.5
class FullRangeFloats(FloatStrategy): parameter = params.CompositeParameter( negative_probability=params.UniformFloatParameter(0, 1), subnormal_probability=params.UniformFloatParameter(0, 0.5), ) def produce(self, random, pv): sign = int(dist.biased_coin(random, pv.negative_probability)) if dist.biased_coin(random, pv.subnormal_probability): exponent = 0 else: exponent = random.getrandbits(11) return compose_float(sign, exponent, random.getrandbits(52)) def could_have_produced(self, value): return isinstance(value, float)
class BoolStrategy(SearchStrategy): """A strategy that produces Booleans with a Bernoulli conditional distribution.""" descriptor = bool parameter = params.UniformFloatParameter(0, 1) def produce(self, random, p): return dist.biased_coin(random, p)
def __init__(self, main_strategy, examples): assert examples assert all(main_strategy.could_have_produced(e) for e in examples) self.examples = tuple(examples) self.main_strategy = main_strategy self.descriptor = main_strategy.descriptor self.parameter = params.CompositeParameter( examples=params.NonEmptySubset(examples), example_probability=params.UniformFloatParameter(0.0, 0.5), main=main_strategy.parameter) self.has_immutable_data = main_strategy.has_immutable_data if hasattr(main_strategy, 'element_strategy'): self.element_strategy = main_strategy.element_strategy
class DatetimeStrategy(SearchStrategy): descriptor = datetime parameter = params.CompositeParameter( p_hour=params.UniformFloatParameter(0, 1), p_minute=params.UniformFloatParameter(0, 1), p_second=params.UniformFloatParameter(0, 1), month=params.NonEmptySubset(list(range(1, 13))), naive_chance=params.UniformFloatParameter(0, 0.5), utc_chance=params.UniformFloatParameter(0, 1), timezones=params.NonEmptySubset( list(map(pytz.timezone, pytz.all_timezones)))) def produce(self, random, pv): year = random.randint(MINYEAR, MAXYEAR) month = random.choice(pv.month) base = datetime( year=year, month=month, day=draw_day_for_month(random, year, month), hour=maybe_zero_or(random, pv.p_hour, random.randint(0, 23)), minute=maybe_zero_or(random, pv.p_minute, random.randint(0, 59)), second=maybe_zero_or(random, pv.p_second, random.randint(0, 59)), microsecond=random.randint(0, 1000000 - 1), ) if random.random() <= pv.naive_chance: return base if random.random() <= pv.utc_chance: return pytz.UTC.localize(base) return random.choice(pv.timezones).localize(base) def simplify(self, value): if not value.tzinfo: yield pytz.UTC.localize(value) elif value.tzinfo != pytz.UTC: yield pytz.UTC.normalize(value.astimezone(pytz.UTC)) s = {value} s.add(value.replace(microsecond=0)) s.add(value.replace(second=0)) s.add(value.replace(minute=0)) s.add(value.replace(hour=0)) s.add(value.replace(day=1)) s.add(value.replace(month=1)) s.remove(value) for t in s: yield t year = value.year if year == 2000: return yield value.replace(year=2000) # We swallow a bunch of value errors here. # These can happen if the original value was february 29 on a # leap year and the current year is not a leap year. # Note that 2000 was a leap year which is why we didn't need one above. mid = (year + 2000) // 2 if mid != 2000 and mid != year: try: yield value.replace(year=mid) except ValueError: pass years = hrange(year, 2000, -1 if year > 2000 else 1) for year in years: if year == mid: continue try: yield value.replace(year) except ValueError: pass