Exemplo n.º 1
0
def train(model, model_name, train_loader, valid_loader, epochs=1000):
    # Create callbacks and checkpoints
    lrscheduler = ReduceLROnPlateau(patience=3, verbose=True)
    early_stopping = EarlyStopping(patience=10, min_delta=1e-4, verbose=True)
    model_path = './models/'

    os.makedirs(model_path, exist_ok=True)
    ckpt_best = ModelCheckpoint(model_path + 'best_' + model_name + '.torch',
                                save_best_only=True,
                                restore_best=True,
                                temporary_filename=model_path + 'temp_best_' + model_name + '.torch',
                                verbose=True)

    ckpt_last = ModelCheckpoint(model_path + 'last_' + model_name + '.torch',
                                temporary_filename=model_path + 'temp_last_' + model_name + '.torch')

    logger_path = './train_logs/'
    os.makedirs(logger_path, exist_ok=True)
    csv_logger = CSVLogger(logger_path + model_name + '.csv')

    callbacks = [lrscheduler, ckpt_best, ckpt_last, early_stopping, csv_logger]

    # Fit the model
    model.fit_generator(train_loader, valid_loader,
                        epochs=epochs, callbacks=callbacks)
Exemplo n.º 2
0
    def test_restore_best_without_save_best_only(self):
        with self.assertRaises(ValueError):
            ModelCheckpoint(self.checkpoint_filename,
                            monitor='val_loss',
                            verbose=True,
                            save_best_only=False,
                            restore_best=True)

        with self.assertRaises(ValueError):
            ModelCheckpoint(self.checkpoint_filename,
                            monitor='val_loss',
                            verbose=True,
                            restore_best=True)
Exemplo n.º 3
0
 def test_non_atomic_write(self):
     checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt')
     train_gen = some_data_generator(ModelCheckpointTest.batch_size)
     valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
     checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', verbose=True, period=1, atomic_write=False)
     self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
     self.assertTrue(os.path.isfile(checkpoint_filename))
Exemplo n.º 4
0
    def fit(self,
            x_train,
            y_train,
            x_valid,
            y_valid,
            n_epochs=100,
            batch_size=32,
            log_filename=None,
            checkpoint_filename=None,
            with_early_stopping=True):
        """
        :param x_train: training set examples
        :param y_train: training set labels
        :param x_valid: testing set examples
        :param y_valid: testing set labels
        :param n_epochs: int, number of epoch default value 100
        :param batch_size: int, size of the batch  default value 32, must be multiple of 2
        :param log_filename: optional, to output the training informations
        :param checkpoint_filename: optional, to save the model
        :param with_early_stopping: to activate the early stopping or not
        :return: self, the model
        """
        callbacks = []
        if with_early_stopping:
            early_stopping = EarlyStopping(monitor='val_loss',
                                           patience=3,
                                           verbose=0)
            callbacks += [early_stopping]
        reduce_lr = ReduceLROnPlateau(monitor='loss',
                                      patience=2,
                                      factor=1 / 10,
                                      min_lr=1e-6)
        best_model_restore = BestModelRestore()
        callbacks += [reduce_lr, best_model_restore]
        if log_filename:
            logger = CSVLogger(log_filename,
                               batch_granularity=False,
                               separator='\t')
            callbacks += [logger]
        if checkpoint_filename:
            checkpointer = ModelCheckpoint(checkpoint_filename,
                                           monitor='val_loss',
                                           save_best_only=True)
            callbacks += [checkpointer]

            # self.model.fit(x_train, y_train, x_valid, y_valid,
            #                batch_size=batch_size, epochs=n_epochs,
            #                callbacks=callbacks)
            nb_steps_train, nb_step_valid = int(
                len(x_train) / batch_size), int(len(x_valid) / batch_size)
            self.model.fit_generator(
                generator(x_train, y_train, batch_size),
                steps_per_epoch=nb_steps_train,
                valid_generator=generator(x_valid, y_valid, batch_size),
                validation_steps=nb_step_valid,
                epochs=n_epochs,
                callbacks=callbacks,
            )
            return self
Exemplo n.º 5
0
    def test_save_best_only_with_restore_best(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True, restore_best=True)

        val_losses = [10, 3, 8, 5, 2]
        has_checkpoints = [True, True, False, False, True]
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)

        self._test_restore_best(val_losses)
Exemplo n.º 6
0
    def test_periodic_with_period_of_2(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename,
                                       monitor='val_loss',
                                       verbose=True,
                                       period=2,
                                       save_best_only=False)

        val_losses = [1] * 10
        has_checkpoints = [False, True] * 5
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
Exemplo n.º 7
0
    def test_save_best_only_with_max(self):
        checkpointer = ModelCheckpoint(self.checkpoint_filename,
                                       monitor='val_loss',
                                       mode='max',
                                       verbose=True,
                                       save_best_only=True)

        val_losses = [2, 3, 8, 5, 2]
        has_checkpoints = [True, True, True, False, False]
        self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
