Beispiel #1
0
    def sac_q_loss(
        self,
        q1_out: Dict[str, torch.Tensor],
        q2_out: Dict[str, torch.Tensor],
        target_values: Dict[str, torch.Tensor],
        dones: torch.Tensor,
        rewards: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        q1_losses = []
        q2_losses = []
        # Multiple q losses per stream
        for i, name in enumerate(q1_out.keys()):
            q1_stream = q1_out[name].squeeze()
            q2_stream = q2_out[name].squeeze()
            with torch.no_grad():
                q_backup = rewards[name] + (
                    (1.0 - self.use_dones_in_backup[name] * dones) *
                    self.gammas[i] * target_values[name])
            _q1_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks)
            _q2_loss = 0.5 * ModelUtils.masked_mean(
                torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks)

            q1_losses.append(_q1_loss)
            q2_losses.append(_q2_loss)
        q1_loss = torch.mean(torch.stack(q1_losses))
        q2_loss = torch.mean(torch.stack(q2_losses))
        return q1_loss, q2_loss
Beispiel #2
0
    def sac_entropy_loss(
        self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool
    ) -> torch.Tensor:
        if not discrete:
            with torch.no_grad():
                target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1)
            entropy_loss = -1 * ModelUtils.masked_mean(
                self._log_ent_coef * target_current_diff, loss_masks
            )
        else:
            with torch.no_grad():
                branched_per_action_ent = ModelUtils.break_into_branches(
                    log_probs * log_probs.exp(), self.act_size
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss = -1 * ModelUtils.masked_mean(
                torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks
            )

        return entropy_loss
Beispiel #3
0
 def sac_policy_loss(
     self,
     log_probs: torch.Tensor,
     q1p_outs: Dict[str, torch.Tensor],
     loss_masks: torch.Tensor,
     discrete: bool,
 ) -> torch.Tensor:
     _ent_coef = torch.exp(self._log_ent_coef)
     mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
     if not discrete:
         mean_q1 = mean_q1.unsqueeze(1)
         batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     else:
         action_probs = log_probs.exp()
         branched_per_action_ent = ModelUtils.break_into_branches(
             log_probs * action_probs, self.act_size
         )
         branched_q_term = ModelUtils.break_into_branches(
             mean_q1 * action_probs, self.act_size
         )
         branched_policy_loss = torch.stack(
             [
                 torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True)
                 for i, (_lp, _qt) in enumerate(
                     zip(branched_per_action_ent, branched_q_term)
                 )
             ],
             dim=1,
         )
         batch_policy_loss = torch.squeeze(branched_policy_loss)
         policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)
     return policy_loss
Beispiel #4
0
def test_masked_mean():
    test_input = torch.tensor([1, 2, 3, 4, 5])
    masks = torch.ones_like(test_input).bool()
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 3.0

    masks = torch.tensor([False, False, True, True, True])
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 4.0

    # Make sure it works if all masks are off
    masks = torch.tensor([False, False, False, False, False])
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 0.0
Beispiel #5
0
    def ppo_policy_loss(
        self,
        advantages: torch.Tensor,
        log_probs: torch.Tensor,
        old_log_probs: torch.Tensor,
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        """
        Evaluate PPO policy loss.
        :param advantages: Computed advantages.
        :param log_probs: Current policy probabilities
        :param old_log_probs: Past policy probabilities
        :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
        """
        advantage = advantages.unsqueeze(-1)

        decay_epsilon = self.hyperparameters.epsilon

        r_theta = torch.exp(log_probs - old_log_probs)
        p_opt_a = r_theta * advantage
        p_opt_b = (
            torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) *
            advantage)
        policy_loss = -1 * ModelUtils.masked_mean(torch.min(p_opt_a, p_opt_b),
                                                  loss_masks)
        return policy_loss
