def create_attention_mask_from_input_mask(from_tensor, to_mask):
    """Create 3D attention mask from a 2D tensor mask.
  
    Args:
      from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
      to_mask: int32 Tensor of shape [batch_size, to_seq_length].
  
    Returns:
      float Tensor of shape [batch_size, from_seq_length, to_seq_length].
    """
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
Exemplo n.º 2
0
def scatter_update(sequence, updates, positions):
    """Scatter-update a sequence.

  Args:
    sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth] tensor
    updates: A tensor of size batch_size*seq_len(*depth)
    positions: A [batch_size, n_positions] tensor

  Returns:
    updated_sequence: A [batch_size, seq_len] or [batch_size, seq_len, depth]
      tensor of "sequence" with elements at "positions" replaced by the values
      at "updates". Updates to index 0 are ignored. If there are duplicated
      positions the update is only applied once.
    updates_mask: A [batch_size, seq_len] mask tensor of which inputs were
      updated.
  """
    shape = tf_utils.get_shape_list(sequence, expected_rank=[2, 3])
    depth_dimension = (len(shape) == 3)
    if depth_dimension:
        batch_size, seq_len, depth = shape
    else:
        batch_size, seq_len = shape
        depth = 1
        sequence = tf.expand_dims(sequence, -1)
    n_positions = tf_utils.get_shape_list(positions)[1]

    shift = tf.expand_dims(seq_len * tf.range(batch_size), -1)
    flat_positions = tf.reshape(positions + shift, [-1, 1])
    flat_updates = tf.reshape(updates, [-1, depth])
    updates = tf.scatter_nd(flat_positions, flat_updates,
                            [batch_size * seq_len, depth])
    updates = tf.reshape(updates, [batch_size, seq_len, depth])

    flat_updates_mask = tf.ones([batch_size * n_positions], tf.int32)
    updates_mask = tf.scatter_nd(flat_positions, flat_updates_mask,
                                 [batch_size * seq_len])
    updates_mask = tf.reshape(updates_mask, [batch_size, seq_len])
    not_first_token = tf.concat([
        tf.zeros((batch_size, 1), tf.int32),
        tf.ones((batch_size, seq_len - 1), tf.int32)
    ], -1)
    updates_mask *= not_first_token
    updates_mask_3d = tf.expand_dims(updates_mask, -1)

    # account for duplicate positions
    if sequence.dtype == tf.float32:
        updates_mask_3d = tf.cast(updates_mask_3d, tf.float32)
        updates /= tf.maximum(1.0, updates_mask_3d)
    else:
        assert sequence.dtype == tf.int32
        updates = tf.math.floordiv(updates, tf.maximum(1, updates_mask_3d))
    updates_mask = tf.minimum(updates_mask, 1)
    updates_mask_3d = tf.minimum(updates_mask_3d, 1)

    updated_sequence = (((1 - updates_mask_3d) * sequence) +
                        (updates_mask_3d * updates))
    if not depth_dimension:
        updated_sequence = tf.squeeze(updated_sequence, -1)

    return updated_sequence, updates_mask
Exemplo n.º 3
0
 def call(self, inputs):
     sources = inputs["inputs"]
     targets = inputs["targets"]
     pos_embed = inputs["pos_embed"]
     mask = inputs["mask"]
     input_shape = tf_utils.get_shape_list(sources)
     source_attention_mask = tf.tile(tf.expand_dims(mask, axis=1),
                                     [1, input_shape[1], 1])
     memory = self._encoder(sources,
                            attention_mask=source_attention_mask,
                            pos_embed=pos_embed)
     target_shape = tf_utils.get_shape_list(targets)
     cross_attention_mask = tf.tile(tf.expand_dims(mask, axis=1),
                                    [1, target_shape[1], 1])
     target_shape = tf.shape(targets)
     decoded = self._decoder(
         tf.zeros_like(targets),
         memory,
         # TODO(b/199545430): self_attention_mask could be set to None when this
         # bug is resolved. Passing ones for now.
         self_attention_mask=tf.ones(
             (target_shape[0], target_shape[1], target_shape[1])),
         cross_attention_mask=cross_attention_mask,
         return_all_decoder_outputs=True,
         input_pos_embed=targets,
         memory_pos_embed=pos_embed)
     return decoded
