예제 #1
0
    def build(self, unused_input_shapes):
        """Implements build() for the layer."""
        self.output_bias = self.add_weight(
            shape=[self.config.vocab_size],
            name='predictions/output_bias',
            initializer=tf.keras.initializers.Zeros())
        self.lm_dense = tf.keras.layers.Dense(
            self.config.hidden_size,
            activation=modeling.get_activation(self.config.hidden_act),
            kernel_initializer=self.initializer,
            name='predictions/transform/dense')
        self.lm_layer_norm = tf.keras.layers.LayerNormalization(
            axis=-1, epsilon=1e-12, name='predictions/transform/LayerNorm')

        # Next sentence binary classification dense layer including bias to match
        # TF1.x BERT variable shapes.
        with tf.name_scope('seq_relationship'):
            self.next_seq_weights = self.add_weight(
                shape=[self.num_next_sentence_label, self.config.hidden_size],
                name='output_weights',
                initializer=self.initializer)
            self.next_seq_bias = self.add_weight(
                shape=[self.num_next_sentence_label],
                name='output_bias',
                initializer=tf.keras.initializers.Zeros())
        super(BertPretrainLayer, self).build(unused_input_shapes)
예제 #2
0
 def build(self, unused_input_shapes):
     self.lm_dense = tf.keras.layers.Dense(
         self.config.hidden_size,
         activation=modeling.get_activation(self.config.hidden_act),
         kernel_initializer=self.initializer)
     self.lm_bias = self.add_weight(
         shape=[self.config.vocab_size],
         name='lm_bias',
         initializer=tf.keras.initializers.Zeros())
     self.lm_layer_norm = tf.keras.layers.LayerNormalization(axis=-1,
                                                             epsilon=1e-12)
     self.next_sentence_dense = tf.keras.layers.Dense(
         self.num_next_sentence_label, kernel_initializer=self.initializer)
     super(BertPretrainLayer, self).build(unused_input_shapes)