示例#1
0
    def compute_actions_from_input_dict(
            self,
            input_dict: SampleBatch,
            explore: bool = None,
            timestep: Optional[int] = None,
            episodes: Optional[List["MultiAgentEpisode"]] = None,
            **kwargs) -> \
            Tuple[TensorType, List[TensorType], Dict[str, TensorType]]:
        """Computes actions from collected samples (across multiple-agents).

        Uses the currently "forward-pass-registered" samples from the collector
        to construct the input_dict for the Model.

        Args:
            input_dict (SampleBatch): A SampleBatch containing the Tensors
                to compute actions. `input_dict` already abides to the
                Policy's as well as the Model's view requirements and can
                thus be passed to the Model as-is.
            explore (bool): Whether to pick an exploitation or exploration
                action (default: None -> use self.config["explore"]).
            timestep (Optional[int]): The current (sampling) time step.
            kwargs: forward compatibility placeholder

        Returns:
            Tuple:
                actions (TensorType): Batch of output actions, with shape
                    like [BATCH_SIZE, ACTION_SHAPE].
                state_outs (List[TensorType]): List of RNN state output
                    batches, if any, each with shape [BATCH_SIZE, STATE_SIZE].
                info (dict): Dictionary of extra feature batches, if any, with
                    shape like
                    {"f1": [BATCH_SIZE, ...], "f2": [BATCH_SIZE, ...]}.
        """
        # Default implementation just passes obs, prev-a/r, and states on to
        # `self.compute_actions()`.
        state_batches = [
            s for k, s in input_dict.items() if k[:9] == "state_in_"
        ]
        return self.compute_actions(
            input_dict[SampleBatch.OBS],
            state_batches,
            prev_action_batch=input_dict.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=input_dict.get(SampleBatch.PREV_REWARDS),
            info_batch=input_dict.get(SampleBatch.INFOS),
            explore=explore,
            timestep=timestep,
            episodes=episodes,
            **kwargs,
        )
示例#2
0
        def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
            drop_last = self.config["vtrace"] and self.config[
                "vtrace_drop_last_ts"]
            values_batched = _make_time_major(
                self,
                train_batch.get(SampleBatch.SEQ_LENS),
                self.model.value_function(),
                drop_last=drop_last,
            )

            return {
                "cur_lr":
                tf.cast(self.cur_lr, tf.float64),
                "policy_loss":
                self.vtrace_loss.mean_pi_loss,
                "entropy":
                self.vtrace_loss.mean_entropy,
                "entropy_coeff":
                tf.cast(self.entropy_coeff, tf.float64),
                "var_gnorm":
                tf.linalg.global_norm(self.model.trainable_variables()),
                "vf_loss":
                self.vtrace_loss.mean_vf_loss,
                "vf_explained_var":
                explained_variance(
                    tf.reshape(self.vtrace_loss.value_targets, [-1]),
                    tf.reshape(values_batched, [-1]),
                ),
            }
示例#3
0
        def stats_fn(self, train_batch: SampleBatch) -> Dict[str, TensorType]:
            values_batched = _make_time_major(
                self,
                train_batch.get(SampleBatch.SEQ_LENS),
                self.model.value_function(),
                drop_last=self.config["vtrace"] and self.config["vtrace_drop_last_ts"],
            )

            stats_dict = {
                "cur_lr": tf.cast(self.cur_lr, tf.float64),
                "total_loss": self._total_loss,
                "policy_loss": self._mean_policy_loss,
                "entropy": self._mean_entropy,
                "var_gnorm": tf.linalg.global_norm(self.model.trainable_variables()),
                "vf_loss": self._mean_vf_loss,
                "vf_explained_var": explained_variance(
                    tf.reshape(self._value_targets, [-1]),
                    tf.reshape(values_batched, [-1]),
                ),
                "entropy_coeff": tf.cast(self.entropy_coeff, tf.float64),
            }

            if self.config["vtrace"]:
                is_stat_mean, is_stat_var = tf.nn.moments(self._is_ratio, [0, 1])
                stats_dict["mean_IS"] = is_stat_mean
                stats_dict["var_IS"] = is_stat_var

            if self.config["use_kl_loss"]:
                stats_dict["kl"] = self._mean_kl_loss
                stats_dict["KL_Coeff"] = self.kl_coeff

            return stats_dict
示例#4
0
    def call(
        self, input_dict: SampleBatch
    ) -> (TensorType, List[TensorType], Dict[str, TensorType]):
        assert input_dict.get(SampleBatch.SEQ_LENS) is not None
        # Push obs through underlying (wrapped) model first.
        wrapped_out, _, _ = self.wrapped_keras_model(input_dict)

        # Concat. prev-action/reward if required.
        prev_a_r = []
        if self.lstm_use_prev_action:
            prev_a = input_dict[SampleBatch.PREV_ACTIONS]
            if isinstance(self.action_space, (Discrete, MultiDiscrete)):
                prev_a = one_hot(prev_a, self.action_space)
            prev_a_r.append(
                tf.reshape(tf.cast(prev_a, tf.float32), [-1, self.action_dim])
            )
        if self.lstm_use_prev_reward:
            prev_a_r.append(
                tf.reshape(
                    tf.cast(input_dict[SampleBatch.PREV_REWARDS], tf.float32), [-1, 1]
                )
            )

        if prev_a_r:
            wrapped_out = tf.concat([wrapped_out] + prev_a_r, axis=1)

        max_seq_len = (
            tf.shape(wrapped_out)[0] // tf.shape(input_dict[SampleBatch.SEQ_LENS])[0]
        )
        wrapped_out_plus_time_dim = add_time_dimension(
            wrapped_out, max_seq_len=max_seq_len, framework="tf"
        )
        model_out, value_out, h, c = self._rnn_model(
            [
                wrapped_out_plus_time_dim,
                input_dict[SampleBatch.SEQ_LENS],
                input_dict["state_in_0"],
                input_dict["state_in_1"],
            ]
        )
        model_out_no_time_dim = tf.reshape(
            model_out, tf.concat([[-1], tf.shape(model_out)[2:]], axis=0)
        )
        return (
            model_out_no_time_dim,
            [h, c],
            {SampleBatch.VF_PREDS: tf.reshape(value_out, [-1])},
        )
