def test_random_crop_dict_mismatch_raises():
    with pytest.raises(IteratorValidationError):
        _ = RandomCrop(inner, shape_dict={'images': (1, 1)})
    with pytest.raises(IteratorValidationError):
        _ = RandomCrop(inner, shape_dict={'default': 10})
    with pytest.raises(IteratorValidationError):
        _ = RandomCrop(inner, shape_dict={'default': (1, 1, 1)})
    with pytest.raises(IteratorValidationError):
        _ = RandomCrop(inner, shape_dict={'default': (10, 10)})
def test_non5d_data_raises():
    with pytest.raises(IteratorValidationError):
        _ = Flip(Undivided(default=np.random.randn(2, 3, 1, 2)),
                 prob_dict={'default': 1})
    with pytest.raises(IteratorValidationError):
        _ = Pad(Undivided(default=np.random.randn(2, 3, 1, 2)),
                size_dict={'default': 1})
    with pytest.raises(IteratorValidationError):
        _ = RandomCrop(Undivided(default=np.random.randn(2, 3, 1, 2)),
                       shape_dict={'default': (1, 1)})
示例#3
0
def test_random_crop():
    a = np.random.randn(1, 3, 5, 5, 4)
    b = np.random.randn(1, 3, 4, 4, 1)
    c = np.random.randn(1, 3, 1)
    iterator = Undivided(default=a, secondary=b, targets=c)
    crop = RandomCrop(iterator, shape_dict={'default': (3, 3),
                                            'secondary': (2, 2)
                                            })(default_handler)
    x = next(crop)
    assert set(x.keys()) == set(iterator.data.keys())
    assert x['default'].shape == (1, 3, 3, 3, 4)
    assert x['secondary'].shape == (1, 3, 2, 2, 1)
    assert x['targets'].shape == (1, 3, 1)
    assert np.allclose(x['targets'], c)