Exemplo n.º 1
0
 def sac_policy_loss(
     self,
     log_probs: torch.Tensor,
     q1p_outs: Dict[str, torch.Tensor],
     loss_masks: torch.Tensor,
     discrete: bool,
 ) -> torch.Tensor:
     _ent_coef = torch.exp(self._log_ent_coef)
     mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
     if not discrete:
         mean_q1 = mean_q1.unsqueeze(1)
         batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     else:
         action_probs = log_probs.exp()
         branched_per_action_ent = ModelUtils.break_into_branches(
             log_probs * action_probs, self.act_size
         )
         branched_q_term = ModelUtils.break_into_branches(
             mean_q1 * action_probs, self.act_size
         )
         branched_policy_loss = torch.stack(
             [
                 torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
                 for i, (_lp, _qt) in enumerate(
                     zip(branched_per_action_ent, branched_q_term)
                 )
             ],
             dim=1,
         )
         batch_policy_loss = torch.squeeze(branched_policy_loss)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     return policy_loss
Exemplo n.º 2
0
    def sac_q_loss(
        self,
        q1_out: Dict[str, torch.Tensor],
        q2_out: Dict[str, torch.Tensor],
        target_values: Dict[str, torch.Tensor],
        dones: torch.Tensor,
        rewards: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1_losses = []
        q2_losses = []
        # Multiple q losses per stream
        for i, name in enumerate(q1_out.keys()):
            q1_stream = q1_out[name].squeeze()
            q2_stream = q2_out[name].squeeze()
            with torch.no_grad():
                q_backup = rewards[name] + (
                    (1.0 - self.use_dones_in_backup[name] * dones) *
                    self.gammas[i] * target_values[name])
            _q1_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks)
            _q2_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks)

            q1_losses.append(_q1_loss)
            q2_losses.append(_q2_loss)
        q1_loss = torch.mean(torch.stack(q1_losses))
        q2_loss = torch.mean(torch.stack(q2_losses))
        return q1_loss, q2_loss
