def test_random_rotation():
    # test exceptions for probability input outside of [0,1]
    assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=1.1)
    assertRaises(ValueError, transforms.RandomRotation, [-10, 10.], rotate_with_proba=-0.3)
    # test `forward`
    transformer = transforms.RandomRotation([-10, 10.])
    assertRaises(TypeError, transformer, mx.nd.ones((3, 30, 60), dtype='uint8'))
    single_image = mx.nd.ones((3, 30, 60), dtype='float32')
    single_output = transformer(single_image)
    assert same(single_output.shape, (3, 30, 60))
    batch_image = mx.nd.ones((3, 3, 30, 60), dtype='float32')
    batch_output = transformer(batch_image)
    assert same(batch_output.shape, (3, 3, 30, 60))
    # test identity (rotate_with_proba = 0)
    transformer = transforms.RandomRotation([-100., 100.], rotate_with_proba=0.0)
    data = mx.nd.random_normal(shape=(3, 30, 60))
    assert_almost_equal(data, transformer(data))
def test_transformer():
    from mxnet.gluon.data.vision import transforms

    transform = transforms.Compose([
        transforms.Resize(300),
        transforms.Resize(300, keep_ratio=True),
        transforms.CenterCrop(256),
        transforms.RandomResizedCrop(224),
        transforms.RandomFlipLeftRight(),
        transforms.RandomColorJitter(0.1, 0.1, 0.1, 0.1),
        transforms.RandomBrightness(0.1),
        transforms.RandomContrast(0.1),
        transforms.RandomSaturation(0.1),
        transforms.RandomHue(0.1),
        transforms.RandomLighting(0.1),
        transforms.ToTensor(),
        transforms.RandomRotation([-10., 10.]),
        transforms.Normalize([0, 0, 0], [1, 1, 1])])

    transform(mx.nd.ones((245, 480, 3), dtype='uint8')).wait_to_read()
Beispiel #3
0
def prepare_data_mxnet(batch_size=256):
    train_transforms = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.RandomRotation((-30, 30)),
        transforms.Normalize(0.1307, 0.3081),
    ])
    test_transforms = transforms.Compose([
        transforms.Resize((32, 32)),
        transforms.ToTensor(),
        transforms.Normalize(0.1307, 0.3081),
    ])

    train_data = gluon.data.DataLoader(
        datasets.MNIST(train=True).transform_first(train_transforms),
        batch_size=batch_size,
        shuffle=True,
        num_workers=8)
    test_data = gluon.data.DataLoader(
        datasets.MNIST(train=False).transform_first(test_transforms),
        batch_size=batch_size,
        shuffle=False,
        num_workers=8)
    return train_data, test_data
Beispiel #4
0

# Data augmentation definitions 
transform_train = transforms.Compose([
    # Randomly crop an area, and then resize it to be 32x32
    transforms.RandomResizedCrop(opt.crop_size,scale=(0.6,1.)),# test also with 0.6
    # Randomly flip the image horizontally/vertically
    transforms.RandomFlipLeftRight(),
    transforms.RandomFlipTopBottom(),
    # Randomly jitter the brightness, contrast and saturation of the image
    transforms.RandomColorJitter(brightness=0.9, contrast=0.9, saturation=0.9), #NEW hue
    # Transpose the image from height*width*num_channels to num_channels*height*width
    # and map values from [0, 255] to [0,1]
    #transforms.RandomGray(p=0.35), # Random gray scale NEW
    transforms.ToTensor(),
    transforms.RandomRotation(angle_limits=(-90,90),zoom_in=True), # Random rotation 
    # Normalize the image with mean and standard deviation calculated across all images
    transforms.Normalize([11.663384, 10.260227,  7.65015 ], [21.421959, 18.044296, 15.494861])
])

transform_test = transforms.Compose([
    transforms.Resize(opt.crop_size),
    transforms.ToTensor(),
    transforms.Normalize([11.663384, 10.260227,  7.65015 ], [21.421959, 18.044296, 15.494861])
])

# Datasets/DataLoaders 
from GalaxyZoo.src.GZooDataset import *
dataset_train = GZooData(root=os.path.join(opt.root,'GalaxyZoo'), transform=transform_train)
datagen_train = gluon.data.DataLoader(dataset_train,batch_size=opt.batch_size,shuffle=True,last_batch='rollover',pin_memory=True,num_workers=16)