Exemplo n.º 1
0
    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values
        and an entropy term.  Computes reparameterized policy loss, and loss for
        tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))

        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1, q2 = self.agent.q(*agent_inputs, action)
        with torch.no_grad():
            target_action, target_log_pi, _ = self.agent.pi(*target_inputs)
            target_q1, target_q2 = self.agent.target_q(*target_inputs,
                                                       target_action)
        min_target_q = torch.min(target_q1, target_q2)
        target_value = min_target_q - self._alpha * target_log_pi
        disc = self.discount**self.n_step_return
        y = (self.reward_scale * samples.return_ +
             (1 - samples.done_n.float()) * disc * target_value)

        q1_loss = 0.5 * valid_mean((y - q1)**2, valid)
        q2_loss = 0.5 * valid_mean((y - q2)**2, valid)

        new_action, log_pi, (pi_mean,
                             pi_log_std) = self.agent.pi(*agent_inputs)
        if not self.reparameterize:
            new_action = new_action.detach()  # No grad.
        log_target1, log_target2 = self.agent.q(*agent_inputs, new_action)
        min_log_target = torch.min(log_target1, log_target2)
        prior_log_pi = self.get_action_prior(new_action.cpu())

        if self.reparameterize:
            pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi
        else:
            raise NotImplementedError

        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)
        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None:
            alpha_losses = -self._log_alpha * (log_pi.detach() +
                                               self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)
        else:
            alpha_loss = None

        losses = (q1_loss, q2_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std))
        return losses, values
Exemplo n.º 2
0
    def loss(self, samples):
        """
        Computes the Distributional Q-learning loss, based on projecting the
        discounted rewards + target Q-distribution into the current Q-domain,
        with cross-entropy loss.  

        Returns loss and KL-divergence-errors for use in prioritization.
        """

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Makde 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount**self.n_step_return)  # [P']
        next_z = torch.ger(1 - samples.done_n.float(), next_z)  # [B,P']
        ret = samples.return_.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(*samples.target_inputs)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(*samples.target_inputs)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent(*samples.agent_inputs)  # [B,A,P]
        p = select_at_indexes(samples.action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= samples.is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
                           (torch.log(target_p) - torch.log(p.detach())),
                           dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)

        return loss, KL_div
Exemplo n.º 3
0
    def compute_input_priorities(self, samples):
        """Used when putting new samples into the replay buffer.  Computes
        n-step TD-errors using recorded Q-values from online network and
        value scaling.  Weights the max and the mean TD-error over each sequence
        to make a single priority value for that sequence.  

        Note:
            Although the original R2D2 implementation used the entire
            80-step sequence to compute the input priorities, we ran R2D1 with 40
            time-step sample batches, and so computed the priority for each
            80-step training sequence based on one of the two 40-step halves.
            Algorithm argument ``input_priority_shift`` determines which 40-step
            half is used as the priority for the 80-step sequence.  (Since this 
            method might get executed by alternating memory copiers in async mode,
            don't carry internal state here, do all computation with only the samples
            available in input.  Could probably reduce to one memory copier and keep
            state there, if needed.)
        """

        # """Just for first input into replay buffer.
        # Simple 1-step return TD-errors using recorded Q-values from online
        # network and value scaling, with the T dimension reduced away (same
        # priority applied to all samples in this batch; whereever the rnn state
        # is kept--hopefully the first step--this priority will apply there).
        # The samples duration T might be less than the training segment, so
        # this is an approximation of an approximation, but hopefully will
        # capture the right behavior.
        # UPDATE 20190826: Trying using n-step returns.  For now using samples
        # with full n-step return available...later could also use partial
        # returns for samples at end of batch.  35/40 ain't bad tho.
        # Might not carry/use internal state here, because might get executed
        # by alternating memory copiers in async mode; do all with only the
        # samples avialable from input."""
        samples = torchify_buffer(samples)
        q = samples.agent.agent_info.q
        action = samples.agent.action
        q_max = torch.max(q, dim=-1).values
        q_at_a = select_at_indexes(action, q)
        return_n, done_n = discount_return_n_step(
            reward=samples.env.reward,
            done=samples.env.done,
            n_step=self.n_step_return,
            discount=self.discount,
            do_truncated=False,  # Only samples with full n-step return.
        )
        # y = self.value_scale(
        #     samples.env.reward[:-1] +
        #     (self.discount * (1 - samples.env.done[:-1].float()) *  # probably done.float()
        #         self.inv_value_scale(q_max[1:]))
        # )
        nm1 = max(1,
                  self.n_step_return - 1)  # At least 1 bc don't have next Q.
        y = self.value_scale(return_n + (1 - done_n.float()) *
                             self.inv_value_scale(q_max[nm1:]))
        delta = abs(q_at_a[:-nm1] - y)
        # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
        if self.delta_clip is not None:  # Huber loss.
            delta = torch.clamp(delta, 0, self.delta_clip)
        valid = valid_from_done(samples.env.done[:-nm1])
        max_d = torch.max(delta * valid, dim=0).values
        mean_d = valid_mean(delta, valid, dim=0)  # Still high if less valid.
        priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]
        return priorities.numpy()
