def test_reduce_lr_on_plateau_integration(self): train_gen = some_data_generator(OptimizerCheckpointTest.batch_size) valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size) reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3) checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1) self.model.fit_generator(train_gen, valid_gen, epochs=OptimizerCheckpointTest.epochs, steps_per_epoch=5, callbacks=[checkpointer])
def test_any_scheduler_integration(self): train_gen = some_data_generator(OptimizerCheckpointTest.batch_size) valid_gen = some_data_generator(OptimizerCheckpointTest.batch_size) lr_scheduler = ExponentialLR(gamma=0.01) checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1) self.model.fit_generator(train_gen, valid_gen, epochs=OptimizerCheckpointTest.epochs, steps_per_epoch=5, callbacks=[checkpointer])
def _init_lr_scheduler_callbacks(self, lr_schedulers): callbacks = [] if self.logging: for i, lr_scheduler in enumerate(lr_schedulers): filename = self.lr_scheduler_filename % i tmp_filename = self.lr_scheduler_tmp_filename % i callbacks += [ LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename) ] else: callbacks += lr_schedulers callbacks += [ BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True) ] return callbacks
def test_reduce_lr_checkpoints(self): reduce_lr = ReduceLROnPlateau(monitor='loss', patience=3) checkpointer = LRSchedulerCheckpoint(reduce_lr, self.checkpoint_filename, period=1) self._test_checkpointer(checkpointer, reduce_lr)
def test_any_scheduler_checkpoints(self): lr_scheduler = ExponentialLR(gamma=0.01) checkpointer = LRSchedulerCheckpoint(lr_scheduler, self.checkpoint_filename, period=1) self._test_checkpointer(checkpointer, lr_scheduler)
def train(self, train_loader, valid_loader=None, callbacks=[], lr_schedulers=[], disable_tensorboard=False, epochs=1000, steps_per_epoch=None, validation_steps=None, seed=42): if seed is not None: # Make training deterministic. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Copy callback list. callbacks = list(callbacks) initial_epoch = 1 if self.logging: if not os.path.exists(self.directory): os.makedirs(self.directory) # Restarting optimization if needed. initial_epoch = self._load_epoch_state(lr_schedulers) csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1) best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename, monitor=self.monitor_metric, mode=self.monitor_mode, save_best_only=True, restore_best=True, verbose=True, temporary_filename=self.best_checkpoint_tmp_filename) if initial_epoch > 1: # We set the current best metric score in the ModelCheckpoint so that # it does not save checkpoint it would not have saved if the # optimization was not stopped. best_epoch_stats = self.get_best_epoch_stats() best_epoch = best_epoch_stats['epoch'].item() best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch) best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item() checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename) optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename) # We save the last epoch number after the end of the epoch so that the # load_epoch_state() knows which epoch to restart the optimization. save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd), self.epoch_filename, temporary_filename=self.epoch_tmp_filename, open_mode='w') callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number] if not disable_tensorboard: if SummaryWriter is None: warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.") else: writer = SummaryWriter(self.tensorboard_directory) tensorboard_logger = TensorBoardLogger(writer) callbacks.append(tensorboard_logger) for i, lr_scheduler in enumerate(lr_schedulers): filename = self.lr_scheduler_filename % i tmp_filename = self.lr_scheduler_tmp_filename % i lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename) callbacks.append(lr_scheduler_checkpoint) else: for lr_scheduler in lr_schedulers: callbacks.append(lr_scheduler) best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode) callbacks.append(best_restore) self.model.fit_generator(train_loader, valid_loader, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, initial_epoch=initial_epoch, callbacks=callbacks)