def test_BuildBatch(): numbers = [4.1, 3.2, 1.1] vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])] images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))] class_ids = [1, 2, 1] samples = zip(numbers, vectors, images, class_ids) build_batch = (nb.BuildBatch(2, prefetch=0).input( 0, 'number', float).input(1, 'vector', np.uint8).input(2, 'image', np.uint8, False).output(3, 'one_hot', 'uint8', 3)) batches = samples >> build_batch >> Collect() assert len(batches) == 2 batch = batches[0] assert len(batch) == 2, 'Expect inputs and outputs' ins, outs = batch assert len(ins) == 3, 'Expect three input columns in batch' assert len(outs) == 1, 'Expect one output column in batch' assert np.array_equal(ins[0], nb.build_number_batch(numbers[:2], float)) assert np.array_equal(ins[1], nb.build_vector_batch(vectors[:2], 'uint8')) assert np.array_equal(ins[2], nb.build_image_batch(images[:2], 'uint8')) assert np.array_equal(outs[0], nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))
def test_build_image_batch_exceptions(): with pytest.raises(ValueError) as ex: nb.build_image_batch([], 'uint8') assert str(ex.value).startswith('No images ') with pytest.raises(ValueError) as ex: images = [np.zeros((3, 10, 5)), np.ones((3, 10, 5))] nb.build_image_batch(images, 'uint8') assert str(ex.value).startswith('Channel not at last axis') with pytest.raises(ValueError) as ex: images = [np.zeros((10, 5)), np.ones((15, 5))] nb.build_image_batch(images, 'uint8') assert str(ex.value).startswith('Images vary in shape')
def test_BuildBatch(): numbers = [4.1, 3.2, 1.1] vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])] images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))] class_ids = [1, 2, 1] samples = zip(numbers, vectors, images, class_ids) build_batch = (nb.BuildBatch(2, prefetch=0).by(0, 'number', float).by( 1, 'vector', np.uint8).by(2, 'image', np.uint8, False).by(3, 'one_hot', 'uint8', 3)) batches = samples >> build_batch >> Collect() assert len(batches) == 2 batch = batches[0] assert len(batch) == 4, 'Expect four columns in batch' assert np.array_equal(batch[0], nb.build_number_batch(numbers[:2], float)) assert np.array_equal(batch[1], nb.build_vector_batch(vectors[:2], 'uint8')) assert np.array_equal(batch[2], nb.build_image_batch(images[:2], 'uint8')) assert np.array_equal(batch[3], nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))