예제 #1
0
파일: bandit.py 프로젝트: veds12/genrl
 def __init__(
     self,
     agent: Any,
     bandit: Any,
     logdir: str = "./logs",
     log_mode: List[str] = ["stdout"],
 ):
     self.agent = agent
     self.bandit = bandit
     self.logdir = logdir
     self.log_mode = log_mode
     self.logger = Logger(logdir=logdir, formats=[*log_mode])
예제 #2
0
파일: base.py 프로젝트: veds12/genrl
    def __init__(
        self,
        agent: Any,
        env: Union[gym.Env, VecEnv],
        log_mode: List[str] = ["stdout"],
        log_key: str = "timestep",
        log_interval: int = 10,
        logdir: str = "logs",
        epochs: int = 50,
        max_timesteps: int = None,
        off_policy: bool = False,
        save_interval: int = 0,
        save_model: str = "checkpoints",
        run_num: int = None,
        load_weights: str = None,
        load_hyperparams: str = None,
        render: bool = False,
        evaluate_episodes: int = 25,
        seed: Optional[int] = None,
    ):
        self.agent = agent
        self.env = env
        self.log_mode = log_mode
        self.log_key = log_key
        self.log_interval = log_interval
        self.logdir = logdir
        self.epochs = epochs
        self.max_timesteps = max_timesteps
        self.off_policy = off_policy
        self.save_interval = save_interval
        self.save_model = save_model
        self.run_num = run_num
        self.load_weights = load_weights
        self.load_hyperparams = load_hyperparams
        self.render = render
        self.evaluate_episodes = evaluate_episodes

        if seed is not None:
            set_seeds(seed, self.env)

        self.logger = Logger(logdir=logdir, formats=[*log_mode])
예제 #3
0
def test_loggers():
    logger = Logger("./logs", formats=["csv", "stdout", "tensorboard"])
    logger.write({"hello": 0000, "timestep": 10}, log_key="timestep")
    logger.close()
    shutil.rmtree("./logs")