Ejemplo n.º 1
0
  def _prepare_tables(self):
    """Prepares two tables, both with three distinct rows.

    The first table has two columns:
      1.0, 2.0 | 3.0
      2.0, 0.0 | 1.0
      1.0, 3.0 | 4.0

    The second table has three columns:
      1.0 | 2.0 | 3.0
      2.0 | 0.0 | 1.0
      1.0 | 3.0 | 4.0

    Returns:
      SegmentedTensors with the tables.
    """
    values = tf.constant([
        [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]],
        [[1.0, 2.0, 3.0], [2.0, 0.0, 1.0], [1.0, 3.0, 4.0]],
    ])
    row_index = segmented_tensor.IndexMap(
        indices=[
            [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
            [[0, 0, 0], [1, 1, 1], [2, 2, 2]],
        ],
        num_segments=3,
        batch_dims=1)
    col_index = segmented_tensor.IndexMap(
        indices=[
            [[0, 0, 1], [0, 0, 1], [0, 0, 1]],
            [[0, 1, 2], [0, 1, 2], [0, 1, 2]],
        ],
        num_segments=3,
        batch_dims=1)
    return values, row_index, col_index
Ejemplo n.º 2
0
def _get_relative_position_embeddings(
    full_position_embeddings,
    token_type_ids,
    token_type_vocab_size,
    seq_length,
    batch_size,
    max_position_embeddings,
):
  """Create position embeddings that restart at every cell."""
  col_index = segmented_tensor.IndexMap(
      token_type_ids[1], token_type_vocab_size[1], batch_dims=1)
  row_index = segmented_tensor.IndexMap(
      token_type_ids[2], token_type_vocab_size[2], batch_dims=1)
  full_index = segmented_tensor.ProductIndexMap(col_index, row_index)
  position = tf.expand_dims(tf.range(seq_length), axis=0)
  logging.info("position: %s", position)
  batched_position = tf.repeat(position, repeats=batch_size, axis=0)
  logging.info("batched_position: %s", batched_position)
  logging.info("token_type_ids: %s", token_type_ids[1])
  first_position_per_segment = segmented_tensor.reduce_min(
      batched_position, full_index)[0]
  first_position = segmented_tensor.gather(first_position_per_segment,
                                           full_index)
  position_embeddings = tf.nn.embedding_lookup(
      full_position_embeddings,
      tf.math.minimum(max_position_embeddings - 1, position - first_position))
  return position_embeddings
Ejemplo n.º 3
0
def get_token_scores_from_column_scores(
    column_ids,
    column_probs,
    input_mask,
    max_num_columns,
):
    """Given the columns scores in [0,1] extracts the tokens scores.

  It also gives a score of 1.0 for the question's tokens and padding.

  Args:
    column_ids: <int32>[batch_size, seq_length] additional to the columns' ids
      [1, max_num_columns] the value 0 refers to question tokens and padding.
    column_probs: <float32>[batch_size, max_column_id]: contains only the
      columns' scores: question score or padding not included. The expected
        values are in [0,1].
    input_mask: <float32>[batch_size, seq_length] used to zero-out the padding.
    max_num_columns: the maximum number of columns.

  Returns:
    <float32>[batch_size, seq_length]: The tokens' scores.
  """
    col_index = segmented_tensor.IndexMap(indices=column_ids,
                                          num_segments=max_num_columns + 1,
                                          batch_dims=1)
    # <float32>[batch size, max_num_columns+1]: it contains the question at pos 0.
    # The scores for the question and padding is 1.
    padded_column_scores = tf.pad(column_probs,
                                  paddings=[[0, 0], [1, 0]],
                                  constant_values=1.0)
    # <float32>[batch_size, seq_length]
    return segmented_tensor.gather(index=col_index,
                                   values=padded_column_scores) * tf.cast(
                                       input_mask, dtype=tf.float32)
Ejemplo n.º 4
0
 def _select_columns(self, mode, features):
     input_mask = features["input_mask"]
     column_ids = features["column_ids"]
     with tf.variable_scope("bert"):
         with tf.variable_scope("embeddings",
                                reuse=tf.compat.v1.AUTO_REUSE):
             input_embeddings, _ = modeling.embedding_lookup(
                 input_ids=features["input_ids"],
                 vocab_size=self._vocab_size,
                 embedding_size=self._hidden_size,
                 initializer_range=self._initializer_range,
                 word_embedding_name="word_embeddings")
             if self._use_positional_embeddings:
                 token_type_ids = []
                 token_type_features = [
                     "segment_ids", "column_ids", "row_ids",
                     "prev_label_ids", "column_ranks", "inv_column_ranks",
                     "numeric_relations"
                 ]
                 for key in token_type_features:
                     if self._disabled_features is not None and key in self._disabled_features:
                         token_type_ids.append(tf.zeros_like(features[key]))
                     else:
                         token_type_ids.append(features[key])
                 input_embeddings = modeling.embedding_postprocessor(
                     input_tensor=input_embeddings,
                     use_token_type=True,
                     token_type_ids=token_type_ids,
                     token_type_vocab_size=self._type_vocab_size,
                     token_type_embedding_name="token_type_embeddings",
                     use_position_embeddings=self._use_position_embeddings,
                     position_embedding_name="position_embeddings",
                     initializer_range=self._initializer_range,
                     max_position_embeddings=self._max_position_embeddings,
                     extra_embeddings=None,
                     dropout_prob=0.0)
             # Indexes all the zero values from the input_mask by (max_num_columns+1)
             # The index 0 is for the question and from 1 to max_num_columns included
             # is for the columns.
             masked_col_ids = column_ids * input_mask + (1 - input_mask) * (
                 self._max_num_columns + 1)
             col_index = segmented_tensor.IndexMap(
                 indices=masked_col_ids,
                 num_segments=self._max_num_columns + 2,
                 batch_dims=1)
             average_embeddings, _ = segmented_tensor.reduce_mean(
                 input_embeddings, col_index)
             # Removes the last index as it contains the avg of non selected values
             average_embeddings = average_embeddings[:, :-1]
             normalize_average_embeddings = tf.math.l2_normalize(
                 average_embeddings, axis=2)
             questions_embeddings = normalize_average_embeddings[:, :1]
             columns_embeddings = normalize_average_embeddings[:, 1:]
             multiply = columns_embeddings * questions_embeddings
             multiply = tf.where(tf.is_nan(multiply),
                                 tf.zeros_like(multiply), multiply)
             column_scores = tf.math.reduce_sum(multiply,
                                                axis=-1,
                                                name="column_scores")
             return column_scores
Ejemplo n.º 5
0
 def test_gather_vectorized(self):
   values = [[[1, 2], [3, 4]], [[5, 6], [7, 8]]]
   index = segmented_tensor.IndexMap(
       indices=[[0, 1], [1, 0]], num_segments=2, batch_dims=1)
   result = segmented_tensor.gather(values, index)
   with self.session() as sess:
     self.assertAllEqual(
         sess.run(result), [[[1, 2], [3, 4]], [[7, 8], [5, 6]]])
Ejemplo n.º 6
0
 def test_reduce_sum_vectorized(self):
   values = [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0], [3.0, 4.0, 5.0]]
   index = segmented_tensor.IndexMap(
       indices=[0, 0, 1], num_segments=2, batch_dims=0)
   sums, new_index = segmented_tensor.reduce_sum(values, index)
   with self.session() as sess:
     self.assertAllClose(sess.run(sums), [[3.0, 5.0, 7.0], [3.0, 4.0, 5.0]])
     self.assertAllEqual(sess.run(new_index.indices), [0, 1])
     self.assertEqual(sess.run(new_index.num_segments), 2)
     self.assertEqual(new_index.batch_dims, 0)
