def test_shift_elements_right_2d(self):
        tensor = tf.constant([
            [1, 2, 3, 4],  #
            [5, 6, 7, 8],  #
            [9, 10, 11, 12],  #
        ])

        self.assertAllEqual(
            [
                [0, 1, 2, 3],  #
                [0, 5, 6, 7],  #
                [0, 9, 10, 11],  #
            ],
            tensor_utils.shift_elements_right(tensor))

        self.assertAllEqual(
            [
                [3, 4, -1, -1],  #
                [7, 8, -1, -1],  #
                [11, 12, -1, -1],  #
            ],
            tensor_utils.shift_elements_right(tensor, amount=-2, pad_value=-1))

        self.assertAllEqual(
            [
                [0, 0, 0, 0],  #
                [0, 0, 0, 0],  #
                [1, 2, 3, 4],  #
            ],
            tensor_utils.shift_elements_right(tensor, axis=-2, amount=2))

        self.assertAllEqual(
            tensor, tensor_utils.shift_elements_right(tensor, amount=0))
    def test_shift_elements_right_1d(self):
        tensor = tf.constant([5, 4, 3, 2, 1])

        self.assertAllEqual([0, 5, 4, 3, 2],
                            tensor_utils.shift_elements_right(tensor))

        self.assertAllEqual([0, 0, 0, 5, 4],
                            tensor_utils.shift_elements_right(tensor,
                                                              amount=3))

        self.assertAllEqual([3, 2, 1, 0, 0],
                            tensor_utils.shift_elements_right(tensor,
                                                              amount=-2))

        self.assertAllEqual([3, 2, 1, -1, -1],
                            tensor_utils.shift_elements_right(tensor,
                                                              amount=-2,
                                                              pad_value=-1))

        self.assertAllEqual([0, 0, 0, 0, 0],
                            tensor_utils.shift_elements_right(tensor,
                                                              amount=10))

        self.assertAllEqual(
            tensor, tensor_utils.shift_elements_right(tensor, amount=0))

        with self.assertRaises(ValueError):
            tensor_utils.shift_elements_right(tensor, axis=1)
Пример #3
0
def make_att_mask_from_breakpoints(att_breakpoints: tf.Tensor,
                                   use_starting_breakpoints: bool = False,
                                   name: Optional[Text] = None) -> tf.Tensor:
    """Makes self-attention mask from attention breakpoints.

  Each attention breakpoint marks the end of a segment by default (or the
  start if `use_starting_breakpoints` is True), and the resulting
  mask prevents attention across different segments.

  Args:
    att_breakpoints: <int32>[batch_size, seq_len] Tensor containing only 0 and 1
      values, where each "1" marks the end of a segment (or the start, depending
      on `use_starting_breakpoints`).
    use_starting_breakpoints: If True, breakpoints represent starts of segments
      rather than ends of segments. Default False.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, seq_len] attention mask.
  """
    with tf.name_scope(name or 'make_att_mask_from_breakpoints'):
        att_breakpoints = tf.convert_to_tensor(att_breakpoints)

        if att_breakpoints.shape.rank != 2:
            raise ValueError('`att_breakpoints` must be a 2-D tensor.')

        if not use_starting_breakpoints:
            att_breakpoints = tensor_utils.shift_elements_right(
                att_breakpoints, axis=-1, amount=1)

        segment_ids = tf.cumsum(att_breakpoints, axis=1)
        return make_segmented_att_mask(segment_ids)
Пример #4
0
def make_local_att_mask_from_breakpoints(
        att_breakpoints: tf.Tensor,
        local_radius: int,
        use_starting_breakpoints: bool = False,
        name: Optional[Text] = None) -> tf.Tensor:
    """Makes local self-attention mask from attention breakpoints.

  Each attention breakpoint marks the end of a segment by default (or the
  start if `use_starting_breakpoints` is True), and the resulting
  mask prevents attention across different segments. The result can be used as
  `l2l_att_mask` in `layers.GlobalLocalTransformerLayers` for example.

  Args:
    att_breakpoints: <int32>[batch_size, seq_len] Tensor containing only 0 and 1
      values, where each "1" marks the end of a segment (or the start, depending
      on `use_starting_breakpoints`).
    local_radius: The local radius as expected by
      `layers.GlobalLocalTransformerLayers`. Must be positive.
    use_starting_breakpoints: If True, breakpoints represent starts of segments
      rather than ends of segments. Default False.
    name: A name for the operation (optional).

  Returns:
    <int32>[batch_size, seq_len, 2*local_radius + 1] attention mask.
  """
    with tf.name_scope(name or 'make_local_att_mask_from_breakpoints'):
        att_breakpoints = tf.convert_to_tensor(att_breakpoints)

        if att_breakpoints.shape.rank != 2:
            raise ValueError('`att_breakpoints` must be a 2-D tensor.')

        if not use_starting_breakpoints:
            att_breakpoints = tensor_utils.shift_elements_right(
                att_breakpoints, axis=-1, amount=1)

        # [batch_size, seq_len]
        segment_ids = tf.cumsum(att_breakpoints, axis=1)

        return make_local_segmented_att_mask(segment_ids, local_radius)
