def test_random_transforms(): from mxnet.gluon.data.vision import transforms tmp_t = transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)]) transform = transforms.Compose([transforms.RandomApply(tmp_t, 0.5)]) img = mx.nd.ones((10, 10, 3), dtype='uint8') iteration = 1000 num_apply = 0 for _ in range(iteration): out = transform(img) if out.shape[0] == 224: num_apply += 1 assert_almost_equal(num_apply/float(iteration), 0.5, 0.1)
def test_random_transforms(): from mxnet.gluon.data.vision import transforms counter = 0 def transform_fn(x): nonlocal counter counter += 1 return x transform = transforms.Compose([transforms.RandomApply(transform_fn, 0.5)]) img = mx.np.ones((10, 10, 3), dtype='uint8') iteration = 10000 num_apply = 0 for _ in range(iteration): out = transform(img) assert counter == pytest.approx(5000, 1e-1)