예제 #1
0
    def beta_dist_loss(self, advantage_mask, phi, action, dist_info,
                       old_dist_info, valid, opt_info):
        action = (action + 1) / 2
        distribution = torch.distributions.beta.Beta(dist_info.mean,
                                                     dist_info.log_std)
        old_dist = torch.distributions.beta.Beta(old_dist_info.mean,
                                                 old_dist_info.log_std)
        pi_loss = -torch.sum(
            advantage_mask *
            (phi.detach() * distribution.log_prob(action).sum(dim=-1)))
        kl = torch.distributions.kl_divergence(old_dist,
                                               distribution).sum(dim=-1)
        alpha_loss = valid_mean(
            self.alpha * (self.epsilon_alpha - kl.detach()) +
            self.alpha.detach() * kl, valid)
        entropy = valid_mean(distribution.entropy().sum(dim=-1), valid)
        alpha_loss -= 0.01 * entropy

        mode = self.agent.beta_dist_mode(old_dist_info.mean,
                                         old_dist_info.log_std)
        opt_info.alpha.append(self.alpha.item())
        opt_info.policy_kl.append(kl.mean().item())
        opt_info.pi_mu.append(mode.mean().item())
        opt_info.pi_log_std.append(
            old_dist.entropy().sum(dim=-1).mean().item())
        return pi_loss, alpha_loss, opt_info
예제 #2
0
파일: ppo_alae.py 프로젝트: jhejna/ul_gen
    def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution
        ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
            new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
            1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = - valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_) ** 2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = - self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss # + self.vae_loss_coeff * vae_loss
        
        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
예제 #3
0
    def compute_loss(self, observations, next_observations, actions, valid):
        #------------------------------------------------------------#
        # hacky dimension add for when you have only one environment (debugging)
        if actions.dim() == 2: 
            actions = actions.unsqueeze(1)
        #------------------------------------------------------------#
        phi1, phi2, predicted_phi2, predicted_phi2_stacked, predicted_action = self.forward(observations, next_observations, actions)
        actions = torch.max(actions.view(-1, *actions.shape[2:]), 1)[1] # conver action to (T * B, action_size), then get target indexes
        inverse_loss = nn.functional.cross_entropy(predicted_action.view(-1, *predicted_action.shape[2:]), actions.detach(), reduction='none').view(phi1.shape[0], phi2.shape[1])
        inverse_loss = valid_mean(inverse_loss, valid)
        
        forward_loss = torch.tensor(0.0, device=self.device)

        forward_loss_1 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[0], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size
        forward_loss += valid_mean(forward_loss_1, valid)

        forward_loss_2 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[1], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size
        forward_loss += valid_mean(forward_loss_2, valid)

        forward_loss_3 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[2], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size
        forward_loss += valid_mean(forward_loss_3, valid)

        forward_loss_4 = nn.functional.dropout(nn.functional.mse_loss(predicted_phi2[3], phi2.detach(), reduction='none'), p=0.2).sum(-1)/self.feature_size
        forward_loss += valid_mean(forward_loss_4, valid)

        return self.inverse_loss_wt*inverse_loss, self.forward_loss_wt*forward_loss
예제 #4
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             init_rnn_state=None):
        if init_rnn_state is not None:
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)

        else:
            dist_info, value = self.agent(*agent_inputs)

        dist = self.agent.distribution

        lr = dist.likelihood_ratio(action,
                                   old_dist_info=old_dist_info,
                                   new_dist_info=dist_info)
        kl = dist.kl(old_dist_info=old_dist_info, new_dist_info=dist_info)

        if init_rnn_state is not None:
            raise NotImplementedError
        else:
            mean_kl = valid_mean(kl)
            surr_loss = -valid_mean(lr * advantage)

        loss = surr_loss
        entropy = dist.mean_entropy(dist_info, valid)
        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, entropy, perplexity
