示例#1
0
def get_params():
    args = get_args(image_cols = ['Image'])

    input_params = InputParameters(args)
    training_params = TrainingParameters(args)
    image_generation_params = ImageGenerationParameters(args)
    transformation_params = ImageDataTransformation.Parameters(samplewise_mean = True)

    return input_params, training_params, image_generation_params, transformation_params
示例#2
0
    def test_transform_with_featurewise_mean_fit_not_called(self):
        #Arrange
        images = TestImageDataTransformation.get_featurewise_mean_examples()

        #Transformation object
        parameters = ImageDataTransformation.Parameters(featurewise_mean = True)
        transformation = ImageDataTransformation(parameters = parameters)

        with self.assertRaises(ValueError):
            transformation.transform(images)
示例#3
0
    def test_transform_affine(self):
        #Rotation transformation
        parameters = ImageDataTransformation.Parameters(rotation_range = 20)
        self.transform_affine(parameters)

        #Shear transformation
        parameters = ImageDataTransformation.Parameters(shear_range = 10)
        self.transform_affine(parameters)

        #Zoom transformation
        parameters = ImageDataTransformation.Parameters(zoom_range = 0.2)
        self.transform_affine(parameters)

        #Width translation
        parameters = ImageDataTransformation.Parameters(width_shift_range = 0.2)
        self.transform_affine(parameters)

        #Height translation
        parameters = ImageDataTransformation.Parameters(height_shift_range = 0.2)
        self.transform_affine(parameters)
示例#4
0
    def transform_featurewise_std_normalization(self, featurewise_std_normalization, images, result):
        #Transformation object
        parameters = ImageDataTransformation.Parameters(featurewise_std_normalization = featurewise_std_normalization)
        transformation = ImageDataTransformation(parameters = parameters)

        #Fit and perform transformation
        transformation.fit(images)
        transformed_images = transformation.transform(images)
        transformed_std = transformed_images.std(axis = 0)

        #Assert
        self.assertTrue(np.allclose(transformed_std, result))
示例#5
0
    def transform_horizontal_flip(self, horizontal_flip, images, results):
        #Transformation object
        parameters = ImageDataTransformation.Parameters(horizontal_flip = horizontal_flip, horizontal_flip_prob = 1.0)
        transformation = ImageDataTransformation(parameters = parameters)

        #Transform
        transformed_images = transformation.transform(images)

        #Assert
        self.assertTrue(
                np.array_equal(transformed_images, results),
                "transformed_images: {} != expected: {}".format(transformed_images, results))
示例#6
0
    def transform_samplewise_mean(self, samplewise_mean, images, results):
        #Transformation object
        parameters = ImageDataTransformation.Parameters(samplewise_mean = samplewise_mean)
        transformation = ImageDataTransformation(parameters = parameters)

        #Transform
        transformed_images = transformation.transform(images)

        #Assert
        self.assertTrue(
            np.array_equal(transformed_images, results),
            "transformed_images: {} != expected: {}".format(transformed_images, results))
示例#7
0
def get_augmentation_executor():
    #Augmentation instances
    augmentation_instances = [
        #Horizontal flip
        ImageAugmentation.Instance(ImageDataTransformation.Parameters(
            horizontal_flip=True, horizontal_flip_prob=1.0),
                                   num_output_images=1),
        #Rotation
        ImageAugmentation.Instance(
            ImageDataTransformation.Parameters(rotation_range=20),
            num_output_images=5),
        #Zoom
        ImageAugmentation.Instance(
            ImageDataTransformation.Parameters(zoom_range=0.25),
            num_output_images=5),
        #Shear
        ImageAugmentation.Instance(
            ImageDataTransformation.Parameters(shear_range=15),
            num_output_images=5),
        #Width shift
        ImageAugmentation.Instance(
            ImageDataTransformation.Parameters(width_shift_range=.25),
            num_output_images=5),
        #Height shift
        ImageAugmentation.Instance(
            ImageDataTransformation.Parameters(height_shift_range=.20),
            num_output_images=5)
    ]

    executor = ImageAugmentation(augmentation_instances)

    return executor
示例#8
0
    def transform_featurewise_mean(self, featurewise_mean, images, result):
        #Transformation object
        parameters = ImageDataTransformation.Parameters(featurewise_mean = featurewise_mean)
        transformation = ImageDataTransformation(parameters = parameters)

        #Fit and perform transformation
        transformation.fit(images)
        transformed_images = transformation.transform(images)
        sum_image = transformed_images.sum(axis = 0)

        #Assert
        self.assertTrue(
                np.array_equal(sum_image, result),
                "Sum images: {} expected: {}".format(sum_image, result))
