Example #1
0
def make_env(name, action_repeat):
    env = gym.make(name)
    env = ActionRepeat(env, action_repeat)
    env = RescaleAction(env, -1.0, 1.0)
    train_env = ObservationNormalize(env)
    test_env = TestObservationNormalize(train_env)
    return train_env, test_env
Example #2
0
def main():

    parser = ArgumentParser()
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--seed', type=int, default=100)
    parser.add_argument('--use_obs_filter',
                        dest='obs_filter',
                        action='store_true')
    parser.add_argument('--update_every_n_steps', type=int, default=1)
    parser.add_argument('--n_random_actions', type=int, default=10000)
    parser.add_argument('--n_collect_steps', type=int, default=1000)
    parser.add_argument('--n_evals', type=int, default=1)
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.set_defaults(obs_filter=False)
    parser.set_defaults(save_model=False)

    args = parser.parse_args()
    params = vars(args)

    seed = params['seed']
    env = gym.make(params['env'])
    env = RescaleAction(env, -1, 1)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    agent = SAC_Agent(seed, state_dim, action_dim)

    train_agent_model_free(agent=agent, env=env, params=params)
Example #3
0
def make_env(
    env_name,
    seed,
    save_dir = None,
    add_episode_monitor = True,
    action_repeat = 1,
    frame_stack = 1,
):
  """Env factory with wrapping.

  Args:
    env_name: The name of the environment.
    seed: The RNG seed.
    save_dir: Specifiy a save directory to wrap with `VideoRecorder`.
    add_episode_monitor: Set to True to wrap with `EpisodeMonitor`.
    action_repeat: A value > 1 will wrap with `ActionRepeat`.
    frame_stack: A value > 1 will wrap with `FrameStack`.

  Returns:
    gym.Env object.
  """
  # Check if the env is in x-magical.
  xmagical.register_envs()
  if env_name in xmagical.ALL_REGISTERED_ENVS:
    env = gym.make(env_name)
  else:
    raise ValueError(f"{env_name} is not a valid environment name.")

  if add_episode_monitor:
    env = wrappers.EpisodeMonitor(env)
  if action_repeat > 1:
    env = wrappers.ActionRepeat(env, action_repeat)
  env = RescaleAction(env, -1.0, 1.0)
  if save_dir is not None:
    env = wrappers.VideoRecorder(env, save_dir=save_dir)
  if frame_stack > 1:
    env = wrappers.FrameStack(env, frame_stack)

  # Seed.
  env.seed(seed)
  env.action_space.seed(seed)
  env.observation_space.seed(seed)

  return env
Example #4
0
def test_rescale_action():
    env = gym.make("CartPole-v1")
    with pytest.raises(AssertionError):
        env = RescaleAction(env, -1, 1)
    del env

    env = gym.make("Pendulum-v1")
    wrapped_env = RescaleAction(gym.make("Pendulum-v1"), -1, 1)

    seed = 0

    obs = env.reset(seed=seed)
    wrapped_obs = wrapped_env.reset(seed=seed)
    assert np.allclose(obs, wrapped_obs)

    obs, reward, _, _ = env.step([1.5])
    with pytest.raises(AssertionError):
        wrapped_env.step([1.5])
    wrapped_obs, wrapped_reward, _, _ = wrapped_env.step([0.75])

    assert np.allclose(obs, wrapped_obs)
    assert np.allclose(reward, wrapped_reward)
