Ejemplo n.º 1
0
  def test_with_normal_context_and_normal_reward(self):

    def _context_sampling_fn():
      return np.random.normal(0, 3, [1, 2])

    def _reward_fn(x):
      return np.random.normal(2 * x[0], abs(x[1]) + 1)

    env = sspe.StationaryStochasticPyEnvironment(_context_sampling_fn,
                                                 [_reward_fn])
    time_step_spec = env.time_step_spec()
    action_spec = env.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec)

    for _ in range(5):
      time_step = env.reset()
      self.assertTrue(
          check_unbatched_time_step_spec(
              time_step=time_step,
              time_step_spec=time_step_spec,
              batch_size=env.batch_size))

      action = random_policy.action(time_step).action
      time_step = env.step(action)
Ejemplo n.º 2
0
    def testMetricIsComputedCorrectly(self):
        def reward_fn(*unused_args):
            reward = np.random.uniform()
            reward_fn.total_reward += reward
            return reward

        reward_fn.total_reward = 0

        action_spec = array_spec.BoundedArraySpec((1, ), np.int32, -10, 10)
        observation_spec = array_spec.BoundedArraySpec((1, ), np.int32, -10,
                                                       10)
        env = random_py_environment.RandomPyEnvironment(observation_spec,
                                                        action_spec,
                                                        reward_fn=reward_fn)
        policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
                                                 action_spec=action_spec)

        average_return = py_metrics.AverageReturnMetric()

        num_episodes = 10
        results = metric_utils.compute([average_return], env, policy,
                                       num_episodes)
        self.assertAlmostEqual(reward_fn.total_reward / num_episodes,
                               results[average_return.name],
                               places=5)
Ejemplo n.º 3
0
    def test_with_uniform_context_and_normal_mu_reward(self):
        def _context_sampling_fn():
            return np.random.randint(-10, 10, [1, 4])

        reward_fns = [
            LinearNormalReward(theta)
            for theta in ([0, 1, 2, 3], [3, 2, 1, 0], [-1, -2, -3, -4])
        ]

        env = sspe.StationaryStochasticPyEnvironment(_context_sampling_fn,
                                                     reward_fns)
        time_step_spec = env.time_step_spec()
        action_spec = env.action_spec()

        random_policy = random_py_policy.RandomPyPolicy(
            time_step_spec=time_step_spec, action_spec=action_spec)

        for _ in range(5):
            time_step = env.reset()
            self.assertTrue(
                check_unbatched_time_step_spec(time_step=time_step,
                                               time_step_spec=time_step_spec,
                                               batch_size=env.batch_size))

            action = random_policy.action(time_step).action
            time_step = env.step(action)
Ejemplo n.º 4
0
def validate_py_environment(
    environment: py_environment.PyEnvironment,
    episodes: int = 5,
    observation_and_action_constraint_splitter: Optional[
        types.Splitter] = None):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec,
        action_spec=action_spec,
        observation_and_action_constraint_splitter=(
            observation_and_action_constraint_splitter))

    if environment.batch_size is not None:
        batched_time_step_spec = array_spec.add_outer_dims_nest(
            time_step_spec, outer_dims=(environment.batch_size, ))
    else:
        batched_time_step_spec = time_step_spec

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, batched_time_step_spec):
            raise ValueError('Given `time_step`: %r does not match expected '
                             '`time_step_spec`: %r' %
                             (time_step, batched_time_step_spec))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        episode_count += np.sum(time_step.is_last())
Ejemplo n.º 5
0
def collect(summary_dir: Text,
            environment_name: Text,
            collect_policy: py_tf_eager_policy.PyTFEagerPolicyBase,
            replay_buffer_server_address: Text,
            variable_container_server_address: Text,
            suite_load_fn: Callable[
                [Text], py_environment.PyEnvironment] = suite_mujoco.load,
            initial_collect_steps: int = 10000,
            max_train_steps: int = 2000000) -> None:
  """Collects experience using a policy updated after every episode."""
  # Create the environment. For now support only single environment collection.
  collect_env = suite_load_fn(environment_name)

  # Create the variable container.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the replay buffer observer.
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb.Client(replay_buffer_server_address),
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      sequence_length=2,
      stride_length=1)

  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  env_step_metric = py_metrics.EnvironmentSteps()
  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=1,
      metrics=actor.collect_metrics(10),
      summary_dir=summary_dir,
      observers=[rb_observer, env_step_metric])

  # Run the experience collection loop.
  while train_step.numpy() < max_train_steps:
    logging.info('Collecting with policy at step: %d', train_step.numpy())
    collect_actor.run()
    variable_container.update(variables)
Ejemplo n.º 6
0
def create_random_gif():
    """Create a gif showing a random policy."""
    env_params = {
        'monster_speed': 0.7,
        'timeout_factor': 20,
        'step_size': 0.05,
        'n_actions': 8
    }
    py_env = LakeMonsterEnvironment(**env_params)
    policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
                                             action_spec=py_env.action_spec())

    save_path = os.path.join(configs.ASSETS_DIR, 'random.gif')
    episode_as_gif(py_env, policy, save_path=save_path)
Ejemplo n.º 7
0
    def testGeneratesActions(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]
        policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
                                                 action_spec=action_spec)

        action_step = policy.action(None)
        tf.nest.assert_same_structure(action_spec, action_step.action)

        self.assertTrue(np.all(action_step.action[0] >= -10))
        self.assertTrue(np.all(action_step.action[0] <= 10))
        self.assertTrue(np.all(action_step.action[1] >= -10))
        self.assertTrue(np.all(action_step.action[1] <= 10))
Ejemplo n.º 8
0
 def _initial_collect(self):
   """Collect initial experience before training begins."""
   logging.info('Collecting initial experience...')
   time_step_spec = ts.time_step_spec(self._env.observation_spec())
   random_policy = random_py_policy.RandomPyPolicy(
       time_step_spec, self._env.action_spec())
   time_step = self._env.reset()
   while self._replay_buffer.size < self._initial_collect_steps:
     if self.game_over():
       time_step = self._env.reset()
     action_step = random_policy.action(time_step)
     next_time_step = self._env.step(action_step.action)
     self._replay_buffer.add_batch(trajectory.from_transition(
         time_step, action_step, next_time_step))
     time_step = next_time_step
   logging.info('Done.')
