def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for DQNTFPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # q network evaluation q_t, q_logits_t, q_dist_t, _ = compute_q_values( policy, model, {"obs": train_batch[SampleBatch.CUR_OBS]}, state_batches=None, explore=False) # target q network evalution q_tp1, q_logits_tp1, q_dist_tp1, _ = compute_q_values( policy, policy.target_q_model, {"obs": train_batch[SampleBatch.NEXT_OBS]}, state_batches=None, explore=False) if not hasattr(policy, "target_q_func_vars"): policy.target_q_func_vars = policy.target_q_model.variables() # q scores for actions which we know were selected in the given state. one_hot_selection = tf.one_hot( tf.cast(train_batch[SampleBatch.ACTIONS], tf.int32), policy.action_space.n) q_t_selected = tf.reduce_sum(q_t * one_hot_selection, 1) q_logits_t_selected = tf.reduce_sum( q_logits_t * tf.expand_dims(one_hot_selection, -1), 1) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: q_tp1_using_online_net, q_logits_tp1_using_online_net, \ q_dist_tp1_using_online_net, _ = compute_q_values( policy, model, {"obs": train_batch[SampleBatch.NEXT_OBS]}, state_batches=None, explore=False) q_tp1_best_using_online_net = tf.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = tf.one_hot(q_tp1_best_using_online_net, policy.action_space.n) q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) q_dist_tp1_best = tf.reduce_sum( q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) else: q_tp1_best_one_hot_selection = tf.one_hot(tf.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = tf.reduce_sum(q_tp1 * q_tp1_best_one_hot_selection, 1) q_dist_tp1_best = tf.reduce_sum( q_dist_tp1 * tf.expand_dims(q_tp1_best_one_hot_selection, -1), 1) policy.q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, q_dist_tp1_best, train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], tf.cast(train_batch[SampleBatch.DONES], tf.float32), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"]) return policy.q_loss.loss
def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) # Log device and worker index. from ray.rllib.evaluation.rollout_worker import get_global_worker worker = get_global_worker() worker_idx = worker.worker_index if worker else 0 if get_gpu_devices(): logger.info( "TF-eager Policy (worker={}) running on GPU.".format( worker_idx if worker_idx > 0 else "local")) else: logger.info( "TF-eager Policy (worker={}) running on CPU.".format( worker_idx if worker_idx > 0 else "local")) self._is_training = False self._loss_initialized = False self._loss = loss_fn self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) self._max_seq_len = config["model"]["max_seq_len"] if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) # Lock used for locking some methods on the object-level. # This prevents possible race conditions when calling the model # first, then its value function (e.g. in a loss function), in # between of which another model call is made (e.g. to compute an # action). self._lock = threading.RLock() # Auto-update model's inference view requirements, if recurrent. self._update_model_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Combine view_requirements for Model and Policy. self.view_requirements.update(self.model.view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) if optimizer_fn: optimizers = optimizer_fn(self, config) else: optimizers = tf.keras.optimizers.Adam(config["lr"]) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # TODO: (sven) Allow tf policy to have more than 1 optimizer. # Just like torch Policy does. self._optimizer: LocalOptimizer = \ optimizers[0] if optimizers else None self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) self._loss_initialized = True if after_init: after_init(self, observation_space, action_space, config) # Got to reset global_timestep again after fake run-throughs. self.global_timestep = 0
def sac_actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] # Get the base model output from the train batch. model_out_t, _ = model( { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the base model output from the next observations in the train batch. model_out_tp1, _ = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Get the target model's base outputs from the next observations in the # train batch. target_model_out_tp1, _ = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": policy._get_is_training_placeholder(), }, [], None) # Discrete actions case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = tf.nn.log_softmax(model.get_policy_output(model_out_t), -1) policy_t = tf.math.exp(log_pis_t) log_pis_tp1 = tf.nn.log_softmax(model.get_policy_output(model_out_tp1), -1) policy_tp1 = tf.math.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_tp1 -= model.alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = tf.one_hot(train_batch[SampleBatch.ACTIONS], depth=q_t.shape.as_list()[-1]) q_t_selected = tf.reduce_sum(q_t * one_hot, axis=-1) if policy.config["twin_q"]: twin_q_t_selected = tf.reduce_sum(twin_q_t * one_hot, axis=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = tf.reduce_sum(tf.multiply(policy_tp1, q_tp1), axis=-1) q_tp1_best_masked = \ (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * \ q_tp1_best # Continuous actions case. else: # Sample simgle actions from distribution. action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = tf.expand_dims(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = tf.expand_dims(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = tf.reduce_min( (q_t_det_policy, twin_q_t_det_policy), axis=0) # target q network evaluation q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.reduce_min((q_tp1, twin_q_tp1), axis=0) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if policy.config["twin_q"]: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 -= model.alpha * log_pis_tp1 q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = (1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation for the Q-loss (critic(s)). q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + tf.cast( policy.config["gamma"]**policy.config["n_step"] * q_tp1_best_masked, tf.float32)) # Compute the TD-error (potentially clipped). base_td_error = tf.math.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error # Calculate one or two critic losses (2 in the twin_q case). critic_loss = [ 0.5 * tf.keras.losses.MSE(y_true=q_t_selected_target, y_pred=q_t_selected) ] if policy.config["twin_q"]: critic_loss.append(0.5 * tf.keras.losses.MSE( y_true=q_t_selected_target, y_pred=twin_q_t_selected)) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: alpha_loss = tf.reduce_mean( tf.reduce_sum(tf.multiply( tf.stop_gradient(policy_t), -model.log_alpha * tf.stop_gradient(log_pis_t + policy.target_entropy)), axis=-1)) actor_loss = tf.reduce_mean( tf.reduce_sum( tf.multiply( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, model.alpha * log_pis_t - tf.stop_gradient(q_t)), axis=-1)) else: alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + policy.target_entropy)) actor_loss = tf.reduce_mean(model.alpha * log_pis_t - q_t_det_policy) # Save for stats function. policy.policy_t = policy_t policy.q_t = q_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.alpha_value = model.alpha # policy.target_entropy = policy.target_entropy # In a custom apply op we handle the losses separately, but return them # combined in one loss here. return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def build_q_losses(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for DQNTorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Q-network evaluation. q_t, q_logits_t, q_probs_t, _ = compute_q_values( policy, model, {"obs": train_batch[SampleBatch.CUR_OBS]}, explore=False, is_training=True) # Target Q-network evaluation. q_tp1, q_logits_tp1, q_probs_tp1, _ = compute_q_values( policy, policy.target_models[model], {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True) # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), policy.action_space.n) q_t_selected = torch.sum( torch.where(q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=q_t.device)) * one_hot_selection, 1) q_logits_t_selected = torch.sum( q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: q_tp1_using_online_net, q_logits_tp1_using_online_net, \ q_dist_tp1_using_online_net, _ = compute_q_values( policy, model, {"obs": train_batch[SampleBatch.NEXT_OBS]}, explore=False, is_training=True) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot(q_tp1_best_using_online_net, policy.action_space.n) q_tp1_best = torch.sum( torch.where(q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)) * q_tp1_best_one_hot_selection, 1) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) else: q_tp1_best_one_hot_selection = F.one_hot(torch.argmax(q_tp1, 1), policy.action_space.n) q_tp1_best = torch.sum( torch.where(q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=q_tp1.device)) * q_tp1_best_one_hot_selection, 1) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1) policy.q_loss = QLoss(q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, train_batch[PRIO_WEIGHTS], train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"]) # Store td-error in model, such that for multi-GPU, we do not override # them during the parallel loss phase. TD-error tensor in final stats # can then be concatenated and retrieved for each individual batch item. model.td_error = policy.q_loss.td_error return policy.q_loss.loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model.from_batch(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). B = tf.shape(train_batch["seq_lens"])[0] max_seq_len = tf.shape(logits)[0] // B mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) # non-RNN case: No masking. else: mask = None reduce_mean_valid = tf.reduce_mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = tf.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) if policy.config["use_gae"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = tf.math.square(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]) vf_loss = tf.maximum(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) else: mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def build_sac_model(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> ModelV2: """Constructs the necessary ModelV2 for the Policy and returns it. Args: policy (Policy): The TFPolicy that will use the models. obs_space (gym.spaces.Space): The observation space. action_space (gym.spaces.Space): The action space. config (TrainerConfigDict): The SAC trainer's config dict. Returns: ModelV2: The ModelV2 to be used by the Policy. Note: An additional target model will be created in this function and assigned to `policy.target_model`. """ # With separate state-preprocessor (before obs+action concat). num_outputs = int(np.product(obs_space.shape)) # Force-ignore any additionally provided hidden layer sizes. # Everything should be configured using SAC's "Q_model" and "policy_model" # settings. policy_model_config = MODEL_DEFAULTS.copy() policy_model_config.update(config["policy_model"]) q_model_config = MODEL_DEFAULTS.copy() q_model_config.update(config["Q_model"]) default_model_cls = SACTorchModel if config["framework"] == "torch" \ else SACTFModel model = ModelCatalog.get_model_v2(obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, name="sac_model", policy_model_config=policy_model_config, q_model_config=q_model_config, twin_q=config["twin_q"], initial_alpha=config["initial_alpha"], target_entropy=config["target_entropy"]) assert isinstance(model, default_model_cls) # Create an exact copy of the model and store it in `policy.target_model`. # This will be used for tau-synched Q-target models that run behind the # actual Q-networks and are used for target q-value calculations in the # loss terms. policy.target_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework=config["framework"], default_model=default_model_cls, name="target_sac_model", policy_model_config=policy_model_config, q_model_config=q_model_config, twin_q=config["twin_q"], initial_alpha=config["initial_alpha"], target_entropy=config["target_entropy"]) assert isinstance(policy.target_model, default_model_cls) return model
def build_q_losses_wt_additional_logs( policy: Policy, model, _, train_batch: SampleBatch ) -> TensorType: """ Copy of build_q_losses with additional values saved into the policy Made only 2 changes, see in comments. """ config = policy.config # Q-network evaluation. q_t, q_logits_t, q_probs_t = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.CUR_OBS], explore=False, is_training=True, ) # Addition 1 out of 2 policy.last_q_t = q_t.clone() # Target Q-network evaluation. q_tp1, q_logits_tp1, q_probs_tp1 = compute_q_values( policy, policy.target_q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) # Addition 2 out of 2 policy.last_target_q_t = q_tp1.clone() # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot( train_batch[SampleBatch.ACTIONS], policy.action_space.n ) q_t_selected = torch.sum( torch.where( q_t > FLOAT_MIN, q_t, torch.tensor(0.0, device=policy.device) ) * one_hot_selection, 1, ) q_logits_t_selected = torch.sum( q_logits_t * torch.unsqueeze(one_hot_selection, -1), 1 ) # compute estimate of best possible value starting from state at t + 1 if config["double_q"]: ( q_tp1_using_online_net, q_logits_tp1_using_online_net, q_dist_tp1_using_online_net, ) = compute_q_values( policy, policy.q_model, train_batch[SampleBatch.NEXT_OBS], explore=False, is_training=True, ) q_tp1_best_using_online_net = torch.argmax(q_tp1_using_online_net, 1) q_tp1_best_one_hot_selection = F.one_hot( q_tp1_best_using_online_net, policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) else: q_tp1_best_one_hot_selection = F.one_hot( torch.argmax(q_tp1, 1), policy.action_space.n ) q_tp1_best = torch.sum( torch.where( q_tp1 > FLOAT_MIN, q_tp1, torch.tensor(0.0, device=policy.device), ) * q_tp1_best_one_hot_selection, 1, ) q_probs_tp1_best = torch.sum( q_probs_tp1 * torch.unsqueeze(q_tp1_best_one_hot_selection, -1), 1 ) if PRIO_WEIGHTS not in train_batch.keys(): assert config["prioritized_replay"] is False prio_weights = torch.tensor( [1.0] * len(train_batch[SampleBatch.REWARDS]) ).to(policy.device) else: prio_weights = train_batch[PRIO_WEIGHTS] policy.q_loss = QLoss( q_t_selected, q_logits_t_selected, q_tp1_best, q_probs_tp1_best, prio_weights, train_batch[SampleBatch.REWARDS], train_batch[SampleBatch.DONES].float(), config["gamma"], config["n_step"], config["num_atoms"], config["v_min"], config["v_max"], ) return policy.q_loss.loss
def appo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ model_out, _ = model.from_batch(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 # TODO: (sven) deprecate this when trajectory view API gets activated. def make_time_major(*args, **kw): return _make_time_major(policy, train_batch.get("seq_lens"), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = policy.target_model.from_batch(train_batch) prev_action_dist = dist_class(behaviour_logits, policy.model) values = policy.model.value_function() values_time_major = make_time_major(values) policy.model_vars = policy.model.variables() policy.target_model_vars = policy.target_model.variables() if policy.is_recurrent(): max_seq_len = tf.reduce_max(train_batch["seq_lens"]) - 1 mask = tf.sequence_mask(train_batch["seq_lens"], max_seq_len) mask = tf.reshape(mask, [-1]) mask = make_time_major(mask, drop_last=policy.config["vtrace"]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) else: reduce_mean_valid = tf.reduce_mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True)") # Prepare actions for loss. loss_actions = actions if is_multidiscrete else tf.expand_dims(actions, axis=1) old_policy_behaviour_logits = tf.stop_gradient(target_model_out) old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) # Prepare KL for Loss mean_kl = make_time_major(old_policy_action_dist.multi_kl(action_dist), drop_last=True) unpacked_behaviour_logits = tf.split(behaviour_logits, output_hidden_shape, axis=1) unpacked_old_policy_behaviour_logits = tf.split( old_policy_behaviour_logits, output_hidden_shape, axis=1) # Compute vtrace on the CPU for better perf. with tf.device("/cpu:0"): vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=make_time_major( unpacked_behaviour_logits, drop_last=True), target_policy_logits=make_time_major( unpacked_old_policy_behaviour_logits, drop_last=True), actions=tf.unstack(make_time_major(loss_actions, drop_last=True), axis=2), discounts=tf.cast(~make_time_major(dones, drop_last=True), tf.float32) * policy.config["gamma"], rewards=make_time_major(rewards, drop_last=True), values=values_time_major[:-1], # drop-last=True bootstrap_value=values_time_major[-1], dist_class=Categorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=tf.cast( policy.config["vtrace_clip_rho_threshold"], tf.float32), clip_pg_rho_threshold=tf.cast( policy.config["vtrace_clip_pg_rho_threshold"], tf.float32), ) actions_logp = make_time_major(action_dist.logp(actions), drop_last=True) prev_actions_logp = make_time_major(prev_action_dist.logp(actions), drop_last=True) old_policy_actions_logp = make_time_major( old_policy_action_dist.logp(actions), drop_last=True) is_ratio = tf.clip_by_value( tf.math.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * tf.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. delta = values_time_major[:-1] - vtrace_returns.vs value_targets = vtrace_returns.vs mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. actions_entropy = make_time_major(action_dist.multi_entropy(), drop_last=True) mean_entropy = reduce_mean_valid(actions_entropy) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss mean_kl = make_time_major(prev_action_dist.multi_kl(action_dist)) logp_ratio = tf.math.exp( make_time_major(action_dist.logp(actions)) - make_time_major(prev_action_dist.logp(actions))) advantages = make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = tf.minimum( advantages * logp_ratio, advantages * tf.clip_by_value(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) action_kl = tf.reduce_mean(mean_kl, axis=0) \ if is_multidiscrete else mean_kl mean_kl = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(tf.math.square(delta)) # The entropy loss. mean_entropy = reduce_mean_valid( make_time_major(action_dist.multi_entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.config["entropy_coeff"] # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_kl = mean_kl policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy policy._value_targets = value_targets # Store stats in policy for stats_fn. return total_loss
def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) self.action_space_shape = action_space.shape self.n_products = config['number_of_products'] self.n_sources = config['number_of_sources']
def build_slateq_stats(policy: Policy, batch) -> Dict[str, TensorType]: stats = { "q_values": torch.mean(torch.stack(policy.get_tower_stats("q_values"))), "q_clicked": torch.mean(torch.stack(policy.get_tower_stats("q_clicked"))), "scores": torch.mean(torch.stack(policy.get_tower_stats("scores"))), "score_no_click": torch.mean( torch.stack(policy.get_tower_stats("score_no_click")) ), "slate_q_values": torch.mean( torch.stack(policy.get_tower_stats("slate_q_values")) ), "replay_click_q": torch.mean( torch.stack(policy.get_tower_stats("replay_click_q")) ), "bellman_reward": torch.mean( torch.stack(policy.get_tower_stats("bellman_reward")) ), "next_q_values": torch.mean( torch.stack(policy.get_tower_stats("next_q_values")) ), "target": torch.mean(torch.stack(policy.get_tower_stats("target"))), "next_q_target_slate": torch.mean( torch.stack(policy.get_tower_stats("next_q_target_slate")) ), "next_q_target_max": torch.mean( torch.stack(policy.get_tower_stats("next_q_target_max")) ), "target_clicked": torch.mean( torch.stack(policy.get_tower_stats("target_clicked")) ), "q_loss": torch.mean(torch.stack(policy.get_tower_stats("q_loss"))), "mean_actions": torch.mean(torch.stack(policy.get_tower_stats("mean_actions"))), "choice_loss": torch.mean(torch.stack(policy.get_tower_stats("choice_loss"))), # "choice_beta": torch.mean(torch.stack(policy.get_tower_stats("choice_beta"))), # "choice_score_no_click": torch.mean( # torch.stack(policy.get_tower_stats("choice_score_no_click")) # ), } # model_stats = { # k: torch.mean(var) # for k, var in policy.model.trainable_variables(as_dict=True).items() # } # stats.update(model_stats) return stats
def __init__( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict, **kwargs, ): self.framework = config.get("framework", "tf2") # Log device. logger.info("Creating TF-eager policy running on {}.".format( "GPU" if get_gpu_devices() else "CPU")) Policy.__init__(self, observation_space, action_space, config) config = dict(self.get_default_config(), **config) self.config = config self._is_training = False # Global timestep should be a tensor. self.global_timestep = tf.Variable(0, trainable=False, dtype=tf.int64) self.explore = tf.Variable(self.config["explore"], trainable=False, dtype=tf.bool) self._loss_initialized = False # Backward compatibility workaround so Policy will call self.loss() directly. # TODO(jungong): clean up after all policies are migrated to new sub-class # implementation. self._loss = None self.batch_divisibility_req = self.get_batch_divisibility_req() self._max_seq_len = config["model"]["max_seq_len"] self.validate_spaces(observation_space, action_space, config) # If using default make_model(), dist_class will get updated when # the model is created next. self.dist_class = self._init_dist_class() self.model = self.make_model() self._init_view_requirements() self.exploration = self._create_exploration() self._state_inputs = self.model.get_initial_state() self._is_recurrent = len(self._state_inputs) > 0 # Got to reset global_timestep again after fake run-throughs. self.global_timestep.assign(0) # Lock used for locking some methods on the object-level. # This prevents possible race conditions when calling the model # first, then its value function (e.g. in a loss function), in # between of which another model call is made (e.g. to compute an # action). self._lock = threading.RLock() # Only for `config.eager_tracing=True`: A counter to keep track of # how many times an eager-traced method (e.g. # `self._compute_actions_helper`) has been re-traced by tensorflow. # We will raise an error if more than n re-tracings have been # detected, since this would considerably slow down execution. # The variable below should only get incremented during the # tf.function trace operations, never when calling the already # traced function after that. self._re_trace_counter = 0
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get(SampleBatch.SEQ_LENS) model_out_t, state_in_t = model({ "obs": train_batch[SampleBatch.CUR_OBS], "prev_actions": train_batch[SampleBatch.PREV_ACTIONS], "prev_rewards": train_batch[SampleBatch.PREV_REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model({ "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = target_model({ "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) target_states_in_tp1 = target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0] # Target Q-values. q_tp1 = target_model.get_q_values( target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens)[0] twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)[0] q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot( train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t)[0] if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0] q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1)[0] if policy.config["twin_q"]: twin_q_tp1 = target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1)[0] # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = ( train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked ).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid( train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid( train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Save for stats function. policy.q_t = q_t * seq_mask[..., None] policy.policy_t = policy_t * seq_mask[..., None] policy.log_pis_t = log_pis_t * seq_mask[..., None] # Store td-error in model, such that for multi-GPU, we do not override # them during the parallel loss phase. TD-error tensor in final stats # can then be concatenated and retrieved for each individual batch item. model.td_error = td_error * seq_mask policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def ppo_surrogate_loss( policy: Policy, model: Union[ModelV2, "tf.keras.Model"], dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (Union[ModelV2, tf.keras.Model]): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ if isinstance(model, tf.keras.Model): logits, state, extra_outs = model(train_batch) value_fn_out = extra_outs[SampleBatch.VF_PREDS] else: logits, state = model(train_batch) value_fn_out = model.value_function() curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: # Derive max_seq_len from the data itself, not from the seq_lens # tensor. This is in case e.g. seq_lens=[2, 3], but the data is still # 0-padded up to T=5 (as it's the case for attention nets). B = tf.shape(train_batch[SampleBatch.SEQ_LENS])[0] max_seq_len = tf.shape(logits)[0] // B mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = tf.reshape(mask, [-1]) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, mask)) # non-RNN case: No masking. else: mask = None reduce_mean_valid = tf.reduce_mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = tf.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) # Only calculate kl loss if necessary (kl-coeff > 0.0). if policy.config["kl_coeff"] > 0.0: action_kl = prev_action_dist.kl(curr_action_dist) mean_kl_loss = reduce_mean_valid(action_kl) else: mean_kl_loss = 0.0 curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = tf.minimum( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * tf.clip_by_value( logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if policy.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] vf_loss1 = tf.math.square(value_fn_out - train_batch[Postprocessing.VALUE_TARGETS]) vf_clipped = prev_value_fn_out + tf.clip_by_value( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = tf.math.square(vf_clipped - train_batch[Postprocessing.VALUE_TARGETS]) vf_loss = tf.maximum(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = tf.constant(0.0) total_loss = reduce_mean_valid(-surrogate_loss + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) # Add mean_kl_loss (already processed through `reduce_mean_valid`), # if necessary. if policy.config["kl_coeff"] > 0.0: total_loss += policy.kl_coeff * mean_kl_loss # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._mean_entropy = mean_entropy # Backward compatibility: Deprecate policy._mean_kl. policy._mean_kl_loss = policy._mean_kl = mean_kl_loss policy._value_fn_out = value_fn_out return total_loss
def ppo_surrogate_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for Proximal Policy Objective. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ logits, state = model(train_batch) curr_action_dist = dist_class(logits, model) # RNN case: Mask away 0-padded chunks at end of time axis. if state: B = len(train_batch[SampleBatch.SEQ_LENS]) max_seq_len = logits.shape[0] // B mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len, time_major=model.is_time_major()) mask = torch.reshape(mask, [-1]) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid # non-RNN case: No masking. else: mask = None reduce_mean_valid = torch.mean prev_action_dist = dist_class(train_batch[SampleBatch.ACTION_DIST_INPUTS], model) logp_ratio = torch.exp( curr_action_dist.logp(train_batch[SampleBatch.ACTIONS]) - train_batch[SampleBatch.ACTION_LOGP]) action_kl = prev_action_dist.kl(curr_action_dist) mean_kl = reduce_mean_valid(action_kl) curr_entropy = curr_action_dist.entropy() mean_entropy = reduce_mean_valid(curr_entropy) surrogate_loss = torch.min( train_batch[Postprocessing.ADVANTAGES] * logp_ratio, train_batch[Postprocessing.ADVANTAGES] * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_policy_loss = reduce_mean_valid(-surrogate_loss) # Compute a value function loss. if policy.config["use_critic"]: prev_value_fn_out = train_batch[SampleBatch.VF_PREDS] value_fn_out = model.value_function() vf_loss1 = torch.pow( value_fn_out - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_clipped = prev_value_fn_out + torch.clamp( value_fn_out - prev_value_fn_out, -policy.config["vf_clip_param"], policy.config["vf_clip_param"]) vf_loss2 = torch.pow( vf_clipped - train_batch[Postprocessing.VALUE_TARGETS], 2.0) vf_loss = torch.max(vf_loss1, vf_loss2) mean_vf_loss = reduce_mean_valid(vf_loss) # Ignore the value function. else: vf_loss = mean_vf_loss = 0.0 total_loss = reduce_mean_valid(-surrogate_loss + policy.kl_coeff * action_kl + policy.config["vf_loss_coeff"] * vf_loss - policy.entropy_coeff * curr_entropy) # Store stats in policy for stats_fn. policy._total_loss = total_loss policy._mean_policy_loss = mean_policy_loss policy._mean_vf_loss = mean_vf_loss policy._vf_explained_var = explained_variance( train_batch[Postprocessing.VALUE_TARGETS], model.value_function()) policy._mean_entropy = mean_entropy policy._mean_kl = mean_kl return total_loss
def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) self.blocks = config['blocks'] self.fiftyone = config['fiftyone'] self.extended = config['extended']
def before_init_fn(policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> None: # Create global step for counting the number of update operations. policy.global_step = 0
def compute_and_clip_gradients(policy: Policy, optimizer: LocalOptimizer, loss: TensorType) -> ModelGradients: """Gradients computing function (from loss tensor, using local optimizer). Note: For SAC, optimizer and loss are ignored b/c we have 3 losses and 3 local optimizers (all stored in policy). `optimizer` will be used, though, in the tf-eager case b/c it is then a fake optimizer (OptimizerWrapper) object with a `tape` property to generate a GradientTape object for gradient recording. Args: policy (Policy): The Policy object that generated the loss tensor and that holds the given local optimizer. optimizer (LocalOptimizer): The tf (local) optimizer object to calculate the gradients with. loss (TensorType): The loss tensor for which gradients should be calculated. Returns: ModelGradients: List of the possibly clipped gradients- and variable tuples. """ # Eager: Use GradientTape (which is a property of the `optimizer` object # (an OptimizerWrapper): see rllib/policy/eager_tf_policy.py). if policy.config["framework"] in ["tf2", "tfe"]: tape = optimizer.tape pol_weights = policy.model.policy_variables() actor_grads_and_vars = list( zip(tape.gradient(policy.actor_loss, pol_weights), pol_weights)) q_weights = policy.model.q_variables() if policy.config["twin_q"]: half_cutoff = len(q_weights) // 2 grads_1 = tape.gradient(policy.critic_loss[0], q_weights[:half_cutoff]) grads_2 = tape.gradient(policy.critic_loss[1], q_weights[half_cutoff:]) critic_grads_and_vars = \ list(zip(grads_1, q_weights[:half_cutoff])) + \ list(zip(grads_2, q_weights[half_cutoff:])) else: critic_grads_and_vars = list( zip(tape.gradient(policy.critic_loss[0], q_weights), q_weights)) alpha_vars = [policy.model.log_alpha] alpha_grads_and_vars = list( zip(tape.gradient(policy.alpha_loss, alpha_vars), alpha_vars)) # Tf1.x: Use optimizer.compute_gradients() else: actor_grads_and_vars = policy._actor_optimizer.compute_gradients( policy.actor_loss, var_list=policy.model.policy_variables()) q_weights = policy.model.q_variables() if policy.config["twin_q"]: half_cutoff = len(q_weights) // 2 base_q_optimizer, twin_q_optimizer = policy._critic_optimizer critic_grads_and_vars = base_q_optimizer.compute_gradients( policy.critic_loss[0], var_list=q_weights[:half_cutoff] ) + twin_q_optimizer.compute_gradients( policy.critic_loss[1], var_list=q_weights[half_cutoff:]) else: critic_grads_and_vars = policy._critic_optimizer[ 0].compute_gradients(policy.critic_loss[0], var_list=q_weights) alpha_grads_and_vars = policy._alpha_optimizer.compute_gradients( policy.alpha_loss, var_list=[policy.model.log_alpha]) # Clip if necessary. if policy.config["grad_clip"]: clip_func = partial(tf.clip_by_norm, clip_norm=policy.config["grad_clip"]) else: clip_func = tf.identity # Save grads and vars for later use in `build_apply_op`. policy._actor_grads_and_vars = [(clip_func(g), v) for (g, v) in actor_grads_and_vars if g is not None] policy._critic_grads_and_vars = [(clip_func(g), v) for (g, v) in critic_grads_and_vars if g is not None] policy._alpha_grads_and_vars = [(clip_func(g), v) for (g, v) in alpha_grads_and_vars if g is not None] grads_and_vars = (policy._actor_grads_and_vars + policy._critic_grads_and_vars + policy._alpha_grads_and_vars) return grads_and_vars
def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) self.infiltrating = config['infiltrating']
def ddpg_actor_critic_loss(policy: Policy, model: ModelV2, _, train_batch: SampleBatch) -> TensorType: twin_q = policy.config["twin_q"] gamma = policy.config["gamma"] n_step = policy.config["n_step"] use_huber = policy.config["use_huber"] huber_threshold = policy.config["huber_threshold"] l2_reg = policy.config["l2_reg"] input_dict = SampleBatch(obs=train_batch[SampleBatch.CUR_OBS], _is_training=True) input_dict_next = SampleBatch(obs=train_batch[SampleBatch.NEXT_OBS], _is_training=True) model_out_t, _ = model(input_dict, [], None) model_out_tp1, _ = model(input_dict_next, [], None) target_model_out_tp1, _ = policy.target_model(input_dict_next, [], None) policy.target_q_func_vars = policy.target_model.variables() # Policy network evaluation. policy_t = model.get_policy_output(model_out_t) policy_tp1 = policy.target_model.get_policy_output(target_model_out_tp1) # Action outputs. if policy.config["smooth_target_policy"]: target_noise_clip = policy.config["target_noise_clip"] clipped_normal_sample = tf.clip_by_value( tf.random.normal(tf.shape(policy_tp1), stddev=policy.config["target_noise"]), -target_noise_clip, target_noise_clip, ) policy_tp1_smoothed = tf.clip_by_value( policy_tp1 + clipped_normal_sample, policy.action_space.low * tf.ones_like(policy_tp1), policy.action_space.high * tf.ones_like(policy_tp1), ) else: # No smoothing, just use deterministic actions. policy_tp1_smoothed = policy_tp1 # Q-net(s) evaluation. # prev_update_ops = set(tf.get_collection(tf.GraphKeys.UPDATE_OPS)) # Q-values for given actions & observations in given current q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy (no noise) in given current state q_t_det_policy = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) # Target q-net(s) evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1_smoothed) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1_smoothed) q_t_selected = tf.squeeze(q_t, axis=len(q_t.shape) - 1) if twin_q: twin_q_t_selected = tf.squeeze(twin_q_t, axis=len(q_t.shape) - 1) q_tp1 = tf.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=len(q_tp1.shape) - 1) q_tp1_best_masked = ( 1.0 - tf.cast(train_batch[SampleBatch.DONES], tf.float32)) * q_tp1_best # Compute RHS of bellman equation. q_t_selected_target = tf.stop_gradient( tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) + gamma**n_step * q_tp1_best_masked) # Compute the error (potentially clipped). if twin_q: td_error = q_t_selected - q_t_selected_target twin_td_error = twin_q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) + huber_loss( twin_td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) + 0.5 * tf.math.square( twin_td_error) else: td_error = q_t_selected - q_t_selected_target if use_huber: errors = huber_loss(td_error, huber_threshold) else: errors = 0.5 * tf.math.square(td_error) critic_loss = tf.reduce_mean( tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) * errors) actor_loss = -tf.reduce_mean(q_t_det_policy) # Add l2-regularization if required. if l2_reg is not None: for var in policy.model.policy_variables(): if "bias" not in var.name: actor_loss += l2_reg * tf.nn.l2_loss(var) for var in policy.model.q_variables(): if "bias" not in var.name: critic_loss += l2_reg * tf.nn.l2_loss(var) # Model self-supervised losses. if policy.config["use_state_preprocessor"]: # Expand input_dict in case custom_loss' need them. input_dict[SampleBatch.ACTIONS] = train_batch[SampleBatch.ACTIONS] input_dict[SampleBatch.REWARDS] = train_batch[SampleBatch.REWARDS] input_dict[SampleBatch.DONES] = train_batch[SampleBatch.DONES] input_dict[SampleBatch.NEXT_OBS] = train_batch[SampleBatch.NEXT_OBS] if log_once("ddpg_custom_loss"): logger.warning( "You are using a state-preprocessor with DDPG and " "therefore, `custom_loss` will be called on your Model! " "Please be aware that DDPG now uses the ModelV2 API, which " "merges all previously separate sub-models (policy_model, " "q_model, and twin_q_model) into one ModelV2, on which " "`custom_loss` is called, passing it " "[actor_loss, critic_loss] as 1st argument. " "You may have to change your custom loss function to handle " "this.") [actor_loss, critic_loss] = model.custom_loss([actor_loss, critic_loss], input_dict) # Store values for stats function. policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.td_error = td_error policy.q_t = q_t # Return one loss value (even though we treat them separately in our # 2 optimizers: actor and critic). return policy.critic_loss + policy.actor_loss
def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config) x, y, r1, r2 = get_Nash_equilibrium(config['alphas']) self.infiltrating = y / config['alphas'][1]
def build_q_model_and_distribution( policy: Policy, obs_space: gym.spaces.Space, action_space: gym.spaces.Space, config: TrainerConfigDict) -> Tuple[ModelV2, TorchDistributionWrapper]: """Build q_model and target_model for DQN Args: policy (Policy): The policy, which will use the model for optimization. obs_space (gym.spaces.Space): The policy's observation space. action_space (gym.spaces.Space): The policy's action space. config (TrainerConfigDict): Returns: (q_model, TorchCategorical) Note: The target q model will not be returned, just assigned to `policy.target_model`. """ if not isinstance(action_space, gym.spaces.Discrete): raise UnsupportedSpaceException( "Action space {} is not supported for DQN.".format(action_space)) if config["hiddens"]: # try to infer the last layer size, otherwise fall back to 256 num_outputs = ([256] + list(config["model"]["fcnet_hiddens"]))[-1] config["model"]["no_final_linear"] = True else: num_outputs = action_space.n # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm = ( isinstance(getattr(policy, "exploration", None), ParameterNoise) or config["exploration_config"]["type"] == "ParameterNoise") model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework="torch", model_interface=DQNTorchModel, name=Q_SCOPE, q_hiddens=config["hiddens"], dueling=config["dueling"], num_atoms=config["num_atoms"], use_noisy=config["noisy"], v_min=config["v_min"], v_max=config["v_max"], sigma0=config["sigma0"], # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm=add_layer_norm) policy.target_model = ModelCatalog.get_model_v2( obs_space=obs_space, action_space=action_space, num_outputs=num_outputs, model_config=config["model"], framework="torch", model_interface=DQNTorchModel, name=Q_TARGET_SCOPE, q_hiddens=config["hiddens"], dueling=config["dueling"], num_atoms=config["num_atoms"], use_noisy=config["noisy"], v_min=config["v_min"], v_max=config["v_max"], sigma0=config["sigma0"], # TODO(sven): Move option to add LayerNorm after each Dense # generically into ModelCatalog. add_layer_norm=add_layer_norm) return model, TorchCategorical
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] model_out_t, _ = model( { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": True, }, [], None) model_out_tp1, _ = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, }, [], None) target_model_out_tp1, _ = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "is_training": True, }, [], None) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output(model_out_t), dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax(model.get_policy_output(model_out_tp1), -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t) # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t) twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1) q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1), policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, train_batch[SampleBatch.ACTIONS]) if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, train_batch[SampleBatch.ACTIONS]) # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, policy_t) if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, policy_t) q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( torch.mean(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = torch.mean(torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = torch.mean( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -torch.mean(model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = torch.mean(alpha.detach() * log_pis_t - q_t_det_policy) # Save for stats function. policy.q_t = q_t policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def appo_surrogate_loss(policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> TensorType: """Constructs the loss for APPO. With IS modifications and V-trace for Advantage Estimation. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[ActionDistribution]): The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ target_model = policy.target_models[model] model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(policy.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [policy.action_space.n] elif isinstance(policy.action_space, gym.spaces.multi_discrete.MultiDiscrete): is_multidiscrete = True output_hidden_shape = policy.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kwargs): return make_time_major(policy, train_batch.get(SampleBatch.SEQ_LENS), *args, **kwargs) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] target_model_out, _ = target_model(train_batch) prev_action_dist = dist_class(behaviour_logits, model) values = model.value_function() values_time_major = _make_time_major(values) drop_last = policy.config["vtrace"] and \ policy.config["vtrace_drop_last_ts"] if policy.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask, [-1]) mask = _make_time_major(mask, drop_last=drop_last) num_valid = torch.sum(mask) def reduce_mean_valid(t): return torch.sum(t[mask]) / num_valid else: reduce_mean_valid = torch.mean if policy.config["vtrace"]: logger.debug("Using V-Trace surrogate loss (vtrace=True; " f"drop_last={drop_last})") old_policy_behaviour_logits = target_model_out.detach() old_policy_action_dist = dist_class(old_policy_behaviour_logits, model) if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_old_policy_behaviour_logits = torch.split( old_policy_behaviour_logits, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_old_policy_behaviour_logits = torch.chunk( old_policy_behaviour_logits, output_hidden_shape, dim=1) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Prepare KL for loss. action_kl = _make_time_major(old_policy_action_dist.kl(action_dist), drop_last=drop_last) # Compute vtrace on the CPU for better perf. vtrace_returns = vtrace.multi_from_logits( behaviour_policy_logits=_make_time_major(unpacked_behaviour_logits, drop_last=drop_last), target_policy_logits=_make_time_major( unpacked_old_policy_behaviour_logits, drop_last=drop_last), actions=torch.unbind(_make_time_major(loss_actions, drop_last=drop_last), dim=2), discounts=(1.0 - _make_time_major(dones, drop_last=drop_last).float()) * policy.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=values_time_major[:-1] if drop_last else values_time_major, bootstrap_value=values_time_major[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, clip_rho_threshold=policy.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=policy.config["vtrace_clip_pg_rho_threshold"] ) actions_logp = _make_time_major(action_dist.logp(actions), drop_last=drop_last) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions), drop_last=drop_last) old_policy_actions_logp = _make_time_major( old_policy_action_dist.logp(actions), drop_last=drop_last) is_ratio = torch.clamp( torch.exp(prev_actions_logp - old_policy_actions_logp), 0.0, 2.0) logp_ratio = is_ratio * torch.exp(actions_logp - prev_actions_logp) policy._is_ratio = is_ratio advantages = vtrace_returns.pg_advantages.to(logp_ratio.device) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = vtrace_returns.vs.to(values_time_major.device) if drop_last: delta = values_time_major[:-1] - value_targets else: delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy(), drop_last=drop_last)) else: logger.debug("Using PPO surrogate loss (vtrace=False)") # Prepare KL for Loss action_kl = _make_time_major(prev_action_dist.kl(action_dist)) actions_logp = _make_time_major(action_dist.logp(actions)) prev_actions_logp = _make_time_major(prev_action_dist.logp(actions)) logp_ratio = torch.exp(actions_logp - prev_actions_logp) advantages = _make_time_major(train_batch[Postprocessing.ADVANTAGES]) surrogate_loss = torch.min( advantages * logp_ratio, advantages * torch.clamp(logp_ratio, 1 - policy.config["clip_param"], 1 + policy.config["clip_param"])) mean_kl_loss = reduce_mean_valid(action_kl) mean_policy_loss = -reduce_mean_valid(surrogate_loss) # The value function loss. value_targets = _make_time_major( train_batch[Postprocessing.VALUE_TARGETS]) delta = values_time_major - value_targets mean_vf_loss = 0.5 * reduce_mean_valid(torch.pow(delta, 2.0)) # The entropy loss. mean_entropy = reduce_mean_valid( _make_time_major(action_dist.entropy())) # The summed weighted loss total_loss = mean_policy_loss + \ mean_vf_loss * policy.config["vf_loss_coeff"] - \ mean_entropy * policy.entropy_coeff # Optional additional KL Loss if policy.config["use_kl_loss"]: total_loss += policy.kl_coeff * mean_kl_loss # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["total_loss"] = total_loss model.tower_stats["mean_policy_loss"] = mean_policy_loss model.tower_stats["mean_kl_loss"] = mean_kl_loss model.tower_stats["mean_vf_loss"] = mean_vf_loss model.tower_stats["mean_entropy"] = mean_entropy model.tower_stats["value_targets"] = value_targets model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(value_targets, [-1]), torch.reshape( values_time_major[:-1] if drop_last else values_time_major, [-1]), ) return total_loss
def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None if get_default_config: config = dict(get_default_config(), **config) if before_init: before_init(self, observation_space, action_space, config) self.config = config if action_sampler_fn: if not make_model: raise ValueError( "make_model is required if action_sampler_fn is given") self._dist_class = None else: self._dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework="tf", ) self.model( { SampleBatch.CUR_OBS: tf.convert_to_tensor(np.array([observation_space.sample() ])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( [_flatten_action(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), }, [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ], tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_with_dummy_batch() self._loss_initialized = True if optimizer_fn: self._optimizer = optimizer_fn(self, config) else: self._optimizer = tf.train.AdamOptimizer(config["lr"]) if after_init: after_init(self, observation_space, action_space, config)
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TFPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values( policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True, ) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values( policy, policy.target_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get(SampleBatch.SEQ_LENS), explore=False, is_training=True, ) if not hasattr(policy, "target_q_func_vars"): policy.target_q_func_vars = policy.target_model.variables() actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64) dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32) rewards = train_batch[SampleBatch.REWARDS] weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) B = tf.shape(state_batches[0])[0] T = tf.shape(q)[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = tf.one_hot(actions, policy.action_space.n) q_selected = tf.reduce_sum( tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection, axis=1 ) if config["double_q"]: best_actions = tf.argmax(q, axis=1) else: best_actions = tf.argmax(q_target, axis=1) best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n) q_target_best = tf.reduce_sum( tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target)) * best_actions_one_hot, axis=1, ) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * tf.concat( [q_target_best[1:], tf.constant([0.0])], axis=0 ) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"] ** config["n_step"] * h_inv, config["h_function_epsilon"], ) else: target = ( rewards + config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 ) # Seq-mask all loss-related terms. seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["replay_buffer_config"]["replay_burn_in"] # Making sure, this works for both static graph and eager. if burn_in > 0: seq_mask = tf.cond( pred=tf.convert_to_tensor(burn_in, tf.int32) < T, true_fn=lambda: tf.concat( [tf.fill([B, burn_in], False), seq_mask[:, burn_in:]], 1 ), false_fn=lambda: seq_mask, ) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, seq_mask)) # Make sure to use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = tf.reshape(q_selected, [B, T])[:, :-1] td_error = q_selected - tf.stop_gradient(tf.reshape(target, [B, T])[:, :-1]) td_error = td_error * tf.cast(seq_mask, tf.float32) weights = tf.reshape(weights, [B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) # Store the TD-error per time chunk (b/c we need only one mean # prioritized replay weight per stored sequence). policy._td_error = tf.reduce_mean(td_error, axis=-1) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": tf.reduce_min(q_selected), "max_q": tf.reduce_max(q_selected), "mean_td_error": reduce_mean_valid(td_error), } return policy._total_loss
def cql_loss(policy: Policy, model: ModelV2, dist_class: Type[TFActionDistribution], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: logger.info(f"Current iteration = {policy.cur_iter}") policy.cur_iter += 1 # For best performance, turn deterministic off deterministic = policy.config["_deterministic_loss"] assert not deterministic twin_q = policy.config["twin_q"] discount = policy.config["gamma"] # CQL Parameters bc_iters = policy.config["bc_iters"] cql_temp = policy.config["temperature"] num_actions = policy.config["num_actions"] min_q_weight = policy.config["min_q_weight"] use_lagrange = policy.config["lagrangian"] target_action_gap = policy.config["lagrangian_thresh"] obs = train_batch[SampleBatch.CUR_OBS] actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.float32) rewards = tf.cast(train_batch[SampleBatch.REWARDS], tf.float32) next_obs = train_batch[SampleBatch.NEXT_OBS] terminals = train_batch[SampleBatch.DONES] model_out_t, _ = model(SampleBatch(obs=obs, _is_training=True), [], None) model_out_tp1, _ = model(SampleBatch(obs=next_obs, _is_training=True), [], None) target_model_out_tp1, _ = policy.target_model( SampleBatch(obs=next_obs, _is_training=True), [], None) action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class(model.get_policy_output(model_out_t), model) policy_t, log_pis_t = action_dist_t.sample_logp() log_pis_t = tf.expand_dims(log_pis_t, -1) # Unlike original SAC, Alpha and Actor Loss are computed first. # Alpha Loss alpha_loss = -tf.reduce_mean( model.log_alpha * tf.stop_gradient(log_pis_t + model.target_entropy)) # Policy Loss (Either Behavior Clone Loss or SAC Loss) alpha = tf.math.exp(model.log_alpha) if policy.cur_iter >= bc_iters: min_q = model.get_q_values(model_out_t, policy_t) if twin_q: twin_q_ = model.get_twin_q_values(model_out_t, policy_t) min_q = tf.math.minimum(min_q, twin_q_) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - min_q) else: bc_logp = action_dist_t.logp(actions) actor_loss = tf.reduce_mean( tf.stop_gradient(alpha) * log_pis_t - bc_logp) # actor_loss = -tf.reduce_mean(bc_logp) # Critic Loss (Standard SAC Critic L2 Loss + CQL Entropy Loss) # SAC Loss: # Q-values for the batched actions. action_dist_tp1 = action_dist_class(model.get_policy_output(model_out_tp1), model) policy_tp1, _ = action_dist_tp1.sample_logp() q_t = model.get_q_values(model_out_t, actions) q_t_selected = tf.squeeze(q_t, axis=-1) if twin_q: twin_q_t = model.get_twin_q_values(model_out_t, actions) twin_q_t_selected = tf.squeeze(twin_q_t, axis=-1) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, policy_tp1) if twin_q: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, policy_tp1) # Take min over both twin-NNs. q_tp1 = tf.math.minimum(q_tp1, twin_q_tp1) q_tp1_best = tf.squeeze(input=q_tp1, axis=-1) q_tp1_best_masked = (1.0 - tf.cast(terminals, tf.float32)) * q_tp1_best # compute RHS of bellman equation q_t_target = tf.stop_gradient(rewards + (discount**policy.config["n_step"]) * q_tp1_best_masked) # Compute the TD-error (potentially clipped), for priority replay buffer base_td_error = tf.math.abs(q_t_selected - q_t_target) if twin_q: twin_td_error = tf.math.abs(twin_q_t_selected - q_t_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss_1 = tf.keras.losses.MSE(q_t_selected, q_t_target) if twin_q: critic_loss_2 = tf.keras.losses.MSE(twin_q_t_selected, q_t_target) # CQL Loss (We are using Entropy version of CQL (the best version)) rand_actions, _ = policy._random_action_generator.get_exploration_action( action_distribution=action_dist_class( tf.tile(action_dist_tp1.inputs, (num_actions, 1)), model), timestep=0, explore=True) curr_actions, curr_logp = policy_actions_repeat(model, action_dist_class, model_out_t, num_actions) next_actions, next_logp = policy_actions_repeat(model, action_dist_class, model_out_tp1, num_actions) q1_rand = q_values_repeat(model, model_out_t, rand_actions) q1_curr_actions = q_values_repeat(model, model_out_t, curr_actions) q1_next_actions = q_values_repeat(model, model_out_t, next_actions) if twin_q: q2_rand = q_values_repeat(model, model_out_t, rand_actions, twin=True) q2_curr_actions = q_values_repeat(model, model_out_t, curr_actions, twin=True) q2_next_actions = q_values_repeat(model, model_out_t, next_actions, twin=True) random_density = np.log(0.5**int(curr_actions.shape[-1])) cat_q1 = tf.concat([ q1_rand - random_density, q1_next_actions - tf.stop_gradient(next_logp), q1_curr_actions - tf.stop_gradient(curr_logp) ], 1) if twin_q: cat_q2 = tf.concat([ q2_rand - random_density, q2_next_actions - tf.stop_gradient(next_logp), q2_curr_actions - tf.stop_gradient(curr_logp) ], 1) min_qf1_loss_ = tf.reduce_mean( tf.reduce_logsumexp(cat_q1 / cql_temp, axis=1)) * min_q_weight * cql_temp min_qf1_loss = min_qf1_loss_ - (tf.reduce_mean(q_t) * min_q_weight) if twin_q: min_qf2_loss_ = tf.reduce_mean( tf.reduce_logsumexp(cat_q2 / cql_temp, axis=1)) * min_q_weight * cql_temp min_qf2_loss = min_qf2_loss_ - (tf.reduce_mean(twin_q_t) * min_q_weight) if use_lagrange: alpha_prime = tf.clip_by_value(model.log_alpha_prime.exp(), 0.0, 1000000.0)[0] min_qf1_loss = alpha_prime * (min_qf1_loss - target_action_gap) if twin_q: min_qf2_loss = alpha_prime * (min_qf2_loss - target_action_gap) alpha_prime_loss = 0.5 * (-min_qf1_loss - min_qf2_loss) else: alpha_prime_loss = -min_qf1_loss cql_loss = [min_qf1_loss] if twin_q: cql_loss.append(min_qf2_loss) critic_loss = [critic_loss_1 + min_qf1_loss] if twin_q: critic_loss.append(critic_loss_2 + min_qf2_loss) # Save for stats function. policy.q_t = q_t_selected policy.policy_t = policy_t policy.log_pis_t = log_pis_t policy.td_error = td_error policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # CQL Stats policy.cql_loss = cql_loss if use_lagrange: policy.log_alpha_prime_value = model.log_alpha_prime[0] policy.alpha_prime_value = alpha_prime policy.alpha_prime_loss = alpha_prime_loss # Return all loss terms corresponding to our optimizers. if use_lagrange: return actor_loss + tf.math.add_n(critic_loss) + alpha_loss + \ alpha_prime_loss return actor_loss + tf.math.add_n(critic_loss) + alpha_loss
def __init__(self, observation_space, action_space, config): Policy.__init__(self, observation_space, action_space, config)
def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = "tfe" Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) self.exploration = self._create_exploration() self._state_in = [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ] input_dict = { SampleBatch.CUR_OBS: tf.convert_to_tensor( np.array([observation_space.sample()])), SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( [flatten_to_single_ndarray(action_space.sample())]), SampleBatch.PREV_REWARDS: tf.convert_to_tensor([0.]), } if action_distribution_fn: dist_inputs, self.dist_class, _ = action_distribution_fn( self, self.model, input_dict[SampleBatch.CUR_OBS]) else: self.model(input_dict, self._state_in, tf.convert_to_tensor([1])) if before_loss_init: before_loss_init(self, observation_space, action_space, config) self._initialize_loss_with_dummy_batch() self._loss_initialized = True if optimizer_fn: self._optimizer = optimizer_fn(self, config) else: self._optimizer = tf.keras.optimizers.Adam(config["lr"]) if after_init: after_init(self, observation_space, action_space, config)
def update_target_entropy(policy: Policy): # Constant Target # pass policy.target_entropy = policy.target_entropy
def __init__(self, observation_space, action_space, config): assert tf.executing_eagerly() self.framework = config.get("framework", "tfe") Policy.__init__(self, observation_space, action_space, config) self._is_training = False self._loss_initialized = False self._sess = None self._loss = loss_fn self.batch_divisibility_req = get_batch_divisibility_req(self) if \ callable(get_batch_divisibility_req) else \ (get_batch_divisibility_req or 1) self._max_seq_len = config["model"]["max_seq_len"] if get_default_config: config = dict(get_default_config(), **config) if validate_spaces: validate_spaces(self, observation_space, action_space, config) if before_init: before_init(self, observation_space, action_space, config) self.config = config self.dist_class = None if action_sampler_fn or action_distribution_fn: if not make_model: raise ValueError( "`make_model` is required if `action_sampler_fn` OR " "`action_distribution_fn` is given") else: self.dist_class, logit_dim = ModelCatalog.get_action_dist( action_space, self.config["model"]) if make_model: self.model = make_model(self, observation_space, action_space, config) else: self.model = ModelCatalog.get_model_v2( observation_space, action_space, logit_dim, config["model"], framework=self.framework, ) # Auto-update model's inference view requirements, if recurrent. self._update_model_inference_view_requirements_from_init_state() self.exploration = self._create_exploration() self._state_in = [ tf.convert_to_tensor([s]) for s in self.model.get_initial_state() ] # Combine view_requirements for Model and Policy. self.view_requirements.update( self.model.inference_view_requirements) if before_loss_init: before_loss_init(self, observation_space, action_space, config) if optimizer_fn: optimizers = optimizer_fn(self, config) else: optimizers = tf.keras.optimizers.Adam(config["lr"]) optimizers = force_list(optimizers) if getattr(self, "exploration", None): optimizers = self.exploration.get_exploration_optimizer( optimizers) # TODO: (sven) Allow tf policy to have more than 1 optimizer. # Just like torch Policy does. self._optimizer = optimizers[0] if optimizers else None self._initialize_loss_from_dummy_batch( auto_remove_unneeded_view_reqs=True, stats_fn=stats_fn, ) self._loss_initialized = True if after_init: after_init(self, observation_space, action_space, config) # Got to reset global_timestep again after fake run-throughs. self.global_timestep = 0