def global_discriminator_logits(config, input_tensor, reuse=None, **kargs): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/seq_global' else: scope = 'cls/seq_global' tf.logging.info("**** nsp scope **** %s", str(scope)) # with tf.variable_scope("cls/seq_relationship", reuse=reuse): with tf.variable_scope(scope, reuse=reuse): output_weights = tf.get_variable( "output_weights", shape=[2, config.hidden_size], initializer=albert_modules.create_initializer( config.initializer_range)) output_bias = tf.get_variable("output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) return logits
def global_feature_discriminator(config, input_tensor, labels, reuse=None, **kargs): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/seq_global' else: scope = 'cls/seq_global' tf.logging.info("**** nsp scope **** %s", str(scope)) # with tf.variable_scope("cls/seq_relationship", reuse=reuse): with tf.variable_scope(scope, reuse=reuse): output_weights = tf.get_variable( "output_weights", shape=[2, config.hidden_size], initializer=albert_modules.create_initializer(config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, log_probs)
def build_pooler(self, *args, **kargs): reuse = kargs["reuse"] layer_num = kargs.get("layer_num", -1) with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): # self.sequence_output = self.all_encoder_layers[-1] self.sequence_output = self.get_encoder_layers(layer_num) # The "pooler" converts the encoded sequence tensor of shape # [batch_size, seq_length, hidden_size] to a tensor of shape # [batch_size, hidden_size]. This is necessary for segment-level # (or segment-pair-level) classification tasks where we need a fixed # dimensional representation of the segment. with tf.variable_scope("pooler"): # We "pool" the model by simply taking the hidden state corresponding # to the first token. We assume that this has been pre-trained first_token_tensor = tf.squeeze(self.sequence_output[:, 0:1, :], axis=1) self.pooled_output = tf.layers.dense( first_token_tensor, self.config.hidden_size, activation=tf.tanh, kernel_initializer=albert_modules.create_initializer( self.config.initializer_range))
def token_generator_gumbel_normal(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), kargs.get("num_train_steps", 10000), end_learning_rate=0.1, power=1.0, cycle=False) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) else: annealed_temp = 1.0 tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, samples=config.get('gen_sample', 1)) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', True): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get("straight_through", True): tf.logging.info( "****** apply straight_through_estimator without grl *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp) else: tf.logging.info( "****** apply gumbel-softmax probs without grl *******") sampled_id = flip_gradient(sampled_logprob_temp) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) return sampled_input_id
def token_generator_gumbel(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] # it seems no need for logits to be normalized logits_tempered = logits #tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) temperature_warmup_steps = int(num_train_steps) * 0.1 annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), temperature_warmup_steps, end_learning_rate=0.01, power=1.0, cycle=False) gumbel_samples = None if not kargs.get('use_tpu', True): tf.summary.scalar('annealed_temp', annealed_temp) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if not kargs.get('use_tpu', True): tf.summary.scalar('softplus temperature', tf.reduce_mean(annealed_temp)) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) gumbel_samples = None elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5 annealed_temp_decay = 1.01 - inverse_lin_decay( steps) # minimum temperature is set 0.2 annealed_temp = tf.cond( tf.less(tf.random_uniform([]), 0.9), lambda: annealed_temp_decay, lambda: tf.random_uniform( [], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = 1.01 - inverse_exp_decay( kargs.get("num_train_steps", 10000)) # minimum temperature is set 0.2 annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = inverse_temp_exp_decay( kargs.get("num_train_steps", 10000), kargs.get("max_temp", 100)) annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* " ) tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ", str(kargs.get("num_train_steps", 10000)), str(kargs.get("max_temp", 100))) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) else: annealed_temp = 0.01 gumbel_samples = None tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) if not kargs.get('use_tpu', True): tf.summary.scalar('gumbel_temperature', annealed_temp) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) if kargs.get('stable_gradient', True): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) tf.logging.info( "****** apply normal derivate for gradient calculation *******" ) else: sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1)) tf.logging.info( "****** apply log deriviate for stable gradient calculation *******" ) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', False): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get('if_flip_grad', True): tf.logging.info("****** apply gradient flipping *******") sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp) else: tf.logging.info("****** not apply gradient flipping *******") sampled_logprob_temp_1 = sampled_logprob_temp if kargs.get("straight_through", True): tf.logging.info("****** apply straight_through_estimator *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp_1) else: tf.logging.info("****** apply gumbel-softmax probs *******") sampled_id = sampled_logprob_temp_1 sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids_1 = input_ori_ids input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 ori_input_mask = tf.expand_dims(input_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32) + (1 - unsampled_mask) * tf.cast( ori_input_mask, tf.float32) * tf.cast( input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_equal_id = tf.equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) label_diff_ids_my = tf.cast(label_diff_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast( sampled_equal_id, tf.float32) * tf.cast( tf.squeeze(unsampled_mask, axis=-1), tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) label_diff_ids_my = tf.cast(unsampled_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) tf.summary.scalar("generator_valid_token", tf.reduce_sum(input_mask)) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id sampled_hard_id = tf.cast(sampled_hard_id, tf.float32) sampled_hard_id = tf.reshape( sampled_hard_id, [batch_size, seq_length, config.vocab_size]) sampled_soft_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) sampled_hard_id *= label_diff_ids_my sampled_soft_id *= label_diff_ids_my hard_soft_bias = tf.reduce_sum( tf.sqrt( tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2), axis=-1))) / (1e-10 + tf.reduce_sum( tf.cast(label_diff_ids_my, tf.float32))) tf.summary.scalar('soft_hard_bias', hard_soft_bias) return sampled_input_id
def efficient_attention_layer(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None, attention_fixed_size=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor from_shape = bert_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = bert_utils.get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` if attention_fixed_size: attention_head_size = attention_fixed_size tf.logging.info("==apply attention_fixed_size==", str(attention_head_size)) else: attention_head_size = size_per_head tf.logging.info("==apply attention_original_size==", str(attention_head_size)) from_tensor_2d = bert_utils.reshape_to_matrix(from_tensor) to_tensor_2d = bert_utils.reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * attention_head_size, activation=query_act, name="query", kernel_initializer=albert_modules.create_initializer( initializer_range)) # `key_layer` = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * attention_head_size, activation=key_act, name="key", kernel_initializer=albert_modules.create_initializer( initializer_range)) # `value_layer` = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * attention_head_size, activation=value_act, name="value", kernel_initializer=albert_modules.create_initializer( initializer_range)) # softmax(QK^T/sqrt(4))V #softmax(Q)softmax(K)^TV # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, attention_head_size) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, attention_head_size) # `value_layer` = [B, N, T, H] value_layer = transpose_for_scores(value_layer, batch_size, num_attention_heads, to_seq_length, attention_head_size) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, H, H]<---[B, N, T, H] x [B, N, T, H] # key_mask = [B, T, 1, 1] attention_mask = tf.cast( tf.expand_dims(attention_mask[:, 0:1, :], axis=[2]), tf.float32) attention_mask = tf.cast(tf.expand_dims(attention_mask, axis=[3]), tf.float32) # key_mask = [B, 1, T, 1] attention_mask = tf.reshape(attention_mask, [batch_size, 1, to_seq_length, 1]) adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 attention_scores = tf.nn.log_softmax(key_layer + adder, axis=2) attention_probs = tf.exp(attention_scores) attention_probs = albert_modules.dropout(attention_probs, attention_probs_dropout_prob) key_value_scores = tf.matmul(attention_probs, value_layer, transpose_a=True) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. # [B, N, F, H] x [B, N, H, H]--->[B, N, F, H] context_layer = tf.matmul(tf.exp(tf.nn.log_softmax(query_layer, axis=-1)), key_value_scores) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*V] context_layer = tf.reshape(context_layer, [ batch_size * from_seq_length, num_attention_heads * attention_head_size ]) else: # `context_layer` = [B, F, N*V] context_layer = tf.reshape(context_layer, [ batch_size, from_seq_length, num_attention_heads * attention_head_size ]) return context_layer, attention_scores, value_layer
def get_masked_lm_output(config, input_tensor, output_weights, positions, label_ids, label_weights, **kargs): reuse = kargs.get('reuse', False) embedding_projection = kargs.get('embedding_projection', None) """Get loss and log probs for the masked LM.""" input_tensor = tf.cast(input_tensor, tf.float32) positions = tf.cast(positions, tf.int32) label_ids = tf.cast(label_ids, tf.int32) label_weights = tf.cast(label_weights, tf.float32) input_tensor = bert_utils.gather_indexes(input_tensor, positions) """ flatten masked lm ids with positions """ scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=reuse): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: input_tensor = tf.matmul(input_tensor, embedding_projection, transpose_b=True) else: print("==no need for embedding projection==") input_tensor = input_tensor # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) # logits = tf.multiply(logits, # 1.0 / math.sqrt(float(config.hidden_size))) # logits *= 2 logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) # one_hot_labels = tf.one_hot( # label_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.stop_gradient(label_ids), logits=logits) # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # numerator = tf.reduce_sum(label_weights * per_example_loss) # denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs, label_weights)
def transformer_cell(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False, shared_type=None, adapter_fn=None): layer_input = bert_utils.reshape_to_matrix(input_tensor) with tf.variable_scope("layer_shared", reuse=tf.AUTO_REUSE): with tf.variable_scope("attention", reuse=tf.AUTO_REUSE): attention_heads = [] with tf.variable_scope("self"): [attention_head, attention_scores] = albert_modules.attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) all_attention_scores.append(attention_scores) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope("output", reuse=tf.AUTO_REUSE): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=albert_modules.create_initializer( initializer_range)) attention_output = albert_modules.dropout( attention_output, hidden_dropout_prob) if adapter_fn: attention_output = adapter_fn(attention_output, layer_idx=layer_idx) attention_output = albert_modules.layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.variable_scope('intermediate', reuse=tf.AUTO_REUSE): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=albert_modules.create_initializer( initializer_range)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope('output', reuse=tf.AUTO_REUSE): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=albert_modules.create_initializer( initializer_range)) layer_output = albert_modules.dropout(layer_output, hidden_dropout_prob) if adapter_fn: layer_output = adapter_fn(attention_output, layer_idx=layer_idx) layer_output = albert_modules.layer_norm(layer_output + attention_output)
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] if not kargs.get("apply_valid_vocab", False): logits = logits tf.logging.info("****** normal logits *******") elif kargs.get("apply_valid_vocab", False) == 'topk': prob, _ = top_k_softmax(logits, kargs.get('topk', 10)) logits = tf.log(prob + 1e-10) tf.logging.info("****** topk logits *******") else: invalid_size = kargs.get("invalid_size", 106) invalid_mask = tf.cast( tf.ones((1, invalid_size)) * (-10000), tf.float32) valid_mask = tf.cast( tf.zeros((1, config.vocab_size - invalid_size)), tf.float32) invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1) # invaild_mask = tf.expand_dims(invaild_mask, axis=1) # batch x seq x vocab logits += tf.cast(invaild_mask, tf.float32) tf.logging.info( "****** only valid logits ******* , invalid size: %s", str(invalid_size)) logits_tempered = tf.nn.log_softmax(logits / config.get("temperature", 1.0)) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2 if not kargs.get("greedy", False): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=1.0, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) samples = tf.argmax(sampled_logprob, axis=1) # batch x seq tf.logging.info("****** normal sample *******") else: samples = tf.argmax(flat_logits_tempered, axis=-1) tf.logging.info("****** greedy sample *******") # samples = tf.multinomial(flat_logits_tempered, # num_samples=config.get('gen_sample', 1), # output_dtype=tf.int32) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = sampled_binary_mask # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") if not kargs.get('use_tpu', True): tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': input_ori_ids_1 = input_ori_ids unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32 ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) input_mask = tf.expand_dims(input_mask, axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.float32) if not kargs.get('use_tpu', True): sampled_not_equal_id = tf.not_equal( tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast(sampled_equal_id, tf.float32) * tf.cast( unsampled_mask, tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) # sampled_not_equal_id = tf.not_equal( # tf.cast(sampled_input_id, tf.int32), # tf.cast(input_ori_ids, tf.int32) # ) # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32) # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) # if not kargs.get('use_tpu', True): # tf.summary.scalar('generator_sample_acc', # sampled_not_equal) return sampled_input_id
def seq_mask_masked_lm_output(config, input_tensor, output_weights, input_mask, input_ori_ids, input_ids, sampled_binary_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) """ if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1 """ sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(input_ori_ids), ) per_example_loss *= sampled_binary_mask loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask)) return (loss, per_example_loss, logits, sampled_binary_mask)
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = logits / config.get("temperature", 1.0) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size / 2)) samples = tf.multinomial(flat_logits_tempered_topk, num_samples=config.get('gen_sample', 1), output_dtype=tf.int32) label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32)) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'all') == 'only_mask': label_diff_ids = tf.cast(label_diff_ids, tf.float32) samples = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + label_diff_ids * tf.cast( input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'all') == 'only_mask': # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + input_ori_ids * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) return sampled_input_id