Ejemplo n.º 7
0
    def _compute_column_scores_from_token_scores(self, mode, output_layer,
                                                 features):
        """Gets the columns scores by avereging the tokens scores."""
        with tf.variable_scope(PRUNING_SCOPE, reuse=tf.AUTO_REUSE):
            if mode == tf_estimator.ModeKeys.TRAIN:
                output_layer = tf.nn.dropout(
                    output_layer, keep_prob=_SEQUENCE_OUTPUT_KEEP_PROB)
            input_mask = features["input_mask"]
            row_ids = features["row_ids"]
            column_ids = features["column_ids"]

            # Construct indices for the table.
            row_index = segmented_tensor.IndexMap(
                indices=tf.minimum(row_ids, self._max_num_rows - 1),
                num_segments=self._max_num_rows,
                batch_dims=1)
            col_index = segmented_tensor.IndexMap(
                indices=tf.minimum(column_ids, self._max_num_columns),
                num_segments=self._max_num_columns + 1,
                batch_dims=1)
            cell_index = segmented_tensor.ProductIndexMap(row_index, col_index)

            # Masks.
            # <float32>[batch_size, seq_length]
            input_mask_float = tf.cast(input_mask, tf.float32)
            # Mask for cells that exist in the table (i.e. that are not padding).
            cell_mask, _ = segmented_tensor.reduce_mean(
                input_mask_float, cell_index)

            # Compute logits per column which can be used to select a column.
            # <float32>[batch_size, max_num_columns]
            column_scores = utils.compute_column_logits(
                output_layer=output_layer,
                cell_index=cell_index,
                cell_mask=cell_mask,
                init_cell_selection_weights_to_zero=False,
                allow_empty_column_selection=False)[:, 1:]
            column_scores = tf.debugging.assert_all_finite(
                column_scores, "column_scores contains nan values.")
            return column_scores