예제 #5
0
    def pi_alpha_loss(self, samples, valid, conv_out):
        # PI LOSS.
        # Uses detached conv out; avoid re-computing.
        conv_detach = conv_out.detach()
        agent_inputs = samples.agent_inputs._replace(observation=conv_detach)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            # new_action = new_action.detach()  # No grad.
            raise NotImplementedError
        # Re-use the detached latent.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError
        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        # ALPHA LOSS.
        if self.target_entropy is not None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        return pi_loss, alpha_loss, pi_mean.detach(), pi_log_std.detach()
예제 #6
0
    def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info,
            init_rnn_state=None):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
            new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
            1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = - valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_) ** 2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = - self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
예제 #7
0
    def loss(self, samples):
        agent_inputs = AgentInputs(
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        if self.agent.recurrent:
            init_rnn_state = self.samples.agent.agent_info.prev_rnn_state[
                0]  # T = 0.
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        # TODO: try to compute everyone on device.
        return_, advantage, valid = self.process_returns(samples)

        dist = self.agent.distribution
        logli = dist.log_likelihood(samples.agent.action, dist_info)
        pi_loss = -valid_mean(logli * advantage, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, entropy, perplexity
예제 #8
0
    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values
        and an entropy term.  Computes reparameterized policy loss, and loss for
        tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))

        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action)
        min_target_q = torch.min(target_q1, target_q2)
        target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount ** self.n_step_return
        y = (self.reward_scale * samples.return_ +
            (1 - samples.done_n.float()) * disc * target_value)

        q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
        q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)

        new_action, log_pi, (pi_mean, pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError

        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None and self.fixed_alpha is None:
            alpha_losses = - self._log_alpha * (log_pi.detach() + self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std))
        return losses, values
예제 #9
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             old_value,
             init_rnn_state=None):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state,
                                                      device=action.device)
        else:
            dist_info, value = self.agent(*agent_inputs, device=action.device)
        dist = self.agent.distribution

        # Surrogate policy loss
        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        # Surrogate value loss (if doing)
        if self.clip_vf_loss:
            v_loss_unclipped = (value - return_)**2
            v_clipped = old_value + torch.clamp(
                value - old_value, -self.ratio_clip, self.ratio_clip)
            v_loss_clipped = (v_clipped - return_)**2
            v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped)
            value_error = 0.5 * v_loss_max.mean()
        else:
            value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, pi_loss, value_loss, entropy, perplexity
