Пример #1
0
def run(args):
    env = make_dmc_env(args.domain_name, args.task_name, args.action_repeat)
    env_test = make_dmc_env(args.domain_name, args.task_name, args.action_repeat)

    algo = SAC_AE(
        num_agent_steps=args.num_agent_steps,
        state_space=env.observation_space,
        action_space=env.action_space,
        seed=args.seed,
    )

    time = datetime.now().strftime("%Y%m%d-%H%M")
    log_dir = os.path.join("logs", f"{args.domain_name}-{args.task_name}", f"{str(algo)}-seed{args.seed}-{time}")

    trainer = Trainer(
        env=env,
        env_test=env_test,
        algo=algo,
        log_dir=log_dir,
        num_agent_steps=args.num_agent_steps,
        action_repeat=args.action_repeat,
        eval_interval=args.eval_interval,
        seed=args.seed,
    )
    trainer.train()
Пример #2
0
def run(args):
    env = make_continuous_env(args.env_id)
    env_test = make_continuous_env(args.env_id)

    algo = DDPG(
        num_agent_steps=args.num_agent_steps,
        state_space=env.observation_space,
        action_space=env.action_space,
        seed=args.seed,
    )

    time = datetime.now().strftime("%Y%m%d-%H%M")
    log_dir = os.path.join("logs", args.env_id,
                           f"{str(algo)}-seed{args.seed}-{time}")

    trainer = Trainer(
        env=env,
        env_test=env_test,
        algo=algo,
        log_dir=log_dir,
        num_agent_steps=args.num_agent_steps,
        eval_interval=args.eval_interval,
        seed=args.seed,
    )
    trainer.train()
Пример #3
0
def run(args):
    env = make_atari_env(args.env_id, sign_rewards=False, clip_rewards=True)
    env_test = make_atari_env(args.env_id,
                              episode_life=False,
                              sign_rewards=False)

    algo = SAC_Discrete(
        num_agent_steps=args.num_agent_steps,
        state_space=env.observation_space,
        action_space=env.action_space,
        seed=args.seed,
    )

    time = datetime.now().strftime("%Y%m%d-%H%M")
    log_dir = os.path.join("logs", args.env_id,
                           f"{str(algo)}-seed{args.seed}-{time}")

    trainer = Trainer(
        env=env,
        env_test=env_test,
        algo=algo,
        log_dir=log_dir,
        num_agent_steps=args.num_agent_steps,
        action_repeat=4,
        eval_interval=args.eval_interval,
        seed=args.seed,
    )
    trainer.train()