示例#5
0
def stats(policy: Policy, train_batch: SampleBatch) -> Dict[str, TensorType]:
    """Stats function for APPO. Returns a dict with important loss stats.

    Args:
        policy (Policy): The Policy to generate stats for.
        train_batch (SampleBatch): The SampleBatch (already) used for training.

    Returns:
        Dict[str, TensorType]: The stats dict.
    """
    values_batched = _make_time_major(
        policy,
        train_batch.get(SampleBatch.SEQ_LENS),
        policy.model.value_function(),
        drop_last=policy.config["vtrace"]
        and policy.config["vtrace_drop_last_ts"],
    )

    stats_dict = {
        "cur_lr":
        tf.cast(policy.cur_lr, tf.float64),
        "total_loss":
        policy._total_loss,
        "policy_loss":
        policy._mean_policy_loss,
        "entropy":
        policy._mean_entropy,
        "var_gnorm":
        tf.linalg.global_norm(policy.model.trainable_variables()),
        "vf_loss":
        policy._mean_vf_loss,
        "vf_explained_var":
        explained_variance(tf.reshape(policy._value_targets, [-1]),
                           tf.reshape(values_batched, [-1])),
        "entropy_coeff":
        tf.cast(policy.entropy_coeff, tf.float64),
    }

    if policy.config["vtrace"]:
        is_stat_mean, is_stat_var = tf.nn.moments(policy._is_ratio, [0, 1])
        stats_dict["mean_IS"] = is_stat_mean
        stats_dict["var_IS"] = is_stat_var

    if policy.config["use_kl_loss"]:
        stats_dict["kl"] = policy._mean_kl_loss
        stats_dict["KL_Coeff"] = policy.kl_coeff

    return stats_dict
示例#6
0
    def from_batch(self,
                   train_batch: SampleBatch,
                   is_training: bool = True) -> (TensorType, List[TensorType]):
        """Convenience function that calls this model with a tensor batch.

        All this does is unpack the tensor batch to call this model with the
        right input dict, state, and seq len arguments.
        """

        train_batch["is_training"] = is_training
        states = []
        i = 0
        while "state_in_{}".format(i) in train_batch:
            states.append(train_batch["state_in_{}".format(i)])
            i += 1
        ret = self.__call__(train_batch, states, train_batch.get("seq_lens"))
        del train_batch["is_training"]
        return ret
示例#7
0
def stats(policy: Policy, train_batch: SampleBatch):
    """Stats function for APPO. Returns a dict with important loss stats.

    Args:
        policy (Policy): The Policy to generate stats for.
        train_batch (SampleBatch): The SampleBatch (already) used for training.

    Returns:
        Dict[str, TensorType]: The stats dict.
    """
    values_batched = make_time_major(policy,
                                     train_batch.get("seq_lens"),
                                     policy.model.value_function(),
                                     drop_last=policy.config["vtrace"])

    stats_dict = {
        "cur_lr":
        policy.cur_lr,
        "policy_loss":
        policy._mean_policy_loss,
        "entropy":
        policy._mean_entropy,
        "var_gnorm":
        global_norm(policy.model.trainable_variables()),
        "vf_loss":
        policy._mean_vf_loss,
        "vf_explained_var":
        explained_variance(torch.reshape(policy._value_targets, [-1]),
                           torch.reshape(values_batched, [-1])),
    }

    if policy.config["vtrace"]:
        is_stat_mean = torch.mean(policy._is_ratio, [0, 1])
        is_stat_var = torch.var(policy._is_ratio, [0, 1])
        stats_dict.update({"mean_IS": is_stat_mean})
        stats_dict.update({"var_IS": is_stat_var})

    if policy.config["use_kl_loss"]:
        stats_dict.update({"kl": policy._mean_kl})
        stats_dict.update({"KL_Coeff": policy.kl_coeff})

    return stats_dict
示例#8
0
    def from_batch(self, train_batch: SampleBatch,
                   is_training: bool = True) -> (TensorType, List[TensorType]):
        """Convenience function that calls this model with a tensor batch.

        All this does is unpack the tensor batch to call this model with the
        right input dict, state, and seq len arguments.
        """

        input_dict = {
            "obs": train_batch[SampleBatch.CUR_OBS],
            "is_training": is_training,
        }
        if SampleBatch.PREV_ACTIONS in train_batch:
            input_dict["prev_actions"] = train_batch[SampleBatch.PREV_ACTIONS]
        if SampleBatch.PREV_REWARDS in train_batch:
            input_dict["prev_rewards"] = train_batch[SampleBatch.PREV_REWARDS]
        states = []
        i = 0
        while "state_in_{}".format(i) in train_batch:
            states.append(train_batch["state_in_{}".format(i)])
            i += 1
        return self.__call__(input_dict, states, train_batch.get("seq_lens"))
