Exemple #1
0
    def get_ucbs(self, x: TensorType):
        """Calculate upper confidence bounds using covariance matrix according
        to algorithm 1: LinUCB
        (http://proceedings.mlr.press/v15/chu11a/chu11a.pdf).

        Args:
            x: Input feature tensor of shape
                (batch_size, [num_items]?, feature_dim)
        """
        # Fold batch and num-items dimensions into one dim.
        if len(x.shape) == 3:
            B, C, F = x.shape
            x_folded_batch = x.reshape([-1, F])
        # Only batch and feature dims.
        else:
            x_folded_batch = x

        projections = self.covariance @ x_folded_batch.T
        batch_dots = (x_folded_batch * projections.T).sum(dim=-1)
        batch_dots = batch_dots.sqrt()

        # Restore original B and C dimensions.
        if len(x.shape) == 3:
            batch_dots = batch_dots.reshape([B, C])
        return batch_dots
Exemple #2
0
def one_hot(x: TensorType, space: gym.Space) -> TensorType:
    """Returns a one-hot tensor, given and int tensor and a space.

    Handles the MultiDiscrete case as well.

    Args:
        x: The input tensor.
        space: The space to use for generating the one-hot tensor.

    Returns:
        The resulting one-hot tensor.

    Raises:
        ValueError: If the given space is not a discrete one.

    Examples:
        >>> import torch
        >>> import gym
        >>> from ray.rllib.utils.torch_utils import one_hot
        >>> x = torch.IntTensor([0, 3])  # batch-dim=2
        >>> # Discrete space with 4 (one-hot) slots per batch item.
        >>> s = gym.spaces.Discrete(4)
        >>> one_hot(x, s) # doctest: +SKIP
        tensor([[1, 0, 0, 0], [0, 0, 0, 1]])
        >>> x = torch.IntTensor([[0, 1, 2, 3]])  # batch-dim=1
        >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots
        >>> # per batch item.
        >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7])
        >>> one_hot(x, s) # doctest: +SKIP
        tensor([[1, 0, 0, 0, 0,
                 0, 1, 0, 0,
                 0, 0, 1, 0,
                 0, 0, 0, 1, 0, 0, 0]])
    """
    if isinstance(space, Discrete):
        return nn.functional.one_hot(x.long(), space.n)
    elif isinstance(space, MultiDiscrete):
        if isinstance(space.nvec[0], np.ndarray):
            nvec = np.ravel(space.nvec)
            x = x.reshape(x.shape[0], -1)
        else:
            nvec = space.nvec
        return torch.cat(
            [
                nn.functional.one_hot(x[:, i].long(), n)
                for i, n in enumerate(nvec)
            ],
            dim=-1,
        )
    else:
        raise ValueError("Unsupported space for `one_hot`: {}".format(space))