Beispiel #1
0
def test_StableRandom():
    rnd = StableRandom(1)
    assert rnd.randint(1, 10) == 5
    assert rnd.uniform(-10, 10) == approx(9.943696167306904)
    assert rnd.random() == approx(0.7203244894557457)
    assert rnd.sample(range(10), 3) == [9, 0, 1]

    lst = [1, 2, 3, 4, 5]
    rnd.shuffle(lst)
    assert lst == [5, 3, 1, 4, 2]

    rnd = StableRandom()
    assert max(rnd.random() for _ in range(1000)) < 1.0
    assert min(rnd.random() for _ in range(1000)) >= 0.0

    rnd1, rnd2 = StableRandom(0), StableRandom(0)
    for _ in range(100):
        assert rnd1.random() == rnd2.random()

    rnd1, rnd2 = StableRandom(0), StableRandom(0)
    rnd2.jumpahead(10)
    for _ in range(100):
        assert rnd1.random() != rnd2.random()
    rnd2.setstate(rnd1.getstate())
    for _ in range(100):
        assert rnd1.random() == rnd2.random()

    rnd1, rnd2 = StableRandom(0), StableRandom(1)
    for _ in range(100):
        assert rnd1.random() != rnd2.random()

    rnd1 = StableRandom()
    sleep(0.5)  # seed is based on system time.
    rnd2 = StableRandom()
    for _ in range(100):
        assert rnd1.random() != rnd2.random()

    rnd = StableRandom()
    numbers = [rnd._randbelow(10) for _ in range(1000)]
    assert max(numbers) < 10
    assert min(numbers) >= 0

    rnd = StableRandom()
    numbers = [rnd.gauss_next() for _ in range(10000)]
    my, std = numbers >> MeanStd()
    assert 0.0 == approx(my, abs=0.1)
    assert 1.0 == approx(std, abs=0.1)
def SplitRandom(iterable, ratio=0.7, constraint=None, rand=None):
    """
    Randomly split iterable into partitions.

    For the same input data the same split is created every time and is stable
    across different Python version 2.x or 3.x. A random number generator
    can be provided to create varying splits.

    >>> train, val = range(10) >> SplitRandom(ratio=0.7)
    >>> train, val
    ([6, 3, 1, 7, 0, 2, 4], [5, 9, 8])

    >>> range(10) >> SplitRandom(ratio=0.7)  # Same split again
    [[6, 3, 1, 7, 0, 2, 4], [5, 9, 8]]

    >>> train, val, test = range(10) >> SplitRandom(ratio=(0.6, 0.3, 0.1))
    >>> train, val, test
    ([6, 1, 4, 0, 3, 2], [8, 7, 9], [5])

    >>> data = zip('aabbccddee', range(10))
    >>> same_letter = lambda t: t[0]
    >>> train, val = data >> SplitRandom(ratio=0.6, constraint=same_letter)
    >>> train
    [('a', 1), ('a', 0), ('d', 7), ('b', 2), ('d', 6), ('b', 3)]
    >>> val
    [('c', 5), ('e', 8), ('e', 9), ('c', 4)]

    :param iterable iterable: Iterable over anything. Will be consumed!
    :param float|tuple ratio: Ratio of two partition e.g. a ratio of 0.7
            means 70%, 30% split.
            Alternatively a list or ratios can be provided, e.g.
            ratio=(0.6, 0.3, 0.1). Note that ratios must sum up to one.
    :param function|None constraint: Function that returns key the elements of
        the iterable are grouped by before partitioning. Useful to ensure
        that a partition contains related elements, e.g. left and right eye
        images are not scattered across partitions.
        Note that constrains have precedence over ratios.
    :param Random|None rand: Random number generator. The default None
            ensures that the same split is created every time SplitRandom
            is called. This is important when continuing an interrupted
            training session or running the same training on machines with
            different Python versions. Note that Python's random.Random(0)
            generates different number for Python 2.x and 3.x!
    :return: partitions of iterable with sizes according to provided ratios.
    :rtype: (list, list, ...)
    """
    rand = StableRandom(0) if rand is None else rand
    samples = list(iterable)
    if hasattr(ratio, '__iter__'):
        ratios = tuple(ratio)
        if abs(sum(ratios) - 1.0) > 1e-6:
            raise ValueError('Ratios must sum up to one: ' + str(ratios))
    else:
        ratios = (ratio, 1.0 - ratio)
    ns = [int(len(samples) * r) for r in ratios]

    if constraint is None:
        groups = [[s] for s in samples]
    else:
        # sort to make stable across python 2.x, 3.x
        groups = sorted(group_by(samples, constraint).values())
    rand.shuffle(groups)
    groups = iter(groups)
    splits = []

    def append(split):
        rand.shuffle(split)
        splits.append(split)

    for n in ns[:-1]:
        split = []
        for group in groups:
            split.extend(group)
            if len(split) >= n:
                append(split)
                break
    append([e for g in groups for e in g])  # append remaining groups
    return splits