def testSavedModel(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)
        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(self.time_step_spec,
                                                       rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        original_action = self.tf_policy.action(batched_sample_time_step)
        unbatched_original_action = nest_utils.unbatch_nested_tensors(
            original_action)
        original_action_np = tf.nest.map_structure(lambda t: t.numpy(),
                                                   unbatched_original_action)
        saved_policy_action = eager_py_policy.action(sample_time_step)

        tf.nest.assert_same_structure(saved_policy_action.action,
                                      self.action_spec)

        np.testing.assert_array_almost_equal(original_action_np.action,
                                             saved_policy_action.action)
    def testUpdateFromCheckpoint(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in TF2.x.')

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        eager_py_policy.update_from_checkpoint(checkpoint_path)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()),
                              check_types=False)
Пример #3
0
    def testUpdateFromCheckpoint(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)
        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        checkpoint = tf.train.Checkpoint(policy=eager_py_policy._policy)
        manager = tf.train.CheckpointManager(checkpoint,
                                             directory=checkpoint_path,
                                             max_to_keep=None)

        eager_py_policy.update_from_checkpoint(manager.latest_checkpoint)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()))
Пример #4
0
def wait_for_policy(
    policy_dir: Text,
    sleep_time_secs: int = _WAIT_DEFAULT_SLEEP_TIME_SECS,
    num_retries: int = _WAIT_DEFAULT_NUM_RETRIES,
    **saved_model_policy_args) -> py_tf_eager_policy.PyTFEagerPolicyBase:
  """Blocks until the policy in `policy_dir` becomes available.

  The default setting allows a fairly loose, but not infinite wait time of one
  days for this function to block checking the `policy_dir` in every seconds.

  Args:
    policy_dir: The directory containing the policy files.
    sleep_time_secs: Number of time in seconds slept between retries.
    num_retries: Number of times the existence of the file is checked.
    **saved_model_policy_args: Additional keyword arguments passed directly to
      the `SavedModelPyTFEagerPolicy` policy constructor which loads the policy
      from `policy_dir` once the policy becomes available.

  Returns:
    The policy loaded from the `policy_dir`.

  Raises:
    TimeoutError: If the policy does not become available during the number of
      retries.
  """
  # TODO(b/173815037): Write and wait for a DONE file instead.
  last_written_policy_file = os.path.join(policy_dir, 'policy_specs.pbtxt')
  wait_for_file(
      last_written_policy_file,
      sleep_time_secs=sleep_time_secs,
      num_retries=num_retries)
  return py_tf_eager_policy.SavedModelPyTFEagerPolicy(policy_dir,
                                                      **saved_model_policy_args)
Пример #5
0
    def testRegisteredFunction(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.register_function('actor_net', self._actor_net,
                                self.time_step_spec.observation)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)
        sample_observation = tensor_spec.sample_spec_nest(
            tensor_spec.from_spec(self.time_step_spec.observation),
            outer_dims=(3, ))
        eager_py_policy.actor_net(sample_observation)
Пример #6
0
    def testInferenceFromCheckpoint(self):
        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(self.tf_policy)
        saver.save(path)

        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(self.time_step_spec,
                                                       rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        self.evaluate(
            tf.nest.map_structure(lambda v: v.assign(v * 0 + -1),
                                  self.tf_policy.variables()))
        checkpoint_path = os.path.join(self.get_temp_dir(), 'checkpoint')
        saver.save_checkpoint(checkpoint_path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, self.time_step_spec, self.action_spec)

        # Use evaluate to force a copy.
        saved_model_variables = self.evaluate(eager_py_policy.variables())

        checkpoint = tf.train.Checkpoint(policy=eager_py_policy._policy)
        manager = tf.train.CheckpointManager(checkpoint,
                                             directory=checkpoint_path,
                                             max_to_keep=None)

        eager_py_policy.update_from_checkpoint(manager.latest_checkpoint)

        assert_np_not_equal = lambda a, b: self.assertFalse(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_not_equal, saved_model_variables,
                              self.evaluate(eager_py_policy.variables()))

        assert_np_all_equal = lambda a, b: self.assertTrue(
            np.equal(a, b).all())
        tf.nest.map_structure(assert_np_all_equal,
                              self.evaluate(self.tf_policy.variables()),
                              self.evaluate(eager_py_policy.variables()))

        # Can't check if the action is different as in some cases depending on
        # variable initialization it will be the same. Checking that they are at
        # least always the same.
        checkpoint_action = eager_py_policy.action(sample_time_step)

        current_policy_action = self.tf_policy.action(batched_sample_time_step)
        current_policy_action = self.evaluate(
            nest_utils.unbatch_nested_tensors(current_policy_action))
        tf.nest.map_structure(assert_np_all_equal, current_policy_action,
                              checkpoint_action)
Пример #7
0
  def testGetTrainStep(self, train_step):
    path = os.path.join(self.get_temp_dir(), 'saved_policy')
    if train_step is None:
      # Use the default argument, which should set the train step to be -1.
      saver = policy_saver.PolicySaver(self.tf_policy)
      expected_train_step = -1
    else:
      saver = policy_saver.PolicySaver(
          self.tf_policy, train_step=tf.constant(train_step))
      expected_train_step = train_step
    saver.save(path)

    eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        path, self.time_step_spec, self.action_spec)

    self.assertEqual(expected_train_step, eager_py_policy.get_train_step())
Пример #8
0
  def testSavedModelLoadingSpecs(self):
    path = os.path.join(self.get_temp_dir(), 'saved_policy')
    saver = policy_saver.PolicySaver(self.tf_policy)
    saver.save(path)

    eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        path, load_specs_from_pbtxt=True)

    # Bounded specs get converted to regular specs when saved into a proto.
    def assert_specs_mostly_equal(loaded_spec, expected_spec):
      self.assertEqual(loaded_spec.shape, expected_spec.shape)
      self.assertEqual(loaded_spec.dtype, expected_spec.dtype)

    tf.nest.map_structure(assert_specs_mostly_equal,
                          eager_py_policy.time_step_spec, self.time_step_spec)
    tf.nest.map_structure(assert_specs_mostly_equal,
                          eager_py_policy.action_spec, self.action_spec)
