def compute_actions_from_input_dict( self, input_dict: SampleBatch, explore: bool = None, timestep: Optional[int] = None, episodes: Optional[List["MultiAgentEpisode"]] = None, **kwargs) -> \ Tuple[TensorType, List[TensorType], Dict[str, TensorType]]: """Computes actions from collected samples (across multiple-agents). Uses the currently "forward-pass-registered" samples from the collector to construct the input_dict for the Model. Args: input_dict (SampleBatch): A SampleBatch containing the Tensors to compute actions. `input_dict` already abides to the Policy's as well as the Model's view requirements and can thus be passed to the Model as-is. explore (bool): Whether to pick an exploitation or exploration action (default: None -> use self.config["explore"]). timestep (Optional[int]): The current (sampling) time step. kwargs: forward compatibility placeholder Returns: Tuple: actions (TensorType): Batch of output actions, with shape like [BATCH_SIZE, ACTION_SHAPE]. state_outs (List[TensorType]): List of RNN state output batches, if any, each with shape [BATCH_SIZE, STATE_SIZE]. info (dict): Dictionary of extra feature batches, if any, with shape like {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}. """ # Default implementation just passes obs, prev-a/r, and states on to # `self.compute_actions()`. state_batches = [ s for k, s in input_dict.items() if k[:9] == "state_in_" ] return self.compute_actions( input_dict[SampleBatch.OBS], state_batches, prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS), info_batch=input_dict.get(SampleBatch.INFOS), explore=explore, timestep=timestep, episodes=episodes, **kwargs, )
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: drop_last = self.config["vtrace"] and self.config[ "vtrace_drop_last_ts"] values_batched = _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), self.model.value_function(), drop_last=drop_last, ) return { "cur_lr": tf.cast(self.cur_lr, tf.float64), "policy_loss": self.vtrace_loss.mean_pi_loss, "entropy": self.vtrace_loss.mean_entropy, "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()), "vf_loss": self.vtrace_loss.mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(self.vtrace_loss.value_targets, [-1]), tf.reshape(values_batched, [-1]), ), }
def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]: values_batched = _make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), self.model.value_function(), drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"], ) stats_dict = { "cur_lr": tf.cast(self.cur_lr, tf.float64), "total_loss": self._total_loss, "policy_loss": self._mean_policy_loss, "entropy": self._mean_entropy, "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()), "vf_loss": self._mean_vf_loss, "vf_explained_var": explained_variance( tf.reshape(self._value_targets, [-1]), tf.reshape(values_batched, [-1]), ), "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64), } if self.config["vtrace"]: is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1]) stats_dict["mean_IS"] = is_stat_mean stats_dict["var_IS"] = is_stat_var if self.config["use_kl_loss"]: stats_dict["kl"] = self._mean_kl_loss stats_dict["KL_Coeff"] = self.kl_coeff return stats_dict
def call( self, input_dict: SampleBatch ) -> (TensorType, List[TensorType], Dict[str, TensorType]): assert input_dict.get(SampleBatch.SEQ_LENS) is not None # Push obs through underlying (wrapped) model first. wrapped_out, _, _ = self.wrapped_keras_model(input_dict) # Concat. prev-action/reward if required. prev_a_r = [] if self.lstm_use_prev_action: prev_a = input_dict[SampleBatch.PREV_ACTIONS] if isinstance(self.action_space, (Discrete, MultiDiscrete)): prev_a = one_hot(prev_a, self.action_space) prev_a_r.append( tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim]) ) if self.lstm_use_prev_reward: prev_a_r.append( tf.reshape( tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1] ) ) if prev_a_r: wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1) max_seq_len = ( tf.shape(wrapped_out)[0] // tf.shape(input_dict[SampleBatch.SEQ_LENS])[0] ) wrapped_out_plus_time_dim = add_time_dimension( wrapped_out, max_seq_len=max_seq_len, framework="tf" ) model_out, value_out, h, c = self._rnn_model( [ wrapped_out_plus_time_dim, input_dict[SampleBatch.SEQ_LENS], input_dict["state_in_0"], input_dict["state_in_1"], ] ) model_out_no_time_dim = tf.reshape( model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0) ) return ( model_out_no_time_dim, [h, c], {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])}, )
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]: """Stats function for APPO. Returns a dict with important loss stats. Args: policy (Policy): The Policy to generate stats for. train_batch (SampleBatch): The SampleBatch (already) used for training. Returns: Dict[str, TensorType]: The stats dict. """ values_batched = _make_time_major( policy, train_batch.get(SampleBatch.SEQ_LENS), policy.model.value_function(), drop_last=policy.config["vtrace"] and policy.config["vtrace_drop_last_ts"], ) stats_dict = { "cur_lr": tf.cast(policy.cur_lr, tf.float64), "total_loss": policy._total_loss, "policy_loss": policy._mean_policy_loss, "entropy": policy._mean_entropy, "var_gnorm": tf.linalg.global_norm(policy.model.trainable_variables()), "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance(tf.reshape(policy._value_targets, [-1]), tf.reshape(values_batched, [-1])), "entropy_coeff": tf.cast(policy.entropy_coeff, tf.float64), } if policy.config["vtrace"]: is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1]) stats_dict["mean_IS"] = is_stat_mean stats_dict["var_IS"] = is_stat_var if policy.config["use_kl_loss"]: stats_dict["kl"] = policy._mean_kl_loss stats_dict["KL_Coeff"] = policy.kl_coeff return stats_dict
def from_batch(self, train_batch: SampleBatch, is_training: bool = True) -> (TensorType, List[TensorType]): """Convenience function that calls this model with a tensor batch. All this does is unpack the tensor batch to call this model with the right input dict, state, and seq len arguments. """ train_batch["is_training"] = is_training states = [] i = 0 while "state_in_{}".format(i) in train_batch: states.append(train_batch["state_in_{}".format(i)]) i += 1 ret = self.__call__(train_batch, states, train_batch.get("seq_lens")) del train_batch["is_training"] return ret
def stats(policy: Policy, train_batch: SampleBatch): """Stats function for APPO. Returns a dict with important loss stats. Args: policy (Policy): The Policy to generate stats for. train_batch (SampleBatch): The SampleBatch (already) used for training. Returns: Dict[str, TensorType]: The stats dict. """ values_batched = make_time_major(policy, train_batch.get("seq_lens"), policy.model.value_function(), drop_last=policy.config["vtrace"]) stats_dict = { "cur_lr": policy.cur_lr, "policy_loss": policy._mean_policy_loss, "entropy": policy._mean_entropy, "var_gnorm": global_norm(policy.model.trainable_variables()), "vf_loss": policy._mean_vf_loss, "vf_explained_var": explained_variance(torch.reshape(policy._value_targets, [-1]), torch.reshape(values_batched, [-1])), } if policy.config["vtrace"]: is_stat_mean = torch.mean(policy._is_ratio, [0, 1]) is_stat_var = torch.var(policy._is_ratio, [0, 1]) stats_dict.update({"mean_IS": is_stat_mean}) stats_dict.update({"var_IS": is_stat_var}) if policy.config["use_kl_loss"]: stats_dict.update({"kl": policy._mean_kl}) stats_dict.update({"KL_Coeff": policy.kl_coeff}) return stats_dict
def from_batch(self, train_batch: SampleBatch, is_training: bool = True) -> (TensorType, List[TensorType]): """Convenience function that calls this model with a tensor batch. All this does is unpack the tensor batch to call this model with the right input dict, state, and seq len arguments. """ input_dict = { "obs": train_batch[SampleBatch.CUR_OBS], "is_training": is_training, } if SampleBatch.PREV_ACTIONS in train_batch: input_dict["prev_actions"] = train_batch[SampleBatch.PREV_ACTIONS] if SampleBatch.PREV_REWARDS in train_batch: input_dict["prev_rewards"] = train_batch[SampleBatch.PREV_REWARDS] states = [] i = 0 while "state_in_{}".format(i) in train_batch: states.append(train_batch["state_in_{}".format(i)]) i += 1 return self.__call__(input_dict, states, train_batch.get("seq_lens"))
def postprocess_trajectory(self, policy: "Policy", sample_batch: SampleBatch, tf_sess: Optional["tf.Session"] = None): noisy_action_dist = noise_free_action_dist = None # Adjust the stddev depending on the action (pi)-distance. # Also see [1] for details. # TODO(sven): Find out whether this can be scrapped by simply using # the `sample_batch` to get the noisy/noise-free action dist. _, _, fetches = policy.compute_actions( obs_batch=sample_batch[SampleBatch.CUR_OBS], # TODO(sven): What about state-ins and seq-lens? prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS), explore=self.weights_are_currently_noisy) # Categorical case (e.g. DQN). if policy.dist_class in (Categorical, TorchCategorical): action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS]) # Deterministic (Gaussian actions, e.g. DDPG). elif policy.dist_class in [Deterministic, TorchDeterministic]: action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS] else: raise NotImplementedError # TODO(sven): Other action-dist cases. if self.weights_are_currently_noisy: noisy_action_dist = action_dist else: noise_free_action_dist = action_dist _, _, fetches = policy.compute_actions( obs_batch=sample_batch[SampleBatch.CUR_OBS], prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS), explore=not self.weights_are_currently_noisy) # Categorical case (e.g. DQN). if policy.dist_class in (Categorical, TorchCategorical): action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS]) # Deterministic (Gaussian actions, e.g. DDPG). elif policy.dist_class in [Deterministic, TorchDeterministic]: action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS] if noisy_action_dist is None: noisy_action_dist = action_dist else: noise_free_action_dist = action_dist delta = distance = None # Categorical case (e.g. DQN). if policy.dist_class in (Categorical, TorchCategorical): # Calculate KL-divergence (DKL(clean||noisy)) according to [2]. # TODO(sven): Allow KL-divergence to be calculated by our # Distribution classes (don't support off-graph/numpy yet). distance = np.nanmean( np.sum( noise_free_action_dist * np.log(noise_free_action_dist / (noisy_action_dist + SMALL_NUMBER)), 1)) current_epsilon = self.sub_exploration.get_info( sess=tf_sess)["cur_epsilon"] delta = -np.log(1 - current_epsilon + current_epsilon / self.action_space.n) elif policy.dist_class in [Deterministic, TorchDeterministic]: # Calculate MSE between noisy and non-noisy output (see [2]). distance = np.sqrt( np.mean(np.square(noise_free_action_dist - noisy_action_dist))) current_scale = self.sub_exploration.get_info( sess=tf_sess)["cur_scale"] delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * \ current_scale # Adjust stddev according to the calculated action-distance. if distance <= delta: self.stddev_val *= 1.01 else: self.stddev_val /= 1.01 # Set self.stddev to calculated value. if self.framework == "tf": self.stddev.load(self.stddev_val, session=tf_sess) else: self.stddev = self.stddev_val return sample_batch
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TorchPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, policy.target_q_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get("seq_lens"), explore=False, is_training=True) actions = train_batch[SampleBatch.ACTIONS].long() dones = train_batch[SampleBatch.DONES].float() rewards = train_batch[SampleBatch.REWARDS] weights = train_batch[PRIO_WEIGHTS] B = state_batches[0].shape[0] T = q.shape[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = F.one_hot(actions, policy.action_space.n) q_selected = torch.sum( torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device)) * one_hot_selection, 1) if config["double_q"]: best_actions = torch.argmax(q, dim=1) else: best_actions = torch.argmax(q_target, dim=1) best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n) q_target_best = torch.sum( torch.where(q_target > FLOAT_MIN, q_target, torch.tensor(0.0, device=policy.device)) * best_actions_one_hot, dim=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * torch.cat( [q_target_best[1:], torch.tensor([0.0], device=policy.device)]) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = q_selected.reshape([B, T])[:, :-1] td_error = q_selected - target.reshape([B, T])[:, :-1].detach() td_error = td_error * seq_mask weights = weights.reshape([B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) policy._td_error = td_error.reshape([-1]) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": torch.min(q_selected), "max_q": torch.max(q_selected), "mean_td_error": reduce_mean_valid(td_error), } return policy._total_loss
def r2d2_loss(policy: Policy, model, _, train_batch: SampleBatch) -> TensorType: """Constructs the loss for R2D2TFPolicy. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. train_batch (SampleBatch): The training data. Returns: TensorType: A single loss tensor. """ config = policy.config # Construct internal state inputs. i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches # Q-network evaluation (at t). q, _, _, _ = compute_q_values(policy, model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) # Target Q-network evaluation (at t+1). q_target, _, _, _ = compute_q_values(policy, policy.target_model, train_batch, state_batches=state_batches, seq_lens=train_batch.get( SampleBatch.SEQ_LENS), explore=False, is_training=True) if not hasattr(policy, "target_q_func_vars"): policy.target_q_func_vars = policy.target_model.variables() actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64) dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32) rewards = train_batch[SampleBatch.REWARDS] weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32) B = tf.shape(state_batches[0])[0] T = tf.shape(q)[0] // B # Q scores for actions which we know were selected in the given state. one_hot_selection = tf.one_hot(actions, policy.action_space.n) q_selected = tf.reduce_sum( tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection, axis=1) if config["double_q"]: best_actions = tf.argmax(q, axis=1) else: best_actions = tf.argmax(q_target, axis=1) best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n) q_target_best = tf.reduce_sum(tf.where(q_target > tf.float32.min, q_target, tf.zeros_like(q_target)) * best_actions_one_hot, axis=1) if config["num_atoms"] > 1: raise ValueError("Distributional R2D2 not supported yet!") else: q_target_best_masked_tp1 = (1.0 - dones) * tf.concat( [q_target_best[1:], tf.constant([0.0])], axis=0) if config["use_h_function"]: h_inv = h_inverse(q_target_best_masked_tp1, config["h_function_epsilon"]) target = h_function( rewards + config["gamma"]**config["n_step"] * h_inv, config["h_function_epsilon"]) else: target = rewards + \ config["gamma"] ** config["n_step"] * q_target_best_masked_tp1 # Seq-mask all loss-related terms. seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS], T)[:, :-1] # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] # Making sure, this works for both static graph and eager. if burn_in > 0 and (config["framework"] == "tf" or burn_in < T): seq_mask = tf.concat( [tf.fill([B, burn_in], False), seq_mask[:, burn_in:]], axis=1) def reduce_mean_valid(t): return tf.reduce_mean(tf.boolean_mask(t, seq_mask)) # Make sure use the correct time indices: # Q(t) - [gamma * r + Q^(t+1)] q_selected = tf.reshape(q_selected, [B, T])[:, :-1] td_error = q_selected - tf.stop_gradient( tf.reshape(target, [B, T])[:, :-1]) td_error = td_error * tf.cast(seq_mask, tf.float32) weights = tf.reshape(weights, [B, T])[:, :-1] policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error)) policy._td_error = tf.reshape(td_error, [-1]) policy._loss_stats = { "mean_q": reduce_mean_valid(q_selected), "min_q": tf.reduce_min(q_selected), "max_q": tf.reduce_max(q_selected), "mean_td_error": reduce_mean_valid(td_error), } return policy._total_loss
def pad_batch_to_sequences_of_same_size( batch: SampleBatch, max_seq_len: int, shuffle: bool = False, batch_divisibility_req: int = 1, feature_keys: Optional[List[str]] = None, view_requirements: Optional[ViewRequirementsDict] = None, ): """Applies padding to `batch` so it's choppable into same-size sequences. Shuffles `batch` (if desired), makes sure divisibility requirement is met, then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o adding a time dimension (yet). Padding depends on episodes found in batch and `max_seq_len`. Args: batch (SampleBatch): The SampleBatch object. All values in here have the shape [B, ...]. max_seq_len (int): The max. sequence length to use for chopping. shuffle (bool): Whether to shuffle batch sequences. Shuffle may be done in-place. This only makes sense if you're further applying minibatch SGD after getting the outputs. batch_divisibility_req (int): The int by which the batch dimension must be dividable. feature_keys (Optional[List[str]]): An optional list of keys to apply sequence-chopping to. If None, use all keys in batch that are not "state_in/out_"-type keys. view_requirements (Optional[ViewRequirementsDict]): An optional Policy ViewRequirements dict to be able to infer whether e.g. dynamic max'ing should be applied over the seq_lens. """ if batch_divisibility_req > 1: meets_divisibility_reqs = ( len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0 # not multiagent and max(batch[SampleBatch.AGENT_INDEX]) == 0) else: meets_divisibility_reqs = True states_already_reduced_to_init = False # RNN/attention net case. Figure out whether we should apply dynamic # max'ing over the list of sequence lengths. if "state_in_0" in batch or "state_out_0" in batch: # Check, whether the state inputs have already been reduced to their # init values at the beginning of each max_seq_len chunk. if batch.seq_lens is not None and \ len(batch["state_in_0"]) == len(batch.seq_lens): states_already_reduced_to_init = True # RNN (or single timestep state-in): Set the max dynamically. if view_requirements["state_in_0"].shift_from is None: dynamic_max = True # Attention Nets (state inputs are over some range): No dynamic maxing # possible. else: dynamic_max = False # Multi-agent case. elif not meets_divisibility_reqs: max_seq_len = batch_divisibility_req dynamic_max = False # Simple case: No RNN/attention net, nor do we need to pad. else: if shuffle: batch.shuffle() return # RNN, attention net, or multi-agent case. state_keys = [] feature_keys_ = feature_keys or [] for k, v in batch.items(): if k.startswith("state_in_"): state_keys.append(k) elif not feature_keys and not k.startswith("state_out_") and \ k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray): feature_keys_.append(k) feature_sequences, initial_states, seq_lens = \ chop_into_sequences( feature_columns=[batch[k] for k in feature_keys_], state_columns=[batch[k] for k in state_keys], episode_ids=batch.get(SampleBatch.EPS_ID), unroll_ids=batch.get(SampleBatch.UNROLL_ID), agent_indices=batch.get(SampleBatch.AGENT_INDEX), seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")), max_seq_len=max_seq_len, dynamic_max=dynamic_max, states_already_reduced_to_init=states_already_reduced_to_init, shuffle=shuffle) for i, k in enumerate(feature_keys_): batch[k] = feature_sequences[i] for i, k in enumerate(state_keys): batch[k] = initial_states[i] batch["seq_lens"] = np.array(seq_lens) if log_once("rnn_ma_feed_dict"): logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format( summarize({ "features": feature_sequences, "initial_states": initial_states, "seq_lens": seq_lens, "max_seq_len": max_seq_len, })))
def actor_critic_loss( policy: Policy, model: ModelV2, dist_class: Type[TorchDistributionWrapper], train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]: """Constructs the loss for the Soft Actor Critic. Args: policy (Policy): The Policy to calculate the loss for. model (ModelV2): The Model to calculate the loss for. dist_class (Type[TorchDistributionWrapper]: The action distr. class. train_batch (SampleBatch): The training data. Returns: Union[TensorType, List[TensorType]]: A single loss tensor or a list of loss tensors. """ # Should be True only for debugging purposes (e.g. test cases)! deterministic = policy.config["_deterministic_loss"] i = 0 state_batches = [] while "state_in_{}".format(i) in train_batch: state_batches.append(train_batch["state_in_{}".format(i)]) i += 1 assert state_batches seq_lens = train_batch.get("seq_lens") model_out_t, state_in_t = model( { "obs": train_batch[SampleBatch.CUR_OBS], "prev_actions": train_batch[SampleBatch.PREV_ACTIONS], "prev_rewards": train_batch[SampleBatch.PREV_REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"]) model_out_tp1, state_in_tp1 = model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) target_model_out_tp1, target_state_in_tp1 = policy.target_model( { "obs": train_batch[SampleBatch.NEXT_OBS], "prev_actions": train_batch[SampleBatch.ACTIONS], "prev_rewards": train_batch[SampleBatch.REWARDS], "is_training": True, }, state_batches, seq_lens) target_states_in_tp1 = \ policy.target_model.select_state(state_in_tp1, ["policy", "q", "twin_q"]) alpha = torch.exp(model.log_alpha) # Discrete case. if model.discrete: # Get all action probs directly from pi and form their logp. log_pis_t = F.log_softmax(model.get_policy_output( model_out_t, states_in_t["policy"], seq_lens)[0], dim=-1) policy_t = torch.exp(log_pis_t) log_pis_tp1 = F.log_softmax( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], -1) policy_tp1 = torch.exp(log_pis_tp1) # Q-values. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0] # Target Q-values. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens)[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values(model_out_t, states_in_t["twin_q"], seq_lens)[0] twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens)[0] q_tp1 = torch.min(q_tp1, twin_q_tp1) q_tp1 -= alpha * log_pis_tp1 # Actually selected Q-values (from the actions batch). one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(), num_classes=q_t.size()[-1]) q_t_selected = torch.sum(q_t * one_hot, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1) # Discrete case: "Best" means weighted by the policy (prob) outputs. q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * \ q_tp1_best # Continuous actions case. else: # Sample single actions from distribution. action_dist_class = _get_dist_class(policy, policy.config, policy.action_space) action_dist_t = action_dist_class( model.get_policy_output(model_out_t, states_in_t["policy"], seq_lens)[0], policy.model) policy_t = action_dist_t.sample() if not deterministic else \ action_dist_t.deterministic_sample() log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1) action_dist_tp1 = action_dist_class( model.get_policy_output(model_out_tp1, states_in_tp1["policy"], seq_lens)[0], policy.model) policy_tp1 = action_dist_tp1.sample() if not deterministic else \ action_dist_tp1.deterministic_sample() log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1) # Q-values for the actually selected actions. q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] if policy.config["twin_q"]: twin_q_t = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, train_batch[SampleBatch.ACTIONS])[0] # Q-values for current policy in given current state. q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"], seq_lens, policy_t)[0] if policy.config["twin_q"]: twin_q_t_det_policy = model.get_twin_q_values( model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0] q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy) # Target q network evaluation. q_tp1 = policy.target_model.get_q_values(target_model_out_tp1, target_states_in_tp1["q"], seq_lens, policy_tp1)[0] if policy.config["twin_q"]: twin_q_tp1 = policy.target_model.get_twin_q_values( target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens, policy_tp1)[0] # Take min over both twin-NNs. q_tp1 = torch.min(q_tp1, twin_q_tp1) q_t_selected = torch.squeeze(q_t, dim=-1) if policy.config["twin_q"]: twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1) q_tp1 -= alpha * log_pis_tp1 q_tp1_best = torch.squeeze(input=q_tp1, dim=-1) q_tp1_best_masked = \ (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best # compute RHS of bellman equation q_t_selected_target = (train_batch[SampleBatch.REWARDS] + (policy.config["gamma"]**policy.config["n_step"]) * q_tp1_best_masked).detach() # BURNIN # B = state_batches[0].shape[0] T = q_t_selected.shape[0] // B seq_mask = sequence_mask(train_batch["seq_lens"], T) # Mask away also the burn-in sequence at the beginning. burn_in = policy.config["burn_in"] if burn_in > 0 and burn_in < T: seq_mask[:, :burn_in] = False seq_mask = seq_mask.reshape(-1) num_valid = torch.sum(seq_mask) def reduce_mean_valid(t): return torch.sum(t[seq_mask]) / num_valid # Compute the TD-error (potentially clipped). base_td_error = torch.abs(q_t_selected - q_t_selected_target) if policy.config["twin_q"]: twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target) td_error = 0.5 * (base_td_error + twin_td_error) else: td_error = base_td_error critic_loss = [ reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(base_td_error)) ] if policy.config["twin_q"]: critic_loss.append( reduce_mean_valid(train_batch[PRIO_WEIGHTS] * huber_loss(twin_td_error))) # Alpha- and actor losses. # Note: In the papers, alpha is used directly, here we take the log. # Discrete case: Multiply the action probs as weights with the original # loss terms (no expectations needed). if model.discrete: weighted_log_alpha_loss = policy_t.detach() * ( -model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Sum up weighted terms and mean over all batch items. alpha_loss = reduce_mean_valid( torch.sum(weighted_log_alpha_loss, dim=-1)) # Actor loss. actor_loss = reduce_mean_valid( torch.sum( torch.mul( # NOTE: No stop_grad around policy output here # (compare with q_t_det_policy for continuous case). policy_t, alpha.detach() * log_pis_t - q_t.detach()), dim=-1)) else: alpha_loss = -reduce_mean_valid( model.log_alpha * (log_pis_t + model.target_entropy).detach()) # Note: Do not detach q_t_det_policy here b/c is depends partly # on the policy vars (policy sample pushed through Q-net). # However, we must make sure `actor_loss` is not used to update # the Q-net(s)' variables. actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t - q_t_det_policy) # Save for stats function. policy.q_t = q_t * seq_mask[..., None] policy.policy_t = policy_t * seq_mask[..., None] policy.log_pis_t = log_pis_t * seq_mask[..., None] # Store td-error in model, such that for multi-GPU, we do not override # them during the parallel loss phase. TD-error tensor in final stats # can then be concatenated and retrieved for each individual batch item. model.td_error = td_error * seq_mask policy.actor_loss = actor_loss policy.critic_loss = critic_loss policy.alpha_loss = alpha_loss policy.log_alpha_value = model.log_alpha policy.alpha_value = alpha policy.target_entropy = model.target_entropy # Return all loss terms corresponding to our optimizers. return tuple([policy.actor_loss] + policy.critic_loss + [policy.alpha_loss])
def train_q(self, batch: SampleBatch) -> TensorType: """Trains self.q_model using Q-Reg loss on given batch. Args: batch: A SampleBatch of episodes to train on Returns: A list of losses for each training iteration """ losses = [] obs = torch.tensor(batch[SampleBatch.OBS], device=self.device) actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device) ps = torch.zeros([batch.count], device=self.device) returns = torch.zeros([batch.count], device=self.device) discounts = torch.zeros([batch.count], device=self.device) # Neccessary if policy uses recurrent/attention model num_state_inputs = 0 for k in batch.keys(): if k.startswith("state_in_"): num_state_inputs += 1 state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] # get rewards, old_prob, new_prob rewards = batch[SampleBatch.REWARDS] old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP]) new_log_prob = (self.policy.compute_log_likelihoods( actions=batch[SampleBatch.ACTIONS], obs_batch=batch[SampleBatch.OBS], state_batches=[batch[k] for k in state_keys], prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), actions_normalized=False, ).detach().cpu()) prob_ratio = torch.exp(new_log_prob - old_log_prob) eps_begin = 0 for episode in batch.split_by_episode(): eps_end = eps_begin + episode.count # calculate importance ratios and returns for t in range(episode.count): discounts[eps_begin + t] = self.gamma**t if t == 0: pt_prev = 1.0 else: pt_prev = ps[eps_begin + t - 1] ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t] # O(n^3) # ret = 0 # for t_prime in range(t, episode.count): # gamma = self.gamma ** (t_prime - t) # rho_t_1_t_prime = 1.0 # for k in range(t + 1, min(t_prime + 1, episode.count)): # rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k] # r = rewards[eps_begin + t_prime] # ret += gamma * rho_t_1_t_prime * r # O(n^2) ret = 0 rho = 1 for t_ in reversed(range(t, episode.count)): ret = rewards[eps_begin + t_] + self.gamma * rho * ret rho = prob_ratio[eps_begin + t_] returns[eps_begin + t] = ret # Update before next episode eps_begin = eps_end indices = np.arange(batch.count) for _ in range(self.n_iters): minibatch_losses = [] np.random.shuffle(indices) for idx in range(0, batch.count, self.batch_size): idxs = indices[idx:idx + self.batch_size] q_values, _ = self.q_model({"obs": obs[idxs]}, [], None) q_acts = torch.gather(q_values, -1, actions[idxs].unsqueeze(-1)).squeeze(-1) loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts)**2 loss = torch.mean(loss) self.optimizer.zero_grad() loss.backward() nn.utils.clip_grad.clip_grad_norm_(self.q_model.variables(), self.clip_grad_norm) self.optimizer.step() minibatch_losses.append(loss.item()) iter_loss = sum(minibatch_losses) / len(minibatch_losses) losses.append(iter_loss) if iter_loss < self.delta: break return losses
def loss( self, model: ModelV2, dist_class: Type[ActionDistribution], train_batch: SampleBatch, ) -> Union[TensorType, List[TensorType]]: model_out, _ = model(train_batch) action_dist = dist_class(model_out, model) if isinstance(self.action_space, gym.spaces.Discrete): is_multidiscrete = False output_hidden_shape = [self.action_space.n] elif isinstance(self.action_space, gym.spaces.MultiDiscrete): is_multidiscrete = True output_hidden_shape = self.action_space.nvec.astype(np.int32) else: is_multidiscrete = False output_hidden_shape = 1 def _make_time_major(*args, **kw): return make_time_major(self, train_batch.get(SampleBatch.SEQ_LENS), *args, **kw) actions = train_batch[SampleBatch.ACTIONS] dones = train_batch[SampleBatch.DONES] rewards = train_batch[SampleBatch.REWARDS] behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP] behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS] if isinstance(output_hidden_shape, (list, tuple, np.ndarray)): unpacked_behaviour_logits = torch.split(behaviour_logits, list(output_hidden_shape), dim=1) unpacked_outputs = torch.split(model_out, list(output_hidden_shape), dim=1) else: unpacked_behaviour_logits = torch.chunk(behaviour_logits, output_hidden_shape, dim=1) unpacked_outputs = torch.chunk(model_out, output_hidden_shape, dim=1) values = model.value_function() if self.is_recurrent(): max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS]) mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS], max_seq_len) mask = torch.reshape(mask_orig, [-1]) else: mask = torch.ones_like(rewards) # Prepare actions for loss. loss_actions = actions if is_multidiscrete else torch.unsqueeze( actions, dim=1) # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc. drop_last = self.config["vtrace_drop_last_ts"] loss = VTraceLoss( actions=_make_time_major(loss_actions, drop_last=drop_last), actions_logp=_make_time_major(action_dist.logp(actions), drop_last=drop_last), actions_entropy=_make_time_major(action_dist.entropy(), drop_last=drop_last), dones=_make_time_major(dones, drop_last=drop_last), behaviour_action_logp=_make_time_major(behaviour_action_logp, drop_last=drop_last), behaviour_logits=_make_time_major(unpacked_behaviour_logits, drop_last=drop_last), target_logits=_make_time_major(unpacked_outputs, drop_last=drop_last), discount=self.config["gamma"], rewards=_make_time_major(rewards, drop_last=drop_last), values=_make_time_major(values, drop_last=drop_last), bootstrap_value=_make_time_major(values)[-1], dist_class=TorchCategorical if is_multidiscrete else dist_class, model=model, valid_mask=_make_time_major(mask, drop_last=drop_last), config=self.config, vf_loss_coeff=self.config["vf_loss_coeff"], entropy_coeff=self.entropy_coeff, clip_rho_threshold=self.config["vtrace_clip_rho_threshold"], clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"], ) # Store values for stats function in model (tower), such that for # multi-GPU, we do not override them during the parallel loss phase. model.tower_stats["pi_loss"] = loss.pi_loss model.tower_stats["vf_loss"] = loss.vf_loss model.tower_stats["entropy"] = loss.entropy model.tower_stats["mean_entropy"] = loss.mean_entropy model.tower_stats["total_loss"] = loss.total_loss values_batched = make_time_major( self, train_batch.get(SampleBatch.SEQ_LENS), values, drop_last=self.config["vtrace"] and drop_last, ) model.tower_stats["vf_explained_var"] = explained_variance( torch.reshape(loss.value_targets, [-1]), torch.reshape(values_batched, [-1])) return loss.total_loss