Esempio n. 1
0
 def sac_policy_loss(
     self,
     log_probs: torch.Tensor,
     q1p_outs: Dict[str, torch.Tensor],
     loss_masks: torch.Tensor,
     discrete: bool,
 ) -> torch.Tensor:
     _ent_coef = torch.exp(self._log_ent_coef)
     mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
     if not discrete:
         mean_q1 = mean_q1.unsqueeze(1)
         batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     else:
         action_probs = log_probs.exp()
         branched_per_action_ent = ModelUtils.break_into_branches(
             log_probs * action_probs, self.act_size
         )
         branched_q_term = ModelUtils.break_into_branches(
             mean_q1 * action_probs, self.act_size
         )
         branched_policy_loss = torch.stack(
             [
                 torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
                 for i, (_lp, _qt) in enumerate(
                     zip(branched_per_action_ent, branched_q_term)
                 )
             ],
             dim=1,
         )
         batch_policy_loss = torch.squeeze(branched_policy_loss)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     return policy_loss
Esempio n. 2
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Esempio n. 3
0
    def ppo_policy_loss(
        self,
        advantages: torch.Tensor,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate PPO policy loss.
        :param advantages: Computed advantages.
        :param log_probs: Current policy probabilities
        :param old_log_probs: Past policy probabilities
        :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        advantage = advantages.unsqueeze(-1)

        decay_epsilon = self.hyperparameters.epsilon
        r_theta = torch.exp(log_probs - old_log_probs)
        p_opt_a = r_theta * advantage
        p_opt_b = (
            torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) *
            advantage)
        policy_loss = -1 * ModelUtils.masked_mean(torch.min(p_opt_a, p_opt_b),
                                                  loss_masks)
        return policy_loss
Esempio n. 4
0
    def forward(self, x_self: torch.Tensor,
                entities: torch.Tensor) -> torch.Tensor:
        num_entities = self.entity_num_max_elements
        if num_entities < 0:
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")
            num_entities = entities.shape[1]

        if exporting_to_onnx.is_exporting():
            # When exporting to ONNX, we want to transpose the entities. This is
            # because ONNX only support input in NCHW (channel first) format.
            # Barracuda also expect to get data in NCHW.
            entities = torch.transpose(entities, 2,
                                       1).reshape(-1, num_entities,
                                                  self.entity_size)

        if self.self_size > 0:
            expanded_self = x_self.reshape(-1, 1, self.self_size)
            expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
            # Concatenate all observations with self
            entities = torch.cat([expanded_self, entities], dim=2)
        # Encode entities
        encoded_entities = self.self_ent_encoder(entities)
        return encoded_entities
Esempio n. 5
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             if num_entities < 0:
                 if exporting_to_onnx.is_exporting():
                     raise UnityTrainerException(
                         "Trying to export an attention mechanism that doesn't have a set max \
                         number of elements.")
                 num_entities = ent.shape[1]
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     encoded_entities = self.embedding_norm(encoded_entities)
     return encoded_entities
Esempio n. 6
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     batch_size = visual_obs.shape[0]
     hidden = self.sequential(visual_obs)
     before_out = hidden.reshape(batch_size, -1)
     return torch.relu(self.dense(before_out))
Esempio n. 7
0
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_mask: Optional[torch.Tensor] = None,
        number_of_keys: int = -1,
        number_of_queries: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size
        # This is to avoid using .size() when possible as Barracuda does not support
        n_q = number_of_queries if number_of_queries != -1 else query.size(1)
        n_k = number_of_keys if number_of_keys != -1 else key.size(1)

        query = self.fc_q(query)  # (b, n_q, h*d)
        key = self.fc_k(key)  # (b, n_k, h*d)
        value = self.fc_v(value)  # (b, n_k, h*d)

        query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
        key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
        value = value.reshape(b, n_k, self.n_heads, self.embedding_size)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb)
        value_attention = value_attention.reshape(
            b, n_q, self.n_heads * self.embedding_size)  # (b, n_q, h*emb)

        out = self.fc_out(value_attention)  # (b, n_q, emb)
        return out, att
