예제 #1
0
    def __init__(self, args):
        # Training hyperparameters
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.epoch = args.epoch
        self.save_model_period = 1 # save model weights every N epochs
        # Training and validation dataset paths
        self.train_data_path = './data/train'
        self.val_data_path = './data/validation'
        # Where to save and load model weights (=checkpoints)
        self.checkpoints_dir = './checkpoints'
        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        self.ckpt_save_name = 'classTemplate'
        # Where to save tensorboard summaries
        self.summaries_dir = './summaries'
        if not os.path.exists(self.summaries_dir):
            os.makedirs(self.summaries_dir)

        # Get training dataset as lists of image paths
        self.train_gt_data_list = get_filepaths_from_dir(self.train_data_path)
        if len(self.train_gt_data_list) is 0:
            raise ValueError("No training data found in folder {}".format(self.train_data_path))
        elif (len(self.train_gt_data_list) < self.batch_size):
            raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
                .format(self.batch_size, len(self.train_gt_data_list)))

        # Get validation dataset if provided
        self.has_val_data = True
        self.val_gt_data_list = get_filepaths_from_dir(self.val_data_path)
        if len(self.val_gt_data_list) is 0:
            print("No validation data found in {}, 20% of training data will be used as validation data".format(self.val_data_path))
            self.has_val_data = False
            self.validation_split = 0.2
        elif (len(self.val_gt_data_list) < self.batch_size):
            raise ValueError("Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
                .format(self.batch_size, len(self.val_gt_data_list)))
        else:
            print_("Number of validation data: {}\n".format(len(self.val_gt_data_list)), 'm')
            self.validation_split = 0.0
        
        self.train_labels = get_labels_from_dir(self.train_data_path)
        # Check class labels are the same
        if self.has_val_data:            
            self.val_labels = get_labels_from_dir(self.val_data_path)
            if self.train_labels != self.val_labels:
                if len(self.train_labels) != len(self.val_labels):
                    raise ValueError("{} and {} should have the same number of subdirectories ({}!={})"
                    .format(self.train_data_path, self.val_data_path, len(self.train_labels), len(self.val_labels)))
                raise ValueError("{} and {} should have the same subdirectory label names ({}!={})"
                    .format(self.train_data_path, self.val_data_path, self.train_labels, self.val_labels))
        
        # Compute and print training hyperparameters
        self.batch_per_epoch = int(np.ceil(len(self.train_gt_data_list) / float(self.batch_size)))
        self.max_steps = int(self.epoch * (self.batch_per_epoch))
        print_("Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
            .format(len(self.train_gt_data_list), self.batch_per_epoch, self.batch_size, self.epoch, self.max_steps), 'm')
        print("Class labels: {}".format(self.train_labels))
예제 #2
0
 def load(self, sess, checkpoint_dir):
     ckpt_names = get_ckpt_list(checkpoint_dir)
     if not ckpt_names:  # list is empty
         print_("No checkpoints found in {}\n".format(checkpoint_dir), 'm')
         return False
     else:
         print_("Found checkpoints:\n", 'm')
         for name in ckpt_names:
             print("    {}".format(name))
         # Ask user if they prefer to start training from scratch or resume training on a specific ckeckpoint
         while True:
             mode = str(
                 input(
                     'Start training from scratch (start) or resume training from a previous checkpoint (choose one of the above): '
                 ))
             if mode == 'start' or mode in ckpt_names:
                 break
             else:
                 print(
                     "Answer should be 'start' or one of the following checkpoints: {}"
                     .format(ckpt_names))
                 continue
         if mode == 'start':
             return False
         elif mode in ckpt_names:
             # Try to load given intermediate checkpoint
             print_("Loading trained model...\n", 'm')
             self.saver.restore(sess, os.path.join(checkpoint_dir, mode))
             print_("...Checkpoint {} loaded\n".format(mode), 'm')
             return True
         else:
             raise ValueError(
                 "User input is neither 'start' nor a valid checkpoint")
