コード例 #1
0
ファイル: batcher.py プロジェクト: gsanou/nuts-ml
def build_image_batch(images, dtype, channelfirst=False):
    """
    Return batch of images.

    If images have no channel a channel axis is added. For channelfirst=True
    it will be added/moved to front otherwise the channel comes last.
    All images in batch will have a channel axis. Batch is of shape
    (n, c, h, w) or (n, h, w, c) depending on channelfirst, where n is
    the number of images in the batch.

    >>> from nutsml.datautil import shapestr
    >>> images = [np.zeros((2, 3)), np.ones((2, 3))]
    >>> batch = build_image_batch(images, 'uint8', True)
    >>> shapestr(batch)
    '2x1x2x3'

    >>> batch
    array([[[[0, 0, 0],
             [0, 0, 0]]],
    <BLANKLINE>
    <BLANKLINE>
           [[[1, 1, 1],
             [1, 1, 1]]]], dtype=uint8)

    :param numpy array images: Images to batch. Must be of shape (w,h,c)
           or (w,h). Gray-scale with channel is fine (w,h,1) and also
           alpha channel is fine (w,h,4).
    :param numpy data type dtype: Data type of batch, e.g. 'uint8'
    :param bool channelfirst: If True, channel is added/moved to front.
    :return: Image batch with shape (n, c, h, w) or (n, h, w, c).
    :rtype: np.array
    """

    def _targetshape(image):
        shape = image.shape
        return (shape[0], shape[1], 1) if image.ndim == 2 else shape

    n = len(images)
    if not n:
        raise ValueError('No images to build batch!')
    h, w, c = _targetshape(images[0])  # shape of first(=all) images
    if c > w or c > h:
        raise ValueError('Channel not at last axis: ' + str((h, w, c)))
    batch = np.empty((n, c, h, w) if channelfirst else (n, h, w, c))
    for i, image in enumerate(images):
        image = ni.add_channel(image, channelfirst)
        if image.shape != batch.shape[1:]:
            raise ValueError('Images vary in shape: ' + str(image.shape))
        batch[i, :, :, :] = image
    return batch.astype(dtype)
コード例 #2
0
def test_add_channel():
    image = np.ones((10, 20))
    assert ni.add_channel(image, True).shape == (1, 10, 20)
    assert ni.add_channel(image, False).shape == (10, 20, 1)

    image = np.ones((10, 20, 3))
    assert ni.add_channel(image, True).shape == (3, 10, 20)
    assert ni.add_channel(image, False).shape == (10, 20, 3)

    with pytest.raises(ValueError) as ex:
        image = np.ones((10,))
        ni.add_channel(image, True)
    assert str(ex.value).startswith('Image must be 2 or 3 channel!')

    with pytest.raises(ValueError) as ex:
        image = np.ones((10, 20, 3, 1))
        ni.add_channel(image, True)
    assert str(ex.value).startswith('Image must be 2 or 3 channel!')