예제 #10
0
    def loss(self,
             agent_inputs,
             action,
             return_,
             advantage,
             valid,
             old_dist_info,
             init_rnn_state=None):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            if self.agent.both_actions:
                dist_info, og_dist_info, value, og_value = self.agent(
                    *agent_inputs)
            else:
                dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution
        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
                                    1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        if self.similarity_loss:  # Try KL next
            # pi_sim = self.agent.distribution.kl(og_dist_info, dist_info)
            pi_sim = F.cosine_similarity(dist_info.prob, og_dist_info.prob)
            value_sim = (value - og_value)**2

        loss += -self.similarity_coeff * pi_sim.mean() + 0.5 * value_sim.mean()
        # loss += self.similarity_coeff * (pi_sim.mean() + value_sim)

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
예제 #11
0
파일: sac_v.py 프로젝트: afansi/rlpyt
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs)
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        v = self.agent.v(*agent_inputs)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        v_target = (min_log_target - log_pi +
                    prior_log_pi).detach()  # No grad.

        v_loss = 0.5 * valid_mean((v - v_target)**2, valid)

        if self.reparameterize:
            pi_losses = log_pi - min_log_target
        else:
            pi_factor = (v - v_target).detach()
            pi_losses = log_pi * pi_factor
        if self.policy_output_regularization > 0:
            pi_losses += self.policy_output_regularization * torch.mean(
                0.5 * pi_mean**2 + 0.5 * pi_log_std**2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        losses = (q1_loss, q2_loss, v_loss, pi_loss)
        values = tuple(val.detach()
                       for val in (q1, q2, v, pi_mean, pi_log_std))
        return losses, values
예제 #12
0
    def loss(
        self,
        agent_inputs,
        action,
        return_,
        advantage,
        valid,
        old_dist_info,
        init_rnn_state=None,
    ):
        """
        Compute the training loss: policy_loss + value_loss + entropy_loss
        Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage)
        Value loss:  0.5 * (estimated_value - return) ^ 2
        Calls the agent to compute forward pass on training data, and uses
        the ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs,
                                                      init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)
        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action,
                                      old_dist_info=old_dist_info,
                                      new_dist_info=dist_info)
        ratio = ratio.clamp_max(1000)  # added (to prevent ratio == inf)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1.0 - self.ratio_clip,
                                    1.0 + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = -valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity
예제 #13
0
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        qs = self.agent(*samples.agent_inputs)
        q = select_at_indexes(samples.action, qs)
        with torch.no_grad():
            target_qs = self.agent.target(*samples.target_inputs)
            if self.double_dqn:
                next_qs = self.agent(*samples.target_inputs)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
        disc_target_q = (self.discount**self.n_step_return) * target_q
        y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q
        delta = y - q
        losses = 0.5 * delta**2
        abs_delta = abs(delta)
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights
        td_abs_errors = torch.clamp(abs_delta.detach(), 0, self.delta_clip)
        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            td_abs_errors *= valid
        else:
            loss = torch.mean(losses)

        return loss, td_abs_errors
예제 #14
0
    def loss(self, samples):
        """
        Computes the Distributional Q-learning loss, based on projecting the
        discounted rewards + target Q-distribution into the current Q-domain,
        with cross-entropy loss.

        Returns loss and KL-divergence-errors for use in prioritization.
        """

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Makde 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount ** self.n_step_return)  # [P']
        next_z = torch.ger(1 - samples.done_n.float(), next_z)  # [B,P']
        ret = samples.return_.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(*samples.target_inputs)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(*samples.target_inputs)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent(*samples.agent_inputs)  # [B,A,P]
        p = select_at_indexes(samples.action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= samples.is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
            (torch.log(target_p) - torch.log(p.detach())), dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)

        return loss, KL_div
예제 #15
0
파일: vae.py 프로젝트: kevinghst/rl_ul
    def vae_loss(self, samples):
        observation = samples.observation[0]  # [T,B,C,H,W]->[B,C,H,W]
        target_observation = samples.observation[self.delta_T]
        if self.delta_T > 0:
            action = samples.action[:-1]  # [T-1,B,A]  don't need the last one
            if self.onehot_action:
                action = self.distribution.to_onehot(action)
            t, b = action.shape[:2]
            action = action.transpose(1, 0)  # [B,T-1,A]
            action = action.reshape(b, -1)
        else:
            action = None
        observation, target_observation, action = buffer_to(
            (observation, target_observation, action),
            device=self.device
        )

        h, conv_out = self.encoder(observation)
        z, mu, logvar = self.vae_head(h, action)
        recon_z = self.decoder(z)

        if target_observation.dtype == torch.uint8:
            target_observation = target_observation.type(torch.float)
            target_observation = target_observation.mul_(1 / 255.)

        b, c, h, w = target_observation.shape
        recon_losses = F.binary_cross_entropy(
            input=recon_z.reshape(b * c, h, w),
            target=target_observation.reshape(b * c, h, w),
            reduction="none",
        )
        if self.delta_T > 0:
            valid = valid_from_done(samples.done).type(torch.bool)  # [T,B]
            valid = valid[-1]  # [B]
            valid = valid.to(self.device)
        else:
            valid = None  # all are valid
        recon_losses = recon_losses.view(b, c, h, w).sum(dim=(2, 3))  # sum over H,W
        recon_losses = recon_losses.mean(dim=1)  # mean over C (o/w loss is HUGE)
        recon_loss = valid_mean(recon_losses, valid=valid)  # mean over batch

        kl_losses = 1 + logvar - mu.pow(2) - logvar.exp()
        kl_losses = kl_losses.sum(dim=-1)  # sum over latent dimension
        kl_loss = -0.5 * valid_mean(kl_losses, valid=valid)  # mean over batch
        kl_loss = self.kl_coeff * kl_loss

        return recon_loss, kl_loss, conv_out
예제 #16
0
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action),
            device=self.agent.device)  # Move to device once, re-use.
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs)
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        v = self.agent.v(*agent_inputs)
        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())
        v_target = (min_log_target - log_pi +
                    prior_log_pi).detach()  # No grad.
        v_loss = 0.5 * valid_mean((v - v_target)**2, valid)

        if self.reparameterize:
            pi_losses = log_pi - min_log_target
        else:
            pi_factor = (v - v_target).detach()  # No grad.
            pi_losses = log_pi * pi_factor
        if self.policy_output_regularization > 0:
            pi_losses += torch.sum(
                self.policy_output_regularization * 0.5 * pi_mean**2 +
                pi_log_std**2,
                dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        losses = (q1_loss, q2_loss, v_loss, pi_loss)
        values = tuple(val.detach()
                       for val in (q1, q2, v, pi_mean, pi_log_std))
        return losses, values
예제 #17
0
파일: rlpyt_algos.py 프로젝트: zivzone/spr
    def rl_loss(self, latent, action, return_n, done_n, prev_action,
                prev_reward, next_state, next_prev_action, next_prev_reward,
                is_weights, done):

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Make 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount**self.n_step_return)  # [P']
        next_z = torch.ger(1 - done_n.float(), next_z)  # [B,P']
        ret = return_n.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(next_state, next_prev_action,
                                          next_prev_reward)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(next_state, next_prev_action,
                                     next_prev_reward)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent.head_forward(latent, prev_action,
                                     prev_reward)  # [B,A,P]
        p = select_at_indexes(action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
                           (torch.log(target_p) - torch.log(p.detach())),
                           dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(done[1])
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)

        return loss, KL_div
예제 #18
0
    def q_loss(self, samples):
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= 1 - samples.timeout_n.float()

        # Run the convolution only once, return so pi_loss can use it.
        if self.store_latent:
            conv_out = None
            q_inputs = samples.agent_inputs
        else:
            conv_out = self.agent.conv(samples.agent_inputs.observation)
            if self.stop_conv_grad:
                conv_out = conv_out.detach()
            q_inputs = samples.agent_inputs._replace(observation=conv_out)

        # Q LOSS.
        q1, q2 = self.agent.q(*q_inputs, samples.action)
        with torch.no_grad():
            # Run the target convolution only once.
            if self.store_latent:
                target_inputs = samples.target_inputs
            else:
                target_conv_out = self.agent.target_conv(
                    samples.target_inputs.observation
                )
                target_inputs = samples.target_inputs._replace(
                    observation=target_conv_out
                )
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs, target_action)
            min_target_q = torch.min(target_q1, target_q2)
            target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount ** self.n_step_return
        y = (
            self.reward_scale * samples.return_
            + (1 - samples.done_n.float()) * disc * target_value
        )
        q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
        q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)

        return q1_loss, q2_loss, valid, conv_out, q1.detach(), q2.detach()
예제 #19
0
    def loss(self, samples):
        """
        Computes the training loss: policy_loss + value_loss + entropy_loss.
        Policy loss: log-likelihood of actions * advantages
        Value loss: 0.5 * (estimated_value - return) ^ 2
        Organizes agent inputs from training samples, calls the agent instance
        to run forward pass on training data, and uses the
        ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        agent_inputs = AgentInputs(
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
        )
        if self.agent.recurrent:
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[
                0]  # T = 0.
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(
                *agent_inputs,
                init_rnn_state,
                device=agent_inputs.prev_action.device)
        else:
            dist_info, value = self.agent(
                *agent_inputs, device=agent_inputs.prev_action.device)
        # TODO: try to compute everyone on device.
        return_, advantage, valid = self.process_returns(samples)

        dist = self.agent.distribution
        logli = dist.log_likelihood(samples.agent.action, dist_info)
        pi_loss = -valid_mean(logli * advantage, valid)

        value_error = 0.5 * (value - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)

        return loss, pi_loss, value_loss, entropy, perplexity
예제 #20
0
 def compute_loss(self, observations, next_observations, actions, valid):
     # dimension add for when you have only one environment
     if actions.dim() == 2: actions = actions.unsqueeze(1)
     phi1, phi2, predicted_phi2, predicted_action = self.forward(
         observations, next_observations, actions)
     actions = torch.max(actions.view(-1, *actions.shape[2:]),
                         1)[1]  # convert action to (T * B, action_size)
     inverse_loss = nn.functional.cross_entropy(
         predicted_action.view(-1, *predicted_action.shape[2:]),
         actions.detach(),
         reduction='none').view(phi1.shape[0], phi1.shape[1])
     forward_loss = nn.functional.mse_loss(
         predicted_phi2, phi2.detach(),
         reduction='none').sum(-1) / self.feature_size
     inverse_loss = valid_mean(inverse_loss, valid.detach())
     forward_loss = valid_mean(forward_loss, valid.detach())
     return self.inverse_loss_wt * inverse_loss, self.forward_loss_wt * forward_loss
예제 #21
0
파일: sac_v2.py 프로젝트: wilson1yan/rlpyt
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        qs = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs).detach()
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q_losses = [0.5 * valid_mean((y - q)**2, valid) for q in qs]

        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_targets = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(torch.stack(log_targets, dim=0), dim=0)[0]
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            alpha = self.agent.log_alpha.exp().detach()
            pi_losses = alpha * log_pi - min_log_target - prior_log_pi

        if self.policy_output_regularization > 0:
            pi_losses += torch.sum(
                self.policy_output_regularization * 0.5 * pi_mean**2 +
                pi_log_std**2,
                dim=-1)

        pi_loss = valid_mean(pi_losses, valid)

        # Calculate log_alpha loss
        alpha_loss = -valid_mean(self.agent.log_alpha *
                                 (log_pi + self.target_entropy).detach())

        losses = (pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (pi_mean, pi_log_std, alpha))
        q_values = tuple(q.detach() for q in qs)
        return q_losses, losses, values, q_values