Exemplo n.º 8
0
 def test_temporary_filename_arg_with_differing_checkpoint_filename(self):
     epochs = 10
     tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt')
     checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.ckpt')
     train_gen = some_data_generator(ModelCheckpointTest.batch_size)
     valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
     checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', verbose=True, period=1, temporary_filename=tmp_filename)
     self.model.fit_generator(train_gen, valid_gen, epochs=epochs, steps_per_epoch=5, callbacks=[checkpointer])
     self.assertFalse(os.path.isfile(tmp_filename))
     for i in range(1, epochs+1):
         self.assertTrue(os.path.isfile(checkpoint_filename.format(epoch=i)))
Exemplo n.º 9
0
 def test_integration(self):
     train_gen = some_data_generator(ModelCheckpointTest.batch_size)
     valid_gen = some_data_generator(ModelCheckpointTest.batch_size)
     checkpointer = ModelCheckpoint(self.checkpoint_filename,
                                    monitor='val_loss',
                                    verbose=True,
                                    save_best_only=True)
     self.model.fit_generator(train_gen, valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[checkpointer])
Exemplo n.º 10
0
    def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch):
        callbacks = []
        best_checkpoint = ModelCheckpoint(
            self.best_checkpoint_filename,
            monitor=self.monitor_metric,
            mode=self.monitor_mode,
            save_best_only=not save_every_epoch,
            restore_best=not save_every_epoch,
            verbose=not save_every_epoch,
            temporary_filename=self.best_checkpoint_tmp_filename)
        callbacks.append(best_checkpoint)

        if save_every_epoch:
            best_restore = BestModelRestore(monitor=self.monitor_metric,
                                            mode=self.monitor_mode,
                                            verbose=True)
            callbacks.append(best_restore)

        if initial_epoch > 1:
            # We set the current best metric score in the ModelCheckpoint so that
            # it does not save checkpoint it would not have saved if the
            # optimization was not stopped.
            best_epoch_stats = self.get_best_epoch_stats()
            best_epoch = best_epoch_stats['epoch'].item()
            best_filename = self.best_checkpoint_filename.format(
                epoch=best_epoch)
            if not save_every_epoch:
                best_checkpoint.best_filename = best_filename
                best_checkpoint.current_best = best_epoch_stats[
                    self.monitor_metric].item()
            else:
                best_restore.best_weights = torch.load(best_filename,
                                                       map_location='cpu')
                best_restore.current_best = best_epoch_stats[
                    self.monitor_metric].item()

        return callbacks
Exemplo n.º 11
0
    def fit(self,
            meta_train,
            meta_valid,
            n_epochs=100,
            steps_per_epoch=100,
            log_filename=None,
            checkpoint_filename=None,
            tboard_folder=None):
        if hasattr(self.model, 'is_eval'):
            self.model.is_eval = False
        self.is_eval = False
        self.steps_per_epoch = steps_per_epoch
        callbacks = [
            EarlyStopping(patience=10, verbose=False),
            ReduceLROnPlateau(patience=2,
                              factor=1 / 2,
                              min_lr=1e-6,
                              verbose=True),
            BestModelRestore()
        ]
        if log_filename:
            callbacks += [
                CSVLogger(log_filename,
                          batch_granularity=False,
                          separator='\t')
            ]
        if checkpoint_filename:
            callbacks += [
                ModelCheckpoint(checkpoint_filename,
                                monitor='val_loss',
                                save_best_only=True,
                                temporary_filename=checkpoint_filename +
                                'temp')
            ]

        if tboard_folder is not None:
            self.writer = SummaryWriter(tboard_folder)

        self.fit_generator(meta_train,
                           meta_valid,
                           epochs=n_epochs,
                           steps_per_epoch=steps_per_epoch,
                           validation_steps=steps_per_epoch,
                           callbacks=callbacks,
                           verbose=True)
        self.is_fitted = True
        return self
Exemplo n.º 12
0
 def test_temporary_filename_arg(self):
     tmp_filename = os.path.join(self.temp_dir_obj.name,
                                 'my_checkpoint.tmp.ckpt')
     checkpoint_filename = os.path.join(self.temp_dir_obj.name,
                                        'my_checkpoint.ckpt')
     train_gen = some_data_generator(20)
     valid_gen = some_data_generator(20)
     checkpointer = ModelCheckpoint(checkpoint_filename,
                                    monitor='val_loss',
                                    verbose=True,
                                    period=1,
                                    temporary_filename=tmp_filename)
     self.model.fit_generator(train_gen,
                              valid_gen,
                              epochs=10,
                              steps_per_epoch=5,
                              callbacks=[checkpointer])
     self.assertFalse(os.path.isfile(tmp_filename))
     self.assertTrue(os.path.isfile(checkpoint_filename))