Ejemplo n.º 9
0
    def test_with_random_policy(self):
        def _global_context_sampling_fn():
            abc = np.array(['a', 'b', 'c'])
            return {
                'global1': np.random.randint(-2, 3, [3, 4]),
                'global2': abc[np.random.randint(0, 2, [1])]
            }

        def _arm_context_sampling_fn():
            aabbcc = np.array(['aa', 'bb', 'cc'])
            return {
                'arm1': np.random.randint(-3, 4, [5]),
                'arm2': np.random.randint(-3, 4, [3, 1]),
                'arm3': aabbcc[np.random.randint(0, 2, [1])]
            }

        def _reward_fn(global_obs, arm_obs):
            return global_obs['global1'][2, 1] + arm_obs['arm1'][4]

        env = ssspe.StationaryStochasticStructuredPyEnvironment(
            _global_context_sampling_fn,
            _arm_context_sampling_fn,
            6,
            _reward_fn,
            batch_size=2)
        time_step_spec = env.time_step_spec()
        action_spec = array_spec.BoundedArraySpec(shape=(),
                                                  minimum=0,
                                                  maximum=5,
                                                  dtype=np.int32)

        random_policy = random_py_policy.RandomPyPolicy(
            time_step_spec=time_step_spec, action_spec=action_spec)

        for _ in range(5):
            time_step = env.reset()
            self.assertTrue(
                check_unbatched_time_step_spec(time_step=time_step,
                                               time_step_spec=time_step_spec,
                                               batch_size=env.batch_size))

            action = random_policy.action(time_step).action
            self.assertAllEqual(action.shape, [2])
            self.assertAllGreaterEqual(action, 0)
            self.assertAllLess(action, 6)
            time_step = env.step(action)
            self.assertEqual(time_step.reward.shape, (2, ))
Ejemplo n.º 10
0
  def testGeneratesBatchedActions(self):
    action_spec = [
        array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10),
        array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
    ]
    policy = random_py_policy.RandomPyPolicy(
        time_step_spec=None, action_spec=action_spec, outer_dims=(3,))

    action_step = policy.action(None)
    nest.assert_same_structure(action_spec, action_step.action)
    self.assertEqual((3, 2, 3), action_step.action[0].shape)
    self.assertEqual((3, 1, 2), action_step.action[1].shape)

    self.assertTrue(np.all(action_step.action[0] >= -10))
    self.assertTrue(np.all(action_step.action[0] <= 10))
    self.assertTrue(np.all(action_step.action[1] >= -10))
    self.assertTrue(np.all(action_step.action[1] <= 10))
Ejemplo n.º 11
0
 def _insert_random_data(self,
                         env,
                         num_steps,
                         sequence_length=2,
                         additional_observers=None):
   """Insert `num_step` random observations into Reverb server."""
   observers = [] if additional_observers is None else additional_observers
   traj_obs = reverb_utils.ReverbAddTrajectoryObserver(
       self._py_client, self._table_name, sequence_length=sequence_length)
   observers.append(traj_obs)
   policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                            env.action_spec())
   driver = py_driver.PyDriver(env,
                               policy,
                               observers=observers,
                               max_steps=num_steps)
   time_step = env.reset()
   driver.run(time_step)
   traj_obs.close()
    def test_with_variable_num_actions_masking(self):
        def _global_context_sampling_fn():
            return np.random.randint(-10, 10, [4])

        def _arm_context_sampling_fn():
            return np.random.randint(-2, 3, [5])

        def _num_actions_fn():
            return np.random.randint(0, 7)

        reward_fn = LinearNormalReward([0, 1, 2, 3, 4, 5, 6, 7, 8])

        env = sspe.StationaryStochasticPerArmPyEnvironment(
            _global_context_sampling_fn,
            _arm_context_sampling_fn,
            6,
            reward_fn,
            _num_actions_fn,
            batch_size=2,
            add_num_actions_feature=False)
        time_step_spec = env.time_step_spec()
        self.assertAllEqual(time_step_spec.observation[1].shape, [6])
        action_spec = array_spec.BoundedArraySpec(shape=(),
                                                  minimum=0,
                                                  maximum=5,
                                                  dtype=np.int32)

        random_policy = random_py_policy.RandomPyPolicy(
            time_step_spec=time_step_spec, action_spec=action_spec)

        for _ in range(5):
            time_step = env.reset()
            self.assertTrue(
                check_unbatched_time_step_spec(time_step=time_step,
                                               time_step_spec=time_step_spec,
                                               batch_size=env.batch_size))

            action = random_policy.action(time_step).action
            self.assertAllEqual(action.shape, [2])
            self.assertAllGreaterEqual(action, 0)
            self.assertAllLess(action, 6)
            time_step = env.step(action)
Ejemplo n.º 13
0
    def testGeneratesBatchedActionsWithoutSpecifyingOuterDims(self):
        action_spec = [
            array_spec.BoundedArraySpec((2, 3), np.int32, -10, 10),
            array_spec.BoundedArraySpec((1, 2), np.int32, -10, 10)
        ]
        time_step_spec = time_step.time_step_spec(
            observation_spec=array_spec.ArraySpec((1, ), np.int32))
        policy = random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec,
                                                 action_spec=action_spec)

        action_step = policy.action(
            time_step.restart(np.array([[1], [2], [3]], dtype=np.int32)))
        tf.nest.assert_same_structure(action_spec, action_step.action)
        self.assertEqual((3, 2, 3), action_step.action[0].shape)
        self.assertEqual((3, 1, 2), action_step.action[1].shape)

        self.assertTrue(np.all(action_step.action[0] >= -10))
        self.assertTrue(np.all(action_step.action[0] <= 10))
        self.assertTrue(np.all(action_step.action[1] >= -10))
        self.assertTrue(np.all(action_step.action[1] <= 10))
Ejemplo n.º 14
0
  def testRandomPyPolicyGeneratesActionTensors(self):
    array_action_spec = array_spec.BoundedArraySpec((7,), np.int32, -10, 10)
    observation = tf.ones([3], tf.float32)
    time_step = ts.restart(observation)

    observation_spec = tensor_spec.TensorSpec.from_tensor(observation)
    time_step_spec = ts.time_step_spec(observation_spec)

    tf_py_random_policy = tf_py_policy.TFPyPolicy(
        random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec,
                                        action_spec=array_action_spec))

    batched_time_step = nest_utils.batch_nested_tensors(time_step)
    action_step = tf_py_random_policy.action(time_step=batched_time_step)
    action, new_policy_state = self.evaluate(
        [action_step.action, action_step.state])

    self.assertEqual((1,) + array_action_spec.shape, action.shape)
    self.assertTrue(np.all(action >= array_action_spec.minimum))
    self.assertTrue(np.all(action <= array_action_spec.maximum))
    self.assertEqual(new_policy_state, ())
