def call(self, inputs: Tuple[tf.Tensor, tf.Tensor]) -> tf.Tensor: """Restores listwise shape of flattened_logits. Args: inputs: A tuple of (flattened_logits, list_mask), which are described below. * `flattened_logits`: A `Tensor` of predicted logits for each pair of query and documents, 1D tensor of shape [batch_size * list_size] or 2D tensor of shape [batch_size * list_size, 1]. * `list_mask`: A boolean `Tensor` of shape [batch_size, list_size] to mask out the invalid examples. Returns: A `Tensor` of shape [batch_size, list_size]. Raises: ValueError: If `flattened_logits` is not of shape [batch_size * list_size] or [batch_size * list_size, 1]. """ flattened_logits, list_mask = inputs try: logits = tf.reshape(flattened_logits, shape=tf.shape(list_mask)) except: raise ValueError('`flattened_logits` needs to be either ' '1D of [batch_size * list_size] or ' '2D of [batch_size * list_size, 1].') if self._by_scatter: nd_indices, _ = utils.padded_nd_indices(is_valid=list_mask) counts = tf.scatter_nd(nd_indices, tf.ones_like(logits), tf.shape(list_mask)) logits = tf.scatter_nd(nd_indices, logits, tf.shape(list_mask)) return tf.where(tf.math.greater(counts, 0.), logits / counts, tf.math.log(_EPSILON)) else: return tf.where(list_mask, logits, tf.math.log(_EPSILON))
def call(self, flattened_logits: tf.Tensor, list_mask: tf.Tensor) -> tf.Tensor: """Restores listwise shape of flattened_logits. Args: flattened_logits: A `Tensor` of predicted logits for each pair of query and documents, 1D tensor of shape [batch_size * list_size] or 2D tensor of shape [batch_size * list_size, 1]. list_mask: A boolean `Tensor` of shape [batch_size, list_size] to mask out the invalid examples. Returns: A `Tensor` of shape [batch_size, list_size]. Raises: ValueError: An error if the shape of `flattened_logits` is neither 1D nor 2D with shape [batch_size * list_size, 1]. """ try: logits = tf.reshape(flattened_logits, shape=tf.shape(list_mask)) except: raise ValueError('`flattened_logits` needs to be either ' '1D of [batch_size * list_size] or ' '2D of [batch_size * list_size, 1].') if self._by_scatter: nd_indices, _ = utils.padded_nd_indices(is_valid=list_mask) counts = tf.scatter_nd(nd_indices, tf.ones_like(logits), tf.shape(list_mask)) logits = tf.scatter_nd(nd_indices, logits, tf.shape(list_mask)) return tf.where(tf.math.greater(counts, 0.), logits / counts, tf.math.log(_EPSILON)) else: return tf.where(list_mask, logits, tf.math.log(_EPSILON))
def test_padded_nd_indices(self): with tf.Graph().as_default(): tf.compat.v1.set_random_seed(1) # batch_size, list_size = 2, 3. is_valid = [[True, True, True], [True, True, False]] # Disable shuffling. indices, mask = utils.padded_nd_indices(is_valid, shuffle=False) with tf.compat.v1.Session() as sess: indices, mask = sess.run([indices, mask]) # shape = [2, 3, 2] = [batch_size, list_size, 2]. self.assertAllEqual( indices, [ # batch_size = 2. [ # list_size = 3. [0, 0], [0, 1], [0, 2] ], [ # list_size = 3. [1, 0], [1, 1], [1, 0] ] ]) # shape = [2, 3] = [batch_size, list_size] self.assertAllEqual(mask, [[True, True, True], [True, True, False]]) # Enable shuffling. indices, mask = utils.padded_nd_indices(is_valid, shuffle=True, seed=87124) with tf.compat.v1.Session() as sess: indices, mask = sess.run([indices, mask]) # shape = [2, 3, 2] = [batch_size, list_size, 2]. self.assertAllEqual( indices, [ # batch_size = 2. [ # list_size = 3. [0, 0], [0, 1], [0, 2] ], [ # list_size = 3 [1, 1], [1, 0], [1, 1] ] ]) # shape = [2, 3] = [batch_size, list_size] self.assertAllEqual(mask, [[True, True, True], [True, True, False]])
def compute_logits(self, context_features=None, example_features=None, training=True, mask=None): tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand query features to be of [batch_size, list_size, ...]. large_batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) large_batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size * list_size]) large_batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) large_batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size * list_size]) # Get scores for large batch. sparse_input, dense_input = [], [] for name in large_batch_context_features: if name in self._sparse_embed_layers: sparse_input.append(self._sparse_embed_layers[name](large_batch_context_features[name])) else: dense_input.append(context_features[name]) for name in large_batch_example_features: if name in self._sparse_embed_layers: sparse_input.append(self._sparse_embed_layers[name](large_batch_example_features[name])) else: dense_input.append(large_batch_example_features[name]) sparse_input = [tf.keras.layers.Flatten()(inpt) for inpt in sparse_input] inputs = tf.concat(sparse_input + dense_input, 1) outputs = inputs for layer in self._scoring_layers: outputs = layer(outputs, training=training) scores = self._output_score_layer(outputs, training=training) logits = tf.reshape( scores, shape=[batch_size, list_size]) # Apply nd_mask to zero out invalid entries. logits = tf.where(nd_mask, logits, tf.zeros_like(logits)) return logits
def call( self, inputs: Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor], tf.Tensor] ) -> Tuple[Dict[str, tf.Tensor], Dict[str, tf.Tensor]]: """Call FlattenList layer to flatten context_features and example_features. Args: inputs: A tuple of (context_features, example_features, list_mask), which are described below: * `context_features`: A map of context features to 2D tensors of shape [batch_size, feature_dim]. * `example_features`: A map of example features to 3D tensors of shape [batch_size, list_size, feature_dim]. * `list_mask`: A Tensor of shape [batch_size, list_size] to mask out the invalid examples. Returns: A tuple of (flattened_context_features, flattened_example_fatures) where the former is a dict of context features to 2D tensors of shape [batch_size * list_size, feature_dim] and the latter is a dict of example features to 2D tensors of shape [batch_size * list_size, feature_dim]. Raises: ValueError: If `example_features` is empty dict or None. """ context_features, example_features, list_mask = inputs if not example_features: raise ValueError('Need a valid example feature.') batch_size = tf.shape(list_mask)[0] list_size = tf.shape(list_mask)[1] # Expand context features to be of [batch_size, list_size, ...]. flattened_context_features = {} for name, tensor in context_features.items(): expanded_tensor = tf.repeat(tf.expand_dims(tensor, axis=1), repeats=[list_size], axis=1) flattened_context_features[name] = utils.reshape_first_ndims( expanded_tensor, 2, [batch_size * list_size]) nd_indices = None if self._circular_padding: nd_indices, _ = utils.padded_nd_indices(is_valid=list_mask) flattened_example_features = {} for name, tensor in example_features.items(): if nd_indices is not None: # Replace invalid example features with valid ones. tensor = tf.gather_nd(tensor, nd_indices) flattened_example_features[name] = utils.reshape_first_ndims( tensor, 2, [batch_size * list_size]) return flattened_context_features, flattened_example_features
def compute_logits(self, context_features=None, example_features=None, training=None, mask=None): """Scores context and examples to return a score per document. Args: context_features: (dict) context feature names to 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to 3D tensors of shape [batch_size, list_size, feature_dims]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. If mask is None, all entries are valid. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. """ tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand query features to be of [batch_size, list_size, ...]. large_batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) large_batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size * list_size]) large_batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) large_batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size * list_size]) # Get scores for large batch. scores = self.score(context_features=large_batch_context_features, example_features=large_batch_example_features, training=training) logits = tf.reshape(scores, shape=[batch_size, list_size]) # Apply nd_mask to zero out invalid entries. logits = tf.where(nd_mask, logits, tf.zeros_like(logits)) return logits
def listwise_scoring(scorer, context_features, example_features, training=None, mask=None): """Listwise scoring op for context and example features. Args: scorer: A callable (e.g., A keras layer instance, a function) for scoring with the following signature: * Args: `context_features`: (dict) A dict of Tensors with the shape [batch_size, ...]. `example_features`: (dict) A dict of Tensors with the shape [batch_size, ...]. `training`: (bool) whether in training or inference mode. * Returns: The computed logits, a Tensor of shape [batch_size, output_size]. context_features: (dict) context feature names to dense 2D tensors of shape [batch_size, ...]. example_features: (dict) example feature names to dense 3D tensors of shape [batch_size, list_size, ...]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size, output_size]. Raises: ValueError: If example features is None or an empty dict. """ # Raise error if example features is None or empty dict. if not example_features: raise ValueError('Need a valid example feature.') tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand context features to be of [batch_size, list_size, ...]. large_batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) large_batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size * list_size]) large_batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) large_batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size * list_size]) # Get scores for large batch. scores = scorer(large_batch_context_features, large_batch_example_features, training=training) scores = tf.reshape(scores, shape=[batch_size, list_size, -1]) # Apply nd_mask to zero out invalid entries. # Expand dimension and use broadcasting for filtering. expanded_nd_mask = tf.expand_dims(nd_mask, axis=2) scores = tf.where(expanded_nd_mask, scores, tf.zeros_like(scores)) return scores
def compute_logits(self, context_features=None, example_features=None, training=None, mask=None): """Scores context and examples to return a score per example. Args: context_features: (dict) context feature names to 2D tensors of shape [batch_size, feature_dims]. example_features: (dict) example feature names to 3D tensors of shape [batch_size, list_size, feature_dims]. training: (bool) whether in train or inference mode. mask: (tf.Tensor) Mask is a tensor of shape [batch_size, list_size], which is True for a valid example and False for invalid one. If mask is None, all entries are valid. Returns: (tf.Tensor) A score tensor of shape [batch_size, list_size]. Raises: ValueError: If `scorer` does not return a scalar output. """ if not example_features: raise ValueError('Need a valid example feature.') tensor = next(six.itervalues(example_features)) batch_size = tf.shape(tensor)[0] list_size = tf.shape(tensor)[1] if mask is None: mask = tf.ones(shape=[batch_size, list_size], dtype=tf.bool) nd_indices, nd_mask = utils.padded_nd_indices(is_valid=mask) # Expand context features to be of [batch_size, list_size, ...]. batch_context_features = {} for name, tensor in six.iteritems(context_features): x = tf.expand_dims(input=tensor, axis=1) x = tf.gather(x, tf.zeros([list_size], tf.int32), axis=1) batch_context_features[name] = utils.reshape_first_ndims( x, 2, [batch_size, list_size]) batch_example_features = {} for name, tensor in six.iteritems(example_features): # Replace invalid example features with valid ones. padded_tensor = tf.gather_nd(tensor, nd_indices) batch_example_features[name] = utils.reshape_first_ndims( padded_tensor, 2, [batch_size, list_size]) sparse_inputs, dense_inputs = [], [] for name in batch_context_features: if name in self._sparse_embed_layers: sparse_inputs.append(self._sparse_embed_layers[name]( batch_context_features[name])) else: dense_inputs.append(context_features[name]) for name in batch_example_features: if name in self._sparse_embed_layers: sparse_inputs.append(self._sparse_embed_layers[name]( batch_example_features[name])) else: dense_inputs.append(batch_example_features[name]) sparse_inputs = [tf.squeeze(inpt, axis=2) for inpt in sparse_inputs] inputs = tf.concat(sparse_inputs + dense_inputs, axis=-1) scores = self.score(inputs, nd_mask, training=training) scores = tf.reshape(scores, shape=[batch_size, list_size, -1]) # Apply nd_mask to zero out invalid entries. # Expand dimension and use broadcasting for filtering. expanded_nd_mask = tf.expand_dims(nd_mask, axis=2) scores = tf.where(expanded_nd_mask, scores, tf.zeros_like(scores)) # Remove last dimension of shape = 1. try: logits = tf.squeeze(scores, axis=2) except: raise ValueError( 'Logits not of shape: [batch_size, list_size, 1]. ' 'This could occur if the `scorer` does not return ' 'a scalar output.') return logits