Exemplo n.º 4
0
  def call(self, inputs):
    from_tensor = inputs[0]
    to_mask = inputs[1]
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(
        tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
        dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(
        shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
    def _chunk(hidden_states, window_overlap):
        """convert into overlapping chunks. Chunk size = 2w, overlap size = w."""
        batch_size, seq_length, hidden_dim = get_shape_list(hidden_states)
        num_output_chunks = 2 * (seq_length // (2 * window_overlap)) - 1

        # define frame size and frame stride (similar to convolution)
        frame_hop_size = window_overlap * hidden_dim
        frame_size = 2 * frame_hop_size
        hidden_states = tf.reshape(hidden_states,
                                   (batch_size, seq_length * hidden_dim))

        # chunk with overlap
        chunked_hidden_states = tf.signal.frame(hidden_states, frame_size,
                                                frame_hop_size)

        if tf.executing_eagerly():
            tf.debugging.assert_equal(
                get_shape_list(chunked_hidden_states),
                [batch_size, num_output_chunks, frame_size],
                message=
                f"Make sure chunking is correctly applied. `Chunked hidden "
                f"states should have output dimension"
                f" {[batch_size, frame_size, num_output_chunks]}, but got "
                f"{get_shape_list(chunked_hidden_states)}.",
            )

        chunked_hidden_states = tf.reshape(
            chunked_hidden_states,
            (batch_size, num_output_chunks, 2 * window_overlap, hidden_dim),
        )

        return chunked_hidden_states
Exemplo n.º 6
0
    def call(self, query: tf.Tensor, key: tf.Tensor):
        """Implements the forward pass.

    Args:
      query: query input tensor shape [batch, query length, hidden size].
      key: key input tensor shape [batch, key length, hidden size].

    Returns:
      A tensor in shape of [batch, heads, query length, key length].
    """
        batch_size, qlen = tf_utils.get_shape_list(query)[:2]
        klen = tf_utils.get_shape_list(key)[1]
        context_position = tf.range(qlen)[:, None]
        memory_position = tf.range(klen)[None, :]
        relative_position = memory_position - context_position
        rp_bucket = _relative_position_bucket(
            relative_position,
            bidirectional=self.bidirectional,
            num_buckets=self.relative_attention_num_buckets,
            max_distance=self.relative_attention_max_distance)
        values = tf.nn.embedding_lookup(self._relative_attention_bias,
                                        rp_bucket)
        values = tf.expand_dims(tf.transpose(values, [2, 0, 1]),
                                axis=0)  # shape (1, num_heads, qlen, klen)
        values = tf.tile(values, [batch_size, 1, 1, 1])
        return values
    def _mask_invalid_locations(input_tensor, window_overlap):
        # create correct upper triangle bool mask
        mask_2d_upper = tf.reverse(
            tf.linalg.band_part(
                tf.ones(shape=(window_overlap, window_overlap + 1)), -1, 0),
            axis=[0],
        )

        # pad to full matrix
        padding = tf.convert_to_tensor(
            [[0, get_shape_list(input_tensor)[1] - window_overlap],
             [0, get_shape_list(input_tensor)[3] - window_overlap - 1]])

        # create lower mask
        mask_2d = tf.pad(mask_2d_upper, padding)

        # combine with upper mask
        mask_2d = mask_2d + tf.reverse(mask_2d, axis=[0, 1])

        # broadcast to full matrix
        mask_4d = tf.tile(mask_2d[None, :, None, :],
                          (get_shape_list(input_tensor)[0], 1, 1, 1))

        # inf tensor used for masking
        inf_tensor = -float("inf") * tf.ones_like(input_tensor)

        # mask
        input_tensor = tf.where(tf.math.greater(mask_4d, 0), inf_tensor,
                                input_tensor)

        return input_tensor
    def _concat_with_global_key_attn_probs(
        self,
        attn_scores,
        key_vectors,
        query_vectors,
        max_num_global_attn_indices,
        is_index_global_attn_nonzero,
        is_local_index_global_attn_nonzero,
        is_local_index_no_global_attn_nonzero,
    ):
        batch_size = get_shape_list(key_vectors)[0]

        # select global key vectors
        global_key_vectors = tf.gather_nd(key_vectors,
                                          is_index_global_attn_nonzero)

        # create only global key vectors
        key_vectors_only_global = tf.scatter_nd(
            is_local_index_global_attn_nonzero,
            global_key_vectors,
            shape=(
                batch_size,
                max_num_global_attn_indices,
                self._num_heads,
                self._key_dim,
            ),
        )

        # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
        attn_probs_from_global_key = tf.einsum("blhd,bshd->blhs",
                                               query_vectors,
                                               key_vectors_only_global)

        # (batch_size, max_num_global_attn_indices, seq_len, num_heads)
        attn_probs_from_global_key_trans = tf.transpose(
            attn_probs_from_global_key, (0, 3, 1, 2))
        mask_shape = (get_shape_list(
            is_local_index_no_global_attn_nonzero)[0], ) + tuple(
                get_shape_list(attn_probs_from_global_key_trans)[-2:])
        mask = tf.ones(mask_shape) * -10000.0
        mask = tf.cast(mask, dtype=attn_probs_from_global_key_trans.dtype)

        # scatter mask
        attn_probs_from_global_key_trans = tf.tensor_scatter_nd_update(
            attn_probs_from_global_key_trans,
            is_local_index_no_global_attn_nonzero,
            mask,
        )

        # (batch_size, seq_len, num_heads, max_num_global_attn_indices)
        attn_probs_from_global_key = tf.transpose(
            attn_probs_from_global_key_trans, (0, 2, 3, 1))

        # concat to attn_probs
        # (batch_size, seq_len, num_heads, extra attention count + 2*window+1)
        attn_scores = tf.concat((attn_probs_from_global_key, attn_scores),
                                axis=-1)
        return attn_scores
    def _pad_to_window_size(
        self,
        word_ids,
        mask,
        type_ids,
        word_embeddings,
        pad_token_id,
    ):
        # padding
        attention_window = max(self._attention_window)

        assert (attention_window %
                2 == 0), ('`attention_window` should be an even value.'
                          f'Given {attention_window}')

        input_shape = get_shape_list(
            word_ids) if word_ids is not None else get_shape_list(
                word_embeddings)
        batch_size, seq_len = input_shape[:2]

        if seq_len is not None:
            padding_len = (attention_window -
                           seq_len % attention_window) % attention_window
        else:
            padding_len = 0

        paddings = tf.convert_to_tensor([[0, 0], [0, padding_len]])

        if word_ids is not None:
            word_ids = tf.pad(word_ids, paddings, constant_values=pad_token_id)

        if word_embeddings is not None:

            def pad_embeddings():
                word_ids_padding = tf.fill((batch_size, padding_len),
                                           self.pad_token_id)
                word_embeddings_padding = self._embedding_layer(
                    word_ids_padding)
                return tf.concat([word_embeddings, word_embeddings_padding],
                                 axis=-2)

            word_embeddings = tf.cond(tf.math.greater(padding_len, 0),
                                      pad_embeddings, lambda: word_embeddings)

        mask = tf.pad(
            mask, paddings,
            constant_values=False)  # no attention on the padding tokens
        token_type_ids = tf.pad(
            type_ids, paddings,
            constant_values=0)  # pad with token_type_id = 0

        return (
            padding_len,
            word_ids,
            mask,
            token_type_ids,
            word_embeddings,
        )
Exemplo n.º 10
0
    def _gather_indexes(self, sequence_tensor, positions):
        """Gathers the vectors at the specific positions.

    Args:
        sequence_tensor: Sequence output of shape
          (`batch_size`, `seq_length`, `num_hidden`) where `num_hidden` is
          number of hidden units.
        positions: Positions ids of tokens in batched sequences.

    Returns:
        Sequence tensor of shape (batch_size * num_predictions,
        num_hidden).
    """
        sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                                 name='sequence_output_tensor')
        batch_size, seq_length, width = sequence_shape

        flat_offsets = tf.reshape(
            tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
        flat_positions = tf.reshape(positions + flat_offsets, [-1])
        flat_sequence_tensor = tf.reshape(sequence_tensor,
                                          [batch_size * seq_length, width])
        output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

        return output_tensor
Exemplo n.º 11
0
    def call(self, input_embeddings: tf.Tensor,
             input_mask: tf.Tensor) -> Dict[str, tf.Tensor]:
        batch_size, seq_len, embedding_dim = tf_utils.get_shape_list(
            input_embeddings, expected_rank=3)
        example_ids = None
        reduced_batch_size = batch_size // self.pack_sequences
        packed_seq_len = self.pack_sequences * seq_len
        packed_embeddings = tf.reshape(
            input_embeddings,
            [reduced_batch_size, packed_seq_len, embedding_dim])
        input_mask = tf.reshape(input_mask,
                                [reduced_batch_size, packed_seq_len])
        example_ids = 1 + tf.range(self.pack_sequences)
        # Shape: [batch_size, seq_len, pack_sequences].
        example_ids = tf.tile(example_ids[None, :, None],
                              [reduced_batch_size, 1, seq_len])
        example_ids = tf.reshape(example_ids,
                                 [reduced_batch_size, packed_seq_len])
        example_ids = tf.where(tf.math.equal(input_mask, 0),
                               tf.zeros_like(example_ids), example_ids)
        packing_mask = _packing_mask(example_ids, example_ids, dtype=tf.bool)

        attention_mask = self_attention_mask.get_mask(packed_embeddings,
                                                      input_mask,
                                                      dtype=tf.bool)

        combined_attention_mask = tf.cast(
            tf.math.logical_and(attention_mask, packing_mask), tf.float32)

        return dict(packed_embeddings=packed_embeddings,
                    combined_attention_mask=combined_attention_mask)
Exemplo n.º 12
0
    def build_losses(self,
                     labels,
                     model_outputs,
                     aux_losses=None) -> tf.Tensor:
        """Interface to compute losses. Refer to base_task.Task.build_losses."""
        del labels

        left_logits = model_outputs['left_logits']
        right_logits = model_outputs['right_logits']

        batch_size = tf_utils.get_shape_list(left_logits, name='batch_size')[0]

        ranking_labels = tf.range(batch_size)

        loss = tf_utils.safe_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=ranking_labels, logits=left_logits))

        if self.task_config.model.bidirectional:
            right_rank_loss = tf_utils.safe_mean(
                tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=ranking_labels, logits=right_logits))

            loss += right_rank_loss
        return tf.reduce_mean(loss)