Ejemplo n.º 15
0
def _create_collect_actor(
        collect_env: YGOEnvironment, collect_policy: PyTFEagerPolicy,
        train_step, rb_observer: ReverbAddTrajectoryObserver) -> actor.Actor:

    initial_collect_actor = actor.Actor(
        collect_env,
        random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                        collect_env.action_spec()),
        train_step,
        episodes_per_run=_initial_collect_episodes,
        observers=[rb_observer])
    initial_collect_actor.run()

    return actor.Actor(collect_env,
                       collect_policy,
                       train_step,
                       episodes_per_run=1,
                       metrics=actor.collect_metrics(10),
                       summary_dir=os.path.join(tempdir, learner.TRAIN_DIR),
                       observers=[rb_observer,
                                  py_metrics.EnvironmentSteps()])
Ejemplo n.º 16
0
def validate_py_environment(environment, episodes=5):
    """Validates the environment follows the defined specs."""
    time_step_spec = environment.time_step_spec()
    action_spec = environment.action_spec()

    random_policy = random_py_policy.RandomPyPolicy(
        time_step_spec=time_step_spec, action_spec=action_spec)

    episode_count = 0
    time_step = environment.reset()

    while episode_count < episodes:
        if not array_spec.check_arrays_nest(time_step, time_step_spec):
            raise ValueError(
                'Given `time_step`: %r does not match expected `time_step_spec`: %r'
                % (time_step, random_policy.time_step_spec()))

        action = random_policy.action(time_step).action
        time_step = environment.step(action)

        if time_step.is_last():
            episode_count += 1
Ejemplo n.º 17
0
def profile_env(env_str, max_ep_len, n_steps=None, env_wrappers=[]):
  n_steps = n_steps or max_ep_len * 2
  profile = [None]

  def profile_fn(p):
    assert isinstance(p, cProfile.Profile)
    profile[0] = p

  py_env = suite_gym.load(env_str, gym_env_wrappers=env_wrappers,
                          max_episode_steps=max_ep_len)
  env = wrappers.PerformanceProfiler(
    py_env, process_profile_fn=profile_fn,
    process_steps=n_steps)
  policy = random_py_policy.RandomPyPolicy(env.time_step_spec(), env.action_spec())

  driver = py_driver.PyDriver(env, policy, [], max_steps=n_steps)
  time_step = env.reset()
  policy_state = policy.get_initial_state()
  for _ in range(n_steps):
    time_step, policy_state = driver.run(time_step, policy_state)
  stats = pstats.Stats(profile[0])
  stats.print_stats()
    def testPyPolicyIsBatchedTrue(self):
        action_dims = 5
        observation_dims = 3
        batch_size = 2
        array_action_spec = array_spec.BoundedArraySpec((action_dims, ),
                                                        np.int32, -10, 10)
        observation_spec = array_spec.ArraySpec((observation_dims, ),
                                                np.float32)
        array_time_step_spec = ts.time_step_spec(observation_spec)

        observation = tf.ones([batch_size, observation_dims], tf.float32)
        time_step = ts.restart(observation, batch_size=batch_size)

        tf_py_random_policy = tf_py_policy.TFPyPolicy(
            random_py_policy.RandomPyPolicy(
                time_step_spec=array_time_step_spec,
                action_spec=array_action_spec),
            py_policy_is_batched=True)

        action_step = tf_py_random_policy.action(time_step=time_step)
        action = self.evaluate(action_step.action)

        self.assertEqual(action.shape, (batch_size, action_dims))
Ejemplo n.º 19
0
  def testRandomPyPolicyGeneratesActionTensors(self):
    if tf.executing_eagerly():
      self.skipTest('b/123935604')

    py_action_spec = array_spec.BoundedArraySpec((7,), np.int32, -10, 10)

    observation = tf.ones([3], tf.float32)
    time_step = ts.restart(observation)
    observation_spec = tensor_spec.TensorSpec.from_tensor(observation)
    time_step_spec = ts.time_step_spec(observation_spec)

    tf_py_random_policy = tf_py_policy.TFPyPolicy(
        random_py_policy.RandomPyPolicy(time_step_spec=time_step_spec,
                                        action_spec=py_action_spec))

    action_step = tf_py_random_policy.action(time_step=time_step)
    py_action, py_new_policy_state = self.evaluate(
        [action_step.action, action_step.state])

    self.assertEqual(py_action.shape, py_action_spec.shape)
    self.assertTrue(np.all(py_action >= py_action_spec.minimum))
    self.assertTrue(np.all(py_action <= py_action_spec.maximum))
    self.assertEqual(py_new_policy_state, ())
Ejemplo n.º 20
0
    def testMasking(self):
        batch_size = 1000

        time_step_spec = time_step.time_step_spec(
            observation_spec=array_spec.ArraySpec((1, ), np.int32))
        action_spec = array_spec.BoundedArraySpec((), np.int64, -5, 5)

        # We create a fixed mask here for testing purposes. Normally the mask would
        # be part of the observation.
        mask = [0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0]
        np_mask = np.array(mask)
        batched_mask = np.array([mask for _ in range(batch_size)])

        policy = random_py_policy.RandomPyPolicy(
            time_step_spec=time_step_spec,
            action_spec=action_spec,
            observation_and_action_constraint_splitter=(lambda obs:
                                                        (obs, batched_mask)))

        my_time_step = time_step.restart(time_step_spec, batch_size)
        action_step = policy.action(my_time_step)
        tf.nest.assert_same_structure(action_spec, action_step.action)

        # Sample from the policy 1000 times, and ensure that actions considered
        # invalid according to the mask are never chosen.
        action_ = self.evaluate(action_step.action)
        self.assertTrue(np.all(action_ >= -5))
        self.assertTrue(np.all(action_ <= 5))
        self.assertAllEqual(np_mask[action_ - action_spec.minimum],
                            np.ones([batch_size]))

        # Ensure that all valid actions occur somewhere within the batch. Because we
        # sample 1000 times, the chance of this failing for any particular action is
        # (2/3)^1000, roughly 1e-176.
        for index in range(action_spec.minimum, action_spec.maximum + 1):
            if np_mask[index - action_spec.minimum]:
                self.assertIn(index, action_)
Ejemplo n.º 21
0
 def testPolicyStateSpecIsEmpty(self):
     policy = random_py_policy.RandomPyPolicy(time_step_spec=None,
                                              action_spec=[])
     self.assertEqual(policy.policy_state_spec, ())
