Exemplo n.º 1
0
    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user-doc Q model

        Args:
            user: User embedding of shape (batch_size, user embedding size).
                Note that `self.embedding_size` is the sum of both user- and
                doc-embedding size.
            doc: Doc embeddings of shape (batch_size, num_docs, doc embedding size).
                Note that `self.embedding_size` is the sum of both user- and
                doc-embedding size.

        Returns:
            The q_values per document of shape (batch_size, num_docs + 1). +1 due to
            also having a Q-value for the non-interaction (no click/no doc).
        """
        batch_size, num_docs, embedding_size = doc.shape
        doc_flat = doc.view((batch_size * num_docs, embedding_size))

        # Concat everything.
        # No user features.
        if user.shape[-1] == 0:
            x = doc_flat
        # User features, repeat user embeddings n times (n=num docs).
        else:
            user_repeated = user.repeat(num_docs, 1)
            x = torch.cat([user_repeated, doc_flat], dim=1)

        x = self.layers(x)

        # Similar to Google's SlateQ implementation in RecSim, we force the
        # Q-values to zeros if there are no clicks.
        # See https://arxiv.org/abs/1905.12767 for details.
        x_no_click = torch.zeros((batch_size, 1), device=x.device)

        return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
Exemplo n.º 2
0
def add_time_dimension(
    padded_inputs: TensorType,
    *,
    max_seq_len: int,
    framework: str = "tf",
    time_major: bool = False,
):
    """Adds a time dimension to padded inputs.

    Args:
        padded_inputs (TensorType): a padded batch of sequences. That is,
            for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where
            A, B, C are sequence elements and * denotes padding.
        max_seq_len (int): The max. sequence length in padded_inputs.
        framework (str): The framework string ("tf2", "tf", "tfe", "torch").
        time_major (bool): Whether data should be returned in time-major (TxB)
            format or not (BxT).

    Returns:
        TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...].
    """

    # Sequence lengths have to be specified for LSTM batch inputs. The
    # input batch must be padded to the max seq length given here. That is,
    # batch_size == len(seq_lens) * max(seq_lens)
    if framework in ["tf2", "tf", "tfe"]:
        assert time_major is False, "time-major not supported yet for tf!"
        padded_batch_size = tf.shape(padded_inputs)[0]
        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        new_shape = tf.squeeze(
            tf.stack(
                [
                    tf.expand_dims(new_batch_size, axis=0),
                    tf.expand_dims(max_seq_len, axis=0),
                    tf.shape(padded_inputs)[1:],
                ],
                axis=0,
            ))
        ret = tf.reshape(padded_inputs, new_shape)
        ret.set_shape([None, None] + padded_inputs.shape[1:].as_list())
        return ret
    else:
        assert framework == "torch", "`framework` must be either tf or torch!"
        padded_batch_size = padded_inputs.shape[0]

        # Dynamically reshape the padded batch to introduce a time dimension.
        new_batch_size = padded_batch_size // max_seq_len
        batch_major_shape = (new_batch_size,
                             max_seq_len) + padded_inputs.shape[1:]
        padded_outputs = padded_inputs.view(batch_major_shape)

        if time_major:
            # Swap the batch and time dimensions
            padded_outputs = padded_outputs.transpose(0, 1)
        return padded_outputs
Exemplo n.º 3
0
    def forward(self, user: TensorType, doc: TensorType) -> TensorType:
        """Evaluate the user-doc Q model

        Args:
            user (TensorType): User embedding of shape (batch_size,
                embedding_size).
            doc (TensorType): Doc embeddings of shape (batch_size, num_docs,
                embedding_size).

        Returns:
            score (TensorType): q_values of shape (batch_size, num_docs + 1).
        """
        batch_size, num_docs, embedding_size = doc.shape
        doc_flat = doc.view((batch_size * num_docs, embedding_size))
        user_repeated = user.repeat(num_docs, 1)
        x = torch.cat([user_repeated, doc_flat], dim=1)
        x = self.layers(x)
        # Similar to Google's SlateQ implementation in RecSim, we force the
        # Q-values to zeros if there are no clicks.
        x_no_click = torch.zeros((batch_size, 1), device=x.device)
        return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)