コード例 #1
0
ファイル: agent.py プロジェクト: xiaofei-w/spirl
 def load_model_weights(model, checkpoint, epoch='latest'):
     """Loads weights for a given model from the given checkpoint directory."""
     checkpoint_dir = checkpoint if os.path.basename(checkpoint) == 'weights' \
                         else os.path.join(checkpoint, 'weights')     # checkpts in 'weights' dir
     checkpoint_path = CheckpointHandler.get_resume_ckpt_file(
         epoch, checkpoint_dir)
     CheckpointHandler.load_weights(checkpoint_path, model=model)
コード例 #2
0
 def resume(self, ckpt, path=None):
     path = os.path.join(self._hp.exp_path, 'weights') if path is None else os.path.join(path, 'weights')
     assert ckpt is not None  # need to specify resume epoch for loading checkpoint
     weights_file = CheckpointHandler.get_resume_ckpt_file(ckpt, path)
     # TODO(karl): check whether that actually loads the optimizer too
     self.global_step, start_epoch, _ = \
         CheckpointHandler.load_weights(weights_file, self.agent,
                                        load_step=True, strict=self.args.strict_weight_loading)
     self.agent.load_state(self._hp.exp_path)
     self.agent.to(self.device)
     return start_epoch
コード例 #3
0
ファイル: train.py プロジェクト: clvrai/spirl
 def resume(self, ckpt, path=None):
     path = os.path.join(self._hp.exp_path,
                         'weights') if path is None else os.path.join(
                             path, 'weights')
     assert ckpt is not None  # need to specify resume epoch for loading checkpoint
     weights_file = CheckpointHandler.get_resume_ckpt_file(ckpt, path)
     self.global_step, start_epoch, _ = \
         CheckpointHandler.load_weights(weights_file, self.model,
                                        load_step=True, load_opt=True, optimizer=self.optimizer,
                                        strict=self.args.strict_weight_loading)
     self.model.to(self.model.device)
     return start_epoch
コード例 #4
0
ファイル: train.py プロジェクト: clvrai/spirl
 def run_val_sweep(self):
     epochs = CheckpointHandler.get_epochs(
         os.path.join(self._hp.exp_path, 'weights'))
     for epoch in list(sorted(epochs))[::2]:
         self.resume(epoch)
         self.val()
     return
コード例 #5
0
    def train(self, start_epoch):
        """Run outer training loop."""
        if self._hp.n_warmup_steps > 0:
            self.warmup()

        for epoch in range(start_epoch, self._hp.num_epochs):
            print("Epoch {}".format(epoch))
            self.train_epoch(epoch)

            if not self.args.dont_save and self.is_chef:
                save_checkpoint({
                    'epoch': epoch,
                    'global_step': self.global_step,
                    'state_dict': self.agent.state_dict(),
                }, os.path.join(self._hp.exp_path, 'weights'), CheckpointHandler.get_ckpt_name(epoch))
                self.agent.save_state(self._hp.exp_path)
                self.val()
コード例 #6
0
ファイル: train.py プロジェクト: clvrai/spirl
    def train(self, start_epoch):
        if not self.args.skip_first_val:
            self.val()

        for epoch in range(start_epoch, self._hp.num_epochs):
            self.train_epoch(epoch)

            if not self.args.dont_save:
                save_checkpoint(
                    {
                        'epoch': epoch,
                        'global_step': self.global_step,
                        'state_dict': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict(),
                    }, os.path.join(self._hp.exp_path, 'weights'),
                    CheckpointHandler.get_ckpt_name(epoch))

            if epoch % self.args.val_interval == 0:
                self.val()
コード例 #7
0
def delete_non_latest_checkpoint(dir):
    latest_checkpoint = CheckpointHandler.get_resume_ckpt_file("latest", dir)
    checkpoint_names = glob.glob(os.path.abspath(dir) + "/*.pth")
    for file in checkpoint_names:
        if file != latest_checkpoint:
            os.remove(file)