예제 #3
0
    def evaluate(self, test_data_path, weights):
        """Evaluate a trained model on the test dataset
        
        Args:
            test_data_path (str): path to directory containing images for testing
            weights (str): name of the tensorflow checkpoint (weights) to evaluate
        """
        test_data_list = get_filepaths_from_dir(test_data_path)
        if not test_data_list:
            raise ValueError(
                "No test data found in folder {}".format(test_data_path))
        elif (len(self.train_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of test data = {})"
                .format(self.batch_size, len(test_data_list)))
        self.is_exr = is_exr(test_data_list[0])

        # Get and create test dataset
        ds_test = self.get_data(test_data_list, self.batch_size, 1)
        for x, y in ds_test.take(1):  # take one batch from ds_test
            testX, testY = x, y
        print_("Number of test data: {}\n".format(len(test_data_list)), 'm')
        print("Input shape {}, target shape: {}".format(
            testX.shape, testY.shape))

        # Build model
        model = self.get_compiled_model(testX.shape[1:])

        # Load model weights
        print_("Loading trained model for testing...\n", 'm')
        model.load_weights(os.path.join(self.ckpt_dir,
                                        weights)).expect_partial()
        print_("...Checkpoint {} loaded\n".format(weights), 'm')

        # Test final model on this unseen dataset
        results = model.evaluate(ds_test)
        print("test loss, test acc:", results)
        print_("--------End of testing--------\n", 'm')
    def train(self):
        # Build model
        self.model = mobilenet_transfer(len(self.train_labels))
        # Configure the model for training
        self.model.compile(optimizer=tf.keras.optimizers.Adam(),
                           loss='categorical_crossentropy',
                           metrics=['accuracy'])
        # Print current model layers
        # self.model.summary()

        # Set preprocessing function
        datagen = tf.keras.preprocessing.image.ImageDataGenerator(
            # scale pixels between -1 and 1, sample-wise
            preprocessing_function=tf.keras.applications.mobilenet.
            preprocess_input,
            validation_split=self.validation_split)
        # Get classification data
        train_generator = datagen.flow_from_directory(
            self.train_data_path,
            target_size=(224, 224),
            color_mode='rgb',
            batch_size=self.batch_size,
            class_mode='categorical',
            shuffle=True,
            subset='training')
        if self.has_val_data:
            validation_generator = datagen.flow_from_directory(
                self.val_data_path,
                target_size=(224, 224),
                color_mode='rgb',
                batch_size=self.batch_size,
                class_mode='categorical',
                shuffle=True)
        else:  # Generate a split of the training data as validation data
            validation_generator = datagen.flow_from_directory(
                self.train_data_path,  # subset from training data path
                target_size=(224, 224),
                color_mode='rgb',
                batch_size=self.batch_size,
                class_mode='categorical',
                shuffle=True,
                subset='validation')

        # Callback for creating Tensorboard summary
        summary_name = "classif_data{}_bch{}_ep{}".format(
            len(self.train_gt_data_list), self.batch_size, self.epoch)
        tensorboard_callback = tf.keras.callbacks.TensorBoard(
            log_dir=os.path.join(self.summaries_dir, summary_name))
        # Callback for saving models periodically
        class_labels_save = '_'.join(self.train_labels) + '.'
        # 'acc' is the training accuracy and 'val_acc' is the validation set accuracy
        self.ckpt_save_name = class_labels_save + self.ckpt_save_name + "-val_acc{val_acc:.2f}-acc{acc:.2f}-ep{epoch:04d}.h5"
        checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=os.path.join(self.checkpoints_dir, self.ckpt_save_name),
            save_weights_only=False,
            period=self.save_model_period,
            save_best_only=True,
            monitor='val_acc',
            mode='max')

        # Check if there are intermediate trained model to load
        # Uncomment following lines if you want to resume from a previous saved model
        # if not self.load_model():
        #     print_("Starting training from scratch\n", 'm')

        # Train the model
        fit_history = self.model.fit_generator(
            generator=train_generator,
            steps_per_epoch=train_generator.n // self.batch_size,
            validation_data=validation_generator,
            validation_steps=validation_generator.n // self.batch_size,
            epochs=self.epoch,
            callbacks=[checkpoint_callback, tensorboard_callback])

        print_("--------End of training--------\n", 'm')
