def compute_probs(self, trajs, env_dict): """Compute the probability of the trajs for contexts provided.""" contexts = [env_dict[t.env_name].context + 1 for t in trajs] contexts, context_lengths, _ = pad_sequences(contexts, 0) contexts = tf.stack(contexts, axis=0) context_lengths = np.array(context_lengths, dtype=np.int32) all_actions, sequence_length, maxlen = pad_sequences( [t.actions for t in trajs], 0) batch_actions = tf.stack(all_actions, axis=0) logits = self.compute_logits( contexts, maxlen, context_lengths, return_state=False) seq_neg_logprobs = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=batch_actions) seq_mask = tf.sequence_mask(sequence_length, dtype=seq_neg_logprobs.dtype) logprobs = -tf.reduce_sum(seq_neg_logprobs * seq_mask, axis=-1) probs = (tf.exp(logprobs)).numpy() return probs
def create_batch(self, samples, contexts=None): """Helper method for creating batches of data.""" batch_actions, batch_rews, batch_weights = [], [], [] kwargs = {} # Padding required for recurrent policies trajs = [s.traj for s in samples] batch_actions, sequence_length, maxlen = pad_sequences( [t.actions for t in trajs], 0) batch_rews, _, _ = pad_sequences( [self._discount_rewards(t.rewards) for t in trajs], 0, maxlen) batch_weights = [[s.prob] * maxlen for s in samples] kwargs.update(sequence_length=np.array(sequence_length, dtype=np.int32)) if contexts is None: raise ValueError('No Contexts passed.') contexts, context_lengths, _ = pad_sequences(contexts, 0) contexts = tf.stack(contexts, axis=0) context_lengths = np.array(context_lengths, dtype=np.int32) kwargs.update(contexts=contexts, context_lengths=context_lengths) batch_rews = np.array(batch_rews, dtype=np.float32) batch_actions = np.array(batch_actions, dtype=np.int32) # [batch_size] return batch_actions, batch_rews, batch_weights, kwargs
def sample_trajs(self, envs, greedy=False): env_names = [env.name for env in envs] env_dict = {env.name: env for env in envs} for env in envs: env.reset() rews = {env.name: [] for env in envs} actions = {env.name: [] for env in envs} kwargs = {'return_state': True} contexts = [env.context + 1 for env in envs] # Zero is not a valid index contexts, context_lengths, _ = pad_sequences(contexts, 0) contexts = tf.stack(contexts, axis=0) context_lengths = np.array(context_lengths, dtype=np.int32) encoded_context, initial_state = self.pi.encode_context( contexts, context_lengths=context_lengths) kwargs.update(enc_output=encoded_context) # pylint: disable=protected-access model_fn = self.pi._call # pylint: enable=protected-access while env_names: logprobs, next_state = model_fn( num_inputs=1, initial_state=initial_state, **kwargs) logprobs = logprobs[:, 0] acs = self._sample_action(logprobs, greedy=greedy) dones = [] new_env_names = [] for ac, name in zip(acs, env_names): rew, done = env_dict[name].step(ac) actions[name].append(ac) rews[name].append(rew) dones.append(done) if not done: new_env_names.append(name) env_names = new_env_names if env_names: # Remove the states of `done` environments from recurrent_state only if # at least one of them is not `done` if isinstance(next_state, tf.Tensor): next_state = [ next_state[i] for i, done in enumerate(dones) if not done ] initial_state = tf.stack(next_state, axis=0) else: # LSTM Tuple raise NotImplementedError enc_output = kwargs['enc_output'] enc_output = [enc_output[i] for i, done in enumerate(dones) if not done] kwargs['enc_output'] = tf.stack(enc_output, axis=0) env_names = set([env.name for env in envs]) features = [ create_joint_features(actions[name], env_dict[name].context) for name in env_names ] # pylint: disable=g-complex-comprehension trajs = [ Traj( env_name=name, actions=actions[name], features=f, rewards=rews[name]) for name, f in zip(env_names, features) ] # pylint: enable=g-complex-comprehension return trajs