Exemple #1
0
    def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
        callbacks = []
        best_checkpoint = ModelCheckpoint(
            self.best_checkpoint_filename,
            monitor=self.monitor_metric,
            mode=self.monitor_mode,
            save_best_only=not save_every_epoch,
            restore_best=not save_every_epoch,
            verbose=not save_every_epoch,
            temporary_filename=self.best_checkpoint_tmp_filename)
        callbacks.append(best_checkpoint)

        if save_every_epoch:
            best_restore = BestModelRestore(monitor=self.monitor_metric,
                                            mode=self.monitor_mode,
                                            verbose=True)
            callbacks.append(best_restore)

        if initial_epoch > 1:
            # We set the current best metric score in the ModelCheckpoint so that
            # it does not save checkpoint it would not have saved if the
            # optimization was not stopped.
            best_epoch_stats = self.get_best_epoch_stats()
            best_epoch = best_epoch_stats['epoch'].item()
            best_filename = self.best_checkpoint_filename.format(
                epoch=best_epoch)
            if not save_every_epoch:
                best_checkpoint.best_filename = best_filename
                best_checkpoint.current_best = best_epoch_stats[
                    self.monitor_metric].item()
            else:
                best_restore.best_weights = torch.load(best_filename,
                                                       map_location='cpu')
                best_restore.current_best = best_epoch_stats[
                    self.monitor_metric].item()

        return callbacks
Exemple #2
0
    def train(self, train_loader, valid_loader=None,
              callbacks=[], lr_schedulers=[],
              disable_tensorboard=False,
              epochs=1000, steps_per_epoch=None, validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)


            best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                              monitor=self.monitor_metric,
                                              mode=self.monitor_mode,
                                              save_best_only=True,
                                              restore_best=True,
                                              verbose=True,
                                              temporary_filename=self.best_checkpoint_tmp_filename)
            if initial_epoch > 1:
                # We set the current best metric score in the ModelCheckpoint so that
                # it does not save checkpoint it would not have saved if the
                # optimization was not stopped.
                best_epoch_stats = self.get_best_epoch_stats()
                best_epoch = best_epoch_stats['epoch'].item()
                best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()

            checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename)
            optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename)

            # We save the last epoch number after the end of the epoch so that the
            # load_epoch_state() knows which epoch to restart the optimization.
            save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd),
                                                   self.epoch_filename,
                                                   temporary_filename=self.epoch_tmp_filename,
                                                   open_mode='w')

            callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number]

            if not disable_tensorboard:
                if SummaryWriter is None:
                    warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.")
                else:
                    writer = SummaryWriter(self.tensorboard_directory)
                    tensorboard_logger = TensorBoardLogger(writer)
                    callbacks.append(tensorboard_logger)
            for i, lr_scheduler in enumerate(lr_schedulers):
                filename = self.lr_scheduler_filename % i
                tmp_filename = self.lr_scheduler_tmp_filename % i
                lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
                callbacks.append(lr_scheduler_checkpoint)
        else:
            for lr_scheduler in lr_schedulers:
                callbacks.append(lr_scheduler)
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode)
            callbacks.append(best_restore)

        self.model.fit_generator(train_loader, valid_loader,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps,
                            initial_epoch=initial_epoch,
                            callbacks=callbacks)