Esempio n. 1
0
    def test_stack_sequence_fields(self):
        sequence = [{
            'action': np.array([1.0]),
            'observation': (np.array([0.0, 1.0, 2.0]), ),
            'reward': np.array(1.0),
        }, {
            'action': np.array([0.5]),
            'observation': (np.array([1.0, 2.0, 3.0]), ),
            'reward': np.array(0.0),
        }, {
            'action': np.array([0.3]),
            'observation': (np.array([2.0, 3.0, 4.0]), ),
            'reward': np.array(0.5),
        }]

        stacked = tf2_utils.stack_sequence_fields(sequence)

        self.assertIsInstance(stacked, dict)
        self.assertLen(stacked.keys(), 3)
        self.assertLen(stacked['observation'], 1)

        self.assertEqual(stacked['action'].shape, (3, 1))
        self.assertEqual(stacked['observation'][0].shape, (3, 3))
        self.assertEqual(stacked['reward'].shape, (3, ))
        self.assertEqual(stacked['observation'][0].tolist(),
                         [[0., 1., 2.], [1., 2., 3.], [2., 3., 4.]])
Esempio n. 2
0
def ope_evaluation(value_func,
                   policy_net,
                   environment,
                   num_init_samples,
                   mse_samples=0,
                   discount=0.99,
                   counter=None,
                   logger=None):
    """Run OPE evaluation."""
    mse = -1
    if mse_samples > 0:
        mse = cal_mse(value_func, policy_net, environment, mse_samples,
                      discount)

    # Compute policy value from initial distribution.
    # q_0s = []
    # for _ in range(num_init_samples):
    #     timestep = environment.reset()
    #     observation = tf2_utils.add_batch_dim(timestep.observation)
    #     action = policy_net(observation)
    #     q_0s.append(value_func(observation, action).numpy().squeeze())

    init_obs = []
    for _ in range(num_init_samples):
        timestep = environment.reset()
        init_obs.append(timestep.observation)
    init_obs = tf2_utils.stack_sequence_fields(init_obs)
    init_obs = tree.map_structure(tf.convert_to_tensor, init_obs)
    init_actions = policy_net(init_obs)
    q_0s = value_func(init_obs, init_actions).numpy().squeeze()

    results = {
        'Bellman_Residual_MSE': mse,
        'Q0_mean': np.mean(q_0s),
        'Q0_std_err': np.std(q_0s, ddof=0) / np.sqrt(len(q_0s)),
    }
    if counter is not None:
        counts = counter.increment(steps=1)
        results.update(counts)
    if logger is not None:
        logger.write(results)
    return results
Esempio n. 3
0
def calculate_priorities(
    priority_fns: base.PriorityFnMapping,
    steps: Union[base.Step, Sequence[base.Step]]) -> Dict[str, float]:
  """Helper used to calculate the priority of a sequence of steps.

  This converts the sequence of steps into a PriorityFnInput tuple where the
  components of each step (actions, observations, etc.) are stacked along the
  time dimension.

  Priorities are calculated for the sequence or transition that starts from
  step[0].next_observation. As a result, the stack of observations comes from
  steps[0:] whereas all other components (e.g. actions, rewards, discounts,
  extras) corresponds to steps[1:].

  Note: this means that all components other than the observation will be
  ignored from step[0]. This also means that step[0] is allowed to correspond to
  an "initial step" in which case the action, reward, discount, and extras are
  each None, which is handled properly by this function.

  Args:
    priority_fns: a mapping from table names to priority functions (i.e. a
      callable of type PriorityFn). The given function will be used to generate
      the priority (a float) for the given table.
    steps: a list of Step objects used to compute the priorities.

  Returns:
    A dictionary mapping from table names to the priority (a float) for the
    given collection of steps.
  """

  if isinstance(steps, list):
    steps = tf2_utils.stack_sequence_fields(steps)

  if any([priority_fn is not None for priority_fn in priority_fns.values()]):
    # Stack the steps and wrap them as PrioityFnInput.
    fn_input = base.PriorityFnInput(*steps)

  return {
      table: (priority_fn(fn_input) if priority_fn else 1.0)
      for table, priority_fn in priority_fns.items()
  }