def get_distribution_inputs_and_class(policy, model, obs_batch, *, explore=True, is_training=False, **kwargs): q_vals = compute_q_values(policy, model, obs_batch, explore, is_training) q_vals = q_vals[0] if isinstance(q_vals, tuple) else q_vals policy.q_values = q_vals return policy.q_values, TorchCategorical, [] # state-out
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, policy.target_q_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) actions = train_batch[SampleBatch.ACTIONS].long() dones = train_batch[SampleBatch.DONES].float() rewards = train_batch[SampleBatch.REWARDS] weights = train_batch[PRIO_WEIGHTS] B = state_batches[0].shape[0] T = q.shape[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(actions, policy.action_space.n) q_selected = torch.sum( torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device)) * one_hot_selection, 1) if config["double_q"]: best_actions = torch.argmax(q, dim=1) else: best_actions = torch.argmax(q_target, dim=1) best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) q_target_best = torch.sum( torch.where(q_target > FLOAT_MIN, q_target, torch.tensor(0.0, device=policy.device)) * best_actions_one_hot, dim=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * torch.cat( [q_target_best[1:], torch.tensor([0.0], device=policy.device)]) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = q_selected.reshape([B, T])[:, :-1] td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) policy._td_error = td_error.reshape([-1]) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": torch.min(q_selected), "max_q": torch.max(q_selected), "mean_td_error": reduce_mean_valid(td_error), } return policy._total_loss
def build_q_losses_wt_additional_logs( policy: Policy, model, _, train_batch: SampleBatch ) -> TensorType: """ Copy of build_q_losses with additional values saved into the policy Made only 2 changes, see in comments. """ config = policy.config # Q-network evaluation. q_t, q_logits_t, q_probs_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], explore=False, is_training=True, ) # Addition 1 out of 2 policy.last_q_t = q_t.clone() # Target Q-network evaluation. q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) # Addition 2 out of 2 policy.last_target_q_t = q_tp1.clone() # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot( train_batch[SampleBatch.ACTIONS], policy.action_space.n ) q_t_selected = torch.sum( torch.where( q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device) ) * one_hot_selection, 1, ) q_logits_t_selected = torch.sum( q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1 ) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: ( q_tp1_using_online_net, q_logits_tp1_using_online_net, q_dist_tp1_using_online_net, ) = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot( q_tp1_best_using_online_net, policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) else: q_tp1_best_one_hot_selection = F.one_hot( torch.argmax(q_tp1, 1), policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) if PRIO_WEIGHTS not in train_batch.keys(): assert config["prioritized_replay"] is False prio_weights = torch.tensor( [1.0] * len(train_batch[SampleBatch.REWARDS]) ).to(policy.device) else: prio_weights = train_batch[PRIO_WEIGHTS] policy.q_loss = QLoss( q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, prio_weights, train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"], ) return policy.q_loss.loss
def build_drq_q_losses(policy, model, _, train_batch): """ use input augmentation on Q target and Q updates """ config = policy.config aug_num = config["aug_num"] # target q network evalution q_tp1_best_avg = 0 orig_nxt_obs = train_batch[SampleBatch.NEXT_OBS].clone() for _ in range(aug_num): # augment obs aug_nxt_obs = model.trans(orig_nxt_obs.permute(0, 3, 1, 2).float()).permute( 0, 2, 3, 1) q_tp1 = compute_q_values(policy, policy.target_q_model, aug_nxt_obs, explore=False, is_training=True) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: q_tp1_using_online_net = compute_q_values(policy, policy.q_model, aug_nxt_obs, explore=False, is_training=True) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot( q_tp1_best_using_online_net, policy.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) else: q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = torch.sum(q_tp1 * q_tp1_best_one_hot_selection, 1) # accumulate target Q with augmented next obs q_tp1_best_avg += q_tp1_best q_tp1_best_avg /= aug_num # q network evaluation aug_loss = 0 orig_cur_obs = train_batch[SampleBatch.CUR_OBS].clone() for _ in range(aug_num): # augment obs aug_cur_obs = model.trans(orig_cur_obs.permute(0, 3, 1, 2).float()).permute( 0, 2, 3, 1) q_t = compute_q_values(policy, policy.q_model, aug_cur_obs, explore=False, is_training=True) # q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS], policy.action_space.n) q_t_selected = torch.sum(q_t * one_hot_selection, 1) # Bellman error policy.q_loss = QLoss(q_t_selected, q_tp1_best_avg, train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"]) # accumulate loss with augmented obs aug_loss += policy.q_loss.loss return aug_loss / aug_num