Пример #9
0
    def testSavedModel(self):
        if not common.has_eager_been_enabled():
            self.skipTest('Only supported in eager.')

        observation_spec = array_spec.ArraySpec([2], np.float32)
        action_spec = array_spec.BoundedArraySpec([1], np.float32, 2, 3)
        time_step_spec = ts.time_step_spec(observation_spec)

        observation_tensor_spec = tensor_spec.from_spec(observation_spec)
        action_tensor_spec = tensor_spec.from_spec(action_spec)
        time_step_tensor_spec = tensor_spec.from_spec(time_step_spec)

        actor_net = actor_network.ActorNetwork(
            observation_tensor_spec,
            action_tensor_spec,
            fc_layer_params=(10, ),
        )

        tf_policy = actor_policy.ActorPolicy(time_step_tensor_spec,
                                             action_tensor_spec,
                                             actor_network=actor_net)

        path = os.path.join(self.get_temp_dir(), 'saved_policy')
        saver = policy_saver.PolicySaver(tf_policy)
        saver.save(path)

        eager_py_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            path, time_step_spec, action_spec)

        rng = np.random.RandomState()
        sample_time_step = array_spec.sample_spec_nest(time_step_spec, rng)
        batched_sample_time_step = nest_utils.batch_nested_array(
            sample_time_step)

        original_action = tf_policy.action(batched_sample_time_step)
        unbatched_original_action = nest_utils.unbatch_nested_tensors(
            original_action)
        original_action_np = tf.nest.map_structure(lambda t: t.numpy(),
                                                   unbatched_original_action)
        saved_policy_action = eager_py_policy.action(sample_time_step)

        tf.nest.assert_same_structure(saved_policy_action.action, action_spec)

        np.testing.assert_array_almost_equal(original_action_np.action,
                                             saved_policy_action.action)
Пример #10
0
def load(saved_model_path, checkpoint_path=None):
    """Loads a policy.

  The argument `saved_model_path` is the path of a directory containing a full
  saved model for the policy. The path typically looks like
  '/root_dir/policies/policy', it may contain trailing numbers for the
  train_step.

  `saved_model_path` is expected to contain the following files. (There can be
  additional shards for the `variables.data` files.)
     * `saved_model.pb`
     * `policy_specs.pbtxt`
     * `variables/variables.index`
     * `variables/variables.data-00000-of-00001`

  The optional argument `checkpoint_path` is the path to a directory that
  contains variable checkpoints (as opposed to full saved models) for the
  policy. The path also typically ends-up with the checkpoint number,
  for example: '/my/save/dir/checkpoint/000022100'.

  If specified, `checkpoint_path` is expected to contain the following
  files. (There can be additional shards for the `variables.data` files.)
     * `variables/variables.index`
     * `variables/variables.data-00000-of-00001`

  `load()` recreates a policy from the saved model, and if it was specified
  updates the policy from the checkpoint.  It returns the policy.

  Args:
    saved_model_path: string. Path to a directory containing a full saved model.
    checkpoint_path: string. Optional path to a directory containing a
      checkpoint of the model variables.

  Returns:
    A `tf_agents.policies.SavedModelPyTFEagerPolicy`.
  """
    policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        saved_model_path, load_specs_from_pbtxt=True)
    if checkpoint_path:
        policy.update_from_checkpoint(checkpoint_path)
    return policy