Ejemplo n.º 8
0
  def test_flatten(self):
    _, row_index, col_index = self._prepare_tables()
    row_index_flat = segmented_tensor.flatten(row_index)
    col_index_flat = segmented_tensor.flatten(col_index)

    shape = [3, 4, 5]
    batched_index = segmented_tensor.IndexMap(
        indices=tf.fill(shape, tf.constant(0, dtype=tf.int32)),
        num_segments=1,
        batch_dims=3)
    batched_index_flat = segmented_tensor.flatten(batched_index)

    with self.session() as sess:
      self.assertAllEqual(
          sess.run(row_index_flat.indices),
          [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5])
      self.assertAllEqual(
          sess.run(col_index_flat.indices),
          [0, 0, 1, 0, 0, 1, 0, 0, 1, 3, 4, 5, 3, 4, 5, 3, 4, 5])
      self.assertEqual(
          sess.run(batched_index_flat.num_segments), np.prod(shape))
      self.assertAllEqual(
          sess.run(batched_index_flat.indices), range(np.prod(shape)))
Ejemplo n.º 9
0
 def test_reduce_max(self):
   values = [2, 1, 0, 3]
   index = segmented_tensor.IndexMap(indices=[0, 1, 0, 1], num_segments=2)
   maximum, _ = segmented_tensor.reduce_max(values, index)
   with self.session() as sess:
     self.assertAllEqual(sess.run(maximum), [2, 3])
Ejemplo n.º 10
0
    def call(self, input_token_ids, input_mask, segment_ids, column_ids,
             row_ids, prev_label_ids, column_ranks, inv_column_ranks,
             numeric_relations, label_ids, **kwargs):

        # Construct indices for the table.
        row_index = segmented_tensor.IndexMap(
            indices=tf.minimum(tf.cast(row_ids, tf.int32),
                               self.tapas_classifier_config.max_num_rows - 1),
            num_segments=self.tapas_classifier_config.max_num_rows,
            batch_dims=1)
        col_index = segmented_tensor.IndexMap(
            indices=tf.minimum(
                tf.cast(column_ids, tf.int32),
                self.tapas_classifier_config.max_num_columns - 1),
            num_segments=self.tapas_classifier_config.max_num_columns,
            batch_dims=1)
        cell_index = segmented_tensor.ProductIndexMap(row_index, col_index)

        # Masks.
        # <float32>[batch_size, seq_length]
        table_mask = tf.where(row_ids > 0, tf.ones_like(row_ids),
                              tf.zeros_like(row_ids))
        input_mask_float = tf.cast(input_mask, tf.float32)
        table_mask_float = tf.cast(table_mask, tf.float32)

        # Mask for cells that exist in the table (i.e. that are not padding).
        cell_mask, _ = segmented_tensor.reduce_mean(input_mask_float,
                                                    cell_index)

        pooled_output, sequence_output = self.bert([
            input_token_ids, input_mask, segment_ids, column_ids, row_ids,
            prev_label_ids, column_ranks, inv_column_ranks, numeric_relations
        ], **kwargs)
        # Compute logits per token. These are used to select individual cells.
        logits = self.compute_token_logits(sequence_output)
        # Compute logits per column. These are used to select a column.
        if self.tapas_classifier_config.select_one_column:
            column_logits = self.compute_column_logits(sequence_output,
                                                       cell_index, cell_mask)

        logits_cls = None
        if self.do_model_classification:
            logits_cls = self.compute_classification_logits(pooled_output)

        if self.tapas_classifier_config.average_logits_per_cell:
            logits_per_cell, _ = segmented_tensor.reduce_mean(
                logits, cell_index)
            logits = segmented_tensor.gather(logits_per_cell, cell_index)
        dist_per_token = tfp.distributions.Bernoulli(logits=logits)

        if self.tapas_classifier_config.select_one_column:
            logits = single_column_cell_selection(logits, column_logits,
                                                  label_ids, cell_index,
                                                  col_index, cell_mask)
            dist_per_token = tfp.distributions.Bernoulli(logits=logits)

        logits_aggregation = None
        if self.do_model_aggregation:
            logits_aggregation = self.calculate_aggregation_logits(
                pooled_output)

        probs = _get_probs(dist_per_token) * input_mask_float

        return logits, probs, logits_aggregation, logits_cls