예제 #22
0
    def beta_kl_losses(
        self,
        agent_inputs,
        action,
        return_,
        advantage,
        valid,
        old_dist_info,
        c_return,
        c_advantage,
        init_rnn_state=None,
    ):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                *agent_inputs, init_rnn_state)
        else:
            r_dist_info, c_dist_info = self.agent.beta_dist_infos(
                *agent_inputs)
        dist = self.agent.distribution

        r_ratio = dist.likelihood_ratio(action,
                                        old_dist_info=old_dist_info,
                                        new_dist_info=r_dist_info)
        surr_1 = r_ratio * advantage
        r_clipped_ratio = torch.clamp(r_ratio, 1.0 - self.ratio_clip,
                                      1.0 + self.ratio_clip)
        surr_2 = r_clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        beta_r_loss = -valid_mean(surrogate, valid)

        c_ratio = dist.likelihood_ratio(action,
                                        old_dist_info=old_dist_info,
                                        new_dist_info=c_dist_info)
        c_surr_1 = c_ratio * c_advantage
        c_clipped_ratio = torch.clamp(c_ratio, 1.0 - self.ratio_clip,
                                      1.0 + self.ratio_clip)
        c_surr_2 = c_clipped_ratio * c_advantage
        c_surrogate = torch.max(c_surr_1, c_surr_2)
        beta_c_loss = valid_mean(c_surrogate, valid)

        return beta_r_loss, beta_c_loss
    def compute_loss(self, observations, next_observations, actions, valid):
        #------------------------------------------------------------#
        # hacky dimension add for when you have only one environment (debugging)
        if actions.dim() == 2:
            actions = actions.unsqueeze(1)
        #------------------------------------------------------------#
        phi2, predicted_phi2, _ = self.forward(observations, next_observations,
                                               actions)

        forward_loss = torch.tensor(0.0, device=self.device)

        forward_loss_1 = nn.functional.dropout(
            nn.functional.mse_loss(
                predicted_phi2[0], phi2.detach(), reduction='none'),
            p=0.2).sum(-1) / self.feature_size
        forward_loss += valid_mean(forward_loss_1, valid)

        forward_loss_2 = nn.functional.dropout(
            nn.functional.mse_loss(
                predicted_phi2[1], phi2.detach(), reduction='none'),
            p=0.2).sum(-1) / self.feature_size
        forward_loss += valid_mean(forward_loss_2, valid)

        forward_loss_3 = nn.functional.dropout(
            nn.functional.mse_loss(
                predicted_phi2[2], phi2.detach(), reduction='none'),
            p=0.2).sum(-1) / self.feature_size
        forward_loss += valid_mean(forward_loss_3, valid)

        forward_loss_4 = nn.functional.dropout(
            nn.functional.mse_loss(
                predicted_phi2[3], phi2.detach(), reduction='none'),
            p=0.2).sum(-1) / self.feature_size
        forward_loss += valid_mean(forward_loss_4, valid)

        forward_loss_5 = nn.functional.dropout(
            nn.functional.mse_loss(
                predicted_phi2[4], phi2.detach(), reduction='none'),
            p=0.2).sum(-1) / self.feature_size
        forward_loss += valid_mean(forward_loss_5, valid)

        return self.forward_loss_wt * forward_loss
