コード例 #1
0
def test_queuemask():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', get_queues=(True, False, False), val_size=0, _rand_data=True)
    assert test_queue is None
    assert val_queue is None
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', get_queues=(True, False, True), _rand_data=True)
    assert test_queue is None
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', get_queues=(False, True, False), val_size=0, _rand_data=True)
    assert train_queue is None
    assert val_queue is None
コード例 #2
0
def test_pullfromqueue():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', _rand_data=True)
    sleep(1)
    data, labels = train_queue.get_batch(100)
    test, labels = test_queue.get_batch(100)
    val, labels = val_queue.get_batch(100)
コード例 #3
0
def test_lastbatch():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', _rand_data=True)
    sleep(1)
    train_queue._read_count = train_queue.epoch_size - 200
    data, labels = train_queue.get_batch(100)
    assert not train_queue.last_batch
    data, labels = train_queue.get_batch(100)
    assert train_queue.last_batch
コード例 #4
0
def test_transform():
    def transform(x):
        return x-np.mean(x)

    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', transform=transform, _rand_data=True)
    sleep(1)
    data, labels = train_queue.get_batch(100)
    for im in data:
        assert abs(np.mean(im)) <= 0.0001
コード例 #5
0
def test_epochlimit():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', max_epochs=1, _rand_data=True)
    sleep(1)
    epoch_size = train_queue.epoch_size
    for i in range(int(np.ceil(epoch_size//100))):
        data, labels = train_queue.get_batch(100)

    with pytest.raises(FileQueueDepleted, message="Expecting queue empty exception at epoch limit"):  # noqa
        data, labels = train_queue.get_batch(100)
コード例 #6
0
def test_transform2():
    def transform_tr(x):
        return x-np.mean(x)
    def transform_te(x):
        return x-np.mean(x) + 100

    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', transform=(transform_tr, transform_te, None), _rand_data=True)
    sleep(1)
    data, labels = train_queue.get_batch(100)
    for im in data:
        assert abs(np.mean(im)) <= 0.0001
    data, labels = test_queue.get_batch(100)
    for im in data:
        assert abs(np.mean(im)-100) <= 0.0001
コード例 #7
0
def test_lastbatch_reset():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', _rand_data=True)
    sleep(1)
    train_queue._read_count = train_queue.epoch_size - 200

    for i in range(2):
        while not train_queue.last_batch:
            data, labels = train_queue.get_batch(100)
        if i == 0:
            old_data = data
        train_queue._read_count = train_queue.epoch_size - 100

    # test that the last_batch flag was reset properly so that the data should
    # be different now
    old_data = np.array(old_data)
    data = np.array(data)
    with pytest.raises(AssertionError, message="Expecting Different arrays"):
        np.testing.assert_array_equal(old_data, data)
コード例 #8
0
def test_getqueuenoval():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', val_size=0, _rand_data=True)
コード例 #9
0
def test_getqueue():
    train_queue, test_queue, val_queue = cifar.get_cifar_queues(
        '', _rand_data=True)