Exemplo n.º 1
0
def test_build_vector_batch():
    vectors = [np.array([1, 2, 3]), np.array([2, 3, 4])]
    batch = nb.build_vector_batch(vectors, 'uint8')
    expected = np.array([[1, 2, 3], [2, 3, 4]], dtype='uint8')
    assert batch.dtype == np.uint8
    assert np.array_equal(batch, expected)

    with pytest.raises(ValueError) as ex:
        nb.build_vector_batch([], 'uint8')
    assert str(ex.value).startswith('No vectors ')
Exemplo n.º 2
0
def test_BuildBatch():
    numbers = [4.1, 3.2, 1.1]
    vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])]
    images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))]
    class_ids = [1, 2, 1]
    samples = zip(numbers, vectors, images, class_ids)

    build_batch = (nb.BuildBatch(2, prefetch=0).input(
        0, 'number',
        float).input(1, 'vector',
                     np.uint8).input(2, 'image', np.uint8,
                                     False).output(3, 'one_hot', 'uint8', 3))
    batches = samples >> build_batch >> Collect()
    assert len(batches) == 2

    batch = batches[0]
    assert len(batch) == 2, 'Expect inputs and outputs'
    ins, outs = batch
    assert len(ins) == 3, 'Expect three input columns in batch'
    assert len(outs) == 1, 'Expect one output column in batch'
    assert np.array_equal(ins[0], nb.build_number_batch(numbers[:2], float))
    assert np.array_equal(ins[1], nb.build_vector_batch(vectors[:2], 'uint8'))
    assert np.array_equal(ins[2], nb.build_image_batch(images[:2], 'uint8'))
    assert np.array_equal(outs[0],
                          nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))
Exemplo n.º 3
0
def test_BuildBatch():
    numbers = [4.1, 3.2, 1.1]
    vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])]
    images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))]
    class_ids = [1, 2, 1]
    samples = zip(numbers, vectors, images, class_ids)

    build_batch = (nb.BuildBatch(2, prefetch=0).by(0, 'number', float).by(
        1, 'vector', np.uint8).by(2, 'image', np.uint8,
                                  False).by(3, 'one_hot', 'uint8', 3))
    batches = samples >> build_batch >> Collect()
    assert len(batches) == 2

    batch = batches[0]
    assert len(batch) == 4, 'Expect four columns in batch'
    assert np.array_equal(batch[0], nb.build_number_batch(numbers[:2], float))
    assert np.array_equal(batch[1],
                          nb.build_vector_batch(vectors[:2], 'uint8'))
    assert np.array_equal(batch[2], nb.build_image_batch(images[:2], 'uint8'))
    assert np.array_equal(batch[3],
                          nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))