Exemplo n.º 3
0
    def forward(
        self,
        obs_only: List[List[torch.Tensor]],
        obs: List[List[torch.Tensor]],
        actions: List[AgentAction],
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns sampled actions.
        If memory is enabled, return the memories as well.
        :param obs_only: Observations to be processed that do not have corresponding actions.
            These are encoded with the obs_encoder.
        :param obs: Observations to be processed that do have corresponding actions.
            After concatenation with actions, these are processed with obs_action_encoder.
        :param actions: After concatenation with obs, these are processed with obs_action_encoder.
        :param memories: If using memory, a Tensor of initial memories.
        :param sequence_length: If using memory, the sequence length.
        """
        self_attn_masks = []
        self_attn_inputs = []
        concat_f_inp = []
        if obs:
            obs_attn_mask = self._get_masks_from_nans(obs)
            obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
            for inputs, action in zip(obs, actions):
                encoded = self.observation_encoder(inputs)
                cat_encodes = [
                    encoded,
                    action.to_flat(self.action_spec.discrete_branches),
                ]
                concat_f_inp.append(torch.cat(cat_encodes, dim=1))
            f_inp = torch.stack(concat_f_inp, dim=1)
            self_attn_masks.append(obs_attn_mask)
            self_attn_inputs.append(self.obs_action_encoder(None, f_inp))

        concat_encoded_obs = []
        if obs_only:
            obs_only_attn_mask = self._get_masks_from_nans(obs_only)
            obs_only = self._copy_and_remove_nans_from_obs(
                obs_only, obs_only_attn_mask)
            for inputs in obs_only:
                encoded = self.observation_encoder(inputs)
                concat_encoded_obs.append(encoded)
            g_inp = torch.stack(concat_encoded_obs, dim=1)
            self_attn_masks.append(obs_only_attn_mask)
            self_attn_inputs.append(self.obs_encoder(None, g_inp))

        encoded_entity = torch.cat(self_attn_inputs, dim=1)
        encoded_state = self.self_attn(encoded_entity, self_attn_masks)

        encoding = self.linear_encoder(encoded_state)
        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
Exemplo n.º 4
0
 def _behavioral_cloning_loss(
     self,
     selected_actions: AgentAction,
     log_probs: ActionLogProbs,
     expert_actions: torch.Tensor,
 ) -> torch.Tensor:
     bc_loss = 0
     if self.policy.behavior_spec.action_spec.continuous_size > 0:
         bc_loss += torch.nn.functional.mse_loss(
             selected_actions.continuous_tensor, expert_actions.continuous_tensor
         )
     if self.policy.behavior_spec.action_spec.discrete_size > 0:
         one_hot_expert_actions = ModelUtils.actions_to_onehot(
             expert_actions.discrete_tensor,
             self.policy.behavior_spec.action_spec.discrete_branches,
         )
         log_prob_branches = ModelUtils.break_into_branches(
             log_probs.all_discrete_tensor,
             self.policy.behavior_spec.action_spec.discrete_branches,
         )
         bc_loss += torch.mean(
             torch.stack(
                 [
                     torch.sum(
                         -torch.nn.functional.log_softmax(log_prob_branch, dim=1)
                         * expert_actions_branch,
                         dim=1,
                     )
                     for log_prob_branch, expert_actions_branch in zip(
                         log_prob_branches, one_hot_expert_actions
                     )
                 ]
             )
         )
     return bc_loss
Exemplo n.º 5
0
    def sample_actions(
        self,
        vec_obs: List[torch.Tensor],
        vis_obs: List[torch.Tensor],
        masks: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        seq_len: int = 1,
        all_log_probs: bool = False,
    ) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
    ]:
        """
        :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
        """
        dists, value_heads, memories = self.actor_critic.get_dist_and_value(
            vec_obs, vis_obs, masks, memories, seq_len
        )
        action_list = self.actor_critic.sample_action(dists)
        log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
            action_list, dists
        )
        actions = torch.stack(action_list, dim=-1)
        if self.use_continuous_act:
            actions = actions[:, :, 0]
        else:
            actions = actions[:, 0, :]

        return (
            actions,
            all_logs if all_log_probs else log_probs,
            entropies,
            value_heads,
            memories,
        )
Exemplo n.º 6
0
 def forward(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, int, int, int, int]:
     """
     Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
     """
     dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
     if self.action_spec.is_continuous():
         action_list = self.sample_action(dists)
         action_out = torch.stack(action_list, dim=-1)
         if self._clip_action_on_export:
             action_out = torch.clamp(action_out, -3, 3) / 3
     else:
         action_out = torch.cat([dist.all_log_prob() for dist in dists],
                                dim=1)
     return (
         action_out,
         self.version_number,
         torch.Tensor([self.network_body.memory_size]),
         self.is_continuous_int,
         self.act_size_vector,
     )
Exemplo n.º 7
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Exemplo n.º 8
0
 def ppo_value_loss(
     self,
     values: Dict[str, torch.Tensor],
     old_values: Dict[str, torch.Tensor],
     returns: Dict[str, torch.Tensor],
     epsilon: float,
     loss_masks: torch.Tensor,
 ) -> torch.Tensor:
     """
     Evaluates value loss for PPO.
     :param values: Value output of the current network.
     :param old_values: Value stored with experiences in buffer.
     :param returns: Computed returns.
     :param epsilon: Clipping value for value estimate.
     :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
     """
     value_losses = []
     for name, head in values.items():
         old_val_tensor = old_values[name]
         returns_tensor = returns[name]
         clipped_value_estimate = old_val_tensor + torch.clamp(
             head - old_val_tensor, -1 * epsilon, epsilon)
         v_opt_a = (returns_tensor - head)**2
         v_opt_b = (returns_tensor - clipped_value_estimate)**2
         value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b),
                                             loss_masks)
         value_losses.append(value_loss)
     value_loss = torch.mean(torch.stack(value_losses))
     return value_loss
Exemplo n.º 9
0
def test_evaluate_actions(rnn, visual, discrete):
    policy = create_policy_mock(
        TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
    )
    buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
    act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK])
    agent_action = AgentAction.from_buffer(buffer)
    np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs))
    tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

    memories = [
        ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i])
        for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    log_probs, entropy, values = policy.evaluate_actions(
        tensor_obs,
        masks=act_masks,
        actions=agent_action,
        memories=memories,
        seq_len=policy.sequence_length,
    )
    if discrete:
        _size = policy.behavior_spec.action_spec.discrete_size
    else:
        _size = policy.behavior_spec.action_spec.continuous_size

    assert log_probs.flatten().shape == (64, _size)
    assert entropy.shape == (64,)
    for val in values.values():
        assert val.shape == (64,)
Exemplo n.º 10
0
def test_sample_actions(rnn, visual, discrete):
    policy = create_policy_mock(
        TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual
    )
    buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size)
    act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK])

    np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs))
    tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs]

    memories = [
        ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i])
        for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    (sampled_actions, log_probs, entropies, memories) = policy.sample_actions(
        tensor_obs, masks=act_masks, memories=memories, seq_len=policy.sequence_length
    )
    if discrete:
        assert log_probs.all_discrete_tensor.shape == (
            64,
            sum(policy.behavior_spec.action_spec.discrete_branches),
        )
    else:
        assert log_probs.continuous_tensor.shape == (
            64,
            policy.behavior_spec.action_spec.continuous_size,
        )
    assert entropies.shape == (64,)

    if rnn:
        assert memories.shape == (1, 1, policy.m_size)
Exemplo n.º 11
0
 def discrete_tensor(self) -> torch.Tensor:
     """
     Returns the discrete action list as a stacked tensor
     """
     if self.discrete_list is not None and len(self.discrete_list) > 0:
         return torch.stack(self.discrete_list, dim=-1)
     else:
         return torch.empty(0)
Exemplo n.º 12
0
    def sac_policy_loss(
        self,
        log_probs: ActionLogProbs,
        q1p_outs: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        _cont_ent_coef = _cont_ent_coef.exp()
        _disc_ent_coef = _disc_ent_coef.exp()

        mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
        batch_policy_loss = 0
        if self._action_spec.discrete_size > 0:
            disc_log_probs = log_probs.all_discrete_tensor
            disc_action_probs = disc_log_probs.exp()
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_q_term = ModelUtils.break_into_branches(
                mean_q1 * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_policy_loss = torch.stack(
                [
                    torch.sum(
                        _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False)
                    for i, (_lp, _qt) in enumerate(
                        zip(branched_per_action_ent, branched_q_term))
                ],
                dim=1,
            )
            batch_policy_loss += torch.sum(branched_policy_loss, dim=1)
            all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1)
        else:
            all_mean_q1 = mean_q1
        if self._action_spec.continuous_size > 0:
            cont_log_probs = log_probs.continuous_tensor
            batch_policy_loss += torch.mean(_cont_ent_coef * cont_log_probs -
                                            all_mean_q1.unsqueeze(1),
                                            dim=1)
        policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)

        return policy_loss
Exemplo n.º 13
0
def test_sample_actions(rnn, visual, discrete):
    policy = create_policy_mock(TrainerSettings(),
                                use_rnn=rnn,
                                use_discrete=discrete,
                                use_visual=visual)
    buffer = mb.simulate_rollout(64,
                                 policy.behavior_spec,
                                 memory_size=policy.m_size)
    vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
    act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])

    vis_obs = []
    for idx, _ in enumerate(
            policy.actor_critic.network_body.visual_processors):
        vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
        vis_obs.append(vis_ob)

    memories = [
        ModelUtils.list_to_tensor(buffer["memory"][i])
        for i in range(0, len(buffer["memory"]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    (
        sampled_actions,
        clipped_actions,
        log_probs,
        entropies,
        memories,
    ) = policy.sample_actions(
        vec_obs,
        vis_obs,
        masks=act_masks,
        memories=memories,
        seq_len=policy.sequence_length,
        all_log_probs=not policy.use_continuous_act,
    )
    if discrete:
        assert log_probs.shape == (
            64,
            sum(policy.behavior_spec.action_spec.discrete_branches),
        )
    else:
        assert log_probs.shape == (
            64, policy.behavior_spec.action_spec.continuous_size)
        assert clipped_actions.shape == (
            64,
            policy.behavior_spec.action_spec.continuous_size,
        )
    assert entropies.shape == (64, )

    if rnn:
        assert memories.shape == (1, 1, policy.m_size)
Exemplo n.º 14
0
 def get_probs_and_entropy(
     action_list: List[torch.Tensor], dists: List[DistInstance]
 ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
     log_probs_list = []
     all_probs_list = []
     entropies_list = []
     for action, action_dist in zip(action_list, dists):
         log_prob = action_dist.log_prob(action)
         log_probs_list.append(log_prob)
         entropies_list.append(action_dist.entropy())
         if isinstance(action_dist, DiscreteDistInstance):
             all_probs_list.append(action_dist.all_log_prob())
     log_probs = torch.stack(log_probs_list, dim=-1)
     entropies = torch.stack(entropies_list, dim=-1)
     if not all_probs_list:
         log_probs = log_probs.squeeze(-1)
         entropies = entropies.squeeze(-1)
         all_probs = None
     else:
         all_probs = torch.cat(all_probs_list, dim=-1)
     return log_probs, entropies, all_probs
Exemplo n.º 15
0
 def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor:
     """
     Get attention masks by grabbing an arbitrary obs across all the agents
     Since these are raw obs, the padded values are still NaN
     """
     only_first_obs = [_all_obs[0] for _all_obs in obs_tensors]
     # Just get the first element in each obs regardless of its dimension. This will speed up
     # searching for NaNs.
     only_first_obs_flat = torch.stack(
         [_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1
     )
     # Get the mask from NaNs
     attn_mask = only_first_obs_flat.isnan().float()
     return attn_mask
Exemplo n.º 16
0
    def sample_actions(
        self,
        vec_obs: List[torch.Tensor],
        vis_obs: List[torch.Tensor],
        masks: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        seq_len: int = 1,
        all_log_probs: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
               torch.Tensor]:
        """
        :param vec_obs: List of vector observations.
        :param vis_obs: List of visual observations.
        :param masks: Loss masks for RNN, else None.
        :param memories: Input memories when using RNN, else None.
        :param seq_len: Sequence length when using RNN.
        :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action.
        :return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs),
            entropies, and output memories, all as Torch Tensors.
        """
        if memories is None:
            dists, memories = self.actor_critic.get_dists(
                vec_obs, vis_obs, masks, memories, seq_len)
        else:
            # If we're using LSTM. we need to execute the values to get the critic memories
            dists, _, memories = self.actor_critic.get_dist_and_value(
                vec_obs, vis_obs, masks, memories, seq_len)
        action_list = self.actor_critic.sample_action(dists)
        log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy(
            action_list, dists)
        actions = torch.stack(action_list, dim=-1)
        if self.use_continuous_act:
            actions = actions[:, :, 0]
        else:
            actions = actions[:, 0, :]
        # Use the sum of entropy across actions, not the mean
        entropy_sum = torch.sum(entropies, dim=1)

        if self._clip_action and self.use_continuous_act:
            clipped_action = torch.clamp(actions, -3, 3) / 3
        else:
            clipped_action = actions
        return (
            actions,
            clipped_action,
            all_logs if all_log_probs else log_probs,
            entropy_sum,
            memories,
        )
Exemplo n.º 17
0
    def _condense_q_streams(
            self, q_output: Dict[str, torch.Tensor],
            discrete_actions: torch.Tensor) -> Dict[str, torch.Tensor]:
        condensed_q_output = {}
        onehot_actions = ModelUtils.actions_to_onehot(discrete_actions,
                                                      self.act_size)
        for key, item in q_output.items():
            branched_q = ModelUtils.break_into_branches(item, self.act_size)
            only_action_qs = torch.stack([
                torch.sum(_act * _q, dim=1, keepdim=True)
                for _act, _q in zip(onehot_actions, branched_q)
            ])

            condensed_q_output[key] = torch.mean(only_action_qs, dim=0)
        return condensed_q_output
Exemplo n.º 18
0
def test_simple_transformer_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size)
    transformer = ResidualSelfAttention(embedding_size, [n_k])
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(list(transformer.parameters()) +
                                 list(l_layer.parameters()),
                                 lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(250):
        center = torch.rand((batch_size, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, [key])
        masks = EntityEmbeddings.get_masks([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.3
Exemplo n.º 19
0
def test_evaluate_actions(rnn, visual, discrete):
    policy = create_policy_mock(TrainerSettings(),
                                use_rnn=rnn,
                                use_discrete=discrete,
                                use_visual=visual)
    buffer = mb.simulate_rollout(64,
                                 policy.behavior_spec,
                                 memory_size=policy.m_size)
    vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])]
    act_masks = ModelUtils.list_to_tensor(buffer["action_mask"])
    if policy.use_continuous_act:
        actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1)
    else:
        actions = ModelUtils.list_to_tensor(buffer["actions"],
                                            dtype=torch.long)
    vis_obs = []
    for idx, _ in enumerate(
            policy.actor_critic.network_body.visual_processors):
        vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx])
        vis_obs.append(vis_ob)

    memories = [
        ModelUtils.list_to_tensor(buffer["memory"][i])
        for i in range(0, len(buffer["memory"]), policy.sequence_length)
    ]
    if len(memories) > 0:
        memories = torch.stack(memories).unsqueeze(0)

    log_probs, entropy, values = policy.evaluate_actions(
        vec_obs,
        vis_obs,
        masks=act_masks,
        actions=actions,
        memories=memories,
        seq_len=policy.sequence_length,
    )
    if discrete:
        _size = policy.behavior_spec.action_spec.discrete_size
    else:
        _size = policy.behavior_spec.action_spec.continuous_size

    assert log_probs.shape == (64, _size)
    assert entropy.shape == (64, )
    for val in values.values():
        assert val.shape == (64, )
Exemplo n.º 20
0
 def _behavioral_cloning_loss(self, selected_actions, log_probs,
                              expert_actions):
     if self.policy.use_continuous_act:
         bc_loss = torch.nn.functional.mse_loss(selected_actions,
                                                expert_actions)
     else:
         log_prob_branches = ModelUtils.break_into_branches(
             log_probs, self.policy.act_size)
         bc_loss = torch.mean(
             torch.stack([
                 torch.sum(
                     -torch.nn.functional.log_softmax(
                         log_prob_branch, dim=1) * expert_actions_branch,
                     dim=1,
                 ) for log_prob_branch, expert_actions_branch in zip(
                     log_prob_branches, expert_actions)
             ]))
     return bc_loss
Exemplo n.º 21
0
def test_predict_closest_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 200
    for _ in range(200):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = get_zero_entities_mask([key])
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.02
    def sac_entropy_loss(
        self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        entropy_loss = 0
        if self._action_spec.discrete_size > 0:
            with torch.no_grad():
                # Break continuous into separate branch
                disc_log_probs = log_probs.all_discrete_tensor
                branched_per_action_ent = ModelUtils.break_into_branches(
                    disc_log_probs * disc_log_probs.exp(),
                    self._action_spec.discrete_branches,
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy.discrete
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss += -1 * ModelUtils.masked_mean(
                torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
            )
        if self._action_spec.continuous_size > 0:
            with torch.no_grad():
                cont_log_probs = log_probs.continuous_tensor
                target_current_diff = torch.sum(
                    cont_log_probs + self.target_entropy.continuous, dim=1
                )
            # We update all the _cont_ent_coef as one block
            entropy_loss += -1 * ModelUtils.masked_mean(
                _cont_ent_coef * target_current_diff, loss_masks
            )

        return entropy_loss
Exemplo n.º 23
0
def test_all_masking(mask_value):
    # We make sure that a mask of all zeros or all ones will not trigger an error
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_k, = 3, 5
    embedding_size = 64
    entity_embeddings = EntityEmbedding(size, n_k, embedding_size)
    entity_embeddings.add_self_embedding(size)
    transformer = ResidualSelfAttention(embedding_size, n_k)
    l_layer = linear_layer(embedding_size, size)
    optimizer = torch.optim.Adam(
        list(entity_embeddings.parameters()) + list(transformer.parameters()) +
        list(l_layer.parameters()),
        lr=0.001,
        weight_decay=1e-6,
    )
    batch_size = 20
    for _ in range(5):
        center = torch.rand((batch_size, size))
        key = torch.rand((batch_size, n_k, size))
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((center.reshape(
                (batch_size, 1, size)) - key)**2,
                                 dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        embeddings = entity_embeddings(center, key)
        masks = [torch.ones_like(key[:, :, 0]) * mask_value]
        prediction = transformer.forward(embeddings, masks)
        prediction = l_layer(prediction)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
Exemplo n.º 24
0
def test_multi_head_attention_training():
    np.random.seed(1336)
    torch.manual_seed(1336)
    size, n_h, n_k, n_q = 3, 10, 5, 1
    embedding_size = 64
    mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size)
    optimizer = torch.optim.Adam(mha.parameters(), lr=0.001)
    batch_size = 200
    point_range = 3
    init_error = -1.0
    for _ in range(50):
        query = torch.rand(
            (batch_size, n_q, size)) * point_range * 2 - point_range
        key = torch.rand(
            (batch_size, n_k, size)) * point_range * 2 - point_range
        value = key
        with torch.no_grad():
            # create the target : The key closest to the query in euclidean distance
            distance = torch.sum((query - key)**2, dim=2)
            argmin = torch.argmin(distance, dim=1)
            target = []
            for i in range(batch_size):
                target += [key[i, argmin[i], :]]
            target = torch.stack(target, dim=0)
            target = target.detach()

        prediction, _ = mha.forward(query, key, value)
        prediction = prediction.reshape((batch_size, size))
        error = torch.mean((prediction - target)**2, dim=1)
        error = torch.mean(error) / 2
        if init_error == -1.0:
            init_error = error.item()
        else:
            assert error.item() < init_error
        print(error.item())
        optimizer.zero_grad()
        error.backward()
        optimizer.step()
    assert error.item() < 0.5
Exemplo n.º 25
0
 def forward(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, int, int, int, int]:
     """
     Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
     """
     dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
     action_list = self.sample_action(dists)
     sampled_actions = torch.stack(action_list, dim=-1)
     if self.act_type == ActionType.CONTINUOUS:
         action_out = sampled_actions
     else:
         action_out = dists[0].all_log_prob()
     return (
         action_out,
         self.version_number,
         torch.Tensor([self.network_body.memory_size]),
         self.is_continuous_int,
         self.act_size_vector,
     )
Exemplo n.º 26
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Performs update on model.
        :param batch: Batch of experiences.
        :param num_sequences: Number of sequences to process.
        :return: Results of update.
        """
        # Get decayed parameters
        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        decay_eps = self.decay_epsilon.get_value(
            self.policy.get_current_step())
        decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
        returns = {}
        old_values = {}
        for name in self.reward_signals:
            old_values[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.value_estimates_key(name)])
            returns[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.returns_key(name)])

        n_obs = len(self.policy.behavior_spec.observation_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
        actions = AgentAction.from_buffer(batch)

        memories = [
            ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in
            range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
        ]
        if len(memories) > 0:
            memories = torch.stack(memories).unsqueeze(0)

        # Get value memories
        value_memories = [
            ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
            for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]),
                           self.policy.sequence_length)
        ]
        if len(value_memories) > 0:
            value_memories = torch.stack(value_memories).unsqueeze(0)

        log_probs, entropy = self.policy.evaluate_actions(
            current_obs,
            masks=act_masks,
            actions=actions,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        values, _ = self.critic.critic_pass(
            current_obs,
            memories=value_memories,
            sequence_length=self.policy.sequence_length,
        )
        old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
        log_probs = log_probs.flatten()
        loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS],
                                               dtype=torch.bool)
        value_loss = self.ppo_value_loss(values, old_values, returns,
                                         decay_eps, loss_masks)
        policy_loss = self.ppo_policy_loss(
            ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]),
            log_probs,
            old_log_probs,
            loss_masks,
        )
        loss = (policy_loss + 0.5 * value_loss -
                decay_bet * ModelUtils.masked_mean(entropy, loss_masks))

        # Set optimizer learning rate
        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        update_stats = {
            # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
            # TODO: After PyTorch is default, change to something more correct.
            "Losses/Policy Loss": torch.abs(policy_loss).item(),
            "Losses/Value Loss": value_loss.item(),
            "Policy/Learning Rate": decay_lr,
            "Policy/Epsilon": decay_eps,
            "Policy/Beta": decay_bet,
        }

        for reward_provider in self.reward_signals.values():
            update_stats.update(reward_provider.update(batch))

        return update_stats