Esempio n. 8
0
 def actions_to_onehot(discrete_actions: torch.Tensor,
                       action_size: List[int]) -> List[torch.Tensor]:
     """
     Takes a tensor of discrete actions and turns it into a List of onehot encoding for each
     action.
     :param discrete_actions: Actions in integer form.
     :param action_size: List of branch sizes. Should be of same size as discrete_actions'
     last dimension.
     :return: List of one-hot tensors, one representing each branch.
     """
     onehot_branches = [
         torch.nn.functional.one_hot(_act.T, action_size[i]).float()
         for i, _act in enumerate(discrete_actions.long().T)
     ]
     return onehot_branches
Esempio n. 9
0
    def update(self, vector_input: torch.Tensor) -> None:
        steps_increment = vector_input.size()[0]
        total_new_steps = self.normalization_steps + steps_increment

        input_to_old_mean = vector_input - self.running_mean
        new_mean = self.running_mean + (input_to_old_mean /
                                        total_new_steps).sum(0)

        input_to_new_mean = vector_input - new_mean
        new_variance = self.running_variance + (input_to_new_mean *
                                                input_to_old_mean).sum(0)
        # Update in-place
        self.running_mean.data.copy_(new_mean.data)
        self.running_variance.data.copy_(new_variance.data)
        self.normalization_steps.data.copy_(total_new_steps.data)
Esempio n. 10
0
    def update(self, vector_input: torch.Tensor) -> None:
        with torch.no_grad():
            steps_increment = vector_input.size()[0]
            total_new_steps = self.normalization_steps + steps_increment

            input_to_old_mean = vector_input - self.running_mean
            new_mean: torch.Tensor = self.running_mean + (
                input_to_old_mean / total_new_steps).sum(0)

            input_to_new_mean = vector_input - new_mean
            new_variance = self.running_variance + (input_to_new_mean *
                                                    input_to_old_mean).sum(0)
            # Update references. This is much faster than in-place data update.
            self.running_mean: torch.Tensor = new_mean
            self.running_variance: torch.Tensor = new_variance
            self.normalization_steps: torch.Tensor = total_new_steps
Esempio n. 11
0
 def _copy_and_remove_nans_from_obs(
         self, all_obs: List[List[torch.Tensor]],
         attention_mask: torch.Tensor) -> List[List[torch.Tensor]]:
     """
     Helper function to remove NaNs from observations using an attention mask.
     """
     obs_with_no_nans = []
     for i_agent, single_agent_obs in enumerate(all_obs):
         no_nan_obs = []
         for obs in single_agent_obs:
             new_obs = obs.clone()
             new_obs[attention_mask.bool(
             )[:, i_agent], ::] = 0.0  # Remove NaNs fast
             no_nan_obs.append(new_obs)
         obs_with_no_nans.append(no_nan_obs)
     return obs_with_no_nans
Esempio n. 12
0
 def forward(self, x_self: torch.Tensor,
             entities: torch.Tensor) -> torch.Tensor:
     if self.self_size > 0:
         num_entities = self.entity_num_max_elements
         if num_entities < 0:
             if exporting_to_onnx.is_exporting():
                 raise UnityTrainerException(
                     "Trying to export an attention mechanism that doesn't have a set max \
                     number of elements.")
             num_entities = entities.shape[1]
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         # Concatenate all observations with self
         entities = torch.cat([expanded_self, entities], dim=2)
     # Encode entities
     encoded_entities = self.self_ent_encoder(entities)
     return encoded_entities
Esempio n. 13
0
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        n_q: int,
        n_k: int,
        key_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size

        query = query.reshape(b, n_q, self.n_heads,
                              self.head_size)  # (b, n_q, h, emb / h)
        key = key.reshape(b, n_k, self.n_heads,
                          self.head_size)  # (b, n_k, h, emb / h)
        value = value.reshape(b, n_k, self.n_heads,
                              self.head_size)  # (b, n_k, h, emb / h)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb / h)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb / h, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb / h, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb / h)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb / h)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb / h)
        value_attention = value_attention.reshape(
            b, n_q, self.embedding_size)  # (b, n_q, emb)

        return value_attention, att
Esempio n. 14
0
 def forward(
     self,
     x_self: torch.Tensor,
     entities: List[torch.Tensor],
     key_masks: List[torch.Tensor],
 ) -> torch.Tensor:
     # Gather the maximum number of entities information
     if self.entities_num_max_elements is None:
         self.entities_num_max_elements = []
         for ent in entities:
             self.entities_num_max_elements.append(ent.shape[1])
     # Concatenate all observations with self
     self_and_ent: List[torch.Tensor] = []
     for num_entities, ent in zip(self.entities_num_max_elements, entities):
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         # .repeat(
         #     1, num_entities, 1
         # )
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     # Generate the tensor that will serve as query, key and value to self attention
     qkv = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     mask = torch.cat(key_masks, dim=1)
     # Feed to self attention
     max_num_ent = sum(self.entities_num_max_elements)
     output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent,
                                max_num_ent)
     # Residual
     output = self.residual_layer(output) + qkv
     # Average Pooling
     numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1),
                           dim=1)
     denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
     output = numerator / denominator
     # Residual between x_self and the output of the module
     output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1))
     return output