예제 #24
0
파일: ddpg.py 프로젝트: BorenTsai/rlpyt
 def q_loss(self, samples, valid):
     """Samples have leading batch dimension [B,..] (but not time)."""
     q = self.agent.q(*samples.agent_inputs, samples.action)
     with torch.no_grad():
         target_q = self.agent.target_q_at_mu(*samples.target_inputs)
     disc = self.discount**self.n_step_return
     y = samples.return_ + (1 - samples.done_n.float()) * disc * target_q
     y = torch.clamp(y, -self.q_target_clip, self.q_target_clip)
     q_losses = 0.5 * (y - q)**2
     q_loss = valid_mean(q_losses, valid)  # valid can be None.
     return q_loss
예제 #25
0
 def compute_loss(self, observations, valid):
     phi, predicted_phi, T, B = self.forward(observations, done=None)
     forward_loss = nn.functional.mse_loss(
         predicted_phi, phi.detach(),
         reduction='none').sum(-1) / self.feature_size
     mask = torch.rand(forward_loss.shape)
     mask = (mask > self.drop_probability).type(torch.FloatTensor).to(
         self.device)
     forward_loss = forward_loss * mask.detach()
     forward_loss = valid_mean(forward_loss, valid.detach())
     return forward_loss
예제 #26
0
 def q_loss(self, samples, valid):
     q1, q2 = self.agent.q(*samples.agent_inputs, samples.action)
     with torch.no_grad():
         target_q1, target_q2 = self.agent.target_q_at_mu(
             *samples.target_inputs)  # Includes target action noise.
         target_q = torch.min(target_q1, target_q2)
     disc = self.discount**self.n_step_return
     y = samples.return_ + (1 - samples.done_n.float()) * disc * target_q
     q1_losses = 0.5 * (y - q1)**2
     q2_losses = 0.5 * (y - q2)**2
     q_loss = valid_mean(q1_losses + q2_losses, valid)  # valid can be None.
     return q_loss
