def compute_column_logits(output_layer, cell_index, cell_mask,
                          init_cell_selection_weights_to_zero,
                          allow_empty_column_selection):
    """Computes logits for each column.

  Args:
    output_layer: <float>[batch_size, seq_length, hidden_dim] Output of the
      encoder layer.
    cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
      groups tokens into cells.
    cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per
      cell, 1 for cells that exists in the example and 0 for padding.
    init_cell_selection_weights_to_zero: Whether the initial weights should be
      set to 0. This is also applied to column logits, as they are used to
      select the cells. This ensures that all columns have the same prior
      probability.
    allow_empty_column_selection: Allow to select no column.

  Returns:
    <float>[batch_size, max_num_cols] Logits per column. Logits will be set to
      a very low value (such that the probability is 0) for the special id 0
      (which means "outside the table") or columns that do not apear in the
      table.
  """
    hidden_size = output_layer.shape.as_list()[-1]
    column_output_weights = tf.get_variable(
        "column_output_weights", [hidden_size],
        initializer=tf.zeros_initializer() if
        init_cell_selection_weights_to_zero else classification_initializer())
    column_output_bias = tf.get_variable("column_output_bias",
                                         shape=(),
                                         initializer=tf.zeros_initializer())
    token_logits = (
        tf.einsum("bsj,j->bs", output_layer, column_output_weights) +
        column_output_bias)

    # Average the logits per cell and then per column.
    # Note that by linearity it doesn't matter if we do the averaging on the
    # embeddings or on the logits. For performance we do the projection first.
    # [batch_size, max_num_cols * max_num_rows]
    cell_logits, cell_logits_index = segmented_tensor.reduce_mean(
        token_logits, cell_index)

    column_index = cell_index.project_inner(cell_logits_index)
    # [batch_size, max_num_cols]
    column_logits, out_index = segmented_tensor.reduce_sum(
        cell_logits * cell_mask, column_index)
    cell_count, _ = segmented_tensor.reduce_sum(cell_mask, column_index)
    column_logits /= cell_count + EPSILON_ZERO_DIVISION

    # Mask columns that do not appear in the example.
    is_padding = tf.logical_and(cell_count < 0.5,
                                tf.not_equal(out_index.indices, 0))
    column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32)

    if not allow_empty_column_selection:
        column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(
            tf.equal(out_index.indices, 0), tf.float32)

    return column_logits
 def test_reduce_sum(self):
   values, row_index, col_index = self._prepare_tables()
   cell_index = segmented_tensor.ProductIndexMap(row_index, col_index)
   row_sum, _ = segmented_tensor.reduce_sum(values, row_index)
   col_sum, _ = segmented_tensor.reduce_sum(values, col_index)
   cell_sum, _ = segmented_tensor.reduce_sum(values, cell_index)
   with self.session() as sess:
     self.assertAllClose(sess.run(row_sum), [[6.0, 3.0, 8.0], [6.0, 3.0, 8.0]])
     self.assertAllClose(sess.run(col_sum), [[9.0, 8.0, 0.0], [4.0, 5.0, 8.0]])
     self.assertAllClose(
         sess.run(cell_sum), [[3.0, 3.0, 0.0, 2.0, 1.0, 0.0, 4.0, 4.0, 0.0],
                              [1.0, 2.0, 3.0, 2.0, 0.0, 1.0, 1.0, 3.0, 4.0]])
 def test_reduce_sum_vectorized(self):
   values = [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]
   index = segmented_tensor.IndexMap(
       indices=[0, 0, 1], num_segments=2, batch_dims=0)
   sums, new_index = segmented_tensor.reduce_sum(values, index)
   with self.session() as sess:
     self.assertAllClose(sess.run(sums), [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]])
     self.assertAllEqual(sess.run(new_index.indices), [0, 1])
     self.assertEqual(sess.run(new_index.num_segments), 2)
     self.assertEqual(new_index.batch_dims, 0)
