Ejemplo n.º 1
0
def train(model, model_name, train_loader, valid_loader, epochs=1000):
    # Create callbacks and checkpoints
    lrscheduler = ReduceLROnPlateau(patience=3, verbose=True)
    early_stopping = EarlyStopping(patience=10, min_delta=1e-4, verbose=True)
    model_path = './models/'

    os.makedirs(model_path, exist_ok=True)
    ckpt_best = ModelCheckpoint(model_path + 'best_' + model_name + '.torch',
                                save_best_only=True,
                                restore_best=True,
                                temporary_filename=model_path + 'temp_best_' + model_name + '.torch',
                                verbose=True)

    ckpt_last = ModelCheckpoint(model_path + 'last_' + model_name + '.torch',
                                temporary_filename=model_path + 'temp_last_' + model_name + '.torch')

    logger_path = './train_logs/'
    os.makedirs(logger_path, exist_ok=True)
    csv_logger = CSVLogger(logger_path + model_name + '.csv')

    callbacks = [lrscheduler, ckpt_best, ckpt_last, early_stopping, csv_logger]

    # Fit the model
    model.fit_generator(train_loader, valid_loader,
                        epochs=epochs, callbacks=callbacks)
Ejemplo n.º 2
0
    def fit(self,
            x_train,
            y_train,
            x_valid,
            y_valid,
            n_epochs=100,
            batch_size=32,
            log_filename=None,
            checkpoint_filename=None,
            with_early_stopping=True):
        """
        :param x_train: training set examples
        :param y_train: training set labels
        :param x_valid: testing set examples
        :param y_valid: testing set labels
        :param n_epochs: int, number of epoch default value 100
        :param batch_size: int, size of the batch  default value 32, must be multiple of 2
        :param log_filename: optional, to output the training informations
        :param checkpoint_filename: optional, to save the model
        :param with_early_stopping: to activate the early stopping or not
        :return: self, the model
        """
        callbacks = []
        if with_early_stopping:
            early_stopping = EarlyStopping(monitor='val_loss',
                                           patience=3,
                                           verbose=0)
            callbacks += [early_stopping]
        reduce_lr = ReduceLROnPlateau(monitor='loss',
                                      patience=2,
                                      factor=1 / 10,
                                      min_lr=1e-6)
        best_model_restore = BestModelRestore()
        callbacks += [reduce_lr, best_model_restore]
        if log_filename:
            logger = CSVLogger(log_filename,
                               batch_granularity=False,
                               separator='\t')
            callbacks += [logger]
        if checkpoint_filename:
            checkpointer = ModelCheckpoint(checkpoint_filename,
                                           monitor='val_loss',
                                           save_best_only=True)
            callbacks += [checkpointer]

            # self.model.fit(x_train, y_train, x_valid, y_valid,
            #                batch_size=batch_size, epochs=n_epochs,
            #                callbacks=callbacks)
            nb_steps_train, nb_step_valid = int(
                len(x_train) / batch_size), int(len(x_valid) / batch_size)
            self.model.fit_generator(
                generator(x_train, y_train, batch_size),
                steps_per_epoch=nb_steps_train,
                valid_generator=generator(x_valid, y_valid, batch_size),
                validation_steps=nb_step_valid,
                epochs=n_epochs,
                callbacks=callbacks,
            )
            return self
Ejemplo n.º 3
0
 def test_integration(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     earlystopper = EarlyStopping(monitor='val_loss', min_delta=0, patience=2, verbose=False)
     self.model.fit_generator(train_gen, valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[earlystopper])
Ejemplo n.º 4
0
    def test_early_stopping_with_delta(self):
        earlystopper = EarlyStopping(monitor='val_loss',
                                     min_delta=3,
                                     patience=2,
                                     verbose=False)

        val_losses = [8, 4, 5, 2, 2]
        early_stop_epoch = 4
        self._test_early_stopping(earlystopper, val_losses, early_stop_epoch)
Ejemplo n.º 5
0
    def test_early_stopping_patience_of_1(self):
        earlystopper = EarlyStopping(monitor='val_loss',
                                     min_delta=0,
                                     patience=1,
                                     verbose=False)

        val_losses = [8, 4, 5, 2]
        early_stop_epoch = 3
        self._test_early_stopping(earlystopper, val_losses, early_stop_epoch)
Ejemplo n.º 6
0
    def test_early_stopping_with_max(self):
        earlystopper = EarlyStopping(monitor='val_loss',
                                     min_delta=0,
                                     patience=2,
                                     verbose=False,
                                     mode='max')

        val_losses = [2, 8, 4, 5, 2]
        early_stop_epoch = 4
        self._test_early_stopping(earlystopper, val_losses, early_stop_epoch)
Ejemplo n.º 7
0
    def fit(self,
            meta_train,
            meta_valid,
            n_epochs=100,
            steps_per_epoch=100,
            log_filename=None,
            checkpoint_filename=None,
            tboard_folder=None):
        if hasattr(self.model, 'is_eval'):
            self.model.is_eval = False
        self.is_eval = False
        self.steps_per_epoch = steps_per_epoch
        callbacks = [
            EarlyStopping(patience=10, verbose=False),
            ReduceLROnPlateau(patience=2,
                              factor=1 / 2,
                              min_lr=1e-6,
                              verbose=True),
            BestModelRestore()
        ]
        if log_filename:
            callbacks += [
                CSVLogger(log_filename,
                          batch_granularity=False,
                          separator='\t')
            ]
        if checkpoint_filename:
            callbacks += [
                ModelCheckpoint(checkpoint_filename,
                                monitor='val_loss',
                                save_best_only=True,
                                temporary_filename=checkpoint_filename +
                                'temp')
            ]

        if tboard_folder is not None:
            self.writer = SummaryWriter(tboard_folder)

        self.fit_generator(meta_train,
                           meta_valid,
                           epochs=n_epochs,
                           steps_per_epoch=steps_per_epoch,
                           validation_steps=steps_per_epoch,
                           callbacks=callbacks,
                           verbose=True)
        self.is_fitted = True
        return self