Exemplo n.º 4
0
    def loss(self, samples):
        """Samples have leading Time and Batch dimentions [T,B,..]. Move all
        samples to device first, and then slice for sub-sequences.  Use same
        init_rnn_state for agent and target; start both at same t.  Warmup the
        RNN state first on the warmup subsequence, then train on the remaining
        subsequence.

        Returns loss (usually use MSE, not Huber), TD-error absolute values,
        and new sequence-wise priorities, based on weighted sum of max and mean
        TD-error over the sequence."""
        all_observation, all_action, all_reward = buffer_to(
            (samples.all_observation, samples.all_action, samples.all_reward),
            device=self.agent.device)
        wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return
        if wT > 0:
            warmup_slice = slice(None, wT)  # Same for agent and target.
            warmup_inputs = AgentInputs(
                observation=all_observation[warmup_slice],
                prev_action=all_action[warmup_slice],
                prev_reward=all_reward[warmup_slice],
            )
        agent_slice = slice(wT, wT + bT)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice],
            prev_action=all_action[agent_slice],
            prev_reward=all_reward[agent_slice],
        )
        target_slice = slice(wT,
                             None)  # Same start t as agent. (wT + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )
        action = samples.all_action[wT + 1:wT + 1 + bT]  # CPU.
        return_ = samples.return_[wT:wT + bT]
        done_n = samples.done_n[wT:wT + bT]
        if self.store_rnn_state_interval == 0:
            init_rnn_state = None
        else:
            # [B,N,H]-->[N,B,H] cudnn.
            init_rnn_state = buffer_method(samples.init_rnn_state, "transpose",
                                           0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
        if wT > 0:  # Do warmup.
            with torch.no_grad():
                _, target_rnn_state = self.agent.target(
                    *warmup_inputs, init_rnn_state)
                _, init_rnn_state = self.agent(*warmup_inputs, init_rnn_state)
            # Recommend aligning sampling batch_T and store_rnn_interval with
            # warmup_T (and no mid_batch_reset), so that end of trajectory
            # during warmup leads to new trajectory beginning at start of
            # training segment of replay.
            warmup_invalid_mask = valid_from_done(
                samples.done[:wT])[-1] == 0  # [B]
            init_rnn_state[:, warmup_invalid_mask] = 0  # [N,B,H] (cudnn)
            target_rnn_state[:, warmup_invalid_mask] = 0
        else:
            target_rnn_state = init_rnn_state

        qs, _ = self.agent(*agent_inputs, init_rnn_state)  # [T,B,A]
        q = select_at_indexes(action, qs)
        with torch.no_grad():
            target_qs, _ = self.agent.target(*target_inputs, target_rnn_state)
            if self.double_dqn:
                next_qs, _ = self.agent(*target_inputs, init_rnn_state)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
            target_q = target_q[-bT:]  # Same length as q.

        disc = self.discount**self.n_step_return
        y = self.value_scale(return_ + (1 - done_n.float()) * disc *
                             self.inv_value_scale(target_q))  # [T,B]
        delta = y - q
        losses = 0.5 * delta**2
        abs_delta = abs(delta)
        # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights.unsqueeze(0)  # weights: [B] --> [1,B]
        valid = valid_from_done(samples.done[wT:])  # 0 after first done.
        loss = valid_mean(losses, valid)
        td_abs_errors = abs_delta.detach()
        if self.delta_clip is not None:
            td_abs_errors = torch.clamp(td_abs_errors, 0,
                                        self.delta_clip)  # [T,B]
        valid_td_abs_errors = td_abs_errors * valid
        max_d = torch.max(valid_td_abs_errors, dim=0).values
        mean_d = valid_mean(td_abs_errors, valid,
                            dim=0)  # Still high if less valid.
        priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]

        return loss, valid_td_abs_errors, priorities
