def test_random_transforms():
    from mxnet.gluon.data.vision import transforms

    tmp_t = transforms.Compose([transforms.Resize(300), transforms.RandomResizedCrop(224)])
    transform = transforms.Compose([transforms.RandomApply(tmp_t, 0.5)])

    img = mx.nd.ones((10, 10, 3), dtype='uint8')
    iteration = 1000
    num_apply = 0
    for _ in range(iteration):
        out = transform(img)
        if out.shape[0] == 224:
            num_apply += 1
    assert_almost_equal(num_apply/float(iteration), 0.5, 0.1)
def test_random_transforms():
    from mxnet.gluon.data.vision import transforms

    counter = 0
    def transform_fn(x):
        nonlocal counter
        counter += 1
        return x
    transform = transforms.Compose([transforms.RandomApply(transform_fn, 0.5)])

    img = mx.np.ones((10, 10, 3), dtype='uint8')
    iteration = 10000
    num_apply = 0
    for _ in range(iteration):
        out = transform(img)
    assert counter == pytest.approx(5000, 1e-1)