Exemplo n.º 27
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"])

        n_obs = len(self.policy.behavior_spec.sensor_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        actions = AgentAction.from_dict(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        offset = 1 if self.policy.sequence_length > 1 else 0
        next_memories_list = [
            ModelUtils.list_to_tensor(
                batch["memory"][i]
                [self.policy.m_size //
                 2:])  # only pass value part of memory to target network
            for i in range(offset, len(batch["memory"]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            next_memories = torch.stack(next_memories_list).unsqueeze(0)
        else:
            memories = None
            next_memories = None
        # Q network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(next_memories)
                      if next_memories is not None else None)

        # Copy normalizers from policy
        self.value_network.q1_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.value_network.q2_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor_critic.network_body)
        (
            sampled_actions,
            log_probs,
            _,
            value_estimates,
            _,
        ) = self.policy.actor_critic.get_action_stats_and_value(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.value_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.value_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            target_values, _ = self.target_network(
                next_obs,
                memories=next_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch["done"])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss + value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self.policy.actor_critic.critic,
                               self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats
Exemplo n.º 28
0
    def sac_value_loss(
        self,
        log_probs: ActionLogProbs,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _cont_ent_coef = self._log_ent_coef.continuous.exp()
            _disc_ent_coef = self._log_ent_coef.discrete.exp()
            for name in values.keys():
                if self._action_spec.discrete_size <= 0:
                    min_policy_qs[name] = torch.min(q1p_out[name],
                                                    q2p_out[name])
                else:
                    disc_action_probs = log_probs.all_discrete_tensor.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _q1p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q1p
                        ]),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q2p
                        ]),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if self._action_spec.discrete_size <= 0:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _cont_ent_coef * log_probs.continuous_tensor, dim=1)
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup),
                    loss_masks)
                value_losses.append(value_loss)
        else:
            disc_log_probs = log_probs.all_discrete_tensor
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_log_probs.exp(),
                self._action_spec.discrete_branches,
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack([
                torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True)
                for i, _lp in enumerate(branched_per_action_ent)
            ])
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0)
                    # Add continuous entropy bonus to minimum Q
                    if self._action_spec.continuous_size > 0:
                        v_backup += torch.sum(
                            _cont_ent_coef * log_probs.continuous_tensor,
                            dim=1,
                            keepdim=True,
                        )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name],
                                                 v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Exemplo n.º 29
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Performs update on model.
        :param batch: Batch of experiences.
        :param num_sequences: Number of sequences to process.
        :return: Results of update.
        """
        # Get decayed parameters
        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        decay_eps = self.decay_epsilon.get_value(
            self.policy.get_current_step())
        decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
        returns = {}
        old_values = {}
        for name in self.reward_signals:
            old_values[name] = ModelUtils.list_to_tensor(
                batch[f"{name}_value_estimates"])
            returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"])

        vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        if self.policy.use_continuous_act:
            actions = ModelUtils.list_to_tensor(
                batch["actions_pre"]).unsqueeze(-1)
        else:
            actions = ModelUtils.list_to_tensor(batch["actions"],
                                                dtype=torch.long)

        memories = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        if len(memories) > 0:
            memories = torch.stack(memories).unsqueeze(0)

        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                    self.policy.actor_critic.network_body.visual_processors):
                vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
        else:
            vis_obs = []
        log_probs, entropy, values = self.policy.evaluate_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            actions=actions,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        loss_masks = ModelUtils.list_to_tensor(batch["masks"],
                                               dtype=torch.bool)
        value_loss = self.ppo_value_loss(values, old_values, returns,
                                         decay_eps, loss_masks)
        policy_loss = self.ppo_policy_loss(
            ModelUtils.list_to_tensor(batch["advantages"]),
            log_probs,
            ModelUtils.list_to_tensor(batch["action_probs"]),
            loss_masks,
        )
        loss = (policy_loss + 0.5 * value_loss -
                decay_bet * ModelUtils.masked_mean(entropy, loss_masks))

        # Set optimizer learning rate
        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        update_stats = {
            # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
            # TODO: After PyTorch is default, change to something more correct.
            "Losses/Policy Loss": torch.abs(policy_loss).item(),
            "Losses/Value Loss": value_loss.item(),
            "Policy/Learning Rate": decay_lr,
            "Policy/Epsilon": decay_eps,
            "Policy/Beta": decay_bet,
        }

        for reward_provider in self.reward_signals.values():
            update_stats.update(reward_provider.update(batch))

        return update_stats
Exemplo n.º 30
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Updates model using buffer.
        :param num_sequences: Number of trajectories in batch.
        :param batch: Experience mini-batch.
        :param update_target: Whether or not to update target value network
        :param reward_signal_batches: Minibatches to use for updating the reward signals,
            indexed by name. If none, don't update the reward signals.
        :return: Output from update process.
        """
        rewards = {}
        for name in self.reward_signals:
            rewards[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.rewards_key(name)])

        n_obs = len(self.policy.behavior_spec.observation_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        next_obs = ObsUtil.from_buffer_next(batch, n_obs)
        # Convert to tensors
        next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs]

        act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
        actions = AgentAction.from_buffer(batch)

        memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in
            range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
        ]
        # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true.
        value_memories_list = [
            ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
            for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]),
                           self.policy.sequence_length)
        ]

        if len(memories_list) > 0:
            memories = torch.stack(memories_list).unsqueeze(0)
            value_memories = torch.stack(value_memories_list).unsqueeze(0)
        else:
            memories = None
            value_memories = None

        # Q and V network memories are 0'ed out, since we don't have them during inference.
        q_memories = (torch.zeros_like(value_memories)
                      if value_memories is not None else None)

        # Copy normalizers from policy
        self.q_network.q1_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.q_network.q2_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self.target_network.network_body.copy_normalization(
            self.policy.actor.network_body)
        self._critic.network_body.copy_normalization(
            self.policy.actor.network_body)
        sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats(
            current_obs,
            masks=act_masks,
            memories=memories,
            sequence_length=self.policy.sequence_length,
        )
        value_estimates, _ = self._critic.critic_pass(
            current_obs,
            value_memories,
            sequence_length=self.policy.sequence_length)

        cont_sampled_actions = sampled_actions.continuous_tensor
        cont_actions = actions.continuous_tensor
        q1p_out, q2p_out = self.q_network(
            current_obs,
            cont_sampled_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
            q2_grad=False,
        )
        q1_out, q2_out = self.q_network(
            current_obs,
            cont_actions,
            memories=q_memories,
            sequence_length=self.policy.sequence_length,
        )

        if self._action_spec.discrete_size > 0:
            disc_actions = actions.discrete_tensor
            q1_stream = self._condense_q_streams(q1_out, disc_actions)
            q2_stream = self._condense_q_streams(q2_out, disc_actions)
        else:
            q1_stream, q2_stream = q1_out, q2_out

        with torch.no_grad():
            # Since we didn't record the next value memories, evaluate one step in the critic to
            # get them.
            if value_memories is not None:
                # Get the first observation in each sequence
                just_first_obs = [
                    _obs[::self.policy.sequence_length] for _obs in current_obs
                ]
                _, next_value_memories = self._critic.critic_pass(
                    just_first_obs, value_memories, sequence_length=1)
            else:
                next_value_memories = None
            target_values, _ = self.target_network(
                next_obs,
                memories=next_value_memories,
                sequence_length=self.policy.sequence_length,
            )
        masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS],
                                          dtype=torch.bool)
        dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE])

        q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values,
                                           dones, rewards, masks)
        value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out,
                                         q2p_out, masks)
        policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks)
        entropy_loss = self.sac_entropy_loss(log_probs, masks)

        total_value_loss = q1_loss + q2_loss
        if self.policy.shared_critic:
            policy_loss += value_loss
        else:
            total_value_loss += value_loss

        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr)
        self.policy_optimizer.zero_grad()
        policy_loss.backward()
        self.policy_optimizer.step()

        ModelUtils.update_learning_rate(self.value_optimizer, decay_lr)
        self.value_optimizer.zero_grad()
        total_value_loss.backward()
        self.value_optimizer.step()

        ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr)
        self.entropy_optimizer.zero_grad()
        entropy_loss.backward()
        self.entropy_optimizer.step()

        # Update target network
        ModelUtils.soft_update(self._critic, self.target_network, self.tau)
        update_stats = {
            "Losses/Policy Loss":
            policy_loss.item(),
            "Losses/Value Loss":
            value_loss.item(),
            "Losses/Q1 Loss":
            q1_loss.item(),
            "Losses/Q2 Loss":
            q2_loss.item(),
            "Policy/Discrete Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.discrete)).item(),
            "Policy/Continuous Entropy Coeff":
            torch.mean(torch.exp(self._log_ent_coef.continuous)).item(),
            "Policy/Learning Rate":
            decay_lr,
        }

        return update_stats