Esempio n. 1
0
def create_attention_mask_from_input_mask(from_tensor, to_mask):
    """Create 3D attention mask from a 2D tensor mask.

    Args:
      from_tensor: 2D or 3D Tensor of shape [batch_size, from_seq_length, ...].
      to_mask: int32 Tensor of shape [batch_size, to_seq_length].

    Returns:
      float Tensor of shape [batch_size, from_seq_length, to_seq_length].
    """
    from_shape = tf_utils.get_shape_list(from_tensor, expected_rank=[2, 3])
    batch_size = from_shape[0]
    from_seq_length = from_shape[1]

    to_shape = tf_utils.get_shape_list(to_mask, expected_rank=2)
    to_seq_length = to_shape[1]

    to_mask = tf.cast(
        tf.reshape(to_mask, [batch_size, 1, to_seq_length]),
        dtype=from_tensor.dtype)

    # We don't assume that `from_tensor` is a mask (although it could be). We
    # don't actually care if we attend *from* padding tokens (only *to* padding)
    # tokens so we create a tensor of all ones.
    #
    # `broadcast_ones` = [batch_size, from_seq_length, 1]
    broadcast_ones = tf.ones(
        shape=[batch_size, from_seq_length, 1], dtype=from_tensor.dtype)

    # Here we broadcast along two dimensions to create the mask.
    mask = broadcast_ones * to_mask

    return mask
Esempio n. 2
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        segment_ids = unpacked_inputs[1]
        column_ids = unpacked_inputs[2]
        row_ids = unpacked_inputs[3]
        prev_label_ids = unpacked_inputs[4]
        column_ranks = unpacked_inputs[5]
        inv_column_ranks = unpacked_inputs[6]
        numeric_relations = unpacked_inputs[7]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        token_type_ids_list = [segment_ids, column_ids, row_ids, prev_label_ids,
                               column_ranks, inv_column_ranks, numeric_relations]
        token_type_embeddings_list = [self.segment_embeddings, self.column_embeddings, self.row_embeddings, self.prev_label_embeddings,
                                      self.column_ranks_embeddings, self.inv_column_ranks_embeddings, self.numeric_relations_embeddings]
        if self.use_type_embeddings:
            for i, (token_type_ids, type_embeddings) in enumerate(zip(token_type_ids_list, token_type_embeddings_list)):
                flat_token_type_ids = tf.reshape(token_type_ids, [-1])
                one_hot_ids = tf.one_hot(
                    flat_token_type_ids,
                    depth=self.token_type_vocab_size[i],
                    dtype=self.dtype)
                token_type_embeddings = tf.matmul(
                    one_hot_ids, type_embeddings)
                token_type_embeddings = tf.reshape(token_type_embeddings,
                                                   [batch_size, seq_length, width])
                output += token_type_embeddings

        if self.use_position_embeddings:
            if not self.reset_position_index_per_cell:
                position_embeddings = tf.expand_dims(
                    tf.slice(self.position_embeddings, [
                        0, 0], [seq_length, width]),
                    axis=0)
            else:
                col_index = segmented_tensor.IndexMap(
                    token_type_ids_list[1], self.token_type_vocab_size[1], batch_dims=1)
                row_index = segmented_tensor.IndexMap(
                    token_type_ids_list[2], self.token_type_vocab_size[2], batch_dims=1)
                full_index = segmented_tensor.ProductIndexMap(
                    col_index, row_index)
                position = tf.expand_dims(tf.range(seq_length), axis=0)
                batched_position = tf.repeat(
                    position, repeats=batch_size, axis=0)
                first_position_per_segment = segmented_tensor.reduce_min(
                    batched_position, full_index)[0]
                first_position = segmented_tensor.gather(first_position_per_segment,
                                                         full_index)
                position_embeddings = tf.nn.embedding_lookup(self.position_embeddings,
                                                             position - first_position)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(
            output, training=kwargs.get('training', False))

        return output
Esempio n. 3
0
 def call(self, inputs):
     """Implements call() for the layer."""
     input_shape = tf_utils.get_shape_list(inputs)
     output = tf.nn.embedding_lookup(self.embeddings, inputs)
     output = tf.reshape(output, input_shape + [self.embedding_size])
     return output