Ejemplo n.º 22
0
def train_eval(
        root_dir,
        # Dataset params
        env_name,
        data_dir=None,
        load_pretrained=False,
        pretrained_model_dir=None,
        img_pad=4,
        frame_shape=(84, 84, 3),
        frame_stack=3,
        num_augmentations=2,  # K and M in DrQ
        # Training params
    contrastive_loss_weight=1.0,
        contrastive_loss_temperature=0.5,
        image_encoder_representation=True,
        initial_collect_steps=1000,
        num_train_steps=3000000,
        actor_fc_layers=(1024, 1024),
        critic_joint_fc_layers=(1024, 1024),
        # Agent params
        batch_size=256,
        actor_learning_rate=1e-3,
        critic_learning_rate=1e-3,
        alpha_learning_rate=1e-3,
        encoder_learning_rate=1e-3,
        actor_update_freq=2,
        gamma=0.99,
        target_update_tau=0.01,
        target_update_period=2,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        checkpoint_interval=10000,
        policy_save_interval=5000,
        eval_interval=10000,
        summary_interval=250,
        debug_summaries=False,
        eval_episodes_per_run=10,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    collect_env = env_utils.load_dm_env_for_training(env_name,
                                                     frame_shape,
                                                     frame_stack=frame_stack)
    eval_env = env_utils.load_dm_env_for_eval(env_name,
                                              frame_shape,
                                              frame_stack=frame_stack)

    logging.info('Data directory: %s', data_dir)
    logging.info('Num train steps: %d', num_train_steps)
    logging.info('Contrastive loss coeff: %.2f', contrastive_loss_weight)
    logging.info('Contrastive loss temperature: %.4f',
                 contrastive_loss_temperature)
    logging.info('load_pretrained: %s', 'yes' if load_pretrained else 'no')
    logging.info('encoder representation: %s',
                 'yes' if image_encoder_representation else 'no')

    load_episode_data = (contrastive_loss_weight > 0)
    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()
    image_encoder = networks.ImageEncoder(observation_tensor_spec)

    actor_net = model_utils.Actor(
        observation_tensor_spec,
        action_tensor_spec,
        image_encoder=image_encoder,
        fc_layers=actor_fc_layers,
        image_encoder_representation=image_encoder_representation)

    critic_net = networks.Critic((observation_tensor_spec, action_tensor_spec),
                                 image_encoder=image_encoder,
                                 joint_fc_layers=critic_joint_fc_layers)
    critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=image_encoder,
        joint_fc_layers=critic_joint_fc_layers)

    target_image_encoder = networks.ImageEncoder(observation_tensor_spec)
    target_critic_net_1 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)
    target_critic_net_2 = networks.Critic(
        (observation_tensor_spec, action_tensor_spec),
        image_encoder=target_image_encoder)

    agent = pse_drq_agent.DrQSacModifiedAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        critic_network_2=critic_net_2,
        target_critic_network=target_critic_net_1,
        target_critic_network_2=target_critic_net_2,
        actor_update_frequency=actor_update_freq,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        contrastive_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=encoder_learning_rate),
        contrastive_loss_weight=contrastive_loss_weight,
        contrastive_loss_temperature=contrastive_loss_temperature,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        use_log_alpha_in_alpha_loss=False,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step,
        num_augmentations=num_augmentations)
    agent.initialize()

    # Setup the replay buffer.
    reverb_replay, rb_observer = (
        replay_buffer_utils.get_reverb_buffer_and_observer(
            agent.collect_data_spec,
            sequence_length=2,
            replay_capacity=replay_capacity,
            port=reverb_port))

    # pylint: disable=g-long-lambda
    if num_augmentations == 0:
        image_aug = lambda traj, meta: (dict(
            experience=traj, augmented_obs=[], augmented_next_obs=[]), meta)
    else:
        image_aug = lambda traj, meta: pse_drq_agent.image_aug(
            traj, meta, img_pad, num_augmentations)
    augmented_dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                                 num_steps=2).unbatch().map(
                                                     image_aug,
                                                     num_parallel_calls=3)
    augmented_iterator = iter(augmented_dataset)

    trajs = augmented_dataset.batch(batch_size).prefetch(50)
    if load_episode_data:
        # Load full episodes and zip them
        episodes = dataset_utils.load_episodes(
            os.path.join(data_dir, 'episodes2'), img_pad)
        episode_iterator = iter(episodes)
        dataset = tf.data.Dataset.zip((trajs, episodes)).prefetch(10)
    else:
        dataset = trajs
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    learning_triggers = [
        triggers.PolicySavedModelTrigger(saved_model_dir,
                                         agent,
                                         train_step,
                                         interval=policy_save_interval),
        triggers.StepPerSecondLogTrigger(train_step,
                                         interval=summary_interval),
    ]

    agent_learner = model_utils.Learner(
        root_dir,
        train_step,
        agent,
        experience_dataset_fn=experience_dataset_fn,
        triggers=learning_triggers,
        checkpoint_interval=checkpoint_interval,
        summary_interval=summary_interval,
        load_episode_data=load_episode_data,
        use_kwargs_in_agent_train=True,
        # Turn off the initialization of the optimizer variables since, the agent
        # expects different batching for the `training_data_spec` and
        # `train_argspec` which can't be handled in general by the initialization
        # logic in the learner.
        run_optimizer_variable_init=False)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    train_dir = os.path.join(root_dir, learner.TRAIN_DIR)

    # Code for loading pretrained policy.
    if load_pretrained:
        # Note that num_train_steps is same as the max_train_step we want to
        # load the pretrained policy for our experiments
        pretrained_policy = model_utils.load_pretrained_policy(
            pretrained_model_dir, num_train_steps)
        initial_collect_policy = pretrained_policy

        agent.policy.update_partial(pretrained_policy)
        agent.collect_policy.update_partial(pretrained_policy)
        logging.info('Restored pretrained policy.')
    else:
        initial_collect_policy = random_py_policy.RandomPyPolicy(
            collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        initial_collect_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                observers=[rb_observer],
                                metrics=actor.collect_metrics(buffer_size=10),
                                summary_dir=train_dir,
                                summary_interval=summary_interval,
                                name='CollectActor')

    # If restarting with train_step > 0, the replay buffer will be empty
    # except for random experience. Populate the buffer with some on-policy
    # experience.
    if load_pretrained or (agent_learner.train_step_numpy > 0):
        for _ in range(batch_size * 50):
            collect_actor.run()

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(eval_env,
                             greedy_policy,
                             train_step,
                             episodes_per_run=eval_episodes_per_run,
                             metrics=actor.eval_metrics(buffer_size=10),
                             summary_dir=os.path.join(root_dir, 'eval'),
                             summary_interval=-1,
                             name='EvalTrainActor')

    if eval_interval:
        logging.info('Evaluating.')
        img_summary(
            next(augmented_iterator)[0], eval_actor.summary_writer, train_step)
        if load_episode_data:
            contrastive_img_summary(next(episode_iterator), agent,
                                    eval_actor.summary_writer, train_step)
        eval_actor.run_and_log()

    logging.info('Saving operative gin config file.')
    gin_path = os.path.join(train_dir, 'train_operative_gin_config.txt')
    with tf.io.gfile.GFile(gin_path, mode='w') as f:
        f.write(gin.operative_config_str())

    logging.info('Training Staring at: %r', train_step.numpy())
    while train_step < num_train_steps:
        collect_actor.run()
        agent_learner.run(iterations=1)
        if (not eval_interval) and (train_step % 10000 == 0):
            img_summary(
                next(augmented_iterator)[0],
                agent_learner.train_summary_writer, train_step)
        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            img_summary(
                next(augmented_iterator)[0], eval_actor.summary_writer,
                train_step)
            if load_episode_data:
                contrastive_img_summary(next(episode_iterator), agent,
                                        eval_actor.summary_writer, train_step)
            eval_actor.run_and_log()
