Exemple #1
0
    def __init__(self):
        # Variables to hold the description of the experiment
        self.description = "Training configuration file for the RGB version of the ResNet50 network."

        # System dependent variable
        self._workers = 10
        self._multiprocessing = True

        # Variables for comet.ml
        self._project_name = "jpeg-deep"
        self._workspace = "classification_resnet50"

        # Network variables
        self._weights = None
        self._network = ResNet50()

        # Training variables
        self._epochs = 90
        self._batch_size = 32
        self._steps_per_epoch = 1281167 // self.batch_size
        self._validation_steps = 50000 // self._batch_size
        self.optimizer_parameters = {"lr": 0.0125, "momentum": 0.9}
        self._optimizer = SGD(**self.optimizer_parameters)
        self._loss = categorical_crossentropy
        self._metrics = ['accuracy', 'top_k_categorical_accuracy']

        self.train_directory = join(environ["DATASET_PATH_TRAIN"], "train")
        self.validation_directory = join(environ["DATASET_PATH_VAL"],
                                         "validation")
        self.test_directory = join(environ["DATASET_PATH_TEST"], "validation")
        self.index_file = "data/imagenet_class_index.json"

        # Defining the transformations that will be applied to the inputs.
        self.train_transformations = [
            SmallestMaxSize(256),
            RandomCrop(224, 224),
            HorizontalFlip()
        ]

        self.validation_transformations = [
            SmallestMaxSize(256), CenterCrop(224, 224)
        ]

        self.test_transformations = [SmallestMaxSize(256)]

        # Keras stuff
        self._callbacks = []

        self._train_generator = None
        self._validation_generator = None
        self._test_generator = None

        # Stuff for display
        self._displayer = ImageNetDisplayer(self.index_file)
Exemple #2
0
    def __init__(self):
        # Variables to hold the description of the experiment
        self.description = "Training configuration file for the VGG deconvolution network."

        # System dependent variable
        self._workers = 5
        self._multiprocessing = True

        # Variables for comet.ml
        self._project_name = "jpeg_deep"
        self._workspace = "classification_dct_deconv"

        # Network variables
        self._weights = None
        self._network = VGG16_dct_deconv()

        # Training variables
        self._epochs = 180
        self._batch_size = 64
        self._steps_per_epoch = 1281167 // self._batch_size
        self._validation_steps = 50000 // self._batch_size
        self.optimizer_parameters = {
            "lr": 0.0025, "momentum": 0.9}
        self._optimizer = SGD(**self.optimizer_parameters)
        self._loss = categorical_crossentropy
        self._metrics = ['accuracy', 'top_k_categorical_accuracy']

        self.train_directory = join(
            environ["DATASET_PATH_TRAIN"], "train")
        self.validation_directory = join(
            environ["DATASET_PATH_VAL"], "validation")
        self.test_directory = join(
            environ["DATASET_PATH_TEST"], "validation")
        self.index_file = "data/imagenet_class_index.json"

        # Defining the transformations that will be applied to the inputs.
        self.train_transformations = [
            SmallestMaxSize(256),
            RandomCrop(224, 224),
            HorizontalFlip()
        ]

        self.validation_transformations = [
            SmallestMaxSize(256), CenterCrop(224, 224)]

        self.test_transformations = [SmallestMaxSize(256)]

        # Keras stuff
        self.reduce_lr_on_plateau = ReduceLROnPlateau(patience=5, verbose=1)
        self.terminate_on_nan = TerminateOnNaN()
        self.early_stopping = EarlyStopping(monitor='val_loss',
                                            min_delta=0,
                                            patience=11)

        self._callbacks = [self.reduce_lr_on_plateau,
                           self.terminate_on_nan, self.early_stopping]

        # Creating the training and validation generator
        self._train_generator = None
        self._validation_generator = None
        self._test_generator = None

        self._displayer = ImageNetDisplayer(self.index_file)