Пример #11
0
def evaluate(env_name,
             saved_model_dir,
             env_load_fn=env_utils.load_dm_env_for_eval,
             num_episodes=1,
             eval_log_dir=None,
             continuous=False,
             max_train_step=math.inf,
             seconds_between_checkpoint_polls=5,
             num_retries=100,
             log_measurements=lambda metrics, current_step: None):
    """Evaluates a checkpoint directory.

  Checkpoints for the saved model to evaluate are assumed to be at the same
  directory level as the saved_model dir. ie:

  * saved_model_dir: root_dir/policies/greedy_policy
  * checkpoints_dir: root_dir/checkpoints

  Args:
    env_name: Name of the environment to evaluate in.
    saved_model_dir: String path to the saved model directory.
    env_load_fn: Function to load the environment specified by env_name.
    num_episodes: Number or episodes to evaluate per checkpoint.
    eval_log_dir: Optional path to output summaries of the evaluations. If None
      a default directory relative to the saved_model_dir will be used.
    continuous: If True all the evaluation will keep polling for new
      checkpoints.
    max_train_step: Maximum train_step to evaluate. Once a train_step greater or
      equal to this is evaluated the evaluations will terminate. Should set to
      <= train_eval.num_iterations to ensure that eval terminates.
    seconds_between_checkpoint_polls: The amount of time in seconds to wait
      between polls to see if new checkpoints appear in the continuous setting.
    num_retries: Number of retries for reading checkpoints.
    log_measurements: Function to log measurements.

  Raises:
    IOError: on repeated failures to read checkpoints after all the retries.
  """
    split = os.path.split(saved_model_dir)
    # Remove trailing slash if we have one.
    if not split[-1]:
        saved_model_dir = split[0]

    env = env_load_fn(env_name)

    # Load saved model.
    saved_model_path = os.path.join(saved_model_dir, 'saved_model.pb')
    while continuous and not tf.io.gfile.exists(saved_model_path):
        logging.info(
            'Waiting on the first checkpoint to become available at: %s',
            saved_model_path)
        time.sleep(seconds_between_checkpoint_polls)

    for _ in range(num_retries):
        try:
            policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
                saved_model_dir, load_specs_from_pbtxt=True)
            break
        except (tf.errors.OpError, tf.errors.DataLossError, IndexError,
                FileNotFoundError):
            logging.warning(
                'Encountered an error while loading a policy. This can '
                'happen when reading a checkpoint before it is fully written. '
                'Retrying...')
            time.sleep(seconds_between_checkpoint_polls)
    else:
        logging.error('Failed to load a checkpoint after retrying: %s',
                      saved_model_dir)

    if max_train_step and policy.get_train_step() > max_train_step:
        logging.info(
            'Policy train_step (%d) > max_train_step (%d). No evaluations performed.',
            policy.get_train_step(), max_train_step)
        return

    # Assume saved_model dir is of the form: root_dir/policies/greedy_policy. This
    # requires going up two levels to get the root_dir.
    root_dir = os.path.dirname(os.path.dirname(saved_model_dir))
    log_dir = eval_log_dir or os.path.join(root_dir, 'eval')

    # evaluated_file = os.path.join(log_dir, EVALUATED_STEPS_FILE)
    evaluated_checkpoints = set()

    train_step = tf.Variable(policy.get_train_step(), dtype=tf.int64)
    metrics = actor.eval_metrics(buffer_size=num_episodes)
    eval_actor = actor.Actor(env,
                             policy,
                             train_step,
                             metrics=metrics,
                             episodes_per_run=num_episodes,
                             summary_dir=log_dir)

    checkpoint_list = _get_checkpoints_to_evaluate(evaluated_checkpoints,
                                                   saved_model_dir)

    latest_eval_step = policy.get_train_step()
    while (checkpoint_list
           or continuous) and latest_eval_step < max_train_step:
        while not checkpoint_list and continuous:
            logging.info('Waiting on new checkpoints to become available.')
            time.sleep(seconds_between_checkpoint_polls)
            checkpoint_list = _get_checkpoints_to_evaluate(
                evaluated_checkpoints, saved_model_dir)
        checkpoint = checkpoint_list.pop()
        for _ in range(num_retries):
            try:
                policy.update_from_checkpoint(checkpoint)
                break
            except (tf.errors.OpError, IndexError):
                logging.warning(
                    'Encountered an error while evaluating a checkpoint. This can '
                    'happen when reading a checkpoint before it is fully written. '
                    'Retrying...')
                time.sleep(seconds_between_checkpoint_polls)
        else:
            # This seems to happen rarely. Just skip this checkpoint.
            logging.error('Failed to evaluate checkpoint after retrying: %s',
                          checkpoint)
            continue

        logging.info('Evaluating:\n\tStep:%d\tcheckpoint: %s',
                     policy.get_train_step(), checkpoint)
        eval_actor.train_step.assign(policy.get_train_step())

        train_step = policy.get_train_step()
        if triggers.ENV_STEP_METADATA_KEY in policy.get_metadata():
            env_step = policy.get_metadata()[
                triggers.ENV_STEP_METADATA_KEY].numpy()
            eval_actor.training_env_step = env_step

        if latest_eval_step <= train_step:
            eval_actor.run_and_log()
            latest_eval_step = policy.get_train_step()
        else:
            logging.info(
                'Skipping over train_step %d to avoid logging backwards in time.',
                train_step)
        evaluated_checkpoints.add(checkpoint)
Пример #12
0
def load_policy(saved_model_dir,
                max_train_step,
                seconds_between_checkpoint_polls=5,
                num_retries=10):
    """Loads the latest checkpoint in a directory.

  Checkpoints for the saved model to evaluate are assumed to be at the same
  directory level as the saved_model dir. ie:

  * saved_model_dir: root_dir/policies/greedy_policy
  * checkpoints_dir: root_dir/checkpoints

  Args:
    saved_model_dir: String path to the saved model directory.
    max_train_step: Int, Maximum number of train step.
    seconds_between_checkpoint_polls: The amount of time in seconds to wait
      between polls to see if new checkpoints appear in the continuous setting.
    num_retries: Number of retries for reading checkpoints.

  Returns:
    Policy loaded from the latest checkpoint in saved_model_dir.

  Raises:
    IOError: on repeated failures to read checkpoints after all the retries.
  """
    split = os.path.split(saved_model_dir)
    # Remove trailing slash if we have one.
    if not split[-1]:
        saved_model_dir = split[0]
    for _ in range(num_retries):
        try:
            policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
                saved_model_dir, load_specs_from_pbtxt=True)
            break
        except (tf.errors.OpError, tf.errors.DataLossError, IndexError,
                FileNotFoundError):
            logging.warning(
                'Encountered an error while loading a policy. This can '
                'happen when reading a checkpoint before it is fully written. '
                'Retrying...')
            time.sleep(seconds_between_checkpoint_polls)
    else:
        logging.error('Failed to load a checkpoint after retrying: %s',
                      saved_model_dir)

    checkpoint_list = _get_checkpoints_to_evaluate(set(), saved_model_dir)
    checkpoint_numbers = [int(ckpt.split('_')[-1]) for ckpt in checkpoint_list]
    checkpoint_list = [
        ckpt for ckpt, num in zip(checkpoint_list, checkpoint_numbers)
        if num <= max_train_step
    ]
    latest_checkpoint = checkpoint_list.pop()
    assert int(
        latest_checkpoint.split('_')[-1]) <= max_train_step, 'Get a valid ckpt'

    for _ in range(num_retries):
        try:
            policy.update_from_checkpoint(latest_checkpoint)
            break
        except (tf.errors.OpError, IndexError):
            logging.warning(
                'Encountered an error while loading a checkpoint. This can '
                'happen when reading a checkpoint before it is fully written. '
                'Retrying...')
            time.sleep(seconds_between_checkpoint_polls)

    logging.info('Loading:\n\tStep:%d\tcheckpoint: %s',
                 policy.get_train_step(), latest_checkpoint)
    return policy
