def test_reduce_max(self): values = [2, 1, 0, 3] index = segmented_tensor.IndexMap(indices=[0, 1, 0, 1], num_segments=2) maximum, _ = segmented_tensor.reduce_max(values, index) with self.session() as sess: self.assertAllEqual(sess.run(maximum), [2, 3])
def single_column_cell_selection(token_logits, column_logits, label_ids, cell_index, col_index, cell_mask): """Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside the selected column are never selected. Args: token_logits: <float>[batch_size, seq_length] Logits per token. column_logits: <float>[batch_size, max_num_cols] Logits per column. label_ids: <int32>[batch_size, seq_length] Labels per token. cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into cells. col_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into columns. cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per cell, 1 for cells that exists in the example and 0 for padding. Returns: selection_loss_per_example: <float>[batch_size] Loss for each example. logits: <float>[batch_size, seq_length] New logits which are only allowed to select cells in a single column. Logits outside of the most likely column according to `column_logits` will be set to a very low value (such that the probabilities are 0). """ # First find the column we should select. We use the column with maximum # number of selected cells. labels_per_column, _ = segmented_tensor.reduce_sum( tf.cast(label_ids, tf.float32), col_index) column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32) # Check if there are no selected cells in the column. In that case the model # should predict the special column id 0, which means "select nothing". no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0) column_label = tf.where(no_cell_selected, tf.zeros_like(column_label), column_label) column_dist = tfp.distributions.Categorical(logits=column_logits) # Reduce the labels and logits to per-cell from per-token. logits_per_cell, _ = segmented_tensor.reduce_mean(token_logits, cell_index) _, labels_index = segmented_tensor.reduce_max(tf.cast(label_ids, tf.int32), cell_index) # Mask for the selected column. column_id_for_cells = cell_index.project_inner(labels_index).indices # Set the probs outside the selected column (selected by the *model*) # to 0. This ensures backwards compatibility with models that select # cells from multiple columns. selected_column_id = tf.argmax(column_logits, axis=-1, output_type=tf.int32) selected_column_mask = tf.cast( tf.equal(column_id_for_cells, tf.expand_dims(selected_column_id, axis=-1)), tf.float32) # Never select cells with the special column id 0. selected_column_mask = tf.where(tf.equal(column_id_for_cells, 0), tf.zeros_like(selected_column_mask), selected_column_mask) logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * ( 1.0 - cell_mask * selected_column_mask) logits = segmented_tensor.gather(logits_per_cell, cell_index) return logits