Exemplo n.º 1
0
def test_random_crop_on_list(patch_size, n_inputs):
    '''test tf.data compatible random crop function on a list.

    NOTE we currently expect the test to fail as there's no extra handling
    of lists/tuples/dicts

    '''
    tf.random.set_seed(17)
    img_shape = tuple(2 * val for val in patch_size)
    img = np.arange(np.prod(img_shape)).reshape(img_shape)

    cropper = random_crop(patch_size)

    input_list = [key * img for key in range(1, n_inputs + 1)]
    first_patch = cropper(input_list)

    for _ in range(10):
        patch = cropper(input_list)

        # check if patches have the correct shape
        for vals in patch:
            vals = vals.numpy()
            assert vals.ndim == img.ndim
            assert np.all(vals.shape == patch_size)

        assert len(patch) == n_inputs

        # check if all inputs were cropped in the same location
        for (ii, first), (jj, second) in pairwise(enumerate(patch, start=1)):
            assert np.all(first.numpy() * jj == second.numpy() * ii)

        # make sure we're not drawing the same patch over and over
        for ii, vals in enumerate(patch):
            assert not np.all(vals.numpy() == first_patch[ii].numpy())
Exemplo n.º 2
0
def test_random_crop_on_dict(patch_size, n_inputs):
    '''test tf.data compatible random crop function on an input_dict.

    '''
    tf.random.set_seed(17)
    img_shape = tuple(2 * val for val in patch_size)
    img = np.arange(np.prod(img_shape)).reshape(img_shape)

    cropper = random_crop(patch_size)

    input_dict = {key: key * img for key in range(1, n_inputs + 1)}
    first_patch = cropper(input_dict)

    for _ in range(10):
        patch = cropper(input_dict)

        # check if patches have the correct shape
        for key in input_dict.keys():
            vals = patch[key].numpy()
            assert vals.ndim == img.ndim
            assert np.all(vals.shape == patch_size)

        assert len(patch.keys()) == n_inputs

        # check if all inputs were cropped in the same location
        for first_key, second_key in pairwise(patch):
            assert np.all(patch[first_key].numpy() *
                          second_key == patch[second_key].numpy() * first_key)

        # make sure we're not drawing the same patch over and over
        for key in patch.keys():
            assert not np.all(patch[key].numpy() == first_patch[key].numpy())
Exemplo n.º 3
0
def test_mismatching_shapes(shapes):
    '''test if shape mismatches in the input raise.

    '''
    patch_size = (13, 13, 1)
    inputs = [np.ones(shape) for shape in shapes]

    cropper = random_crop(patch_size)
    with pytest.raises(tf.errors.InvalidArgumentError):
        cropper(inputs)
Exemplo n.º 4
0
def test_random_crop_flexible(shapes, patch_size):
    '''test cropping with flexible channels.

    '''
    inputs = [np.ones(shape) for shape in shapes]

    cropper = random_crop(patch_size)
    patches = cropper(inputs)

    patch_size = np.asarray(patch_size)
    is_flexible = patch_size == -1

    for ii, patch in enumerate(patches):

        for dim in range(len(patch_size)):
            assert (patch.shape[dim] == patch_size[dim]
                    or (patch_size[dim] == -1
                        and patch.shape[dim] == inputs[ii].shape[dim]))