예제 #27
0
    def continuous_actions_loss(self, advantage_mask, phi, action, dist_info,
                                old_dist_info, valid, opt_info):
        d = np.prod(action.shape[-1])
        distribution = torch.distributions.normal.Normal(
            loc=dist_info.mean, scale=dist_info.log_std)
        pi_loss = -torch.sum(
            advantage_mask *
            (phi.detach() * distribution.log_prob(action).sum(dim=-1)))
        # pi_loss = - torch.sum(advantage_mask * (phi.detach() * self.agent.distribution.log_likelihood(action, dist_info)))
        new_std = dist_info.log_std
        old_std = old_dist_info.log_std
        old_covariance = torch.diag_embed(old_std)
        old_covariance_inverse = torch.diag_embed(1 / old_std)
        new_covariance_inverse = torch.diag_embed(1 / new_std)
        old_covariance_determinant = torch.prod(old_std, dim=-1)
        new_covariance_determinant = torch.prod(new_std, dim=-1)

        mu_kl = 0.5 * utils.batched_quadratic_form(
            dist_info.mean - old_dist_info.mean, old_covariance_inverse)
        trace = utils.batched_trace(
            torch.matmul(new_covariance_inverse, old_covariance))
        sigma_kl = 0.5 * (trace - d + torch.log(
            new_covariance_determinant / old_covariance_determinant))
        alpha_mu_loss = valid_mean(
            self.alpha_mu * (self.epsilon_alpha_mu - mu_kl.detach()) +
            self.alpha_mu.detach() * mu_kl, valid)
        alpha_sigma_loss = valid_mean(
            self.alpha_sigma * (self.epsilon_alpha_sigma - sigma_kl.detach()) +
            self.alpha_sigma.detach() * sigma_kl, valid)
        opt_info.alpha_mu.append(self.alpha_mu.item())
        opt_info.alpha_sigma.append(self.alpha_sigma.item())
        opt_info.alpha_mu_loss.append(alpha_mu_loss.item())
        opt_info.mu_kl.append(valid_mean(mu_kl, valid).item())
        opt_info.sigma_kl.append(valid_mean(sigma_kl, valid).item())
        opt_info.alpha_sigma_loss.append(
            valid_mean(self.epsilon_alpha_sigma - sigma_kl, valid).item())
        opt_info.pi_mu.append(dist_info.mean.mean().item())
        opt_info.pi_log_std.append(dist_info.log_std.mean().item())
        return pi_loss, alpha_mu_loss + alpha_sigma_loss, opt_info