Example #5
0
def main():

    parser = ArgumentParser()
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--seed', type=int, default=100)
    parser.add_argument('--use_obs_filter',
                        dest='obs_filter',
                        action='store_true')
    parser.add_argument('--update_every_n_steps', type=int, default=1)
    parser.add_argument('--n_random_actions', type=int, default=10000)
    parser.add_argument('--n_collect_steps', type=int, default=1000)
    parser.add_argument('--n_evals', type=int, default=1)
    parser.add_argument('--experiment_name', type=str, default='')
    parser.add_argument('--make_gif', dest='make_gif', action='store_true')
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--total_steps', type=int, default=int(1e7))
    parser.set_defaults(obs_filter=False)
    parser.set_defaults(save_model=False)

    args = parser.parse_args()
    params = vars(args)

    seed = params['seed']
    all_envs = gym.envs.registry.all()
    available_envs = [env_spec.id for env_spec in all_envs]
    env_name = params['env']
    if env_name in available_envs:
        env = gym.make(params['env'])
    elif env_name == 'cartpole-swingup_sparse':
        env = dmc2gym.make(domain_name='cartpole',
                           task_name='swingup_sparse',
                           seed=0)
    else:
        raise Exception("Invalid environment name")
    env = RescaleAction(env, -1, 1)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    agent = SAC_Agent(seed, state_dim, action_dim)

    train_agent_model_free(agent=agent, env=env, params=params)
Example #6
0
def main():

    parser = ArgumentParser()
    parser.add_argument('--env', type=str, default='HalfCheetah-v2')
    parser.add_argument('--alg',
                        type=str,
                        default='td3',
                        choices={'td3', 'sac', 'tds', 'mepg'})
    parser.add_argument('--yaml_config', type=str, default=None)
    parser.add_argument('--seed', type=int, default=100)
    parser.add_argument('--use_obs_filter',
                        dest='obs_filter',
                        action='store_true')
    parser.add_argument('--update_every_n_steps', type=int, default=None)
    parser.add_argument('--n_random_actions', type=int, default=None)
    parser.add_argument('--n_collect_steps', type=int, default=None)
    parser.add_argument('--n_evals', type=int, default=1)
    parser.add_argument('--total_timesteps', type=int, default=1e7)
    parser.add_argument('--save_model', dest='save_model', action='store_true')
    parser.add_argument('--experiment_name', type=str, default=None)
    parser.add_argument('--make_gif', dest='make_gif', action='store_true')
    parser.add_argument('--checkpoint_interval', type=int, default=500000)
    parser.add_argument('--save_replay_pool', type=bool, default=False)
    parser.add_argument('--load_model_path', type=str, default=None)
    parser.set_defaults(obs_filter=False)
    parser.set_defaults(save_model=False)

    args = parser.parse_args()
    params = vars(args)

    seed = params['seed']
    env = gym.make(params['env'])
    env = RescaleAction(env, -1, 1)

    state_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]

    params, agent = get_agent_and_update_params(seed, state_dim, action_dim,
                                                params)

    train_agent_model_free(agent=agent, env=env, params=params)
Example #7
0
def main():
    envs = {
        0: ['Walker2d-v2', 5],
        1: ['Hopper-v2', 5],
        2: ['HalfCheetah-v2', 1]
    }
    ind = 1

    env_name = envs[ind][0]
    env = gym.make(env_name)
    env = RescaleAction(env, -1, 1)

    obs_dim = env.observation_space.shape[0]
    action_dim = env.action_space.shape[0]
    print(action_dim, env.action_space.low, env.action_space.high)

    critic_net = DoubleQFunc(obs_dim, action_dim)
    target_net = copy.deepcopy(critic_net)
    target_net.eval()
    policy = Policy(obs_dim, action_dim)

    train(env, critic_net, target_net, policy)
Example #8
0
    reward_sum = 0
    for _ in range(n_starts):
        done = False
        state = env.reset()
        while (not done):
            action = get_action(state, policy, deterministic=True)
            nextstate, reward, done, _ = env.step(action)
            reward_sum += reward
            state = nextstate
            env.render()
    return reward_sum / n_starts

envs = {0: ['Walker2d-v2', 5], 1: ['Hopper-v2', 5], 2: ['HalfCheetah-v2', 1]}
ind = 1

env_name = envs[ind][0]
env = gym.make(env_name)
env = RescaleAction(env, -1, 1)

obs_dim = env.observation_space.shape[0]
action_dim = env.action_space.shape[0]
print(action_dim, env.action_space.low, env.action_space.high)

save_path = "Hopper/Policy200000pt"


# policy  = Policy(obs_dim, action_dim)
policy = torch.load(save_path)

evaluate_agent(env, policy, False, n_starts=10000)