def compute_column_logits(output_layer, cell_index, cell_mask, init_cell_selection_weights_to_zero, allow_empty_column_selection): """Computes logits for each column. Args: output_layer: <float>[batch_size, seq_length, hidden_dim] Output of the encoder layer. cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into cells. cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per cell, 1 for cells that exists in the example and 0 for padding. init_cell_selection_weights_to_zero: Whether the initial weights should be set to 0. This is also applied to column logits, as they are used to select the cells. This ensures that all columns have the same prior probability. allow_empty_column_selection: Allow to select no column. Returns: <float>[batch_size, max_num_cols] Logits per column. Logits will be set to a very low value (such that the probability is 0) for the special id 0 (which means "outside the table") or columns that do not apear in the table. """ hidden_size = output_layer.shape.as_list()[-1] column_output_weights = tf.get_variable( "column_output_weights", [hidden_size], initializer=tf.zeros_initializer() if init_cell_selection_weights_to_zero else classification_initializer()) column_output_bias = tf.get_variable("column_output_bias", shape=(), initializer=tf.zeros_initializer()) token_logits = ( tf.einsum("bsj,j->bs", output_layer, column_output_weights) + column_output_bias) # Average the logits per cell and then per column. # Note that by linearity it doesn't matter if we do the averaging on the # embeddings or on the logits. For performance we do the projection first. # [batch_size, max_num_cols * max_num_rows] cell_logits, cell_logits_index = segmented_tensor.reduce_mean( token_logits, cell_index) column_index = cell_index.project_inner(cell_logits_index) # [batch_size, max_num_cols] column_logits, out_index = segmented_tensor.reduce_sum( cell_logits * cell_mask, column_index) cell_count, _ = segmented_tensor.reduce_sum(cell_mask, column_index) column_logits /= cell_count + EPSILON_ZERO_DIVISION # Mask columns that do not appear in the example. is_padding = tf.logical_and(cell_count < 0.5, tf.not_equal(out_index.indices, 0)) column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast(is_padding, tf.float32) if not allow_empty_column_selection: column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast( tf.equal(out_index.indices, 0), tf.float32) return column_logits
def test_reduce_sum(self): values, row_index, col_index = self._prepare_tables() cell_index = segmented_tensor.ProductIndexMap(row_index, col_index) row_sum, _ = segmented_tensor.reduce_sum(values, row_index) col_sum, _ = segmented_tensor.reduce_sum(values, col_index) cell_sum, _ = segmented_tensor.reduce_sum(values, cell_index) with self.session() as sess: self.assertAllClose(sess.run(row_sum), [[6.0, 3.0, 8.0], [6.0, 3.0, 8.0]]) self.assertAllClose(sess.run(col_sum), [[9.0, 8.0, 0.0], [4.0, 5.0, 8.0]]) self.assertAllClose( sess.run(cell_sum), [[3.0, 3.0, 0.0, 2.0, 1.0, 0.0, 4.0, 4.0, 0.0], [1.0, 2.0, 3.0, 2.0, 0.0, 1.0, 1.0, 3.0, 4.0]])
def test_reduce_sum_vectorized(self): values = [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]] index = segmented_tensor.IndexMap( indices=[0, 0, 1], num_segments=2, batch_dims=0) sums, new_index = segmented_tensor.reduce_sum(values, index) with self.session() as sess: self.assertAllClose(sess.run(sums), [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]]) self.assertAllEqual(sess.run(new_index.indices), [0, 1]) self.assertEqual(sess.run(new_index.num_segments), 2) self.assertEqual(new_index.batch_dims, 0)
def call(self, inputs, cell_index, cell_mask): ''' Args: inputs: <float>[batch_size, seq_length, hidden_dim] Output of the encoder layer. cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into cells. cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per cell, 1 for cells that exists in the example and 0 for padding. ''' token_logits = ( tf.einsum("bsj,j->bs", inputs, self.column_output_weights) + self.column_output_bias) # Average the logits per cell and then per column. # Note that by linearity it doesn't matter if we do the averaging on the # embeddings or on the logits. For performance we do the projection first. # [batch_size, max_num_cols * max_num_rows] cell_logits, cell_logits_index = segmented_tensor.reduce_mean( token_logits, cell_index) column_index = cell_index.project_inner(cell_logits_index) # [batch_size, max_num_cols] column_logits, out_index = segmented_tensor.reduce_sum( cell_logits * cell_mask, column_index) cell_count, _ = segmented_tensor.reduce_sum(cell_mask, column_index) column_logits /= cell_count + EPSILON_ZERO_DIVISION # Mask columns that do not appear in the example. is_padding = tf.logical_and(cell_count < 0.5, tf.not_equal(out_index.indices, 0)) column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * \ tf.cast(is_padding, tf.float32) if not self.allow_empty_column_selection: column_logits += CLOSE_ENOUGH_TO_LOG_ZERO * tf.cast( tf.equal(out_index.indices, 0), tf.float32) return column_logits
def test_gather(self): values, row_index, col_index = self._prepare_tables() cell_index = segmented_tensor.ProductIndexMap(row_index, col_index) # Compute sums and then gather. The result should have the same shape as # the original table and each element should contain the sum the values in # its cell. sums, _ = segmented_tensor.reduce_sum(values, cell_index) cell_sum = segmented_tensor.gather(sums, cell_index) cell_sum.shape.assert_is_compatible_with(values.shape) with self.session() as sess: self.assertAllClose( sess.run(cell_sum), [[[3.0, 3.0, 3.0], [2.0, 2.0, 1.0], [4.0, 4.0, 4.0]], [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]]])
def single_column_cell_selection(token_logits, column_logits, label_ids, cell_index, col_index, cell_mask): """Computes the loss for cell selection constrained to a single column. The loss is a hierarchical log-likelihood. The model first predicts a column and then selects cells within that column (conditioned on the column). Cells outside the selected column are never selected. Args: token_logits: <float>[batch_size, seq_length] Logits per token. column_logits: <float>[batch_size, max_num_cols] Logits per column. label_ids: <int32>[batch_size, seq_length] Labels per token. cell_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into cells. col_index: segmented_tensor.IndexMap [batch_size, seq_length] Index that groups tokens into columns. cell_mask: <float>[batch_size, max_num_rows * max_num_cols] Input mask per cell, 1 for cells that exists in the example and 0 for padding. Returns: selection_loss_per_example: <float>[batch_size] Loss for each example. logits: <float>[batch_size, seq_length] New logits which are only allowed to select cells in a single column. Logits outside of the most likely column according to `column_logits` will be set to a very low value (such that the probabilities are 0). """ # First find the column we should select. We use the column with maximum # number of selected cells. labels_per_column, _ = segmented_tensor.reduce_sum( tf.cast(label_ids, tf.float32), col_index) column_label = tf.argmax(labels_per_column, axis=-1, output_type=tf.int32) # Check if there are no selected cells in the column. In that case the model # should predict the special column id 0, which means "select nothing". no_cell_selected = tf.equal(tf.reduce_max(labels_per_column, axis=-1), 0) column_label = tf.where(no_cell_selected, tf.zeros_like(column_label), column_label) column_dist = tfp.distributions.Categorical(logits=column_logits) # Reduce the labels and logits to per-cell from per-token. logits_per_cell, _ = segmented_tensor.reduce_mean(token_logits, cell_index) _, labels_index = segmented_tensor.reduce_max(tf.cast(label_ids, tf.int32), cell_index) # Mask for the selected column. column_id_for_cells = cell_index.project_inner(labels_index).indices # Set the probs outside the selected column (selected by the *model*) # to 0. This ensures backwards compatibility with models that select # cells from multiple columns. selected_column_id = tf.argmax(column_logits, axis=-1, output_type=tf.int32) selected_column_mask = tf.cast( tf.equal(column_id_for_cells, tf.expand_dims(selected_column_id, axis=-1)), tf.float32) # Never select cells with the special column id 0. selected_column_mask = tf.where(tf.equal(column_id_for_cells, 0), tf.zeros_like(selected_column_mask), selected_column_mask) logits_per_cell += CLOSE_ENOUGH_TO_LOG_ZERO * ( 1.0 - cell_mask * selected_column_mask) logits = segmented_tensor.gather(logits_per_cell, cell_index) return logits