def _init_model_restoring_callbacks(self, initial_epoch, save_every_epoch): callbacks = [] best_checkpoint = ModelCheckpoint( self.best_checkpoint_filename, monitor=self.monitor_metric, mode=self.monitor_mode, save_best_only=not save_every_epoch, restore_best=not save_every_epoch, verbose=not save_every_epoch, temporary_filename=self.best_checkpoint_tmp_filename) callbacks.append(best_checkpoint) if save_every_epoch: best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode, verbose=True) callbacks.append(best_restore) if initial_epoch > 1: # We set the current best metric score in the ModelCheckpoint so that # it does not save checkpoint it would not have saved if the # optimization was not stopped. best_epoch_stats = self.get_best_epoch_stats() best_epoch = best_epoch_stats['epoch'].item() best_filename = self.best_checkpoint_filename.format( epoch=best_epoch) if not save_every_epoch: best_checkpoint.best_filename = best_filename best_checkpoint.current_best = best_epoch_stats[ self.monitor_metric].item() else: best_restore.best_weights = torch.load(best_filename, map_location='cpu') best_restore.current_best = best_epoch_stats[ self.monitor_metric].item() return callbacks
def train(self, train_loader, valid_loader=None, callbacks=[], lr_schedulers=[], disable_tensorboard=False, epochs=1000, steps_per_epoch=None, validation_steps=None, seed=42): if seed is not None: # Make training deterministic. random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) # Copy callback list. callbacks = list(callbacks) initial_epoch = 1 if self.logging: if not os.path.exists(self.directory): os.makedirs(self.directory) # Restarting optimization if needed. initial_epoch = self._load_epoch_state(lr_schedulers) csv_logger = CSVLogger(self.log_filename, separator='\t', append=initial_epoch != 1) best_checkpoint = ModelCheckpoint(self.best_checkpoint_filename, monitor=self.monitor_metric, mode=self.monitor_mode, save_best_only=True, restore_best=True, verbose=True, temporary_filename=self.best_checkpoint_tmp_filename) if initial_epoch > 1: # We set the current best metric score in the ModelCheckpoint so that # it does not save checkpoint it would not have saved if the # optimization was not stopped. best_epoch_stats = self.get_best_epoch_stats() best_epoch = best_epoch_stats['epoch'].item() best_checkpoint.best_filename = self.best_checkpoint_filename.format(epoch=best_epoch) best_checkpoint.current_best = best_epoch_stats[self.monitor_metric].item() checkpoint = ModelCheckpoint(self.model_checkpoint_filename, verbose=False, temporary_filename=self.model_checkpoint_tmp_filename) optimizer_checkpoint = OptimizerCheckpoint(self.optimizer_checkpoint_filename, verbose=False, temporary_filename=self.optimizer_checkpoint_tmp_filename) # We save the last epoch number after the end of the epoch so that the # load_epoch_state() knows which epoch to restart the optimization. save_epoch_number = PeriodicSaveLambda(lambda fd, epoch, logs: print(epoch, file=fd), self.epoch_filename, temporary_filename=self.epoch_tmp_filename, open_mode='w') callbacks += [csv_logger, best_checkpoint, checkpoint, optimizer_checkpoint, save_epoch_number] if not disable_tensorboard: if SummaryWriter is None: warnings.warn("tensorboardX does not seem to be installed. To remove this warning, set the 'disable_tensorboard' flag to True.") else: writer = SummaryWriter(self.tensorboard_directory) tensorboard_logger = TensorBoardLogger(writer) callbacks.append(tensorboard_logger) for i, lr_scheduler in enumerate(lr_schedulers): filename = self.lr_scheduler_filename % i tmp_filename = self.lr_scheduler_tmp_filename % i lr_scheduler_checkpoint = LRSchedulerCheckpoint(lr_scheduler, filename, verbose=False, temporary_filename=tmp_filename) callbacks.append(lr_scheduler_checkpoint) else: for lr_scheduler in lr_schedulers: callbacks.append(lr_scheduler) best_restore = BestModelRestore(monitor=self.monitor_metric, mode=self.monitor_mode) callbacks.append(best_restore) self.model.fit_generator(train_loader, valid_loader, epochs=epochs, steps_per_epoch=steps_per_epoch, validation_steps=validation_steps, initial_epoch=initial_epoch, callbacks=callbacks)