示例#9
0
    def postprocess_trajectory(self,
                               policy: "Policy",
                               sample_batch: SampleBatch,
                               tf_sess: Optional["tf.Session"] = None):
        noisy_action_dist = noise_free_action_dist = None
        # Adjust the stddev depending on the action (pi)-distance.
        # Also see [1] for details.
        # TODO(sven): Find out whether this can be scrapped by simply using
        #  the `sample_batch` to get the noisy/noise-free action dist.
        _, _, fetches = policy.compute_actions(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            # TODO(sven): What about state-ins and seq-lens?
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
        # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]
        else:
            raise NotImplementedError  # TODO(sven): Other action-dist cases.

        if self.weights_are_currently_noisy:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        _, _, fetches = policy.compute_actions(
            obs_batch=sample_batch[SampleBatch.CUR_OBS],
            prev_action_batch=sample_batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=sample_batch.get(SampleBatch.PREV_REWARDS),
            explore=not self.weights_are_currently_noisy)

        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            action_dist = softmax(fetches[SampleBatch.ACTION_DIST_INPUTS])
            # Deterministic (Gaussian actions, e.g. DDPG).
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            action_dist = fetches[SampleBatch.ACTION_DIST_INPUTS]

        if noisy_action_dist is None:
            noisy_action_dist = action_dist
        else:
            noise_free_action_dist = action_dist

        delta = distance = None
        # Categorical case (e.g. DQN).
        if policy.dist_class in (Categorical, TorchCategorical):
            # Calculate KL-divergence (DKL(clean||noisy)) according to [2].
            # TODO(sven): Allow KL-divergence to be calculated by our
            #  Distribution classes (don't support off-graph/numpy yet).
            distance = np.nanmean(
                np.sum(
                    noise_free_action_dist *
                    np.log(noise_free_action_dist /
                           (noisy_action_dist + SMALL_NUMBER)), 1))
            current_epsilon = self.sub_exploration.get_info(
                sess=tf_sess)["cur_epsilon"]
            delta = -np.log(1 - current_epsilon +
                            current_epsilon / self.action_space.n)
        elif policy.dist_class in [Deterministic, TorchDeterministic]:
            # Calculate MSE between noisy and non-noisy output (see [2]).
            distance = np.sqrt(
                np.mean(np.square(noise_free_action_dist - noisy_action_dist)))
            current_scale = self.sub_exploration.get_info(
                sess=tf_sess)["cur_scale"]
            delta = getattr(self.sub_exploration, "ou_sigma", 0.2) * \
                current_scale

        # Adjust stddev according to the calculated action-distance.
        if distance <= delta:
            self.stddev_val *= 1.01
        else:
            self.stddev_val /= 1.01

        # Set self.stddev to calculated value.
        if self.framework == "tf":
            self.stddev.load(self.stddev_val, session=tf_sess)
        else:
            self.stddev = self.stddev_val

        return sample_batch
示例#10
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TorchPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get("seq_lens"),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         policy.target_q_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get("seq_lens"),
                                         explore=False,
                                         is_training=True)

    actions = train_batch[SampleBatch.ACTIONS].long()
    dones = train_batch[SampleBatch.DONES].float()
    rewards = train_batch[SampleBatch.REWARDS]
    weights = train_batch[PRIO_WEIGHTS]

    B = state_batches[0].shape[0]
    T = q.shape[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = F.one_hot(actions, policy.action_space.n)
    q_selected = torch.sum(
        torch.where(q > FLOAT_MIN, q, torch.tensor(0.0, device=policy.device))
        * one_hot_selection, 1)

    if config["double_q"]:
        best_actions = torch.argmax(q, dim=1)
    else:
        best_actions = torch.argmax(q_target, dim=1)

    best_actions_one_hot = F.one_hot(best_actions, policy.action_space.n)
    q_target_best = torch.sum(
        torch.where(q_target > FLOAT_MIN, q_target,
                    torch.tensor(0.0, device=policy.device)) *
        best_actions_one_hot,
        dim=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * torch.cat(
            [q_target_best[1:],
             torch.tensor([0.0], device=policy.device)])

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = sequence_mask(train_batch["seq_lens"], T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        if burn_in > 0 and burn_in < T:
            seq_mask[:, :burn_in] = False

        num_valid = torch.sum(seq_mask)

        def reduce_mean_valid(t):
            return torch.sum(t[seq_mask]) / num_valid

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = q_selected.reshape([B, T])[:, :-1]
        td_error = q_selected - target.reshape([B, T])[:, :-1].detach()
        td_error = td_error * seq_mask
        weights = weights.reshape([B, T])[:, :-1]
        policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
        policy._td_error = td_error.reshape([-1])
        policy._loss_stats = {
            "mean_q": reduce_mean_valid(q_selected),
            "min_q": torch.min(q_selected),
            "max_q": torch.max(q_selected),
            "mean_td_error": reduce_mean_valid(td_error),
        }

    return policy._total_loss
示例#11
0
def r2d2_loss(policy: Policy, model, _,
              train_batch: SampleBatch) -> TensorType:
    """Constructs the loss for R2D2TFPolicy.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        train_batch (SampleBatch): The training data.

    Returns:
        TensorType: A single loss tensor.
    """
    config = policy.config

    # Construct internal state inputs.
    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches

    # Q-network evaluation (at t).
    q, _, _, _ = compute_q_values(policy,
                                  model,
                                  train_batch,
                                  state_batches=state_batches,
                                  seq_lens=train_batch.get(
                                      SampleBatch.SEQ_LENS),
                                  explore=False,
                                  is_training=True)

    # Target Q-network evaluation (at t+1).
    q_target, _, _, _ = compute_q_values(policy,
                                         policy.target_model,
                                         train_batch,
                                         state_batches=state_batches,
                                         seq_lens=train_batch.get(
                                             SampleBatch.SEQ_LENS),
                                         explore=False,
                                         is_training=True)

    if not hasattr(policy, "target_q_func_vars"):
        policy.target_q_func_vars = policy.target_model.variables()

    actions = tf.cast(train_batch[SampleBatch.ACTIONS], tf.int64)
    dones = tf.cast(train_batch[SampleBatch.DONES], tf.float32)
    rewards = train_batch[SampleBatch.REWARDS]
    weights = tf.cast(train_batch[PRIO_WEIGHTS], tf.float32)

    B = tf.shape(state_batches[0])[0]
    T = tf.shape(q)[0] // B

    # Q scores for actions which we know were selected in the given state.
    one_hot_selection = tf.one_hot(actions, policy.action_space.n)
    q_selected = tf.reduce_sum(
        tf.where(q > tf.float32.min, q, tf.zeros_like(q)) * one_hot_selection,
        axis=1)

    if config["double_q"]:
        best_actions = tf.argmax(q, axis=1)
    else:
        best_actions = tf.argmax(q_target, axis=1)

    best_actions_one_hot = tf.one_hot(best_actions, policy.action_space.n)
    q_target_best = tf.reduce_sum(tf.where(q_target > tf.float32.min, q_target,
                                           tf.zeros_like(q_target)) *
                                  best_actions_one_hot,
                                  axis=1)

    if config["num_atoms"] > 1:
        raise ValueError("Distributional R2D2 not supported yet!")
    else:
        q_target_best_masked_tp1 = (1.0 - dones) * tf.concat(
            [q_target_best[1:], tf.constant([0.0])], axis=0)

        if config["use_h_function"]:
            h_inv = h_inverse(q_target_best_masked_tp1,
                              config["h_function_epsilon"])
            target = h_function(
                rewards + config["gamma"]**config["n_step"] * h_inv,
                config["h_function_epsilon"])
        else:
            target = rewards + \
                config["gamma"] ** config["n_step"] * q_target_best_masked_tp1

        # Seq-mask all loss-related terms.
        seq_mask = tf.sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                    T)[:, :-1]
        # Mask away also the burn-in sequence at the beginning.
        burn_in = policy.config["burn_in"]
        # Making sure, this works for both static graph and eager.
        if burn_in > 0 and (config["framework"] == "tf" or burn_in < T):
            seq_mask = tf.concat(
                [tf.fill([B, burn_in], False), seq_mask[:, burn_in:]], axis=1)

        def reduce_mean_valid(t):
            return tf.reduce_mean(tf.boolean_mask(t, seq_mask))

        # Make sure use the correct time indices:
        # Q(t) - [gamma * r + Q^(t+1)]
        q_selected = tf.reshape(q_selected, [B, T])[:, :-1]
        td_error = q_selected - tf.stop_gradient(
            tf.reshape(target, [B, T])[:, :-1])
        td_error = td_error * tf.cast(seq_mask, tf.float32)
        weights = tf.reshape(weights, [B, T])[:, :-1]
        policy._total_loss = reduce_mean_valid(weights * huber_loss(td_error))
        policy._td_error = tf.reshape(td_error, [-1])
        policy._loss_stats = {
            "mean_q": reduce_mean_valid(q_selected),
            "min_q": tf.reduce_min(q_selected),
            "max_q": tf.reduce_max(q_selected),
            "mean_td_error": reduce_mean_valid(td_error),
        }

    return policy._total_loss
示例#12
0
def pad_batch_to_sequences_of_same_size(
    batch: SampleBatch,
    max_seq_len: int,
    shuffle: bool = False,
    batch_divisibility_req: int = 1,
    feature_keys: Optional[List[str]] = None,
    view_requirements: Optional[ViewRequirementsDict] = None,
):
    """Applies padding to `batch` so it's choppable into same-size sequences.

    Shuffles `batch` (if desired), makes sure divisibility requirement is met,
    then pads the batch ([B, ...]) into same-size chunks ([B, ...]) w/o
    adding a time dimension (yet).
    Padding depends on episodes found in batch and `max_seq_len`.

    Args:
        batch (SampleBatch): The SampleBatch object. All values in here have
            the shape [B, ...].
        max_seq_len (int): The max. sequence length to use for chopping.
        shuffle (bool): Whether to shuffle batch sequences. Shuffle may
            be done in-place. This only makes sense if you're further
            applying minibatch SGD after getting the outputs.
        batch_divisibility_req (int): The int by which the batch dimension
            must be dividable.
        feature_keys (Optional[List[str]]): An optional list of keys to apply
            sequence-chopping to. If None, use all keys in batch that are not
            "state_in/out_"-type keys.
        view_requirements (Optional[ViewRequirementsDict]): An optional
            Policy ViewRequirements dict to be able to infer whether
            e.g. dynamic max'ing should be applied over the seq_lens.
    """
    if batch_divisibility_req > 1:
        meets_divisibility_reqs = (
            len(batch[SampleBatch.CUR_OBS]) % batch_divisibility_req == 0
            # not multiagent
            and max(batch[SampleBatch.AGENT_INDEX]) == 0)
    else:
        meets_divisibility_reqs = True

    states_already_reduced_to_init = False

    # RNN/attention net case. Figure out whether we should apply dynamic
    # max'ing over the list of sequence lengths.
    if "state_in_0" in batch or "state_out_0" in batch:
        # Check, whether the state inputs have already been reduced to their
        # init values at the beginning of each max_seq_len chunk.
        if batch.seq_lens is not None and \
                len(batch["state_in_0"]) == len(batch.seq_lens):
            states_already_reduced_to_init = True

        # RNN (or single timestep state-in): Set the max dynamically.
        if view_requirements["state_in_0"].shift_from is None:
            dynamic_max = True
        # Attention Nets (state inputs are over some range): No dynamic maxing
        # possible.
        else:
            dynamic_max = False
    # Multi-agent case.
    elif not meets_divisibility_reqs:
        max_seq_len = batch_divisibility_req
        dynamic_max = False
    # Simple case: No RNN/attention net, nor do we need to pad.
    else:
        if shuffle:
            batch.shuffle()
        return

    # RNN, attention net, or multi-agent case.
    state_keys = []
    feature_keys_ = feature_keys or []
    for k, v in batch.items():
        if k.startswith("state_in_"):
            state_keys.append(k)
        elif not feature_keys and not k.startswith("state_out_") and \
                k not in ["infos", "seq_lens"] and isinstance(v, np.ndarray):
            feature_keys_.append(k)

    feature_sequences, initial_states, seq_lens = \
        chop_into_sequences(
            feature_columns=[batch[k] for k in feature_keys_],
            state_columns=[batch[k] for k in state_keys],
            episode_ids=batch.get(SampleBatch.EPS_ID),
            unroll_ids=batch.get(SampleBatch.UNROLL_ID),
            agent_indices=batch.get(SampleBatch.AGENT_INDEX),
            seq_lens=getattr(batch, "seq_lens", batch.get("seq_lens")),
            max_seq_len=max_seq_len,
            dynamic_max=dynamic_max,
            states_already_reduced_to_init=states_already_reduced_to_init,
            shuffle=shuffle)

    for i, k in enumerate(feature_keys_):
        batch[k] = feature_sequences[i]
    for i, k in enumerate(state_keys):
        batch[k] = initial_states[i]
    batch["seq_lens"] = np.array(seq_lens)

    if log_once("rnn_ma_feed_dict"):
        logger.info("Padded input for RNN/Attn.Nets/MA:\n\n{}\n".format(
            summarize({
                "features": feature_sequences,
                "initial_states": initial_states,
                "seq_lens": seq_lens,
                "max_seq_len": max_seq_len,
            })))
示例#13
0
def actor_critic_loss(
        policy: Policy, model: ModelV2,
        dist_class: Type[TorchDistributionWrapper],
        train_batch: SampleBatch) -> Union[TensorType, List[TensorType]]:
    """Constructs the loss for the Soft Actor Critic.

    Args:
        policy (Policy): The Policy to calculate the loss for.
        model (ModelV2): The Model to calculate the loss for.
        dist_class (Type[TorchDistributionWrapper]: The action distr. class.
        train_batch (SampleBatch): The training data.

    Returns:
        Union[TensorType, List[TensorType]]: A single loss tensor or a list
            of loss tensors.
    """
    # Should be True only for debugging purposes (e.g. test cases)!
    deterministic = policy.config["_deterministic_loss"]

    i = 0
    state_batches = []
    while "state_in_{}".format(i) in train_batch:
        state_batches.append(train_batch["state_in_{}".format(i)])
        i += 1
    assert state_batches
    seq_lens = train_batch.get("seq_lens")

    model_out_t, state_in_t = model(
        {
            "obs": train_batch[SampleBatch.CUR_OBS],
            "prev_actions": train_batch[SampleBatch.PREV_ACTIONS],
            "prev_rewards": train_batch[SampleBatch.PREV_REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    states_in_t = model.select_state(state_in_t, ["policy", "q", "twin_q"])

    model_out_tp1, state_in_tp1 = model(
        {
            "obs": train_batch[SampleBatch.NEXT_OBS],
            "prev_actions": train_batch[SampleBatch.ACTIONS],
            "prev_rewards": train_batch[SampleBatch.REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    states_in_tp1 = model.select_state(state_in_tp1, ["policy", "q", "twin_q"])

    target_model_out_tp1, target_state_in_tp1 = policy.target_model(
        {
            "obs": train_batch[SampleBatch.NEXT_OBS],
            "prev_actions": train_batch[SampleBatch.ACTIONS],
            "prev_rewards": train_batch[SampleBatch.REWARDS],
            "is_training": True,
        }, state_batches, seq_lens)
    target_states_in_tp1 = \
        policy.target_model.select_state(state_in_tp1,
                                         ["policy", "q", "twin_q"])

    alpha = torch.exp(model.log_alpha)

    # Discrete case.
    if model.discrete:
        # Get all action probs directly from pi and form their logp.
        log_pis_t = F.log_softmax(model.get_policy_output(
            model_out_t, states_in_t["policy"], seq_lens)[0],
                                  dim=-1)
        policy_t = torch.exp(log_pis_t)
        log_pis_tp1 = F.log_softmax(
            model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
                                    seq_lens)[0], -1)
        policy_tp1 = torch.exp(log_pis_tp1)
        # Q-values.
        q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens)[0]
        # Target Q-values.
        q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
                                                 target_states_in_tp1["q"],
                                                 seq_lens)[0]
        if policy.config["twin_q"]:
            twin_q_t = model.get_twin_q_values(model_out_t,
                                               states_in_t["twin_q"],
                                               seq_lens)[0]
            twin_q_tp1 = policy.target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"],
                seq_lens)[0]
            q_tp1 = torch.min(q_tp1, twin_q_tp1)
        q_tp1 -= alpha * log_pis_tp1

        # Actually selected Q-values (from the actions batch).
        one_hot = F.one_hot(train_batch[SampleBatch.ACTIONS].long(),
                            num_classes=q_t.size()[-1])
        q_t_selected = torch.sum(q_t * one_hot, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.sum(twin_q_t * one_hot, dim=-1)
        # Discrete case: "Best" means weighted by the policy (prob) outputs.
        q_tp1_best = torch.sum(torch.mul(policy_tp1, q_tp1), dim=-1)
        q_tp1_best_masked = \
            (1.0 - train_batch[SampleBatch.DONES].float()) * \
            q_tp1_best
    # Continuous actions case.
    else:
        # Sample single actions from distribution.
        action_dist_class = _get_dist_class(policy, policy.config,
                                            policy.action_space)
        action_dist_t = action_dist_class(
            model.get_policy_output(model_out_t, states_in_t["policy"],
                                    seq_lens)[0], policy.model)
        policy_t = action_dist_t.sample() if not deterministic else \
            action_dist_t.deterministic_sample()
        log_pis_t = torch.unsqueeze(action_dist_t.logp(policy_t), -1)
        action_dist_tp1 = action_dist_class(
            model.get_policy_output(model_out_tp1, states_in_tp1["policy"],
                                    seq_lens)[0], policy.model)
        policy_tp1 = action_dist_tp1.sample() if not deterministic else \
            action_dist_tp1.deterministic_sample()
        log_pis_tp1 = torch.unsqueeze(action_dist_tp1.logp(policy_tp1), -1)

        # Q-values for the actually selected actions.
        q_t = model.get_q_values(model_out_t, states_in_t["q"], seq_lens,
                                 train_batch[SampleBatch.ACTIONS])[0]
        if policy.config["twin_q"]:
            twin_q_t = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens,
                train_batch[SampleBatch.ACTIONS])[0]

        # Q-values for current policy in given current state.
        q_t_det_policy = model.get_q_values(model_out_t, states_in_t["q"],
                                            seq_lens, policy_t)[0]
        if policy.config["twin_q"]:
            twin_q_t_det_policy = model.get_twin_q_values(
                model_out_t, states_in_t["twin_q"], seq_lens, policy_t)[0]
            q_t_det_policy = torch.min(q_t_det_policy, twin_q_t_det_policy)

        # Target q network evaluation.
        q_tp1 = policy.target_model.get_q_values(target_model_out_tp1,
                                                 target_states_in_tp1["q"],
                                                 seq_lens, policy_tp1)[0]
        if policy.config["twin_q"]:
            twin_q_tp1 = policy.target_model.get_twin_q_values(
                target_model_out_tp1, target_states_in_tp1["twin_q"], seq_lens,
                policy_tp1)[0]
            # Take min over both twin-NNs.
            q_tp1 = torch.min(q_tp1, twin_q_tp1)

        q_t_selected = torch.squeeze(q_t, dim=-1)
        if policy.config["twin_q"]:
            twin_q_t_selected = torch.squeeze(twin_q_t, dim=-1)
        q_tp1 -= alpha * log_pis_tp1

        q_tp1_best = torch.squeeze(input=q_tp1, dim=-1)
        q_tp1_best_masked = \
            (1.0 - train_batch[SampleBatch.DONES].float()) * q_tp1_best

    # compute RHS of bellman equation
    q_t_selected_target = (train_batch[SampleBatch.REWARDS] +
                           (policy.config["gamma"]**policy.config["n_step"]) *
                           q_tp1_best_masked).detach()

    # BURNIN #
    B = state_batches[0].shape[0]
    T = q_t_selected.shape[0] // B
    seq_mask = sequence_mask(train_batch["seq_lens"], T)
    # Mask away also the burn-in sequence at the beginning.
    burn_in = policy.config["burn_in"]
    if burn_in > 0 and burn_in < T:
        seq_mask[:, :burn_in] = False

    seq_mask = seq_mask.reshape(-1)
    num_valid = torch.sum(seq_mask)

    def reduce_mean_valid(t):
        return torch.sum(t[seq_mask]) / num_valid

    # Compute the TD-error (potentially clipped).
    base_td_error = torch.abs(q_t_selected - q_t_selected_target)
    if policy.config["twin_q"]:
        twin_td_error = torch.abs(twin_q_t_selected - q_t_selected_target)
        td_error = 0.5 * (base_td_error + twin_td_error)
    else:
        td_error = base_td_error

    critic_loss = [
        reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                          huber_loss(base_td_error))
    ]
    if policy.config["twin_q"]:
        critic_loss.append(
            reduce_mean_valid(train_batch[PRIO_WEIGHTS] *
                              huber_loss(twin_td_error)))

    # Alpha- and actor losses.
    # Note: In the papers, alpha is used directly, here we take the log.
    # Discrete case: Multiply the action probs as weights with the original
    # loss terms (no expectations needed).
    if model.discrete:
        weighted_log_alpha_loss = policy_t.detach() * (
            -model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Sum up weighted terms and mean over all batch items.
        alpha_loss = reduce_mean_valid(
            torch.sum(weighted_log_alpha_loss, dim=-1))
        # Actor loss.
        actor_loss = reduce_mean_valid(
            torch.sum(
                torch.mul(
                    # NOTE: No stop_grad around policy output here
                    # (compare with q_t_det_policy for continuous case).
                    policy_t,
                    alpha.detach() * log_pis_t - q_t.detach()),
                dim=-1))
    else:
        alpha_loss = -reduce_mean_valid(
            model.log_alpha * (log_pis_t + model.target_entropy).detach())
        # Note: Do not detach q_t_det_policy here b/c is depends partly
        # on the policy vars (policy sample pushed through Q-net).
        # However, we must make sure `actor_loss` is not used to update
        # the Q-net(s)' variables.
        actor_loss = reduce_mean_valid(alpha.detach() * log_pis_t -
                                       q_t_det_policy)

    # Save for stats function.
    policy.q_t = q_t * seq_mask[..., None]
    policy.policy_t = policy_t * seq_mask[..., None]
    policy.log_pis_t = log_pis_t * seq_mask[..., None]

    # Store td-error in model, such that for multi-GPU, we do not override
    # them during the parallel loss phase. TD-error tensor in final stats
    # can then be concatenated and retrieved for each individual batch item.
    model.td_error = td_error * seq_mask

    policy.actor_loss = actor_loss
    policy.critic_loss = critic_loss
    policy.alpha_loss = alpha_loss
    policy.log_alpha_value = model.log_alpha
    policy.alpha_value = alpha
    policy.target_entropy = model.target_entropy

    # Return all loss terms corresponding to our optimizers.
    return tuple([policy.actor_loss] + policy.critic_loss +
                 [policy.alpha_loss])
示例#14
0
    def train_q(self, batch: SampleBatch) -> TensorType:
        """Trains self.q_model using Q-Reg loss on given batch.

        Args:
            batch: A SampleBatch of episodes to train on

        Returns:
            A list of losses for each training iteration
        """
        losses = []
        obs = torch.tensor(batch[SampleBatch.OBS], device=self.device)
        actions = torch.tensor(batch[SampleBatch.ACTIONS], device=self.device)
        ps = torch.zeros([batch.count], device=self.device)
        returns = torch.zeros([batch.count], device=self.device)
        discounts = torch.zeros([batch.count], device=self.device)

        # Neccessary if policy uses recurrent/attention model
        num_state_inputs = 0
        for k in batch.keys():
            if k.startswith("state_in_"):
                num_state_inputs += 1
        state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)]

        # get rewards, old_prob, new_prob
        rewards = batch[SampleBatch.REWARDS]
        old_log_prob = torch.tensor(batch[SampleBatch.ACTION_LOGP])
        new_log_prob = (self.policy.compute_log_likelihoods(
            actions=batch[SampleBatch.ACTIONS],
            obs_batch=batch[SampleBatch.OBS],
            state_batches=[batch[k] for k in state_keys],
            prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS),
            prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS),
            actions_normalized=False,
        ).detach().cpu())
        prob_ratio = torch.exp(new_log_prob - old_log_prob)

        eps_begin = 0
        for episode in batch.split_by_episode():
            eps_end = eps_begin + episode.count

            # calculate importance ratios and returns

            for t in range(episode.count):
                discounts[eps_begin + t] = self.gamma**t
                if t == 0:
                    pt_prev = 1.0
                else:
                    pt_prev = ps[eps_begin + t - 1]
                ps[eps_begin + t] = pt_prev * prob_ratio[eps_begin + t]

                # O(n^3)
                # ret = 0
                # for t_prime in range(t, episode.count):
                #     gamma = self.gamma ** (t_prime - t)
                #     rho_t_1_t_prime = 1.0
                #     for k in range(t + 1, min(t_prime + 1, episode.count)):
                #         rho_t_1_t_prime = rho_t_1_t_prime * prob_ratio[eps_begin + k]
                #     r = rewards[eps_begin + t_prime]
                #     ret += gamma * rho_t_1_t_prime * r

                # O(n^2)
                ret = 0
                rho = 1
                for t_ in reversed(range(t, episode.count)):
                    ret = rewards[eps_begin + t_] + self.gamma * rho * ret
                    rho = prob_ratio[eps_begin + t_]

                returns[eps_begin + t] = ret

            # Update before next episode
            eps_begin = eps_end

        indices = np.arange(batch.count)
        for _ in range(self.n_iters):
            minibatch_losses = []
            np.random.shuffle(indices)
            for idx in range(0, batch.count, self.batch_size):
                idxs = indices[idx:idx + self.batch_size]
                q_values, _ = self.q_model({"obs": obs[idxs]}, [], None)
                q_acts = torch.gather(q_values, -1,
                                      actions[idxs].unsqueeze(-1)).squeeze(-1)
                loss = discounts[idxs] * ps[idxs] * (returns[idxs] - q_acts)**2
                loss = torch.mean(loss)
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad.clip_grad_norm_(self.q_model.variables(),
                                                   self.clip_grad_norm)
                self.optimizer.step()
                minibatch_losses.append(loss.item())
            iter_loss = sum(minibatch_losses) / len(minibatch_losses)
            losses.append(iter_loss)
            if iter_loss < self.delta:
                break
        return losses