예제 #4
0
    def call(self, inputs, cell_index, cell_mask):
        '''
        Args:
        inputs: <float>[batch_size, seq_length, hidden_dim] Output of the
            encoder layer.
        cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
        groups tokens into cells.
        cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per
        cell, 1 for cells that exists in the example and 0 for padding.
        '''
        token_logits = (
            tf.einsum("bsj,j->bs", inputs, self.column_output_weights) +
            self.column_output_bias)

        # Average the logits per cell and then per column.
        # Note that by linearity it doesn't matter if we do the averaging on the
        # embeddings or on the logits. For performance we do the projection first.
        # [batch_size, max_num_cols * max_num_rows]
        cell_logits, cell_logits_index = segmented_tensor.reduce_mean(
            token_logits, cell_index)

        column_index = cell_index.project_inner(cell_logits_index)
        # [batch_size, max_num_cols]
        column_logits, out_index = segmented_tensor.reduce_sum(
            cell_logits * cell_mask, column_index)
        cell_count, _ = segmented_tensor.reduce_sum(cell_mask, column_index)
        column_logits /= cell_count + EPSILON_ZERO_DIVISION

        # Mask columns that do not appear in the example.
        is_padding = tf.logical_and(cell_count < 0.5,
                                    tf.not_equal(out_index.indices, 0))
        column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * \
            tf.cast(is_padding, tf.float32)

        if not self.allow_empty_column_selection:
            column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(
                tf.equal(out_index.indices, 0), tf.float32)

        return column_logits
  def test_gather(self):
    values, row_index, col_index = self._prepare_tables()
    cell_index = segmented_tensor.ProductIndexMap(row_index, col_index)

    # Compute sums and then gather. The result should have the same shape as
    # the original table and each element should contain the sum the values in
    # its cell.
    sums, _ = segmented_tensor.reduce_sum(values, cell_index)
    cell_sum = segmented_tensor.gather(sums, cell_index)
    cell_sum.shape.assert_is_compatible_with(values.shape)

    with self.session() as sess:
      self.assertAllClose(
          sess.run(cell_sum),
          [[[3.0, 3.0, 3.0], [2.0, 2.0, 1.0], [4.0, 4.0, 4.0]],
           [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]]])
예제 #6
0
def single_column_cell_selection(token_logits, column_logits, label_ids,
                                 cell_index, col_index, cell_mask):
    """Computes the loss for cell selection constrained to a single column.

    The loss is a hierarchical log-likelihood. The model first predicts a column
    and then selects cells within that column (conditioned on the column). Cells
    outside the selected column are never selected.

    Args:
      token_logits: <float>[batch_size, seq_length] Logits per token.
      column_logits: <float>[batch_size, max_num_cols] Logits per column.
      label_ids: <int32>[batch_size, seq_length] Labels per token.
      cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
        groups tokens into cells.
      col_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that
        groups tokens into columns.
      cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per
        cell, 1 for cells that exists in the example and 0 for padding.

    Returns:
      selection_loss_per_example: <float>[batch_size] Loss for each example.
      logits: <float>[batch_size, seq_length] New logits which are only allowed
        to select cells in a single column. Logits outside of the most likely
        column according to `column_logits` will be set to a very low value
        (such that the probabilities are 0).
    """
    # First find the column we should select. We use the column with maximum
    # number of selected cells.
    labels_per_column, _ = segmented_tensor.reduce_sum(
        tf.cast(label_ids, tf.float32), col_index)
    column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32)
    # Check if there are no selected cells in the column. In that case the model
    # should predict the special column id 0, which means "select nothing".
    no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0)
    column_label = tf.where(no_cell_selected, tf.zeros_like(column_label),
                            column_label)

    column_dist = tfp.distributions.Categorical(logits=column_logits)

    # Reduce the labels and logits to per-cell from per-token.
    logits_per_cell, _ = segmented_tensor.reduce_mean(token_logits, cell_index)
    _, labels_index = segmented_tensor.reduce_max(tf.cast(label_ids, tf.int32),
                                                  cell_index)

    # Mask for the selected column.
    column_id_for_cells = cell_index.project_inner(labels_index).indices

    # Set the probs outside the selected column (selected by the *model*)
    # to 0. This ensures backwards compatibility with models that select
    # cells from multiple columns.
    selected_column_id = tf.argmax(column_logits,
                                   axis=-1,
                                   output_type=tf.int32)
    selected_column_mask = tf.cast(
        tf.equal(column_id_for_cells,
                 tf.expand_dims(selected_column_id, axis=-1)), tf.float32)
    # Never select cells with the special column id 0.
    selected_column_mask = tf.where(tf.equal(column_id_for_cells, 0),
                                    tf.zeros_like(selected_column_mask),
                                    selected_column_mask)
    logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * (
        1.0 - cell_mask * selected_column_mask)
    logits = segmented_tensor.gather(logits_per_cell, cell_index)

    return logits