Ejemplo n.º 23
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Params for collect
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        epsilon_greedy=0.1,
        replay_buffer_capacity=100000,
        # Params for target update
        target_update_tau=0.05,
        target_update_period=5,
        # Params for train
        train_steps_per_iteration=1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        reward_scale_factor=1.0,
        gradient_clipping=None,
        # Params for eval
        num_eval_episodes=10,
        eval_interval=1000,
        # Params for checkpoints, summaries and logging
        train_checkpoint_interval=10000,
        policy_checkpoint_interval=5000,
        log_interval=1000,
        summaries_flush_secs=10,
        debug_summaries=False,
        summarize_grads_and_vars=False,
        eval_metrics_callback=None):
    """A simple train and eval for DQN."""
    root_dir = os.path.expanduser(root_dir)
    train_dir = os.path.join(root_dir, 'train')
    eval_dir = os.path.join(root_dir, 'eval')

    train_summary_writer = tf.compat.v2.summary.create_file_writer(
        train_dir, flush_millis=summaries_flush_secs * 1000)
    train_summary_writer.set_as_default()

    eval_summary_writer = tf.compat.v2.summary.create_file_writer(
        eval_dir, flush_millis=summaries_flush_secs * 1000)
    eval_metrics = [
        py_metrics.AverageReturnMetric(buffer_size=num_eval_episodes),
        py_metrics.AverageEpisodeLengthMetric(buffer_size=num_eval_episodes),
    ]

    # Note this is a python environment.
    env = batched_py_environment.BatchedPyEnvironment(
        [suite_gym.load(env_name)])
    eval_py_env = suite_gym.load(env_name)

    # Convert specs to BoundedTensorSpec.
    action_spec = tensor_spec.from_spec(env.action_spec())
    observation_spec = tensor_spec.from_spec(env.observation_spec())
    time_step_spec = ts.time_step_spec(observation_spec)

    q_net = q_network.QNetwork(tensor_spec.from_spec(env.observation_spec()),
                               tensor_spec.from_spec(env.action_spec()),
                               fc_layer_params=fc_layer_params)

    # The agent must be in graph.
    global_step = tf.compat.v1.train.get_or_create_global_step()
    agent = dqn_agent.DqnAgent(
        time_step_spec,
        action_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=learning_rate),
        td_errors_loss_fn=dqn_agent.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=gradient_clipping,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=global_step)

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_policy.PyTFPolicy(tf_collect_policy)
    greedy_policy = py_tf_policy.PyTFPolicy(agent.policy)
    random_policy = random_py_policy.RandomPyPolicy(env.time_step_spec(),
                                                    env.action_spec())

    # Python replay buffer.
    replay_buffer = py_uniform_replay_buffer.PyUniformReplayBuffer(
        capacity=replay_buffer_capacity,
        data_spec=tensor_spec.to_nest_array_spec(agent.collect_data_spec))

    time_step = env.reset()

    # Initialize the replay buffer with some transitions. We use the random
    # policy to initialize the replay buffer to make sure we get a good
    # distribution of actions.
    for _ in range(initial_collect_steps):
        time_step = collect_step(env, time_step, random_policy, replay_buffer)

    # TODO(b/112041045) Use global_step as counter.
    train_checkpointer = common.Checkpointer(ckpt_dir=train_dir,
                                             agent=agent,
                                             global_step=global_step)

    policy_checkpointer = common.Checkpointer(ckpt_dir=os.path.join(
        train_dir, 'policy'),
                                              policy=agent.policy,
                                              global_step=global_step)

    ds = replay_buffer.as_dataset(sample_batch_size=batch_size,
                                  num_steps=n_step_update + 1)
    ds = ds.prefetch(4)
    itr = tf.compat.v1.data.make_initializable_iterator(ds)

    experience = itr.get_next()

    train_op = common.function(agent.train)(experience)

    with eval_summary_writer.as_default(), \
         tf.compat.v2.summary.record_if(True):
        for eval_metric in eval_metrics:
            eval_metric.tf_summaries(train_step=global_step)

    with tf.compat.v1.Session() as session:
        train_checkpointer.initialize_or_restore(session)
        common.initialize_uninitialized_variables(session)
        session.run(itr.initializer)
        # Copy critic network values to the target critic network.
        session.run(agent.initialize())
        train = session.make_callable(train_op)
        global_step_call = session.make_callable(global_step)
        session.run(train_summary_writer.init())
        session.run(eval_summary_writer.init())

        # Compute initial evaluation metrics.
        global_step_val = global_step_call()
        metric_utils.compute_summaries(
            eval_metrics,
            eval_py_env,
            greedy_policy,
            num_episodes=num_eval_episodes,
            global_step=global_step_val,
            log=True,
            callback=eval_metrics_callback,
        )

        timed_at_step = global_step_val
        collect_time = 0
        train_time = 0
        steps_per_second_ph = tf.compat.v1.placeholder(tf.float32,
                                                       shape=(),
                                                       name='steps_per_sec_ph')
        steps_per_second_summary = tf.compat.v2.summary.scalar(
            name='global_steps_per_sec',
            data=steps_per_second_ph,
            step=global_step)

        for _ in range(num_iterations):
            start_time = time.time()
            for _ in range(collect_steps_per_iteration):
                time_step = collect_step(env, time_step, collect_policy,
                                         replay_buffer)
            collect_time += time.time() - start_time
            start_time = time.time()
            for _ in range(train_steps_per_iteration):
                loss = train()
            train_time += time.time() - start_time
            global_step_val = global_step_call()
            if global_step_val % log_interval == 0:
                logging.info('step = %d, loss = %f', global_step_val,
                             loss.loss)
                steps_per_sec = ((global_step_val - timed_at_step) /
                                 (collect_time + train_time))
                session.run(steps_per_second_summary,
                            feed_dict={steps_per_second_ph: steps_per_sec})
                logging.info('%.3f steps/sec', steps_per_sec)
                logging.info(
                    '%s', 'collect_time = {}, train_time = {}'.format(
                        collect_time, train_time))
                timed_at_step = global_step_val
                collect_time = 0
                train_time = 0

            if global_step_val % train_checkpoint_interval == 0:
                train_checkpointer.save(global_step=global_step_val)

            if global_step_val % policy_checkpoint_interval == 0:
                policy_checkpointer.save(global_step=global_step_val)

            if global_step_val % eval_interval == 0:
                metric_utils.compute_summaries(
                    eval_metrics,
                    eval_py_env,
                    greedy_policy,
                    num_episodes=num_eval_episodes,
                    global_step=global_step_val,
                    log=True,
                    callback=eval_metrics_callback,
                )
                # Reset timing to avoid counting eval time.
                timed_at_step = global_step_val
                start_time = time.time()