예제 #28
0
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))
        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_v = self.agent.target_v(*target_inputs).detach()
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_v)
        if self.mid_batch_reset and not self.agent.recurrent:
            valid = None  # OR: torch.ones_like(samples.done, dtype=torch.float)
        else:
            valid = valid_from_done(samples.done)

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, _ = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            alpha = self.agent.log_alpha.exp().detach()
            pi_losses = alpha * log_pi - min_log_target - prior_log_pi

        pi_loss = valid_mean(pi_losses, valid)

        # Calculate log_alpha loss
        alpha_loss = -valid_mean(self.agent.log_alpha *
                                 (log_pi + self.target_entropy).detach())

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, alpha))
        return losses, values
예제 #29
0
 def discrete_actions_loss(self, advantage_mask, phi, action, dist_info,
                           old_dist_info, valid, opt_info):
     dist = self.agent.distribution
     pi_loss = -torch.sum(advantage_mask *
                          (phi.detach() * dist.log_likelihood(
                              action.contiguous(), dist_info)))
     policy_kl = dist.kl(old_dist_info, dist_info)
     alpha_loss = valid_mean(
         self.alpha * (self.epsilon_alpha - policy_kl.detach()) +
         self.alpha.detach() * policy_kl, valid)
     opt_info.alpha_loss.append(alpha_loss.item())
     opt_info.alpha.append(self.alpha.item())
     opt_info.policy_kl.append(policy_kl.mean().item())
     opt_info.entropy.append(dist.entropy(dist_info).mean().item())
     return pi_loss, alpha_loss, opt_info
예제 #30
0
파일: ppo.py 프로젝트: ajabri/rlpyt
    def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info,
            init_rnn_state=None):
        if init_rnn_state is not None:
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state)
        else:
            dist_info, value = self.agent(*agent_inputs)

        # TODO IF MULTIAGENT, reshape things
        # Just kidding. It seems that the dist.* functions can operate on multi-dimensional tensors
        # Need to double check this is true, but things seem to be ok
            # (entropy and likelihood ratios are computed along last dim)

        dist = self.agent.distribution

        ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info,
            new_dist_info=dist_info)
        surr_1 = ratio * advantage
        clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip,
            1. + self.ratio_clip)
        surr_2 = clipped_ratio * advantage
        surrogate = torch.min(surr_1, surr_2)
        pi_loss = - valid_mean(surrogate, valid)

        value_error = 0.5 * (value - return_) ** 2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        entropy = dist.mean_entropy(dist_info, valid)
        entropy_loss = - self.entropy_loss_coeff * entropy

        loss = pi_loss + value_loss + entropy_loss

        perplexity = dist.mean_perplexity(dist_info, valid)
        return loss, entropy, perplexity