Exemplo n.º 1
0
    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user-doc Q model

        Args:
            user: User embedding of shape (batch_size, user embedding size).
                Note that `self.embedding_size` is the sum of both user- and
                doc-embedding size.
            doc: Doc embeddings of shape (batch_size, num_docs, doc embedding size).
                Note that `self.embedding_size` is the sum of both user- and
                doc-embedding size.

        Returns:
            The q_values per document of shape (batch_size, num_docs + 1). +1 due to
            also having a Q-value for the non-interaction (no click/no doc).
        """
        batch_size, num_docs, embedding_size = doc.shape
        doc_flat = doc.view((batch_size * num_docs, embedding_size))

        # Concat everything.
        # No user features.
        if user.shape[-1] == 0:
            x = doc_flat
        # User features, repeat user embeddings n times (n=num docs).
        else:
            user_repeated = user.repeat(num_docs, 1)
            x = torch.cat([user_repeated, doc_flat], dim=1)

        x = self.layers(x)

        # Similar to Google's SlateQ implementation in RecSim, we force the
        # Q-values to zeros if there are no clicks.
        # See https://arxiv.org/abs/1905.12767 for details.
        x_no_click = torch.zeros((batch_size, 1), device=x.device)

        return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
Exemplo n.º 2
0
 def write(self, observation: TensorType, array: np.ndarray,
           offset: int) -> None:
     if not isinstance(observation, OrderedDict):
         observation = OrderedDict(sorted(observation.items()))
     assert len(observation) == len(self.preprocessors), \
         (len(observation), len(self.preprocessors))
     for o, p in zip(observation.values(), self.preprocessors):
         p.write(o, array, offset)
         offset += p.size
Exemplo n.º 3
0
 def logp(self, actions: TensorType) -> TensorType:
     # If tensor is provided, unstack it into list.
     if isinstance(actions, tf.Tensor):
         if isinstance(self.action_space, gym.spaces.Box):
             actions = tf.reshape(
                 actions, [-1, int(np.product(self.action_space.shape))])
         elif isinstance(self.action_space, gym.spaces.MultiDiscrete):
             actions.set_shape((None, len(self.cats)))
         actions = tf.unstack(tf.cast(actions, tf.int32), axis=1)
     logps = tf.stack(
         [cat.logp(act) for cat, act in zip(self.cats, actions)])
     return tf.reduce_sum(logps, axis=0)
Exemplo n.º 4
0
    def forward(self, inputs: TensorType,
                memory: TensorType = None) -> TensorType:
        T = list(inputs.size())[1]  # length of segment (time)
        H = self._num_heads  # number of attention heads
        d = self._head_dim  # attention head dimension

        # Add previous memory chunk (as const, w/o gradient) to input.
        # Tau (number of (prev) time slices in each memory chunk).
        Tau = list(memory.shape)[1] if memory is not None else 0
        if memory is not None:
            memory.requires_grad_(False)
            inputs = torch.cat((memory, inputs), dim=1)

        # Apply the Layer-Norm.
        if self._input_layernorm is not None:
            inputs = self._input_layernorm(inputs)

        qkv = self._qkv_layer(inputs)

        queries, keys, values = torch.chunk(input=qkv, chunks=3, dim=-1)
        # Cut out Tau memory timesteps from query.
        queries = queries[:, -T:]

        queries = torch.reshape(queries, [-1, T, H, d])
        keys = torch.reshape(keys, [-1, T + Tau, H, d])
        values = torch.reshape(values, [-1, T + Tau, H, d])

        R = self._pos_proj(self._rel_pos_encoder)
        R = torch.reshape(R, [T + Tau, H, d])

        # b=batch
        # i and j=time indices (i=max-timesteps (inputs); j=Tau memory space)
        # h=head
        # d=head-dim (over which we will reduce-sum)
        score = torch.einsum("bihd,bjhd->bijh", queries + self._uvar, keys)
        pos_score = torch.einsum("bihd,jhd->bijh", queries + self._vvar, R)
        score = score + self.rel_shift(pos_score)
        score = score / d**0.5

        # causal mask of the same length as the sequence
        mask = sequence_mask(
            torch.arange(Tau + 1, T + Tau + 1), dtype=score.dtype)
        mask = mask[None, :, :, None]

        masked_score = score * mask + 1e30 * (mask.to(torch.float32) - 1.)
        wmat = nn.functional.softmax(masked_score, dim=2)

        out = torch.einsum("bijh,bjhd->bihd", wmat, values)
        shape = list(out.shape)[:2] + [H * d]
        out = torch.reshape(out, shape)

        return self._linear_layer(out)
Exemplo n.º 5
0
def one_hot(x: TensorType, space: gym.Space) -> TensorType:
    """Returns a one-hot tensor, given and int tensor and a space.

    Handles the MultiDiscrete case as well.

    Args:
        x: The input tensor.
        space: The space to use for generating the one-hot tensor.

    Returns:
        The resulting one-hot tensor.

    Raises:
        ValueError: If the given space is not a discrete one.

    Examples:
        >>> import torch
        >>> import gym
        >>> from ray.rllib.utils.torch_utils import one_hot
        >>> x = torch.IntTensor([0, 3])  # batch-dim=2
        >>> # Discrete space with 4 (one-hot) slots per batch item.
        >>> s = gym.spaces.Discrete(4)
        >>> one_hot(x, s) # doctest: +SKIP
        tensor([[1, 0, 0, 0], [0, 0, 0, 1]])
        >>> x = torch.IntTensor([[0, 1, 2, 3]])  # batch-dim=1
        >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
        >>> # per batch item.
        >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
        >>> one_hot(x, s) # doctest: +SKIP
        tensor([[1, 0, 0, 0, 0,
                 0, 1, 0, 0,
                 0, 0, 1, 0,
                 0, 0, 0, 1, 0, 0, 0]])
    """
    if isinstance(space, Discrete):
        return nn.functional.one_hot(x.long(), space.n)
    elif isinstance(space, MultiDiscrete):
        if isinstance(space.nvec[0], np.ndarray):
            nvec = np.ravel(space.nvec)
            x = x.reshape(x.shape[0], -1)
        else:
            nvec = space.nvec
        return torch.cat(
            [
                nn.functional.one_hot(x[:, i].long(), n)
                for i, n in enumerate(nvec)
            ],
            dim=-1,
        )
    else:
        raise ValueError("Unsupported space for `one_hot`: {}".format(space))
Exemplo n.º 6
0
 def _value(self, t: TensorType) -> TensorType:
     """Returns the result of: initial_p * decay_rate ** (`t`/t_max).
     """
     if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
         t = t.float()
     return self.initial_p * \
         self.decay_rate ** (t / self.schedule_timesteps)
Exemplo n.º 7
0
    def get_policy_output(self, model_out: TensorType) -> TensorType:
        """Returns policy outputs, given the output of self.__call__().

        For continuous action spaces, these will be the mean/stddev
        distribution inputs for the (SquashedGaussian) action distribution.
        For discrete action spaces, these will be the logits for a categorical
        distribution.

        Args:
            model_out (TensorType): Feature outputs from the model layers
                (result of doing `self.__call__(obs)`).

        Returns:
            TensorType: Distribution inputs for sampling actions.
        """
        # Model outs may come as original Tuple/Dict observations, concat them
        # here if this is the case.
        if isinstance(self.action_model.obs_space, Box):
            if isinstance(model_out, (list, tuple)):
                model_out = tf.concat(model_out, axis=-1)
            elif isinstance(model_out, dict):
                model_out = tf.concat([
                    tf.expand_dims(val, 1) if len(val.shape) == 1 else val
                    for val in tree.flatten(model_out.values())
                ],
                                      axis=-1)
        out, _ = self.action_model({"obs": model_out}, [], None)
        return out
Exemplo n.º 8
0
    def get_ucbs(self, x: TensorType):
        """Calculate upper confidence bounds using covariance matrix according
        to algorithm 1: LinUCB
        (http://proceedings.mlr.press/v15/chu11a/chu11a.pdf).

        Args:
            x: Input feature tensor of shape
                (batch_size, [num_items]?, feature_dim)
        """
        # Fold batch and num-items dimensions into one dim.
        if len(x.shape) == 3:
            B, C, F = x.shape
            x_folded_batch = x.reshape([-1, F])
        # Only batch and feature dims.
        else:
            x_folded_batch = x

        projections = self.covariance @ x_folded_batch.T
        batch_dots = (x_folded_batch * projections.T).sum(dim=-1)
        batch_dots = batch_dots.sqrt()

        # Restore original B and C dimensions.
        if len(x.shape) == 3:
            batch_dots = batch_dots.reshape([B, C])
        return batch_dots
Exemplo n.º 9
0
    def _get_q_value(
        self,
        model_out: TensorType,
        actions,
        net,
        state_in: List[TensorType],
        seq_lens: TensorType,
    ) -> (TensorType, List[TensorType]):
        # Continuous case -> concat actions to model_out.
        if (actions is not None
                and not model_out.get("obs_and_action_concatenated") is True):
            # Make sure that, if we call this method twice with the same
            # input, we don't concatenate twice
            model_out["obs_and_action_concatenated"] = True

            if self.concat_obs_and_actions:
                model_out[SampleBatch.OBS] = torch.cat(
                    [model_out[SampleBatch.OBS], actions], dim=-1)
            else:
                model_out[SampleBatch.OBS] = force_list(
                    model_out[SampleBatch.OBS]) + [actions]

        # Switch on training mode (when getting Q-values, we are usually in
        # training).
        model_out["is_training"] = True

        out, state_out = net(model_out, state_in, seq_lens)
        return out, state_out
Exemplo n.º 10
0
def _repeat_tensor(t: TensorType, n: int):
    # Insert new dimension at posotion 1 into tensor t
    t_rep = t.unsqueeze(1)
    # Repeat tensor t_rep along new dimension n times
    t_rep = torch.repeat_interleave(t_rep, n, dim=1)
    # Merge new dimension into batch dimension
    t_rep = t_rep.view(-1, *t.shape[1:])
    return t_rep
Exemplo n.º 11
0
 def _value(self, t: TensorType) -> TensorType:
     """Returns the result of:
     final_p + (initial_p - final_p) * (1 - `t`/t_max) ** power
     """
     if self.framework == "torch" and torch and isinstance(t, torch.Tensor):
         t = t.float()
     t = min(t, self.schedule_timesteps)
     return self.final_p + (self.initial_p - self.final_p) * (
         1.0 - (t / self.schedule_timesteps))**self.power
Exemplo n.º 12
0
def add_time_dimension(
    padded_inputs: TensorType,
    *,
    max_seq_len: int,
    framework: str = "tf",
    time_major: bool = False,
):
    """Adds a time dimension to padded inputs.

    Args:
        padded_inputs (TensorType): a padded batch of sequences. That is,
            for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
            A, B, C are sequence elements and * denotes padding.
        max_seq_len (int): The max. sequence length in padded_inputs.
        framework (str): The framework string ("tf2", "tf", "tfe", "torch").
        time_major (bool): Whether data should be returned in time-major (TxB)
            format or not (BxT).

    Returns:
        TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
    """

    # Sequence lengths have to be specified for LSTM batch inputs. The
    # input batch must be padded to the max seq length given here. That is,
    # batch_size == len(seq_lens) * max(seq_lens)
    if framework in ["tf2", "tf", "tfe"]:
        assert time_major is False, "time-major not supported yet for tf!"
        padded_batch_size = tf.shape(padded_inputs)[0]
        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        new_shape = tf.squeeze(
            tf.stack(
                [
                    tf.expand_dims(new_batch_size, axis=0),
                    tf.expand_dims(max_seq_len, axis=0),
                    tf.shape(padded_inputs)[1:],
                ],
                axis=0,
            ))
        ret = tf.reshape(padded_inputs, new_shape)
        ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
        return ret
    else:
        assert framework == "torch", "`framework` must be either tf or torch!"
        padded_batch_size = padded_inputs.shape[0]

        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        batch_major_shape = (new_batch_size,
                             max_seq_len) + padded_inputs.shape[1:]
        padded_outputs = padded_inputs.view(batch_major_shape)

        if time_major:
            # Swap the batch and time dimensions
            padded_outputs = padded_outputs.transpose(0, 1)
        return padded_outputs
Exemplo n.º 13
0
def linear(x: TensorType,
           size: int,
           name: str,
           initializer: Optional[Any] = None,
           bias_init: float = 0.0) -> TensorType:
    w = tf1.get_variable(
        name + "/w", [x.get_shape()[1], size], initializer=initializer)
    b = tf1.get_variable(
        name + "/b", [size], initializer=tf1.constant_initializer(bias_init))
    return tf.matmul(x, w) + b
Exemplo n.º 14
0
    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user-doc Q model

        Args:
            user (TensorType): User embedding of shape (batch_size,
                embedding_size).
            doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
                embedding_size).

        Returns:
            score (TensorType): q_values of shape (batch_size, num_docs + 1).
        """
        batch_size, num_docs, embedding_size = doc.shape
        doc_flat = doc.view((batch_size * num_docs, embedding_size))
        user_repeated = user.repeat(num_docs, 1)
        x = torch.cat([user_repeated, doc_flat], dim=1)
        x = self.layers(x)
        # Similar to Google's SlateQ implementation in RecSim, we force the
        # Q-values to zeros if there are no clicks.
        x_no_click = torch.zeros((batch_size, 1), device=x.device)
        return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
Exemplo n.º 15
0
    def representation_function(self, obs: TensorType) -> TensorType:
        obs = obs.float().permute(0, 3, 1, 2)
        output = self.representation(obs)
        self.hidden = output

        if not self.cache:
            self.cache = [self.hidden] * self.order
        else:
            self.cache.append(self.hidden)
            self.cache.pop(0)

        return output
Exemplo n.º 16
0
    def forward_rnn(self, inputs: TensorType, state: List[TensorType],
                    seq_lens: TensorType) -> (TensorType, List[TensorType]):
        # To make Attention work with current RLlib's ModelV2 API:
        # We assume `state` is the history of L recent observations (all
        # concatenated into one tensor) and append the current inputs to the
        # end and only keep the most recent (up to `max_seq_len`). This allows
        # us to deal with timestep-wise inference and full sequence training
        # within the same logic.
        state = [torch.from_numpy(item) for item in state]
        observations = state[0]
        memory = state[1:]

        inputs = torch.reshape(inputs, [1, -1, observations.shape[-1]])
        observations = torch.cat((observations, inputs),
                                 axis=1)[:, -self.max_seq_len:]

        all_out = observations
        for i in range(len(self.layers)):
            # MHA layers which need memory passed in.
            if i % 2 == 1:
                all_out = self.layers[i](all_out, memory=memory[i // 2])
            # Either linear layers or MultiLayerPerceptrons.
            else:
                all_out = self.layers[i](all_out)

        logits = self.logits(all_out)
        self._value_out = self.values_out(all_out)

        memory_outs = all_out[2:]
        # If memory_tau > max_seq_len -> overlap w/ previous `memory` input.
        if self.memory_tau > self.max_seq_len:
            memory_outs = [
                torch.cat(
                    [memory[i][:, -(self.memory_tau - self.max_seq_len):], m],
                    axis=1) for i, m in enumerate(memory_outs)
            ]
        else:
            memory_outs = [m[:, -self.memory_tau:] for m in memory_outs]

        T = list(inputs.size())[1]  # Length of input segment (time).

        # Postprocessing final output.
        logits = logits[:, -T:]
        self._value_out = self._value_out[:, -T:]

        return logits, [observations] + memory_outs
Exemplo n.º 17
0
Arquivo: misc.py Projeto: smorad/ray
def conv2d(
    x: TensorType,
    num_filters: int,
    name: str,
    filter_size: Tuple[int, int] = (3, 3),
    stride: Tuple[int, int] = (1, 1),
    pad: str = "SAME",
    dtype: Optional[Any] = None,
    collections: Optional[Any] = None,
) -> TensorType:
    if dtype is None:
        dtype = tf.float32

    with tf1.variable_scope(name):
        stride_shape = [1, stride[0], stride[1], 1]
        filter_shape = [
            filter_size[0],
            filter_size[1],
            int(x.get_shape()[3]),
            num_filters,
        ]

        # There are "num input feature maps * filter height * filter width"
        # inputs to each hidden unit.
        fan_in = np.prod(filter_shape[:3])
        # Each unit in the lower layer receives a gradient from: "num output
        # feature maps * filter height * filter width" / pooling size.
        fan_out = np.prod(filter_shape[:2]) * num_filters
        # Initialize weights with random weights.
        w_bound = np.sqrt(6 / (fan_in + fan_out))

        w = tf1.get_variable(
            "W",
            filter_shape,
            dtype,
            tf1.random_uniform_initializer(-w_bound, w_bound),
            collections=collections,
        )
        b = tf1.get_variable(
            "b",
            [1, 1, 1, num_filters],
            initializer=tf1.constant_initializer(0.0),
            collections=collections,
        )
        return tf1.nn.conv2d(x, w, stride_shape, pad) + b
Exemplo n.º 18
0
def add_time_dimension(padded_inputs: TensorType,
                       *,
                       max_seq_len: int,
                       framework: str = "tf",
                       time_major: bool = False):
    """Adds a time dimension to padded inputs.

    Args:
        padded_inputs (TensorType): a padded batch of sequences. That is,
            for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
            A, B, C are sequence elements and * denotes padding.
        max_seq_len (int): The max. sequence length in padded_inputs.
        framework (str): The framework string ("tf2", "tf", "tfe", "torch").
        time_major (bool): Whether data should be returned in time-major (TxB)
            format or not (BxT).

    Returns:
        TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
    """

    # Sequence lengths have to be specified for LSTM batch inputs. The
    # input batch must be padded to the max seq length given here. That is,
    # batch_size == len(seq_lens) * max(seq_lens)
    if framework in ["tf2", "tf", "tfe"]:
        assert time_major is False, "time-major not supported yet for tf!"
        padded_batch_size = tf.shape(padded_inputs)[0]
        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        new_shape = ([new_batch_size, max_seq_len] +
                     padded_inputs.get_shape().as_list()[1:])
        return tf.reshape(padded_inputs, new_shape)
    else:
        assert framework == "torch", "`framework` must be either tf or torch!"
        padded_batch_size = padded_inputs.shape[0]

        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        if time_major:
            new_shape = (max_seq_len, new_batch_size) + padded_inputs.shape[1:]
        else:
            new_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:]
        return torch.reshape(padded_inputs, new_shape)
Exemplo n.º 19
0
def sequence_mask(
    lengths: TensorType,
    maxlen: Optional[int] = None,
    dtype=None,
    time_major: bool = False,
) -> TensorType:
    """Offers same behavior as tf.sequence_mask for torch.

    Thanks to Dimitris Papatheodorou
    (https://discuss.pytorch.org/t/pytorch-equivalent-for-tf-sequence-mask/
    39036).

    Args:
        lengths: The tensor of individual lengths to mask by.
        maxlen: The maximum length to use for the time axis. If None, use
            the max of `lengths`.
        dtype: The torch dtype to use for the resulting mask.
        time_major: Whether to return the mask as [B, T] (False; default) or
            as [T, B] (True).

    Returns:
         The sequence mask resulting from the given input and parameters.
    """
    # If maxlen not given, use the longest lengths in the `lengths` tensor.
    if maxlen is None:
        maxlen = int(lengths.max())

    mask = ~(
        torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t()
        > lengths
    )
    # Time major transformation.
    if not time_major:
        mask = mask.t()

    # By default, set the mask to be boolean.
    mask.type(dtype or torch.bool)

    return mask
Exemplo n.º 20
0
 def _hidden_layers(self, obs: TensorType) -> TensorType:
     res = self._convs(obs.permute(0, 3, 1, 2))  # switch to channel-major
     # res = res.squeeze(3)
     res = res.squeeze(2)
     return res
Exemplo n.º 21
0
 def forward(self, x: TensorType) -> TensorType:
     if x.dim() == 4:
         x = torch.squeeze(x, dim=2)
     return self._model(x)
Exemplo n.º 22
0
Arquivo: misc.py Projeto: smorad/ray
def flatten(x: TensorType) -> TensorType:
    return tf.reshape(x, [-1, np.prod(x.get_shape().as_list()[1:])])
Exemplo n.º 23
0
 def representation_function(self, obs: TensorType) -> TensorType:
     obs = obs.float().permute(0, 3, 1, 2)
     output = self.representation(obs)
     self.hidden = output
     return output
Exemplo n.º 24
0
    def __init__(
        self,
        q_t_selected: TensorType,
        q_logits_t_selected: TensorType,
        q_tp1_best: TensorType,
        q_probs_tp1_best: TensorType,
        importance_weights: TensorType,
        rewards: TensorType,
        done_mask: TensorType,
        gamma=0.99,
        n_step=1,
        num_atoms=1,
        v_min=-10.0,
        v_max=10.0,
    ):

        if num_atoms > 1:
            # Distributional Q-learning which corresponds to an entropy loss
            z = torch.range(0.0, num_atoms - 1,
                            dtype=torch.float32).to(rewards.device)
            z = v_min + z * (v_max - v_min) / float(num_atoms - 1)

            # (batch_size, 1) * (1, num_atoms) = (batch_size, num_atoms)
            r_tau = torch.unsqueeze(
                rewards, -1) + gamma**n_step * torch.unsqueeze(
                    1.0 - done_mask, -1) * torch.unsqueeze(z, 0)
            r_tau = torch.clamp(r_tau, v_min, v_max)
            b = (r_tau - v_min) / ((v_max - v_min) / float(num_atoms - 1))
            lb = torch.floor(b)
            ub = torch.ceil(b)

            # Indispensable judgement which is missed in most implementations
            # when b happens to be an integer, lb == ub, so pr_j(s', a*) will
            # be discarded because (ub-b) == (b-lb) == 0.
            floor_equal_ceil = ((ub - lb) < 0.5).float()

            # (batch_size, num_atoms, num_atoms)
            l_project = F.one_hot(lb.long(), num_atoms)
            # (batch_size, num_atoms, num_atoms)
            u_project = F.one_hot(ub.long(), num_atoms)
            ml_delta = q_probs_tp1_best * (ub - b + floor_equal_ceil)
            mu_delta = q_probs_tp1_best * (b - lb)
            ml_delta = torch.sum(l_project * torch.unsqueeze(ml_delta, -1),
                                 dim=1)
            mu_delta = torch.sum(u_project * torch.unsqueeze(mu_delta, -1),
                                 dim=1)
            m = ml_delta + mu_delta

            # Rainbow paper claims that using this cross entropy loss for
            # priority is robust and insensitive to `prioritized_replay_alpha`
            self.td_error = softmax_cross_entropy_with_logits(
                logits=q_logits_t_selected, labels=m.detach())
            self.loss = torch.mean(self.td_error * importance_weights)
            self.stats = {
                # TODO: better Q stats for dist dqn
            }
        else:
            q_tp1_best_masked = (1.0 - done_mask) * q_tp1_best

            # compute RHS of bellman equation
            q_t_selected_target = rewards + gamma**n_step * q_tp1_best_masked

            # compute the error (potentially clipped)
            self.td_error = q_t_selected - q_t_selected_target.detach()
            self.loss = torch.mean(importance_weights.float() *
                                   huber_loss(self.td_error))
            self.stats = {
                "mean_q": torch.mean(q_t_selected),
                "min_q": torch.min(q_t_selected),
                "max_q": torch.max(q_t_selected),
            }
Exemplo n.º 25
0
 def transform(self, observation: TensorType) -> np.ndarray:
     self.check_shape(observation)
     return (observation.astype("float32") - 128) / 128