def loss(self, samples): """Samples have leading batch dimension [B,..] (but not time).""" qs = self.agent(*samples.agent_inputs) q = select_at_indexes(samples.action, qs) with torch.no_grad(): target_qs = self.agent.target(*samples.target_inputs) if self.double_dqn: next_qs = self.agent(*samples.target_inputs) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values disc_target_q = (self.discount**self.n_step_return) * target_q y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q delta = y - q losses = 0.5 * delta**2 abs_delta = abs(delta) if self.delta_clip is not None: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) if self.prioritized_replay: losses *= samples.is_weights td_abs_errors = torch.clamp(abs_delta.detach(), 0, self.delta_clip) if not self.mid_batch_reset: valid = valid_from_done(samples.done) loss = valid_mean(losses, valid) td_abs_errors *= valid else: loss = torch.mean(losses) return loss, td_abs_errors
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's option and action distributions from inputs. Calls model to get mean, std for all pi_w, q, beta for all options, pi over options Moves inputs to device and returns outputs back to CPU, for the sampler. (no grad) """ model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, beta, q, pi = self.model(*model_inputs) dist_info_omega = DistInfo(prob=pi) new_o, terminations = self.sample_option( beta, dist_info_omega) # Sample terminations and options dist_info = DistInfoStd(mean=mu, log_std=log_std) dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu), log_std=select_at_indexes(new_o, log_std)) action = self.distribution.sample(dist_info_o) agent_info = AgentInfoOC(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi * q).sum(-1), termination=terminations, dist_info_omega=dist_info_omega, prev_o=self._prev_option, o=new_o) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def loss(self, samples): """ Computes the Distributional Q-learning loss, based on projecting the discounted rewards + target Q-distribution into the current Q-domain, with cross-entropy loss. Returns loss and KL-divergence-errors for use in prioritization. """ delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Makde 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount ** self.n_step_return) # [P'] next_z = torch.ger(1 - samples.done_n.float(), next_z) # [B,P'] ret = samples.return_.unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target(*samples.target_inputs) # [B,A,P'] if self.double_dqn: next_ps = self.agent(*samples.target_inputs) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] ps = self.agent(*samples.agent_inputs) # [B,A,P] p = select_at_indexes(samples.action, ps) # [B,P] p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy. if self.prioritized_replay: losses *= samples.is_weights target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - torch.log(p.detach())), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. if not self.mid_batch_reset: valid = valid_from_done(samples.done) loss = valid_mean(losses, valid) KL_div *= valid else: loss = torch.mean(losses) return loss, KL_div
def sample(self, dist_info): logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist u = torch.rand_like(logits) u = torch.clamp(u, 1e-5, 1 - 1e-5) gumbel = -torch.log(-torch.log(u)) prob = F.softmax((logits + gumbel) / 10, dim=-1) cat_sample = torch.argmax(prob, dim=-1) one_hot = to_onehot(cat_sample, 4, dtype=torch.float32) if len(prob.shape) == 1: # Edge case for when it gets buffer shapes cat_sample = cat_sample.unsqueeze(0) if self._all_corners: mu, log_std = delta_dist_info.mean, delta_dist_info.log_std mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3) mu = select_at_indexes(cat_sample, mu) log_std = select_at_indexes(cat_sample, log_std) if len(prob.shape) == 1: # Edge case for when it gets buffer shapes mu, log_std = mu.squeeze(0), log_std.squeeze(0) new_dist_info = DistInfoStd(mean=mu, log_std=log_std) else: new_dist_info = delta_dist_info if self.training: self.delta_distribution.set_std(None) else: self.delta_distribution.set_std(0) delta_sample = self.delta_distribution.sample(new_dist_info) return torch.cat((one_hot, delta_sample), dim=-1)
def rl_loss(self, latent, action, return_n, done_n, prev_action, prev_reward, next_state, next_prev_action, next_prev_reward, is_weights, done): delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Make 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount**self.n_step_return) # [P'] next_z = torch.ger(1 - done_n.float(), next_z) # [B,P'] ret = return_n.unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target(next_state, next_prev_action, next_prev_reward) # [B,A,P'] if self.double_dqn: next_ps = self.agent(next_state, next_prev_action, next_prev_reward) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] ps = self.agent.head_forward(latent, prev_action, prev_reward) # [B,A,P] p = select_at_indexes(action, ps) # [B,P] p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy. if self.prioritized_replay: losses *= is_weights target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - torch.log(p.detach())), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. if not self.mid_batch_reset: valid = valid_from_done(done[1]) loss = valid_mean(losses, valid) KL_div *= valid else: loss = torch.mean(losses) return loss, KL_div
def loss(self, samples): """ Computes the Q-learning loss, based on: 0.5 * (Q - target_Q) ^ 2. Implements regular DQN or Double-DQN for computing target_Q values using the agent's target network. Computes the Huber loss using ``delta_clip``, or if ``None``, uses MSE. When using prioritized replay, multiplies losses by importance sample weights. Input ``samples`` have leading batch dimension [B,..] (but not time). Calls the agent to compute forward pass on training inputs, and calls ``agent.target()`` to compute target values. Returns loss and TD-absolute-errors for use in prioritization. Warning: If not using mid_batch_reset, the sampler will only reset environments between iterations, so some samples in the replay buffer will be invalid. This case is not supported here currently. """ qs = self.agent(*samples.agent_inputs) q = select_at_indexes(samples.action, qs) with torch.no_grad(): target_qs = self.agent.target(*samples.target_inputs) if self.double_dqn: next_qs = self.agent(*samples.target_inputs) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values disc_target_q = (self.discount**self.n_step_return) * target_q y = samples.return_ + (1 - samples.done_n.float()) * disc_target_q delta = y - q losses = 0.5 * delta**2 abs_delta = abs(delta) if self.delta_clip is not None: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) if self.prioritized_replay: losses *= samples.is_weights td_abs_errors = abs_delta.detach() if self.delta_clip is not None: td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip) if not self.mid_batch_reset: # FIXME: I think this is wrong, because the first "done" sample # is valid, but here there is no [T] dim, so there's no way to # know if a "done" sample is the first "done" in the sequence. raise NotImplementedError # valid = valid_from_done(samples.done) # loss = valid_mean(losses, valid) # td_abs_errors *= valid else: loss = torch.mean(losses) return loss, td_abs_errors
def value(self, observation, prev_action, prev_reward, device="cpu"): prev_action = self.distribution.to_onehot(prev_action) model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) _pi, beta, q, pi_omega = self.model(*model_inputs) v = (q * pi_omega).sum( -1 ) # Weight q value by probability of option. Average value if terminal q_prev_o = select_at_indexes(self.prev_option, q) beta_prev_o = select_at_indexes(self.prev_option, beta) value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o return value.to(device)
def dqn_rl_loss(self, qs, samples, index): """ Computes the Q-learning loss, based on: 0.5 * (Q - target_Q) ^ 2. Implements regular DQN or Double-DQN for computing target_Q values using the agent's target network. Computes the Huber loss using ``delta_clip``, or if ``None``, uses MSE. When using prioritized replay, multiplies losses by importance sample weights. Input ``samples`` have leading batch dimension [B,..] (but not time). Calls the agent to compute forward pass on training inputs, and calls ``agent.target()`` to compute target values. Returns loss and TD-absolute-errors for use in prioritization. Warning: If not using mid_batch_reset, the sampler will only reset environments between iterations, so some samples in the replay buffer will be invalid. This case is not supported here currently. """ q = select_at_indexes(samples.all_action[index + 1], qs).cpu() with torch.no_grad(): target_qs = self.agent.target( samples.all_observation[index + self.n_step_return], samples.all_action[index + self.n_step_return], samples.all_reward[index + self.n_step_return]) # [B,A,P'] if self.double_dqn: next_qs = self.agent( samples.all_observation[index + self.n_step_return], samples.all_action[index + self.n_step_return], samples.all_reward[index + self.n_step_return]) # [B,A,P'] next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values disc_target_q = (self.discount**self.n_step_return) * target_q y = samples.return_[index] + ( 1 - samples.done_n[index].float()) * disc_target_q delta = y - q losses = 0.5 * delta**2 abs_delta = abs(delta) if self.delta_clip > 0: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) td_abs_errors = abs_delta.detach() if self.delta_clip > 0: td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip) return losses, td_abs_errors
def dist_rl_loss(self, log_pred_ps, samples, index): delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Make 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount**self.n_step_return) # [P'] next_z = torch.ger(1 - samples.done_n[index].float(), next_z) # [B,P'] ret = samples.return_[index].unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target( samples.all_observation[index + self.n_step_return], samples.all_action[index + self.n_step_return], samples.all_reward[index + self.n_step_return]) # [B,A,P'] if self.double_dqn: next_ps = self.agent( samples.all_observation[index + self.n_step_return], samples.all_action[index + self.n_step_return], samples.all_reward[index + self.n_step_return]) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] p = select_at_indexes(samples.all_action[index + 1].squeeze(-1), log_pred_ps.cpu()) # [B,P] # p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * p, dim=1) # Cross-entropy. target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - p.detach()), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. return losses, KL_div.detach()
def sample_loglikelihood(self, dist_info): logits, delta_dist_info = dist_info.cat_dist, dist_info.delta_dist u = torch.rand_like(logits) u = torch.clamp(u, 1e-5, 1 - 1e-5) gumbel = -torch.log(-torch.log(u)) prob = F.softmax((logits + gumbel) / 10, dim=-1) cat_sample = torch.argmax(prob, dim=-1) cat_loglikelihood = select_at_indexes(cat_sample, prob) one_hot = to_onehot(cat_sample, 4, dtype=torch.float32) one_hot = (one_hot - prob).detach() + prob # Make action differentiable through prob if self._all_corners: mu, log_std = delta_dist_info.mean, delta_dist_info.log_std mu, log_std = mu.view(-1, 4, 3), log_std.view(-1, 4, 3) mu = mu[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)] log_std = log_std[torch.arange(len(cat_sample)), cat_sample.squeeze(-1)] new_dist_info = DistInfoStd(mean=mu, log_std=log_std) else: new_dist_info = delta_dist_info delta_sample, delta_loglikelihood = self.delta_distribution.sample_loglikelihood(new_dist_info) action = torch.cat((one_hot, delta_sample), dim=-1) log_likelihood = cat_loglikelihood + delta_loglikelihood return action, log_likelihood
def __call__(self, observation, prev_action, prev_reward, sampled_option, device="cpu"): """Performs forward pass on training data, for algorithm. Returns sampled distinfo, q, beta, and piomega distinfo""" model_inputs = buffer_to( (observation, prev_action, prev_reward, sampled_option), device=self.device) mu, log_std, beta, q, pi = self.model(*model_inputs[:-1]) # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo) return buffer_to( (DistInfoStd(mean=select_at_indexes(sampled_option, mu), log_std=select_at_indexes(sampled_option, log_std)), q, beta, DistInfo(prob=pi)), device=device)
def compute_true_delta(self, samples): """ Helper method with no training purpose. Only purpose is to compute the "true" return as samples come in, make the current Q estimate and see what the difference is (i.e. for evaluation and logging only) NOTE: if multiple trajectories are collected in a single sample, only the first trajectory will be used. :param samples: samples from environment sampler :return: tensor of delta between true G and predicted Q and target Q of shape (T, 1) (T being the length of valid traj) """ # Extract information to estimate Q all_observation, all_action, all_reward = buffer_to( (samples.env.observation.clone().detach(), samples.agent.prev_action.clone().detach(), samples.env.prev_reward.clone().detach()), device=self.agent.device) action = samples.agent.prev_action[1:self.batch_T + 1] return_ = samples.env.reward[0:self.batch_T] done_n = samples.env.done[0:self.batch_T] # Get the behaviour Qs and target max q input_buffer = (all_observation, all_action, all_reward) with torch.no_grad(): qs, target_q = self.compute_q_predictions(input_buffer) q = select_at_indexes(action, qs) # Valid length valid = valid_from_done(done_n) valid_T = int(torch.sum(valid)) # lambda target lambda_G = self.compute_lambda_return(return_, target_q, valid) # (T, 1) # == # Compute true return (highly specific to the delay action.py env) # NOTE: this is built specifically for the action independent, pure # prediction variant of the delayed_actions.py env arm_num = int(samples.env.env_info.arm_num[(valid_T - 1)]) true_R = 1.0 if (arm_num == 1) else -1.0 true_G = torch.zeros((valid_T, 1)) true_G[-1] = true_R for i in reversed(range(valid_T - 1)): true_G[i] = self.discount * true_G[i + 1] true_G[0] = 0.0 # first state has expected 0 # == # Compute delta to true value predic_true_delta = true_G - q[:valid_T] target_true_delta = true_G - lambda_G[:valid_T] return predic_true_delta, target_true_delta
def value(self, observation, prev_action, prev_reward, device="cpu"): """ Compute the value estimate for the environment state, e.g. for the bootstrap value, V(s_{T+1}), in the sampler. For option-critic algorithms, this is the q(s_{T+1}, prev_o) * (1-beta(s_{T+1}, prev_o)) + beta(s_{T+1}, prev_o) * sum_{o} pi_omega(o|s_{T+1}) * q(s_{T+1}, o) (no grad) """ model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) _mu, _log_std, beta, q, pi = self.model(*model_inputs) # [B, nOpt] v = (q * pi).sum( -1 ) # Weight q value by probability of option. Average value if terminal q_prev_o = select_at_indexes(self.prev_option, q) beta_prev_o = select_at_indexes(self.prev_option, beta) value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o return value.to(device)
def value(self, observation, prev_action, prev_reward, device="cpu"): prev_option_input = self._prev_option if prev_option_input is None: # Hack to extract previous option prev_option_input = torch.full_like(prev_action, -1) prev_action = self.distribution.to_onehot(prev_action) prev_option_input = self.distribution_omega.to_onehot_with_invalid( prev_option_input) agent_inputs = buffer_to( (observation, prev_action, prev_reward, prev_option_input), device=self.device) _pi, beta, q, pi_omega, _rnn_state = self.model( *agent_inputs, self.prev_rnn_state) v = (q * pi_omega).sum( -1 ) # Weight q value by probability of option. Average value if terminal q_prev_o = select_at_indexes(self.prev_option, q) beta_prev_o = select_at_indexes(self.prev_option, beta) value = q_prev_o * (1 - beta_prev_o) + v * beta_prev_o return value.to(device)
def compute_q_predictions(self, input_buffer): """ Compute the behaviour and target network Q predictions Note this is a separate method since I re-use the method during training and also to evaluate progress on new sampled trajectories :param input_buffer: observations, actions and reward of a trajectory :return: behaviour qs (size [T, B, A]) and target_q (size [T, B]) """ # Unpack the RNN input buffer all_observation, all_action, all_reward = input_buffer # all_action = torch.zeros(all_action.size()) # all_reward = torch.zeros(all_reward.size()) # TODO make this a feature in future? # == # Compute Q estimates (NOTE: no RNN warm-up) agent_slice = slice(0, self.batch_T) agent_inputs = AgentInputs( observation=all_observation[agent_slice].clone().detach(), prev_action=all_action[agent_slice].clone().detach(), prev_reward=all_reward[agent_slice].clone().detach(), ) target_slice = slice(0, None) # Same start t as agent. (0 + bT + nsr) target_inputs = AgentInputs( observation=all_observation[target_slice], prev_action=all_action[target_slice], prev_reward=all_reward[target_slice], ) # NOTE: always initialize to None; assume to always have full traj # For how to sample rnn intermediate state from mid-run, see # https://github.com/astooke/rlpyt/blob/f04f23db1eb7b5915d88401fca67869968a07a37 # /rlpyt/algos/dqn/r2d1.py#L280 init_rnn_state = None target_rnn_state = None # NOTE: no RNN warmup for target # Behavioural net Q estimate qs, _ = self.agent(*agent_inputs, init_rnn_state) # [T,B,A] # Target network Q estimates with torch.no_grad(): target_qs, _ = self.agent.target(*target_inputs, target_rnn_state) if self.double_dqn: next_qs, _ = self.agent(*target_inputs, init_rnn_state) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values target_q = target_q[-self.batch_T:] # Same length as q. return qs, target_q
def __call__(self, observation, prev_action, prev_reward, sampled_option, init_rnn_state, device="cpu"): """Performs forward pass on training data, for algorithm (requires recurrent state input). Returnssampled distinfo, q, beta, and piomega distinfo""" # Assume init_rnn_state already shaped: [N,B,H] model_inputs = buffer_to((observation, prev_action, prev_reward, init_rnn_state, sampled_option), device=self.device) mu, log_std, beta, q, pi, next_rnn_state = self.model( *model_inputs[:-1]) # Need gradients from intra-option (DistInfoStd), q_o (q), termination (beta), and pi_omega (DistInfo) dist_info, q, beta, dist_info_omega = buffer_to( (DistInfoStd(mean=select_at_indexes(sampled_option, mu), log_std=select_at_indexes(sampled_option, log_std)), q, beta, DistInfo(prob=pi)), device=device) return dist_info, q, beta, dist_info_omega, next_rnn_state # Leave rnn_state on device.
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, value estimate, and next recurrent state. Moves inputs to device and returns outputs back to CPU, for the sampler. Advances the recurrent state of the agent. (no grad) """ agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, beta, q, pi, rnn_state = self.model( *agent_inputs, self.prev_rnn_state) terminations = torch.bernoulli(beta).bool() # Sample terminations dist_info_omega = DistInfo(prob=pi) new_o = self.sample_option(terminations, dist_info_omega) dist_info = DistInfoStd(mean=mu, log_std=log_std) dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu), log_std=select_at_indexes(new_o, log_std)) action = self.distribution.sample(dist_info_o) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoOCRnn(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi * q).sum(-1), termination=terminations, inter_option_dist_info=dist_info_omega, prev_o=self._prev_option, o=new_o, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_rnn_state(rnn_state) # Keep on device. self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def __call__(self, observation, prev_action, prev_reward, sampled_option, device="cpu"): prev_action = self.distribution.to_onehot(prev_action) model_inputs = buffer_to( (observation, prev_action, prev_reward, sampled_option), device=self.device) pi, beta, q, pi_omega = self.model(*model_inputs[:-1]) return buffer_to( (DistInfo(prob=select_at_indexes(sampled_option, pi)), q, beta, DistInfo(prob=pi_omega)), device=device)
def select_at_indexes(self, indexes, tensor): """Returns the `tensor` data at the multi-dimensional integer array `indexes`. Parameters ---------- indexes: tensor a tensor of indexes. tensor: tensor a tensor from which to retrieve the data of interest. Return ---------- result: tensor the resulting data. """ return select_at_indexes(indexes, tensor)
def compute_input_priorities(self, samples): """Just for first input into replay buffer. Simple 1-step return TD-errors using recorded Q-values from online network and value scaling, with the T dimension reduced away (same priority applied to all samples in this batch; whereever the rnn state is kept--hopefully the first step--this priority will apply there). The samples duration T might be less than the training segment, so this is an approximation of an approximation, but hopefully will capture the right behavior. UPDATE 20190826: Trying using n-step returns. For now using samples with full n-step return available...later could also use partial returns for samples at end of batch. 35/40 ain't bad tho. Might not carry/use internal state here, because might get executed by alternating memory copiers in async mode; do all with only the samples avialable from input.""" samples = torchify_buffer(samples) q = samples.agent.agent_info.q action = samples.agent.action q_max = torch.max(q, dim=-1).values q_at_a = select_at_indexes(action, q) return_n, done_n = discount_return_n_step( reward=samples.env.reward, done=samples.env.done, n_step=self.n_step_return, discount=self.discount, do_truncated=False, # Only samples with full n-step return. ) # y = self.value_scale( # samples.env.reward[:-1] + # (self.discount * (1 - samples.env.done[:-1].float()) * # probably done.float() # self.inv_value_scale(q_max[1:])) # ) nm1 = max(1, self.n_step_return - 1) # At least 1 bc don't have next Q. y = self.value_scale(return_n + (1 - done_n.float()) * self.inv_value_scale(q_max[nm1:])) delta = abs(q_at_a[:-nm1] - y) # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None. if self.delta_clip is not None: # Huber loss. delta = torch.clamp(delta, 0, self.delta_clip) valid = valid_from_done(samples.env.done[:-nm1]) max_d = torch.max(delta * valid, dim=0).values mean_d = valid_mean(delta, valid, dim=0) # Still high if less valid. priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d # [B] return priorities.numpy()
def sample_loglikelihood(self, dist_info): if isinstance(dist_info, DistInfoStd): action, log_likelihood = self.delta_distribution.sample_loglikelihood(dist_info) else: logits = dist_info u = torch.rand_like(logits) u = torch.clamp(u, 1e-5, 1 - 1e-5) gumbel = -torch.log(-torch.log(u)) prob = F.softmax((logits + gumbel) / 10, dim=-1) cat_sample = torch.argmax(prob, dim=-1) log_likelihood = select_at_indexes(cat_sample, prob) one_hot = to_onehot(cat_sample, 4, dtype=torch.float32) action = (one_hot - prob).detach() + prob # Make action differentiable through prob return action, log_likelihood
def sample_option(self, betas, option_dist_info): """Sample options according to which previous options are terminated and probability over options""" if self._prev_option is None: # No previous option, store as -1 self._prev_option = torch.full(betas.size()[:-1], -1, dtype=torch.long, device=betas.device) terminations = select_at_indexes(self._prev_option, torch.bernoulli(betas).bool()) options = self._prev_option.clone() new_o = self.distribution_omega.sample(option_dist_info).expand_as( self._prev_option) options[self._prev_option == -1] = new_o[ self._prev_option == -1] # Must terminate, episode reset mask = self._prev_option != -1 options[mask] = torch.where( terminations.view(-1)[mask].flatten(), new_o[mask], self._prev_option[mask]) return options, terminations
def step(self, observation, prev_action, prev_reward, device="cpu"): prev_action = self.distribution.to_onehot(prev_action) model_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) pi, beta, q, pi_omega = self.model(*model_inputs) dist_info_omega = DistInfo(prob=pi_omega) new_o, terminations = self.sample_option( beta, dist_info_omega) # Sample terminations and options dist_info = DistInfo(prob=pi) dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi)) action = self.distribution.sample(dist_info_o) agent_info = AgentInfoOC(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi_omega * q).sum(-1), termination=terminations, dist_info_omega=dist_info_omega, prev_o=self._prev_option, o=new_o) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward, device="cpu"): prev_option_input = self._prev_option if prev_option_input is None: # Hack to extract previous option prev_option_input = torch.full_like(prev_action, -1) prev_action = self.distribution.to_onehot(prev_action) prev_option_input = self.distribution_omega.to_onehot_with_invalid( prev_option_input) model_inputs = buffer_to( (observation, prev_action, prev_reward, prev_option_input), device=self.device) pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs, self.prev_rnn_state) dist_info_omega = DistInfo(prob=pi_omega) new_o, terminations = self.sample_option( beta, dist_info_omega) # Sample terminations and options # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) dist_info = DistInfo(prob=pi) dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi)) action = self.distribution.sample(dist_info_o) agent_info = AgentInfoOCRnn(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi_omega * q).sum(-1), termination=terminations, dist_info_omega=dist_info_omega, prev_o=self._prev_option, o=new_o, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_oc_state(new_o) self.advance_rnn_state(rnn_state) return AgentStep(action=action, agent_info=agent_info)
def likelihood_ratio(self, indexes, old_dist_info, new_dist_info): num = select_at_indexes(indexes, new_dist_info.prob) den = select_at_indexes(indexes, old_dist_info.prob) return (num + EPS) / (den + EPS)
def log_likelihood(self, indexes, dist_info): selected_likelihood = select_at_indexes(indexes, dist_info.prob) return torch.log(selected_likelihood + EPS)
def loss(self, samples): """Samples have leading Time and Batch dimentions [T,B,..]. Move all samples to device first, and then slice for sub-sequences. Use same init_rnn_state for agent and target; start both at same t. Warmup the RNN state first on the warmup subsequence, then train on the remaining subsequence. Returns loss (usually use MSE, not Huber), TD-error absolute values, and new sequence-wise priorities, based on weighted sum of max and mean TD-error over the sequence.""" all_observation, all_action, all_reward = buffer_to( (samples.all_observation, samples.all_action, samples.all_reward), device=self.agent.device) wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return if wT > 0: warmup_slice = slice(None, wT) # Same for agent and target. warmup_inputs = AgentInputs( observation=all_observation[warmup_slice], prev_action=all_action[warmup_slice], prev_reward=all_reward[warmup_slice], ) agent_slice = slice(wT, wT + bT) agent_inputs = AgentInputs( observation=all_observation[agent_slice], prev_action=all_action[agent_slice], prev_reward=all_reward[agent_slice], ) target_slice = slice(wT, None) # Same start t as agent. (wT + bT + nsr) target_inputs = AgentInputs( observation=all_observation[target_slice], prev_action=all_action[target_slice], prev_reward=all_reward[target_slice], ) action = samples.all_action[wT + 1:wT + 1 + bT] # CPU. return_ = samples.return_[wT:wT + bT] done_n = samples.done_n[wT:wT + bT] if self.store_rnn_state_interval == 0: init_rnn_state = None else: # [B,N,H]-->[N,B,H] cudnn. init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") if wT > 0: # Do warmup. with torch.no_grad(): _, target_rnn_state = self.agent.target(*warmup_inputs, init_rnn_state) _, init_rnn_state = self.agent(*warmup_inputs, init_rnn_state) # Recommend aligning sampling batch_T and store_rnn_interval with # warmup_T (and no mid_batch_reset), so that end of trajectory # during warmup leads to new trajectory beginning at start of # training segment of replay. warmup_invalid_mask = valid_from_done(samples.done[:wT])[-1] == 0 # [B] init_rnn_state[:, warmup_invalid_mask] = 0 # [N,B,H] (cudnn) target_rnn_state[:, warmup_invalid_mask] = 0 else: target_rnn_state = init_rnn_state qs, _ = self.agent(*agent_inputs, init_rnn_state) # [T,B,A] q = select_at_indexes(action, qs) with torch.no_grad(): target_qs, _ = self.agent.target(*target_inputs, target_rnn_state) if self.double_dqn: next_qs, _ = self.agent(*target_inputs, init_rnn_state) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values target_q = target_q[-bT:] # Same length as q. disc = self.discount ** self.n_step_return y = self.value_scale(return_ + (1 - done_n.float()) * disc * self.inv_value_scale(target_q)) # [T,B] delta = y - q losses = 0.5 * delta ** 2 abs_delta = abs(delta) # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None. if self.delta_clip is not None: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) if self.prioritized_replay: losses *= samples.is_weights.unsqueeze(0) # weights: [B] --> [1,B] valid = valid_from_done(samples.done[wT:]) # 0 after first done. loss = valid_mean(losses, valid) td_abs_errors = abs_delta.detach() if self.delta_clip is not None: td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip) # [T,B] valid_td_abs_errors = td_abs_errors * valid max_d = torch.max(valid_td_abs_errors, dim=0).values mean_d = valid_mean(td_abs_errors, valid, dim=0) # Still high if less valid. priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d # [B] return loss, valid_td_abs_errors, priorities
def compute_input_priorities(self, samples): """Used when putting new samples into the replay buffer. Computes n-step TD-errors using recorded Q-values from online network and value scaling. Weights the max and the mean TD-error over each sequence to make a single priority value for that sequence. Note: Although the original R2D2 implementation used the entire 80-step sequence to compute the input priorities, we ran R2D1 with 40 time-step sample batches, and so computed the priority for each 80-step training sequence based on one of the two 40-step halves. Algorithm argument ``input_priority_shift`` determines which 40-step half is used as the priority for the 80-step sequence. (Since this method might get executed by alternating memory copiers in async mode, don't carry internal state here, do all computation with only the samples available in input. Could probably reduce to one memory copier and keep state there, if needed.) """ # """Just for first input into replay buffer. # Simple 1-step return TD-errors using recorded Q-values from online # network and value scaling, with the T dimension reduced away (same # priority applied to all samples in this batch; whereever the rnn state # is kept--hopefully the first step--this priority will apply there). # The samples duration T might be less than the training segment, so # this is an approximation of an approximation, but hopefully will # capture the right behavior. # UPDATE 20190826: Trying using n-step returns. For now using samples # with full n-step return available...later could also use partial # returns for samples at end of batch. 35/40 ain't bad tho. # Might not carry/use internal state here, because might get executed # by alternating memory copiers in async mode; do all with only the # samples avialable from input.""" samples = torchify_buffer(samples) q = samples.agent.agent_info.q action = samples.agent.action q_max = torch.max(q, dim=-1).values q_at_a = select_at_indexes(action, q) return_n, done_n = discount_return_n_step( reward=samples.env.reward, done=samples.env.done, n_step=self.n_step_return, discount=self.discount, do_truncated=False, # Only samples with full n-step return. ) # y = self.value_scale( # samples.env.reward[:-1] + # (self.discount * (1 - samples.env.done[:-1].float()) * # probably done.float() # self.inv_value_scale(q_max[1:])) # ) nm1 = max(1, self.n_step_return - 1) # At least 1 bc don't have next Q. y = self.value_scale(return_n + (1 - done_n.float()) * self.inv_value_scale(q_max[nm1:])) delta = abs(q_at_a[:-nm1] - y) # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None. if self.delta_clip is not None: # Huber loss. delta = torch.clamp(delta, 0, self.delta_clip) valid = valid_from_done(samples.env.done[:-nm1]) max_d = torch.max(delta * valid, dim=0).values mean_d = valid_mean(delta, valid, dim=0) # Still high if less valid. priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d # [B] return priorities.numpy()
def loss(self, samples): """ Computes the training loss: policy_loss + value_loss + entropy_loss. Policy loss: log-likelihood of actions * advantages Value loss: 0.5 * (estimated_value - return) ^ 2 Organizes agent inputs from training samples, calls the agent instance to run forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ agent_inputs = AgentInputsOC( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, sampled_option=samples.agent.agent_info.o, ) if self.agent.recurrent: init_rnn_state = samples.agent.agent_info.prev_rnn_state[ 0] # T = 0. # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") po = samples.agent.agent_info.prev_o (dist_info_o, q, beta, dist_info_omega), _rnn_state = self.agent( *agent_inputs, po, init_rnn_state, device=agent_inputs.prev_action.device) else: dist_info_o, q, beta, dist_info_omega = self.agent( *agent_inputs, device=agent_inputs.prev_action.device) dist = self.agent.distribution dist_omega = self.agent.distribution_omega # TODO: try to compute everyone on device. return_, advantage, valid, beta_adv, not_init_states, op_adv = self.process_returns( samples) logli = dist.log_likelihood(samples.agent.action, dist_info_o) pi_loss = -valid_mean(logli * advantage, valid) o = samples.agent.agent_info.o q_o = select_at_indexes(o, q) value_error = 0.5 * (q_o - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) # Termination loss prev_o = samples.agent.agent_info.prev_o beta_prev_o = select_at_indexes(prev_o, beta) beta_error = beta_prev_o * beta_adv beta_loss = self.termination_loss_coeff * valid_mean( beta_error, not_init_states) logli = dist_omega.log_likelihood(o, dist_info_omega) # pi_omega_loss = - valid_mean(logli * advantage, valid) pi_omega_loss = -valid_mean(logli * op_adv, valid) entropy = dist.mean_entropy(dist_info_o, valid) entropy_loss = -self.entropy_loss_coeff * entropy entropy_o = dist_omega.mean_entropy(dist_info_omega, valid) entropy_loss_omega = -self.omega_entropy_loss_coeff * entropy_o loss = pi_loss + pi_omega_loss + beta_loss + value_loss + entropy_loss + entropy_loss_omega # perplexity = dist.mean_perplexity(dist_info_o, valid) return loss, pi_loss, value_loss, beta_loss, pi_omega_loss, entropy, entropy_o
def loss(self, samples, itr, samples_nce): """Samples have leading batch dimension [B,..] (but not time).""" self.args['device'] = self.agent.device """ Get rlpyt batch inputs and write them to GPU tensors """ rl_agent_inputs = AgentInputs( observation=samples.agent_inputs.observation, prev_action=samples.agent_inputs.prev_action, prev_reward=None) rl_action = samples.action rl_return_ = samples.return_ rl_target_inputs = AgentInputs( observation=samples.target_inputs.observation, prev_action=samples.target_inputs.prev_action, prev_reward=None) rl_done = samples.done rl_done_n = samples.done_n self.states[ (self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self. args['batch_size']] = samples_nce.agent_inputs.observation.type( torch.float32).to(self.args['device']) / 255. self.actions[(self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self.args['batch_size']] = samples_nce.action.type( torch.int64).to(self.args['device']) self.returns[(self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self.args['batch_size']] = samples_nce.return_.type( torch.float32).to(self.args['device']) self.next_states[ (self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self. args['batch_size']] = samples_nce.target_inputs.observation.type( torch.float32).to(self.args['device']) / 255. self.nonterminals[(self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self.args['batch_size']] = samples_nce.done if self.prioritized_replay: rl_is_weights = samples.is_weights self.weights[(self.nce_counter * self.args['batch_size']):(self.nce_counter + 1) * self.args['batch_size']] = samples_nce.is_weights self.nce_counter += 1 """ C51 code from rlpyt (unchanged) """ delta_z = (self.V_max - self.V_min) / (self.agent.n_atoms - 1) z = torch.linspace(self.V_min, self.V_max, self.agent.n_atoms) # Makde 2-D tensor of contracted z_domain for each data point, # with zeros where next value should not be added. next_z = z * (self.discount**self.n_step_return) # [P'] next_z = torch.ger(1 - rl_done_n.float(), next_z) # [B,P'] ret = rl_return_.unsqueeze(1) # [B,1] next_z = torch.clamp(ret + next_z, self.V_min, self.V_max) # [B,P'] z_bc = z.view(1, -1, 1) # [1,P,1] next_z_bc = next_z.unsqueeze(1) # [B,1,P'] abs_diff_on_delta = abs(next_z_bc - z_bc) / delta_z projection_coeffs = torch.clamp(1 - abs_diff_on_delta, 0, 1) # Most 0. # projection_coeffs is a 3-D tensor: [B,P,P'] # dim-0: independent data entries # dim-1: base_z atoms (remains after projection) # dim-2: next_z atoms (summed in projection) with torch.no_grad(): target_ps = self.agent.target(*rl_target_inputs) # [B,A,P'] if self.double_dqn: next_ps = self.agent(*rl_target_inputs) # [B,A,P'] next_qs = torch.tensordot(next_ps, z, dims=1) # [B,A] next_a = torch.argmax(next_qs, dim=-1) # [B] else: target_qs = torch.tensordot(target_ps, z, dims=1) # [B,A] next_a = torch.argmax(target_qs, dim=-1) # [B] target_p_unproj = select_at_indexes(next_a, target_ps) # [B,P'] target_p_unproj = target_p_unproj.unsqueeze(1) # [B,1,P'] target_p = (target_p_unproj * projection_coeffs).sum(-1) # [B,P] ps = self.agent(*rl_agent_inputs) # [B,A,P] p = select_at_indexes(rl_action, ps) # [B,P] p = torch.clamp(p, EPS, 1) # NaN-guard. losses = -torch.sum(target_p * torch.log(p), dim=1) # Cross-entropy. if self.prioritized_replay: losses *= rl_is_weights target_p = torch.clamp(target_p, EPS, 1) KL_div = torch.sum(target_p * (torch.log(target_p) - torch.log(p.detach())), dim=1) KL_div = torch.clamp(KL_div, EPS, 1 / EPS) # Avoid <0 from NaN-guard. if not self.mid_batch_reset: valid = valid_from_done(rl_done) loss = valid_mean(losses, valid) KL_div *= valid else: loss = torch.mean(losses) # else: # KL_div = torch.tensor([0.]).cpu() # loss = torch.tensor([0.]).to(self.args['device']) """ NCE loss """ loss_device = loss.get_device() if self.args['lambda_LL'] != 0 or self.args[ 'lambda_LG'] != 0 or self.args['lambda_GL'] != 0 or self.args[ 'lambda_GG'] != 0: """ Compute this only if one of the 4 lambdas != 0 """ if self.args['nce_batch_size'] // self.args[ 'batch_size'] <= self.nce_counter: target = None # Select the proper NCE loss passed as argument dict_nce = globals()[self.args['nce_loss']]( self.agent.model.model, self.states, self.actions, self.returns, self.next_states, self.args, target=target) nce_scores = self.args['lambda_LL'] * dict_nce[ 'nce_L_L'] + self.args['lambda_LG'] * dict_nce[ 'nce_L_G'] + self.args['lambda_GL'] * dict_nce[ 'nce_G_L'] + self.args['lambda_GG'] * dict_nce[ 'nce_G_G'] device_ = nce_scores.device nce_scores_raw = (dict_nce['nce_L_L'] if self.args['lambda_LL'] > 0 else torch.tensor(0.).to(device_)).mean() nce_scores_raw += (dict_nce['nce_L_G'] if self.args['lambda_LG'] > 0 else torch.tensor(0.).to(device_)).mean() nce_scores_raw += (dict_nce['nce_G_L'] if self.args['lambda_GL'] > 0 else torch.tensor(0.).to(device_)).mean() nce_scores_raw += (dict_nce['nce_G_G'] if self.args['lambda_GG'] > 0 else torch.tensor(0.).to(device_)).mean() if self.prioritized_replay: nce_device = nce_scores.get_device() if nce_device < 0: nce_scores *= samples.is_weights else: nce_scores *= samples.is_weights.to(nce_device) info_nce_loss_weighted = ( -nce_scores).mean() # decay by epsilon nce_scores_raw = (-nce_scores_raw).mean() if loss_device < 0: info_nce_loss_weighted = info_nce_loss_weighted.to('cpu') nce_scores_raw = nce_scores_raw.to('cpu') # self.reset_nce_accumulators(self.agent.device) self.nce_counter = 0 else: if loss_device > 0: info_nce_loss_weighted = torch.tensor(0.).to(loss_device) nce_scores_raw = torch.tensor(0.).to(loss_device) else: info_nce_loss_weighted = torch.tensor(0.).cpu() nce_scores_raw = torch.tensor(0.).cpu() else: if self.args['nce_batch_size'] // self.args[ 'batch_size'] <= self.nce_counter: # self.reset_nce_accumulators(self.agent.device) self.nce_counter = 0 if loss_device > 0: info_nce_loss_weighted = torch.tensor(0.).to(loss_device) nce_scores_raw = torch.tensor(0.).to(loss_device) else: info_nce_loss_weighted = torch.tensor(0.).cpu() nce_scores_raw = torch.tensor(0.).cpu() return loss + ( self.args['nce_batch_size'] // self.batch_size ) * info_nce_loss_weighted, KL_div, loss, info_nce_loss_weighted, nce_scores_raw