Exemplo n.º 1
0
def train(model, model_name, train_loader, valid_loader, epochs=1000):
    # Create callbacks and checkpoints
    lrscheduler = ReduceLROnPlateau(patience=3, verbose=True)
    early_stopping = EarlyStopping(patience=10, min_delta=1e-4, verbose=True)
    model_path = './models/'

    os.makedirs(model_path, exist_ok=True)
    ckpt_best = ModelCheckpoint(model_path + 'best_' + model_name + '.torch',
                                save_best_only=True,
                                restore_best=True,
                                temporary_filename=model_path + 'temp_best_' + model_name + '.torch',
                                verbose=True)

    ckpt_last = ModelCheckpoint(model_path + 'last_' + model_name + '.torch',
                                temporary_filename=model_path + 'temp_last_' + model_name + '.torch')

    logger_path = './train_logs/'
    os.makedirs(logger_path, exist_ok=True)
    csv_logger = CSVLogger(logger_path + model_name + '.csv')

    callbacks = [lrscheduler, ckpt_best, ckpt_last, early_stopping, csv_logger]

    # Fit the model
    model.fit_generator(train_loader, valid_loader,
                        epochs=epochs, callbacks=callbacks)
Exemplo n.º 2
0
    def fit(self,
            x_train,
            y_train,
            x_valid,
            y_valid,
            n_epochs=100,
            batch_size=32,
            log_filename=None,
            checkpoint_filename=None,
            with_early_stopping=True):
        """
        :param x_train: training set examples
        :param y_train: training set labels
        :param x_valid: testing set examples
        :param y_valid: testing set labels
        :param n_epochs: int, number of epoch default value 100
        :param batch_size: int, size of the batch  default value 32, must be multiple of 2
        :param log_filename: optional, to output the training informations
        :param checkpoint_filename: optional, to save the model
        :param with_early_stopping: to activate the early stopping or not
        :return: self, the model
        """
        callbacks = []
        if with_early_stopping:
            early_stopping = EarlyStopping(monitor='val_loss',
                                           patience=3,
                                           verbose=0)
            callbacks += [early_stopping]
        reduce_lr = ReduceLROnPlateau(monitor='loss',
                                      patience=2,
                                      factor=1 / 10,
                                      min_lr=1e-6)
        best_model_restore = BestModelRestore()
        callbacks += [reduce_lr, best_model_restore]
        if log_filename:
            logger = CSVLogger(log_filename,
                               batch_granularity=False,
                               separator='\t')
            callbacks += [logger]
        if checkpoint_filename:
            checkpointer = ModelCheckpoint(checkpoint_filename,
                                           monitor='val_loss',
                                           save_best_only=True)
            callbacks += [checkpointer]

            # self.model.fit(x_train, y_train, x_valid, y_valid,
            #                batch_size=batch_size, epochs=n_epochs,
            #                callbacks=callbacks)
            nb_steps_train, nb_step_valid = int(
                len(x_train) / batch_size), int(len(x_valid) / batch_size)
            self.model.fit_generator(
                generator(x_train, y_train, batch_size),
                steps_per_epoch=nb_steps_train,
                valid_generator=generator(x_valid, y_valid, batch_size),
                validation_steps=nb_step_valid,
                epochs=n_epochs,
                callbacks=callbacks,
            )
            return self