Пример #13
0
def main(_):
    if FLAGS.eager:
        tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    tf.random.set_seed(FLAGS.seed)
    # np.random.seed(FLAGS.seed)
    # random.seed(FLAGS.seed)

    if 'procgen' in FLAGS.env_name:
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=int(train_levels),
            start_level=0)

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

    hparam_str = utils.make_hparam_string(FLAGS.xm_parameters,
                                          algo_name=FLAGS.algo_name,
                                          seed=FLAGS.seed,
                                          task_name=FLAGS.env_name,
                                          ckpt_timesteps=FLAGS.ckpt_timesteps)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))

    if FLAGS.env_name.startswith('procgen'):
        # map env string to digit [1,16]
        env_id = [
            i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
        ][0] + 1
        if FLAGS.ckpt_timesteps == 10_000_000:
            ckpt_iter = '0000020480'
        elif FLAGS.ckpt_timesteps == 25_000_000:
            ckpt_iter = '0000051200'
        policy_weights_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/checkpoints/'
                              'policy_checkpoint_%s/' % (env_id, ckpt_iter))
        policy_def_dir = ('ppo_darts/'
                          '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
        model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            policy_def_dir,
            time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
            action_spec=env._action_spec,  # pylint: disable=protected-access
            policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
            info_spec=tf.TensorSpec(shape=(None, )),
            load_specs_from_pbtxt=False)
        model.update_from_checkpoint(policy_weights_dir)
        model = TfAgentsPolicy(model)
    else:
        if 'ddpg' in FLAGS.algo_name:
            model = ddpg.DDPG(env.observation_spec(),
                              env.action_spec(),
                              cross_norm='crossnorm' in FLAGS.algo_name)
        elif 'crr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='bin_max')
        elif 'awr' in FLAGS.algo_name:
            model = awr.AWR(env.observation_spec(),
                            env.action_spec(),
                            f='exp_mean')
        elif 'sac_v1' in FLAGS.algo_name:
            model = sac_v1.SAC(env.observation_spec(),
                               env.action_spec(),
                               target_entropy=-env.action_spec().shape[0])
        elif 'asac' in FLAGS.algo_name:
            model = asac.ASAC(env.observation_spec(),
                              env.action_spec(),
                              target_entropy=-env.action_spec().shape[0])
        elif 'sac' in FLAGS.algo_name:
            model = sac.SAC(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0],
                            cross_norm='crossnorm' in FLAGS.algo_name,
                            pcl_actor_update='pc' in FLAGS.algo_name)
        elif 'pcl' in FLAGS.algo_name:
            model = pcl.PCL(env.observation_spec(),
                            env.action_spec(),
                            target_entropy=-env.action_spec().shape[0])
        if 'distractor' in FLAGS.env_name:
            ckpt_path = os.path.join(
                ('experiments/20210622_2023.policy_weights_sac'
                 '_1M_dmc_distractor_hard_pixel/'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))
        else:
            ckpt_path = os.path.join(
                ('experiments/20210607_2023.'
                 'policy_weights_dmc_1M_SAC_pixel'), 'results',
                FLAGS.env_name + '__' + str(FLAGS.ckpt_timesteps))

        model.load_weights(ckpt_path)
    print('Loaded model weights')

    with summary_writer.as_default():
        env = procgen_wrappers.TFAgentsParallelProcGenEnv(
            1,
            normalize_rewards=False,
            env_name=env_name,
            num_levels=0,
            start_level=0)
        (avg_returns,
         avg_len) = evaluation.evaluate(env,
                                        model,
                                        num_episodes=100,
                                        return_distributions=False)
        tf.summary.scalar('evaluation/returns-all', avg_returns, step=0)
        tf.summary.scalar('evaluation/length-all', avg_len, step=0)
