class BERTOutput(Layer): def __init__(self, config, **kwargs): self.config = config self.trainable = False super().__init__(**kwargs) self.dense = Dense(input_shape=(config.intermediate_size, ), units=config.hidden_size, trainable=False) self.LayerNorm = BERTLayerNorm(config, trainable=False) self.dropout = Dropout(config.hidden_dropout_prob, trainable=False) def build(self, input_shape): self.dense.build( (self.config.hidden_size, self.config.intermediate_size)) self.LayerNorm.build(self.config.hidden_size) self.dropout.build(self.config.hidden_size) super(BERTOutput, self).build(self.config.hidden_size) def call(self, x, **kwargs): input_tensor, hidden_states = x original_shape = hidden_states.shape hidden_states_r = K.reshape(hidden_states, (-1, hidden_states.shape[-1])) hidden_states = self.dense(hidden_states_r) hidden_states = self.dropout(hidden_states) hidden_states_r = K.reshape( hidden_states, (-1, original_shape[1], hidden_states.shape[-1])) hidden_states = self.LayerNorm(hidden_states_r + input_tensor) return hidden_states
def modify_model(model): model.layers[0].trainable = False layer = Dropout(0.2) model.layers.insert(3, layer) layer.build(model.layers[2].output_shape) model._flattened_layers = None
class BERTEmbeddings(Layer): def __init__(self, config, **kwargs): self.trainable = False super(BERTEmbeddings, self).__init__(**kwargs) self.config = config self.token_type_embeddings = Embedding(config.type_vocab_size, config.hidden_size, name='token_type_embeddings', trainable=False) self.position_embeddings = Embedding(config.max_position_embeddings, config.hidden_size, name='position_embeddings', trainable=False) self.word_embeddings = Embedding(config.vocab_size, config.hidden_size, name='token_embeddings', trainable=False) self.LayerNorm = BERTLayerNorm(config, trainable=False) self.dropout = Dropout(config.hidden_dropout_prob, name='EmbeddingDropOut', trainable=False) def build(self, input_shape): self.token_type_embeddings.build(input_shape) self.position_embeddings.build(input_shape) self.word_embeddings.build(input_shape) self.LayerNorm.build(self.config.hidden_size) self.dropout.build(self.config.hidden_size) super(BERTEmbeddings, self).build(self.config.hidden_size) def call(self, x, **kwargs): input_ids, token_type_ids, position_ids = x words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = words_embeddings + position_embeddings + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings
class BERTSelfOutput(Layer): def __init__(self, config, **kwargs): self.trainable = False super().__init__(**kwargs) self.config = config self.dense = Dense(input_shape=(self.config.hidden_size, ), units=self.config.hidden_size, trainable=False) self.LayerNorm = BERTLayerNorm(self.config, trainable=False) self.dropout = Dropout(self.config.hidden_dropout_prob, trainable=False) def build(self, input_shape): if isinstance(input_shape, tuple) and input_shape[0] is None: dense_input_shape = (self.config.hidden_size, input_shape[1]) else: dense_input_shape = (self.config.hidden_size, input_shape) self.dense.build(dense_input_shape) self.LayerNorm.build(self.config.hidden_size) self.dropout.build(self.config.hidden_size) super(BERTSelfOutput, self).build(input_shape) def call(self, x, **kwargs): input_tensor, hidden_states = x original_shape = hidden_states.shape hidden_states_r = K.reshape(hidden_states, (-1, hidden_states.shape[-1])) hidden_states = self.dense(hidden_states_r) hidden_states = self.dropout(hidden_states) hidden_states_r = K.reshape(hidden_states, (-1, original_shape[1], original_shape[2])) hidden_states = self.LayerNorm(hidden_states_r + input_tensor) return hidden_states
class BERTSelfAttention(Layer): def __init__(self, config, **kwargs): self.trainable = False super().__init__(**kwargs) if config.hidden_size % config.num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (config.hidden_size, config.num_attention_heads)) self.config = config self.num_attention_heads = config.num_attention_heads self.attention_head_size = int(config.hidden_size / config.num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = Dense(input_shape=(config.hidden_size, ), units=self.all_head_size, trainable=False) self.key = Dense(input_shape=(config.hidden_size, ), units=self.all_head_size, trainable=False) self.value = Dense(input_shape=(config.hidden_size, ), units=self.all_head_size, trainable=False) self.dropout = Dropout(config.attention_probs_dropout_prob, trainable=False) def transpose_for_scores(self, x, k: bool = False): x_shape = list(x.shape) new_x_shape = [-1] + x_shape[-2:-1] + [ self.num_attention_heads, self.attention_head_size ] new_x = K.reshape(x, new_x_shape) if k: return K.permute_dimensions(new_x, [0, 2, 3, 1]) else: return K.permute_dimensions(new_x, [0, 2, 1, 3]) def build(self, input_shape): self.query.build((self.all_head_size, self.config.hidden_size)) self.key.build((self.all_head_size, self.config.hidden_size)) self.value.build((self.all_head_size, self.config.hidden_size)) self.dropout.build(input_shape) super(BERTSelfAttention, self).build(input_shape) def call(self, x, **kwargs): hidden_states, attention_mask = x hidden_states_r = K.reshape(hidden_states, (-1, hidden_states.shape[-1])) # `query_layer` = [B*F, N*H] mixed_query_layer = self.query(hidden_states_r) # `key_layer` = [B*T, N*H] mixed_key_layer = self.key(hidden_states_r) # `value_layer` = [B*T, N*H] mixed_value_layer = self.value(hidden_states_r) mixed_query_layer_r = K.reshape( mixed_query_layer, (-1, self.config.max_seq_len, self.config.hidden_size)) mixed_key_layer_r = K.reshape( mixed_key_layer, (-1, self.config.max_seq_len, self.config.hidden_size)) mixed_value_layer_r = K.reshape( mixed_value_layer, (-1, self.config.max_seq_len, self.config.hidden_size)) # `query_layer` = [B, N, F, H] query_layer = self.transpose_for_scores(mixed_query_layer_r, k=False) # `key_layer` = [B, N, T, H] key_layer = self.transpose_for_scores(mixed_key_layer_r, k=True) value_layer = self.transpose_for_scores(mixed_value_layer_r, k=False) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = K.batch_dot(query_layer, key_layer) attention_scores = attention_scores / math.sqrt( self.attention_head_size) # Apply the attention mask is (precomputed for all layers in call to BertModel) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = Softmax(axis=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = K.batch_dot(attention_probs, value_layer) context_layer = K.permute_dimensions(context_layer, [0, 2, 1, 3]) new_context_layer_shape = [ -1, self.config.max_seq_len, self.all_head_size ] context_layer = K.reshape(context_layer, new_context_layer_shape) return context_layer