def get_ucbs(self, x: TensorType): """Calculate upper confidence bounds using covariance matrix according to algorithm 1: LinUCB (http://proceedings.mlr.press/v15/chu11a/chu11a.pdf). Args: x: Input feature tensor of shape (batch_size, [num_items]?, feature_dim) """ # Fold batch and num-items dimensions into one dim. if len(x.shape) == 3: B, C, F = x.shape x_folded_batch = x.reshape([-1, F]) # Only batch and feature dims. else: x_folded_batch = x projections = self.covariance @ x_folded_batch.T batch_dots = (x_folded_batch * projections.T).sum(dim=-1) batch_dots = batch_dots.sqrt() # Restore original B and C dimensions. if len(x.shape) == 3: batch_dots = batch_dots.reshape([B, C]) return batch_dots
def one_hot(x: TensorType, space: gym.Space) -> TensorType: """Returns a one-hot tensor, given and int tensor and a space. Handles the MultiDiscrete case as well. Args: x: The input tensor. space: The space to use for generating the one-hot tensor. Returns: The resulting one-hot tensor. Raises: ValueError: If the given space is not a discrete one. Examples: >>> import torch >>> import gym >>> from ray.rllib.utils.torch_utils import one_hot >>> x = torch.IntTensor([0, 3]) # batch-dim=2 >>> # Discrete space with 4 (one-hot) slots per batch item. >>> s = gym.spaces.Discrete(4) >>> one_hot(x, s) # doctest: +SKIP tensor([[1, 0, 0, 0], [0, 0, 0, 1]]) >>> x = torch.IntTensor([[0, 1, 2, 3]]) # batch-dim=1 >>> # MultiDiscrete space with 5 + 4 + 4 + 7 = 20 (one-hot) slots >>> # per batch item. >>> s = gym.spaces.MultiDiscrete([5, 4, 4, 7]) >>> one_hot(x, s) # doctest: +SKIP tensor([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0]]) """ if isinstance(space, Discrete): return nn.functional.one_hot(x.long(), space.n) elif isinstance(space, MultiDiscrete): if isinstance(space.nvec[0], np.ndarray): nvec = np.ravel(space.nvec) x = x.reshape(x.shape[0], -1) else: nvec = space.nvec return torch.cat( [ nn.functional.one_hot(x[:, i].long(), n) for i, n in enumerate(nvec) ], dim=-1, ) else: raise ValueError("Unsupported space for `one_hot`: {}".format(space))