Пример #14
0
def main(_):
    tf.config.experimental_run_functions_eagerly(FLAGS.eager)

    print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU')))
    if FLAGS.env_name.startswith('procgen'):
        print('Test env: %s' % FLAGS.env_name)
        _, env_name, train_levels, _ = FLAGS.env_name.split('-')
        print('Train env: %s' % FLAGS.env_name)
        env = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=int(train_levels),
                start_level=0))
        env_all = tf_py_environment.TFPyEnvironment(
            procgen_wrappers.TFAgentsParallelProcGenEnv(
                1,
                normalize_rewards=False,  # no normalization for evaluation
                env_name=env_name,
                num_levels=0,
                start_level=0))

        if int(train_levels) == 0:
            train_levels = '200'

    elif FLAGS.env_name.startswith('pixels-dm'):
        if 'distractor' in FLAGS.env_name:
            _, _, domain_name, _, _ = FLAGS.env_name.split('-')
        else:
            _, _, domain_name, _ = FLAGS.env_name.split('-')

        if domain_name in ['cartpole']:
            FLAGS.set_default('action_repeat', 8)
        elif domain_name in ['reacher', 'cheetah', 'ball_in_cup', 'hopper']:
            FLAGS.set_default('action_repeat', 4)
        elif domain_name in ['finger', 'walker']:
            FLAGS.set_default('action_repeat', 2)

        env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                FLAGS.action_repeat, FLAGS.frame_stack,
                                FLAGS.obs_type)

        if FLAGS.obs_type == 'pixels':
            env, _ = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)
        else:
            _, env = utils.load_env(FLAGS.env_name, FLAGS.seed,
                                    FLAGS.action_repeat, FLAGS.frame_stack,
                                    FLAGS.obs_type)

    if FLAGS.obs_type != 'state':
        if FLAGS.env_name.startswith('procgen'):
            bcq = bcq_pixel
            cql = cql_pixel
            fisher_brac = fisher_brac_pixel
            deepmdp = deepmdp_pixel
            vpn = vpn_pixel
            cssc = cssc_pixel
            pse = pse_pixel
    else:
        bcq = bcq_state
        cql = cql_state
    print('Loading dataset')

    # Use load_tfrecord_dataset_sequence to load transitions of size k>=2.
    if FLAGS.numpy_dataset:
        n_shards = 10

        def shard_fn(shard):
            return ('experiments/'
                    '20210617_0105.dataset_dmc_50k,100k,'
                    '200k_SAC_pixel_numpy/datasets/'
                    '%s__%d__%d__%d.npy' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps, FLAGS.max_timesteps,
                     shard))

        np_observer = tf_utils.NumpyObserver(shard_fn, env)
        dataset = np_observer.load(n_shards)
    else:
        if FLAGS.env_name.startswith('procgen'):
            if FLAGS.n_step_returns > 0:
                if FLAGS.max_timesteps == 100_000:
                    dataset_path = ('experiments/'
                                    '20210624_2033.dataset_procgen__ppo_pixel/'
                                    'datasets/%s__%d__%d.tfrecord' %
                                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                                     FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 3_000_000:
                    if int(train_levels) == 1:
                        print('Using dataset with 1 level')
                        dataset_path = (
                            'experiments/'
                            '20210713_1557.dataset_procgen__ppo_pixel_1_level/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                    elif int(train_levels) == 200:
                        print('Using dataset with 200 levels')
                        # Mixture dataset between 10M,15M,20M and 25M in equal amounts
                        # dataset_path = 'experiments/
                        # 20210718_1522.dataset_procgen__ppo_pixel_mixture10,15,20,25M/
                        # datasets/%s__%d__%d.tfrecord'%(FLAGS.env_name,
                        # FLAGS.ckpt_timesteps,FLAGS.max_timesteps)
                        # PPO after 25M steps
                        dataset_path = (
                            'experiments/'
                            '20210702_2234.dataset_procgen__ppo_pixel/'
                            'datasets/%s__%d__%d.tfrecord' %
                            (FLAGS.env_name, FLAGS.ckpt_timesteps,
                             FLAGS.max_timesteps))
                elif FLAGS.max_timesteps == 5_000_000:
                    # epsilon-greedy, eps: 0.1->0.001
                    dataset_path = (
                        'experiments/'
                        '20210805_1958.dataset_procgen__ppo_pixel_'
                        'egreedy_levelIDs/datasets/'
                        '%s__%d__%d.tfrecord*' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))
                    # Pure greedy (epsilon=0)
                    # dataset_path = ('experiments/'
                    #                 '20210820_1348.dataset_procgen__ppo_pixel_'
                    #                 'egreedy_levelIDs/datasets/'
                    #                 '%s__%d__%d.tfrecord*' %
                    #                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000))

        elif FLAGS.env_name.startswith('pixels-dm'):
            if 'distractor' in FLAGS.env_name:
                dataset_path = (
                    'experiments/'
                    '20210623_1749.dataset_dmc__sac_pixel/datasets/'
                    '%s__%d__%d.tfrecord' %
                    (FLAGS.env_name, FLAGS.ckpt_timesteps,
                     FLAGS.max_timesteps))
            else:
                if FLAGS.obs_type == 'pixels':
                    dataset_path = (
                        'experiments/'
                        '20210612_1644.dataset_dmc_50k,100k,200k_SAC_pixel/'
                        'datasets/%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
                else:
                    dataset_path = (
                        'experiments/'
                        '20210621_1436.dataset_dmc__SAC_pixel/datasets/'
                        '%s__%d__%d.tfrecord' %
                        (FLAGS.env_name, FLAGS.ckpt_timesteps,
                         FLAGS.max_timesteps))
        shards = tf.io.gfile.glob(dataset_path)
        shards = [s for s in shards if not s.endswith('.spec')]
        print('Found %d shards under path %s' % (len(shards), dataset_path))
        if FLAGS.n_step_returns > 1:
            # Load sequences of length N
            dataset = load_tfrecord_dataset_sequence(
                shards,
                buffer_size_per_shard=FLAGS.dataset_size // len(shards),
                deterministic=False,
                compress_image=True,
                seq_len=FLAGS.n_step_returns)  # spec=data_spec,
            dataset = dataset.take(FLAGS.dataset_size).shuffle(
                buffer_size=FLAGS.batch_size,
                reshuffle_each_iteration=False).batch(
                    FLAGS.batch_size,
                    drop_remainder=True).prefetch(1).repeat()

            dataset_iter = iter(dataset)
        else:
            dataset_iter = tf_utils.create_data_iterator(
                ('experiments/20210805'
                 '_1958.dataset_procgen__ppo_pixel_egreedy_'
                 'levelIDs/datasets/%s__%d__%d.tfrecord.shard-*-of-*' %
                 (FLAGS.env_name, FLAGS.ckpt_timesteps, 100000)),
                FLAGS.batch_size,
                shuffle_buffer_size=FLAGS.batch_size,
                obs_to_float=False)

    tf.random.set_seed(FLAGS.seed)

    hparam_str = utils.make_hparam_string(
        FLAGS.xm_parameters,
        algo_name=FLAGS.algo_name,
        seed=FLAGS.seed,
        task_name=FLAGS.env_name,
        ckpt_timesteps=FLAGS.ckpt_timesteps,
        rep_learn_keywords=FLAGS.rep_learn_keywords)
    summary_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'tb', hparam_str))
    result_writer = tf.summary.create_file_writer(
        os.path.join(FLAGS.save_dir, 'results', hparam_str))

    pretrain = (FLAGS.pretrain > 0)

    if FLAGS.env_name.startswith('procgen'):
        # disable entropy reg for discrete spaces
        action_dim = env.action_spec().maximum.item() + 1
    else:
        action_dim = env.action_spec().shape[0]
    if 'cql' in FLAGS.algo_name:
        model = cql.CQL(env.observation_spec(),
                        env.action_spec(),
                        reg=FLAGS.f_reg,
                        target_entropy=-action_dim,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size)
    elif 'bcq' in FLAGS.algo_name:
        model = bcq.BCQ(env.observation_spec(),
                        env.action_spec(),
                        num_augmentations=FLAGS.num_data_augs)
    elif 'fbrac' in FLAGS.algo_name:
        model = fisher_brac.FBRAC(env.observation_spec(),
                                  env.action_spec(),
                                  target_entropy=-action_dim,
                                  f_reg=FLAGS.f_reg,
                                  reward_bonus=FLAGS.reward_bonus,
                                  num_augmentations=FLAGS.num_data_augs,
                                  env_name=FLAGS.env_name,
                                  batch_size=FLAGS.batch_size)
    elif 'ours' in FLAGS.algo_name:
        model = ours.OURS(env.observation_spec(),
                          env.action_spec(),
                          target_entropy=-action_dim,
                          f_reg=FLAGS.f_reg,
                          reward_bonus=FLAGS.reward_bonus,
                          num_augmentations=FLAGS.num_data_augs,
                          env_name=FLAGS.env_name,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size,
                          n_quantiles=FLAGS.n_quantiles,
                          temp=FLAGS.temp,
                          num_training_levels=train_levels)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            model_save_path = os.path.join(FLAGS.save_dir, 'weights',
                                           hparam_str)
            checkpoint = tf.train.Checkpoint(**model.model_dict)
            tf_step_counter = tf.Variable(0, dtype=tf.int32)
            manager = tf.train.CheckpointManager(
                checkpoint,
                directory=model_save_path,
                max_to_keep=1,
                checkpoint_interval=FLAGS.save_interval,
                step_counter=tf_step_counter)

            # Load the checkpoint in case it exists
            state = manager.restore_or_initialize()
            if state is not None:
                # loaded variables from checkpoint folder
                timesteps_already_done = int(
                    re.findall('ckpt-([0-9]*)',
                               state)[0])  #* FLAGS.save_interval
                print('Loaded model from timestep %d' % timesteps_already_done)
            else:
                print('Training from scratch')
                timesteps_already_done = 0

            tf_step_counter.assign(timesteps_already_done)

            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                # (quantile_states, quantile_bins)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

                tf_step_counter.assign(i)
                manager.save(checkpoint_number=i)
    elif 'bc' in FLAGS.algo_name:
        model = bc_pixel.BehavioralCloning(
            env.observation_spec(),
            env.action_spec(),
            mixture=False,
            encoder=None,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            env_name=FLAGS.env_name,
            batch_size=FLAGS.batch_size)
    elif 'deepmdp' in FLAGS.algo_name:
        model = deepmdp.DeepMdpLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            num_distributions=1,
            sequence_length=2,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'vpn' in FLAGS.algo_name:
        model = vpn.ValuePredictionNetworkLearner(
            env.observation_spec(),
            env.action_spec(),
            embedding_dim=512,
            learning_rate=3e-4,
            num_augmentations=FLAGS.num_data_augs,
            rep_learn_keywords=FLAGS.rep_learn_keywords,
            batch_size=FLAGS.batch_size)
    elif 'cssc' in FLAGS.algo_name:
        model = cssc.CSSC(env.observation_spec(),
                          env.action_spec(),
                          embedding_dim=512,
                          actor_lr=3e-4,
                          critic_lr=3e-4,
                          num_augmentations=FLAGS.num_data_augs,
                          rep_learn_keywords=FLAGS.rep_learn_keywords,
                          batch_size=FLAGS.batch_size)
    elif 'pse' in FLAGS.algo_name:
        model = pse.PSE(env.observation_spec(),
                        env.action_spec(),
                        embedding_dim=512,
                        actor_lr=3e-4,
                        critic_lr=3e-4,
                        num_augmentations=FLAGS.num_data_augs,
                        rep_learn_keywords=FLAGS.rep_learn_keywords,
                        batch_size=FLAGS.batch_size,
                        temperature=FLAGS.temp)
        bc_pretraining_steps = FLAGS.pretrain
        if pretrain:
            print('Pretraining')
            for i in tqdm.tqdm(range(bc_pretraining_steps)):
                info_dict = model.update_step(dataset_iter,
                                              train_target='encoder')
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'pretrain/{k}', v, step=i)

    if 'fbrac' in FLAGS.algo_name or FLAGS.algo_name == 'bc':
        # Either load the online policy:
        if FLAGS.load_bc and FLAGS.env_name.startswith('procgen'):
            env_id = [
                i for i, name in enumerate(PROCGEN_ENVS) if name == env_name
            ][0] + 1  # map env string to digit [1,16]
            if FLAGS.ckpt_timesteps == 10_000_000:
                ckpt_iter = '0000020480'
            elif FLAGS.ckpt_timesteps == 25_000_000:
                ckpt_iter = '0000051200'
            policy_weights_dir = (
                'ppo_darts/'
                '2021-06-22-16-36-54/%d/policies/checkpoints/'
                'policy_checkpoint_%s/' % (env_id, ckpt_iter))
            policy_def_dir = ('ppo_darts/'
                              '2021-06-22-16-36-54/%d/policies/policy/' %
                              (env_id))
            bc = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
                policy_def_dir,
                time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
                action_spec=env._action_spec,  # pylint: disable=protected-access
                policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
                info_spec=tf.TensorSpec(shape=(None, )),
                load_specs_from_pbtxt=False)
            bc.update_from_checkpoint(policy_weights_dir)
            model.bc.policy = tf_utils.TfAgentsPolicy(bc)
        else:
            if FLAGS.algo_name == 'fbrac':
                bc_pretraining_steps = 100_000
            elif FLAGS.algo_name == 'bc':
                bc_pretraining_steps = 1_000_000

            if 'fbrac' in FLAGS.algo_name:
                bc = model.bc
            else:
                bc = model
            for i in tqdm.tqdm(range(bc_pretraining_steps)):

                info_dict = bc.update_step(dataset_iter)
                if i % FLAGS.log_interval == 0:
                    with summary_writer.as_default():
                        for k, v in info_dict.items():
                            v = tf.reduce_mean(v)
                            tf.summary.scalar(f'bc/{k}', v, step=i)

                if FLAGS.algo_name == 'bc':
                    if (i + 1) % FLAGS.eval_interval == 0:
                        average_returns, average_length = evaluation.evaluate(
                            env, bc)  # (FLAGS.env_name.startswith('procgen'))
                        average_returns_all, average_length_all = evaluation.evaluate(
                            env_all, bc)

                        with result_writer.as_default():
                            tf.summary.scalar('evaluation/returns',
                                              average_returns,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length',
                                              average_length,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/returns-all',
                                              average_returns_all,
                                              step=i + 1)
                            tf.summary.scalar('evaluation/length-all',
                                              average_length_all,
                                              step=i + 1)

    if FLAGS.algo_name == 'bc':
        exit()

    if not (FLAGS.algo_name == 'ours' and pretrain):
        model_save_path = os.path.join(FLAGS.save_dir, 'weights', hparam_str)
        checkpoint = tf.train.Checkpoint(**model.model_dict)
        tf_step_counter = tf.Variable(0, dtype=tf.int32)
        manager = tf.train.CheckpointManager(
            checkpoint,
            directory=model_save_path,
            max_to_keep=1,
            checkpoint_interval=FLAGS.save_interval,
            step_counter=tf_step_counter)

        # Load the checkpoint in case it exists
        weights_path = tf.io.gfile.glob(model_save_path + '/ckpt-*.index')
        key_fn = lambda x: int(re.findall(r'(\d+)', x)[-1])
        weights_path.sort(key=key_fn)
        if weights_path:
            weights_path = weights_path[-1]  # take most recent
        state = manager.restore_or_initialize()  # restore(weights_path)
        if state is not None:
            # loaded variables from checkpoint folder
            timesteps_already_done = int(
                re.findall('ckpt-([0-9]*)', state)[0])  #* FLAGS.save_interval
            print('Loaded model from timestep %d' % timesteps_already_done)
        else:
            print('Training from scratch')
            timesteps_already_done = 0

    tf_step_counter.assign(timesteps_already_done)

    for i in tqdm.tqdm(range(timesteps_already_done, FLAGS.num_updates)):
        with summary_writer.as_default():
            info_dict = model.update_step(
                dataset_iter, train_target='rl' if pretrain else 'both')
        if i % FLAGS.log_interval == 0:
            with summary_writer.as_default():
                for k, v in info_dict.items():
                    v = tf.reduce_mean(v)
                    tf.summary.scalar(f'training/{k}', v, step=i)

        if (i + 1) % FLAGS.eval_interval == 0:
            average_returns, average_length = evaluation.evaluate(env, model)
            average_returns_all, average_length_all = evaluation.evaluate(
                env_all, model)

            with result_writer.as_default():
                tf.summary.scalar('evaluation/returns-200',
                                  average_returns,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-200',
                                  average_length,
                                  step=i + 1)
                tf.summary.scalar('evaluation/returns-all',
                                  average_returns_all,
                                  step=i + 1)
                tf.summary.scalar('evaluation/length-all',
                                  average_length_all,
                                  step=i + 1)

        tf_step_counter.assign(i)
        manager.save(checkpoint_number=i)