Exemplo n.º 13
0
 def call(self, inputs):
   """Implements call() for the layer."""
   length = self._length
   if inputs is None and length is None:
     raise ValueError(
         "If inputs is None, `length` must be set in "
         "RelativePositionEmbedding().")
   if inputs is not None:
     input_shape = tf_utils.get_shape_list(inputs)
     if length is not None and length != input_shape[1]:
       raise ValueError(
           "If inputs is not None, `length` must equal to input_shape[1]."
       )
     length = input_shape[1]
   position = tf.cast(tf.range(length), tf.float32)
   num_timescales = self._hidden_size // 2
   min_timescale, max_timescale = self._min_timescale, self._max_timescale
   log_timescale_increment = (
       math.log(float(max_timescale) / float(min_timescale)) /
       (tf.cast(num_timescales, tf.float32) - 1))
   inv_timescales = min_timescale * tf.exp(
       tf.cast(tf.range(num_timescales), tf.float32) *
       -log_timescale_increment)
   scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(inv_timescales,
                                                              0)
   position_embeddings = tf.concat([tf.sin(scaled_time), tf.cos(scaled_time)],
                                   axis=1)
   return position_embeddings
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     flat_input = tf.reshape(inputs, [-1])
     output = tf.gather(self.embeddings, flat_input)
     output = tf.reshape(output, input_shape + [self.embedding_size])
     return output