Exemplo n.º 13
0
    def train(self,
              train_loader,
              valid_loader=None,
              *,
              callbacks=[],
              lr_schedulers=[],
              save_every_epoch=False,
              disable_tensorboard=False,
              epochs=1000,
              steps_per_epoch=None,
              validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        tensorboard_writer = None
        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            callbacks += [
                CSVLogger(self.log_filename,
                          separator='\t',
                          append=initial_epoch != 1)
            ]

            callbacks += self._init_model_restoring_callbacks(
                initial_epoch, save_every_epoch)
            callbacks += [
                ModelCheckpoint(
                    self.model_checkpoint_filename,
                    verbose=False,
                    temporary_filename=self.model_checkpoint_tmp_filename)
            ]
            callbacks += [
                OptimizerCheckpoint(
                    self.optimizer_checkpoint_filename,
                    verbose=False,
                    temporary_filename=self.optimizer_checkpoint_tmp_filename)
            ]

            # We save the last epoch number after the end of the epoch so that the
            # _load_epoch_state() knows which epoch to restart the optimization.
            callbacks += [
                PeriodicSaveLambda(
                    lambda fd, epoch, logs: print(epoch, file=fd),
                    self.epoch_filename,
                    temporary_filename=self.epoch_tmp_filename,
                    open_mode='w')
            ]

            tensorboard_writer, cb_list = self._init_tensorboard_callbacks(
                disable_tensorboard)
            callbacks += cb_list

        # This method returns callbacks that checkpoints the LR scheduler if logging is enabled.
        # Otherwise, it just returns the list of LR schedulers with a BestModelRestore callback.
        callbacks += self._init_lr_scheduler_callbacks(lr_schedulers)

        try:
            self.model.fit_generator(train_loader,
                                     valid_loader,
                                     epochs=epochs,
                                     steps_per_epoch=steps_per_epoch,
                                     validation_steps=validation_steps,
                                     initial_epoch=initial_epoch,
                                     callbacks=callbacks)
        finally:
            if tensorboard_writer is not None:
                tensorboard_writer.close()
Exemplo n.º 14
0
    def train(self, train_loader, valid_loader=None,
              callbacks=[], lr_schedulers=[],
              disable_tensorboard=False,
              epochs=1000, steps_per_epoch=None, validation_steps=None,
              seed=42):
        if seed is not None:
            # Make training deterministic.
            random.seed(seed)
            np.random.seed(seed)
            torch.manual_seed(seed)

        # Copy callback list.
        callbacks = list(callbacks)

        initial_epoch = 1
        if self.logging:
            if not os.path.exists(self.directory):
                os.makedirs(self.directory)

            # Restarting optimization if needed.
            initial_epoch = self._load_epoch_state(lr_schedulers)

            csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1)


            best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename,
                                              monitor=self.monitor_metric,
                                              mode=self.monitor_mode,
                                              save_best_only=True,
                                              restore_best=True,
                                              verbose=True,
                                              temporary_filename=self.best_checkpoint_tmp_filename)
            if initial_epoch > 1:
                # We set the current best metric score in the ModelCheckpoint so that
                # it does not save checkpoint it would not have saved if the
                # optimization was not stopped.
                best_epoch_stats = self.get_best_epoch_stats()
                best_epoch = best_epoch_stats['epoch'].item()
                best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch)
                best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item()

            checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename)
            optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename)

            # We save the last epoch number after the end of the epoch so that the
            # load_epoch_state() knows which epoch to restart the optimization.
            save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd),
                                                   self.epoch_filename,
                                                   temporary_filename=self.epoch_tmp_filename,
                                                   open_mode='w')

            callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number]

            if not disable_tensorboard:
                if SummaryWriter is None:
                    warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.")
                else:
                    writer = SummaryWriter(self.tensorboard_directory)
                    tensorboard_logger = TensorBoardLogger(writer)
                    callbacks.append(tensorboard_logger)
            for i, lr_scheduler in enumerate(lr_schedulers):
                filename = self.lr_scheduler_filename % i
                tmp_filename = self.lr_scheduler_tmp_filename % i
                lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename)
                callbacks.append(lr_scheduler_checkpoint)
        else:
            for lr_scheduler in lr_schedulers:
                callbacks.append(lr_scheduler)
            best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode)
            callbacks.append(best_restore)

        self.model.fit_generator(train_loader, valid_loader,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_steps=validation_steps,
                            initial_epoch=initial_epoch,
                            callbacks=callbacks)