Beispiel #1
0
  def compute_probs(self, trajs, env_dict):
    """Compute the probability of the trajs for contexts provided."""
    contexts = [env_dict[t.env_name].context + 1 for t in trajs]
    contexts, context_lengths, _ = pad_sequences(contexts, 0)
    contexts = tf.stack(contexts, axis=0)
    context_lengths = np.array(context_lengths, dtype=np.int32)

    all_actions, sequence_length, maxlen = pad_sequences(
        [t.actions for t in trajs], 0)
    batch_actions = tf.stack(all_actions, axis=0)
    logits = self.compute_logits(
        contexts, maxlen, context_lengths, return_state=False)
    seq_neg_logprobs = tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, labels=batch_actions)
    seq_mask = tf.sequence_mask(sequence_length, dtype=seq_neg_logprobs.dtype)
    logprobs = -tf.reduce_sum(seq_neg_logprobs * seq_mask, axis=-1)
    probs = (tf.exp(logprobs)).numpy()
    return probs
Beispiel #2
0
 def create_batch(self, samples, contexts=None):
   """Helper method for creating batches of data."""
   batch_actions, batch_rews, batch_weights = [], [], []
   kwargs = {}
   # Padding required for recurrent policies
   trajs = [s.traj for s in samples]
   batch_actions, sequence_length, maxlen = pad_sequences(
       [t.actions for t in trajs], 0)
   batch_rews, _, _ = pad_sequences(
       [self._discount_rewards(t.rewards) for t in trajs], 0, maxlen)
   batch_weights = [[s.prob] * maxlen for s in samples]
   kwargs.update(sequence_length=np.array(sequence_length, dtype=np.int32))
   if contexts is None:
     raise ValueError('No Contexts passed.')
   contexts, context_lengths, _ = pad_sequences(contexts, 0)
   contexts = tf.stack(contexts, axis=0)
   context_lengths = np.array(context_lengths, dtype=np.int32)
   kwargs.update(contexts=contexts, context_lengths=context_lengths)
   batch_rews = np.array(batch_rews, dtype=np.float32)
   batch_actions = np.array(batch_actions, dtype=np.int32)  # [batch_size]
   return batch_actions, batch_rews, batch_weights, kwargs
Beispiel #3
0
  def sample_trajs(self, envs, greedy=False):
    env_names = [env.name for env in envs]
    env_dict = {env.name: env for env in envs}

    for env in envs:
      env.reset()
    rews = {env.name: [] for env in envs}
    actions = {env.name: [] for env in envs}

    kwargs = {'return_state': True}
    contexts = [env.context + 1 for env in envs]  # Zero is not a valid index
    contexts, context_lengths, _ = pad_sequences(contexts, 0)
    contexts = tf.stack(contexts, axis=0)
    context_lengths = np.array(context_lengths, dtype=np.int32)
    encoded_context, initial_state = self.pi.encode_context(
        contexts, context_lengths=context_lengths)
    kwargs.update(enc_output=encoded_context)
    # pylint: disable=protected-access
    model_fn = self.pi._call
    # pylint: enable=protected-access

    while env_names:
      logprobs, next_state = model_fn(
          num_inputs=1, initial_state=initial_state, **kwargs)
      logprobs = logprobs[:, 0]
      acs = self._sample_action(logprobs, greedy=greedy)
      dones = []
      new_env_names = []
      for ac, name in zip(acs, env_names):
        rew, done = env_dict[name].step(ac)
        actions[name].append(ac)
        rews[name].append(rew)
        dones.append(done)
        if not done:
          new_env_names.append(name)
      env_names = new_env_names

      if env_names:
        # Remove the states of `done` environments from recurrent_state only if
        # at least one of them is not `done`
        if isinstance(next_state, tf.Tensor):
          next_state = [
              next_state[i] for i, done in enumerate(dones) if not done
          ]
          initial_state = tf.stack(next_state, axis=0)
        else:  # LSTM Tuple
          raise NotImplementedError

        enc_output = kwargs['enc_output']
        enc_output = [enc_output[i] for i, done in enumerate(dones) if not done]
        kwargs['enc_output'] = tf.stack(enc_output, axis=0)

    env_names = set([env.name for env in envs])
    features = [
        create_joint_features(actions[name], env_dict[name].context)
        for name in env_names
    ]
    # pylint: disable=g-complex-comprehension
    trajs = [
        Traj(
            env_name=name,
            actions=actions[name],
            features=f,
            rewards=rews[name]) for name, f in zip(env_names, features)
    ]
    # pylint: enable=g-complex-comprehension
    return trajs