コード例 #1
0
 def test_reduce_max(self):
   values = [2, 1, 0, 3]
   index = segmented_tensor.IndexMap(indices=[0, 1, 0, 1], num_segments=2)
   maximum, _ = segmented_tensor.reduce_max(values, index)
   with self.session() as sess:
     self.assertAllEqual(sess.run(maximum), [2, 3])
コード例 #2
0
def single_column_cell_selection(token_logits, column_logits, label_ids,
                                 cell_index, col_index, cell_mask):
    """Computes the loss for cell selection constrained to a single column.

    The loss is a hierarchical log-likelihood. The model first predicts a column
    and then selects cells within that column (conditioned on the column). Cells
    outside the selected column are never selected.

    Args:
      token_logits: <float>[batch_size, seq_length] Logits per token.
      column_logits: <float>[batch_size, max_num_cols] Logits per column.
      label_ids: <int32>[batch_size, seq_length] Labels per token.
      cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
        groups tokens into cells.
      col_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
        groups tokens into columns.
      cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per
        cell, 1 for cells that exists in the example and 0 for padding.

    Returns:
      selection_loss_per_example: <float>[batch_size] Loss for each example.
      logits: <float>[batch_size, seq_length] New logits which are only allowed
        to select cells in a single column. Logits outside of the most likely
        column according to `column_logits` will be set to a very low value
        (such that the probabilities are 0).
    """
    # First find the column we should select. We use the column with maximum
    # number of selected cells.
    labels_per_column, _ = segmented_tensor.reduce_sum(
        tf.cast(label_ids, tf.float32), col_index)
    column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32)
    # Check if there are no selected cells in the column. In that case the model
    # should predict the special column id 0, which means "select nothing".
    no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0)
    column_label = tf.where(no_cell_selected, tf.zeros_like(column_label),
                            column_label)

    column_dist = tfp.distributions.Categorical(logits=column_logits)

    # Reduce the labels and logits to per-cell from per-token.
    logits_per_cell, _ = segmented_tensor.reduce_mean(token_logits, cell_index)
    _, labels_index = segmented_tensor.reduce_max(tf.cast(label_ids, tf.int32),
                                                  cell_index)

    # Mask for the selected column.
    column_id_for_cells = cell_index.project_inner(labels_index).indices

    # Set the probs outside the selected column (selected by the *model*)
    # to 0. This ensures backwards compatibility with models that select
    # cells from multiple columns.
    selected_column_id = tf.argmax(column_logits,
                                   axis=-1,
                                   output_type=tf.int32)
    selected_column_mask = tf.cast(
        tf.equal(column_id_for_cells,
                 tf.expand_dims(selected_column_id, axis=-1)), tf.float32)
    # Never select cells with the special column id 0.
    selected_column_mask = tf.where(tf.equal(column_id_for_cells, 0),
                                    tf.zeros_like(selected_column_mask),
                                    selected_column_mask)
    logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * (
        1.0 - cell_mask * selected_column_mask)
    logits = segmented_tensor.gather(logits_per_cell, cell_index)

    return logits