예제 #1
0
def build_model(etc_model_config: modeling.EtcConfig,
                features: Dict[str, tf.Tensor], flat_sequence: bool,
                is_training: bool, answer_encoding_method: str, use_tpu: bool,
                use_wordpiece: bool):
    """Build the ETC HotpotQA model."""
    long_token_ids = features["long_token_ids"]
    long_sentence_ids = features["long_sentence_ids"]
    long_paragraph_ids = features["long_paragraph_ids"]
    long_paragraph_breakpoints = features["long_paragraph_breakpoints"]
    long_token_type_ids = features["long_token_type_ids"]
    global_token_ids = features["global_token_ids"]
    global_paragraph_breakpoints = features["global_paragraph_breakpoints"]
    global_token_type_ids = features["global_token_type_ids"]

    model = modeling.EtcModel(config=etc_model_config,
                              is_training=is_training,
                              use_one_hot_relative_embeddings=use_tpu)

    model_inputs = dict(token_ids=long_token_ids,
                        global_token_ids=global_token_ids,
                        segment_ids=long_token_type_ids,
                        global_segment_ids=global_token_type_ids)

    cls_token_id = (generate_tf_examples_lib.
                    SENTENCEPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])
    if use_wordpiece:
        cls_token_id = (generate_tf_examples_lib.
                        WORDPIECE_DEFAULT_GLOBAL_TOKEN_IDS["CLS_TOKEN_ID"])

    model_inputs.update(
        qa_input_utils.make_global_local_transformer_side_inputs(
            long_paragraph_breakpoints=long_paragraph_breakpoints,
            long_paragraph_ids=long_paragraph_ids,
            long_sentence_ids=long_sentence_ids,
            global_paragraph_breakpoints=global_paragraph_breakpoints,
            local_radius=etc_model_config.local_radius,
            relative_pos_max_distance=etc_model_config.
            relative_pos_max_distance,
            use_hard_g2l_mask=etc_model_config.use_hard_g2l_mask,
            ignore_hard_g2l_mask=tf.cast(tf.equal(global_token_ids,
                                                  cls_token_id),
                                         dtype=long_sentence_ids.dtype),
            flat_sequence=flat_sequence,
            use_hard_l2g_mask=etc_model_config.use_hard_l2g_mask).to_dict(
                exclude_none_values=True))

    long_output, global_output = model(**model_inputs)

    batch_size, long_seq_length, long_hidden_size = tensor_utils.get_shape_list(
        long_output, expected_rank=3)
    _, global_seq_length, global_hidden_size = tensor_utils.get_shape_list(
        global_output, expected_rank=3)

    long_output_matrix = tf.reshape(
        long_output, [batch_size * long_seq_length, long_hidden_size])
    global_output_matrix = tf.reshape(
        global_output, [batch_size * global_seq_length, global_hidden_size])

    # Get the logits for the supporting facts predictions.
    supporting_facts_output_weights = tf.get_variable(
        "supporting_facts_output_weights", [1, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    supporting_facts_output_bias = tf.get_variable(
        "supporting_facts_output_bias", [1],
        initializer=tf.zeros_initializer())
    supporting_facts_logits = tf.matmul(global_output_matrix,
                                        supporting_facts_output_weights,
                                        transpose_b=True)
    supporting_facts_logits = tf.nn.bias_add(supporting_facts_logits,
                                             supporting_facts_output_bias)
    supporting_facts_logits = tf.reshape(supporting_facts_logits,
                                         [batch_size, global_seq_length])

    # Get the logits for the answer type prediction.
    num_answer_types = 3  # SPAN, YES, NO
    answer_type_output_weights = tf.get_variable(
        "answer_type_output_weights", [num_answer_types, global_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    answer_type_output_bias = tf.get_variable(
        "answer_type_output_bias", [num_answer_types],
        initializer=tf.zeros_initializer())
    answer_type_logits = tf.matmul(global_output[:, 0, :],
                                   answer_type_output_weights,
                                   transpose_b=True)
    answer_type_logits = tf.nn.bias_add(answer_type_logits,
                                        answer_type_output_bias)

    extra_model_losses = model.losses

    if answer_encoding_method == "span":
        # Get the logits for the begin and end indices.
        answer_span_output_weights = tf.get_variable(
            "answer_span_output_weights", [2, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_span_output_bias = tf.get_variable(
            "answer_span_output_bias", [2], initializer=tf.zeros_initializer())
        answer_span_logits = tf.matmul(long_output_matrix,
                                       answer_span_output_weights,
                                       transpose_b=True)
        answer_span_logits = tf.nn.bias_add(answer_span_logits,
                                            answer_span_output_bias)
        answer_span_logits = tf.reshape(answer_span_logits,
                                        [batch_size, long_seq_length, 2])
        answer_span_logits = tf.transpose(answer_span_logits, [2, 0, 1])
        answer_begin_logits, answer_end_logits = tf.unstack(answer_span_logits,
                                                            axis=0)

        return (supporting_facts_logits, (answer_begin_logits,
                                          answer_end_logits),
                answer_type_logits, extra_model_losses)
    else:
        # Get the logits for the answer BIO encodings.
        answer_bio_output_weights = tf.get_variable(
            "answer_bio_output_weights", [3, long_hidden_size],
            initializer=tf.truncated_normal_initializer(stddev=0.02))
        answer_type_output_bias = tf.get_variable(
            "answer_bio_output_bias", [3], initializer=tf.zeros_initializer())
        answer_bio_logits = tf.matmul(long_output_matrix,
                                      answer_bio_output_weights,
                                      transpose_b=True)
        answer_bio_logits = tf.nn.bias_add(answer_bio_logits,
                                           answer_type_output_bias)
        answer_bio_logits = tf.reshape(answer_bio_logits,
                                       [batch_size, long_seq_length, 3])

        return (supporting_facts_logits, answer_bio_logits, answer_type_logits,
                extra_model_losses)
예제 #2
0
def build_model(etc_model_config, features, is_training, flags):
    """Build an ETC model."""
    token_ids = features["token_ids"]
    global_token_ids = features["global_token_ids"]

    model = modeling.EtcModel(config=etc_model_config,
                              is_training=is_training,
                              use_one_hot_relative_embeddings=flags.use_tpu)

    model_inputs = dict(token_ids=token_ids, global_token_ids=global_token_ids)
    for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
        if field.name in features:
            model_inputs[field.name] = features[field.name]

    # Get the logits for the start and end predictions.
    l_final_hidden, _ = model(**model_inputs)

    l_final_hidden_shape = tensor_utils.get_shape_list(l_final_hidden,
                                                       expected_rank=3)

    batch_size = l_final_hidden_shape[0]
    l_seq_length = l_final_hidden_shape[1]
    hidden_size = l_final_hidden_shape[2]

    num_answer_types = 5  # NULL, YES, NO, LONG, SHORT

    # We add a dense layer to the long output:
    l_output_weights = tf.get_variable(
        "cls/nq/long_output_weights", [4, hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))
    l_output_bias = tf.get_variable("cls/nq/long_output_bias", [4],
                                    initializer=tf.zeros_initializer())
    l_final_hidden_matrix = tf.reshape(
        l_final_hidden, [batch_size * l_seq_length, hidden_size])
    l_logits = tf.matmul(l_final_hidden_matrix,
                         l_output_weights,
                         transpose_b=True)
    l_logits = tf.nn.bias_add(l_logits, l_output_bias)
    l_logits = tf.reshape(l_logits, [batch_size, l_seq_length, 4])

    if flags.mask_long_output:
        # Mask out invalid SA/LA start/end positions:
        # 1) find the SEP and CLS tokens:
        long_sep = tf.cast(tf.equal(token_ids, flags.sep_tok_id), tf.int32)
        long_not_sep = 1 - long_sep
        long_cls = tf.cast(tf.equal(token_ids, flags.cls_tok_id), tf.int32)

        # 2) accum sum the SEPs, and the only possible answers are those with sum
        #    equal to 1 (except SEPs) and the CLS position
        l_mask = tf.cast(tf.equal(tf.cumsum(long_sep, axis=-1), 1), tf.int32)
        l_mask = 1 - ((l_mask * long_not_sep) + long_cls)

        # 3) apply the mask to the logits
        l_mask = tf.expand_dims(tf.cast(l_mask, tf.float32) * -10E8, 2)
        l_logits = tf.math.add(l_logits, l_mask)

    # Get the logits for the answer type prediction.
    answer_type_output_layer = l_final_hidden[:, 0, :]
    answer_type_hidden_size = answer_type_output_layer.shape[-1].value

    answer_type_output_weights = tf.get_variable(
        "answer_type_output_weights",
        [num_answer_types, answer_type_hidden_size],
        initializer=tf.truncated_normal_initializer(stddev=0.02))

    answer_type_output_bias = tf.get_variable(
        "answer_type_output_bias", [num_answer_types],
        initializer=tf.zeros_initializer())

    answer_type_logits = tf.matmul(answer_type_output_layer,
                                   answer_type_output_weights,
                                   transpose_b=True)
    answer_type_logits = tf.nn.bias_add(answer_type_logits,
                                        answer_type_output_bias)

    extra_model_losses = model.losses

    l_logits = tf.transpose(l_logits, [2, 0, 1])
    l_unstacked_logits = tf.unstack(l_logits, axis=0)
    return ([l_unstacked_logits[i]
             for i in range(4)], answer_type_logits, extra_model_losses)
예제 #3
0
def process_model_output(model_config,
                         mode,
                         global_output_tensor,
                         global_token_type_ids_tensor,
                         labels,
                         is_real_example,
                         add_final_layer=True,
                         label_smoothing=0.0):
    """Process model output embeddings and computes loss, logits etc."""

    global_output_tensor_shape = tensor_utils.get_shape_list(
        global_output_tensor, expected_rank=3)
    batch_size = global_output_tensor_shape[0]
    global_seq_len = global_output_tensor_shape[1]
    hidden_size = global_output_tensor_shape[2]

    global_output_tensor = tf.reshape(
        global_output_tensor, [batch_size * global_seq_len, hidden_size])

    if add_final_layer:
        with tf.variable_scope("global_output_layer/transform"):
            is_training = True if mode == tf.estimator.ModeKeys.TRAIN else False
            final_layer = wrappers.ResidualBlock(
                inner_intermediate_size=model_config.intermediate_size,
                inner_activation=tensor_utils.get_activation(
                    model_config.hidden_act),
                use_pre_activation_order=False,
                dropout_probability=model_config.hidden_dropout_prob)
            global_output_tensor = final_layer(global_output_tensor,
                                               training=is_training)

    output_weights = tf.get_variable(
        "output_weights", [1, model_config.hidden_size],
        initializer=tf.truncated_normal_initializer(
            stddev=model_config.initializer_range))

    output_bias = tf.get_variable("output_bias", [1],
                                  initializer=tf.zeros_initializer())

    with tf.variable_scope("loss"):
        logits = tf.matmul(global_output_tensor,
                           output_weights,
                           transpose_b=True)
        logits = tf.nn.bias_add(logits, output_bias)

        tf.logging.info("*** logits initial are {} *** ".format(logits))
        logits = tf.reshape(logits, [batch_size, global_seq_len])
        tf.logging.info("*** logits after reshape are {} *** ".format(logits))

        # Consider only candidate global tokens in the global output.
        multiplier_mask = tf.cast(tf.equal(
            global_token_type_ids_tensor,
            multihop_utils.CANDIDATE_GLOBAL_TOKEN_TYPE_ID),
                                  dtype=logits.dtype)

        adder_mask = -10000.0 * (1.0 - multiplier_mask)

        logits = (logits * multiplier_mask + adder_mask)

        tf.logging.info("*** global_token_type_ids_tensor is {} *** ".format(
            global_token_type_ids_tensor))
        tf.logging.info("*** adder_mask is {} *** ".format(adder_mask))
        tf.logging.info(
            "*** multiplier_mask is {} *** ".format(multiplier_mask))
        tf.logging.info("*** logits computed are {} *** ".format(logits))

        # probabilities = tf.nn.softmax(logits, axis=-1)
        log_probs = tf.nn.log_softmax(logits, axis=-1)
        one_hot_labels = tf.one_hot(labels,
                                    depth=global_seq_len,
                                    dtype=tf.float32)
        if label_smoothing > 0:
            num_classes = tf.reduce_sum(multiplier_mask, axis=-1)
            num_classes = tf.expand_dims(num_classes, -1)
            one_hot_labels = (1 - label_smoothing) * one_hot_labels
            one_hot_labels += (label_smoothing / num_classes)
            # Ensure smoothing of labels only for applicable global (candidate)
            # tokens.
            one_hot_labels *= multiplier_mask

        per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)

        numerator = tf.reduce_sum(per_example_loss * is_real_example)
        denominator = tf.reduce_sum(is_real_example) + 1e-5
        loss = numerator / denominator

        return (loss, per_example_loss, logits)
예제 #4
0
def make_fixed_block_side_inputs(
    input_mask: tf.Tensor,
    num_tokens_per_block: int,
    local_radius: int,
    relative_pos_max_distance: int,
    use_hard_g2l_mask: bool = False,
    use_hard_l2g_mask: bool = False,
    global_token_id: int = 1,
    name: Optional[Text] = None
) -> Tuple[GlobalLocalTransformerSideInputs, tf.Tensor]:
    """Utility for creating side inputs in a "fixed blocks" pattern.

  The "fixed blocks" experiments for NQ and OpenKP are implemented via example
  generation rather than using this function, but we include this function
  to illustrate how side inputs can be generated given just a BERT-style
  `input_mask` feature.  The corresponding global tokens are generated
  as part of this function too, so no global features are required as input.

  Args:
    input_mask: <int32>[batch_size, long_seq_len] Tensor of 1 and 0 values, with
      1 for actual tokens and 0 for padding.  This is the same format as
      original BERT.  `long_seq_len` must be statically known.
    num_tokens_per_block: Positive integer number of long tokens to assign to
      each global token.  For pre-training on the original BERT data (which was
      also used for ETC pre-training), the dataset implied a value of about 27,
      but values like 16 or 32 would also be reasonable.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to.  For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations.  All larger distances will be clipped to this value. Use
      0 to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of their
      corresponding block in the long input.  If False, global tokens attend to
      all non-padding long tokens.  False is the default setup.
    use_hard_l2g_mask: If True, long tokens only attend to the global token
      corresponding to their block.  If False, long tokens attend to all the
      non-padding global tokens.  False is the default setup.
    global_token_id: Integer id to use for global tokens.  The default is `1`,
      which was the value used during ETC pre-training.
    name: A name for the operation (optional).

  Returns:
    A tuple with the following 2 elements:
      side_inputs: A `GlobalLocalTransformerSideInputs` object containing all
        side input tensors.
      global_token_ids: <int32>[batch_size, global_seq_len] Tensor of global
        tokens ids suitable to pass into `EtcModel`.  All global tokens will
        use the same `global_token_id`, except for padding tokens.
  """
    if num_tokens_per_block <= 0:
        raise ValueError('`num_tokens_per_block` must be positive.')

    with tf.name_scope(name or 'make_fixed_block_side_inputs'):
        input_mask = tf.convert_to_tensor(input_mask)

        batch_size = tensor_utils.get_shape_list(input_mask)[0]
        long_seq_len = input_mask.shape.as_list()[1]
        if long_seq_len is None:
            raise ValueError('`long_seq_len` must be statically known.')

        global_seq_len = (long_seq_len + num_tokens_per_block -
                          1) // num_tokens_per_block

        # [batch_size, global_seq_len, num_tokens_per_block]
        blocked_input_mask = tensor_utils.split_into_blocks(
            input_mask, block_len=num_tokens_per_block, axis=-1)
        assert blocked_input_mask.shape.as_list()[1] == global_seq_len

        # [batch_size, global_seq_len]
        global_input_mask = tf.minimum(
            tf.reduce_max(blocked_input_mask, axis=-1), 1)

        # [long_seq_len]
        sentence_ids = tf.repeat(tf.range(global_seq_len, dtype=tf.int32),
                                 num_tokens_per_block)[:long_seq_len]

        # [batch_size, long_seq_len]
        sentence_ids = tf.broadcast_to(sentence_ids,
                                       [batch_size, long_seq_len])

        side_inputs = make_global_local_transformer_side_inputs_from_example_ids(
            long_example_ids=input_mask,
            global_example_ids=global_input_mask,
            sentence_ids=sentence_ids,
            local_radius=local_radius,
            relative_pos_max_distance=relative_pos_max_distance,
            use_hard_g2l_mask=use_hard_g2l_mask,
            use_hard_l2g_mask=use_hard_l2g_mask)
        global_token_ids = global_token_id * global_input_mask
        return side_inputs, global_token_ids
예제 #5
0
def make_global_local_transformer_side_inputs_from_example_ids(
        long_example_ids: tf.Tensor,
        global_example_ids: tf.Tensor,
        sentence_ids: tf.Tensor,
        local_radius: int,
        relative_pos_max_distance: int,
        use_hard_g2l_mask: bool = False,
        use_hard_l2g_mask: bool = False,
        name: Optional[Text] = None) -> GlobalLocalTransformerSideInputs:
    """Makes side input tensors based on the given example and sentence ids.

  When packing examples (e.g. for pre-training), each example must have a
  unique id for `long_example_ids`/`global_example_ids`, and padding must
  also have a unique id distinct from all the example ids.

  When not packing examples, there will simply be two unique ids: one for
  example tokens, and another for padding.  Note that in this case, the classic
  BERT `input_mask` is a valid special case of `long_example_ids`.

  The other arguments have the same interpretation as in
  `make_global_local_transformer_side_inputs`.

  Args:
    long_example_ids: <int32>[batch_size, long_seq_len] Tensor of example ids of
      different packed examples.
    global_example_ids: <int32>[batch_size, global_seq_len] Tensor of example
      ids of different packed examples.
    sentence_ids: <int32>[batch_size, long_seq_len] Tensor of ids indicating
      which sentence each token belongs to. For this dataset, "sentence" refers
      to real natural language sentence, not a BERT "sentence" from the "next
      sentence prediction" task.
    local_radius: How many tokens to the left/right for input tokens to locally
      self-attend to. For example, a value of 1 would allow each token to only
      attend to 1 token to the left and 1 token to the right of it.
    relative_pos_max_distance: Maximum distance to use for relative position
      representations. All larger distances will be clipped to this value. Use 0
      to skip relative position representations entirely.
    use_hard_g2l_mask: If True, global tokens only attend to tokens of the
      corresponding sentences in the long input. If False, global tokens attend
      to all sentences within the corresponding global example.
    use_hard_l2g_mask: If True, long tokens only attend to tokens of the
      corresponding global tokens. If False, long tokens attend to all the
      global tokens within the corresponding global example.
    name: A name for the operation (optional).

  Returns:
    A `GlobalLocalTransformerSideInputs` with all relevant tensors set.
  """
    with tf.name_scope(name or 'make_global_local_transformer_side_inputs'):
        long_example_ids = tf.convert_to_tensor(long_example_ids)
        global_example_ids = tf.convert_to_tensor(global_example_ids)
        sentence_ids = tf.convert_to_tensor(sentence_ids)

        long_seq_len = tensor_utils.get_shape_list(long_example_ids)[1]
        global_seq_len = tensor_utils.get_shape_list(global_example_ids)[1]

        l2l_att_mask = feature_utils.make_local_segmented_att_mask(
            long_example_ids, local_radius)
        g2g_att_mask = feature_utils.make_segmented_att_mask(
            global_example_ids)

        l2g_att_mask = tf.cast(
            tf.equal(long_example_ids[:, :, tf.newaxis],
                     global_example_ids[:, tf.newaxis, :]), tf.int32)
        g2l_att_mask = tf.transpose(l2g_att_mask, perm=[0, 2, 1])

        if use_hard_g2l_mask:
            # Have each global token attend to just one sentence instead of having
            # it attend to all the sentences within a global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_g2l_att_mask = tf.cast(
                tf.equal(global_range[tf.newaxis, :, tf.newaxis],
                         sentence_ids[:, tf.newaxis, :]), tf.int32)
            g2l_att_mask *= hard_g2l_att_mask

        if use_hard_l2g_mask:
            # Have each long token attend to just the corresponding global token
            # instead of having it attend to all the global tokens within a
            # global example.
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            hard_l2g_att_mask = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            l2g_att_mask *= hard_l2g_att_mask

        batch_size = tf.shape(long_example_ids)[0]

        l2l_relative_att_ids = None
        g2g_relative_att_ids = None
        l2g_relative_att_ids = None
        g2l_relative_att_ids = None

        if relative_pos_max_distance > 0:
            relative_pos_generator = feature_utils.RelativePositionGenerator(
                relative_pos_max_distance)
            l2l_relative_att_ids = relative_pos_generator.make_local_relative_att_ids(
                seq_len=long_seq_len,
                local_radius=local_radius,
                batch_size=batch_size)
            g2g_relative_att_ids = relative_pos_generator.make_relative_att_ids(
                seq_len=global_seq_len, batch_size=batch_size)
            global_range = tf.range(global_seq_len, dtype=sentence_ids.dtype)
            l2g_relative_att_ids = tf.cast(
                tf.equal(sentence_ids[:, :, tf.newaxis],
                         global_range[tf.newaxis, tf.newaxis, :]), tf.int32)
            g2l_relative_att_ids = tf.transpose(l2g_relative_att_ids,
                                                perm=[0, 2, 1])

            # For fused attention, l2l and l2g share the same relative vocabulary, as
            # do g2g and g2l, so we add an offset for l2g and g2l so their original
            # 0/1 ids don't collide with l2l and g2g relative position ids.
            l2g_relative_att_ids += relative_pos_generator.relative_vocab_size
            g2l_relative_att_ids += relative_pos_generator.relative_vocab_size

        return GlobalLocalTransformerSideInputs(
            l2l_att_mask=l2l_att_mask,
            g2g_att_mask=g2g_att_mask,
            l2g_att_mask=l2g_att_mask,
            g2l_att_mask=g2l_att_mask,
            l2l_relative_att_ids=l2l_relative_att_ids,
            g2g_relative_att_ids=g2g_relative_att_ids,
            l2g_relative_att_ids=l2g_relative_att_ids,
            g2l_relative_att_ids=g2l_relative_att_ids)
예제 #6
0
def make_local_segmented_att_mask(segment_ids: tf.Tensor,
                                  local_radius: int,
                                  name: Optional[Text] = None) -> tf.Tensor:
    """Makes local attention mask preventing attention across different segments.

  Restricts local self-attention to attend within segments, such that tokens can
  only attend to local tokens from the same segment id. The tokens in a segment
  do not need to be contiguous, but attention is still constrained by
  `local_radius`. The output can be used as `l2l_att_mask` in
  `layers.GlobalLocalTransformerLayers` for example.

  Args:
    segment_ids: <int32>[batch_size, seq_len] Tensor of segment ids, all of
      which must be non-negative.
    local_radius: The local radius as expected by
      `layers.GlobalLocalTransformerLayers`. Must be positive.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask.
  """
    with tf.name_scope(name or 'make_local_segmented_att_mask'):
        segment_ids = tf.convert_to_tensor(segment_ids)

        if segment_ids.shape.rank != 2:
            raise ValueError('`segment_ids` must be a 2-D tensor.')

        batch_size, seq_len = tensor_utils.get_shape_list(segment_ids)

        # Add 1 so that segment id `0` doesn't coincide with `0` padding values
        # introduced later by `tensor_utils.concat_3_blocks()` for example.
        segment_ids += 1

        # [batch_size, num_blocks, local_radius]
        blocked_segment_ids = tensor_utils.split_into_blocks(
            segment_ids, block_len=local_radius, axis=1)

        # [batch_size, num_blocks, 3*local_radius]
        concat_blocked_segment_ids = tensor_utils.concat_3_blocks(
            blocked_segment_ids)

        # [batch_size, num_blocks, local_radius, 3*local_radius]
        tiled_segment_ids = tf.tile(
            concat_blocked_segment_ids[:, :, tf.newaxis, :],
            [1, 1, local_radius, 1])

        # [batch_size, num_blocks, local_radius, 2*local_radius + 1]
        blocked_unskewed_segment_ids = tensor_utils.unskew_elements_right(
            tiled_segment_ids, axis=-1)

        # [batch_size, num_blocks * local_radius, 2*local_radius + 1]
        flat_unskewed_segment_ids = tensor_utils.flatten_dims(
            blocked_unskewed_segment_ids, first_dim=1, last_dim=2)

        # [batch_size, seq_len, 2*local_radius + 1]
        unskewed_segment_ids = tf.slice(flat_unskewed_segment_ids,
                                        begin=[0, 0, 0],
                                        size=[-1, seq_len, -1])

        # [batch_size, seq_len, 1]
        center_token_segment_id = unskewed_segment_ids[:, :, local_radius:(
            local_radius + 1)]

        # [batch_size, seq_len, 2*local_radius + 1]
        result = tf.cast(
            tf.equal(unskewed_segment_ids, center_token_segment_id), tf.int32)

        # Use `reshape` to set the static shape when known.
        return tf.reshape(result, [batch_size, seq_len, 2 * local_radius + 1])