示例#9
0
    def transform_samplewise_std_normalization(self, samplewise_std_normalization, images):
        #Transformation object
        parameters = ImageDataTransformation.Parameters(samplewise_std_normalization = samplewise_std_normalization)
        transformation = ImageDataTransformation(parameters = parameters)

        #Transform
        transformed_images = transformation.transform(images)

        #Compute standard deviation
        standard_deviations = np.std(transformed_images, axis = (1, 2, 3))

        #Assert
        if samplewise_std_normalization:
            self.assertAlmostEqual(
                np.sum(standard_deviations), 
                2.,
                places = 2,
                msg = "standard_deviations: {} != expected: 2.".format(standard_deviations))
        else:
            self.assertNotEqual(np.sum(standard_deviations), 2.)
示例#10
0
def parse_args():
    parser = ArgumentParser(description = 'It trains a siamese network for whale identification.')
    parser.add_argument(
        '-m', '--model_name',
        required = True,
        help = 'It specifies the name of the model to train.')
    parser.add_argument(
        '-d', '--dataset_location',
        required = True, type = Path,
        help = 'It specifies the input dataset location.')
    parser.add_argument(
        '--image_cols',
        required = True, nargs = '+',
        help = 'It specifies the names of the image column in the dataframe.')
    parser.add_argument(
        '--image_transform_cols',
        nargs = '+',
        help = 'It specifies the names of the image column in the dataframe that are to be transformed.')
    parser.add_argument(
        '--label_col',
        required = True,
        help = 'It specifies the names of the label column.')
    parser.add_argument(
        '-b', '--batch_size',
        default = 128, type = int,
        help = 'It specifies the training batch size.')
    parser.add_argument(
        '-c', '--image_cache_size',
        default = 512, type = int,
        help = 'It specifies the image cache size.')
    parser.add_argument(
        '-s', '--validation_split',
        type = float,
        help = 'It specifies the validation split on the training dataset. It must be a float between 0 and 1')
    parser.add_argument(
        '-r', '--learning_rate',
        type = float,
        help = 'It specifies the learning rate of the optimization algorithm. It must be a float between 0 and 1')
    parser.add_argument(
        '-t', '--transformations',
        nargs = '+', default = [],
        type = kv_str_to_tuple,
        help = 'It specifies transformation parameters. Options: {}'
                    .format(ImageDataTransformation.Parameters().__dict__.keys()))
    parser.add_argument(
        '-x', '--num_fit_images',
        default = 1000, type = int,
        help = 'It specifies the number of images to send to fit()')
    parser.add_argument(
        '--epoch_id',
        default = 0, type = int,
        help = 'It specifies the start epoch id.')
    parser.add_argument(
        '--batch_id',
        default = 0, type = int,
        help = 'It specifies the start batch id.')
    parser.add_argument(
        '-e', '--number_of_epochs',
        type = int, default = 1,
        help = 'It specifies the number of epochs to train per input set.')
    parser.add_argument(
        '--input_shape',
        default = [224, 224, 3],
        type = int, nargs = 3,
        help = 'It specifies the shape of the image input.')
    parser.add_argument(
        '--number_prediction_steps', default = 2,
        type = int,
        help = 'It specifies the number of prediction steps to evaluate trained model.')
    parser.add_argument(
        '--checkpoint_batch_interval', default = 1,
        type = int,
        help = 'It specifies the number of batches after which to take a checkpoint.')
    parser.add_argument(
        '--training_method', default = TrainingMethod.TRAIN_ON_BATCH,
        type = TrainingMethod, choices = list(TrainingMethod),
        help = 'It specifies the training method to use')
    parser.add_argument(
        '-p', '--dropbox_parameters',
        nargs = 2,
        help = 'It specifies dropbox parameters required to upload the checkpoints.')
    parser.add_argument(
        '-l', '--log_to_console',
        action = 'store_true', default = False,
        help = 'It enables logging to console')

    args = parser.parse_args()

    return args
#Unittest
import unittest as ut
from unittest.mock import MagicMock

#Constants
from common import constants
from common import ut_constants

#Image augmentation
from operation.transform import ImageDataTransformation
from operation.augmentation import ImageAugmentation

#Data processing
import numpy as np

hflip_tranformation_params = ImageDataTransformation.Parameters(
    horizontal_flip=True)
hflip_num_output_images = 1

shear_transformation_params = ImageDataTransformation.Parameters(
    shear_range=10)
shear_num_output_images = 3


class TestImageAugmentationInstance(ut.TestCase):
    def test_init_invalid_params(self):
        with self.assertRaises(ValueError):
            _ = ImageAugmentation.Instance(None, 1)

    def test_init(self):
        #Arrange
        augmentation_instance = ImageAugmentation.Instance(