def load_model_weights(model, checkpoint, epoch='latest'): """Loads weights for a given model from the given checkpoint directory.""" checkpoint_dir = checkpoint if os.path.basename(checkpoint) == 'weights' \ else os.path.join(checkpoint, 'weights') # checkpts in 'weights' dir checkpoint_path = CheckpointHandler.get_resume_ckpt_file( epoch, checkpoint_dir) CheckpointHandler.load_weights(checkpoint_path, model=model)
def resume(self, ckpt, path=None): path = os.path.join(self._hp.exp_path, 'weights') if path is None else os.path.join(path, 'weights') assert ckpt is not None # need to specify resume epoch for loading checkpoint weights_file = CheckpointHandler.get_resume_ckpt_file(ckpt, path) # TODO(karl): check whether that actually loads the optimizer too self.global_step, start_epoch, _ = \ CheckpointHandler.load_weights(weights_file, self.agent, load_step=True, strict=self.args.strict_weight_loading) self.agent.load_state(self._hp.exp_path) self.agent.to(self.device) return start_epoch
def resume(self, ckpt, path=None): path = os.path.join(self._hp.exp_path, 'weights') if path is None else os.path.join( path, 'weights') assert ckpt is not None # need to specify resume epoch for loading checkpoint weights_file = CheckpointHandler.get_resume_ckpt_file(ckpt, path) self.global_step, start_epoch, _ = \ CheckpointHandler.load_weights(weights_file, self.model, load_step=True, load_opt=True, optimizer=self.optimizer, strict=self.args.strict_weight_loading) self.model.to(self.model.device) return start_epoch
def run_val_sweep(self): epochs = CheckpointHandler.get_epochs( os.path.join(self._hp.exp_path, 'weights')) for epoch in list(sorted(epochs))[::2]: self.resume(epoch) self.val() return
def train(self, start_epoch): """Run outer training loop.""" if self._hp.n_warmup_steps > 0: self.warmup() for epoch in range(start_epoch, self._hp.num_epochs): print("Epoch {}".format(epoch)) self.train_epoch(epoch) if not self.args.dont_save and self.is_chef: save_checkpoint({ 'epoch': epoch, 'global_step': self.global_step, 'state_dict': self.agent.state_dict(), }, os.path.join(self._hp.exp_path, 'weights'), CheckpointHandler.get_ckpt_name(epoch)) self.agent.save_state(self._hp.exp_path) self.val()
def train(self, start_epoch): if not self.args.skip_first_val: self.val() for epoch in range(start_epoch, self._hp.num_epochs): self.train_epoch(epoch) if not self.args.dont_save: save_checkpoint( { 'epoch': epoch, 'global_step': self.global_step, 'state_dict': self.model.state_dict(), 'optimizer': self.optimizer.state_dict(), }, os.path.join(self._hp.exp_path, 'weights'), CheckpointHandler.get_ckpt_name(epoch)) if epoch % self.args.val_interval == 0: self.val()
def delete_non_latest_checkpoint(dir): latest_checkpoint = CheckpointHandler.get_resume_ckpt_file("latest", dir) checkpoint_names = glob.glob(os.path.abspath(dir) + "/*.pth") for file in checkpoint_names: if file != latest_checkpoint: os.remove(file)