Ejemplo n.º 24
0
def _create_random_policy_from_env(env):
    return random_py_policy.RandomPyPolicy(
        ts.time_step_spec(env.observation_spec()), env.action_spec())
Ejemplo n.º 25
0
def train_eval(
        root_dir,
        env_name='CartPole-v0',
        # Training params
        initial_collect_steps=1000,
        num_iterations=100000,
        fc_layer_params=(100, ),
        # Agent params
        epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        n_step_update=1,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""
    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    time_step_tensor_spec = tensor_spec.from_spec(collect_env.time_step_spec())
    action_tensor_spec = tensor_spec.from_spec(collect_env.action_spec())

    train_step = train_utils.create_train_step()
    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1

    # Define a helper function to create Dense layers configured with the right
    # activation and kernel initializer.
    def dense_layer(num_units):
        return tf.keras.layers.Dense(
            num_units,
            activation=tf.keras.activations.relu,
            kernel_initializer=tf.keras.initializers.VarianceScaling(
                scale=2.0, mode='fan_in', distribution='truncated_normal'))

    # QNetwork consists of a sequence of Dense layers followed by a dense layer
    # with `num_actions` units to generate one q_value per available action as
    # it's output.
    dense_layers = [dense_layer(num_units) for num_units in fc_layer_params]
    q_values_layer = tf.keras.layers.Dense(
        num_actions,
        activation=None,
        kernel_initializer=tf.keras.initializers.RandomUniform(minval=-0.03,
                                                               maxval=0.03),
        bias_initializer=tf.keras.initializers.Constant(-0.2))
    q_net = sequential.Sequential(dense_layers + [q_values_layer])

    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        n_step_update=n_step_update,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(num_parallel_calls=3,
                                       sample_batch_size=batch_size,
                                       num_steps=2).prefetch(3)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=1,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Ejemplo n.º 26
0
def train_eval(
        root_dir,
        env_name,
        # Training params
        train_sequence_length,
        initial_collect_steps=1000,
        collect_steps_per_iteration=1,
        num_iterations=100000,
        # RNN params.
        q_network_fn=q_lstm_network,  # defaults to q_lstm_network.
        # Agent params
    epsilon_greedy=0.1,
        batch_size=64,
        learning_rate=1e-3,
        gamma=0.99,
        target_update_tau=0.05,
        target_update_period=5,
        reward_scale_factor=1.0,
        # Replay params
        reverb_port=None,
        replay_capacity=100000,
        # Others
        policy_save_interval=1000,
        eval_interval=1000,
        eval_episodes=10):
    """Trains and evaluates DQN."""

    collect_env = suite_gym.load(env_name)
    eval_env = suite_gym.load(env_name)

    unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
    q_net = q_network_fn(num_actions=num_actions)

    sequence_length = train_sequence_length + 1
    agent = dqn_agent.DqnAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        q_network=q_net,
        epsilon_greedy=epsilon_greedy,
        # n-step updates aren't supported with RNNs yet.
        n_step_update=1,
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        td_errors_loss_fn=common.element_wise_squared_loss,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        train_step_counter=train_step)

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))
    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=sequence_length,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=sequence_length,
        stride_length=1,
        pad_end_of_episodes=True)

    def experience_dataset_fn():
        return reverb_replay.as_dataset(sample_batch_size=batch_size,
                                        num_steps=sequence_length)

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()

    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=100),
    ]

    dqn_learner = learner.Learner(root_dir,
                                  train_step,
                                  agent,
                                  experience_dataset_fn,
                                  triggers=learning_triggers)

    # If we haven't trained yet make sure we collect some random samples first to
    # fill up the Replay Buffer with some experience.
    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(
        collect_env,
        collect_policy,
        train_step,
        steps_per_run=collect_steps_per_iteration,
        observers=[rb_observer, env_step_metric],
        metrics=actor.collect_metrics(10),
        summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
    )

    tf_greedy_policy = agent.policy
    greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                       use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        dqn_learner.run(iterations=1)

        if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Ejemplo n.º 27
0
    def __init__(
            self,
            greedy_policy: py_policy.PyPolicy,
            epsilon: types.Float,
            random_policy: Optional[random_py_policy.RandomPyPolicy] = None,
            epsilon_decay_end_count: Optional[types.Float] = None,
            epsilon_decay_end_value: Optional[types.Float] = None,
            random_seed: Optional[types.Seed] = None):
        """Initializes the epsilon-greedy policy.

    Args:
      greedy_policy: An instance of py_policy.PyPolicy to use as the greedy
        policy.
      epsilon: The probability 0.0 <= epsilon <= 1.0 with which an
        action will be selected at random.
      random_policy: An instance of random_py_policy.RandomPyPolicy to
        use as the random policy, if None is provided, a
        RandomPyPolicy will be automatically created with the
        greedy_policy's action_spec and observation_spec and
        random_seed.
      epsilon_decay_end_count: if set, anneal the epislon every time
        this policy is used, until it hits the epsilon_decay_end_value.
      epsilon_decay_end_value: the value of epislon to use when the
        policy usage count hits epsilon_decay_end_count.
      random_seed: seed used to create numpy.random.RandomState.
        /dev/urandom will be used if it's None.

    Raises:
      ValueError: If epsilon is not between 0.0 and 1.0. Or if
      epsilon_decay_end_value is invalid when epsilon_decay_end_count is
      set.
    """
        if not 0 <= epsilon <= 1.0:
            raise ValueError('epsilon should be in [0.0, 1.0]')

        self._greedy_policy = greedy_policy
        if random_policy is None:
            self._random_policy = random_py_policy.RandomPyPolicy(
                time_step_spec=greedy_policy.time_step_spec,
                action_spec=greedy_policy.action_spec,
                seed=random_seed)
        else:
            self._random_policy = random_policy
        # TODO(b/110841809) consider making epsilon be provided by a function.
        self._epsilon = epsilon
        self._epsilon_decay_end_count = epsilon_decay_end_count
        if epsilon_decay_end_count is not None:
            if epsilon_decay_end_value is None or epsilon_decay_end_value >= epsilon:
                raise ValueError(
                    'Invalid value for epsilon_decay_end_value {}'.format(
                        epsilon_decay_end_value))
            self._epsilon_decay_step_factor = float(
                epsilon - epsilon_decay_end_value) / epsilon_decay_end_count
        self._epsilon_decay_end_value = epsilon_decay_end_value

        self._random_seed = random_seed  # Keep it for copy method.
        self._rng = np.random.RandomState(random_seed)

        # Total times action method has been called.
        self._count = 0

        super(EpsilonGreedyPolicy,
              self).__init__(greedy_policy.time_step_spec,
                             greedy_policy.action_spec,
                             greedy_policy.policy_state_spec,
                             greedy_policy.info_spec)
