def get_random_policy(env_name, tabular_obs):
    if env_name == 'taxi':
        env = taxi.Taxi(tabular_obs=tabular_obs)
        policy_fn, policy_info_spec = taxi.get_taxi_policy(env,
                                                           env,
                                                           alpha=0.0,
                                                           py=False)
    elif env_name == 'grid':
        env = navigation.GridWalk(tabular_obs=tabular_obs)
        policy_fn, policy_info_spec = navigation.get_navigation_policy(
            env, epsilon_explore=1.0, py=False)
    elif env_name == 'tree':
        env = tree.Tree(branching=2, depth=10)
        policy_fn, policy_info_spec = tree.get_tree_policy(env,
                                                           epsilon_explore=1.0,
                                                           py=False)
    else:
        raise ValueError('Unknown environment: %s.' % env_name)

    return policy_fn, policy_info_spec
    tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    tf_policy = common_utils.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                   tf_env.action_spec(),
                                                   policy_fn,
                                                   policy_info_spec,
                                                   emit_log_probability=True)

    return tf_policy, policy_info_spec
Esempio n. 2
0
def get_dataset(tabular_obs, epsilon_explore, limit):
    """Get on-policy dataset."""
    env = navigation.GridWalk(tabular_obs=tabular_obs)
    policy_fn, policy_info_spec = navigation.get_navigation_policy(
        env, epsilon_explore=0.5)
    dataset = gym_onpolicy_dataset.GymOnpolicyDataset(env,
                                                      policy_fn,
                                                      policy_info_spec,
                                                      episode_step_limit=limit)
    return dataset
Esempio n. 3
0
def get_env_and_policy(load_dir,
                       env_name,
                       alpha,
                       env_seed=0,
                       tabular_obs=False):
    if env_name == 'taxi':
        env = taxi.Taxi(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir,
                                                           env,
                                                           alpha=alpha,
                                                           py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'grid':
        env = navigation.GridWalk(tabular_obs=tabular_obs)
        env.seed(env_seed)
        policy_fn, policy_info_spec = navigation.get_navigation_policy(
            env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'low_rank':
        env = low_rank.LowRank()
        env.seed(env_seed)
        policy_fn, policy_info_spec = low_rank.get_low_rank_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'tree':
        env = tree.Tree(branching=2, depth=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'lowrank_tree':
        env = tree.Tree(branching=2, depth=3, duplicate=10)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name.startswith('bandit'):
        num_arms = int(env_name[6:]) if len(env_name) > 6 else 2
        env = bandit.Bandit(num_arms=num_arms)
        env.seed(env_seed)
        policy_fn, policy_info_spec = bandit.get_bandit_policy(
            env, epsilon_explore=1 - alpha, py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'small_tree':
        env = tree.Tree(branching=2, depth=3, loop=True)
        env.seed(env_seed)
        policy_fn, policy_info_spec = tree.get_tree_policy(
            env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False)
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
        policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(),
                                                  tf_env.action_spec(),
                                                  policy_fn,
                                                  policy_info_spec,
                                                  emit_log_probability=True)
    elif env_name == 'CartPole-v0':
        tf_env, policy = get_env_and_dqn_policy(
            env_name,
            os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
    elif env_name == 'cartpole':  # Infinite-horizon cartpole.
        tf_env, policy = get_env_and_dqn_policy(
            'CartPole-v0',
            os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'),
            env_seed=env_seed,
            epsilon=0.3 + 0.15 * (1 - alpha))
        env = InfiniteCartPole()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name == 'FrozenLake-v0':
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
    elif env_name == 'frozenlake':  # Infinite-horizon frozenlake.
        tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0',
                                                os.path.join(
                                                    load_dir, 'FrozenLake-v0',
                                                    'train', 'policy'),
                                                env_seed=env_seed,
                                                epsilon=0.2 * (1 - alpha),
                                                ckpt_file='ckpt-100000')
        env = InfiniteFrozenLake()
        tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env))
    elif env_name in ['Reacher-v2', 'reacher']:
        if env_name == 'Reacher-v2':
            env = suites.load_mujoco(env_name)
        else:
            env = gym_wrapper.GymWrapper(InfiniteReacher())
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.4 - 0.3 * alpha,
                                emit_log_probability=True)
    elif env_name == 'HalfCheetah-v2':
        env = suites.load_mujoco(env_name)
        env.seed(env_seed)
        tf_env = tf_py_environment.TFPyEnvironment(env)
        sac_policy = get_sac_policy(tf_env)
        directory = os.path.join(load_dir, env_name, 'train', 'policy')
        policy = load_policy(sac_policy, env_name, directory)
        policy = GaussianPolicy(policy,
                                0.2 - 0.1 * alpha,
                                emit_log_probability=True)
    else:
        raise ValueError('Unrecognized environment %s.' % env_name)

    return tf_env, policy