Пример #1
0
    def loss(self, samples):
        """Samples have leading batch dimension [B,..] (but not time)."""
        qs = self.agent(*samples.agent_inputs)
        q = select_at_indexes(samples.action, qs)
        with torch.no_grad():
            target_qs = self.agent.target(*samples.target_inputs)
            if self.double_dqn:
                next_qs = self.agent(*samples.target_inputs)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
        disc_target_q = (self.discount**self.n_step_return) * target_q
        y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q
        delta = y - q
        losses = 0.5 * delta**2
        abs_delta = abs(delta)
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights
        td_abs_errors = torch.clamp(abs_delta.detach(), 0, self.delta_clip)
        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            td_abs_errors *= valid
        else:
            loss = torch.mean(losses)

        return loss, td_abs_errors
Пример #2
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's option and action distributions from inputs.
     Calls model to get mean, std for all pi_w, q, beta for all options, pi over options
     Moves inputs to device and returns outputs back to CPU, for the
     sampler.  (no grad)
     """
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs)
     dist_info_omega = DistInfo(prob=pi)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOC(dist_info=dist_info,
                              dist_info_o=dist_info_o,
                              q=q,
                              value=(pi * q).sum(-1),
                              termination=terminations,
                              dist_info_omega=dist_info_omega,
                              prev_o=self._prev_option,
                              o=new_o)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Пример #3
0
    def loss(self, samples):
        """
        Computes the Distributional Q-learning loss, based on projecting the
        discounted rewards + target Q-distribution into the current Q-domain,
        with cross-entropy loss.

        Returns loss and KL-divergence-errors for use in prioritization.
        """

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Makde 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount ** self.n_step_return)  # [P']
        next_z = torch.ger(1 - samples.done_n.float(), next_z)  # [B,P']
        ret = samples.return_.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(*samples.target_inputs)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(*samples.target_inputs)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent(*samples.agent_inputs)  # [B,A,P]
        p = select_at_indexes(samples.action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= samples.is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
            (torch.log(target_p) - torch.log(p.detach())), dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(samples.done)
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)

        return loss, KL_div
Пример #4
0
    def sample(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist
        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)

        if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
            cat_sample = cat_sample.unsqueeze(0)

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = select_at_indexes(cat_sample, mu)
            log_std = select_at_indexes(cat_sample, log_std)

            if len(prob.shape) == 1: # Edge case for when it gets buffer shapes
                mu, log_std = mu.squeeze(0), log_std.squeeze(0)

            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        if self.training:
            self.delta_distribution.set_std(None)
        else:
            self.delta_distribution.set_std(0)
        delta_sample = self.delta_distribution.sample(new_dist_info)
        return torch.cat((one_hot, delta_sample), dim=-1)
Пример #5
0
    def rl_loss(self, latent, action, return_n, done_n, prev_action,
                prev_reward, next_state, next_prev_action, next_prev_reward,
                is_weights, done):

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Make 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount**self.n_step_return)  # [P']
        next_z = torch.ger(1 - done_n.float(), next_z)  # [B,P']
        ret = return_n.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(next_state, next_prev_action,
                                          next_prev_reward)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(next_state, next_prev_action,
                                     next_prev_reward)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent.head_forward(latent, prev_action,
                                     prev_reward)  # [B,A,P]
        p = select_at_indexes(action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
                           (torch.log(target_p) - torch.log(p.detach())),
                           dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(done[1])
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)

        return loss, KL_div
Пример #6
0
    def loss(self, samples):
        """
        Computes the Q-learning loss, based on: 0.5 * (Q - target_Q) ^ 2.
        Implements regular DQN or Double-DQN for computing target_Q values
        using the agent's target network.  Computes the Huber loss using 
        ``delta_clip``, or if ``None``, uses MSE.  When using prioritized
        replay, multiplies losses by importance sample weights.

        Input ``samples`` have leading batch dimension [B,..] (but not time).

        Calls the agent to compute forward pass on training inputs, and calls
        ``agent.target()`` to compute target values.

        Returns loss and TD-absolute-errors for use in prioritization.

        Warning: 
            If not using mid_batch_reset, the sampler will only reset environments
            between iterations, so some samples in the replay buffer will be
            invalid.  This case is not supported here currently.
        """
        qs = self.agent(*samples.agent_inputs)
        q = select_at_indexes(samples.action, qs)
        with torch.no_grad():
            target_qs = self.agent.target(*samples.target_inputs)
            if self.double_dqn:
                next_qs = self.agent(*samples.target_inputs)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
        disc_target_q = (self.discount**self.n_step_return) * target_q
        y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q
        delta = y - q
        losses = 0.5 * delta**2
        abs_delta = abs(delta)
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights
        td_abs_errors = abs_delta.detach()
        if self.delta_clip is not None:
            td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip)
        if not self.mid_batch_reset:
            # FIXME: I think this is wrong, because the first "done" sample
            # is valid, but here there is no [T] dim, so there's no way to
            # know if a "done" sample is the first "done" in the sequence.
            raise NotImplementedError
            # valid = valid_from_done(samples.done)
            # loss = valid_mean(losses, valid)
            # td_abs_errors *= valid
        else:
            loss = torch.mean(losses)

        return loss, td_abs_errors
Пример #7
0
 def value(self, observation, prev_action, prev_reward, device="cpu"):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     _pi, beta, q, pi_omega = self.model(*model_inputs)
     v = (q * pi_omega).sum(
         -1
     )  # Weight q value by probability of option. Average value if terminal
     q_prev_o = select_at_indexes(self.prev_option, q)
     beta_prev_o = select_at_indexes(self.prev_option, beta)
     value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o
     return value.to(device)
Пример #8
0
    def dqn_rl_loss(self, qs, samples, index):
        """
        Computes the Q-learning loss, based on: 0.5 * (Q - target_Q) ^ 2.
        Implements regular DQN or Double-DQN for computing target_Q values
        using the agent's target network.  Computes the Huber loss using
        ``delta_clip``, or if ``None``, uses MSE.  When using prioritized
        replay, multiplies losses by importance sample weights.

        Input ``samples`` have leading batch dimension [B,..] (but not time).

        Calls the agent to compute forward pass on training inputs, and calls
        ``agent.target()`` to compute target values.

        Returns loss and TD-absolute-errors for use in prioritization.

        Warning:
            If not using mid_batch_reset, the sampler will only reset environments
            between iterations, so some samples in the replay buffer will be
            invalid.  This case is not supported here currently.
        """
        q = select_at_indexes(samples.all_action[index + 1], qs).cpu()
        with torch.no_grad():
            target_qs = self.agent.target(
                samples.all_observation[index + self.n_step_return],
                samples.all_action[index + self.n_step_return],
                samples.all_reward[index + self.n_step_return])  # [B,A,P']
            if self.double_dqn:
                next_qs = self.agent(
                    samples.all_observation[index + self.n_step_return],
                    samples.all_action[index + self.n_step_return],
                    samples.all_reward[index + self.n_step_return])  # [B,A,P']
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values

            disc_target_q = (self.discount**self.n_step_return) * target_q
            y = samples.return_[index] + (
                1 - samples.done_n[index].float()) * disc_target_q

        delta = y - q
        losses = 0.5 * delta**2
        abs_delta = abs(delta)
        if self.delta_clip > 0:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        td_abs_errors = abs_delta.detach()
        if self.delta_clip > 0:
            td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip)
        return losses, td_abs_errors
Пример #9
0
    def dist_rl_loss(self, log_pred_ps, samples, index):
        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Make 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount**self.n_step_return)  # [P']
        next_z = torch.ger(1 - samples.done_n[index].float(), next_z)  # [B,P']
        ret = samples.return_[index].unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(
                samples.all_observation[index + self.n_step_return],
                samples.all_action[index + self.n_step_return],
                samples.all_reward[index + self.n_step_return])  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(
                    samples.all_observation[index + self.n_step_return],
                    samples.all_action[index + self.n_step_return],
                    samples.all_reward[index + self.n_step_return])  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        p = select_at_indexes(samples.all_action[index + 1].squeeze(-1),
                              log_pred_ps.cpu())  # [B,P]
        # p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * p, dim=1)  # Cross-entropy.

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p * (torch.log(target_p) - p.detach()),
                           dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        return losses, KL_div.detach()
Пример #10
0
    def sample_loglikelihood(self, dist_info):
        logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist

        u = torch.rand_like(logits)
        u = torch.clamp(u, 1e-5, 1 - 1e-5)
        gumbel = -torch.log(-torch.log(u))
        prob = F.softmax((logits + gumbel) / 10, dim=-1)

        cat_sample = torch.argmax(prob, dim=-1)
        cat_loglikelihood = select_at_indexes(cat_sample, prob)

        one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
        one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob

        if self._all_corners:
            mu, log_std = delta_dist_info.mean, delta_dist_info.log_std
            mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3)
            mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)]
            new_dist_info = DistInfoStd(mean=mu, log_std=log_std)
        else:
            new_dist_info = delta_dist_info

        delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info)
        action = torch.cat((one_hot, delta_sample), dim=-1)
        log_likelihood = cat_loglikelihood + delta_loglikelihood
        return action, log_likelihood
Пример #11
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              device="cpu"):
     """Performs forward pass on training data, for algorithm. Returns sampled distinfo, q, beta, and piomega distinfo"""
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, sampled_option),
         device=self.device)
     mu, log_std, beta, q, pi = self.model(*model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     return buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
Пример #12
0
    def compute_true_delta(self, samples):
        """
        Helper method with no training purpose. Only purpose is to compute the
        "true" return as samples come in, make the current Q estimate and
        see what the difference is (i.e. for evaluation and logging only)

        NOTE: if multiple trajectories are collected in a single sample,
              only the first trajectory will be used.
        :param samples: samples from environment sampler
        :return: tensor of delta between true G and predicted Q and target Q
                 of shape (T, 1)  (T being the length of valid traj)
        """

        # Extract information to estimate Q
        all_observation, all_action, all_reward = buffer_to(
            (samples.env.observation.clone().detach(),
             samples.agent.prev_action.clone().detach(),
             samples.env.prev_reward.clone().detach()),
            device=self.agent.device)

        action = samples.agent.prev_action[1:self.batch_T + 1]
        return_ = samples.env.reward[0:self.batch_T]
        done_n = samples.env.done[0:self.batch_T]

        # Get the behaviour Qs and target max q
        input_buffer = (all_observation, all_action, all_reward)
        with torch.no_grad():
            qs, target_q = self.compute_q_predictions(input_buffer)
            q = select_at_indexes(action, qs)

        # Valid length
        valid = valid_from_done(done_n)
        valid_T = int(torch.sum(valid))

        # lambda target
        lambda_G = self.compute_lambda_return(return_, target_q,
                                              valid)  # (T, 1)

        # ==
        # Compute true return (highly specific to the delay action.py env)
        # NOTE: this is built specifically for the action independent, pure
        #       prediction variant of the delayed_actions.py env
        arm_num = int(samples.env.env_info.arm_num[(valid_T - 1)])
        true_R = 1.0 if (arm_num == 1) else -1.0

        true_G = torch.zeros((valid_T, 1))
        true_G[-1] = true_R
        for i in reversed(range(valid_T - 1)):
            true_G[i] = self.discount * true_G[i + 1]
        true_G[0] = 0.0  # first state has expected 0

        # ==
        # Compute delta to true value
        predic_true_delta = true_G - q[:valid_T]
        target_true_delta = true_G - lambda_G[:valid_T]

        return predic_true_delta, target_true_delta
Пример #13
0
    def value(self, observation, prev_action, prev_reward, device="cpu"):
        """
        Compute the value estimate for the environment state, e.g. for the
        bootstrap value, V(s_{T+1}), in the sampler.

        For option-critic algorithms, this is the q(s_{T+1}, prev_o) * (1-beta(s_{T+1}, prev_o)) +
        beta(s_{T+1}, prev_o) * sum_{o} pi_omega(o|s_{T+1}) * q(s_{T+1}, o)
        (no grad)
        """
        model_inputs = buffer_to((observation, prev_action, prev_reward),
                                 device=self.device)
        _mu, _log_std, beta, q, pi = self.model(*model_inputs)  # [B, nOpt]
        v = (q * pi).sum(
            -1
        )  # Weight q value by probability of option. Average value if terminal
        q_prev_o = select_at_indexes(self.prev_option, q)
        beta_prev_o = select_at_indexes(self.prev_option, beta)
        value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o
        return value.to(device)
Пример #14
0
 def value(self, observation, prev_action, prev_reward, device="cpu"):
     prev_option_input = self._prev_option
     if prev_option_input is None:  # Hack to extract previous option
         prev_option_input = torch.full_like(prev_action, -1)
     prev_action = self.distribution.to_onehot(prev_action)
     prev_option_input = self.distribution_omega.to_onehot_with_invalid(
         prev_option_input)
     agent_inputs = buffer_to(
         (observation, prev_action, prev_reward, prev_option_input),
         device=self.device)
     _pi, beta, q, pi_omega, _rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     v = (q * pi_omega).sum(
         -1
     )  # Weight q value by probability of option. Average value if terminal
     q_prev_o = select_at_indexes(self.prev_option, q)
     beta_prev_o = select_at_indexes(self.prev_option, beta)
     value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o
     return value.to(device)
Пример #15
0
    def compute_q_predictions(self, input_buffer):
        """
        Compute the behaviour and target network Q predictions
        Note this is a separate method since I re-use the method during
        training and also to evaluate progress on new sampled trajectories

        :param input_buffer: observations, actions and reward of a trajectory
        :return: behaviour qs (size [T, B, A]) and target_q (size [T, B])
        """

        # Unpack the RNN input buffer
        all_observation, all_action, all_reward = input_buffer

        # all_action = torch.zeros(all_action.size())
        # all_reward = torch.zeros(all_reward.size())  # TODO make this a feature in future?

        # ==
        # Compute Q estimates (NOTE: no RNN warm-up)
        agent_slice = slice(0, self.batch_T)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice].clone().detach(),
            prev_action=all_action[agent_slice].clone().detach(),
            prev_reward=all_reward[agent_slice].clone().detach(),
        )
        target_slice = slice(0, None)  # Same start t as agent. (0 + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )

        # NOTE: always initialize to None; assume to always have full traj
        # For how to sample rnn intermediate state from mid-run, see
        # https://github.com/astooke/rlpyt/blob/f04f23db1eb7b5915d88401fca67869968a07a37
        # /rlpyt/algos/dqn/r2d1.py#L280
        init_rnn_state = None
        target_rnn_state = None  # NOTE: no RNN warmup for target

        # Behavioural net Q estimate
        qs, _ = self.agent(*agent_inputs, init_rnn_state)  # [T,B,A]

        # Target network Q estimates
        with torch.no_grad():
            target_qs, _ = self.agent.target(*target_inputs, target_rnn_state)
            if self.double_dqn:
                next_qs, _ = self.agent(*target_inputs, init_rnn_state)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
            target_q = target_q[-self.batch_T:]  # Same length as q.

        return qs, target_q
Пример #16
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              init_rnn_state,
              device="cpu"):
     """Performs forward pass on training data, for algorithm (requires
     recurrent state input). Returnssampled distinfo, q, beta, and piomega distinfo"""
     # Assume init_rnn_state already shaped: [N,B,H]
     model_inputs = buffer_to((observation, prev_action, prev_reward,
                               init_rnn_state, sampled_option),
                              device=self.device)
     mu, log_std, beta, q, pi, next_rnn_state = self.model(
         *model_inputs[:-1])
     # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo)
     dist_info, q, beta, dist_info_omega = buffer_to(
         (DistInfoStd(mean=select_at_indexes(sampled_option, mu),
                      log_std=select_at_indexes(sampled_option, log_std)),
          q, beta, DistInfo(prob=pi)),
         device=device)
     return dist_info, q, beta, dist_info_omega, next_rnn_state  # Leave rnn_state on device.
Пример #17
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     """
     Compute policy's action distribution from inputs, and sample an
     action. Calls the model to produce mean, log_std, value estimate, and
     next recurrent state.  Moves inputs to device and returns outputs back
     to CPU, for the sampler.  Advances the recurrent state of the agent.
     (no grad)
     """
     agent_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     mu, log_std, beta, q, pi, rnn_state = self.model(
         *agent_inputs, self.prev_rnn_state)
     terminations = torch.bernoulli(beta).bool()  # Sample terminations
     dist_info_omega = DistInfo(prob=pi)
     new_o = self.sample_option(terminations, dist_info_omega)
     dist_info = DistInfoStd(mean=mu, log_std=log_std)
     dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu),
                               log_std=select_at_indexes(new_o, log_std))
     action = self.distribution.sample(dist_info_o)
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi * q).sum(-1),
                                 termination=terminations,
                                 inter_option_dist_info=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_rnn_state(rnn_state)  # Keep on device.
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Пример #18
0
 def __call__(self,
              observation,
              prev_action,
              prev_reward,
              sampled_option,
              device="cpu"):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, sampled_option),
         device=self.device)
     pi, beta, q, pi_omega = self.model(*model_inputs[:-1])
     return buffer_to(
         (DistInfo(prob=select_at_indexes(sampled_option, pi)), q, beta,
          DistInfo(prob=pi_omega)),
         device=device)
Пример #19
0
    def select_at_indexes(self, indexes, tensor):
        """Returns the `tensor` data at the multi-dimensional integer array `indexes`.

        Parameters
        ----------
        indexes: tensor
            a tensor of indexes.
        tensor: tensor
            a tensor from which to retrieve the data of interest.

        Return
        ----------
        result: tensor
            the resulting data.
        """
        return select_at_indexes(indexes, tensor)
Пример #20
0
 def compute_input_priorities(self, samples):
     """Just for first input into replay buffer.
     Simple 1-step return TD-errors using recorded Q-values from online
     network and value scaling, with the T dimension reduced away (same
     priority applied to all samples in this batch; whereever the rnn state
     is kept--hopefully the first step--this priority will apply there).
     The samples duration T might be less than the training segment, so
     this is an approximation of an approximation, but hopefully will
     capture the right behavior.
     UPDATE 20190826: Trying using n-step returns.  For now using samples
     with full n-step return available...later could also use partial
     returns for samples at end of batch.  35/40 ain't bad tho.
     Might not carry/use internal state here, because might get executed
     by alternating memory copiers in async mode; do all with only the 
     samples avialable from input."""
     samples = torchify_buffer(samples)
     q = samples.agent.agent_info.q
     action = samples.agent.action
     q_max = torch.max(q, dim=-1).values
     q_at_a = select_at_indexes(action, q)
     return_n, done_n = discount_return_n_step(
         reward=samples.env.reward,
         done=samples.env.done,
         n_step=self.n_step_return,
         discount=self.discount,
         do_truncated=False,  # Only samples with full n-step return.
     )
     # y = self.value_scale(
     #     samples.env.reward[:-1] +
     #     (self.discount * (1 - samples.env.done[:-1].float()) *  # probably done.float()
     #         self.inv_value_scale(q_max[1:]))
     # )
     nm1 = max(1,
               self.n_step_return - 1)  # At least 1 bc don't have next Q.
     y = self.value_scale(return_n + (1 - done_n.float()) *
                          self.inv_value_scale(q_max[nm1:]))
     delta = abs(q_at_a[:-nm1] - y)
     # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
     if self.delta_clip is not None:  # Huber loss.
         delta = torch.clamp(delta, 0, self.delta_clip)
     valid = valid_from_done(samples.env.done[:-nm1])
     max_d = torch.max(delta * valid, dim=0).values
     mean_d = valid_mean(delta, valid, dim=0)  # Still high if less valid.
     priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]
     return priorities.numpy()
Пример #21
0
    def sample_loglikelihood(self, dist_info):
        if isinstance(dist_info, DistInfoStd):
            action, log_likelihood = self.delta_distribution.sample_loglikelihood(dist_info)
        else:
            logits = dist_info

            u = torch.rand_like(logits)
            u = torch.clamp(u, 1e-5, 1 - 1e-5)
            gumbel = -torch.log(-torch.log(u))
            prob = F.softmax((logits + gumbel) / 10, dim=-1)

            cat_sample = torch.argmax(prob, dim=-1)
            log_likelihood = select_at_indexes(cat_sample, prob)

            one_hot = to_onehot(cat_sample, 4, dtype=torch.float32)
            action = (one_hot - prob).detach() + prob  # Make action differentiable through prob

        return action, log_likelihood
Пример #22
0
 def sample_option(self, betas, option_dist_info):
     """Sample options according to which previous options are terminated and probability over options"""
     if self._prev_option is None:  # No previous option, store as -1
         self._prev_option = torch.full(betas.size()[:-1],
                                        -1,
                                        dtype=torch.long,
                                        device=betas.device)
     terminations = select_at_indexes(self._prev_option,
                                      torch.bernoulli(betas).bool())
     options = self._prev_option.clone()
     new_o = self.distribution_omega.sample(option_dist_info).expand_as(
         self._prev_option)
     options[self._prev_option == -1] = new_o[
         self._prev_option == -1]  # Must terminate, episode reset
     mask = self._prev_option != -1
     options[mask] = torch.where(
         terminations.view(-1)[mask].flatten(), new_o[mask],
         self._prev_option[mask])
     return options, terminations
Пример #23
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_action = self.distribution.to_onehot(prev_action)
     model_inputs = buffer_to((observation, prev_action, prev_reward),
                              device=self.device)
     pi, beta, q, pi_omega = self.model(*model_inputs)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOC(dist_info=dist_info,
                              dist_info_o=dist_info_o,
                              q=q,
                              value=(pi_omega * q).sum(-1),
                              termination=terminations,
                              dist_info_omega=dist_info_omega,
                              prev_o=self._prev_option,
                              o=new_o)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     return AgentStep(action=action, agent_info=agent_info)
Пример #24
0
 def step(self, observation, prev_action, prev_reward, device="cpu"):
     prev_option_input = self._prev_option
     if prev_option_input is None:  # Hack to extract previous option
         prev_option_input = torch.full_like(prev_action, -1)
     prev_action = self.distribution.to_onehot(prev_action)
     prev_option_input = self.distribution_omega.to_onehot_with_invalid(
         prev_option_input)
     model_inputs = buffer_to(
         (observation, prev_action, prev_reward, prev_option_input),
         device=self.device)
     pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs,
                                                   self.prev_rnn_state)
     dist_info_omega = DistInfo(prob=pi_omega)
     new_o, terminations = self.sample_option(
         beta, dist_info_omega)  # Sample terminations and options
     # Model handles None, but Buffer does not, make zeros if needed:
     prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(
         rnn_state, torch.zeros_like)
     # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage.
     # (Special case: model should always leave B dimension in.)
     prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1)
     dist_info = DistInfo(prob=pi)
     dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi))
     action = self.distribution.sample(dist_info_o)
     agent_info = AgentInfoOCRnn(dist_info=dist_info,
                                 dist_info_o=dist_info_o,
                                 q=q,
                                 value=(pi_omega * q).sum(-1),
                                 termination=terminations,
                                 dist_info_omega=dist_info_omega,
                                 prev_o=self._prev_option,
                                 o=new_o,
                                 prev_rnn_state=prev_rnn_state)
     action, agent_info = buffer_to((action, agent_info), device=device)
     self.advance_oc_state(new_o)
     self.advance_rnn_state(rnn_state)
     return AgentStep(action=action, agent_info=agent_info)
Пример #25
0
 def likelihood_ratio(self, indexes, old_dist_info, new_dist_info):
     num = select_at_indexes(indexes, new_dist_info.prob)
     den = select_at_indexes(indexes, old_dist_info.prob)
     return (num + EPS) / (den + EPS)
Пример #26
0
 def log_likelihood(self, indexes, dist_info):
     selected_likelihood = select_at_indexes(indexes, dist_info.prob)
     return torch.log(selected_likelihood + EPS)
Пример #27
0
    def loss(self, samples):
        """Samples have leading Time and Batch dimentions [T,B,..]. Move all
        samples to device first, and then slice for sub-sequences.  Use same
        init_rnn_state for agent and target; start both at same t.  Warmup the
        RNN state first on the warmup subsequence, then train on the remaining
        subsequence.

        Returns loss (usually use MSE, not Huber), TD-error absolute values,
        and new sequence-wise priorities, based on weighted sum of max and mean
        TD-error over the sequence."""
        all_observation, all_action, all_reward = buffer_to(
            (samples.all_observation, samples.all_action, samples.all_reward),
            device=self.agent.device)
        wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return
        if wT > 0:
            warmup_slice = slice(None, wT)  # Same for agent and target.
            warmup_inputs = AgentInputs(
                observation=all_observation[warmup_slice],
                prev_action=all_action[warmup_slice],
                prev_reward=all_reward[warmup_slice],
            )
        agent_slice = slice(wT, wT + bT)
        agent_inputs = AgentInputs(
            observation=all_observation[agent_slice],
            prev_action=all_action[agent_slice],
            prev_reward=all_reward[agent_slice],
        )
        target_slice = slice(wT, None)  # Same start t as agent. (wT + bT + nsr)
        target_inputs = AgentInputs(
            observation=all_observation[target_slice],
            prev_action=all_action[target_slice],
            prev_reward=all_reward[target_slice],
        )
        action = samples.all_action[wT + 1:wT + 1 + bT]  # CPU.
        return_ = samples.return_[wT:wT + bT]
        done_n = samples.done_n[wT:wT + bT]
        if self.store_rnn_state_interval == 0:
            init_rnn_state = None
        else:
            # [B,N,H]-->[N,B,H] cudnn.
            init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
        if wT > 0:  # Do warmup.
            with torch.no_grad():
                _, target_rnn_state = self.agent.target(*warmup_inputs, init_rnn_state)
                _, init_rnn_state = self.agent(*warmup_inputs, init_rnn_state)
            # Recommend aligning sampling batch_T and store_rnn_interval with
            # warmup_T (and no mid_batch_reset), so that end of trajectory
            # during warmup leads to new trajectory beginning at start of
            # training segment of replay.
            warmup_invalid_mask = valid_from_done(samples.done[:wT])[-1] == 0  # [B]
            init_rnn_state[:, warmup_invalid_mask] = 0  # [N,B,H] (cudnn)
            target_rnn_state[:, warmup_invalid_mask] = 0
        else:
            target_rnn_state = init_rnn_state

        qs, _ = self.agent(*agent_inputs, init_rnn_state)  # [T,B,A]
        q = select_at_indexes(action, qs)
        with torch.no_grad():
            target_qs, _ = self.agent.target(*target_inputs, target_rnn_state)
            if self.double_dqn:
                next_qs, _ = self.agent(*target_inputs, init_rnn_state)
                next_a = torch.argmax(next_qs, dim=-1)
                target_q = select_at_indexes(next_a, target_qs)
            else:
                target_q = torch.max(target_qs, dim=-1).values
            target_q = target_q[-bT:]  # Same length as q.

        disc = self.discount ** self.n_step_return
        y = self.value_scale(return_ + (1 - done_n.float()) * disc *
            self.inv_value_scale(target_q))  # [T,B]
        delta = y - q
        losses = 0.5 * delta ** 2
        abs_delta = abs(delta)
        # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
        if self.delta_clip is not None:  # Huber loss.
            b = self.delta_clip * (abs_delta - self.delta_clip / 2)
            losses = torch.where(abs_delta <= self.delta_clip, losses, b)
        if self.prioritized_replay:
            losses *= samples.is_weights.unsqueeze(0)  # weights: [B] --> [1,B]
        valid = valid_from_done(samples.done[wT:])  # 0 after first done.
        loss = valid_mean(losses, valid)
        td_abs_errors = abs_delta.detach()
        if self.delta_clip is not None:
            td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip)  # [T,B]
        valid_td_abs_errors = td_abs_errors * valid
        max_d = torch.max(valid_td_abs_errors, dim=0).values
        mean_d = valid_mean(td_abs_errors, valid, dim=0)  # Still high if less valid.
        priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]

        return loss, valid_td_abs_errors, priorities
Пример #28
0
    def compute_input_priorities(self, samples):
        """Used when putting new samples into the replay buffer.  Computes
        n-step TD-errors using recorded Q-values from online network and
        value scaling.  Weights the max and the mean TD-error over each sequence
        to make a single priority value for that sequence.  

        Note:
            Although the original R2D2 implementation used the entire
            80-step sequence to compute the input priorities, we ran R2D1 with 40
            time-step sample batches, and so computed the priority for each
            80-step training sequence based on one of the two 40-step halves.
            Algorithm argument ``input_priority_shift`` determines which 40-step
            half is used as the priority for the 80-step sequence.  (Since this 
            method might get executed by alternating memory copiers in async mode,
            don't carry internal state here, do all computation with only the samples
            available in input.  Could probably reduce to one memory copier and keep
            state there, if needed.)
        """

        # """Just for first input into replay buffer.
        # Simple 1-step return TD-errors using recorded Q-values from online
        # network and value scaling, with the T dimension reduced away (same
        # priority applied to all samples in this batch; whereever the rnn state
        # is kept--hopefully the first step--this priority will apply there).
        # The samples duration T might be less than the training segment, so
        # this is an approximation of an approximation, but hopefully will
        # capture the right behavior.
        # UPDATE 20190826: Trying using n-step returns.  For now using samples
        # with full n-step return available...later could also use partial
        # returns for samples at end of batch.  35/40 ain't bad tho.
        # Might not carry/use internal state here, because might get executed
        # by alternating memory copiers in async mode; do all with only the
        # samples avialable from input."""
        samples = torchify_buffer(samples)
        q = samples.agent.agent_info.q
        action = samples.agent.action
        q_max = torch.max(q, dim=-1).values
        q_at_a = select_at_indexes(action, q)
        return_n, done_n = discount_return_n_step(
            reward=samples.env.reward,
            done=samples.env.done,
            n_step=self.n_step_return,
            discount=self.discount,
            do_truncated=False,  # Only samples with full n-step return.
        )
        # y = self.value_scale(
        #     samples.env.reward[:-1] +
        #     (self.discount * (1 - samples.env.done[:-1].float()) *  # probably done.float()
        #         self.inv_value_scale(q_max[1:]))
        # )
        nm1 = max(1, self.n_step_return - 1)  # At least 1 bc don't have next Q.
        y = self.value_scale(return_n +
            (1 - done_n.float()) * self.inv_value_scale(q_max[nm1:]))
        delta = abs(q_at_a[:-nm1] - y)
        # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None.
        if self.delta_clip is not None:  # Huber loss.
            delta = torch.clamp(delta, 0, self.delta_clip)
        valid = valid_from_done(samples.env.done[:-nm1])
        max_d = torch.max(delta * valid, dim=0).values
        mean_d = valid_mean(delta, valid, dim=0)  # Still high if less valid.
        priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d  # [B]
        return priorities.numpy()