Ejemplo n.º 11
0
def _get_classification_outputs(
    config,
    is_training,
    output_layer,
    output_layer_aggregation,
    label_ids,
    input_mask,
    table_mask,
    aggregation_function_id,
    answer,
    numeric_values,
    numeric_values_scale,
    row_ids,
    column_ids,
    classification_class_index,
):
    """Creates a classification model.

  Args:
    config: Configuration for Tapas model.
    is_training: Whether the model is training.
    output_layer: <float32>[batch_size, seq_length, hidden_size]
    output_layer_aggregation: <float32>[batch_size, hidden_size]
    label_ids: <int32>[batch_size, seq_length]
    input_mask: <int32>[batch_size, seq_length]
    table_mask: <int32>[batch_size, seq_length]
    aggregation_function_id: <int32>[batch_size]
    answer: <float32>[batch_size]
    numeric_values: <float32>[batch_size, seq_length]
    numeric_values_scale: <float32>[batch_size, seq_length]
    row_ids: <int32>[batch_size, seq_length]
    column_ids: <int32>[batch_size, seq_length]
    classification_class_index: <int32>[batch]

  Returns:
    Outputs
  """
    if is_training:
        # I.e., 0.1 dropout
        output_layer = tf.nn.dropout(output_layer, keep_prob=0.9)

    # Construct indices for the table.
    row_index = segmented_tensor.IndexMap(indices=tf.minimum(
        row_ids, config.max_num_rows - 1),
                                          num_segments=config.max_num_rows,
                                          batch_dims=1)
    col_index = segmented_tensor.IndexMap(indices=tf.minimum(
        column_ids, config.max_num_columns - 1),
                                          num_segments=config.max_num_columns,
                                          batch_dims=1)
    cell_index = segmented_tensor.ProductIndexMap(row_index, col_index)

    # Masks.
    # <float32>[batch_size, seq_length]
    input_mask_float = tf.cast(input_mask, tf.float32)
    table_mask_float = tf.cast(table_mask, tf.float32)
    # Mask for cells that exist in the table (i.e. that are not padding).
    cell_mask, _ = segmented_tensor.reduce_mean(input_mask_float, cell_index)

    # Compute logits per token. These are used to select individual cells.
    logits = utils.compute_token_logits(
        output_layer=output_layer,
        temperature=config.temperature,
        init_cell_selection_weights_to_zero=(
            config.init_cell_selection_weights_to_zero))

    # Compute logits per column. These are used to select a column.
    if config.select_one_column:
        column_logits = utils.compute_column_logits(
            output_layer=output_layer,
            cell_index=cell_index,
            cell_mask=cell_mask,
            init_cell_selection_weights_to_zero=(
                config.init_cell_selection_weights_to_zero),
            allow_empty_column_selection=config.allow_empty_column_selection)

    # TODO(pawelnow): Extract this into a function.
    # Compute aggregation function logits.
    do_model_aggregation = config.num_aggregation_labels > 0
    if do_model_aggregation:
        hidden_size_agg = output_layer_aggregation.shape[-1].value
        output_weights_agg = tf.get_variable(
            "output_weights_agg",
            shape=[config.num_aggregation_labels, hidden_size_agg],
            initializer=_classification_initializer())
        output_bias_agg = tf.get_variable(
            "output_bias_agg",
            shape=[config.num_aggregation_labels],
            initializer=tf.zeros_initializer())

    do_model_classification = config.num_classification_labels > 0
    logits_cls = None
    if do_model_classification:
        logits_cls = compute_classification_logits(
            config.num_classification_labels, output_layer_aggregation)

    with tf.variable_scope("loss"):
        total_loss = 0.0
        is_supervised = (not do_model_aggregation
                         or not config.use_answer_as_supervision)

        ### Semi-supervised cell selection in case of no aggregation
        #############################################################

        # If the answer (the denotation) appears directly in the table we might
        # select the answer without applying any aggregation function. There are
        # some ambiguous cases, see _calculate_aggregate_mask for more info.
        # `aggregate_mask` is 1 for examples where we chose to aggregate and 0
        #  for examples where we chose to select the answer directly.
        # `label_ids` encodes the positions of the answer appearing in the table.
        if is_supervised:
            aggregate_mask = None
        else:
            # <float32>[batch_size]
            aggregate_mask = _calculate_aggregate_mask(
                answer=answer,
                output_layer_aggregation=output_layer_aggregation,
                output_bias_agg=output_bias_agg,
                output_weights_agg=output_weights_agg,
                cell_select_pref=config.cell_select_pref,
                label_ids=label_ids)

        ### Cell selection log-likelihood
        ###################################

        if config.average_logits_per_cell:
            logits_per_cell, _ = segmented_tensor.reduce_mean(
                logits, cell_index)
            logits = segmented_tensor.gather(logits_per_cell, cell_index)
        dist_per_token = tfp.distributions.Bernoulli(logits=logits)

        selection_loss_per_example = None
        if config.select_one_column:
            selection_loss_per_example, logits = _single_column_cell_selection_loss(
                token_logits=logits,
                column_logits=column_logits,
                label_ids=label_ids,
                cell_index=cell_index,
                col_index=col_index,
                cell_mask=cell_mask)
            dist_per_token = tfp.distributions.Bernoulli(logits=logits)
        else:
            weight = tf.where(
                label_ids == 0, tf.ones_like(label_ids, dtype=tf.float32),
                config.positive_weight *
                tf.ones_like(label_ids, dtype=tf.float32))
            selection_loss_per_token = -dist_per_token.log_prob(
                label_ids) * weight
            selection_loss_per_example = (
                tf.reduce_sum(selection_loss_per_token * input_mask_float,
                              axis=1) /
                (tf.reduce_sum(input_mask_float, axis=1) +
                 _EPSILON_ZERO_DIVISION))

        ### Logits for the aggregation function
        #########################################

        logits_aggregation = None
        if do_model_aggregation:
            logits_aggregation = _calculate_aggregation_logits(
                output_layer_aggregation, output_weights_agg, output_bias_agg)

        ### Classification loss
        ###############################
        if do_model_classification:
            one_hot_labels = tf.one_hot(classification_class_index,
                                        depth=config.num_classification_labels,
                                        dtype=tf.float32)
            if config.classification_label_weight:
                label_weights = [
                    config.classification_label_weight.get(i, 1.0)
                    for i in range(config.num_classification_labels)
                ]
                one_hot_labels *= tf.constant(label_weights, dtype=tf.float32)
            log_probs = tf.nn.log_softmax(logits_cls, axis=-1)
            # <float32>[batch_size]
            per_example_classification_intermediate = -tf.reduce_sum(
                one_hot_labels * log_probs, axis=-1)

            cls_loss = tf.reduce_mean(per_example_classification_intermediate)
            total_loss += cls_loss

        ### Supervised cell selection
        ###############################

        span_indexes = None
        span_logits = None
        if config.span_prediction != SpanPredictionMode.NONE:
            (
                span_indexes,
                span_logits,
                span_loss,
            ) = span_prediction_utils.get_span_logits_by_mode(
                config.span_prediction,
                output_layer,
                label_ids,
                column_ids,
                row_ids,
                max_span_length=10,
            )
            total_loss += span_loss
        elif config.disable_per_token_loss:
            pass
        elif config.mask_examples_without_labels:
            total_loss += tf.reduce_mean(
                span_prediction_utils.compute_masked_example_loss(
                    label_ids,
                    selection_loss_per_example,
                ))
        elif is_supervised:
            total_loss += tf.reduce_mean(selection_loss_per_example)
        else:
            # For the not supervissed case, do not assign loss for cell selection
            total_loss += tf.reduce_mean(selection_loss_per_example *
                                         (1.0 - aggregate_mask))

        ### Semi-supervised regression loss and supervised loss for aggregations
        #########################################################################

        if do_model_aggregation:
            # Note that `aggregate_mask` is None if the setting is supervised.
            per_example_additional_loss = _calculate_aggregation_loss(
                logits_aggregation, aggregate_mask, aggregation_function_id,
                config)

            if config.use_answer_as_supervision:
                # Add regression loss for numeric answers which require aggregation.
                answer_loss, large_answer_loss_mask = _calculate_regression_loss(
                    answer, aggregate_mask, dist_per_token, numeric_values,
                    numeric_values_scale, table_mask_float, logits_aggregation,
                    config)
                per_example_additional_loss += answer_loss
                # Zero loss for examples with answer_loss > cutoff.
                per_example_additional_loss *= large_answer_loss_mask

            total_loss += tf.reduce_mean(per_example_additional_loss)

        return Outputs(
            total_loss=total_loss,
            logits=logits,
            probs=_get_probs(dist_per_token) * input_mask_float,
            logits_aggregation=logits_aggregation,
            logits_cls=logits_cls,
            span_indexes=span_indexes,
            span_logits=span_logits,
        )
