コード例 #1
0
ファイル: train_reader.py プロジェクト: silencio94/DPR
    def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str:
        args = self.args
        model_to_save = get_model_obj(self.reader)
        cp = os.path.join(args.output_dir,
                          args.checkpoint_file_name + '.' + str(epoch) + ('.' + str(offset) if offset > 0 else ''))

        meta_params = get_encoder_params_state(args)

        state = CheckpointState(model_to_save.state_dict(), self.optimizer.state_dict(), scheduler.state_dict(), offset,
                                epoch, meta_params
                                )
        torch.save(state._asdict(), cp)
        return cp
コード例 #2
0
 def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str:
     cfg = self.cfg
     model_to_save = get_model_obj(self.biencoder)
     cp = os.path.join(cfg.output_dir, cfg.checkpoint_file_name + "." + str(epoch))
     meta_params = get_encoder_params_state_from_cfg(cfg)
     state = CheckpointState(
         model_to_save.get_state_dict(),
         self.optimizer.state_dict(),
         scheduler.state_dict(),
         offset,
         epoch,
         meta_params,
     )
     torch.save(state._asdict(), cp)
     logger.info("Saved checkpoint at %s", cp)
     return cp
コード例 #3
0
    def _save_checkpoint(self, scheduler, epoch: int, offset: int) -> str:
        args = self.args
        model_to_save = get_model_obj(self.biencoder)
        cp = os.path.join(
            args.output_dir,
            args.checkpoint_file_name + "." + str(epoch) +
            ("." + str(offset) if offset > 0 else ""),
        )

        meta_params = get_encoder_params_state(args)

        state = CheckpointState(
            model_to_save.state_dict(),
            self.optimizer.state_dict(),
            scheduler.state_dict(),
            offset,
            epoch,
            meta_params,
        )
        torch.save(state._asdict(), cp)
        logger.info("Saved checkpoint at %s", cp)
        return cp