コード例 #1
0
 def call(self, inputs, **kwargs):
     """Implements call() for the layer."""
     (input_tensor, attention_mask) = tf_utils.unpack_inputs(inputs)
     attention_output = self.attention_layer(
         from_tensor=input_tensor,
         to_tensor=input_tensor,
         attention_mask=attention_mask, **kwargs)
     attention_output = self.attention_output_dense(attention_output)
     attention_output = self.attention_dropout(
         attention_output, training=kwargs.get('training', False))
     # Use float32 in keras layer norm and the gelu activation in the
     # intermediate dense layer for numeric stability
     attention_output = self.attention_layer_norm(input_tensor +
                                                  attention_output)
     if self.float_type == tf.float16:
         attention_output = tf.cast(attention_output, tf.float16)
     intermediate_output = self.intermediate_dense(attention_output)
     if self.float_type == tf.float16:
         intermediate_output = tf.cast(intermediate_output, tf.float16)
     layer_output = self.output_dense(intermediate_output)
     layer_output = self.output_dropout(
         layer_output, training=kwargs.get('training', False))
     # Use float32 in keras layer norm for numeric stability
     layer_output = self.output_layer_norm(layer_output + attention_output)
     if self.float_type == tf.float16:
         layer_output = tf.cast(layer_output, tf.float16)
     return layer_output
コード例 #2
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        (from_tensor, to_tensor, attention_mask) = tf_utils.unpack_inputs(inputs)

        # Scalar dimensions referenced here:
        #   B = batch size (number of sequences)
        #   F = `from_tensor` sequence length
        #   T = `to_tensor` sequence length
        #   N = `num_attention_heads`
        #   H = `size_per_head`
        # `query_tensor` = [B, F, N ,H]
        query_tensor = self.query_dense(from_tensor)

        # `key_tensor` = [B, T, N, H]
        key_tensor = self.key_dense(to_tensor)

        # `value_tensor` = [B, T, N, H]
        value_tensor = self.value_dense(to_tensor)

        # Take the dot product between "query" and "key" to get the raw
        # attention scores.
        attention_scores = tf.einsum(
            "BTNH,BFNH->BNFT", key_tensor, query_tensor)
        attention_scores = attention_scores / self.softmax_temperature
        attention_scores = tf.multiply(attention_scores,
                                       1.0 / math.sqrt(float(self.size_per_head)))

        if attention_mask is not None:
            # `attention_mask` = [B, 1, F, T]
            attention_mask = tf.expand_dims(attention_mask, axis=[1])

            # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
            # masked positions, this operation will create a tensor which is 0.0 for
            # positions we want to attend and -10000.0 for masked positions.
            adder = (1.0 - tf.cast(attention_mask,
                                   attention_scores.dtype)) * -10000.0

            # Since we are adding it to the raw scores before the softmax, this is
            # effectively the same as removing these entirely.
            attention_scores += adder

        # Normalize the attention scores to probabilities.
        # `attention_probs` = [B, N, F, T]
        attention_probs = tf.nn.softmax(attention_scores)

        # This is actually dropping out entire tokens to attend to, which might
        # seem a bit unusual, but is taken from the original Transformer paper.
        attention_probs = self.attention_probs_dropout(
            attention_probs, training=kwargs.get('training', False))

        # `context_layer` = [B, F, N, H]
        context_tensor = tf.einsum(
            "BNFT,BTNH->BFNH", attention_probs, value_tensor)

        return context_tensor
コード例 #3
0
    def call(self, inputs, mode="bert", **kwargs):
        """Implements call() for the layer.

        Args:
          inputs: packed input tensors.
          mode: string, `bert` or `encoder`.
        Returns:
          Output tensor of the last layer for BERT training (mode=`bert`) which
          is a float Tensor of shape [batch_size, seq_length, hidden_size] or
          a list of output tensors for encoder usage (mode=`encoder`).
        """
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        input_word_ids = unpacked_inputs[0]
        input_mask = unpacked_inputs[1]
        segment_ids = unpacked_inputs[2]
        column_ids = unpacked_inputs[3]
        row_ids = unpacked_inputs[4]
        prev_label_ids = unpacked_inputs[5]
        column_ranks = unpacked_inputs[6]
        inv_column_ranks = unpacked_inputs[7]
        numeric_relations = unpacked_inputs[8]

        word_embeddings = self.embedding_lookup(input_word_ids)
        embedding_tensor = self.embedding_postprocessor(word_embeddings=word_embeddings, segment_ids=segment_ids,
                                                        column_ids=column_ids, row_ids=row_ids, prev_label_ids=prev_label_ids,
                                                        column_ranks=column_ranks, inv_column_ranks=inv_column_ranks,
                                                        numeric_relations=numeric_relations)
        if self.float_type == tf.float16:
            embedding_tensor = tf.cast(embedding_tensor, tf.float16)
        attention_mask = None
        if input_mask is not None:
            attention_mask = create_attention_mask_from_input_mask(
                input_word_ids, input_mask)

        if mode == "encoder":
            return self.encoder(
                embedding_tensor, attention_mask, return_all_layers=True)

        sequence_output = self.encoder(embedding_tensor, attention_mask)
        first_token_tensor = tf.squeeze(sequence_output[:, 0:1, :], axis=1)
        pooled_output = self.pooler_transform(first_token_tensor)

        return (pooled_output, sequence_output)