Exemplo n.º 15
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions.

  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.

  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
    sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                             name='sequence_output_tensor')
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.keras.backend.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.keras.backend.reshape(
        sequence_tensor, [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

    return output_tensor
Exemplo n.º 16
0
    def _gather_indexes(self, sequence_tensor, positions):
        """Gathers the vectors at the specific positions.
    
        Args:
            sequence_tensor: Sequence output of `BertModel` layer of shape
              (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
              hidden units of `BertModel` layer.
            positions: Positions ids of tokens in sequence to mask for pretraining
              of with dimension (batch_size, num_predictions) where
              `num_predictions` is maximum number of tokens to mask out and predict
              per each sequence.
    
        Returns:
            Masked out sequence tensor of shape (batch_size * num_predictions,
            num_hidden).
        """
        sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                                 name='sequence_output_tensor')
        batch_size, seq_length, width = sequence_shape

        # positions 为遮蔽的单词的 id, 形状 batch,num_predictions
        # 获取被遮蔽单词,在批量序列展平后,对应的索引
        flat_offsets = tf.reshape(
            tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
        flat_positions = tf.reshape(positions + flat_offsets, [-1])

        # 将输入展平,
        flat_sequence_tensor = tf.reshape(sequence_tensor,
                                          [batch_size * seq_length, width])
        output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

        # 获取被遮蔽单词 batch*num_predictions, width
        return output_tensor
Exemplo n.º 17
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            token_type_embeddings = tf.gather(self.type_embeddings,
                                              flat_token_type_ids)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output)

        return output
Exemplo n.º 18
0
def sample_from_softmax(logits, disallow=None):
    """Implement softmax sampling using gumbel softmax trick.

  Args:
    logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
      the generator output logits for each masked position.
    disallow: If `None`, we directly sample tokens from the logits. Otherwise,
      this is a tensor of size [batch_size, num_token_predictions, vocab_size]
      indicating the true word id in each masked position.

  Returns:
    sampled_tokens: A [batch_size, num_token_predictions, vocab_size] one hot
      tensor indicating the sampled word id in each masked position.
  """
    if disallow is not None:
        logits -= 1000.0 * disallow
    uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits),
                                      minval=0,
                                      maxval=1)
    gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)

    # Here we essentially follow the original paper and use temperature 1.0 for
    # generator output logits.
    sampled_tokens = tf.one_hot(
        tf.argmax(tf.nn.softmax(logits + gumbel_noise),
                  -1,
                  output_type=tf.int32), logits.shape[-1])
    return sampled_tokens
  def call(self, input_tensor, unpooled_len=0):
    if self.pool_size == 1:
      return input_tensor

    batch_size, seq_len = tf_utils.get_shape_list(input_tensor, expected_rank=2)
    # reshape tensor in order to use tf.nn.pool
    reshaped_tensor = tf.reshape(input_tensor, [batch_size, seq_len, 1])
    if self.nocls:
      tensor_to_pool = reshaped_tensor[:, 1:, :]
    else:
      tensor_to_pool = reshaped_tensor

    if unpooled_len > 0:
      tensor_to_pool = tensor_to_pool[:, :-unpooled_len, :]

    pooled_tensor = tf.nn.max_pool(
        tensor_to_pool,
        ksize=self.pool_size,
        strides=self.pool_size,
        padding='SAME')

    if self.nocls:
      pooled_tensor = tf.concat([reshaped_tensor[:, 0:1, :], pooled_tensor],
                                axis=1)
    if unpooled_len > 0:
      pooled_tensor = tf.concat(
          [pooled_tensor, reshaped_tensor[:, -unpooled_len:, :]], axis=1)

    pooled_tensor = tf.reshape(pooled_tensor, [batch_size, -1])
    return pooled_tensor