Пример #15
0
def collect(task,
            root_dir,
            replay_buffer_server_address,
            variable_container_server_address,
            create_env_fn,
            initial_collect_steps=10000,
            num_iterations=10000000):
  """Collects experience using a policy updated after every episode."""
  # Create the environment. For now support only single environment collection.
  collect_env = create_env_fn()

  # Create the path for the serialized collect policy.
  collect_policy_saved_model_path = os.path.join(
      root_dir, learner.POLICY_SAVED_MODEL_DIR,
      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
  saved_model_pb_path = os.path.join(collect_policy_saved_model_path,
                                     'saved_model.pb')
  try:
    # Wait for the collect policy to be outputed by learner (timeout after 2
    # days), then load it.
    train_utils.wait_for_file(
        saved_model_pb_path, sleep_time_secs=2, num_retries=86400)
    collect_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        collect_policy_saved_model_path, load_specs_from_pbtxt=True)
  except TimeoutError as e:
    # If the collect policy does not become available during the wait time of
    # the call `wait_for_file`, that probably means the learner is not running.
    logging.error('Could not get the file %s. Exiting.', saved_model_pb_path)
    raise e

  # Create the variable container.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container = reverb_variable_container.ReverbVariableContainer(
      variable_container_server_address,
      table_names=[reverb_variable_container.DEFAULT_TABLE])
  variable_container.update(variables)

  # Create the replay buffer observer.
  rb_observer = reverb_utils.ReverbAddTrajectoryObserver(
      reverb.Client(replay_buffer_server_address),
      table_name=reverb_replay_buffer.DEFAULT_TABLE,
      sequence_length=2,
      stride_length=1)

  random_policy = random_py_policy.RandomPyPolicy(
      collect_env.time_step_spec(), collect_env.action_spec())
  initial_collect_actor = actor.Actor(
      collect_env,
      random_policy,
      train_step,
      steps_per_run=initial_collect_steps,
      observers=[rb_observer])
  logging.info('Doing initial collect.')
  initial_collect_actor.run()

  env_step_metric = py_metrics.EnvironmentSteps()
  collect_actor = actor.Actor(
      collect_env,
      collect_policy,
      train_step,
      steps_per_run=1,
      metrics=actor.collect_metrics(10),
      summary_dir=os.path.join(root_dir, learner.TRAIN_DIR, str(task)),
      observers=[rb_observer, env_step_metric])

  # Run the experience collection loop.
  for _ in range(num_iterations):
    logging.info('Collecting with policy at step: %d', train_step.numpy())
    collect_actor.run()
    variable_container.update(variables)
    def __init__(self, env, policy_dir="./predictor/models/policy"):
        self.policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
            policy_dir, env.time_step_spec(), env.action_spec())

        self.env = env
        self.time_step = self.env.reset()