Пример #29
0
    def loss(self, samples):
        """
        Computes the training loss: policy_loss + value_loss + entropy_loss.
        Policy loss: log-likelihood of actions * advantages
        Value loss: 0.5 * (estimated_value - return) ^ 2
        Organizes agent inputs from training samples, calls the agent instance
        to run forward pass on training data, and uses the
        ``agent.distribution`` to compute likelihoods and entropies.  Valid
        for feedforward or recurrent agents.
        """
        agent_inputs = AgentInputsOC(  # Move inputs to device once, index there.
            observation=samples.env.observation,
            prev_action=samples.agent.prev_action,
            prev_reward=samples.env.prev_reward,
            sampled_option=samples.agent.agent_info.o,
        )
        if self.agent.recurrent:
            init_rnn_state = samples.agent.agent_info.prev_rnn_state[
                0]  # T = 0.
            # [B,N,H] --> [N,B,H] (for cudnn).
            init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1)
            init_rnn_state = buffer_method(init_rnn_state, "contiguous")
            po = samples.agent.agent_info.prev_o
            (dist_info_o, q, beta, dist_info_omega), _rnn_state = self.agent(
                *agent_inputs,
                po,
                init_rnn_state,
                device=agent_inputs.prev_action.device)
        else:
            dist_info_o, q, beta, dist_info_omega = self.agent(
                *agent_inputs, device=agent_inputs.prev_action.device)
        dist = self.agent.distribution
        dist_omega = self.agent.distribution_omega
        # TODO: try to compute everyone on device.
        return_, advantage, valid, beta_adv, not_init_states, op_adv = self.process_returns(
            samples)

        logli = dist.log_likelihood(samples.agent.action, dist_info_o)
        pi_loss = -valid_mean(logli * advantage, valid)

        o = samples.agent.agent_info.o
        q_o = select_at_indexes(o, q)
        value_error = 0.5 * (q_o - return_)**2
        value_loss = self.value_loss_coeff * valid_mean(value_error, valid)

        # Termination loss
        prev_o = samples.agent.agent_info.prev_o
        beta_prev_o = select_at_indexes(prev_o, beta)
        beta_error = beta_prev_o * beta_adv
        beta_loss = self.termination_loss_coeff * valid_mean(
            beta_error, not_init_states)

        logli = dist_omega.log_likelihood(o, dist_info_omega)
        # pi_omega_loss = - valid_mean(logli * advantage, valid)
        pi_omega_loss = -valid_mean(logli * op_adv, valid)

        entropy = dist.mean_entropy(dist_info_o, valid)
        entropy_loss = -self.entropy_loss_coeff * entropy
        entropy_o = dist_omega.mean_entropy(dist_info_omega, valid)
        entropy_loss_omega = -self.omega_entropy_loss_coeff * entropy_o

        loss = pi_loss + pi_omega_loss + beta_loss + value_loss + entropy_loss + entropy_loss_omega

        # perplexity = dist.mean_perplexity(dist_info_o, valid)

        return loss, pi_loss, value_loss, beta_loss, pi_omega_loss, entropy, entropy_o
