コード例 #1
0
    def test_batching_scheme_does_not_restart(self):
        """Test if BatchEnv correctly handle environments that come back to life."""

        env = batch_env.BatchEnv([FakeEnvThatComesBackToLife(i) \
                                  for i in range(BATCH_SIZE)])
        env.reset()
        policy = policies.NormalPolicyFixedStd(ACTION_SPACE, std=0.5)
        spec = env_spec.EnvSpec(env)
        trajectories = trajectory_collector.collect_trajectories(
            env, policy, max_steps=MAX_STEPS, env_spec=spec)
        _, _, masks, _ = trajectories
        checks = []
        for t in range(1, masks.shape[0]):
            # Here the logic is:
            # (1) Find environments that were terminated (mask = 0) in the previous
            # time step.
            # (2) Check that in the current time step they are still terminated.
            # At the end we check if this was true for every time pair.
            # We expect that it will be True when trajectoies do not come back to
            # life.
            prev_time_step_end = np.where(masks[t - 1] == 0)
            checks.append(np.all(masks[t, prev_time_step_end] == 0))

        # assert that no environments came back to line.
        self.assertTrue(np.all(checks))
コード例 #2
0
    def test_collector_in_env(self, env_name, policy_fn, policy_args):
        """Will do a rollout in the environment.

    The goal of this test is two fold:
    - trajectory collections can happen.
    - action clipping happens.

    Args:
      env_name: Name of the environment to load.
      policy_fn: a policies.* object that executes actions in the environment.
      policy_args: The arguments needed to load the policy.
    """
        env = batch_env.BatchEnv([gym.make(env_name) \
                                  for _ in range(BATCH_SIZE)])
        env.reset()
        policy = policy_fn(env.action_space.shape[0], **policy_args)
        spec = env_spec.EnvSpec(env)
        trajectories = trajectory_collector.collect_trajectories(
            env, policy, max_steps=MAX_STEPS, env_spec=spec)
        self.assertIsNotNone(trajectories)
コード例 #3
0
    def test_repeated_trajectory_collector_has_gradients(self):
        """Make sure concatenating trajectories maintains gradient information."""
        env = batch_env.BatchEnv(
            [gym.make('Pendulum-v0') for _ in range(BATCH_SIZE)])
        env.reset()
        policy = policies.NormalPolicyFixedStd(env.action_space.shape[0],
                                               std=0.5)
        spec = env_spec.EnvSpec(env)
        objective = objectives.REINFORCE()
        with tf.GradientTape() as tape:
            (rewards, log_probs, masks,
             _) = trajectory_collector.repeat_collect_trajectories(
                 env,
                 policy,
                 n_trajectories=BATCH_SIZE * 5,
                 env_spec=spec,
                 max_steps=100)
            returns = rl_utils.compute_discounted_return(rewards, 0.99, masks)
            loss = objective(log_probs=log_probs, returns=returns, masks=masks)
            grads = tape.gradient(loss, policy.trainable_variables)

        self.assertTrue(len(grads))
        self.assertFalse(np.all([np.all(t.numpy() == 0) for t in grads]))
コード例 #4
0
def get_batched_environment(env_name):
    """Returns a batched version of the environment."""
    return batch_env.BatchEnv([gym.make(env_name) for _ in range(BATCH_SIZE)])
コード例 #5
0
def get_batched_environment(env_name, batch_size):
    """Returns a batched version of the environment."""
    return batch_env.BatchEnv([gym.make(env_name) \
                               for _ in range(batch_size)])
コード例 #6
0
 def create_env():
     """Creates environments and useful things to rollout trajectories."""
     env = batch_env.BatchEnv([gym.make(ENV_NAME) for _ in range(5)])
     spec = env_spec.EnvSpec(env)
     others = {'max_steps_env': 1500, 'n_trajectories': 5}
     return env, spec, others