Beispiel #6
0
 def ppo_value_loss(
     self,
     values: Dict[str, torch.Tensor],
     old_values: Dict[str, torch.Tensor],
     returns: Dict[str, torch.Tensor],
     epsilon: float,
     loss_masks: torch.Tensor,
 ) -> torch.Tensor:
     """
     Evaluates value loss for PPO.
     :param values: Value output of the current network.
     :param old_values: Value stored with experiences in buffer.
     :param returns: Computed returns.
     :param epsilon: Clipping value for value estimate.
     :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences.
     """
     value_losses = []
     for name, head in values.items():
         old_val_tensor = old_values[name]
         returns_tensor = returns[name]
         clipped_value_estimate = old_val_tensor + torch.clamp(
             head - old_val_tensor, -1 * epsilon, epsilon)
         v_opt_a = (returns_tensor - head)**2
         v_opt_b = (returns_tensor - clipped_value_estimate)**2
         value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b),
                                             loss_masks)
         value_losses.append(value_loss)
     value_loss = torch.mean(torch.stack(value_losses))
     return value_loss
    def sac_entropy_loss(
        self, log_probs: ActionLogProbs, loss_masks: torch.Tensor
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        entropy_loss = 0
        if self._action_spec.discrete_size > 0:
            with torch.no_grad():
                # Break continuous into separate branch
                disc_log_probs = log_probs.all_discrete_tensor
                branched_per_action_ent = ModelUtils.break_into_branches(
                    disc_log_probs * disc_log_probs.exp(),
                    self._action_spec.discrete_branches,
                )
                target_current_diff_branched = torch.stack(
                    [
                        torch.sum(_lp, axis=1, keepdim=True) + _te
                        for _lp, _te in zip(
                            branched_per_action_ent, self.target_entropy.discrete
                        )
                    ],
                    axis=1,
                )
                target_current_diff = torch.squeeze(
                    target_current_diff_branched, axis=2
                )
            entropy_loss += -1 * ModelUtils.masked_mean(
                torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks
            )
        if self._action_spec.continuous_size > 0:
            with torch.no_grad():
                cont_log_probs = log_probs.continuous_tensor
                target_current_diff = torch.sum(
                    cont_log_probs + self.target_entropy.continuous, dim=1
                )
            # We update all the _cont_ent_coef as one block
            entropy_loss += -1 * ModelUtils.masked_mean(
                _cont_ent_coef * target_current_diff, loss_masks
            )

        return entropy_loss
Beispiel #8
0
def test_masked_mean():
    test_input = torch.tensor([1, 2, 3, 4, 5])
    masks = torch.ones_like(test_input).bool()
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 3.0

    masks = torch.tensor([False, False, True, True, True])
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 4.0

    # Make sure it works if all masks are off
    masks = torch.tensor([False, False, False, False, False])
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 0.0

    # Make sure it works with 2d arrays of shape (mask_length, N)
    test_input = torch.tensor([1, 2, 3, 4, 5]).repeat(2, 1).T
    masks = torch.tensor([False, False, True, True, True])
    mean = ModelUtils.masked_mean(test_input, masks=masks)
    assert mean == 4.0
Beispiel #9
0
    def sac_policy_loss(
        self,
        log_probs: ActionLogProbs,
        q1p_outs: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        _cont_ent_coef, _disc_ent_coef = (
            self._log_ent_coef.continuous,
            self._log_ent_coef.discrete,
        )
        _cont_ent_coef = _cont_ent_coef.exp()
        _disc_ent_coef = _disc_ent_coef.exp()

        mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0)
        batch_policy_loss = 0
        if self._action_spec.discrete_size > 0:
            disc_log_probs = log_probs.all_discrete_tensor
            disc_action_probs = disc_log_probs.exp()
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_q_term = ModelUtils.break_into_branches(
                mean_q1 * disc_action_probs,
                self._action_spec.discrete_branches)
            branched_policy_loss = torch.stack(
                [
                    torch.sum(
                        _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False)
                    for i, (_lp, _qt) in enumerate(
                        zip(branched_per_action_ent, branched_q_term))
                ],
                dim=1,
            )
            batch_policy_loss += torch.sum(branched_policy_loss, dim=1)
            all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1)
        else:
            all_mean_q1 = mean_q1
        if self._action_spec.continuous_size > 0:
            cont_log_probs = log_probs.continuous_tensor
            batch_policy_loss += torch.mean(_cont_ent_coef * cont_log_probs -
                                            all_mean_q1.unsqueeze(1),
                                            dim=1)
        policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks)

        return policy_loss
