コード例 #1
0
 def _make_batched_mock_gym_py_environment(self, multithreading, num_envs=3):
   self.time_step_spec = ts.time_step_spec(self.observation_spec)
   constructor = functools.partial(GymWrapperEnvironmentMock,
                                   self.observation_spec, self.action_spec)
   return batched_py_environment.BatchedPyEnvironment(
       envs=[constructor() for _ in range(num_envs)],
       multithreading=multithreading)
コード例 #2
0
    def testBatchedPyEnvCompatible(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        actor_net = actor_network.ActorNetwork(
            self._observation_tensor_spec,
            self._action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(self._time_step_tensor_spec,
                                             self._action_tensor_spec,
                                             actor_network=actor_net)

        py_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_policy,
                                                       batch_time_steps=False)

        env_ctr = lambda: random_py_environment.RandomPyEnvironment(  # pylint: disable=g-long-lambda
            self._observation_spec, self._action_spec)

        env = batched_py_environment.BatchedPyEnvironment(
            [env_ctr() for _ in range(3)])
        time_step = env.reset()

        for _ in range(20):
            action_step = py_policy.action(time_step)
            time_step = env.step(action_step.action)
コード例 #3
0
    def testTimeStepSpec(self, batch_py_env):
        py_env = PYEnvironmentMock()
        if batch_py_env:
            batched_py_env = batched_py_environment.BatchedPyEnvironment(
                [py_env])
            tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
        else:
            tf_env = tf_py_environment.TFPyEnvironment(py_env)
        spec = tf_env.time_step_spec()

        # step_type
        self.assertEqual(type(spec.step_type), specs.TensorSpec)
        self.assertEqual(spec.step_type.dtype, tf.int32)
        self.assertEqual(spec.step_type.shape, tf.TensorShape([]))

        # reward
        self.assertEqual(type(spec.reward), specs.TensorSpec)
        self.assertEqual(spec.reward.dtype, tf.float32)
        self.assertEqual(spec.reward.shape, tf.TensorShape([]))

        # discount
        self.assertEqual(type(spec.discount), specs.BoundedTensorSpec)
        self.assertEqual(spec.discount.dtype, tf.float32)
        self.assertEqual(spec.discount.shape, tf.TensorShape([]))
        self.assertEqual(spec.discount.minimum, 0.0)
        self.assertEqual(spec.discount.maximum, 1.0)

        # observation
        self.assertEqual(type(spec.observation), specs.TensorSpec)
コード例 #4
0
ファイル: tf_driver_test.py プロジェクト: Nitty12/MADRL
  def testBatchedEnvironment(self, max_steps, max_episodes, expected_length):

    expected_trajectories = [
        trajectory.Trajectory(
            step_type=np.array([0, 0]),
            observation=np.array([0, 0]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([1, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([1., 1.])),
        trajectory.Trajectory(
            step_type=np.array([1, 1]),
            observation=np.array([2, 1]),
            action=np.array([1, 2]),
            policy_info=np.array([2, 4]),
            next_step_type=np.array([2, 1]),
            reward=np.array([1., 1.]),
            discount=np.array([0., 1.])),
        trajectory.Trajectory(
            step_type=np.array([2, 1]),
            observation=np.array([3, 3]),
            action=np.array([2, 1]),
            policy_info=np.array([4, 2]),
            next_step_type=np.array([0, 2]),
            reward=np.array([0., 1.]),
            discount=np.array([1., 0.]))
    ]

    env1 = driver_test_utils.PyEnvironmentMock(final_state=3)
    env2 = driver_test_utils.PyEnvironmentMock(final_state=4)
    env = batched_py_environment.BatchedPyEnvironment([env1, env2])
    tf_env = tf_py_environment.TFPyEnvironment(env)

    policy = driver_test_utils.TFPolicyMock(
        tf_env.time_step_spec(),
        tf_env.action_spec(),
        batch_size=2,
        initial_policy_state=tf.constant([1, 2], dtype=tf.int32))

    replay_buffer_observer = MockReplayBufferObserver()

    driver = tf_driver.TFDriver(
        tf_env,
        policy,
        observers=[replay_buffer_observer],
        max_steps=max_steps,
        max_episodes=max_episodes,
    )
    initial_time_step = tf_env.reset()
    initial_policy_state = tf.constant([1, 2], dtype=tf.int32)
    self.evaluate(driver.run(initial_time_step, initial_policy_state))
    trajectories = replay_buffer_observer.gather_all()

    self.assertEqual(
        len(trajectories), len(expected_trajectories[:expected_length]))

    for t1, t2 in zip(trajectories, expected_trajectories[:expected_length]):
      for t1_field, t2_field in zip(t1, t2):
        self.assertAllEqual(t1_field, t2_field)
コード例 #5
0
    def testBatchedFirstTimeStepAndOneStep(self):
        py_envs = [PYEnvironmentMock() for _ in range(3)]
        batched_py_env = batched_py_environment.BatchedPyEnvironment(py_envs)
        tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
        self.assertEqual(tf_env.batch_size, 3)
        time_step_0 = tf_env.current_time_step()
        time_step_0_val = self.evaluate(time_step_0)

        self.assertAllEqual([ts.StepType.FIRST] * 3, time_step_0_val.step_type)
        self.assertAllEqual([0.0] * 3, time_step_0_val.reward)
        self.assertAllEqual([1.0] * 3, time_step_0_val.discount)
        self.assertAllEqual(np.array([0, 0, 0]), time_step_0_val.observation)
        for py_env in py_envs:
            self.assertEqual([], py_env.actions_taken)
            self.assertEqual(1, py_env.resets)
            self.assertEqual(0, py_env.steps)
            self.assertEqual(0, py_env.episodes)

        time_step_1 = tf_env.step(np.array([1, 1, 1]))

        time_step_1_val = self.evaluate(time_step_1)

        self.assertAllEqual([ts.StepType.MID] * 3, time_step_1_val.step_type)
        self.assertAllEqual([0.] * 3, time_step_1_val.reward)
        self.assertAllEqual([1.0] * 3, time_step_1_val.discount)
        self.assertAllEqual(np.array([1, 1, 1]), time_step_1_val.observation)
        for py_env in py_envs:
            self.assertEqual([1], py_env.actions_taken)
            self.assertEqual(1, py_env.resets)
            self.assertEqual(1, py_env.steps)
            self.assertEqual(0, py_env.episodes)
コード例 #6
0
 def _make_batched_py_environment(self, num_envs=3):
     self.time_step_spec = ts.time_step_spec(self.observation_spec)
     constructor = functools.partial(
         random_py_environment.RandomPyEnvironment, self.observation_spec,
         self.action_spec)
     return batched_py_environment.BatchedPyEnvironment(
         envs=[constructor() for _ in range(num_envs)])
コード例 #7
0
ファイル: tf_py_environment.py プロジェクト: weiddeng/agents
    def __init__(self, environment):
        """Initializes a new `TFPyEnvironment`.

    Args:
      environment: Environment to interact with, implementing
        `py_environment.PyEnvironment`.

    Raises:
      TypeError: If `environment` is not a subclass of
      `py_environment.PyEnvironment`.
    """
        if not isinstance(environment, py_environment.PyEnvironment):
            raise TypeError(
                'Environment should implement py_environment.PyEnvironment')

        if not environment.batched:
            environment = batched_py_environment.BatchedPyEnvironment(
                [environment])
        self._env = environment

        observation_spec = tensor_spec.from_spec(self._env.observation_spec())
        action_spec = tensor_spec.from_spec(self._env.action_spec())
        time_step_spec = ts.time_step_spec(observation_spec)
        batch_size = self._env.batch_size if self._env.batch_size else 1
        super(TFPyEnvironment, self).__init__(time_step_spec, action_spec,
                                              batch_size)

        # Gather all the dtypes of the elements in time_step.
        self._time_step_dtypes = [
            s.dtype for s in tf.nest.flatten(self.time_step_spec())
        ]

        self._time_step = None
        self._lock = threading.Lock()
コード例 #8
0
 def _create_env():
   if batch_size is None:
     py_env = PYEnvironmentMock()
   else:
     py_env = [PYEnvironmentMock() for _ in range(batch_size)]
   if batch_py_env:
     py_env = batched_py_environment.BatchedPyEnvironment(
         py_env if isinstance(py_env, list) else [py_env])
   return py_env
コード例 #9
0
 def testResetOp(self, batch_py_env):
   py_env = PYEnvironmentMock()
   if batch_py_env:
     batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
     tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
   else:
     tf_env = tf_py_environment.TFPyEnvironment(py_env)
   reset = tf_env.reset()
   self.evaluate(reset)
   self.assertEqual(1, py_env.resets)
   self.assertEqual(0, py_env.steps)
   self.assertEqual(0, py_env.episodes)
コード例 #10
0
 def testActionSpec(self, batch_py_env):
     py_env = PYEnvironmentMock()
     if batch_py_env:
         py_env = batched_py_environment.BatchedPyEnvironment([py_env])
     tf_env = tf_py_environment.TFPyEnvironment(py_env)
     self.assertTrue(tf_env.batched)
     self.assertEqual(tf_env.batch_size, 1)
     spec = tf_env.action_spec()
     self.assertEqual(type(spec), specs.BoundedTensorSpec)
     self.assertEqual(spec.dtype, tf.int32)
     self.assertEqual(spec.shape, tf.TensorShape([]))
     self.assertEqual(spec.name, 'action')
コード例 #11
0
 def testMultipleReset(self, batch_py_env):
   py_env = PYEnvironmentMock()
   if batch_py_env:
     batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
     tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
   else:
     tf_env = tf_py_environment.TFPyEnvironment(py_env)
   reset = tf_env.reset()
   self.evaluate(reset)
   self.assertEqual(1, py_env.resets)
   self.evaluate(reset)
   self.assertEqual(2, py_env.resets)
   self.evaluate(reset)
   self.assertEqual(3, py_env.resets)
コード例 #12
0
  def testFirstObservationIsPreservedAfterTwoSteps(self, batch_py_env):
    py_env = PYEnvironmentMock()
    if batch_py_env:
      batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
      tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
    else:
      tf_env = tf_py_environment.TFPyEnvironment(py_env)
    time_step = tf_env.current_time_step()
    with tf.control_dependencies([time_step.step_type]):
      action = tf.constant([1])
    next_time_step = tf_env.step(action)
    with tf.control_dependencies([next_time_step.step_type]):
      action = tf.constant([2])
    _, observation = self.evaluate([tf_env.step(action), time_step.observation])

    self.assertEqual(np.array([0]), observation)
コード例 #13
0
 def testFirstTimeStep(self, batch_py_env):
   py_env = PYEnvironmentMock()
   if batch_py_env:
     batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
     tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
   else:
     tf_env = tf_py_environment.TFPyEnvironment(py_env)
   time_step = tf_env.current_time_step()
   time_step = self.evaluate(time_step)
   self.assertAllEqual([ts.StepType.FIRST], time_step.step_type)
   self.assertAllEqual([0.0], time_step.reward)
   self.assertAllEqual([1.0], time_step.discount)
   self.assertAllEqual([0], time_step.observation)
   self.assertAllEqual([], py_env.actions_taken)
   self.assertEqual(1, py_env.resets)
   self.assertEqual(0, py_env.steps)
   self.assertEqual(0, py_env.episodes)
コード例 #14
0
  def testMultipleReset(self, batch_py_env):
    if tf.executing_eagerly():
      self.skipTest('b/123881757')

    py_env = PYEnvironmentMock()
    if batch_py_env:
      batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
      tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
    else:
      tf_env = tf_py_environment.TFPyEnvironment(py_env)
    reset = tf_env.reset()
    self.evaluate(reset)
    self.assertEqual(1, py_env.resets)
    self.evaluate(reset)
    self.assertEqual(2, py_env.resets)
    self.evaluate(reset)
    self.assertEqual(3, py_env.resets)
コード例 #15
0
 def __init__(self, flags):
     """Initialize runner."""
     self.num_episodes = flags['num_episodes']
     self.environment = batched_py_environment.BatchedPyEnvironment(
         [rl_env.make('Hanabi-Full', num_players=flags['players'])])
     self.agent_config = {
         'max_information_tokens':
         self.environment.envs[0].game.max_information_tokens(),
         'action_spec':
         self.environment.action_spec(),
         'observation_spec':
         self.environment.observation_spec(),
         'environment_batch_size':
         self.environment.batch_size
     }
     self.agent_1 = AGENT_CLASSES[flags['agent_1']](self.agent_config)
     self.agent_2 = AGENT_CLASSES[flags['agent_2']](self.agent_config)
コード例 #16
0
  def testOneStep(self, batch_py_env):
    py_env = PYEnvironmentMock()
    if batch_py_env:
      batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
      tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
    else:
      tf_env = tf_py_environment.TFPyEnvironment(py_env)
    time_step = tf_env.current_time_step()
    with tf.control_dependencies([time_step.step_type]):
      action = tf.constant([1])
    time_step = self.evaluate(tf_env.step(action))

    self.assertAllEqual([ts.StepType.MID], time_step.step_type)
    self.assertAllEqual([0.], time_step.reward)
    self.assertAllEqual([1.0], time_step.discount)
    self.assertAllEqual([1], time_step.observation)
    self.assertAllEqual([1], py_env.actions_taken)
    self.assertEqual(1, py_env.resets)
    self.assertEqual(1, py_env.steps)
    self.assertEqual(0, py_env.episodes)
コード例 #17
0
  def testTwoStepsDependenceOnTheFirst(self, batch_py_env):
    py_env = PYEnvironmentMock()
    if batch_py_env:
      batched_py_env = batched_py_environment.BatchedPyEnvironment([py_env])
      tf_env = tf_py_environment.TFPyEnvironment(batched_py_env)
    else:
      tf_env = tf_py_environment.TFPyEnvironment(py_env)
    time_step = tf_env.current_time_step()
    with tf.control_dependencies([time_step.step_type]):
      action = tf.constant([1])
    time_step = tf_env.step(action)
    with tf.control_dependencies([time_step.step_type]):
      action = tf.constant([2])
    time_step = self.evaluate(tf_env.step(action))

    self.assertEqual(ts.StepType.LAST, time_step.step_type)
    self.assertEqual([2], time_step.observation)
    self.assertEqual(1., time_step.reward)
    self.assertEqual(0., time_step.discount)
    self.assertEqual([1, 2], py_env.actions_taken)
コード例 #18
0
  def test_metric_results_equal_with_batched_env(self):
    env_ctor = lambda: random_py_environment.RandomPyEnvironment(  # pylint: disable=g-long-lambda
        self._time_step_spec.observation, self._action_spec)
    batch_size = 5
    env = batched_py_environment.BatchedPyEnvironment(
        [env_ctor() for _ in range(batch_size)])
    tf_env = tf_py_environment.TFPyEnvironment(env)

    python_metrics, tensorflow_metrics = self._build_metrics(
        batch_size=batch_size)
    observers = python_metrics + tensorflow_metrics
    driver = dynamic_step_driver.DynamicStepDriver(
        tf_env, self._policy, observers=observers, num_steps=1000)

    self.evaluate(tf.compat.v1.global_variables_initializer())
    self.evaluate(driver.run())

    for python_metric, tensorflow_metric in zip(python_metrics,
                                                tensorflow_metrics):
      python_result = self.evaluate(python_metric.result())
      tensorflow_result = self.evaluate(tensorflow_metric.result())
      self.assertEqual(python_result, tensorflow_result)
コード例 #19
0
def create_envs(env_name,
                use_multiprocessing,
                num_parallel_envs,
                visualize_eval=False,
                mock_train_envs=False):
    def env_load_fn(env_map_name, visualize=False, mock=False):
        env = gym_wrapper.GymWrapper(
            gym_env=SC2GymEnv(map_name=env_map_name,
                              visualize=visualize,
                              mock=mock),
            spec_dtype_map={
                gym.spaces.Box: np.float32,
                gym.spaces.Discrete: np.int32,
                gym.spaces.MultiBinary: np.float32
            },
        )
        return env

    if num_parallel_envs == 1:
        par_env = env_load_fn(env_map_name=env_name, mock=mock_train_envs)
    elif use_multiprocessing:
        par_env = parallel_py_environment.ParallelPyEnvironment(
            [lambda: env_load_fn(env_map_name=env_name, mock=mock_train_envs)
             ] * num_parallel_envs,
            start_serially=False)
    else:
        par_env = batched_py_environment.BatchedPyEnvironment(envs=[
            env_load_fn(env_map_name=env_name, mock=mock_train_envs)
            for _ in range(num_parallel_envs)
        ])
    tf_env = tf_py_environment.TFPyEnvironment(par_env)
    tf_env.reset()

    eval_env = env_load_fn(env_name, visualize=visualize_eval)
    eval_env = tf_py_environment.TFPyEnvironment(eval_env)
    eval_env.reset()

    return tf_env, eval_env
コード例 #20
0
    def __init__(self, environment, check_dims=False, isolation=False):
        """Initializes a new `TFPyEnvironment`.

    Args:
      environment: Environment to interact with, implementing
        `py_environment.PyEnvironment`.  Or a `callable` that returns
        an environment of this form.  If a `callable` is provided and
        `thread_isolation` is provided, the callable is executed in the
        dedicated thread.
      check_dims: Whether should check batch dimensions of actions in `step`.
      isolation: If this value is `False` (default), interactions with
        the environment will occur within whatever thread the methods of the
        `TFPyEnvironment` are run from.  For example, in TF graph mode, methods
        like `step` are called from multiple threads created by the TensorFlow
        engine; calls to step the environment are guaranteed to be sequential,
        but not from the same thread.  This creates problems for environments
        that are not thread-safe.

        Using isolation ensures not only that a dedicated thread (or
        thread-pool) is used to interact with the environment, but also that
        interaction with the environment happens in a serialized manner.

        If `isolation == True`, a dedicated thread is created for
        interactions with the environment.

        If `isolation` is an instance of `multiprocessing.pool.Pool` (this
        includes instances of `multiprocessing.pool.ThreadPool`, nee
        `multiprocessing.dummy.Pool` and `multiprocessing.Pool`, then this
        pool is used to interact with the environment.

        **NOTE** If using `isolation` with a `BatchedPyEnvironment`, ensure
        you create the `BatchedPyEnvironment` with `multithreading=False`, since
        otherwise the multithreading in that wrapper reverses the effects of
        this one.

    Raises:
      TypeError: If `environment` is not an instance of
        `py_environment.PyEnvironment` or subclasses, or is a callable that does
        not return an instance of `PyEnvironment`.
      TypeError: If `isolation` is not `True`, `False`, or an instance of
        `multiprocessing.pool.Pool`.
    """
        if not isolation:
            self._pool = None
        elif isinstance(isolation, pool.Pool):
            self._pool = isolation
        elif isolation:
            self._pool = pool.ThreadPool(1)
        else:
            raise TypeError(
                'isolation should be True, False, or an instance of '
                'a multiprocessing Pool or ThreadPool.  Saw: {}'.format(
                    isolation))

        if callable(environment):
            environment = self._execute(environment)
        if not isinstance(environment, py_environment.PyEnvironment):
            raise TypeError(
                'Environment should implement py_environment.PyEnvironment')

        if not environment.batched:
            # If executing in an isolated thread, do not enable multiprocessing for
            # this environment.
            environment = batched_py_environment.BatchedPyEnvironment(
                [environment], multithreading=not self._pool)
        self._env = environment
        self._check_dims = check_dims

        if isolation and getattr(self._env, '_parallel_execution', None):
            logging.warn(
                'Wrapped environment is executing in parallel.  '
                'Perhaps it is a BatchedPyEnvironment with multithreading=True, '
                'or it is a ParallelPyEnvironment.  This conflicts with the '
                '`isolation` arg passed to TFPyEnvironment: interactions with the '
                'wrapped environment are no longer guaranteed to happen in a common '
                'thread.  Environment: %s', (self._env, ))

        observation_spec = tensor_spec.from_spec(self._env.observation_spec())
        action_spec = tensor_spec.from_spec(self._env.action_spec())
        time_step_spec = ts.time_step_spec(observation_spec)
        batch_size = self._env.batch_size if self._env.batch_size else 1

        super(TFPyEnvironment, self).__init__(time_step_spec, action_spec,
                                              batch_size)

        # Gather all the dtypes of the elements in time_step.
        self._time_step_dtypes = [
            s.dtype for s in tf.nest.flatten(self.time_step_spec())
        ]

        self._time_step = None
        self._lock = threading.Lock()
コード例 #21
0
    def __init__(
            self,
            root_dir,
            env_name,
            num_iterations=200,
            max_episode_frames=108000,  # ALE frames
            terminal_on_life_loss=False,
            conv_layer_params=((32, (8, 8), 4), (64, (4, 4), 2), (64, (3, 3),
                                                                  1)),
            fc_layer_params=(512, ),
            # Params for collect
            initial_collect_steps=80000,  # ALE frames
            epsilon_greedy=0.01,
            epsilon_decay_period=1000000,  # ALE frames
            replay_buffer_capacity=1000000,
            # Params for train
            train_steps_per_iteration=1000000,  # ALE frames
            update_period=16,  # ALE frames
            target_update_tau=1.0,
            target_update_period=32000,  # ALE frames
            batch_size=32,
            learning_rate=2.5e-4,
            n_step_update=2,
            gamma=0.99,
            reward_scale_factor=1.0,
            gradient_clipping=None,
            # Params for eval
            do_eval=True,
            eval_steps_per_iteration=500000,  # ALE frames
            eval_epsilon_greedy=0.001,
            # Params for checkpoints, summaries, and logging
            log_interval=1000,
            summary_interval=1000,
            summaries_flush_secs=10,
            debug_summaries=True,
            summarize_grads_and_vars=True,
            eval_metrics_callback=None):
        """A simple Atari train and eval for DQN.

    Args:
      root_dir: Directory to write log files to.
      env_name: Fully-qualified name of the Atari environment (i.e. Pong-v0).
      num_iterations: Number of train/eval iterations to run.
      max_episode_frames: Maximum length of a single episode, in ALE frames.
      terminal_on_life_loss: Whether to simulate an episode termination when a
        life is lost.
      conv_layer_params: Params for convolutional layers of QNetwork.
      fc_layer_params: Params for fully connected layers of QNetwork.
      initial_collect_steps: Number of frames to ALE frames to process before
        beginning to train. Since this is in ALE frames, there will be
        initial_collect_steps/4 items in the replay buffer when training starts.
      epsilon_greedy: Final epsilon value to decay to for training.
      epsilon_decay_period: Period over which to decay epsilon, from 1.0 to
        epsilon_greedy (defined above).
      replay_buffer_capacity: Maximum number of items to store in the replay
        buffer.
      train_steps_per_iteration: Number of ALE frames to run through for each
        iteration of training.
      update_period: Run a train operation every update_period ALE frames.
      target_update_tau: Coeffecient for soft target network updates (1.0 ==
        hard updates).
      target_update_period: Period, in ALE frames, to copy the live network to
        the target network.
      batch_size: Number of frames to include in each training batch.
      learning_rate: RMS optimizer learning rate.
      n_step_update: The number of steps to consider when computing TD error and
        TD loss. Applies standard single-step updates when set to 1.
      gamma: Discount for future rewards.
      reward_scale_factor: Scaling factor for rewards.
      gradient_clipping: Norm length to clip gradients.
      do_eval: If True, run an eval every iteration. If False, skip eval.
      eval_steps_per_iteration: Number of ALE frames to run through for each
        iteration of evaluation.
      eval_epsilon_greedy: Epsilon value to use for the evaluation policy (0 ==
        totally greedy policy).
      log_interval: Log stats to the terminal every log_interval training
        steps.
      summary_interval: Write TF summaries every summary_interval training
        steps.
      summaries_flush_secs: Flush summaries to disk every summaries_flush_secs
        seconds.
      debug_summaries: If True, write additional summaries for debugging (see
        dqn_agent for which summaries are written).
      summarize_grads_and_vars: Include gradients in summaries.
      eval_metrics_callback: A callback function that takes (metric_dict,
        global_step) as parameters. Called after every eval with the results of
        the evaluation.
    """
        self._update_period = update_period / ATARI_FRAME_SKIP
        self._train_steps_per_iteration = (train_steps_per_iteration /
                                           ATARI_FRAME_SKIP)
        self._do_eval = do_eval
        self._eval_steps_per_iteration = eval_steps_per_iteration / ATARI_FRAME_SKIP
        self._eval_epsilon_greedy = eval_epsilon_greedy
        self._initial_collect_steps = initial_collect_steps / ATARI_FRAME_SKIP
        self._summary_interval = summary_interval
        self._num_iterations = num_iterations
        self._log_interval = log_interval
        self._eval_metrics_callback = eval_metrics_callback

        with gin.unlock_config():
            gin.bind_parameter(('tf_agents.environments.atari_preprocessing.'
                                'AtariPreprocessing.terminal_on_life_loss'),
                               terminal_on_life_loss)

        root_dir = os.path.expanduser(root_dir)
        train_dir = os.path.join(root_dir, 'train')
        eval_dir = os.path.join(root_dir, 'eval')

        train_summary_writer = tf.compat.v2.summary.create_file_writer(
            train_dir, flush_millis=summaries_flush_secs * 1000)
        train_summary_writer.set_as_default()
        self._train_summary_writer = train_summary_writer

        self._eval_summary_writer = None
        if self._do_eval:
            self._eval_summary_writer = tf.compat.v2.summary.create_file_writer(
                eval_dir, flush_millis=summaries_flush_secs * 1000)
            self._eval_metrics = [
                py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                               buffer_size=np.inf),
                py_metrics.AverageEpisodeLengthMetric(
                    name='PhaseAverageEpisodeLength', buffer_size=np.inf),
            ]

        self._global_step = tf.compat.v1.train.get_or_create_global_step()
        with tf.compat.v2.summary.record_if(lambda: tf.math.equal(
                self._global_step % self._summary_interval, 0)):
            self._env = suite_atari.load(
                env_name,
                max_episode_steps=max_episode_frames / ATARI_FRAME_SKIP,
                gym_env_wrappers=suite_atari.
                DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
            self._env = batched_py_environment.BatchedPyEnvironment(
                [self._env])

            observation_spec = tensor_spec.from_spec(
                self._env.observation_spec())
            time_step_spec = ts.time_step_spec(observation_spec)
            action_spec = tensor_spec.from_spec(self._env.action_spec())

            with tf.device('/cpu:0'):
                epsilon = tf.compat.v1.train.polynomial_decay(
                    1.0,
                    self._global_step,
                    epsilon_decay_period / ATARI_FRAME_SKIP /
                    self._update_period,
                    end_learning_rate=epsilon_greedy)

            with tf.device('/gpu:0'):
                optimizer = tf.compat.v1.train.RMSPropOptimizer(
                    learning_rate=learning_rate,
                    decay=0.95,
                    momentum=0.0,
                    epsilon=0.00001,
                    centered=True)
                categorical_q_net = AtariCategoricalQNetwork(
                    observation_spec,
                    action_spec,
                    conv_layer_params=conv_layer_params,
                    fc_layer_params=fc_layer_params)
                agent = categorical_dqn_agent.CategoricalDqnAgent(
                    time_step_spec,
                    action_spec,
                    categorical_q_network=categorical_q_net,
                    optimizer=optimizer,
                    epsilon_greedy=epsilon,
                    n_step_update=n_step_update,
                    target_update_tau=target_update_tau,
                    target_update_period=(target_update_period /
                                          ATARI_FRAME_SKIP /
                                          self._update_period),
                    gamma=gamma,
                    reward_scale_factor=reward_scale_factor,
                    gradient_clipping=gradient_clipping,
                    debug_summaries=debug_summaries,
                    summarize_grads_and_vars=summarize_grads_and_vars,
                    train_step_counter=self._global_step)

                self._collect_policy = py_tf_policy.PyTFPolicy(
                    agent.collect_policy)

                if self._do_eval:
                    self._eval_policy = py_tf_policy.PyTFPolicy(
                        epsilon_greedy_policy.EpsilonGreedyPolicy(
                            policy=agent.policy,
                            epsilon=self._eval_epsilon_greedy))

                py_observation_spec = self._env.observation_spec()
                py_time_step_spec = ts.time_step_spec(py_observation_spec)
                py_action_spec = policy_step.PolicyStep(
                    self._env.action_spec())
                data_spec = trajectory.from_transition(py_time_step_spec,
                                                       py_action_spec,
                                                       py_time_step_spec)
                self._replay_buffer = py_hashed_replay_buffer.PyHashedReplayBuffer(
                    data_spec=data_spec, capacity=replay_buffer_capacity)

            with tf.device('/cpu:0'):
                ds = self._replay_buffer.as_dataset(
                    sample_batch_size=batch_size, num_steps=n_step_update + 1)
                ds = ds.prefetch(4)
                ds = ds.apply(
                    tf.data.experimental.prefetch_to_device('/gpu:0'))

            with tf.device('/gpu:0'):
                self._ds_itr = tf.compat.v1.data.make_one_shot_iterator(ds)
                experience = self._ds_itr.get_next()
                self._train_op = agent.train(experience)

                self._env_steps_metric = py_metrics.EnvironmentSteps()
                self._step_metrics = [
                    py_metrics.NumberOfEpisodes(),
                    self._env_steps_metric,
                ]
                self._train_metrics = self._step_metrics + [
                    py_metrics.AverageReturnMetric(buffer_size=10),
                    py_metrics.AverageEpisodeLengthMetric(buffer_size=10),
                ]
                # The _train_phase_metrics average over an entire train iteration,
                # rather than the rolling average of the last 10 episodes.
                self._train_phase_metrics = [
                    py_metrics.AverageReturnMetric(name='PhaseAverageReturn',
                                                   buffer_size=np.inf),
                    py_metrics.AverageEpisodeLengthMetric(
                        name='PhaseAverageEpisodeLength', buffer_size=np.inf),
                ]
                self._iteration_metric = py_metrics.CounterMetric(
                    name='Iteration')

                # Summaries written from python should run every time they are
                # generated.
                with tf.compat.v2.summary.record_if(True):
                    self._steps_per_second_ph = tf.compat.v1.placeholder(
                        tf.float32, shape=(), name='steps_per_sec_ph')
                    self._steps_per_second_summary = tf.compat.v2.summary.scalar(
                        name='global_steps_per_sec',
                        data=self._steps_per_second_ph,
                        step=self._global_step)

                    for metric in self._train_metrics:
                        metric.tf_summaries(train_step=self._global_step,
                                            step_metrics=self._step_metrics)

                    for metric in self._train_phase_metrics:
                        metric.tf_summaries(
                            train_step=self._global_step,
                            step_metrics=(self._iteration_metric, ))
                    self._iteration_metric.tf_summaries(
                        train_step=self._global_step)

                    if self._do_eval:
                        with self._eval_summary_writer.as_default():
                            for metric in self._eval_metrics:
                                metric.tf_summaries(
                                    train_step=self._global_step,
                                    step_metrics=(self._iteration_metric, ))

                self._train_checkpointer = common.Checkpointer(
                    ckpt_dir=train_dir,
                    agent=agent,
                    global_step=self._global_step,
                    optimizer=optimizer,
                    metrics=metric_utils.MetricsGroup(
                        self._train_metrics + self._train_phase_metrics +
                        [self._iteration_metric], 'train_metrics'))
                self._policy_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'policy'),
                    policy=agent.policy,
                    global_step=self._global_step)
                self._rb_checkpointer = common.Checkpointer(
                    ckpt_dir=os.path.join(train_dir, 'replay_buffer'),
                    max_to_keep=1,
                    replay_buffer=self._replay_buffer)

                self._init_agent_op = agent.initialize()
コード例 #22
0
ファイル: oog_train_eval.py プロジェクト: landuber/agents
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
コード例 #23
0
    def testMultiStepEpisodicReplayBuffer(self):
        num_episodes = 5
        num_driver_episodes = 5

        # Create mock environment.
        py_env = batched_py_environment.BatchedPyEnvironment([
            driver_test_utils.PyEnvironmentMock(final_state=i + 1)
            for i in range(num_episodes)
        ])
        env = tf_py_environment.TFPyEnvironment(py_env)

        # Creat mock policy.
        policy = driver_test_utils.TFPolicyMock(env.time_step_spec(),
                                                env.action_spec(),
                                                batch_size=num_episodes)

        # Create replay buffer and driver.
        replay_buffer = self._make_replay_buffer(env)
        stateful_buffer = episodic_replay_buffer.StatefulEpisodicReplayBuffer(
            replay_buffer, num_episodes)
        driver = dynamic_episode_driver.DynamicEpisodeDriver(
            env,
            policy,
            num_episodes=num_driver_episodes,
            observers=[stateful_buffer.add_batch])

        run_driver = driver.run()

        end_episodes = replay_buffer._maybe_end_batch_episodes(
            stateful_buffer.episode_ids, end_episode=True)

        completed_episodes = replay_buffer._completed_episodes()

        self.evaluate([
            tf.compat.v1.local_variables_initializer(),
            tf.compat.v1.global_variables_initializer()
        ])

        self.evaluate(run_driver)

        self.evaluate(end_episodes)
        completed_episodes = self.evaluate(completed_episodes)
        eps = [replay_buffer._get_episode(ep) for ep in completed_episodes]
        eps = self.evaluate(eps)

        episodes_length = [tf.nest.flatten(ep)[0].shape[0] for ep in eps]

        # Compare with expected output.
        self.assertAllEqual(completed_episodes, [3, 4, 5, 6, 7])
        self.assertAllEqual(episodes_length, [4, 4, 2, 1, 1])

        first = ts.StepType.FIRST
        mid = ts.StepType.MID
        last = ts.StepType.LAST

        step_types = [ep.step_type for ep in eps]
        observations = [ep.observation for ep in eps]
        rewards = [ep.reward for ep in eps]
        actions = [ep.action for ep in eps]

        self.assertAllClose([[first, mid, mid, last], [first, mid, mid, mid],
                             [first, last], [first], [first]], step_types)

        self.assertAllClose([
            [0, 1, 3, 4],
            [0, 1, 3, 4],
            [0, 1],
            [0],
            [0],
        ], observations)

        self.assertAllClose([
            [1, 2, 1, 2],
            [1, 2, 1, 2],
            [1, 2],
            [1],
            [1],
        ], actions)

        self.assertAllClose([
            [1, 1, 1, 0],
            [1, 1, 1, 1],
            [1, 0],
            [1],
            [1],
        ], rewards)