예제 #5
0
    def __init__(self, args):
        # Training hyperparameters
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.epoch = args.epoch
        self.no_resume = args.no_resume
        # A random seed (!=None) allows you to reproduce your training results
        self.seed = args.seed
        if self.seed is not None:
            # Set all seeds necessary for deterministic training
            enable_deterministic_training(self.seed, args.no_gpu_patch)
        self.crop_size = 256
        self.n_levels = 3
        self.scale = 0.5
        self.channels = 3  # input / output channels
        # Training and validation dataset paths
        train_in_data_path = './data/train/input'
        train_gt_data_path = './data/train/groundtruth'
        val_in_data_path = './data/validation/input'
        val_gt_data_path = './data/validation/groundtruth'

        # Where to save and load model weights (=checkpoints)
        self.checkpoints_dir = './checkpoints'
        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        self.ckpt_save_name = args.ckpt_save_name
        # Maximum number of recent checkpoint files to keep
        self.max_ckpts_to_keep = 50
        # In addition keep one checkpoint file for every N hours of training
        self.keep_ckpt_every_n_hours = 1
        # How often, in training steps. we save model checkpoints
        self.ckpts_save_freq = 1000
        # How often, in training steps. we print training losses to bash
        self.training_print_freq = 10

        # Where to save tensorboard summaries
        self.summaries_dir = './summaries'
        if not os.path.exists(self.summaries_dir):
            os.makedirs(self.summaries_dir)
        # How often, in training steps. we save tensorboard summaries
        self.summaries_save_freq = 10
        # How often, in secs, we flush the pending tensorboard summaries to disk
        self.summary_flush_secs = 30

        # Get training dataset as lists of image paths
        self.train_in_data_list = get_filepaths_from_dir(train_in_data_path)
        self.train_gt_data_list = get_filepaths_from_dir(train_gt_data_path)
        if not self.train_in_data_list or not self.train_gt_data_list:
            raise ValueError(
                "No training data found in folders {} or {}".format(
                    train_in_data_path, train_gt_data_path))
        elif len(self.train_in_data_list) != len(self.train_gt_data_list):
            raise ValueError(
                "{} ({} data) and {} ({} data) should have the same number of input data"
                .format(train_in_data_path, len(self.train_in_data_list),
                        train_gt_data_path, len(self.train_gt_data_list)))
        elif (len(self.train_in_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
                .format(self.batch_size, len(self.train_in_data_list)))
        self.is_exr = is_exr(self.train_in_data_list[0])

        # Get validation dataset if provided
        self.has_val_data = True
        self.val_in_data_list = get_filepaths_from_dir(val_in_data_path)
        self.val_gt_data_list = get_filepaths_from_dir(val_gt_data_path)
        if not self.val_in_data_list or not self.val_gt_data_list:
            print("No validation data found in {} or {}".format(
                val_in_data_path, val_gt_data_path))
            self.has_val_data = False
        elif len(self.val_in_data_list) != len(self.val_gt_data_list):
            raise ValueError(
                "{} ({} data) and {} ({} data) should have the same number of input data"
                .format(val_in_data_path, len(self.val_in_data_list),
                        val_gt_data_path, len(self.val_gt_data_list)))
        elif (len(self.val_in_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
                .format(self.batch_size, len(self.val_in_data_list)))
        else:
            val_is_exr = is_exr(self.val_in_data_list[0])
            if (val_is_exr and not self.is_exr) or (not val_is_exr
                                                    and self.is_exr):
                raise TypeError(
                    "Train and validation data should have the same file format"
                )
            print("Number of validation data: {}".format(
                len(self.val_in_data_list)))

        # Compute and print training hyperparameters
        batch_per_epoch = (len(self.train_in_data_list)) // self.batch_size
        self.max_steps = int(self.epoch * (batch_per_epoch))
        print_(
            "Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
            .format(len(self.train_in_data_list), batch_per_epoch,
                    self.batch_size, self.epoch, self.max_steps), 'm')
예제 #6
0
    def train(self):
        # Build model
        model = EncoderDecoder(self.n_levels, self.scale, self.channels)

        # Learning rate decay
        global_step = tf.Variable(initial_value=0,
                                  dtype=tf.int32,
                                  trainable=False)
        self.lr = tf.compat.v1.train.polynomial_decay(
            self.learning_rate,
            global_step,
            decay_steps=self.max_steps,
            end_learning_rate=0.0,
            power=0.3)
        tf.compat.v1.summary.scalar('learning_rate', self.lr)
        # Training operator
        adam = tf.compat.v1.train.AdamOptimizer(self.lr)

        # Get next data from preprocessed training dataset
        img_in, img_gt = self.get_data(self.train_in_data_list,
                                       self.train_gt_data_list,
                                       self.batch_size, self.epoch)
        print('img_in, img_gt', img_in.shape, img_gt.shape)
        tf.compat.v1.summary.image('img_in', im2uint8(img_in))
        tf.compat.v1.summary.image('img_gt', im2uint8(img_gt))

        # Compute image loss
        n_outputs = model(img_in, reuse=False)
        loss_op = self.loss(n_outputs, img_gt)
        # By default, adam uses the current graph trainable_variables to optimise training,
        # thus train_op should be the last operation of the graph for training.
        train_op = adam.minimize(loss_op, global_step)

        # Create session
        sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))

        # Initialise all the variables in current session
        init = tf.compat.v1.global_variables_initializer()
        sess.run(init)
        self.saver = tf.compat.v1.train.Saver(
            max_to_keep=self.max_ckpts_to_keep,
            keep_checkpoint_every_n_hours=self.keep_ckpt_every_n_hours)

        # Check if there are intermediate trained model to load
        if self.no_resume or not self.load(sess, self.checkpoints_dir):
            print_("Starting training from scratch\n", 'm')

        # Tensorboard summary
        summary_op = tf.compat.v1.summary.merge_all()
        summary_name = ("data{}_bch{}_ep{}".format(
            len(self.train_in_data_list), self.batch_size, self.epoch))
        summary_name += ("_seed{}".format(self.seed)
                         if self.seed is not None else "")
        summary_writer = tf.compat.v1.summary.FileWriter(
            os.path.join(self.summaries_dir, summary_name),
            graph=sess.graph,
            flush_secs=self.summary_flush_secs)

        # Compute loss on validation dataset to check overfitting
        if self.has_val_data:
            val_loss_op = self.validate(model)
            # Save validation loss to tensorboard
            val_summary_op = tf.compat.v1.summary.scalar(
                'val_loss', val_loss_op)
            # Compute initial loss
            val_loss, val_summary = sess.run([val_loss_op, val_summary_op])
            summary_writer.add_summary(val_summary, global_step=0)
            print(
                "Initial Loss on validation dataset: {:.6f}".format(val_loss))

        ################ TRAINING ################
        train_start = time.time()
        for step in range(sess.run(global_step), self.max_steps):
            start_time = time.time()
            val_str = ''
            if step % self.summaries_save_freq == 0 or step == self.max_steps - 1:
                # Train model and record summaries
                _, loss_total, summary = sess.run(
                    [train_op, loss_op, summary_op])
                summary_writer.add_summary(summary, global_step=step)
                duration = time.time() - start_time
                if self.has_val_data and step != 0:
                    # Compute validation loss
                    val_loss, val_summary = sess.run(
                        [val_loss_op, val_summary_op])
                    summary_writer.add_summary(val_summary, global_step=step)
                    val_str = ', val loss: {:.6f}'.format(val_loss)
            else:  # Train only
                _, loss_total = sess.run([train_op, loss_op])
                duration = time.time() - start_time
            assert not np.isnan(loss_total), 'Model diverged with loss = NaN'

            if step % self.training_print_freq == 0 or step == self.max_steps - 1:
                examples_per_sec = self.batch_size / duration
                sec_per_batch = float(duration)
                format_str = (
                    '{}: step {}, loss: {:.6f} ({:.1f} data/s; {:.3f} s/bch)'.
                    format(datetime.now().strftime('%Y-%m-%d %H:%M:%S'), step,
                           loss_total, examples_per_sec, sec_per_batch))
                print(format_str + val_str)

            if (step + 1
                ) % self.ckpts_save_freq == 0 or step == self.max_steps - 1:
                # Save current model in a checkpoint
                self.save(sess, self.checkpoints_dir, step + 1)
        print_(
            "Training duration: {:0.4f}s\n".format(time.time() - train_start),
            'm')
        print_("--------End of training--------\n", 'm')
        # Free all resources associated with the session
        sess.close()