Ejemplo n.º 28
0
    sequence_length=2,
    table_name=table_name,
    local_server=reverb_server)

dataset = reverb_replay.as_dataset(sample_batch_size=HyperParms.batch_size,
                                   num_steps=2).prefetch(50)
experience_dataset_fn = lambda: dataset

print(f" --  POLICIES  ({now()})  -- ")
tf_eval_policy = tf_agent.policy
eval_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_eval_policy,
                                                 use_tf_function=True)
tf_collect_policy = tf_agent.collect_policy
collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                    use_tf_function=True)
random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                collect_env.action_spec())

print(f" --  ACTORS  ({now()})  -- ")
rb_observer = reverb_utils.ReverbAddTrajectoryObserver(reverb_replay.py_client,
                                                       table_name,
                                                       sequence_length=2,
                                                       stride_length=1)

initial_collect_actor = actor.Actor(
    collect_env,
    random_policy,
    train_step,
    steps_per_run=HyperParms.initial_collect_steps,
    observers=[rb_observer])
initial_collect_actor.run()
Ejemplo n.º 29
0
def train_eval(
        root_dir,
        env_name='HalfCheetah-v2',
        # Training params
        initial_collect_steps=10000,
        num_iterations=3200000,
        actor_fc_layers=(256, 256),
        critic_obs_fc_layers=None,
        critic_action_fc_layers=None,
        critic_joint_fc_layers=(256, 256),
        # Agent params
        batch_size=256,
        actor_learning_rate=3e-4,
        critic_learning_rate=3e-4,
        alpha_learning_rate=3e-4,
        gamma=0.99,
        target_update_tau=0.005,
        target_update_period=1,
        reward_scale_factor=0.1,
        # Replay params
        reverb_port=None,
        replay_capacity=1000000,
        # Others
        # Defaults to not checkpointing saved policy. If you wish to enable this,
        # please note the caveat explained in README.md.
        policy_save_interval=-1,
        eval_interval=10000,
        eval_episodes=30,
        debug_summaries=False,
        summarize_grads_and_vars=False):
    """Trains and evaluates SAC."""
    logging.info('Training SAC on: %s', env_name)
    collect_env = suite_mujoco.load(env_name)
    eval_env = suite_mujoco.load(env_name)

    observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
        spec_utils.get_tensor_specs(collect_env))

    train_step = train_utils.create_train_step()

    actor_net = actor_distribution_network.ActorDistributionNetwork(
        observation_tensor_spec,
        action_tensor_spec,
        fc_layer_params=actor_fc_layers,
        continuous_projection_net=tanh_normal_projection_network.
        TanhNormalProjectionNetwork)
    critic_net = critic_network.CriticNetwork(
        (observation_tensor_spec, action_tensor_spec),
        observation_fc_layer_params=critic_obs_fc_layers,
        action_fc_layer_params=critic_action_fc_layers,
        joint_fc_layer_params=critic_joint_fc_layers,
        kernel_initializer='glorot_uniform',
        last_kernel_initializer='glorot_uniform')

    agent = sac_agent.SacAgent(
        time_step_tensor_spec,
        action_tensor_spec,
        actor_network=actor_net,
        critic_network=critic_net,
        actor_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=actor_learning_rate),
        critic_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=critic_learning_rate),
        alpha_optimizer=tf.compat.v1.train.AdamOptimizer(
            learning_rate=alpha_learning_rate),
        target_update_tau=target_update_tau,
        target_update_period=target_update_period,
        td_errors_loss_fn=tf.math.squared_difference,
        gamma=gamma,
        reward_scale_factor=reward_scale_factor,
        gradient_clipping=None,
        debug_summaries=debug_summaries,
        summarize_grads_and_vars=summarize_grads_and_vars,
        train_step_counter=train_step)
    agent.initialize()

    table_name = 'uniform_table'
    table = reverb.Table(table_name,
                         max_size=replay_capacity,
                         sampler=reverb.selectors.Uniform(),
                         remover=reverb.selectors.Fifo(),
                         rate_limiter=reverb.rate_limiters.MinSize(1))

    reverb_server = reverb.Server([table], port=reverb_port)
    reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
        agent.collect_data_spec,
        sequence_length=2,
        table_name=table_name,
        local_server=reverb_server)
    rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
        reverb_replay.py_client,
        table_name,
        sequence_length=2,
        stride_length=1)

    dataset = reverb_replay.as_dataset(sample_batch_size=batch_size,
                                       num_steps=2).prefetch(50)
    experience_dataset_fn = lambda: dataset

    saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
    env_step_metric = py_metrics.EnvironmentSteps()
    learning_triggers = [
        triggers.PolicySavedModelTrigger(
            saved_model_dir,
            agent,
            train_step,
            interval=policy_save_interval,
            metadata_metrics={triggers.ENV_STEP_METADATA_KEY:
                              env_step_metric}),
        triggers.StepPerSecondLogTrigger(train_step, interval=1000),
    ]

    agent_learner = learner.Learner(root_dir,
                                    train_step,
                                    agent,
                                    experience_dataset_fn,
                                    triggers=learning_triggers)

    random_policy = random_py_policy.RandomPyPolicy(
        collect_env.time_step_spec(), collect_env.action_spec())
    initial_collect_actor = actor.Actor(collect_env,
                                        random_policy,
                                        train_step,
                                        steps_per_run=initial_collect_steps,
                                        observers=[rb_observer])
    logging.info('Doing initial collect.')
    initial_collect_actor.run()

    tf_collect_policy = agent.collect_policy
    collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                        use_tf_function=True)

    collect_actor = actor.Actor(collect_env,
                                collect_policy,
                                train_step,
                                steps_per_run=1,
                                metrics=actor.collect_metrics(10),
                                summary_dir=os.path.join(
                                    root_dir, learner.TRAIN_DIR),
                                observers=[rb_observer, env_step_metric])

    tf_greedy_policy = greedy_policy.GreedyPolicy(agent.policy)
    eval_greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(
        tf_greedy_policy, use_tf_function=True)

    eval_actor = actor.Actor(
        eval_env,
        eval_greedy_policy,
        train_step,
        episodes_per_run=eval_episodes,
        metrics=actor.eval_metrics(eval_episodes),
        summary_dir=os.path.join(root_dir, 'eval'),
    )

    if eval_interval:
        logging.info('Evaluating.')
        eval_actor.run_and_log()

    logging.info('Training.')
    for _ in range(num_iterations):
        collect_actor.run()
        agent_learner.run(iterations=1)

        if eval_interval and agent_learner.train_step_numpy % eval_interval == 0:
            logging.info('Evaluating.')
            eval_actor.run_and_log()

    rb_observer.close()
    reverb_server.stop()