Exemplo n.º 20
0
        def symbols_to_logits_fn(ids, i, cache):
            """Generate logits for next potential IDs.

      Args:
        ids: Current decoded sequences. int tensor with shape `(batch_size *
          beam_size, i + 1)`.
        i: Loop index.
        cache: Dictionary of values storing the encoder output, encoder-decoder
          attention bias, and previous decoder attention values.

      Returns:
        Tuple of
          (logits with shape `(batch_size * beam_size, vocab_size)`,
           updated cache values)
      """
            # Set decoder input to the last generated IDs
            decoder_input = ids[:, -1:]

            # Preprocess decoder input by getting embeddings and adding timing signal.
            # decoder_input = self.embedding_softmax_layer(decoder_input)
            source_decoder_input = decoder_input
            decoder_input = self.embedding_lookup(decoder_input)
            embedding_mask = tf.cast(tf.not_equal(source_decoder_input, 0),
                                     decoder_input.dtype)
            decoder_input *= tf.expand_dims(embedding_mask, -1)
            decoder_input += timing_signal[i]
            if self._padded_decode:
                # indexing does not work on TPU.
                bias_shape = decoder_self_attention_mask.shape.as_list()
                self_attention_mask = tf.slice(
                    decoder_self_attention_mask, [0, i, 0],
                    [bias_shape[0], 1, bias_shape[2]])
            else:
                self_attention_mask = decoder_self_attention_mask[:, i:i +
                                                                  1, :i + 1]
            decoder_shape = tf_utils.get_shape_list(decoder_input,
                                                    expected_rank=3)
            batch_size = decoder_shape[0]
            decoder_length = decoder_shape[1]

            self_attention_mask = tf.tile(self_attention_mask,
                                          [batch_size, 1, 1])
            attention_mask = cache.get("encoder_decoder_attention_mask")
            attention_mask = tf.tile(attention_mask, [1, decoder_length, 1])

            decoder_outputs = self.decoder_layer(
                decoder_input,
                cache.get("encoder_outputs"),
                self_attention_mask=self_attention_mask,
                cross_attention_mask=attention_mask,
                cache=cache,
                decode_loop_step=i if self._padded_decode else None)

            decoder_outputs = tf.cast(decoder_outputs,
                                      dtype=self.compute_dtype)
            logits = self._embedding_linear(self.embedding_lookup.embeddings,
                                            decoder_outputs)
            logits = tf.squeeze(logits, axis=[1])
            return logits, cache
