def test_Mixup(): numbers1 = [1, 2, 3] numbers2 = [4, 5, 6] samples = list(zip(numbers1, numbers2)) build_batch = (nb.BuildBatch(3, prefetch=0).input(0, 'number', float).output( 1, 'number', float)) # no mixup, return original batch mixup = nb.Mixup(0.0) batches = samples >> build_batch >> mixup >> Collect() inputs, outputs = batches[0] assert list(inputs[0]) == numbers1 assert list(outputs[0]) == numbers2 # mixup with alpaha=1.0 mixup = nb.Mixup(1.0) batches = samples >> build_batch >> mixup >> Collect() for input, output in batches: input, output = input[0], output[0] assert min(input) >= 1 and max(input) <= 3 assert min(output) >= 4 and max(output) <= 6 ri, ro = input[0] - samples[0][0], output[0] - samples[0][1] assert approx(ri, 1e-3) == ro
def test_BuildBatch_verbose(): with Redirect() as out: samples = [[1], [2], [3]] build_batch = (nb.BuildBatch(2, verbose=True).input(0, 'number', 'uint8')) samples >> build_batch >> Consume() assert out.getvalue() == '[2:uint8]\n[1:uint8]\n' with Redirect() as out: samples = [(np.array([1, 2, 3]), 0), (np.array([4, 5, 6]), 1), (np.array([7, 8, 9]), 1)] build_batch = (nb.BuildBatch(2, verbose=True).input( 0, 'vector', 'float32').output(1, 'one_hot', 'uint8', 2)) samples >> build_batch >> Consume() expected = '[[2x3:float32], [2x2:uint8]]\n[[1x3:float32], [1x2:uint8]]\n' assert out.getvalue() == expected
def test_BuildBatch(): numbers = [4.1, 3.2, 1.1] vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])] images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))] class_ids = [1, 2, 1] samples = zip(numbers, vectors, images, class_ids) build_batch = (nb.BuildBatch(2, prefetch=0).input( 0, 'number', float).input(1, 'vector', np.uint8).input(2, 'image', np.uint8, False).output(3, 'one_hot', 'uint8', 3)) batches = samples >> build_batch >> Collect() assert len(batches) == 2 batch = batches[0] assert len(batch) == 2, 'Expect inputs and outputs' ins, outs = batch assert len(ins) == 3, 'Expect three input columns in batch' assert len(outs) == 1, 'Expect one output column in batch' assert np.array_equal(ins[0], nb.build_number_batch(numbers[:2], float)) assert np.array_equal(ins[1], nb.build_vector_batch(vectors[:2], 'uint8')) assert np.array_equal(ins[2], nb.build_image_batch(images[:2], 'uint8')) assert np.array_equal(outs[0], nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))
def test_BuildBatch_exceptions(): class_ids = [1, 2] numbers = [4.1, 3.2] samples = zip(numbers, class_ids) with pytest.raises(ValueError) as ex: build_batch = (nb.BuildBatch(2, prefetch=0).by(0, 'number', float).by( 1, 'invalid', 'uint8', 3)) samples >> build_batch >> Collect() assert str(ex.value).startswith('Invalid builder')
def test_BuildBatch_fmt(): numbers1 = [1, 2, 3] numbers2 = [4, 5, 6] samples = zip(numbers1, numbers2) build_batch = (nb.BuildBatch(3, prefetch=0, fmt=lambda t: ((t[0], t[1], t[0]), t[1])).by( 0, 'number', float).by(1, 'number', float)) batches = samples >> build_batch >> Collect() assert len(batches) == 1 ((a, b, c), d) = batches[0] assert list(a) == numbers1 assert list(b) == numbers2 assert list(c) == numbers1 assert list(d) == numbers2
def test_BuildBatch(): numbers = [4.1, 3.2, 1.1] vectors = [np.array([1, 2, 3]), np.array([2, 3, 4]), np.array([3, 4, 5])] images = [np.zeros((5, 3)), np.ones((5, 3)), np.ones((5, 3))] class_ids = [1, 2, 1] samples = zip(numbers, vectors, images, class_ids) build_batch = (nb.BuildBatch(2, prefetch=0).by(0, 'number', float).by( 1, 'vector', np.uint8).by(2, 'image', np.uint8, False).by(3, 'one_hot', 'uint8', 3)) batches = samples >> build_batch >> Collect() assert len(batches) == 2 batch = batches[0] assert len(batch) == 4, 'Expect four columns in batch' assert np.array_equal(batch[0], nb.build_number_batch(numbers[:2], float)) assert np.array_equal(batch[1], nb.build_vector_batch(vectors[:2], 'uint8')) assert np.array_equal(batch[2], nb.build_image_batch(images[:2], 'uint8')) assert np.array_equal(batch[3], nb.build_one_hot_batch(class_ids[:2], 'uint8', 3))
def test_BuildBatch_prefetch(): samples = [[1], [2]] build_batch = (nb.BuildBatch(2, prefetch=1).input(0, 'number', 'uint8')) batches = samples >> build_batch >> Collect() batch = batches[0][0] assert np.array_equal(batch, np.array([1, 2], dtype='uint8'))