示例#1
0
    def test_save_best_only_with_max(self):
        model_restore = BestModelRestore(monitor='val_loss', mode='max')

        val_losses = [3, 2, 8, 5, 4]
        best_epoch = 3
        self._test_restore_with_val_losses(model_restore, val_losses,
                                           best_epoch)
示例#2
0
    def test_basic_restore(self):
        model_restore = BestModelRestore(monitor='val_loss')

        val_losses = [3, 2, 8, 5, 4]
        best_epoch = 2
        self._test_restore_with_val_losses(model_restore, val_losses,
                                           best_epoch)
    def fit(self,
            x_train,
            y_train,
            x_valid,
            y_valid,
            n_epochs=100,
            batch_size=32,
            log_filename=None,
            checkpoint_filename=None,
            with_early_stopping=True):
        """
        :param x_train: training set examples
        :param y_train: training set labels
        :param x_valid: testing set examples
        :param y_valid: testing set labels
        :param n_epochs: int, number of epoch default value 100
        :param batch_size: int, size of the batch  default value 32, must be multiple of 2
        :param log_filename: optional, to output the training informations
        :param checkpoint_filename: optional, to save the model
        :param with_early_stopping: to activate the early stopping or not
        :return: self, the model
        """
        callbacks = []
        if with_early_stopping:
            early_stopping = EarlyStopping(monitor='val_loss',
                                           patience=3,
                                           verbose=0)
            callbacks += [early_stopping]
        reduce_lr = ReduceLROnPlateau(monitor='loss',
                                      patience=2,
                                      factor=1 / 10,
                                      min_lr=1e-6)
        best_model_restore = BestModelRestore()
        callbacks += [reduce_lr, best_model_restore]
        if log_filename:
            logger = CSVLogger(log_filename,
                               batch_granularity=False,
                               separator='\t')
            callbacks += [logger]
        if checkpoint_filename:
            checkpointer = ModelCheckpoint(checkpoint_filename,
                                           monitor='val_loss',
                                           save_best_only=True)
            callbacks += [checkpointer]

            # self.model.fit(x_train, y_train, x_valid, y_valid,
            #                batch_size=batch_size, epochs=n_epochs,
            #                callbacks=callbacks)
            nb_steps_train, nb_step_valid = int(
                len(x_train) / batch_size), int(len(x_valid) / batch_size)
            self.model.fit_generator(
                generator(x_train, y_train, batch_size),
                steps_per_epoch=nb_steps_train,
                valid_generator=generator(x_valid, y_valid, batch_size),
                validation_steps=nb_step_valid,
                epochs=n_epochs,
                callbacks=callbacks,
            )
            return self
示例#4
0
 def test_integration(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     model_restore = BestModelRestore(monitor='val_loss')
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[model_restore])
示例#5
0
    def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
        callbacks = []
        best_checkpoint = ModelCheckpoint(
            self.best_checkpoint_filename,
            monitor=self.monitor_metric,
            mode=self.monitor_mode,
            save_best_only=not save_every_epoch,
            restore_best=not save_every_epoch,
            verbose=not save_every_epoch,
            temporary_filename=self.best_checkpoint_tmp_filename)
        callbacks.append(best_checkpoint)

        if save_every_epoch:
            best_restore = BestModelRestore(monitor=self.monitor_metric,
                                            mode=self.monitor_mode,
                                            verbose=True)
            callbacks.append(best_restore)

        if initial_epoch > 1:
            # We set the current best metric score in the ModelCheckpoint so that
            # it does not save checkpoint it would not have saved if the
            # optimization was not stopped.
            best_epoch_stats = self.get_best_epoch_stats()
            best_epoch = best_epoch_stats['epoch'].item()
            best_filename = self.best_checkpoint_filename.format(
                epoch=best_epoch)
            if not save_every_epoch:
                best_checkpoint.best_filename = best_filename
                best_checkpoint.current_best = best_epoch_stats[
                    self.monitor_metric].item()
            else:
                best_restore.best_weights = torch.load(best_filename,
                                                       map_location='cpu')
                best_restore.current_best = best_epoch_stats[
                    self.monitor_metric].item()

        return callbacks