Exemplo n.º 21
0
def sample_k_from_softmax(logits, k, disallow=None, use_topk=False):
    """Implement softmax sampling using gumbel softmax trick to select k items.

  Args:
    logits: A [batch_size, num_token_predictions, vocab_size] tensor indicating
      the generator output logits for each masked position.
    k: Number of samples
    disallow: If `None`, we directly sample tokens from the logits. Otherwise,
      this is a tensor of size [batch_size, num_token_predictions, vocab_size]
      indicating the true word id in each masked position.
    use_topk: Whether to use tf.nn.top_k or using iterative approach where the
      latter is empirically faster.

  Returns:
    sampled_tokens: A [batch_size, num_token_predictions, k] tensor indicating
    the sampled word id in each masked position.
  """
    if use_topk:
        if disallow is not None:
            logits -= 10000.0 * disallow
        uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits),
                                          minval=0,
                                          maxval=1)
        gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
        _, sampled_tokens = tf.nn.top_k(logits + gumbel_noise,
                                        k=k,
                                        sorted=False)
    else:
        sampled_tokens_list = []
        vocab_size = tf_utils.get_shape_list(logits)[-1]
        if disallow is not None:
            logits -= 10000.0 * disallow

        uniform_noise = tf.random.uniform(tf_utils.get_shape_list(logits),
                                          minval=0,
                                          maxval=1)
        gumbel_noise = -tf.math.log(-tf.math.log(uniform_noise + 1e-9) + 1e-9)
        logits += gumbel_noise
        for _ in range(k):
            token_ids = tf.argmax(logits, -1, output_type=tf.int32)
            sampled_tokens_list.append(token_ids)
            logits -= 10000.0 * tf.one_hot(
                token_ids, depth=vocab_size, dtype=tf.float32)
        sampled_tokens = tf.stack(sampled_tokens_list, -1)
    return sampled_tokens
    def call(self, inputs):
        """Implements call() for the layer."""
        input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
        if self._use_dynamic_slicing:
            position_embeddings = self._position_embeddings[:input_shape[1], :]
        else:
            position_embeddings = self._position_embeddings

        return tf.broadcast_to(position_embeddings, input_shape)