Esempio n. 15
0
 def trust_region_policy_loss(
     advantages: torch.Tensor,
     log_probs: torch.Tensor,
     old_log_probs: torch.Tensor,
     loss_masks: torch.Tensor,
     epsilon: float,
 ) -> torch.Tensor:
     """
     Evaluate policy loss clipped to stay within a trust region. Used for PPO and POCA.
     :param advantages: Computed advantages.
     :param log_probs: Current policy probabilities
     :param old_log_probs: Past policy probabilities
     :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
     """
     advantage = advantages.unsqueeze(-1)
     r_theta = torch.exp(log_probs - old_log_probs)
     p_opt_a = r_theta * advantage
     p_opt_b = torch.clamp(r_theta, 1.0 - epsilon,
                           1.0 + epsilon) * advantage
     policy_loss = -1 * ModelUtils.masked_mean(torch.min(p_opt_a, p_opt_b),
                                               loss_masks)
     return policy_loss
Esempio n. 16
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     return encoded_entities
Esempio n. 17
0
 def to_numpy(tensor: torch.Tensor) -> np.ndarray:
     """
     Converts a Torch Tensor to a numpy array. If the Tensor is on the GPU, it will
     be brought to the CPU.
     """
     return tensor.detach().cpu().numpy()
Esempio n. 18
0
    def sac_value_loss(
        self,
        log_probs: torch.Tensor,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
        discrete: bool,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _ent_coef = torch.exp(self._log_ent_coef)
            for name in values.keys():
                if not discrete:
                    min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
                else:
                    action_probs = log_probs.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * action_probs, self.act_size
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * action_probs, self.act_size
                    )
                    _q1p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q1p
                            ]
                        ),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q2p
                            ]
                        ),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if not discrete:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _ent_coef * log_probs, dim=1
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
                )
                value_losses.append(value_loss)
        else:
            branched_per_action_ent = ModelUtils.break_into_branches(
                log_probs * log_probs.exp(), self.act_size
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack(
                [
                    torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
                    for i, _lp in enumerate(branched_per_action_ent)
                ]
            )
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Esempio n. 19
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = self.sequential(visual_obs)
     before_out = hidden.reshape(-1, self.final_flat_size)
     return torch.relu(self.dense(before_out))
Esempio n. 20
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = self.conv_layers(visual_obs)
     hidden = hidden.reshape([-1, self.final_flat])
     return self.dense(hidden)
Esempio n. 21
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = visual_obs.reshape(-1, self.input_size)
     return self.dense(hidden)
Esempio n. 22
0
    def _evaluate_by_sequence_team(
        self,
        self_obs: List[torch.Tensor],
        obs: List[List[torch.Tensor]],
        actions: List[AgentAction],
        init_value_mem: torch.Tensor,
        init_baseline_mem: torch.Tensor,
    ) -> Tuple[Dict[str, torch.Tensor], Dict[
            str, torch.Tensor], AgentBufferField, AgentBufferField,
               torch.Tensor, torch.Tensor, ]:
        """
        Evaluate a trajectory sequence-by-sequence, assembling the result. This enables us to get the
        intermediate memories for the critic.
        :param tensor_obs: A List of tensors of shape (trajectory_len, <obs_dim>) that are the agent's
            observations for this trajectory.
        :param initial_memory: The memory that preceeds this trajectory. Of shape (1,1,<mem_size>), i.e.
            what is returned as the output of a MemoryModules.
        :return: A Tuple of the value estimates as a Dict of [name, tensor], an AgentBufferField of the initial
            memories to be used during value function update, and the final memory at the end of the trajectory.
        """
        num_experiences = self_obs[0].shape[0]
        all_next_value_mem = AgentBufferField()
        all_next_baseline_mem = AgentBufferField()
        # In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and
        # trajectory is of length 10, the 1st sequence is [pad,pad,obs].
        # Compute the number of elements in this padded seq.
        leftover = num_experiences % self.policy.sequence_length

        # Compute values for the potentially truncated initial sequence

        first_seq_len = leftover if leftover > 0 else self.policy.sequence_length

        self_seq_obs = []
        groupmate_seq_obs = []
        groupmate_seq_act = []
        seq_obs = []
        for _self_obs in self_obs:
            first_seq_obs = _self_obs[0:first_seq_len]
            seq_obs.append(first_seq_obs)
        self_seq_obs.append(seq_obs)

        for groupmate_obs, groupmate_action in zip(obs, actions):
            seq_obs = []
            for _obs in groupmate_obs:
                first_seq_obs = _obs[0:first_seq_len]
                seq_obs.append(first_seq_obs)
            groupmate_seq_obs.append(seq_obs)
            _act = groupmate_action.slice(0, first_seq_len)
            groupmate_seq_act.append(_act)

        # For the first sequence, the initial memory should be the one at the
        # beginning of this trajectory.
        for _ in range(first_seq_len):
            all_next_value_mem.append(
                ModelUtils.to_numpy(init_value_mem.squeeze()))
            all_next_baseline_mem.append(
                ModelUtils.to_numpy(init_baseline_mem.squeeze()))

        all_seq_obs = self_seq_obs + groupmate_seq_obs
        init_values, _value_mem = self.critic.critic_pass(
            all_seq_obs, init_value_mem, sequence_length=first_seq_len)
        all_values = {
            signal_name: [init_values[signal_name]]
            for signal_name in init_values.keys()
        }

        groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
        init_baseline, _baseline_mem = self.critic.baseline(
            self_seq_obs[0],
            groupmate_obs_and_actions,
            init_baseline_mem,
            sequence_length=first_seq_len,
        )
        all_baseline = {
            signal_name: [init_baseline[signal_name]]
            for signal_name in init_baseline.keys()
        }

        # Evaluate other trajectories, carrying over _mem after each
        # trajectory
        for seq_num in range(
                1, math.ceil(
                    (num_experiences) / (self.policy.sequence_length))):
            for _ in range(self.policy.sequence_length):
                all_next_value_mem.append(
                    ModelUtils.to_numpy(_value_mem.squeeze()))
                all_next_baseline_mem.append(
                    ModelUtils.to_numpy(_baseline_mem.squeeze()))

            start = seq_num * self.policy.sequence_length - (
                self.policy.sequence_length - leftover)
            end = (seq_num + 1) * self.policy.sequence_length - (
                self.policy.sequence_length - leftover)

            self_seq_obs = []
            groupmate_seq_obs = []
            groupmate_seq_act = []
            seq_obs = []
            for _self_obs in self_obs:
                seq_obs.append(_obs[start:end])
            self_seq_obs.append(seq_obs)

            for groupmate_obs, team_action in zip(obs, actions):
                seq_obs = []
                for (_obs, ) in groupmate_obs:
                    first_seq_obs = _obs[start:end]
                    seq_obs.append(first_seq_obs)
                groupmate_seq_obs.append(seq_obs)
                _act = team_action.slice(start, end)
                groupmate_seq_act.append(_act)

            all_seq_obs = self_seq_obs + groupmate_seq_obs
            values, _value_mem = self.critic.critic_pass(
                all_seq_obs,
                _value_mem,
                sequence_length=self.policy.sequence_length)
            all_values = {
                signal_name: [init_values[signal_name]]
                for signal_name in values.keys()
            }

            groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
            baselines, _baseline_mem = self.critic.baseline(
                self_seq_obs[0],
                groupmate_obs_and_actions,
                _baseline_mem,
                sequence_length=first_seq_len,
            )
            all_baseline = {
                signal_name: [baselines[signal_name]]
                for signal_name in baselines.keys()
            }
        # Create one tensor per reward signal
        all_value_tensors = {
            signal_name: torch.cat(value_list, dim=0)
            for signal_name, value_list in all_values.items()
        }
        all_baseline_tensors = {
            signal_name: torch.cat(baseline_list, dim=0)
            for signal_name, baseline_list in all_baseline.items()
        }
        next_value_mem = _value_mem
        next_baseline_mem = _baseline_mem
        return (
            all_value_tensors,
            all_baseline_tensors,
            all_next_value_mem,
            all_next_baseline_mem,
            next_value_mem,
            next_baseline_mem,
        )