示例#6
0
    def fit(self,
            meta_train,
            meta_valid,
            n_epochs=100,
            steps_per_epoch=100,
            log_filename=None,
            checkpoint_filename=None,
            tboard_folder=None):
        if hasattr(self.model, 'is_eval'):
            self.model.is_eval = False
        self.is_eval = False
        self.steps_per_epoch = steps_per_epoch
        callbacks = [
            EarlyStopping(patience=10, verbose=False),
            ReduceLROnPlateau(patience=2,
                              factor=1 / 2,
                              min_lr=1e-6,
                              verbose=True),
            BestModelRestore()
        ]
        if log_filename:
            callbacks += [
                CSVLogger(log_filename,
                          batch_granularity=False,
                          separator='\t')
            ]
        if checkpoint_filename:
            callbacks += [
                ModelCheckpoint(checkpoint_filename,
                                monitor='val_loss',
                                save_best_only=True,
                                temporary_filename=checkpoint_filename +
                                'temp')
            ]

        if tboard_folder is not None:
            self.writer = SummaryWriter(tboard_folder)

        self.fit_generator(meta_train,
                           meta_valid,
                           epochs=n_epochs,
                           steps_per_epoch=steps_per_epoch,
                           validation_steps=steps_per_epoch,
                           callbacks=callbacks,
                           verbose=True)
        self.is_fitted = True
        return self
示例#7
0
 def _init_lr_scheduler_callbacks(self, lr_schedulers):
     callbacks = []
     if self.logging:
         for i, lr_scheduler in enumerate(lr_schedulers):
             filename = self.lr_scheduler_filename % i
             tmp_filename = self.lr_scheduler_tmp_filename % i
             callbacks += [
                 LRSchedulerCheckpoint(lr_scheduler,
                                       filename,
                                       verbose=False,
                                       temporary_filename=tmp_filename)
             ]
     else:
         callbacks += lr_schedulers
         callbacks += [
             BestModelRestore(monitor=self.monitor_metric,
                              mode=self.monitor_mode,
                              verbose=True)
         ]
     return callbacks
示例#8
0
    def train(self, train_loader, valid_loader=None,
              callbacks=[], lr_schedulers=[],
              disable_tensorboard=False,
              epochs=1000, steps_per_epoch=None, validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)


            best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                              monitor=self.monitor_metric,
                                              mode=self.monitor_mode,
                                              save_best_only=True,
                                              restore_best=True,
                                              verbose=True,
                                              temporary_filename=self.best_checkpoint_tmp_filename)
            if initial_epoch > 1:
                # We set the current best metric score in the ModelCheckpoint so that
                # it does not save checkpoint it would not have saved if the
                # optimization was not stopped.
                best_epoch_stats = self.get_best_epoch_stats()
                best_epoch = best_epoch_stats['epoch'].item()
                best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()

            checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename)
            optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename)

            # We save the last epoch number after the end of the epoch so that the
            # load_epoch_state() knows which epoch to restart the optimization.
            save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd),
                                                   self.epoch_filename,
                                                   temporary_filename=self.epoch_tmp_filename,
                                                   open_mode='w')

            callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number]

            if not disable_tensorboard:
                if SummaryWriter is None:
                    warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.")
                else:
                    writer = SummaryWriter(self.tensorboard_directory)
                    tensorboard_logger = TensorBoardLogger(writer)
                    callbacks.append(tensorboard_logger)
            for i, lr_scheduler in enumerate(lr_schedulers):
                filename = self.lr_scheduler_filename % i
                tmp_filename = self.lr_scheduler_tmp_filename % i
                lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
                callbacks.append(lr_scheduler_checkpoint)
        else:
            for lr_scheduler in lr_schedulers:
                callbacks.append(lr_scheduler)
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode)
            callbacks.append(best_restore)

        self.model.fit_generator(train_loader, valid_loader,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps,
                            initial_epoch=initial_epoch,
                            callbacks=callbacks)