Exemplo n.º 3
0
 def test_logging_append(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     logger = CSVLogger(self.csv_filename)
     history = self.model.fit_generator(train_gen,
                                        valid_gen,
                                        epochs=10,
                                        steps_per_epoch=5,
                                        callbacks=[logger])
     logger = CSVLogger(self.csv_filename, append=True)
     history2 = self.model.fit_generator(train_gen,
                                         valid_gen,
                                         epochs=20,
                                         steps_per_epoch=5,
                                         initial_epoch=10,
                                         callbacks=[logger])
     self._test_logging(history + history2)
Exemplo n.º 4
0
 def test_logging(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     logger = CSVLogger(self.csv_filename)
     history = self.model.fit_generator(train_gen,
                                        valid_gen,
                                        epochs=10,
                                        steps_per_epoch=5,
                                        callbacks=[logger])
     self._test_logging(history)
Exemplo n.º 5
0
 def test_logging_with_batch_granularity(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     logger = CSVLogger(self.csv_filename, batch_granularity=True)
     history = History()
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[logger, history])
     self._test_logging(history.history)
Exemplo n.º 6
0
    def fit(self,
            meta_train,
            meta_valid,
            n_epochs=100,
            steps_per_epoch=100,
            log_filename=None,
            checkpoint_filename=None,
            tboard_folder=None):
        if hasattr(self.model, 'is_eval'):
            self.model.is_eval = False
        self.is_eval = False
        self.steps_per_epoch = steps_per_epoch
        callbacks = [
            EarlyStopping(patience=10, verbose=False),
            ReduceLROnPlateau(patience=2,
                              factor=1 / 2,
                              min_lr=1e-6,
                              verbose=True),
            BestModelRestore()
        ]
        if log_filename:
            callbacks += [
                CSVLogger(log_filename,
                          batch_granularity=False,
                          separator='\t')
            ]
        if checkpoint_filename:
            callbacks += [
                ModelCheckpoint(checkpoint_filename,
                                monitor='val_loss',
                                save_best_only=True,
                                temporary_filename=checkpoint_filename +
                                'temp')
            ]

        if tboard_folder is not None:
            self.writer = SummaryWriter(tboard_folder)

        self.fit_generator(meta_train,
                           meta_valid,
                           epochs=n_epochs,
                           steps_per_epoch=steps_per_epoch,
                           validation_steps=steps_per_epoch,
                           callbacks=callbacks,
                           verbose=True)
        self.is_fitted = True
        return self
Exemplo n.º 7
0
    def train(self,
              train_loader,
              valid_loader=None,
              *,
              callbacks=[],
              lr_schedulers=[],
              save_every_epoch=False,
              disable_tensorboard=False,
              epochs=1000,
              steps_per_epoch=None,
              validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        tensorboard_writer = None
        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            callbacks += [
                CSVLogger(self.log_filename,
                          separator='\t',
                          append=initial_epoch != 1)
            ]

            callbacks += self._init_model_restoring_callbacks(
                initial_epoch, save_every_epoch)
            callbacks += [
                ModelCheckpoint(
                    self.model_checkpoint_filename,
                    verbose=False,
                    temporary_filename=self.model_checkpoint_tmp_filename)
            ]
            callbacks += [
                OptimizerCheckpoint(
                    self.optimizer_checkpoint_filename,
                    verbose=False,
                    temporary_filename=self.optimizer_checkpoint_tmp_filename)
            ]

            # We save the last epoch number after the end of the epoch so that the
            # _load_epoch_state() knows which epoch to restart the optimization.
            callbacks += [
                PeriodicSaveLambda(
                    lambda fd, epoch, logs: print(epoch, file=fd),
                    self.epoch_filename,
                    temporary_filename=self.epoch_tmp_filename,
                    open_mode='w')
            ]

            tensorboard_writer, cb_list = self._init_tensorboard_callbacks(
                disable_tensorboard)
            callbacks += cb_list

        # This method returns callbacks that checkpoints the LR scheduler if logging is enabled.
        # Otherwise, it just returns the list of LR schedulers with a BestModelRestore callback.
        callbacks += self._init_lr_scheduler_callbacks(lr_schedulers)

        try:
            self.model.fit_generator(train_loader,
                                     valid_loader,
                                     epochs=epochs,
                                     steps_per_epoch=steps_per_epoch,
                                     validation_steps=validation_steps,
                                     initial_epoch=initial_epoch,
                                     callbacks=callbacks)
        finally:
            if tensorboard_writer is not None:
                tensorboard_writer.close()
Exemplo n.º 8
0
    def train(self, train_loader, valid_loader=None,
              callbacks=[], lr_schedulers=[],
              disable_tensorboard=False,
              epochs=1000, steps_per_epoch=None, validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)


            best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                              monitor=self.monitor_metric,
                                              mode=self.monitor_mode,
                                              save_best_only=True,
                                              restore_best=True,
                                              verbose=True,
                                              temporary_filename=self.best_checkpoint_tmp_filename)
            if initial_epoch > 1:
                # We set the current best metric score in the ModelCheckpoint so that
                # it does not save checkpoint it would not have saved if the
                # optimization was not stopped.
                best_epoch_stats = self.get_best_epoch_stats()
                best_epoch = best_epoch_stats['epoch'].item()
                best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()

            checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename)
            optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename)

            # We save the last epoch number after the end of the epoch so that the
            # load_epoch_state() knows which epoch to restart the optimization.
            save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd),
                                                   self.epoch_filename,
                                                   temporary_filename=self.epoch_tmp_filename,
                                                   open_mode='w')

            callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number]

            if not disable_tensorboard:
                if SummaryWriter is None:
                    warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.")
                else:
                    writer = SummaryWriter(self.tensorboard_directory)
                    tensorboard_logger = TensorBoardLogger(writer)
                    callbacks.append(tensorboard_logger)
            for i, lr_scheduler in enumerate(lr_schedulers):
                filename = self.lr_scheduler_filename % i
                tmp_filename = self.lr_scheduler_tmp_filename % i
                lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
                callbacks.append(lr_scheduler_checkpoint)
        else:
            for lr_scheduler in lr_schedulers:
                callbacks.append(lr_scheduler)
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode)
            callbacks.append(best_restore)

        self.model.fit_generator(train_loader, valid_loader,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps,
                            initial_epoch=initial_epoch,
                            callbacks=callbacks)