예제 #7
0
    def train(self):
        # Create a session so that tf.keras don't allocate all GPU memory at once
        sess = tf.compat.v1.Session(config=tf.compat.v1.ConfigProto(
            gpu_options=tf.compat.v1.GPUOptions(allow_growth=True)))
        tf.compat.v1.keras.backend.set_session(sess)

        # Get training and validation dataset
        ds_train = self.get_data(self.train_data_list, self.batch_size,
                                 self.epoch)
        for x, y in ds_train.take(1):  # take one batch from ds_train
            trainX, trainY = x, y
        print("Input shape {}, target shape: {}".format(
            trainX.shape, trainY.shape))
        if self.has_val_data:
            ds_val = self.get_data(self.val_data_list, self.batch_size,
                                   self.epoch)
        print("********Data Created********")

        # Build model
        model = self.get_compiled_model(trainX.shape[1:])

        # Check if there are intermediate trained model to load
        if self.no_resume or not self.load(model):
            print_("Starting training from scratch\n", 'm')

        # Callback for creating Tensorboard summary
        summary_name = ("data{}_bch{}_ep{}".format(len(self.train_data_list),
                                                   self.batch_size,
                                                   self.epoch))
        summary_name += ("_seed{}".format(self.seed)
                         if self.seed is not None else "")
        summary_writer = tf.contrib.summary.create_file_writer(
            os.path.join(self.summaries_dir, summary_name))
        tb_callback = self.tensorboard_callback(summary_writer)

        # Callback for saving model's weights
        ckpt_path = os.path.join(self.ckpt_dir,
                                 self.ckpt_save_name + "-ep{epoch:02d}")
        ckpt_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath=ckpt_path,
            # save best model based on monitor value
            monitor='val_loss' if self.has_val_data else 'loss',
            verbose=1,
            save_best_only=True,
            save_weights_only=True)

        # Evaluate the model before training
        if self.has_val_data:
            val_loss, val_bin_acc = model.evaluate(ds_val.take(20), verbose=1)
            print(
                "Initial Loss on validation dataset: {:.4f}".format(val_loss))

        # TRAIN model
        print_("--------Start of training--------\n", 'm')
        print("NOTE:\tDuring training, the latest model is saved only if its\n"
              "\t(validation) loss is better than the last best model.")
        train_start = time.time()
        model.fit(ds_train,
                  validation_data=ds_val if self.has_val_data else None,
                  epochs=self.epoch,
                  steps_per_epoch=self.batch_per_epoch,
                  validation_steps=self.val_batch_per_epoch
                  if self.has_val_data else None,
                  callbacks=[ckpt_callback, tb_callback],
                  verbose=1)
        print_(
            "Training duration: {:0.4f}s\n".format(time.time() - train_start),
            'm')
        print_("--------End of training--------\n", 'm')

        # Show predictions on the first batch of training data
        print(
            "Parameter prediction (PR) compared to groundtruth (GT) for first batch of training data:"
        )
        preds_train = model.predict(trainX.numpy())
        print("Train GT:", trainY.numpy().flatten())
        print("Train PR:", preds_train.flatten())
        # Make predictions on the first batch of validation data
        if self.has_val_data:
            print("For first batch of validation data:")
            for x, y in ds_val.take(1):  # take one batch from ds_val
                valX, valY = x, y
            preds_val = model.predict(valX)
            print("Val GT:", valY.numpy().flatten())
            print("Val PR:", preds_val.flatten())
        # Free all resources associated with the session
        sess.close()
