def build_image_batch(images, dtype, channelfirst=False): """ Return batch of images. If images have no channel a channel axis is added. For channelfirst=True it will be added/moved to front otherwise the channel comes last. All images in batch will have a channel axis. Batch is of shape (n, c, h, w) or (n, h, w, c) depending on channelfirst, where n is the number of images in the batch. >>> from nutsml.datautil import shapestr >>> images = [np.zeros((2, 3)), np.ones((2, 3))] >>> batch = build_image_batch(images, 'uint8', True) >>> shapestr(batch) '2x1x2x3' >>> batch array([[[[0, 0, 0], [0, 0, 0]]], <BLANKLINE> <BLANKLINE> [[[1, 1, 1], [1, 1, 1]]]], dtype=uint8) :param numpy array images: Images to batch. Must be of shape (w,h,c) or (w,h). Gray-scale with channel is fine (w,h,1) and also alpha channel is fine (w,h,4). :param numpy data type dtype: Data type of batch, e.g. 'uint8' :param bool channelfirst: If True, channel is added/moved to front. :return: Image batch with shape (n, c, h, w) or (n, h, w, c). :rtype: np.array """ def _targetshape(image): shape = image.shape return (shape[0], shape[1], 1) if image.ndim == 2 else shape n = len(images) if not n: raise ValueError('No images to build batch!') h, w, c = _targetshape(images[0]) # shape of first(=all) images if c > w or c > h: raise ValueError('Channel not at last axis: ' + str((h, w, c))) batch = np.empty((n, c, h, w) if channelfirst else (n, h, w, c)) for i, image in enumerate(images): image = ni.add_channel(image, channelfirst) if image.shape != batch.shape[1:]: raise ValueError('Images vary in shape: ' + str(image.shape)) batch[i, :, :, :] = image return batch.astype(dtype)
def test_add_channel(): image = np.ones((10, 20)) assert ni.add_channel(image, True).shape == (1, 10, 20) assert ni.add_channel(image, False).shape == (10, 20, 1) image = np.ones((10, 20, 3)) assert ni.add_channel(image, True).shape == (3, 10, 20) assert ni.add_channel(image, False).shape == (10, 20, 3) with pytest.raises(ValueError) as ex: image = np.ones((10,)) ni.add_channel(image, True) assert str(ex.value).startswith('Image must be 2 or 3 channel!') with pytest.raises(ValueError) as ex: image = np.ones((10, 20, 3, 1)) ni.add_channel(image, True) assert str(ex.value).startswith('Image must be 2 or 3 channel!')