Ejemplo n.º 30
0
def train_eval(
    root_dir,
    env_name='Pong-v0',
    # Training params
    update_frequency=4,  # Number of collect steps per policy update
    initial_collect_steps=50000,  # 50k collect steps
    num_iterations=50000000,  # 50M collect steps
    # Taken from Rainbow as it's not specified in Mnih,15.
    max_episode_frames_collect=50000,  # env frames observed by the agent
    max_episode_frames_eval=108000,  # env frames observed by the agent
    # Agent params
    epsilon_greedy=0.1,
    epsilon_decay_period=250000,  # 1M collect steps / update_frequency
    batch_size=32,
    learning_rate=0.00025,
    n_step_update=1,
    gamma=0.99,
    target_update_tau=1.0,
    target_update_period=2500,  # 10k collect steps / update_frequency
    reward_scale_factor=1.0,
    # Replay params
    reverb_port=None,
    replay_capacity=1000000,
    # Others
    policy_save_interval=250000,
    eval_interval=1000,
    eval_episodes=30,
    debug_summaries=True):
  """Trains and evaluates DQN."""

  collect_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_collect,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)
  eval_env = suite_atari.load(
      env_name,
      max_episode_steps=max_episode_frames_eval,
      gym_env_wrappers=suite_atari.DEFAULT_ATARI_GYM_WRAPPERS_WITH_STACKING)

  unused_observation_tensor_spec, action_tensor_spec, time_step_tensor_spec = (
      spec_utils.get_tensor_specs(collect_env))

  train_step = train_utils.create_train_step()

  num_actions = action_tensor_spec.maximum - action_tensor_spec.minimum + 1
  epsilon = tf.compat.v1.train.polynomial_decay(
      1.0,
      train_step,
      epsilon_decay_period,
      end_learning_rate=epsilon_greedy)
  agent = dqn_agent.DqnAgent(
      time_step_tensor_spec,
      action_tensor_spec,
      q_network=create_q_network(num_actions),
      epsilon_greedy=epsilon,
      n_step_update=n_step_update,
      target_update_tau=target_update_tau,
      target_update_period=target_update_period,
      optimizer=tf.compat.v1.train.RMSPropOptimizer(
          learning_rate=learning_rate,
          decay=0.95,
          momentum=0.95,
          epsilon=0.01,
          centered=True),
      td_errors_loss_fn=common.element_wise_huber_loss,
      gamma=gamma,
      reward_scale_factor=reward_scale_factor,
      train_step_counter=train_step,
      debug_summaries=debug_summaries)

  table_name = 'uniform_table'
  table = reverb.Table(
      table_name,
      max_size=replay_capacity,
      sampler=reverb.selectors.Uniform(),
      remover=reverb.selectors.Fifo(),
      rate_limiter=reverb.rate_limiters.MinSize(1))
  reverb_server = reverb.Server([table], port=reverb_port)
  reverb_replay = reverb_replay_buffer.ReverbReplayBuffer(
      agent.collect_data_spec,
      sequence_length=2,
      table_name=table_name,
      local_server=reverb_server)
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb_replay.py_client, table_name,
      sequence_length=2,
      stride_length=1)

  dataset = reverb_replay.as_dataset(
      sample_batch_size=batch_size, num_steps=2).prefetch(3)
  experience_dataset_fn = lambda: dataset

  saved_model_dir = os.path.join(root_dir, learner.POLICY_SAVED_MODEL_DIR)
  env_step_metric = py_metrics.EnvironmentSteps()

  learning_triggers = [
      triggers.PolicySavedModelTrigger(
          saved_model_dir,
          agent,
          train_step,
          interval=policy_save_interval,
          metadata_metrics={triggers.ENV_STEP_METADATA_KEY: env_step_metric}),
      triggers.StepPerSecondLogTrigger(train_step, interval=100),
  ]

  dqn_learner = learner.Learner(
      root_dir,
      train_step,
      agent,
      experience_dataset_fn,
      triggers=learning_triggers)

  # If we haven't trained yet make sure we collect some random samples first to
  # fill up the Replay Buffer with some experience.
  random_policy = random_py_policy.RandomPyPolicy(collect_env.time_step_spec(),
                                                  collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  tf_collect_policy = agent.collect_policy
  collect_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_collect_policy,
                                                      use_tf_function=True)

  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=update_frequency,
      observers=[rb_observer, env_step_metric],
      metrics=actor.collect_metrics(10),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR),
  )

  tf_greedy_policy = agent.policy
  greedy_policy = py_tf_eager_policy.PyTFEagerPolicy(tf_greedy_policy,
                                                     use_tf_function=True)

  eval_actor = actor.Actor(
      eval_env,
      greedy_policy,
      train_step,
      episodes_per_run=eval_episodes,
      metrics=actor.eval_metrics(eval_episodes),
      reference_metrics=[env_step_metric],
      summary_dir=os.path.join(root_dir, 'eval'),
  )

  if eval_interval:
    logging.info('Evaluating.')
    eval_actor.run_and_log()

  logging.info('Training.')
  for _ in range(num_iterations):
    collect_actor.run()
    dqn_learner.run(iterations=1)

    if eval_interval and dqn_learner.train_step_numpy % eval_interval == 0:
      logging.info('Evaluating.')
      eval_actor.run_and_log()

  rb_observer.close()
  reverb_server.stop()