Exemplo n.º 23
0
def remove_sos_from_seq(seq, pad_token_id):
  """Remove the start sequence token while keeping seq length."""
  batch_size, seq_len = tf_utils.get_shape_list(seq, expected_rank=2)
  # remove <s>
  targets = seq[:, 1:]
  # pad
  pad_ids = tf.ones([batch_size], tf.int32) * pad_token_id
  targets = tf.concat([targets, tf.expand_dims(pad_ids, axis=1)], axis=1)
  tf.assert_equal(tf.shape(targets), (batch_size, seq_len))
  return targets
    def call(self, inputs):
        """Implements call() for the layer."""
        input_shape = tf_utils.get_shape_list(inputs)

        # 将 betch,seq_len 的数据展平,便于计算
        flat_input = tf.reshape(inputs, [-1])
        output = tf.gather(self.embeddings, flat_input)

        # 再还原成 batch 数据
        output = tf.reshape(output, input_shape + [self.embedding_size])
        return output
Exemplo n.º 25
0
    def call(self, inputs, length=None):
        """Implements call() for the layer.
    
        Args:
          inputs: An tensor whose second dimension will be used as `length`. If
            `None`, the other `length` argument must be specified.
          length: An optional integer specifying the number of positions. If both
            `inputs` and `length` are spcified, `length` must be equal to the
            second dimension of `inputs`.
    
        Returns:
          A tensor in shape of [length, hidden_size].
        """
        if inputs is None and length is None:
            raise ValueError("If inputs is None, `length` must be set in "
                             "RelativePositionEmbedding().")
        if inputs is not None:
            input_shape = tf_utils.get_shape_list(inputs)
            if length is not None and length != input_shape[1]:
                raise ValueError(
                    "If inputs is not None, `length` must equal to input_shape[1]."
                )
            length = input_shape[1]

        # range(10)
        position = tf.cast(tf.range(length), tf.float32)

        # e.g. : 8 // 2
        num_timescales = self._hidden_size // 2

        # 1.0, 1.0e4
        min_timescale, max_timescale = self._min_timescale, self._max_timescale

        # log(1.e4 / 1.) / (4 - 1) = 3.07
        log_timescale_increment = (
            math.log(float(max_timescale) / float(min_timescale)) /
            (tf.cast(num_timescales, tf.float32) - 1))

        # 1.0 * exp( [0.0, 1.0, 2.0, 3.0 ] * -3.07 )
        inv_timescales = min_timescale * tf.exp(
            tf.cast(tf.range(num_timescales), tf.float32) *
            -log_timescale_increment)

        # (length,1) * (1,num_timescale)
        scaled_time = tf.expand_dims(position, 1) * tf.expand_dims(
            inv_timescales, 0)
        # 分别 sin 和 cos 操作,然后拼接成 hidden_size 长的向量
        position_embeddings = tf.concat(
            [tf.sin(scaled_time), tf.cos(scaled_time)], axis=1)

        # r = log( max_ - min_) / (hidden_size / 2)
        # a = [ [0], [1], [2]...[len] ] * [ e^ ( r * [0, 1, 2, ... hidden_size/2 ] ) ]
        # o = concat( sin(a), cos(a) )  -->  len, hidden_size
        return position_embeddings
    def _pad_and_transpose_last_two_dims(hidden_states_padded, paddings):
        """Pads rows and then flips rows and columns."""
        hidden_states_padded = tf.pad(
            hidden_states_padded, paddings
        )  # padding value is not important because it will be overwritten
        batch_size, chunk_size, seq_length, hidden_dim = get_shape_list(
            hidden_states_padded)
        hidden_states_padded = tf.reshape(
            hidden_states_padded,
            (batch_size, chunk_size, hidden_dim, seq_length))

        return hidden_states_padded
  def call(self, input_positions):
    """Implements call() for the layer."""
    batch_size, seq_len = tf_utils.get_shape_list(
        input_positions, expected_rank=2)
    flat_positions = tf.reshape(input_positions, [-1])
    position_embeddings = tf.gather(self._position_embeddings, flat_positions)
    position_embeddings = tf.reshape(position_embeddings,
                                     [batch_size, seq_len, self.embed_dim])

    if self._use_dynamic_slicing:
      position_embeddings = position_embeddings[:, :seq_len, :]

    return position_embeddings
