def set_ensemble_weights(policy, pid, weights): weights = weights[pid] weights = convert_to_torch_tensor(weights, device=policy.device) model = policy.dynamics_model model.load_state_dict(weights)
def compute_log_likelihoods( self, actions: Union[List[TensorType], TensorType], obs_batch: Union[List[TensorType], TensorType], state_batches: Optional[List[TensorType]] = None, prev_action_batch: Optional[Union[List[TensorType], TensorType]] = None, prev_reward_batch: Optional[Union[List[TensorType], TensorType]] = None ) -> TensorType: if self.action_sampler_fn and self.action_distribution_fn is None: raise ValueError("Cannot compute log-prob/likelihood w/o an " "`action_distribution_fn` and a provided " "`action_sampler_fn`!") with torch.no_grad(): input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: obs_batch, SampleBatch.ACTIONS: actions }) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = prev_action_batch if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = prev_reward_batch seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) state_batches = [ convert_to_torch_tensor(s, self.device) for s in (state_batches or []) ] # Exploration hook before each forward pass. self.exploration.before_compute_actions(explore=False) # Action dist class and inputs are generated via custom function. if self.action_distribution_fn: # Try new action_distribution_fn signature, supporting # state_batches and seq_lens. try: dist_inputs, dist_class, state_out = \ self.action_distribution_fn( self, self.model, input_dict=input_dict, state_batches=state_batches, seq_lens=seq_lens, explore=False, is_training=False) # Trying the old way (to stay backward compatible). # TODO: Remove in future. except TypeError as e: if "positional argument" in e.args[0] or \ "unexpected keyword argument" in e.args[0]: dist_inputs, dist_class, _ = \ self.action_distribution_fn( policy=self, model=self.model, obs_batch=input_dict[SampleBatch.CUR_OBS], explore=False, is_training=False) else: raise e # Default action-dist inputs calculation. else: dist_class = self.dist_class dist_inputs, _ = self.model(input_dict, state_batches, seq_lens) action_dist = dist_class(dist_inputs, self.model) log_likelihoods = action_dist.logp(input_dict[SampleBatch.ACTIONS]) return log_likelihoods
def from_importance_weights(log_rhos, discounts, rewards, values, bootstrap_value, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0): """V-trace from log importance weights. Calculates V-trace actor critic targets as described in "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" by Espeholt, Soyer, Munos et al. In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size. This code also supports the case where all tensors have the same number of additional dimensions, e.g., `rewards` is [T, B, C], `values` is [T, B, C], `bootstrap_value` is [B, C]. Args: log_rhos: A float32 tensor of shape [T, B] representing the log importance sampling weights, i.e. log(target_policy(a) / behaviour_policy(a)). V-trace performs operations on rhos in log-space for numerical stability. discounts: A float32 tensor of shape [T, B] with discounts encountered when following the behaviour policy. rewards: A float32 tensor of shape [T, B] containing rewards generated by following the behaviour policy. values: A float32 tensor of shape [T, B] with the value function estimates wrt. the target policy. bootstrap_value: A float32 of shape [B] with the value function estimate at time T. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. If None, no clipping is applied. clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold on rho_s in \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). If None, no clipping is applied. Returns: A VTraceReturns namedtuple (vs, pg_advantages) where: vs: A float32 tensor of shape [T, B]. Can be used as target to train a baseline (V(x_t) - vs_t)^2. pg_advantages: A float32 tensor of shape [T, B]. Can be used as the advantage in the calculation of policy gradients. """ log_rhos = convert_to_torch_tensor(log_rhos, device="cpu") discounts = convert_to_torch_tensor(discounts, device="cpu") rewards = convert_to_torch_tensor(rewards, device="cpu") values = convert_to_torch_tensor(values, device="cpu") bootstrap_value = convert_to_torch_tensor(bootstrap_value, device="cpu") # Make sure tensor ranks are consistent. rho_rank = len(log_rhos.size()) # Usually 2. assert rho_rank == len(values.size()) assert rho_rank - 1 == len(bootstrap_value.size()),\ "must have rank {}".format(rho_rank - 1) assert rho_rank == len(discounts.size()) assert rho_rank == len(rewards.size()) rhos = torch.exp(log_rhos) if clip_rho_threshold is not None: clipped_rhos = torch.clamp_max(rhos, clip_rho_threshold) else: clipped_rhos = rhos cs = torch.clamp_max(rhos, 1.0) # Append bootstrapped value to get [v1, ..., v_t+1] values_t_plus_1 = torch.cat( [values[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values) vs_minus_v_xs = [torch.zeros_like(bootstrap_value)] for i in reversed(range(len(discounts))): discount_t, c_t, delta_t = discounts[i], cs[i], deltas[i] vs_minus_v_xs.append(delta_t + discount_t * c_t * vs_minus_v_xs[-1]) vs_minus_v_xs = torch.stack(vs_minus_v_xs[1:]) # Reverse the results back to original order. vs_minus_v_xs = torch.flip(vs_minus_v_xs, dims=[0]) # Add V(x_s) to get v_s. vs = vs_minus_v_xs + values # Advantage for policy gradient. vs_t_plus_1 = torch.cat( [vs[1:], torch.unsqueeze(bootstrap_value, 0)], dim=0) if clip_pg_rho_threshold is not None: clipped_pg_rhos = torch.clamp_max(rhos, clip_pg_rho_threshold) else: clipped_pg_rhos = rhos pg_advantages = ( clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values)) # Make sure no gradients backpropagated through the returned values. return VTraceReturns(vs=vs.detach(), pg_advantages=pg_advantages.detach())
def set_weights(self, weights: ModelWeights) -> None: weights = convert_to_torch_tensor(weights, device=self.device) self.model.load_state_dict(weights)
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: print(policy.cur_iter) policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS] next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model({ "obs": obs, "is_training": True, }, [], None) model_out_tp1, _ = model({ "obs": next_obs, "is_training": True, }, [], None) target_model_out_tp1, _ = policy.target_model({ "obs": next_obs, "is_training": True, }, [], None) action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: bc_logp = action_dist_t.logp(actions) actor_loss = (alpha * log_pis_t - bc_logp).mean() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() # Q-values for the batched actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 = torch.squeeze(input=q_tp1, dim=-1) q_tp1 = (1.0 - terminals.float()) * q_tp1 # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [nn.MSELoss()(q_t, q_t_target)] if twin_q: critic_loss.append(nn.MSELoss()(twin_q_t, q_t_target)) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, obs, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, next_obs, num_actions) curr_logp = curr_logp.view(actions.shape[0], num_actions, 1) next_logp = next_logp.view(actions.shape[0], num_actions, 1) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat( model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat( model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach() ], 1) if twin_q: cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach() ], 1) min_qf1_loss = torch.logsumexp( cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss - q_t.mean() * min_q_weight if twin_q: min_qf2_loss = torch.logsumexp( cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss - twin_q_t.mean() * min_q_weight if use_lagrange: alpha_prime = torch.clamp( model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf2_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss[0] += min_qf1_loss if twin_q: critic_loss[1] += min_qf2_loss # Save for stats function. policy.q_t = q_t policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss] + [policy.alpha_prime_loss]) return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def multi_from_logits(behaviour_policy_logits, target_policy_logits, actions, discounts, rewards, values, bootstrap_value, dist_class, model, behaviour_action_log_probs=None, clip_rho_threshold=1.0, clip_pg_rho_threshold=1.0): """V-trace for softmax policies. Calculates V-trace actor critic targets for softmax polices as described in "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures" by Espeholt, Soyer, Munos et al. Target policy refers to the policy we are interested in improving and behaviour policy refers to the policy that generated the given rewards and actions. In the notation used throughout documentation and comments, T refers to the time dimension ranging from 0 to T-1. B refers to the batch size and ACTION_SPACE refers to the list of numbers each representing a number of actions. Args: behaviour_policy_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing the softmax behavior policy. target_policy_logits: A list with length of ACTION_SPACE of float32 tensors of shapes [T, B, ACTION_SPACE[0]], ..., [T, B, ACTION_SPACE[-1]] with un-normalized log-probabilities parameterizing the softmax target policy. actions: A list with length of ACTION_SPACE of tensors of shapes [T, B, ...], ..., [T, B, ...] with actions sampled from the behavior policy. discounts: A float32 tensor of shape [T, B] with the discount encountered when following the behavior policy. rewards: A float32 tensor of shape [T, B] with the rewards generated by following the behavior policy. values: A float32 tensor of shape [T, B] with the value function estimates wrt. the target policy. bootstrap_value: A float32 of shape [B] with the value function estimate at time T. dist_class: action distribution class for the logits. model: backing ModelV2 instance behaviour_action_log_probs: Precalculated values of the behavior actions. clip_rho_threshold: A scalar float32 tensor with the clipping threshold for importance weights (rho) when calculating the baseline targets (vs). rho^bar in the paper. clip_pg_rho_threshold: A scalar float32 tensor with the clipping threshold on rho_s in: \rho_s \delta log \pi(a|x) (r + \gamma v_{s+1} - V(x_s)). Returns: A `VTraceFromLogitsReturns` namedtuple with the following fields: vs: A float32 tensor of shape [T, B]. Can be used as target to train a baseline (V(x_t) - vs_t)^2. pg_advantages: A float 32 tensor of shape [T, B]. Can be used as an estimate of the advantage in the calculation of policy gradients. log_rhos: A float32 tensor of shape [T, B] containing the log importance sampling weights (log rhos). behaviour_action_log_probs: A float32 tensor of shape [T, B] containing behaviour policy action log probabilities (log \mu(a_t)). target_action_log_probs: A float32 tensor of shape [T, B] containing target policy action probabilities (log \pi(a_t)). """ behaviour_policy_logits = convert_to_torch_tensor( behaviour_policy_logits, device="cpu") target_policy_logits = convert_to_torch_tensor( target_policy_logits, device="cpu") actions = convert_to_torch_tensor(actions, device="cpu") for i in range(len(behaviour_policy_logits)): # Make sure tensor ranks are as expected. # The rest will be checked by from_action_log_probs. assert len(behaviour_policy_logits[i].size()) == 3 assert len(target_policy_logits[i].size()) == 3 target_action_log_probs = multi_log_probs_from_logits_and_actions( target_policy_logits, actions, dist_class, model) if (len(behaviour_policy_logits) > 1 or behaviour_action_log_probs is None): # can't use precalculated values, recompute them. Note that # recomputing won't work well for autoregressive action dists # which may have variables not captured by 'logits' behaviour_action_log_probs = (multi_log_probs_from_logits_and_actions( behaviour_policy_logits, actions, dist_class, model)) behaviour_action_log_probs = force_list(behaviour_action_log_probs) log_rhos = get_log_rhos(target_action_log_probs, behaviour_action_log_probs) vtrace_returns = from_importance_weights( log_rhos=log_rhos, discounts=discounts, rewards=rewards, values=values, bootstrap_value=bootstrap_value, clip_rho_threshold=clip_rho_threshold, clip_pg_rho_threshold=clip_pg_rho_threshold) return VTraceFromLogitsReturns( log_rhos=log_rhos, behaviour_action_log_probs=behaviour_action_log_probs, target_action_log_probs=target_action_log_probs, **vtrace_returns._asdict())
def value(**input_dict): model_out, _ = self.model.from_batch( convert_to_torch_tensor(input_dict, self.device), is_training=False) # [0] = remove the batch dim. return self.model.value_function()[0]
def set_weights(self, weights: dict): self.module.load_state_dict( convert_to_torch_tensor(weights["module"], device=self.device)) # Optimizer state dicts don't store tensors, only ids self.optimizers.load_state_dict(weights["optimizers"])
def compute_actions(self, obs_batch, state_batches=None, prev_action_batch=None, prev_reward_batch=None, info_batch=None, episodes=None, explore=None, timestep=None, **kwargs): explore = explore if explore is not None else self.config["explore"] timestep = timestep if timestep is not None else self.global_timestep with torch.no_grad(): seq_lens = torch.ones(len(obs_batch), dtype=torch.int32) input_dict = self._lazy_tensor_dict({ SampleBatch.CUR_OBS: np.asarray(obs_batch), "is_training": False, }) if prev_action_batch is not None: input_dict[SampleBatch.PREV_ACTIONS] = \ np.asarray(prev_action_batch) if prev_reward_batch is not None: input_dict[SampleBatch.PREV_REWARDS] = \ np.asarray(prev_reward_batch) state_batches = [ convert_to_torch_tensor(s) for s in (state_batches or []) ] if self.action_sampler_fn: action_dist = dist_inputs = None state_out = [] actions, logp = self.action_sampler_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep) else: # Call the exploration before_compute_actions hook. self.exploration.before_compute_actions( explore=explore, timestep=timestep) if self.action_distribution_fn: dist_inputs, dist_class, state_out = \ self.action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS], explore=explore, timestep=timestep, is_training=False) else: dist_class = self.dist_class dist_inputs, state_out = self.model( input_dict, state_batches, seq_lens) if not (isinstance(dist_class, functools.partial) or issubclass(dist_class, TorchDistributionWrapper)): raise ValueError( "`dist_class` ({}) not a TorchDistributionWrapper " "subclass! Make sure your `action_distribution_fn` or " "`make_model_and_action_dist` return a correct " "distribution class.".format(dist_class.__name__)) action_dist = dist_class(dist_inputs, self.model) # Get the exploration action from the forward results. actions, logp = \ self.exploration.get_exploration_action( action_distribution=action_dist, timestep=timestep, explore=explore) input_dict[SampleBatch.ACTIONS] = actions # Add default and custom fetches. extra_fetches = self.extra_action_out(input_dict, state_batches, self.model, action_dist) # Action-logp and action-prob. if logp is not None: logp = convert_to_non_torch_type(logp) extra_fetches[SampleBatch.ACTION_PROB] = np.exp(logp) extra_fetches[SampleBatch.ACTION_LOGP] = logp # Action-dist inputs. if dist_inputs is not None: extra_fetches[SampleBatch.ACTION_DIST_INPUTS] = dist_inputs return convert_to_non_torch_type((actions, state_out, extra_fetches))
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] action_low = model.action_space.low[0] action_high = model.action_space.high[0] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = train_batch[SampleBatch.ACTIONS] rewards = train_batch[SampleBatch.REWARDS] next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] policy_optimizer = policy._optimizers[0] critic1_optimizer = policy._optimizers[1] critic2_optimizer = policy._optimizers[2] alpha_optimizer = policy._optimizers[3] model_out_t, _ = model({ "obs": obs, "is_training": True, }, [], None) model_out_tp1, _ = model({ "obs": next_obs, "is_training": True, }, [], None) target_model_out_tp1, _ = policy.target_model( { "obs": next_obs, "is_training": True, }, [], None) action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = torch.unsqueeze(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -(model.log_alpha * (log_pis_t + model.target_entropy).detach()).mean() if obs.shape[0] == policy.config["train_batch_size"]: alpha_optimizer.zero_grad() alpha_loss.backward() alpha_optimizer.step() # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = torch.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = torch.min(min_q, twin_q_) actor_loss = (alpha.detach() * log_pis_t - min_q).mean() else: def bc_log(model, obs, actions): z = atanh(actions) logits = model.get_policy_output(obs) mean, log_std = torch.chunk(logits, 2, dim=-1) # Mean Clamping for Stability mean = torch.clamp(mean, MEAN_MIN, MEAN_MAX) log_std = torch.clamp(log_std, MIN_LOG_NN_OUTPUT, MAX_LOG_NN_OUTPUT) std = torch.exp(log_std) normal_dist = torch.distributions.Normal(mean, std) return torch.sum(normal_dist.log_prob(z) - torch.log(1 - actions * actions + SMALL_NUMBER), dim=-1) bc_logp = bc_log(model, model_out_t, actions) actor_loss = (alpha.detach() * log_pis_t - bc_logp).mean() if obs.shape[0] == policy.config["train_batch_size"]: policy_optimizer.zero_grad() actor_loss.backward(retain_graph=True) policy_optimizer.step() # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), policy.model) policy_tp1, log_pis_tp1 = action_dist_tp1.sample_logp() log_pis_tp1 = torch.unsqueeze(log_pis_tp1, -1) q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) q_t_selected = torch.squeeze(q_t, dim=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - terminals.float()) * q_tp1_best # compute RHS of bellman equation q_t_target = ( rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = torch.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = torch.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = nn.functional.mse_loss(q_t_selected, q_t_target) if twin_q: critic_loss_2 = nn.functional.mse_loss(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions = convert_to_torch_tensor( torch.FloatTensor(actions.shape[0] * num_actions, actions.shape[-1]).uniform_(action_low, action_high), policy.device) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**curr_actions.shape[-1]) cat_q1 = torch.cat([ q1_rand - random_density, q1_next_actions - next_logp.detach(), q1_curr_actions - curr_logp.detach() ], 1) if twin_q: cat_q2 = torch.cat([ q2_rand - random_density, q2_next_actions - next_logp.detach(), q2_curr_actions - curr_logp.detach() ], 1) min_qf1_loss_ = torch.logsumexp(cat_q1 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss_ - (q_t.mean() * min_q_weight) if twin_q: min_qf2_loss_ = torch.logsumexp(cat_q2 / cql_temp, dim=1).mean() * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss_ - (twin_q_t.mean() * min_q_weight) if use_lagrange: alpha_prime = torch.clamp(model.log_alpha_prime.exp(), min=0.0, max=1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) if obs.shape[0] == policy.config["train_batch_size"]: critic1_optimizer.zero_grad() critic_loss[0].backward(retain_graph=True) critic1_optimizer.step() critic2_optimizer.zero_grad() critic_loss[1].backward(retain_graph=False) critic2_optimizer.step() # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t model.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss] + [policy.alpha_prime_loss]) return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def _copy_weights_of_one_model(self, model_in_use, model_evading): weights = model_evading.state_dict() weights = convert_to_torch_tensor(weights, device=self.device) model_in_use.load_state_dict(weights)