Ejemplo n.º 1
0
    def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
        callbacks = []
        best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                          monitor=self.monitor_metric,
                                          mode=self.monitor_mode,
                                          save_best_only=not save_every_epoch,
                                          restore_best=not save_every_epoch,
                                          verbose=not save_every_epoch,
                                          temporary_filename=self.best_checkpoint_tmp_filename)
        callbacks.append(best_checkpoint)

        if save_every_epoch:
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)
            callbacks.append(best_restore)

        if initial_epoch > 1:
            # We set the current best metric score in the ModelCheckpoint so that
            # it does not save checkpoint it would not have saved if the
            # optimization was not stopped.
            best_epoch_stats = self.get_best_epoch_stats()
            best_epoch = best_epoch_stats['epoch'].item()
            best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
            if not save_every_epoch:
                best_checkpoint.best_filename = best_filename
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()
            else:
                best_restore.best_weights = torch.load(best_filename, map_location='cpu')
                best_restore.current_best = best_epoch_stats[self.monitor_metric].item()

        return callbacks
Ejemplo n.º 2
0
    def test_save_best_only_with_max(self):
        model_restore = BestModelRestore(monitor='val_loss', mode='max')

        val_losses = [3, 2, 8, 5, 4]
        best_epoch = 3
        self._test_restore_with_val_losses(model_restore, val_losses,
                                           best_epoch)
Ejemplo n.º 3
0
    def test_basic_restore(self):
        model_restore = BestModelRestore(monitor='val_loss')

        val_losses = [3, 2, 8, 5, 4]
        best_epoch = 2
        self._test_restore_with_val_losses(model_restore, val_losses,
                                           best_epoch)
Ejemplo n.º 4
0
 def test_integration(self):
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     model_restore = BestModelRestore(monitor='val_loss', verbose=True)
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[model_restore])
Ejemplo n.º 5
0
 def _init_lr_scheduler_callbacks(self, lr_schedulers):
     callbacks = []
     if self.logging:
         for i, lr_scheduler in enumerate(lr_schedulers):
             filename = self.lr_scheduler_filename % i
             tmp_filename = self.lr_scheduler_tmp_filename % i
             callbacks += [
                 LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
             ]
     else:
         callbacks += lr_schedulers
         callbacks += [BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True)]
     return callbacks
Ejemplo n.º 6
0
    def fit(self, meta_train, meta_valid, meta_test=None, n_epochs=100, steps_per_epoch=100, log_filename=None,
            mse_filename=None, checkpoint_filename=None, tboard_folder=None, grads_inspect_dir=None,
            graph_flow_filename=None, do_early_stopping=True, mse_test=False, config=None):

        if hasattr(self.model, 'is_eval'):
            self.model.is_eval = False
        self.is_eval = False

        try:
            self.model.filename = log_filename[:-8]
        except:
            self.model.filename = 'test'

        self.steps_per_epoch = steps_per_epoch

        callbacks = [ReduceLROnPlateau(patience=2, factor=1 / 2, min_lr=1e-6, verbose=True),
                     BestModelRestore(verbose=True)]
        if do_early_stopping:
            callbacks.append(EarlyStopping(patience=10, verbose=False))

        if log_filename:
            callbacks.append(CSVLogger(log_filename, batch_granularity=False, separator='\t'))

        if mse_test:
            callbacks.append(MseMetaTest(meta_test=meta_test, filename=mse_filename, periodicity='epoch'))

        if checkpoint_filename:
            callbacks.append(ModelCheckpoint(checkpoint_filename, monitor='val_loss', save_best_only=True,
                                             temporary_filename=checkpoint_filename + 'temp'))

        if tboard_folder is not None:
            self.writer = SummaryWriter(tboard_folder)
            self.plain_writer = PlainWriter(tboard_folder)

        self.fit_generator(meta_train, meta_valid,
                           epochs=n_epochs,
                           steps_per_epoch=steps_per_epoch,
                           validation_steps=steps_per_epoch,
                           callbacks=callbacks,
                           verbose=True)
        self.is_fitted = True

        if self.plain_writer is not None:
            self.plain_writer.close()

        return self