def _build_model(model_config, features, is_training, flags):
  """Build an ETC model for OpenKP."""

  global_embedding_adder = None
  long_embedding_adder = None

  # Create `global_embedding_adder` if using visual features.
  if flags.use_visual_features_in_global or flags.use_visual_features_in_long:
    global_embedding_adder = _create_global_visual_feature_embeddings(
        model_config, features, flags)

  if flags.use_visual_features_in_long:
    # Create `long_embedding_adder` based on `global_embedding_adder`
    long_embedding_adder = gather_global_embeddings_to_long(
        global_embedding_adder, features['long_vdom_idx'])

  if not flags.use_visual_features_in_global:
    global_embedding_adder = None

  model = modeling.EtcModel(
      config=model_config,
      is_training=is_training,
      use_one_hot_relative_embeddings=flags.use_tpu)

  model_inputs = dict(
      token_ids=features['long_token_ids'],
      global_token_ids=features['global_token_ids'],
      long_embedding_adder=long_embedding_adder,
      global_embedding_adder=global_embedding_adder)
  for field in attr.fields(input_utils.GlobalLocalTransformerSideInputs):
    model_inputs[field.name] = features[field.name]

  long_output, _ = model(**model_inputs)

  word_embeddings_unnormalized = batch_segment_sum_embeddings(
      long_embeddings=long_output,
      long_word_idx=features['long_word_idx'],
      long_input_mask=features['long_input_mask'])
  word_emb_layer_norm = tf.keras.layers.LayerNormalization(
      axis=-1, epsilon=1e-12, name='word_emb_layer_norm')
  word_embeddings = word_emb_layer_norm(word_embeddings_unnormalized)

  ngram_logit_list = []
  for i in range(flags.kp_max_length):
    conv = tf.keras.layers.Conv1D(
        filters=model_config.hidden_size,
        kernel_size=i + 1,
        padding='valid',
        activation=tensor_utils.get_activation('gelu'),
        kernel_initializer=tf.keras.initializers.TruncatedNormal(
            stddev=0.02 / math.sqrt(i + 1)),
        name=f'{i + 1}gram_conv')
    layer_norm = tf.keras.layers.LayerNormalization(
        axis=-1, epsilon=1e-12, name=f'{i + 1}gram_layer_norm')

    logit_dense = tf.keras.layers.Dense(
        units=1,
        activation=None,
        use_bias=False,
        kernel_initializer=tf.keras.initializers.TruncatedNormal(stddev=0.02),
        name=f'logit_dense{i}')
    # [batch_size, long_max_length - i]
    unpadded_logits = tf.squeeze(
        logit_dense(layer_norm(conv(word_embeddings))), axis=-1)

    # Pad to the right to get back to `long_max_length`.
    padded_logits = tf.pad(unpadded_logits, paddings=[[0, 0], [0, i]])

    # Padding logits should be ignored, so we make a large negative mask adder
    # for them.
    shifted_word_mask = tf.cast(
        tensor_utils.shift_elements_right(
            features['long_word_input_mask'], axis=-1, amount=-i),
        dtype=padded_logits.dtype)
    mask_adder = -10000.0 * (1.0 - shifted_word_mask)

    ngram_logit_list.append(padded_logits * shifted_word_mask + mask_adder)

  # [batch_size, kp_max_length, long_max_length]
  ngram_logits = tf.stack(ngram_logit_list, axis=1)

  extra_model_losses = model.losses

  return ngram_logits, extra_model_losses
    def test_shift_elements_right_3d(self):
        tensor = tf.constant([
            [
                [1, -1],  #
                [2, -2],  #
                [3, -3],  #
            ],  #
            [
                [4, -4],  #
                [5, -5],  #
                [6, -6],  #
            ],  #
        ])

        self.assertAllEqual(
            [
                [
                    [-1, 0],  #
                    [-2, 0],  #
                    [-3, 0],  #
                ],  #
                [
                    [-4, 0],  #
                    [-5, 0],  #
                    [-6, 0],  #
                ],  #
            ],
            tensor_utils.shift_elements_right(tensor, amount=-1))

        self.assertAllEqual(
            [
                [
                    [-1, -1],  #
                    [-1, -1],  #
                    [-1, -1],  #
                ],  #
                [
                    [1, -1],  #
                    [2, -2],  #
                    [3, -3],  #
                ],  #
            ],
            tensor_utils.shift_elements_right(tensor, axis=0, pad_value=-1))

        self.assertAllEqual(
            [
                [
                    [0, 0],  #
                    [0, 0],  #
                    [1, -1],  #
                ],  #
                [
                    [0, 0],  #
                    [0, 0],  #
                    [4, -4],  #
                ],  #
            ],
            tensor_utils.shift_elements_right(tensor, axis=1, amount=2))

        self.assertAllEqual(
            tensor, tensor_utils.shift_elements_right(tensor, axis=2,
                                                      amount=0))