Esempio n. 1
0
def test_simple_determinism():
    """Test to check that the extraction of the batches is deterministic."""
    classes = 10
    number = 100000
    epochs = 100
    batch_size = 10000

    x = np.arange(0, number, dtype=np.int64)
    y = np.random.randint(0, classes, size=number)

    ms = MixedSequence(VectorSequence(x, batch_size),
                       VectorSequence(y, batch_size))

    ms2 = MixedSequence(VectorSequence(x, batch_size),
                        VectorSequence(y, batch_size))

    for epoch in range(epochs):
        for step in range(ms.steps_per_epoch):
            xi, yi = ms[step]
            xj, yj = ms2[step]
            if epoch == 0:
                # The first epochs they must be aligned
                assert (xi == xj).all()
                assert (yi == yj).all()
            else:
                # Afterwards, since the ms2 is not shuffled, they must not be
                # anymore. Or at least, is very unlikely.
                assert (xi != xj).any()
                assert (yi != yj).any()
            assert (y[xi] == yi).all()

        ms.on_epoch_end()
Esempio n. 2
0
def test_genomic_sequence_determinism():
    batch_size = 32
    epochs = 5
    enhancers = pd.read_csv("tests/enhancers.csv")
    promoters = pd.read_csv("tests/promoters.csv")

    genome = Genome("hg19", chromosomes=["chr1"])
    for region in tqdm((enhancers, promoters), desc="Region types"):
        y = np.arange(0, len(region), dtype=np.int64)
        mixed_sequence = MixedSequence(x=BedSequence(genome, region,
                                                     batch_size),
                                       y=VectorSequence(y, batch_size))
        reference_mixed_sequence = MixedSequence(
            x=BedSequence(genome,
                          region,
                          batch_size=len(region),
                          shuffle=False),
            y=VectorSequence(y, batch_size=len(region), shuffle=False))
        X, _ = reference_mixed_sequence[0]
        for _ in trange(epochs, desc="Epochs", leave=False):
            for step in range(mixed_sequence.steps_per_epoch):
                xi, yi = mixed_sequence[step]
                assert (X[yi.astype(int)] == xi).all()
            mixed_sequence.on_epoch_end()