def forward(self, user: TensorType, doc: TensorType) -> TensorType: """Evaluate the user-doc Q model Args: user: User embedding of shape (batch_size, user embedding size). Note that `self.embedding_size` is the sum of both user- and doc-embedding size. doc: Doc embeddings of shape (batch_size, num_docs, doc embedding size). Note that `self.embedding_size` is the sum of both user- and doc-embedding size. Returns: The q_values per document of shape (batch_size, num_docs + 1). +1 due to also having a Q-value for the non-interaction (no click/no doc). """ batch_size, num_docs, embedding_size = doc.shape doc_flat = doc.view((batch_size * num_docs, embedding_size)) # Concat everything. # No user features. if user.shape[-1] == 0: x = doc_flat # User features, repeat user embeddings n times (n=num docs). else: user_repeated = user.repeat(num_docs, 1) x = torch.cat([user_repeated, doc_flat], dim=1) x = self.layers(x) # Similar to Google's SlateQ implementation in RecSim, we force the # Q-values to zeros if there are no clicks. # See https://arxiv.org/abs/1905.12767 for details. x_no_click = torch.zeros((batch_size, 1), device=x.device) return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)
def add_time_dimension( padded_inputs: TensorType, *, max_seq_len: int, framework: str = "tf", time_major: bool = False, ): """Adds a time dimension to padded inputs. Args: padded_inputs (TensorType): a padded batch of sequences. That is, for seq_lens=[1, 2, 2], then inputs=[A, *, B, B, C, C], where A, B, C are sequence elements and * denotes padding. max_seq_len (int): The max. sequence length in padded_inputs. framework (str): The framework string ("tf2", "tf", "tfe", "torch"). time_major (bool): Whether data should be returned in time-major (TxB) format or not (BxT). Returns: TensorType: Reshaped tensor of shape [B, T, ...] or [T, B, ...]. """ # Sequence lengths have to be specified for LSTM batch inputs. The # input batch must be padded to the max seq length given here. That is, # batch_size == len(seq_lens) * max(seq_lens) if framework in ["tf2", "tf", "tfe"]: assert time_major is False, "time-major not supported yet for tf!" padded_batch_size = tf.shape(padded_inputs)[0] # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len new_shape = tf.squeeze( tf.stack( [ tf.expand_dims(new_batch_size, axis=0), tf.expand_dims(max_seq_len, axis=0), tf.shape(padded_inputs)[1:], ], axis=0, )) ret = tf.reshape(padded_inputs, new_shape) ret.set_shape([None, None] + padded_inputs.shape[1:].as_list()) return ret else: assert framework == "torch", "`framework` must be either tf or torch!" padded_batch_size = padded_inputs.shape[0] # Dynamically reshape the padded batch to introduce a time dimension. new_batch_size = padded_batch_size // max_seq_len batch_major_shape = (new_batch_size, max_seq_len) + padded_inputs.shape[1:] padded_outputs = padded_inputs.view(batch_major_shape) if time_major: # Swap the batch and time dimensions padded_outputs = padded_outputs.transpose(0, 1) return padded_outputs
def forward(self, user: TensorType, doc: TensorType) -> TensorType: """Evaluate the user-doc Q model Args: user (TensorType): User embedding of shape (batch_size, embedding_size). doc (TensorType): Doc embeddings of shape (batch_size, num_docs, embedding_size). Returns: score (TensorType): q_values of shape (batch_size, num_docs + 1). """ batch_size, num_docs, embedding_size = doc.shape doc_flat = doc.view((batch_size * num_docs, embedding_size)) user_repeated = user.repeat(num_docs, 1) x = torch.cat([user_repeated, doc_flat], dim=1) x = self.layers(x) # Similar to Google's SlateQ implementation in RecSim, we force the # Q-values to zeros if there are no clicks. x_no_click = torch.zeros((batch_size, 1), device=x.device) return torch.cat([x.view((batch_size, num_docs)), x_no_click], dim=1)