Esempio n. 1
0
 def test_reduce_lr_on_plateau_integration(self):
     train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
     valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
     reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
     checkpointer = LRSchedulerCheckpoint(reduce_lr,
                                          self.checkpoint_filename,
                                          period=1)
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=OptimizerCheckpointTest.epochs,
                              steps_per_epoch=5,
                              callbacks=[checkpointer])
Esempio n. 2
0
 def test_any_scheduler_integration(self):
     train_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
     valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size)
     lr_scheduler = ExponentialLR(gamma=0.01)
     checkpointer = LRSchedulerCheckpoint(lr_scheduler,
                                          self.checkpoint_filename,
                                          period=1)
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=OptimizerCheckpointTest.epochs,
                              steps_per_epoch=5,
                              callbacks=[checkpointer])
Esempio n. 3
0
 def _init_lr_scheduler_callbacks(self, lr_schedulers):
     callbacks = []
     if self.logging:
         for i, lr_scheduler in enumerate(lr_schedulers):
             filename = self.lr_scheduler_filename % i
             tmp_filename = self.lr_scheduler_tmp_filename % i
             callbacks += [
                 LRSchedulerCheckpoint(lr_scheduler,
                                       filename,
                                       verbose=False,
                                       temporary_filename=tmp_filename)
             ]
     else:
         callbacks += lr_schedulers
         callbacks += [
             BestModelRestore(monitor=self.monitor_metric,
                              mode=self.monitor_mode,
                              verbose=True)
         ]
     return callbacks
Esempio n. 4
0
 def test_reduce_lr_checkpoints(self):
     reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3)
     checkpointer = LRSchedulerCheckpoint(reduce_lr,
                                          self.checkpoint_filename,
                                          period=1)
     self._test_checkpointer(checkpointer, reduce_lr)
Esempio n. 5
0
 def test_any_scheduler_checkpoints(self):
     lr_scheduler = ExponentialLR(gamma=0.01)
     checkpointer = LRSchedulerCheckpoint(lr_scheduler,
                                          self.checkpoint_filename,
                                          period=1)
     self._test_checkpointer(checkpointer, lr_scheduler)
Esempio n. 6
0
    def train(self, train_loader, valid_loader=None,
              callbacks=[], lr_schedulers=[],
              disable_tensorboard=False,
              epochs=1000, steps_per_epoch=None, validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)


            best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                              monitor=self.monitor_metric,
                                              mode=self.monitor_mode,
                                              save_best_only=True,
                                              restore_best=True,
                                              verbose=True,
                                              temporary_filename=self.best_checkpoint_tmp_filename)
            if initial_epoch > 1:
                # We set the current best metric score in the ModelCheckpoint so that
                # it does not save checkpoint it would not have saved if the
                # optimization was not stopped.
                best_epoch_stats = self.get_best_epoch_stats()
                best_epoch = best_epoch_stats['epoch'].item()
                best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()

            checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename)
            optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename)

            # We save the last epoch number after the end of the epoch so that the
            # load_epoch_state() knows which epoch to restart the optimization.
            save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd),
                                                   self.epoch_filename,
                                                   temporary_filename=self.epoch_tmp_filename,
                                                   open_mode='w')

            callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number]

            if not disable_tensorboard:
                if SummaryWriter is None:
                    warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.")
                else:
                    writer = SummaryWriter(self.tensorboard_directory)
                    tensorboard_logger = TensorBoardLogger(writer)
                    callbacks.append(tensorboard_logger)
            for i, lr_scheduler in enumerate(lr_schedulers):
                filename = self.lr_scheduler_filename % i
                tmp_filename = self.lr_scheduler_tmp_filename % i
                lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
                callbacks.append(lr_scheduler_checkpoint)
        else:
            for lr_scheduler in lr_schedulers:
                callbacks.append(lr_scheduler)
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode)
            callbacks.append(best_restore)

        self.model.fit_generator(train_loader, valid_loader,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps,
                            initial_epoch=initial_epoch,
                            callbacks=callbacks)