def build_encoder(self, input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, **kargs): reuse = kargs["reuse"] with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = bert_modules.create_attention_mask_from_input_mask( input_ids, input_mask) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. self.all_encoder_layers = bert_modules.transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_modules.get_activation( self.config.hidden_act), hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=self.config.initializer_range, do_return_all_layers=True)
def get_masked_lm_output(config, input_tensor, output_weights, positions, label_ids, label_weights, reuse=None): """Get loss and log probs for the masked LM.""" input_tensor = tf.cast(input_tensor, tf.float32) positions = tf.cast(positions, tf.int32) label_ids = tf.cast(label_ids, tf.int32) label_weights = tf.cast(label_weights, tf.float32) input_tensor = bert_utils.gather_indexes(input_tensor, positions) """ flatten masked lm ids with positions """ with tf.variable_scope("cls/predictions", reuse=reuse): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=config.hidden_size, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) input_tensor = bert_modules.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) # one_hot_labels = tf.one_hot( # label_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=label_ids, logits=logits) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # numerator = tf.reduce_sum(label_weights * per_example_loss) # denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs)
def build_other_output_logits(self, sequence_output, **kargs): input_tensor = sequence_output input_shape_list = bert_utils.get_shape_list(sequence_output, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): projection_width = self.config.emb_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation(self.config.hidden_act), kernel_initializer=bert_modules.create_initializer( self.config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[self.config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, self.emb_mat) logits = tf.nn.bias_add(logits, output_bias) return logits
def build_output_logits(self, **kargs): layer_num = kargs.get("layer_num", -1) self.sequence_output = self.get_encoder_layers(layer_num) input_shape_list = bert_utils.get_shape_list(self.sequence_output, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if self.config.get('ln_type', 'postln') == 'preln': input_tensor = bert_modules.layer_norm(self.sequence_output) tf.logging.info("**** pre ln doing layer norm ****") elif self.config.get('ln_type', 'postln') == 'postln': input_tensor = self.sequence_output tf.logging.info("**** post ln ****") else: input_tensor = self.sequence_output tf.logging.info("**** post ln ****") # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if self.config.get("embedding", "none_factorized") == "none_factorized": projection_width = self.config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = self.config.get('embedding_size', self.config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation( self.config.hidden_act), kernel_initializer=bert_modules.create_initializer( self.config.initializer_range)) if self.config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor tf.logging.info("**** pre ln ****") elif self.config.get('ln_type', 'postln') == 'postln': input_tensor = bert_modules.layer_norm(input_tensor) tf.logging.info("**** post ln doing layer norm ****") else: input_tensor = bert_modules.layer_norm(input_tensor) tf.logging.info("**** post ln doing layer norm ****") if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[self.config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, self.embedding_table) self.logits = tf.nn.bias_add(logits, output_bias)
def build_encoder(self, input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, **kargs): reuse = kargs["reuse"] input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) batch_size = input_shape[0] seq_length = input_shape[1] with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = bert_modules.create_attention_mask_from_input_mask( input_ids, input_mask) seq_type = kargs.get('seq_type', "None") if seq_type == "seq2seq": if kargs.get("mask_type", "left2right") == "left2right": mask_sequence = input_mask tf.logging.info( "==apply left2right LM model with casual mask==") elif kargs.get("mask_type", "left2right") == "seq2seq": token_type_ids = kargs.get("token_type_ids", None) tf.logging.info( "==apply left2right LM model with conditional casual mask==" ) if token_type_ids is None: token_type_ids = tf.zeros( shape=[batch_size, seq_length], dtype=tf.int32) tf.logging.info( "==conditional mask is set to 0 and degenerate to left2right LM model==" ) mask_sequence = token_type_ids attention_mask = bert_utils.generate_seq2seq_mask( attention_mask, mask_sequence, seq_type, **kargs) else: tf.logging.info( "==apply bi-directional LM model with bi-directional mask==" ) if kargs.get('attention_type', 'efficient_attention') == 'normal_attention': tf.logging.info("****** normal attention *******") transformer_model = bert_modules.transformer_model elif kargs.get('attention_type', 'efficient_attention') == 'efficient_attention': tf.logging.info("****** efficient attention *******") transformer_model = bert_modules.transformer_efficient_model else: tf.logging.info("****** normal attention *******") transformer_model = bert_modules.transformer_model # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. [self.all_encoder_layers, self.all_attention_scores] = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_modules.get_activation( self.config.hidden_act), hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=self.config.initializer_range, do_return_all_layers=True)
def classifier(config, seq_output, input_ids, sampled_ids, input_mask, num_labels, dropout_prob, **kargs): """ input_ids: original input ids sampled_ids: generated fake ids """ output_layer = seq_output hidden_size = output_layer.shape[-1].value unk_mask = tf.cast(tf.equal(input_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.equal(input_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.equal(input_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask input_mask = tf.cast(input_mask, tf.int32) input_mask *= tf.cast(1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original hidden = tf.layers.dense( seq_output, units=config.hidden_size, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) logits = tf.layers.dense(hidden, units=2) # batch x seq x 2 # output_weights = tf.get_variable( # "output_weights", [num_labels, hidden_size], # initializer=tf.truncated_normal_initializer(stddev=0.02)) # output_bias = tf.get_variable( # "output_bias", [num_labels], initializer=tf.zeros_initializer()) # if config.get('ln_type', 'postln') == 'preln': # output_layer = albert_modules.layer_norm(output_layer) # print('====preln transformer====') # elif config.get('ln_type', 'postln') == 'postln': # output_layer = output_layer # print('====postln transformer====') # else: # output_layer = output_layer # print('====no layer layer_norm====') # output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob) # logits = tf.einsum("abc,dc->abd", output_layer, output_weights) # logits = tf.nn.bias_add(logits, output_bias) # batch x seq_length x 2 input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(sampled_ids, expected_rank=[2,3]) if len(input_shape_list) == 3: tmp_sampled_ids = tf.argmax(sampled_ids, axis=-1) # batch x seq x vocab tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_sampled_ids = sampled_ids tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** normal 2-D sampled_ids *******") sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: tf.logging.info("****** loss mask using masked token mask for masked tokens *******") loss_mask = sampled_binary_mask else: tf.logging.info("****** loss mask using input_mask for all tokens *******") loss_mask = input_mask # ori_sampled_ids = kargs.get('ori_sampled_ids', None) # if ori_sampled_ids is not None: # input_shape_list = bert_utils.get_shape_list(ori_sampled_ids, expected_rank=[2,3]) # if len(input_shape_list) == 3: # tmp_ori_sampled_ids = tf.argmax(ori_sampled_ids, axis=-1) # batch x seq x vocab # tmp_ori_sampled_ids = tf.cast(tmp_sampled_ori_ids, tf.int32) # tf.logging.info("****** gumbel 3-D sampled_ids *******") # elif len(input_shape_list) == 2: # tmp_ori_sampled_ids = tf.cast(ori_sampled_ids, tf.int32) # tf.logging.info("****** normal 2-D sampled_ids *******") # masked_not_equal_mask = tf.cast(tf.not_equal(input_ids, tmp_ori_sampled_ids), tf.int32) # masked_not_equal_mask *= tf.cast(input_mask, tf.int32) # else: # masked_not_equal_mask = None # if masked_not_equal_mask is not None: # tf.logging.info("****** loss mask using masked token mask for masked tokens *******") # loss_mask = masked_not_equal_mask # else: # tf.logging.info("****** loss mask using input_mask for all tokens *******") # loss_mask = input_mask # original:0, replace:1 not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids), tf.int32) not_equal_label_ids *= tf.cast(input_mask, tf.int32) if kargs.get('loss', 'cross_entropy') == 'cross_entropy': tf.logging.info("====logging discriminator loss using cross entropy ====") per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(not_equal_label_ids)) elif kargs.get('loss', 'cross_entropy') == 'focal_loss': tf.logging.info("====logging discriminator loss using focal loss ====") input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape_list[0] seq_length = input_shape_list[1] not_equal_label_ids_ = tf.reshape(not_equal_label_ids, [batch_size*seq_length]) logits_ = tf.reshape(logits, [batch_size*seq_length, -1]) per_example_loss, _ = loss_utils.focal_loss_binary_v2(config, logits_, not_equal_label_ids_) per_example_loss = tf.reshape(per_example_loss, [batch_size, seq_length]) elif kargs.get('loss', 'cross_entropy') == 'cross_entropy_label_smoothing': tf.logging.info("====logging discriminator loss using cross entropy with label smoothing ====") per_example_loss = loss_utils.ce_label_smoothing(config, logits, not_equal_label_ids, 2, epsilon=0.1) # loss = per_example_loss * tf.cast(loss_mask, tf.float32) # loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast(loss_mask, tf.float32) equal_per_example_loss = per_example_loss * equal_label_ids equal_loss = tf.reduce_sum(equal_per_example_loss) equal_loss_all = equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids)) not_equal_per_example_loss = per_example_loss * tf.cast(not_equal_label_ids, tf.float32) not_equal_loss = tf.reduce_sum(not_equal_per_example_loss) # not equal:1, equal:0 not_equal_loss_all = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) not_equal_loss_output = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32))) loss = (equal_loss + not_equal_loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) # loss = equal_loss_output + not_equal_loss_output * 0.1 tf.logging.info("====discriminator classifier use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging discriminator loss ====") tf.summary.scalar('mask_based_loss', loss) loss = per_example_loss * tf.cast(loss_mask, tf.float32) loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) tf.summary.scalar('equal_loss', equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) tf.summary.scalar('not_equal_loss', not_equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) tf.summary.scalar('loss_decomposition', loss - (equal_loss+not_equal_loss)/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) return (loss, logits, per_example_loss)
def multi_position_classifier(config, features, sequence_output, num_labels, dropout_prob): final_hidden_shape = bert_utils.get_shape_list(sequence_output, expected_rank=3) print(final_hidden_shape, "====multi-choice shape====") answer_pos = tf.cast(features['label_positions'], tf.int32) cls_pos = tf.zeros_like(answer_pos) input_tensor = bert_utils.gather_indexes(sequence_output, answer_pos) cls_tensor = bert_utils.gather_indexes(sequence_output, cls_pos) answer_cls_tensor = tf.concat([cls_tensor, input_tensor], axis=-1) input_tensor = tf.layers.dense( answer_cls_tensor, units=config.hidden_size, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) input_tensor = bert_modules.layer_norm(input_tensor) output_weights = tf.get_variable( "output_weights", [num_labels, final_hidden_shape[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", shape=[num_labels], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) label_ids = tf.reshape(tf.cast(features['label_ids'], tf.int32), [-1]) label_weights = tf.reshape(tf.cast(features['label_weights'], tf.float32), [-1]) if config.get('class_weights', None): class_weights = tf.constant( np.array(config.class_weights).astype(np.float32)) if config.get("loss", "entropy") == "focal_loss": per_example_loss, _ = loss_utils.focal_loss_multi_v1( config, logits=logits, labels=tf.stop_gradient(label_ids)) elif config.get("loss", "smoothed_ce") == 'smoothed_ce': per_example_loss = loss_utils.ce_label_smoothing( config, logits=logits, labels=tf.stop_gradient(label_ids)) elif config.get('loss', 'class_balanced_focal') == 'class_balanced_focal': per_example_loss, _ = loss_utils.class_balanced_focal_loss_multi_v1( config, logits=logits, labels=label_ids, label_weights=class_weights) else: per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.stop_gradient(label_ids), logits=logits) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, logits)
def get_masked_lm_output(config, input_tensor, output_weights, positions, label_ids, label_weights, **kargs): """Get loss and log probs for the masked LM.""" reuse = kargs.get('reuse', False) input_tensor = tf.cast(input_tensor, tf.float32) positions = tf.cast(positions, tf.int32) label_ids = tf.cast(label_ids, tf.int32) label_weights = tf.cast(label_weights, tf.float32) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm scope **** %s", str(scope)) # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width)) input_tensor = bert_utils.gather_indexes(input_tensor, positions) """ flatten masked lm ids with positions """ # with tf.variable_scope("cls/predictions", reuse=reuse): with tf.variable_scope(scope, reuse=reuse): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) input_tensor = bert_modules.layer_norm(input_tensor) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) one_hot_labels = tf.one_hot( label_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.stop_gradient(label_ids), logits=logits) # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # numerator = tf.reduce_sum(label_weights * per_example_loss) # denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs, label_weights)
def seq_mask_masked_lm_output(config, input_tensor, output_weights, input_mask, input_ori_ids, input_ids, sampled_binary_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = bert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = bert_modules.layer_norm(input_tensor) else: input_tensor = bert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) """ if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1 """ sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) input_mask = tf.cast(input_mask, tf.float32) sampled_binary_mask *= input_mask per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(input_ori_ids), ) per_example_loss *= sampled_binary_mask loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask)) return (loss, per_example_loss, logits, sampled_binary_mask)
def emb_score(config, input_tensor, input_ids, output_weights, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] scope = kargs.get('scope', None) if scope: lm_scope = scope + '/' + 'cls/predictions' else: lm_scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(lm_scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(lm_scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = bert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info("==using embedding factorized: embedding size: %s==", str(projection_width)) if kargs.get("energy_pooling", "mi") == "mi": with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = bert_modules.layer_norm(input_tensor) else: input_tensor = bert_modules.layer_norm(input_tensor) output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) tf.logging.info("****** mi using mlm transform *******") elif kargs.get("energy_pooling", "mi") == "cls": with tf.variable_scope("transform_ebm"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(input_tensor[:, 0:1, :], axis=1) input_tensor = tf.layers.dense( first_token_tensor, config.hidden_size, activation=tf.tanh, #bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer(config.initializer_range)) tf.logging.info("****** using cls pooling *******") else: with tf.variable_scope("transform_ebm"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=tf.tanh, #bert_modules.get_activation(config.hidden_act), kernel_initializer=bert_modules.create_initializer( config.initializer_range)) tf.logging.info("****** using other pooling transform *******") # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): if scope: ebm_scope = scope + '/' + 'ebm/predictions' else: ebm_scope = 'ebm/predictions' tf.logging.info("**** ebm generator scope **** %s", str(ebm_scope)) print(input_tensor.get_shape(), "==input_tensor shape==") with tf.variable_scope(ebm_scope, reuse=tf.AUTO_REUSE): # assume the whole model is self-normalization if kargs.get("normalized_constant", "constant") == 'zero_constant': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.zeros_initializer()) valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) tf.logging.info("****** zero_constant logz *******") elif kargs.get("normalized_constant", "constant") == 'one_constant': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.ones_initializer()) tf.logging.info("****** one_constant logz *******") valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) elif kargs.get("normalized_constant", "constant") == 'constant_constant': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*200.0, tf.float32)) tf.logging.info("****** one_constant logz *******") valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) elif kargs.get("normalized_constant", "constant") == 'log9_constant': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*np.log(9.0), tf.float32)) tf.logging.info("****** one_constant logz *******") valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) elif kargs.get("normalized_constant", "constant") == 'logv_constant': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.constant_initializer(np.ones((config.max_position_embeddings))*np.log(config.vocab_size), tf.float32)) tf.logging.info("****** one_constant logz *******") valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) elif kargs.get("normalized_constant", "constant") == 'logv_constant_ln': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[], initializer=tf.constant_initializer(np.log(config.vocab_size), tf.float32)) input_normalized_constant = normalized_constant elif kargs.get("normalized_constant", "length_linear") == 'length_linear': normalized_constant = tf.get_variable( "ebm_normalized_constant", shape=[config.max_position_embeddings], initializer=tf.constant_initializer(np.arange((config.max_position_embeddings))+1, tf.float32), trainable=False) scale_weights = tf.get_variable( "ebm_normalized_constant_scale", shape=[config.max_position_embeddings], initializer=tf.constant_initializer(np.log(config.vocab_size)*np.ones((config.max_position_embeddings)), dtype=tf.float32), trainable=True) scale_bias = tf.get_variable( "ebm_normalized_constant_bias", shape=[config.max_position_embeddings], initializer=tf.zeros_initializer(), trainable=True) tf.logging.info("****** length linear logz *******") # normalized_constant = scale_bias + scale_weights * tf.pow(normalized_constant, 2) valid_seq_length = tf.cast(tf.reduce_sum(input_mask, axis=-1), tf.int32) # batch_size onehot_length_ids = tf.one_hot(valid_seq_length, config.max_position_embeddings) length_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) length_scale_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), scale_weights) length_bias_part = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), scale_bias) input_normalized_constant = length_part*length_scale_part + length_bias_part # input_normalized_constant = tf.einsum("ab,b->a", tf.cast(onehot_length_ids, tf.float32), normalized_constant) # f_input_mask = tf.cast(tf.expand_dims(input_mask, axis=-1), tf.float32) if kargs.get("energy_pooling", "mi") == "mean_pooling": tf.logging.info("==apply mean pooling to get hidden states projections==") # for input token sequence: <start> a b c # we only calculate energy on a,b,c which <start> can't contribute to final # energy function # batch x dim pool_features = tf.einsum("abc,ab->ac", input_tensor[:, 1:], tf.cast(input_mask[:, 1:], tf.float32)) pool_features /= (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=1, keepdims=True)) # tf.reduce_sum(input_tensor*f_input_mask, axis=1) #/ (1e-10+tf.reduce_sum(f_input_mask, axis=1)) print(pool_features.get_shape(), "===pool_features shape===") elif kargs.get("energy_pooling", "mi") == "mi": tf.logging.info("==apply mi to get hidden states projections==") # input_tensor_norm = tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.pow(input_tensor, 2), axis=-1))+1e-20, axis=-1) # input_tensor = input_tensor / tf.stop_gradient(input_tensor_norm) # output_weights_norm = tf.expand_dims(tf.sqrt(tf.reduce_sum(tf.pow(output_weights, 2), axis=-1))+1e-20, axis=-1) # output_weights = output_weights / tf.stop_gradient(output_weights_norm) # we calculate cosine distance to make mi bounded by [-1, 1] logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) # batch x seq x vocab logits = tf.nn.bias_add(logits, output_bias) input_id_shape = bert_utils.get_shape_list(input_ids, [2,3]) if len(input_id_shape) == 2: onehot_input_ids = tf.cast(tf.one_hot(tf.cast(input_ids, tf.int32), config.vocab_size), tf.float32) # batch x seq x vocab input_ori_ids = tf.cast(onehot_input_ids, tf.float32) print("==input ori ids shape== 2-dim", input_ori_ids.get_shape()) else: input_ori_ids = tf.cast(input_ids, tf.float32) print("==input ori ids shape== 3-dim", input_ori_ids.get_shape()) logits = tf.einsum("abd,abd->ab", logits, input_ori_ids) print(logits.get_shape(), "==pooled logits shape==") # with l2-normalize, we can bound logits to 1 pool_features = tf.reduce_sum(logits[:, 1:]*tf.cast(input_mask[:, 1:], tf.float32), axis=1) #/ (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=1)) pool_features = tf.expand_dims(pool_features, axis=-1) print(pool_features.get_shape(), "==pooled feature shape==") if kargs.get("softplus_features", False): # when pooled_features is to infinite, it converges to 0 # when is to minus inifinite, it will converges to inifite pool_features = tf.nn.softplus(-pool_features) tf.logging.info("****** apply softplus transformation for pooled_features *******") elif kargs.get("energy_pooling", "mi") == "cls": with tf.variable_scope("transform"): pool_features = tf.layers.dense( input_tensor, units=1, use_bias=False, activation=None ) tf.logging.info("****** apply linear transformation for pooled_features *******") # batch_size x hidden_dims if kargs.get('transform', True): if kargs.get("transformer_activation", "none") == 'softplus': with tf.variable_scope("transform"): ebm_scalar = tf.layers.dense( pool_features, units=1, use_bias=True, activation=tf.nn.softplus # mask scalar to [0,inifite] ) tf.logging.info("****** apply softplus *******") elif kargs.get("transformer_activation", "none") == 'linear': tf.logging.info("****** apply linear projection *******") with tf.variable_scope("transform"): ebm_scalar = tf.layers.dense( pool_features, units=1, use_bias=True, activation=None # mask scalar to [0,inifite] ) else: with tf.variable_scope("transform"): feature_shape = bert_utils.get_shape_list(pool_features, expected_rank=[1,2]) pool_features = tf.layers.dense( pool_features, units=feature_shape[-1], activation=tf.nn.relu, ) output_weights = tf.get_variable( "output_weights", [config.max_position_embeddings, feature_shape[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [config.max_position_embeddings], initializer=tf.constant_initializer(-np.log(np.arange(config.max_position_embeddings).astype(np.float32)+1.0), dtype=tf.float32) ) # batch x max_position_embeddings ebm_scalar_pos = tf.nn.relu(tf.matmul(pool_features, output_weights, transpose_b=True)) + output_bias pos_tensor = tf.cast(tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1), tf.int32) onehot_pos = tf.cast(tf.one_hot(tf.cast(pos_tensor, tf.int32), config.max_position_embeddings), tf.float32) # batch x seq x vocab ebm_scalar = tf.einsum("ab,ab->a", ebm_scalar_pos, onehot_pos) ebm_scalar = tf.expand_dims(ebm_scalar, axis=-1) tf.logging.info("****** apply linear projection *******") print("===ebm_scalar====", ebm_scalar.get_shape()) ebm_scalar = tf.squeeze(ebm_scalar, axis=-1) print("===ebm_scalar====", ebm_scalar.get_shape()) # ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)) # if kargs.get("energy_pooling", "mi") == "mean_pooling": print("===ebm_scalar====", ebm_scalar.get_shape()) print("===input_normalized_constant====", input_normalized_constant.get_shape()) else: ebm_scalar = tf.squeeze(pool_features, axis=-1) # ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)) print("===ebm_scalar====", ebm_scalar.get_shape()) print("===input_normalized_constant====", input_normalized_constant.get_shape()) if not kargs.get("prob_ln", False): tf.logging.info("****** sum of plogprob as sentence probability *******") # ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)) else: ebm_scalar /= (1e-10+tf.reduce_sum(tf.cast(input_mask[:, 1:], tf.float32), axis=-1)) tf.logging.info("****** sum of plogprob with length normalization as sentence probability *******") print("===ebm_scalar====", ebm_scalar.get_shape()) print("===input_normalized_constant====", input_normalized_constant.get_shape()) # original ebm log-likelihood: # log(exp(-E(x))/Z) = -E(x) - log(Z) # here we use bert encoder of pooled hidden states as energy function which need to minus when apply to # actual energy function if not kargs.get("use_tpu", False): tf.summary.scalar('ebm_scalar', tf.reduce_mean(ebm_scalar)) if kargs.get("logz_mode", "default") == 'default': tf.logging.info("****** default logz *******") logits = -ebm_scalar - input_normalized_constant - tf.log(1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)) elif kargs.get("logz_mode", "default") == 'standard': logits = ebm_scalar - input_normalized_constant tf.logging.info("****** standard logz *******") elif kargs.get("logz_mode", "default") == 'standard_minus': tf.logging.info("****** minus standard logz *******") logits = -ebm_scalar - input_normalized_constant elif kargs.get("logz_mode", "default") == 'constant': logits = -ebm_scalar - tf.log(1e-10+tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1)) tf.logging.info("****** constant logz *******") elif kargs.get("logz_mode", "self_normalizing") == 'self_normalizing': logits = -ebm_scalar tf.logging.info("****** self_normalizing *******") elif kargs.get("logz_mode", "none") == 'none': logits = ebm_scalar tf.logging.info("****** none logz *******") else: tf.logging.info("****** linear logz *******") logits = ebm_scalar - input_normalized_constant * tf.reduce_sum(tf.cast(input_mask, tf.float32), axis=-1) print("=ebm logits shape==", logits.get_shape()) return logits