Пример #30
0
    def loss(self, samples, itr, samples_nce):
        """Samples have leading batch dimension [B,..] (but not time)."""
        self.args['device'] = self.agent.device
        """
        Get rlpyt batch inputs and write them to GPU tensors
        """
        rl_agent_inputs = AgentInputs(
            observation=samples.agent_inputs.observation,
            prev_action=samples.agent_inputs.prev_action,
            prev_reward=None)
        rl_action = samples.action
        rl_return_ = samples.return_
        rl_target_inputs = AgentInputs(
            observation=samples.target_inputs.observation,
            prev_action=samples.target_inputs.prev_action,
            prev_reward=None)
        rl_done = samples.done
        rl_done_n = samples.done_n

        self.states[
            (self.nce_counter *
             self.args['batch_size']):(self.nce_counter + 1) * self.
            args['batch_size']] = samples_nce.agent_inputs.observation.type(
                torch.float32).to(self.args['device']) / 255.
        self.actions[(self.nce_counter *
                      self.args['batch_size']):(self.nce_counter + 1) *
                     self.args['batch_size']] = samples_nce.action.type(
                         torch.int64).to(self.args['device'])
        self.returns[(self.nce_counter *
                      self.args['batch_size']):(self.nce_counter + 1) *
                     self.args['batch_size']] = samples_nce.return_.type(
                         torch.float32).to(self.args['device'])
        self.next_states[
            (self.nce_counter *
             self.args['batch_size']):(self.nce_counter + 1) * self.
            args['batch_size']] = samples_nce.target_inputs.observation.type(
                torch.float32).to(self.args['device']) / 255.
        self.nonterminals[(self.nce_counter *
                           self.args['batch_size']):(self.nce_counter + 1) *
                          self.args['batch_size']] = samples_nce.done
        if self.prioritized_replay:
            rl_is_weights = samples.is_weights
            self.weights[(self.nce_counter *
                          self.args['batch_size']):(self.nce_counter + 1) *
                         self.args['batch_size']] = samples_nce.is_weights

        self.nce_counter += 1
        """
        C51 code from rlpyt (unchanged)
        """

        delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1)
        z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms)
        # Makde 2-D tensor of contracted z_domain for each data point,
        # with zeros where next value should not be added.
        next_z = z * (self.discount**self.n_step_return)  # [P']
        next_z = torch.ger(1 - rl_done_n.float(), next_z)  # [B,P']
        ret = rl_return_.unsqueeze(1)  # [B,1]
        next_z = torch.clamp(ret + next_z, self.V_min, self.V_max)  # [B,P']

        z_bc = z.view(1, -1, 1)  # [1,P,1]
        next_z_bc = next_z.unsqueeze(1)  # [B,1,P']
        abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z
        projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1)  # Most 0.
        # projection_coeffs is a 3-D tensor: [B,P,P']
        # dim-0: independent data entries
        # dim-1: base_z atoms (remains after projection)
        # dim-2: next_z atoms (summed in projection)

        with torch.no_grad():
            target_ps = self.agent.target(*rl_target_inputs)  # [B,A,P']
            if self.double_dqn:
                next_ps = self.agent(*rl_target_inputs)  # [B,A,P']
                next_qs = torch.tensordot(next_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(next_qs, dim=-1)  # [B]
            else:
                target_qs = torch.tensordot(target_ps, z, dims=1)  # [B,A]
                next_a = torch.argmax(target_qs, dim=-1)  # [B]
            target_p_unproj = select_at_indexes(next_a, target_ps)  # [B,P']
            target_p_unproj = target_p_unproj.unsqueeze(1)  # [B,1,P']
            target_p = (target_p_unproj * projection_coeffs).sum(-1)  # [B,P]
        ps = self.agent(*rl_agent_inputs)  # [B,A,P]
        p = select_at_indexes(rl_action, ps)  # [B,P]
        p = torch.clamp(p, EPS, 1)  # NaN-guard.
        losses = -torch.sum(target_p * torch.log(p), dim=1)  # Cross-entropy.

        if self.prioritized_replay:
            losses *= rl_is_weights

        target_p = torch.clamp(target_p, EPS, 1)
        KL_div = torch.sum(target_p *
                           (torch.log(target_p) - torch.log(p.detach())),
                           dim=1)
        KL_div = torch.clamp(KL_div, EPS, 1 / EPS)  # Avoid <0 from NaN-guard.

        if not self.mid_batch_reset:
            valid = valid_from_done(rl_done)
            loss = valid_mean(losses, valid)
            KL_div *= valid
        else:
            loss = torch.mean(losses)
        # else:
        #     KL_div = torch.tensor([0.]).cpu()
        #     loss = torch.tensor([0.]).to(self.args['device'])
        """
        NCE loss
        """
        loss_device = loss.get_device()
        if self.args['lambda_LL'] != 0 or self.args[
                'lambda_LG'] != 0 or self.args['lambda_GL'] != 0 or self.args[
                    'lambda_GG'] != 0:
            """
            Compute this only if one of the 4 lambdas != 0
            """
            if self.args['nce_batch_size'] // self.args[
                    'batch_size'] <= self.nce_counter:
                target = None
                # Select the proper NCE loss passed as argument
                dict_nce = globals()[self.args['nce_loss']](
                    self.agent.model.model,
                    self.states,
                    self.actions,
                    self.returns,
                    self.next_states,
                    self.args,
                    target=target)

                nce_scores = self.args['lambda_LL'] * dict_nce[
                    'nce_L_L'] + self.args['lambda_LG'] * dict_nce[
                        'nce_L_G'] + self.args['lambda_GL'] * dict_nce[
                            'nce_G_L'] + self.args['lambda_GG'] * dict_nce[
                                'nce_G_G']
                device_ = nce_scores.device
                nce_scores_raw = (dict_nce['nce_L_L']
                                  if self.args['lambda_LL'] > 0 else
                                  torch.tensor(0.).to(device_)).mean()
                nce_scores_raw += (dict_nce['nce_L_G']
                                   if self.args['lambda_LG'] > 0 else
                                   torch.tensor(0.).to(device_)).mean()
                nce_scores_raw += (dict_nce['nce_G_L']
                                   if self.args['lambda_GL'] > 0 else
                                   torch.tensor(0.).to(device_)).mean()
                nce_scores_raw += (dict_nce['nce_G_G']
                                   if self.args['lambda_GG'] > 0 else
                                   torch.tensor(0.).to(device_)).mean()
                if self.prioritized_replay:
                    nce_device = nce_scores.get_device()
                    if nce_device < 0:
                        nce_scores *= samples.is_weights
                    else:
                        nce_scores *= samples.is_weights.to(nce_device)
                info_nce_loss_weighted = (
                    -nce_scores).mean()  # decay by epsilon
                nce_scores_raw = (-nce_scores_raw).mean()

                if loss_device < 0:
                    info_nce_loss_weighted = info_nce_loss_weighted.to('cpu')
                    nce_scores_raw = nce_scores_raw.to('cpu')

                # self.reset_nce_accumulators(self.agent.device)
                self.nce_counter = 0
            else:
                if loss_device > 0:
                    info_nce_loss_weighted = torch.tensor(0.).to(loss_device)
                    nce_scores_raw = torch.tensor(0.).to(loss_device)
                else:
                    info_nce_loss_weighted = torch.tensor(0.).cpu()
                    nce_scores_raw = torch.tensor(0.).cpu()
        else:
            if self.args['nce_batch_size'] // self.args[
                    'batch_size'] <= self.nce_counter:
                # self.reset_nce_accumulators(self.agent.device)
                self.nce_counter = 0
            if loss_device > 0:
                info_nce_loss_weighted = torch.tensor(0.).to(loss_device)
                nce_scores_raw = torch.tensor(0.).to(loss_device)
            else:
                info_nce_loss_weighted = torch.tensor(0.).cpu()
                nce_scores_raw = torch.tensor(0.).cpu()

        return loss + (
            self.args['nce_batch_size'] // self.batch_size
        ) * info_nce_loss_weighted, KL_div, loss, info_nce_loss_weighted, nce_scores_raw