Beispiel #1
0
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "PongNoFrameskip-v4"
    env = atari_env_generator(env_name, args.max_episode_steps)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent.env_info = dict(
        env_name="PongNoFrameskip-v4",
        observation_space=env.observation_space,
        action_space=env.action_space,
        is_discrete=True,
    )
    cfg.agent.log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    elif args.test and args.grad_cam:
        agent.test_with_gradcam()
    else:
        agent.test()
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env = gym.make("Reacher-v2")
    env_utils.set_env(env, args)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent.env_info = dict(
        env_name="Reacher-v2",
        observation_space=env.observation_space,
        action_space=env.action_space,
        is_discrete=False,
    )
    cfg.agent.log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "LunarLanderContinuous-v2"
    env = gym.make(env_name)
    env = env_utils.set_env(env, args)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)

    # If running integration test, simplify experiment
    if args.integration_test:
        cfg = common_utils.set_cfg_for_intergration_test(cfg)

    cfg.agent.env_info = dict(
        name=env_name,
        observation_space=env.observation_space,
        action_space=env.action_space,
    )
    cfg.agent.log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "Drone_discrete"
    train_mode = not args.test
    env = unity_env_generator(env_name, train_mode, args.worker_id)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent["log_cfg"] = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "PongNoFrameskip-v4"
    env_gen = env_generator(env_name,
                            args.max_episode_steps,
                            frame_stack=args.framestack)
    env = env_gen(0)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = YamlConfig(dict(agent=args.cfg_path)).get_config_dict()

    # If running integration test, simplify experiment
    if args.integration_test:
        cfg = common_utils.set_cfg_for_intergration_test(cfg)

    env_info = dict(
        name=env.spec.id,
        observation_space=env.observation_space,
        action_space=env.action_space,
        is_atari=True,
        env_generator=env_gen,
    )
    log_cfg = dict(agent=cfg.agent.type,
                   curr_time=curr_time,
                   cfg_path=args.cfg_path)
    build_args = dict(
        env=env,
        env_info=env_info,
        log_cfg=log_cfg,
        is_test=args.test,
        load_from=args.load_from,
        is_render=args.render,
        render_after=args.render_after,
        is_log=args.log,
        save_period=args.save_period,
        episode_num=args.episode_num,
        max_episode_steps=env.spec.max_episode_steps,
        interim_test_num=args.interim_test_num,
    )
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    elif args.test and args.grad_cam:
        agent.test_with_gradcam()
    elif args.test and args.saliency_map:
        agent.test_with_saliency_map()
    else:
        agent.test()
Beispiel #6
0
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "Reacher-v2"
    env = gym.make(env_name)
    env, max_episode_steps = env_utils.set_env(env, args.max_episode_steps)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = YamlConfig(dict(agent=args.cfg_path)).get_config_dict()

    # If running integration test, simplify experiment
    if args.integration_test:
        cfg = common_utils.set_cfg_for_intergration_test(cfg)

    env_info = dict(
        name=env.spec.id,
        observation_space=env.observation_space,
        action_space=env.action_space,
        is_atari=False,
    )
    log_cfg = dict(agent=cfg.agent.type,
                   curr_time=curr_time,
                   cfg_path=args.cfg_path)
    build_args = dict(
        env=env,
        env_info=env_info,
        log_cfg=log_cfg,
        is_test=args.test,
        load_from=args.load_from,
        is_render=args.render,
        render_after=args.render_after,
        is_log=args.log,
        save_period=args.save_period,
        episode_num=args.episode_num,
        max_episode_steps=max_episode_steps,
        interim_test_num=args.interim_test_num,
    )
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
Beispiel #7
0
def test_config_registry():
    # configurations
    args = parse_args(["--test"])

    # set env
    env = gym.make("LunarLanderContinuous-v2")

    # check start time
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent.env_info = dict(
        env_name="LunarLanderContinuous-v2",
        observation_space=env.observation_space,
        action_space=env.action_space,
    )
    cfg.agent.log_cfg = dict(agent=cfg.agent.type, curr_time=curr_time)
    default_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, default_args)
    assert isinstance(agent, Agent)
Beispiel #8
0
def test_config_registry():
    # configurations
    args = parse_args(["--test"])

    # set env
    env = gym.make("LunarLanderContinuous-v2")

    # check start time
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    env_info = dict(
        name=env.spec.id,
        observation_space=env.observation_space,
        action_space=env.action_space,
        is_atari=False,
    )
    log_cfg = dict(agent=cfg.agent.type,
                   curr_time=curr_time,
                   cfg_path=args.cfg_path)
    build_args = dict(
        env=env,
        env_info=env_info,
        log_cfg=log_cfg,
        is_test=args.test,
        load_from=args.load_from,
        is_render=args.render,
        render_after=args.render_after,
        is_log=args.log,
        save_period=args.save_period,
        episode_num=args.episode_num,
        max_episode_steps=args.max_episode_steps,
        interim_test_num=args.interim_test_num,
    )
    agent = build_agent(cfg.agent, build_args)
    assert isinstance(agent, Agent)
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env = gym.make("LunarLander-v2")
    env_utils.set_env(env, args)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent["log_cfg"] = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
def main():
    """Main."""
    args = parse_args()

    # env initialization
    env_name = "PongNoFrameskip-v4"
    env = atari_env_generator(env_name, args.max_episode_steps)

    # set a random seed
    common_utils.set_random_seed(args.seed, env)

    # run
    NOWTIMES = datetime.datetime.now()
    curr_time = NOWTIMES.strftime("%y%m%d_%H%M%S")

    cfg = Config.fromfile(args.cfg_path)
    cfg.agent["log_cfg"] = dict(agent=cfg.agent.type, curr_time=curr_time)
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    if not args.test:
        agent.train()
    else:
        agent.test()
def build_agent_from_config(cfg, env) -> Agent:
    build_args = dict(args=args, env=env)
    agent = build_agent(cfg.agent, build_args)

    return agent