Пример #17
0
  def load_model(checkpoint):
    checkpoint = int(checkpoint)
    print(checkpoint)
    if FLAGS.env_name.startswith('procgen'):
      env_id = [i for i, name in enumerate(
          PROCGEN_ENVS) if name == env_name][0]+1
      if checkpoint == 10_000_000:
        ckpt_iter = '0000020480'
      elif checkpoint == 15_000_000:
        ckpt_iter = '0000030720'
      elif checkpoint == 20_000_000:
        ckpt_iter = '0000040960'
      elif checkpoint == 25_000_000:
        ckpt_iter = '0000051200'
      policy_weights_dir = ('ppo_darts/'
                            '2021-06-22-16-36-54/%d/policies/checkpoints/'
                            'policy_checkpoint_%s/' % (env_id, ckpt_iter))
      policy_def_dir = ('ppo_darts/'
                        '2021-06-22-16-36-54/%d/policies/policy/' % (env_id))
      model = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
          policy_def_dir,
          time_step_spec=env._time_step_spec,  # pylint: disable=protected-access
          action_spec=env._action_spec,  # pylint: disable=protected-access
          policy_state_spec=env._observation_spec,  # pylint: disable=protected-access
          info_spec=tf.TensorSpec(shape=(None,)),
          load_specs_from_pbtxt=False)
      model.update_from_checkpoint(policy_weights_dir)
      model.actor = model.action
    else:
      if 'ddpg' in FLAGS.algo_name:
        model = ddpg.DDPG(
            env.observation_spec(),
            env.action_spec(),
            cross_norm='crossnorm' in FLAGS.algo_name)
      elif 'crr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='bin_max')
      elif 'awr' in FLAGS.algo_name:
        model = awr.AWR(
            env.observation_spec(),
            env.action_spec(), f='exp_mean')
      elif 'sac_v1' in FLAGS.algo_name:
        model = sac_v1.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'asac' in FLAGS.algo_name:
        model = asac.ASAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      elif 'sac' in FLAGS.algo_name:
        model = sac.SAC(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0],
            cross_norm='crossnorm' in FLAGS.algo_name,
            pcl_actor_update='pc' in FLAGS.algo_name)
      elif 'pcl' in FLAGS.algo_name:
        model = pcl.PCL(
            env.observation_spec(),
            env.action_spec(),
            target_entropy=-env.action_spec().shape[0])
      if 'distractor' in FLAGS.env_name:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210622_2023.policy_weights_sac_1M_dmc_distractor_hard_pixel/'),
            'results', FLAGS.env_name+'__'+str(checkpoint))
      else:
        ckpt_path = os.path.join(
            ('experiments/'
             '20210607_2023.policy_weights_dmc_1M_SAC_pixel'), 'results',
            FLAGS.env_name + '__' + str(checkpoint))

      model.load_weights(ckpt_path)
    print('Loaded model weights')
    return model
