Example #1
0
def create_attention_mask_from_input_mask(from_tensor, to_mask):
    """Create 3D attention mask from a 2D tensor mask.

  Args:
    from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
    to_mask: int32 Tensor of shape [batch_size, to_seq_length].

  Returns:
    float Tensor of shape [batch_size, from_seq_length, to_seq_length].
  """
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                      dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                             dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
Example #2
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            one_hot_ids = tf.one_hot(flat_token_type_ids,
                                     depth=self.token_type_vocab_size,
                                     dtype=self.dtype)
            token_type_embeddings = tf.matmul(one_hot_ids,
                                              self.type_embeddings)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output,
                                     training=kwargs.get('training', False))

        return output
Example #3
0
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     flat_input = tf.reshape(inputs, [-1])
     output = tf.gather(self.embeddings, flat_input)
     output = tf.reshape(output, input_shape + [self.embedding_size])
     return output
Example #4
0
def gather_indexes(sequence_tensor, positions):
    """Gathers the vectors at the specific positions.
  Args:
      sequence_tensor: Sequence output of `BertModel` layer of shape
        (`batch_size`, `seq_length`, num_hidden) where num_hidden is number of
        hidden units of `BertModel` layer.
      positions: Positions ids of tokens in sequence to mask for pretraining of
        with dimension (batch_size, max_predictions_per_seq) where
        `max_predictions_per_seq` is maximum number of tokens to mask out and
        predict per each sequence.
  Returns:
      Masked out sequence tensor of shape (batch_size * max_predictions_per_seq,
      num_hidden).
  """
    sequence_shape = tf_utils.get_shape_list(sequence_tensor,
                                             name='sequence_output_tensor')
    batch_size = sequence_shape[0]
    seq_length = sequence_shape[1]
    width = sequence_shape[2]

    flat_offsets = tf.keras.backend.reshape(
        tf.range(0, batch_size, dtype=tf.int32) * seq_length, [-1, 1])
    flat_positions = tf.keras.backend.reshape(positions + flat_offsets, [-1])
    flat_sequence_tensor = tf.keras.backend.reshape(
        sequence_tensor, [batch_size * seq_length, width])
    output_tensor = tf.gather(flat_sequence_tensor, flat_positions)

    return output_tensor
Example #5
0
    def call(self, inputs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        token_type_ids = unpacked_inputs[1]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        if self.use_type_embeddings:
            flat_token_type_ids = tf.reshape(token_type_ids, [-1])
            token_type_embeddings = tf.gather(self.type_embeddings,
                                              flat_token_type_ids)
            token_type_embeddings = tf.reshape(token_type_embeddings,
                                               [batch_size, seq_length, width])
            output += token_type_embeddings

        if self.use_position_embeddings:
            position_embeddings = tf.expand_dims(tf.slice(
                self.position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(output)

        return output
Example #6
0
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     flat_input = tf.reshape(inputs, [-1])
     output = tf.gather(self.embeddings, flat_input)
     output = tf.matmul(output, self.project_variable)
     output = tf.reshape(output, input_shape + [self.hidden_size])
     return output
    def call(self, inputs):
        """Implements call() for the layer."""
        if self._use_dynamic_slicing:
            input_shape = tf_utils.get_shape_list(inputs, expected_rank=3)
            seq_length = input_shape[1]
            width = input_shape[2]

            position_embeddings = tf.expand_dims(tf.slice(
                self._position_embeddings, [0, 0], [seq_length, width]),
                                                 axis=0)
        else:
            position_embeddings = tf.expand_dims(self._position_embeddings,
                                                 axis=0)
        return position_embeddings
Example #8
0
    def call(self, inputs):
        input_shape = tf_utils.get_shape_list(inputs, expected_rank=2)
        input_shape.append(self._embedding_width)
        flat_inputs = tf.reshape(inputs, [-1])
        if self._use_one_hot:
            one_hot_data = tf.one_hot(flat_inputs,
                                      depth=self._vocab_size,
                                      dtype=self._dtype)
            embeddings = tf.matmul(one_hot_data, self.embeddings)
        else:
            embeddings = tf.gather(self.embeddings, flat_inputs)
        embeddings = tf.reshape(embeddings, input_shape)

        return embeddings
Example #9
0
    def call(self, inputs):
        from_tensor, to_mask = inputs
        from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
        batch_size = from_shape[0]
        from_seq_length = from_shape[1]

        to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
        to_seq_length = to_shape[1]

        to_mask = tf.cast(tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
                          dtype=from_tensor.dtype)

        # We don't assume that `from_tensor` is a mask (although it could be). We
        # don't actually care if we attend *from* padding tokens (only *to* padding)
        # tokens so we create a tensor of all ones.
        #
        # `broadcast_ones` = [batch_size, from_seq_length, 1]
        broadcast_ones = tf.ones(shape=[batch_size, from_seq_length, 1],
                                 dtype=from_tensor.dtype)

        # Here we broadcast along two dimensions to create the mask.
        mask = broadcast_ones * to_mask

        return mask