예제 #8
0
    def __init__(self, args):
        # Training hyperparameters
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.epoch = args.epoch
        self.patch_size = 50
        self.channels = 3  # input / output channels
        self.output_param_number = 1
        self.no_resume = args.no_resume
        # A random seed (!=None) allows you to reproduce your training results
        self.seed = args.seed
        if self.seed is not None:
            # Set all seeds necessary for deterministic training
            enable_deterministic_training(self.seed, args.no_gpu_patch)
        # Training and validation dataset paths
        train_data_path = './data/train/'
        val_data_path = './data/validation/'

        # Where to save and load model weights (=checkpoints)
        self.ckpt_dir = './checkpoints'
        if not os.path.exists(self.ckpt_dir):
            os.makedirs(self.ckpt_dir)
        self.ckpt_save_name = args.ckpt_save_name

        # Where to save tensorboard summaries
        self.summaries_dir = './summaries/'
        if not os.path.exists(self.summaries_dir):
            os.makedirs(self.summaries_dir)

        # Get training dataset as list of image paths
        self.train_data_list = get_filepaths_from_dir(train_data_path)
        if not self.train_data_list:
            raise ValueError(
                "No training data found in folder {}".format(train_data_path))
        elif (len(self.train_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
                .format(self.batch_size, len(self.train_data_list)))
        self.is_exr = is_exr(self.train_data_list[0])

        # Compute and print training hyperparameters
        self.batch_per_epoch = (len(self.train_data_list)) // self.batch_size
        max_steps = int(self.epoch * (self.batch_per_epoch))
        print_(
            "Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
            .format(len(self.train_data_list), self.batch_per_epoch,
                    self.batch_size, self.epoch, max_steps), 'm')

        # Get validation dataset if provided
        self.has_val_data = True
        self.val_data_list = get_filepaths_from_dir(val_data_path)
        if not self.val_data_list:
            print("No validation data found in {}".format(val_data_path))
            self.has_val_data = False
        elif (len(self.val_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
                .format(self.batch_size, len(self.val_data_list)))
        else:
            val_is_exr = is_exr(self.val_data_list[0])
            if (val_is_exr and not self.is_exr) or (not val_is_exr
                                                    and self.is_exr):
                raise TypeError(
                    "Train and validation data should have the same file format"
                )
            self.val_batch_per_epoch = (len(
                self.val_data_list)) // self.batch_size
            print(
                "Number of validation data: {}\nNumber of validation batches per epoch: {} (batch size = {})"
                .format(len(self.val_data_list), self.val_batch_per_epoch,
                        self.batch_size))
예제 #9
0
    def __init__(self, args):
        # Training hyperparameters
        self.learning_rate = args.learning_rate
        self.batch_size = args.batch_size
        self.epoch = args.epoch
        self.crop_size = 256
        self.n_levels = 3
        self.scale = 0.5
        self.channels = 3  # input / output channels
        # Training and validation dataset paths
        train_in_data_path = './data/train/input'
        train_gt_data_path = './data/train/groundtruth'
        val_in_data_path = './data/val/input'
        val_gt_data_path = './data/val/groundtruth'
        # Where to save and load model weights (=checkpoints)
        self.checkpoints_dir = './checkpoints'
        if not os.path.exists(self.checkpoints_dir):
            os.makedirs(self.checkpoints_dir)
        self.ckpt_save_name = 'trainingTemplateTF.model'
        # Where to save tensorboard summaries
        self.summaries_dir = './summaries'
        if not os.path.exists(self.summaries_dir):
            os.makedirs(self.summaries_dir)

        # Get training dataset as lists of image paths
        self.train_in_data_list = get_filepaths_from_dir(train_in_data_path)
        self.train_gt_data_list = get_filepaths_from_dir(train_gt_data_path)
        if len(self.train_in_data_list) is 0 or len(
                self.train_gt_data_list) is 0:
            raise ValueError(
                "No training data found in folders {} or {}".format(
                    train_in_data_path, train_gt_data_path))
        elif len(self.train_in_data_list) != len(self.train_gt_data_list):
            raise ValueError(
                "{} ({} data) and {} ({} data) should have the same number of input data"
                .format(train_in_data_path, len(self.train_in_data_list),
                        train_gt_data_path, len(self.train_gt_data_list)))
        elif (len(self.train_in_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of training data = {})"
                .format(self.batch_size, len(self.train_in_data_list)))
        self.is_exr = is_exr(self.train_in_data_list[0])

        # Get validation dataset if provided
        self.has_val_data = True
        self.val_in_data_list = get_filepaths_from_dir(val_in_data_path)
        self.val_gt_data_list = get_filepaths_from_dir(val_gt_data_path)
        if len(self.val_in_data_list) is 0 or len(self.val_gt_data_list) is 0:
            print("No validation data found in {} or {}".format(
                val_in_data_path, val_gt_data_path))
            self.has_val_data = False
        elif len(self.val_in_data_list) != len(self.val_gt_data_list):
            raise ValueError(
                "{} ({} data) and {} ({} data) should have the same number of input data"
                .format(val_in_data_path, len(self.val_in_data_list),
                        val_gt_data_path, len(self.val_gt_data_list)))
        elif (len(self.val_in_data_list) < self.batch_size):
            raise ValueError(
                "Batch size must be smaller than the dataset (batch size = {}, number of validation data = {})"
                .format(self.batch_size, len(self.val_in_data_list)))
        else:
            val_is_exr = is_exr(self.val_in_data_list[0])
            if (val_is_exr and not self.is_exr) or (not val_is_exr
                                                    and self.is_exr):
                raise TypeError(
                    "Train and validation data should have the same file format"
                )
            print("Number of validation data: {}".format(
                len(self.val_in_data_list)))

        # Compute and print training hyperparameters
        batch_per_epoch = (len(self.train_in_data_list)) // self.batch_size
        self.max_steps = int(self.epoch * (batch_per_epoch))
        print_(
            "Number of training data: {}\nNumber of batches per epoch: {} (batch size = {})\nNumber of training steps for {} epochs: {}\n"
            .format(len(self.train_in_data_list), batch_per_epoch,
                    self.batch_size, self.epoch, self.max_steps), 'm')