コード例 #4
0
    def call(self, inputs, return_all_layers=False, **kwargs):
        """Implements call() for the layer.

        Args:
          inputs: packed inputs.
          return_all_layers: bool, whether to return outputs of all layers inside
            encoders.
        Returns:
          Output tensor of the last layer or a list of output tensors.
        """
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        input_tensor = unpacked_inputs[0]
        attention_mask = unpacked_inputs[1]
        output_tensor = input_tensor

        all_layer_outputs = []
        for layer in self.layers:
            output_tensor = layer(output_tensor, attention_mask, **kwargs)
            all_layer_outputs.append(output_tensor)

        if return_all_layers:
            return all_layer_outputs

        return all_layer_outputs[-1]
コード例 #5
0
    def call(self, inputs, **kwargs):
        """Implements call() for the layer."""
        unpacked_inputs = tf_utils.unpack_inputs(inputs)
        word_embeddings = unpacked_inputs[0]
        segment_ids = unpacked_inputs[1]
        column_ids = unpacked_inputs[2]
        row_ids = unpacked_inputs[3]
        prev_label_ids = unpacked_inputs[4]
        column_ranks = unpacked_inputs[5]
        inv_column_ranks = unpacked_inputs[6]
        numeric_relations = unpacked_inputs[7]
        input_shape = tf_utils.get_shape_list(word_embeddings, expected_rank=3)
        batch_size = input_shape[0]
        seq_length = input_shape[1]
        width = input_shape[2]

        output = word_embeddings
        token_type_ids_list = [segment_ids, column_ids, row_ids, prev_label_ids,
                               column_ranks, inv_column_ranks, numeric_relations]
        token_type_embeddings_list = [self.segment_embeddings, self.column_embeddings, self.row_embeddings, self.prev_label_embeddings,
                                      self.column_ranks_embeddings, self.inv_column_ranks_embeddings, self.numeric_relations_embeddings]
        if self.use_type_embeddings:
            for i, (token_type_ids, type_embeddings) in enumerate(zip(token_type_ids_list, token_type_embeddings_list)):
                flat_token_type_ids = tf.reshape(token_type_ids, [-1])
                one_hot_ids = tf.one_hot(
                    flat_token_type_ids,
                    depth=self.token_type_vocab_size[i],
                    dtype=self.dtype)
                token_type_embeddings = tf.matmul(
                    one_hot_ids, type_embeddings)
                token_type_embeddings = tf.reshape(token_type_embeddings,
                                                   [batch_size, seq_length, width])
                output += token_type_embeddings

        if self.use_position_embeddings:
            if not self.reset_position_index_per_cell:
                position_embeddings = tf.expand_dims(
                    tf.slice(self.position_embeddings, [
                        0, 0], [seq_length, width]),
                    axis=0)
            else:
                col_index = segmented_tensor.IndexMap(
                    token_type_ids_list[1], self.token_type_vocab_size[1], batch_dims=1)
                row_index = segmented_tensor.IndexMap(
                    token_type_ids_list[2], self.token_type_vocab_size[2], batch_dims=1)
                full_index = segmented_tensor.ProductIndexMap(
                    col_index, row_index)
                position = tf.expand_dims(tf.range(seq_length), axis=0)
                batched_position = tf.repeat(
                    position, repeats=batch_size, axis=0)
                first_position_per_segment = segmented_tensor.reduce_min(
                    batched_position, full_index)[0]
                first_position = segmented_tensor.gather(first_position_per_segment,
                                                         full_index)
                position_embeddings = tf.nn.embedding_lookup(self.position_embeddings,
                                                             position - first_position)

            output += position_embeddings

        output = self.output_layer_norm(output)
        output = self.output_dropout(
            output, training=kwargs.get('training', False))

        return output