Exemplo n.º 28
0
    def call(self, target_embedding):
        lm_data = self.dense(target_embedding)
        lm_data = self.layer_norm(lm_data)
        lm_data = tf.matmul(lm_data, self.embedding_table, transpose_b=True)
        logits = tf.nn.bias_add(lm_data, self.bias)

        masked_positions_shape = tf_utils.get_shape_list(
            target_embedding, name='masked_positions_tensor')
        logits = tf.reshape(logits,
                            [-1, masked_positions_shape[1], self._vocab_size])
        if self._output_type == 'logits':
            return logits
        return tf.nn.log_softmax(logits)
    def _parse_inputs(self, inputs):
        """Parses the `call` inputs and returns an uniformed output."""
        sources = inputs.get("inputs", None)
        input_mask = inputs.get("input_masks", None)
        embedded = inputs.get("embedded_inputs", None)

        if sources is None and embedded is not None:
            embedded_inputs = embedded
            boolean_mask = input_mask
            input_shape = tf_utils.get_shape_list(embedded, expected_rank=3)
            source_dtype = embedded.dtype
        elif sources is not None:
            embedded_inputs = self.embedding_lookup(sources)
            boolean_mask = tf.not_equal(sources, 0)
            input_shape = tf_utils.get_shape_list(sources, expected_rank=2)
            source_dtype = sources.dtype
        else:
            raise KeyError(
                "The call method expects either `inputs` or `embedded_inputs` and "
                "`input_masks` as input features.")

        return embedded_inputs, boolean_mask, input_shape, source_dtype
  def call(self, inputs):
    """Implements call() for the layer."""
    if self._use_dynamic_slicing:
      input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
      seq_length = input_shape[1]
      width = input_shape[2]

      position_embeddings = tf.expand_dims(
          tf.slice(self._position_embeddings, [0, 0], [seq_length, width]),
          axis=0)
    else:
      position_embeddings = tf.expand_dims(self._position_embeddings, axis=0)

    return position_embeddings