Пример #18
0
def main(_):
  logging.set_verbosity(logging.INFO)

  # Create the path for the serialized collect policy.
  collect_policy_saved_model_path = os.path.join(
      FLAGS.root_dir, learner.POLICY_SAVED_MODEL_DIR,
      learner.COLLECT_POLICY_SAVED_MODEL_DIR)
  saved_model_pb_path = os.path.join(collect_policy_saved_model_path,
                                     'saved_model.pb')

  samples_per_insert = FLAGS.samples_per_insert
  min_table_size_before_sampling = FLAGS.min_table_size_before_sampling

  try:
    # Wait for the collect policy to be outputed by learner (timeout after 2
    # days), then load it.
    train_utils.wait_for_file(
        saved_model_pb_path, sleep_time_secs=2, num_retries=86400)
    collect_policy = py_tf_eager_policy.SavedModelPyTFEagerPolicy(
        collect_policy_saved_model_path, load_specs_from_pbtxt=True)
  except TimeoutError as e:
    # If the collect policy does not become available during the wait time of
    # the call `wait_for_file`, that probably means the learner is not running.
    logging.error('Could not get the file %s. Exiting.', saved_model_pb_path)
    raise e

  # Create the signature for the variable container holding the policy weights.
  train_step = train_utils.create_train_step()
  variables = {
      reverb_variable_container.POLICY_KEY: collect_policy.variables(),
      reverb_variable_container.TRAIN_STEP_KEY: train_step
  }
  variable_container_signature = tf.nest.map_structure(
      lambda variable: tf.TensorSpec(variable.shape, dtype=variable.dtype),
      variables)
  logging.info('Signature of variables: \n%s', variable_container_signature)

  # Create the signature for the replay buffer holding observed experience.
  replay_buffer_signature = tensor_spec.from_spec(
      collect_policy.collect_data_spec)
  logging.info('Signature of experience: \n%s', replay_buffer_signature)

  if samples_per_insert is not None:
    # Use SamplesPerInsertRatio limiter
    samples_per_insert_tolerance = _SAMPLES_PER_INSERT_TOLERANCE_RATIO * samples_per_insert
    error_buffer = min_table_size_before_sampling * samples_per_insert_tolerance

    experience_rate_limiter = reverb.rate_limiters.SampleToInsertRatio(
        min_size_to_sample=min_table_size_before_sampling,
        samples_per_insert=samples_per_insert,
        error_buffer=error_buffer)
  else:
    # Use MinSize limiter
    experience_rate_limiter = reverb.rate_limiters.MinSize(
        min_table_size_before_sampling)

  # Crete and start the replay buffer and variable container server.
  server = reverb.Server(
      tables=[
          reverb.Table(  # Replay buffer storing experience.
              name=reverb_replay_buffer.DEFAULT_TABLE,
              sampler=reverb.selectors.Uniform(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=experience_rate_limiter,
              max_size=FLAGS.replay_buffer_capacity,
              max_times_sampled=0,
              signature=replay_buffer_signature,
          ),
          reverb.Table(  # Variable container storing policy parameters.
              name=reverb_variable_container.DEFAULT_TABLE,
              sampler=reverb.selectors.Uniform(),
              remover=reverb.selectors.Fifo(),
              rate_limiter=reverb.rate_limiters.MinSize(1),
              max_size=1,
              max_times_sampled=0,
              signature=variable_container_signature,
          ),
      ],
      port=FLAGS.port)
  server.wait()