예제 #1
0
    def test_random_transforms(self):
        x = np.random.random((2, 28, 28))
        assert image.random_rotation(x, 45).shape == (2, 28, 28)
        assert image.random_shift(x, 1, 1).shape == (2, 28, 28)
        assert image.random_shear(x, 20).shape == (2, 28, 28)
        assert image.random_zoom(x, (5, 5)).shape == (2, 28, 28)
        assert image.random_channel_shift(x, 20).shape == (2, 28, 28)

        # Test get_random_transform with predefined seed
        seed = 1
        generator = image.ImageDataGenerator(
            rotation_range=90.,
            width_shift_range=0.1,
            height_shift_range=0.1,
            shear_range=0.5,
            zoom_range=0.2,
            channel_shift_range=0.1,
            brightness_range=(1, 5),
            horizontal_flip=True,
            vertical_flip=True)
        transform_dict = generator.get_random_transform(x.shape, seed)
        transform_dict2 = generator.get_random_transform(x.shape, seed * 2)
        assert transform_dict['theta'] != 0
        assert transform_dict['theta'] != transform_dict2['theta']
        assert transform_dict['tx'] != 0
        assert transform_dict['tx'] != transform_dict2['tx']
        assert transform_dict['ty'] != 0
        assert transform_dict['ty'] != transform_dict2['ty']
        assert transform_dict['shear'] != 0
        assert transform_dict['shear'] != transform_dict2['shear']
        assert transform_dict['zx'] != 0
        assert transform_dict['zx'] != transform_dict2['zx']
        assert transform_dict['zy'] != 0
        assert transform_dict['zy'] != transform_dict2['zy']
        assert transform_dict['channel_shift_intensity'] != 0
        assert (transform_dict['channel_shift_intensity'] !=
                transform_dict2['channel_shift_intensity'])
        assert transform_dict['brightness'] != 0
        assert transform_dict['brightness'] != transform_dict2['brightness']

        # Test get_random_transform without any randomness
        generator = image.ImageDataGenerator()
        transform_dict = generator.get_random_transform(x.shape, seed)
        assert transform_dict['theta'] == 0
        assert transform_dict['tx'] == 0
        assert transform_dict['ty'] == 0
        assert transform_dict['shear'] == 0
        assert transform_dict['zx'] == 1
        assert transform_dict['zy'] == 1
        assert transform_dict['channel_shift_intensity'] is None
        assert transform_dict['brightness'] is None
예제 #2
0
    def get_batch():
        index = 1

        global current_index

        B = np.zeros(shape=(batch_size, IMAGE_SIZE, IMAGE_SIZE, 3))
        L = np.zeros(shape=(batch_size))
        while index < batch_size:
            try:
                img = load_img(images[current_index],
                               target_size=(IMAGE_SIZE, IMAGE_SIZE))
                img = img_to_array(img)
                img /= 255.
                # if cnn == 'ResNet50': # imagenet pretrained
                #     mean = np.array([0.485, 0.456, 0.406])
                #     std = np.array([0.229, 0.224, 0.225])
                #     img = (img - mean)/std
                ## data augmentation
                # random width and height shift
                img = random_shift(img, 0.2, 0.2)
                # random rotation
                img = random_rotation(img, 10)
                # random horizental flip
                flip_horizontal = (np.random.random() < 0.5)
                if flip_horizontal:
                    img = flip_axis(img, axis=1)
                # # random vertical flip
                # flip_vertical = (np.random.random() < 0.5)
                # if flip_vertical:
                #     img = flip_axis(img, axis=0)
                # #cutout
                # eraser = get_random_eraser(v_l=0, v_h=1, pixel_level=False)
                # img = eraser(img)

                B[index] = img
                L[index] = labels[current_index]
                index = index + 1
                current_index = current_index + 1
            except:
                traceback.print_exc()
                # print("Ignore image {}".format(images[current_index]))
                current_index = current_index + 1
        # B = np.rollaxis(B, 3, 1)
        return B, np_utils.to_categorical(L, num_classes)