Exemplo n.º 5
0
    def loss(self, samples):
        """
        Computes losses for twin Q-values against the min of twin target Q-values
        and an entropy term.  Computes reparameterized policy loss, and loss for
        tuning entropy weighting, alpha.  
        
        Input samples have leading batch dimension [B,..] (but not time).
        """
        agent_inputs, target_inputs, action = buffer_to(
            (samples.agent_inputs, samples.target_inputs, samples.action))

        if self.mid_batch_reset and not self.agent.recurrent:
            valid = torch.ones_like(samples.done, dtype=torch.float)  # or None
        else:
            valid = valid_from_done(samples.done)
        if self.bootstrap_timelimit:
            # To avoid non-use of bootstrap when environment is 'done' due to
            # time-limit, turn off training on these samples.
            valid *= (1 - samples.timeout_n.float())

        q1_logits, q2_logits = self.agent.q(*agent_inputs, detach_encoder=self.detach_critic)
        
        # print(action, action.requires_grad)
        q1 = q1_logits.gather(1, action.long().unsqueeze(-1))
        q2 = q2_logits.gather(1, action.long().unsqueeze(-1))

        with torch.no_grad():
            target_action, target_log_pi, target_dist_info = self.agent.pi(*target_inputs) # TODO Get act prob, and correct target log pi
            target_q1, target_q2 = self.agent.target_q(*target_inputs) # Note: remove action input
            min_target_q = torch.min(target_q1, target_q2)
            target_value = torch.sum(target_dist_info.prob * (min_target_q - self._alpha * target_log_pi), dim=1, keepdims=True) # TODO: Verify if this should be sum or mean.
        disc = self.discount ** self.n_step_return
        y = (self.reward_scale * samples.return_.unsqueeze(-1) +
            (1 - samples.done_n.float().unsqueeze(-1)) * disc * target_value)

        q1_loss = 0.5 * valid_mean((y - q1) ** 2, valid)
        q2_loss = 0.5 * valid_mean((y - q2) ** 2, valid)
        critic_loss = q1_loss + q2_loss
        
        new_action, log_pi, dist_info = self.agent.pi(*agent_inputs, detach_encoder=True)

        log_target1, log_target2 = self.agent.q(*agent_inputs, detach_encoder=True)
        min_log_target = torch.min(log_target1, log_target2)
        # prior_log_pi = self.get_action_prior(new_action.cpu())

        pi_losses = torch.sum(dist_info.prob * (self._alpha * log_pi - min_log_target), dim=1, keepdims=True)
        # print("Losses Shape", pi_losses.shape)
        # if self.policy_output_regularization > 0:
        #     pi_losses += self.policy_output_regularization * torch.mean(
        #         0.5 * pi_mean ** 2 + 0.5 * pi_log_std ** 2, dim=-1)

        pi_loss = valid_mean(pi_losses, valid)

        if self.target_entropy is not None:
            pi_entropy = torch.sum(dist_info.prob * log_pi, dim=1)
            alpha_losses = - self._log_alpha * (pi_entropy.detach() + self.target_entropy)
            alpha_loss = valid_mean(alpha_losses, valid)

            # Alternate Alpha Loss from explicit entropy? # TODO: Investigate
            # alpha_losses = - self.log_alpha * ()
        else:
            alpha_loss = None

        losses = (critic_loss, pi_loss, alpha_loss)
        values = tuple(val.detach() for val in (q1, q2, log_pi))
        return losses, values