示例#15
0
    def loss(
        self,
        model: ModelV2,
        dist_class: Type[ActionDistribution],
        train_batch: SampleBatch,
    ) -> Union[TensorType, List[TensorType]]:
        model_out, _ = model(train_batch)
        action_dist = dist_class(model_out, model)

        if isinstance(self.action_space, gym.spaces.Discrete):
            is_multidiscrete = False
            output_hidden_shape = [self.action_space.n]
        elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
            is_multidiscrete = True
            output_hidden_shape = self.action_space.nvec.astype(np.int32)
        else:
            is_multidiscrete = False
            output_hidden_shape = 1

        def _make_time_major(*args, **kw):
            return make_time_major(self, train_batch.get(SampleBatch.SEQ_LENS),
                                   *args, **kw)

        actions = train_batch[SampleBatch.ACTIONS]
        dones = train_batch[SampleBatch.DONES]
        rewards = train_batch[SampleBatch.REWARDS]
        behaviour_action_logp = train_batch[SampleBatch.ACTION_LOGP]
        behaviour_logits = train_batch[SampleBatch.ACTION_DIST_INPUTS]
        if isinstance(output_hidden_shape, (list, tuple, np.ndarray)):
            unpacked_behaviour_logits = torch.split(behaviour_logits,
                                                    list(output_hidden_shape),
                                                    dim=1)
            unpacked_outputs = torch.split(model_out,
                                           list(output_hidden_shape),
                                           dim=1)
        else:
            unpacked_behaviour_logits = torch.chunk(behaviour_logits,
                                                    output_hidden_shape,
                                                    dim=1)
            unpacked_outputs = torch.chunk(model_out,
                                           output_hidden_shape,
                                           dim=1)
        values = model.value_function()

        if self.is_recurrent():
            max_seq_len = torch.max(train_batch[SampleBatch.SEQ_LENS])
            mask_orig = sequence_mask(train_batch[SampleBatch.SEQ_LENS],
                                      max_seq_len)
            mask = torch.reshape(mask_orig, [-1])
        else:
            mask = torch.ones_like(rewards)

        # Prepare actions for loss.
        loss_actions = actions if is_multidiscrete else torch.unsqueeze(
            actions, dim=1)

        # Inputs are reshaped from [B * T] => [(T|T-1), B] for V-trace calc.
        drop_last = self.config["vtrace_drop_last_ts"]
        loss = VTraceLoss(
            actions=_make_time_major(loss_actions, drop_last=drop_last),
            actions_logp=_make_time_major(action_dist.logp(actions),
                                          drop_last=drop_last),
            actions_entropy=_make_time_major(action_dist.entropy(),
                                             drop_last=drop_last),
            dones=_make_time_major(dones, drop_last=drop_last),
            behaviour_action_logp=_make_time_major(behaviour_action_logp,
                                                   drop_last=drop_last),
            behaviour_logits=_make_time_major(unpacked_behaviour_logits,
                                              drop_last=drop_last),
            target_logits=_make_time_major(unpacked_outputs,
                                           drop_last=drop_last),
            discount=self.config["gamma"],
            rewards=_make_time_major(rewards, drop_last=drop_last),
            values=_make_time_major(values, drop_last=drop_last),
            bootstrap_value=_make_time_major(values)[-1],
            dist_class=TorchCategorical if is_multidiscrete else dist_class,
            model=model,
            valid_mask=_make_time_major(mask, drop_last=drop_last),
            config=self.config,
            vf_loss_coeff=self.config["vf_loss_coeff"],
            entropy_coeff=self.entropy_coeff,
            clip_rho_threshold=self.config["vtrace_clip_rho_threshold"],
            clip_pg_rho_threshold=self.config["vtrace_clip_pg_rho_threshold"],
        )

        # Store values for stats function in model (tower), such that for
        # multi-GPU, we do not override them during the parallel loss phase.
        model.tower_stats["pi_loss"] = loss.pi_loss
        model.tower_stats["vf_loss"] = loss.vf_loss
        model.tower_stats["entropy"] = loss.entropy
        model.tower_stats["mean_entropy"] = loss.mean_entropy
        model.tower_stats["total_loss"] = loss.total_loss

        values_batched = make_time_major(
            self,
            train_batch.get(SampleBatch.SEQ_LENS),
            values,
            drop_last=self.config["vtrace"] and drop_last,
        )
        model.tower_stats["vf_explained_var"] = explained_variance(
            torch.reshape(loss.value_targets, [-1]),
            torch.reshape(values_batched, [-1]))

        return loss.total_loss