def get_random_policy(env_name, tabular_obs): if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) policy_fn, policy_info_spec = taxi.get_taxi_policy(env, env, alpha=0.0, py=False) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) policy_fn, policy_info_spec = navigation.get_navigation_policy( env, epsilon_explore=1.0, py=False) elif env_name == 'tree': env = tree.Tree(branching=2, depth=10) policy_fn, policy_info_spec = tree.get_tree_policy(env, epsilon_explore=1.0, py=False) else: raise ValueError('Unknown environment: %s.' % env_name) return policy_fn, policy_info_spec tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) tf_policy = common_utils.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) return tf_policy, policy_info_spec
def get_onpolicy_dataset(env_name, tabular_obs, policy_fn, policy_info_spec): """Gets target policy.""" if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) elif env_name == 'tree': env = tree.Tree(branching=2, depth=10) else: raise ValueError('Unknown environment: %s.' % env_name) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) tf_policy = common_utils.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) return TFAgentsOnpolicyDataset(tf_env, tf_policy)
def get_env_and_policy(load_dir, env_name, alpha, env_seed=0, tabular_obs=False): if env_name == 'taxi': env = taxi.Taxi(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = taxi.get_taxi_policy(load_dir, env, alpha=alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'grid': env = navigation.GridWalk(tabular_obs=tabular_obs) env.seed(env_seed) policy_fn, policy_info_spec = navigation.get_navigation_policy( env, epsilon_explore=0.1 + 0.6 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'low_rank': env = low_rank.LowRank() env.seed(env_seed) policy_fn, policy_info_spec = low_rank.get_low_rank_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'tree': env = tree.Tree(branching=2, depth=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'lowrank_tree': env = tree.Tree(branching=2, depth=3, duplicate=10) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name.startswith('bandit'): num_arms = int(env_name[6:]) if len(env_name) > 6 else 2 env = bandit.Bandit(num_arms=num_arms) env.seed(env_seed) policy_fn, policy_info_spec = bandit.get_bandit_policy( env, epsilon_explore=1 - alpha, py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'small_tree': env = tree.Tree(branching=2, depth=3, loop=True) env.seed(env_seed) policy_fn, policy_info_spec = tree.get_tree_policy( env, epsilon_explore=0.1 + 0.8 * (1 - alpha), py=False) tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) policy = common_lib.TFAgentsWrappedPolicy(tf_env.time_step_spec(), tf_env.action_spec(), policy_fn, policy_info_spec, emit_log_probability=True) elif env_name == 'CartPole-v0': tf_env, policy = get_env_and_dqn_policy( env_name, os.path.join(load_dir, 'CartPole-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) elif env_name == 'cartpole': # Infinite-horizon cartpole. tf_env, policy = get_env_and_dqn_policy( 'CartPole-v0', os.path.join(load_dir, 'CartPole-v0-250', 'train', 'policy'), env_seed=env_seed, epsilon=0.3 + 0.15 * (1 - alpha)) env = InfiniteCartPole() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name == 'FrozenLake-v0': tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') elif env_name == 'frozenlake': # Infinite-horizon frozenlake. tf_env, policy = get_env_and_dqn_policy('FrozenLake-v0', os.path.join( load_dir, 'FrozenLake-v0', 'train', 'policy'), env_seed=env_seed, epsilon=0.2 * (1 - alpha), ckpt_file='ckpt-100000') env = InfiniteFrozenLake() tf_env = tf_py_environment.TFPyEnvironment(gym_wrapper.GymWrapper(env)) elif env_name in ['Reacher-v2', 'reacher']: if env_name == 'Reacher-v2': env = suites.load_mujoco(env_name) else: env = gym_wrapper.GymWrapper(InfiniteReacher()) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, 'Reacher-v2', 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.4 - 0.3 * alpha, emit_log_probability=True) elif env_name == 'HalfCheetah-v2': env = suites.load_mujoco(env_name) env.seed(env_seed) tf_env = tf_py_environment.TFPyEnvironment(env) sac_policy = get_sac_policy(tf_env) directory = os.path.join(load_dir, env_name, 'train', 'policy') policy = load_policy(sac_policy, env_name, directory) policy = GaussianPolicy(policy, 0.2 - 0.1 * alpha, emit_log_probability=True) else: raise ValueError('Unrecognized environment %s.' % env_name) return tf_env, policy