def test_random_crop_dict_mismatch_raises(): with pytest.raises(IteratorValidationError): _ = RandomCrop(inner, shape_dict={'images': (1, 1)}) with pytest.raises(IteratorValidationError): _ = RandomCrop(inner, shape_dict={'default': 10}) with pytest.raises(IteratorValidationError): _ = RandomCrop(inner, shape_dict={'default': (1, 1, 1)}) with pytest.raises(IteratorValidationError): _ = RandomCrop(inner, shape_dict={'default': (10, 10)})
def test_non5d_data_raises(): with pytest.raises(IteratorValidationError): _ = Flip(Undivided(default=np.random.randn(2, 3, 1, 2)), prob_dict={'default': 1}) with pytest.raises(IteratorValidationError): _ = Pad(Undivided(default=np.random.randn(2, 3, 1, 2)), size_dict={'default': 1}) with pytest.raises(IteratorValidationError): _ = RandomCrop(Undivided(default=np.random.randn(2, 3, 1, 2)), shape_dict={'default': (1, 1)})
def test_random_crop(): a = np.random.randn(1, 3, 5, 5, 4) b = np.random.randn(1, 3, 4, 4, 1) c = np.random.randn(1, 3, 1) iterator = Undivided(default=a, secondary=b, targets=c) crop = RandomCrop(iterator, shape_dict={'default': (3, 3), 'secondary': (2, 2) })(default_handler) x = next(crop) assert set(x.keys()) == set(iterator.data.keys()) assert x['default'].shape == (1, 3, 3, 3, 4) assert x['secondary'].shape == (1, 3, 2, 2, 1) assert x['targets'].shape == (1, 3, 1) assert np.allclose(x['targets'], c)