def train(model, model_name, train_loader, valid_loader, epochs=1000): # Create callbacks and checkpoints lrscheduler = ReduceLROnPlateau(patience=3, verbose=True) early_stopping = EarlyStopping(patience=10, min_delta=1e-4, verbose=True) model_path = './models/' os.makedirs(model_path, exist_ok=True) ckpt_best = ModelCheckpoint(model_path + 'best_' + model_name + '.torch', save_best_only=True, restore_best=True, temporary_filename=model_path + 'temp_best_' + model_name + '.torch', verbose=True) ckpt_last = ModelCheckpoint(model_path + 'last_' + model_name + '.torch', temporary_filename=model_path + 'temp_last_' + model_name + '.torch') logger_path = './train_logs/' os.makedirs(logger_path, exist_ok=True) csv_logger = CSVLogger(logger_path + model_name + '.csv') callbacks = [lrscheduler, ckpt_best, ckpt_last, early_stopping, csv_logger] # Fit the model model.fit_generator(train_loader, valid_loader, epochs=epochs, callbacks=callbacks)
def test_restore_best_without_save_best_only(self): with self.assertRaises(ValueError): ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=False, restore_best=True) with self.assertRaises(ValueError): ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, restore_best=True)
def test_non_atomic_write(self): checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt') train_gen = some_data_generator(ModelCheckpointTest.batch_size) valid_gen = some_data_generator(ModelCheckpointTest.batch_size) checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', verbose=True, period=1, atomic_write=False) self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer]) self.assertTrue(os.path.isfile(checkpoint_filename))
def fit(self, x_train, y_train, x_valid, y_valid, n_epochs=100, batch_size=32, log_filename=None, checkpoint_filename=None, with_early_stopping=True): """ :param x_train: training set examples :param y_train: training set labels :param x_valid: testing set examples :param y_valid: testing set labels :param n_epochs: int, number of epoch default value 100 :param batch_size: int, size of the batch default value 32, must be multiple of 2 :param log_filename: optional, to output the training informations :param checkpoint_filename: optional, to save the model :param with_early_stopping: to activate the early stopping or not :return: self, the model """ callbacks = [] if with_early_stopping: early_stopping = EarlyStopping(monitor='val_loss', patience=3, verbose=0) callbacks += [early_stopping] reduce_lr = ReduceLROnPlateau(monitor='loss', patience=2, factor=1 / 10, min_lr=1e-6) best_model_restore = BestModelRestore() callbacks += [reduce_lr, best_model_restore] if log_filename: logger = CSVLogger(log_filename, batch_granularity=False, separator='\t') callbacks += [logger] if checkpoint_filename: checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', save_best_only=True) callbacks += [checkpointer] # self.model.fit(x_train, y_train, x_valid, y_valid, # batch_size=batch_size, epochs=n_epochs, # callbacks=callbacks) nb_steps_train, nb_step_valid = int( len(x_train) / batch_size), int(len(x_valid) / batch_size) self.model.fit_generator( generator(x_train, y_train, batch_size), steps_per_epoch=nb_steps_train, valid_generator=generator(x_valid, y_valid, batch_size), validation_steps=nb_step_valid, epochs=n_epochs, callbacks=callbacks, ) return self
def test_save_best_only_with_restore_best(self): checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True, restore_best=True) val_losses = [10, 3, 8, 5, 2] has_checkpoints = [True, True, False, False, True] self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints) self._test_restore_best(val_losses)
def test_periodic_with_period_of_2(self): checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, period=2, save_best_only=False) val_losses = [1] * 10 has_checkpoints = [False, True] * 5 self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
def test_save_best_only_with_max(self): checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', mode='max', verbose=True, save_best_only=True) val_losses = [2, 3, 8, 5, 2] has_checkpoints = [True, True, True, False, False] self._test_checkpointer_with_val_losses(checkpointer, val_losses, has_checkpoints)
def test_temporary_filename_arg_with_differing_checkpoint_filename(self): epochs = 10 tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt') checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint_{epoch}.ckpt') train_gen = some_data_generator(ModelCheckpointTest.batch_size) valid_gen = some_data_generator(ModelCheckpointTest.batch_size) checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', verbose=True, period=1, temporary_filename=tmp_filename) self.model.fit_generator(train_gen, valid_gen, epochs=epochs, steps_per_epoch=5, callbacks=[checkpointer]) self.assertFalse(os.path.isfile(tmp_filename)) for i in range(1, epochs+1): self.assertTrue(os.path.isfile(checkpoint_filename.format(epoch=i)))
def test_integration(self): train_gen = some_data_generator(ModelCheckpointTest.batch_size) valid_gen = some_data_generator(ModelCheckpointTest.batch_size) checkpointer = ModelCheckpoint(self.checkpoint_filename, monitor='val_loss', verbose=True, save_best_only=True) self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer])
def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch): callbacks = [] best_checkpoint = ModelCheckpoint( self.best_checkpoint_filename, monitor=self.monitor_metric, mode=self.monitor_mode, save_best_only=not save_every_epoch, restore_best=not save_every_epoch, verbose=not save_every_epoch, temporary_filename=self.best_checkpoint_tmp_filename) callbacks.append(best_checkpoint) if save_every_epoch: best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True) callbacks.append(best_restore) if initial_epoch > 1: # We set the current best metric score in the ModelCheckpoint so that # it does not save checkpoint it would not have saved if the # optimization was not stopped. best_epoch_stats = self.get_best_epoch_stats() best_epoch = best_epoch_stats['epoch'].item() best_filename = self.best_checkpoint_filename.format( epoch=best_epoch) if not save_every_epoch: best_checkpoint.best_filename = best_filename best_checkpoint.current_best = best_epoch_stats[ self.monitor_metric].item() else: best_restore.best_weights = torch.load(best_filename, map_location='cpu') best_restore.current_best = best_epoch_stats[ self.monitor_metric].item() return callbacks
def fit(self, meta_train, meta_valid, n_epochs=100, steps_per_epoch=100, log_filename=None, checkpoint_filename=None, tboard_folder=None): if hasattr(self.model, 'is_eval'): self.model.is_eval = False self.is_eval = False self.steps_per_epoch = steps_per_epoch callbacks = [ EarlyStopping(patience=10, verbose=False), ReduceLROnPlateau(patience=2, factor=1 / 2, min_lr=1e-6, verbose=True), BestModelRestore() ] if log_filename: callbacks += [ CSVLogger(log_filename, batch_granularity=False, separator='\t') ] if checkpoint_filename: callbacks += [ ModelCheckpoint(checkpoint_filename, monitor='val_loss', save_best_only=True, temporary_filename=checkpoint_filename + 'temp') ] if tboard_folder is not None: self.writer = SummaryWriter(tboard_folder) self.fit_generator(meta_train, meta_valid, epochs=n_epochs, steps_per_epoch=steps_per_epoch, validation_steps=steps_per_epoch, callbacks=callbacks, verbose=True) self.is_fitted = True return self
def test_temporary_filename_arg(self): tmp_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.tmp.ckpt') checkpoint_filename = os.path.join(self.temp_dir_obj.name, 'my_checkpoint.ckpt') train_gen = some_data_generator(20) valid_gen = some_data_generator(20) checkpointer = ModelCheckpoint(checkpoint_filename, monitor='val_loss', verbose=True, period=1, temporary_filename=tmp_filename) self.model.fit_generator(train_gen, valid_gen, epochs=10, steps_per_epoch=5, callbacks=[checkpointer]) self.assertFalse(os.path.isfile(tmp_filename)) self.assertTrue(os.path.isfile(checkpoint_filename))
def train(self, train_loader, valid_loader=None, *, callbacks=[], lr_schedulers=[], save_every_epoch=False, disable_tensorboard=False, epochs=1000, steps_per_epoch=None, validation_steps=None, seed=42): if seed is not None: # Make training deterministic. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Copy callback list. callbacks = list(callbacks) tensorboard_writer = None initial_epoch = 1 if self.logging: if not os.path.exists(self.directory): os.makedirs(self.directory) # Restarting optimization if needed. initial_epoch = self._load_epoch_state(lr_schedulers) callbacks += [ CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1) ] callbacks += self._init_model_restoring_callbacks( initial_epoch, save_every_epoch) callbacks += [ ModelCheckpoint( self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename) ] callbacks += [ OptimizerCheckpoint( self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename) ] # We save the last epoch number after the end of the epoch so that the # _load_epoch_state() knows which epoch to restart the optimization. callbacks += [ PeriodicSaveLambda( lambda fd, epoch, logs: print(epoch, file=fd), self.epoch_filename, temporary_filename=self.epoch_tmp_filename, open_mode='w') ] tensorboard_writer, cb_list = self._init_tensorboard_callbacks( disable_tensorboard) callbacks += cb_list # This method returns callbacks that checkpoints the LR scheduler if logging is enabled. # Otherwise, it just returns the list of LR schedulers with a BestModelRestore callback. callbacks += self._init_lr_scheduler_callbacks(lr_schedulers) try: self.model.fit_generator(train_loader, valid_loader, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, initial_epoch=initial_epoch, callbacks=callbacks) finally: if tensorboard_writer is not None: tensorboard_writer.close()
def train(self, train_loader, valid_loader=None, callbacks=[], lr_schedulers=[], disable_tensorboard=False, epochs=1000, steps_per_epoch=None, validation_steps=None, seed=42): if seed is not None: # Make training deterministic. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Copy callback list. callbacks = list(callbacks) initial_epoch = 1 if self.logging: if not os.path.exists(self.directory): os.makedirs(self.directory) # Restarting optimization if needed. initial_epoch = self._load_epoch_state(lr_schedulers) csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1) best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename, monitor=self.monitor_metric, mode=self.monitor_mode, save_best_only=True, restore_best=True, verbose=True, temporary_filename=self.best_checkpoint_tmp_filename) if initial_epoch > 1: # We set the current best metric score in the ModelCheckpoint so that # it does not save checkpoint it would not have saved if the # optimization was not stopped. best_epoch_stats = self.get_best_epoch_stats() best_epoch = best_epoch_stats['epoch'].item() best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch) best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item() checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename) optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename) # We save the last epoch number after the end of the epoch so that the # load_epoch_state() knows which epoch to restart the optimization. save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd), self.epoch_filename, temporary_filename=self.epoch_tmp_filename, open_mode='w') callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number] if not disable_tensorboard: if SummaryWriter is None: warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.") else: writer = SummaryWriter(self.tensorboard_directory) tensorboard_logger = TensorBoardLogger(writer) callbacks.append(tensorboard_logger) for i, lr_scheduler in enumerate(lr_schedulers): filename = self.lr_scheduler_filename % i tmp_filename = self.lr_scheduler_tmp_filename % i lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename) callbacks.append(lr_scheduler_checkpoint) else: for lr_scheduler in lr_schedulers: callbacks.append(lr_scheduler) best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode) callbacks.append(best_restore) self.model.fit_generator(train_loader, valid_loader, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, initial_epoch=initial_epoch, callbacks=callbacks)