Ejemplo n.º 12
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        segment_ids = unpacked_inputs[1]
        column_ids = unpacked_inputs[2]
        row_ids = unpacked_inputs[3]
        prev_label_ids = unpacked_inputs[4]
        column_ranks = unpacked_inputs[5]
        inv_column_ranks = unpacked_inputs[6]
        numeric_relations = unpacked_inputs[7]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        token_type_ids_list = [segment_ids, column_ids, row_ids, prev_label_ids,
                               column_ranks, inv_column_ranks, numeric_relations]
        token_type_embeddings_list = [self.segment_embeddings, self.column_embeddings, self.row_embeddings, self.prev_label_embeddings,
                                      self.column_ranks_embeddings, self.inv_column_ranks_embeddings, self.numeric_relations_embeddings]
        if self.use_type_embeddings:
            for i, (token_type_ids, type_embeddings) in enumerate(zip(token_type_ids_list, token_type_embeddings_list)):
                flat_token_type_ids = tf.reshape(token_type_ids, [-1])
                one_hot_ids = tf.one_hot(
                    flat_token_type_ids,
                    depth=self.token_type_vocab_size[i],
                    dtype=self.dtype)
                token_type_embeddings = tf.matmul(
                    one_hot_ids, type_embeddings)
                token_type_embeddings = tf.reshape(token_type_embeddings,
                                                   [batch_size, seq_length, width])
                output += token_type_embeddings

        if self.use_position_embeddings:
            if not self.reset_position_index_per_cell:
                position_embeddings = tf.expand_dims(
                    tf.slice(self.position_embeddings, [
                        0, 0], [seq_length, width]),
                    axis=0)
            else:
                col_index = segmented_tensor.IndexMap(
                    token_type_ids_list[1], self.token_type_vocab_size[1], batch_dims=1)
                row_index = segmented_tensor.IndexMap(
                    token_type_ids_list[2], self.token_type_vocab_size[2], batch_dims=1)
                full_index = segmented_tensor.ProductIndexMap(
                    col_index, row_index)
                position = tf.expand_dims(tf.range(seq_length), axis=0)
                batched_position = tf.repeat(
                    position, repeats=batch_size, axis=0)
                first_position_per_segment = segmented_tensor.reduce_min(
                    batched_position, full_index)[0]
                first_position = segmented_tensor.gather(first_position_per_segment,
                                                         full_index)
                position_embeddings = tf.nn.embedding_lookup(self.position_embeddings,
                                                             position - first_position)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(
            output, training=kwargs.get('training', False))

        return output