Beispiel #10
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Performs update on model.
        :param batch: Batch of experiences.
        :param num_sequences: Number of sequences to process.
        :return: Results of update.
        """
        # Get decayed parameters
        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        decay_eps = self.decay_epsilon.get_value(
            self.policy.get_current_step())
        decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
        returns = {}
        old_values = {}
        for name in self.reward_signals:
            old_values[name] = ModelUtils.list_to_tensor(
                batch[f"{name}_value_estimates"])
            returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"])

        vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])]
        act_masks = ModelUtils.list_to_tensor(batch["action_mask"])
        if self.policy.use_continuous_act:
            actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1)
        else:
            actions = ModelUtils.list_to_tensor(batch["actions"],
                                                dtype=torch.long)

        memories = [
            ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(
                0, len(batch["memory"]), self.policy.sequence_length)
        ]
        if len(memories) > 0:
            memories = torch.stack(memories).unsqueeze(0)

        if self.policy.use_vis_obs:
            vis_obs = []
            for idx, _ in enumerate(
                    self.policy.actor_critic.network_body.visual_encoders):
                vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx])
                vis_obs.append(vis_ob)
        else:
            vis_obs = []
        log_probs, entropy, values = self.policy.evaluate_actions(
            vec_obs,
            vis_obs,
            masks=act_masks,
            actions=actions,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        loss_masks = ModelUtils.list_to_tensor(batch["masks"],
                                               dtype=torch.bool)
        value_loss = self.ppo_value_loss(values, old_values, returns,
                                         decay_eps, loss_masks)
        policy_loss = self.ppo_policy_loss(
            ModelUtils.list_to_tensor(batch["advantages"]),
            log_probs,
            ModelUtils.list_to_tensor(batch["action_probs"]),
            loss_masks,
        )
        loss = (policy_loss + 0.5 * value_loss -
                decay_bet * ModelUtils.masked_mean(entropy, loss_masks))

        # Set optimizer learning rate
        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        update_stats = {
            "Losses/Policy Loss": abs(policy_loss.detach().cpu().numpy()),
            "Losses/Value Loss": value_loss.detach().cpu().numpy(),
            "Policy/Learning Rate": decay_lr,
            "Policy/Epsilon": decay_eps,
            "Policy/Beta": decay_bet,
        }

        for reward_provider in self.reward_signals.values():
            update_stats.update(reward_provider.update(batch))

        return update_stats
Beispiel #11
0
    def update(self, batch: AgentBuffer,
               num_sequences: int) -> Dict[str, float]:
        """
        Performs update on model.
        :param batch: Batch of experiences.
        :param num_sequences: Number of sequences to process.
        :return: Results of update.
        """
        # Get decayed parameters
        decay_lr = self.decay_learning_rate.get_value(
            self.policy.get_current_step())
        decay_eps = self.decay_epsilon.get_value(
            self.policy.get_current_step())
        decay_bet = self.decay_beta.get_value(self.policy.get_current_step())
        returns = {}
        old_values = {}
        for name in self.reward_signals:
            old_values[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.value_estimates_key(name)])
            returns[name] = ModelUtils.list_to_tensor(
                batch[RewardSignalUtil.returns_key(name)])

        n_obs = len(self.policy.behavior_spec.observation_specs)
        current_obs = ObsUtil.from_buffer(batch, n_obs)
        # Convert to tensors
        current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs]

        act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK])
        actions = AgentAction.from_buffer(batch)

        memories = [
            ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in
            range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length)
        ]
        if len(memories) > 0:
            memories = torch.stack(memories).unsqueeze(0)

        # Get value memories
        value_memories = [
            ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i])
            for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]),
                           self.policy.sequence_length)
        ]
        if len(value_memories) > 0:
            value_memories = torch.stack(value_memories).unsqueeze(0)

        log_probs, entropy = self.policy.evaluate_actions(
            current_obs,
            masks=act_masks,
            actions=actions,
            memories=memories,
            seq_len=self.policy.sequence_length,
        )
        values, _ = self.critic.critic_pass(
            current_obs,
            memories=value_memories,
            sequence_length=self.policy.sequence_length,
        )
        old_log_probs = ActionLogProbs.from_buffer(batch).flatten()
        log_probs = log_probs.flatten()
        loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS],
                                               dtype=torch.bool)
        value_loss = self.ppo_value_loss(values, old_values, returns,
                                         decay_eps, loss_masks)
        policy_loss = self.ppo_policy_loss(
            ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]),
            log_probs,
            old_log_probs,
            loss_masks,
        )
        loss = (policy_loss + 0.5 * value_loss -
                decay_bet * ModelUtils.masked_mean(entropy, loss_masks))

        # Set optimizer learning rate
        ModelUtils.update_learning_rate(self.optimizer, decay_lr)
        self.optimizer.zero_grad()
        loss.backward()

        self.optimizer.step()
        update_stats = {
            # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow.
            # TODO: After PyTorch is default, change to something more correct.
            "Losses/Policy Loss": torch.abs(policy_loss).item(),
            "Losses/Value Loss": value_loss.item(),
            "Policy/Learning Rate": decay_lr,
            "Policy/Epsilon": decay_eps,
            "Policy/Beta": decay_bet,
        }

        for reward_provider in self.reward_signals.values():
            update_stats.update(reward_provider.update(batch))

        return update_stats
Beispiel #12
0
    def sac_value_loss(
        self,
        log_probs: ActionLogProbs,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _cont_ent_coef = self._log_ent_coef.continuous.exp()
            _disc_ent_coef = self._log_ent_coef.discrete.exp()
            for name in values.keys():
                if self._action_spec.discrete_size <= 0:
                    min_policy_qs[name] = torch.min(q1p_out[name],
                                                    q2p_out[name])
                else:
                    disc_action_probs = log_probs.all_discrete_tensor.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * disc_action_probs,
                        self._action_spec.discrete_branches,
                    )
                    _q1p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q1p
                        ]),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack([
                            torch.sum(_br, dim=1, keepdim=True)
                            for _br in _branched_q2p
                        ]),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if self._action_spec.discrete_size <= 0:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _cont_ent_coef * log_probs.continuous_tensor, dim=1)
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup),
                    loss_masks)
                value_losses.append(value_loss)
        else:
            disc_log_probs = log_probs.all_discrete_tensor
            branched_per_action_ent = ModelUtils.break_into_branches(
                disc_log_probs * disc_log_probs.exp(),
                self._action_spec.discrete_branches,
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack([
                torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True)
                for i, _lp in enumerate(branched_per_action_ent)
            ])
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0)
                    # Add continuous entropy bonus to minimum Q
                    if self._action_spec.continuous_size > 0:
                        v_backup += torch.sum(
                            _cont_ent_coef * log_probs.continuous_tensor,
                            dim=1,
                            keepdim=True,
                        )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name],
                                                 v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss
Beispiel #13
0
    def sac_value_loss(
        self,
        log_probs: torch.Tensor,
        values: Dict[str, torch.Tensor],
        q1p_out: Dict[str, torch.Tensor],
        q2p_out: Dict[str, torch.Tensor],
        loss_masks: torch.Tensor,
        discrete: bool,
    ) -> torch.Tensor:
        min_policy_qs = {}
        with torch.no_grad():
            _ent_coef = torch.exp(self._log_ent_coef)
            for name in values.keys():
                if not discrete:
                    min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name])
                else:
                    action_probs = log_probs.exp()
                    _branched_q1p = ModelUtils.break_into_branches(
                        q1p_out[name] * action_probs, self.act_size
                    )
                    _branched_q2p = ModelUtils.break_into_branches(
                        q2p_out[name] * action_probs, self.act_size
                    )
                    _q1p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q1p
                            ]
                        ),
                        dim=0,
                    )
                    _q2p_mean = torch.mean(
                        torch.stack(
                            [
                                torch.sum(_br, dim=1, keepdim=True)
                                for _br in _branched_q2p
                            ]
                        ),
                        dim=0,
                    )

                    min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean)

        value_losses = []
        if not discrete:
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.sum(
                        _ent_coef * log_probs, dim=1
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup), loss_masks
                )
                value_losses.append(value_loss)
        else:
            branched_per_action_ent = ModelUtils.break_into_branches(
                log_probs * log_probs.exp(), self.act_size
            )
            # We have to do entropy bonus per action branch
            branched_ent_bonus = torch.stack(
                [
                    torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True)
                    for i, _lp in enumerate(branched_per_action_ent)
                ]
            )
            for name in values.keys():
                with torch.no_grad():
                    v_backup = min_policy_qs[name] - torch.mean(
                        branched_ent_bonus, axis=0
                    )
                value_loss = 0.5 * ModelUtils.masked_mean(
                    torch.nn.functional.mse_loss(values[name], v_backup.squeeze()),
                    loss_masks,
                )
                value_losses.append(value_loss)
        value_loss = torch.mean(torch.stack(value_losses))
        if torch.isinf(value_loss).any() or torch.isnan(value_loss).any():
            raise UnityTrainerException("Inf found")
        return value_loss