def build_encoder(self, input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, **kargs): reuse = kargs["reuse"] with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. attention_mask = albert_modules.create_attention_mask_from_input_mask( input_ids, input_mask) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. print("===number of hidden layers===", self.config.num_hidden_layers) if self.config.ln_type == 'postln': print('==apply post layer==') transformer_model = albert_modules.transformer_model elif self.config.ln_type == 'preln': transformer_model = albert_modules.prelln_transformer_model print('==apply pre layer==') [self.all_encoder_layers, self.all_attention_scores] = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=albert_modules.get_activation( self.config.hidden_act), hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=self.config.initializer_range, do_return_all_layers=True, shared_type=self.config.get("shared_type", None), adapter_fn=bert_adapter_modules.get_adapter( self.config.get('adapter_fn', None)))
def token_generator_gumbel_normal(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), kargs.get("num_train_steps", 10000), end_learning_rate=0.1, power=1.0, cycle=False) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) else: annealed_temp = 1.0 tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, samples=config.get('gen_sample', 1)) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', True): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get("straight_through", True): tf.logging.info( "****** apply straight_through_estimator without grl *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp) else: tf.logging.info( "****** apply gumbel-softmax probs without grl *******") sampled_id = flip_gradient(sampled_logprob_temp) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) return sampled_input_id
def token_generator_gumbel(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] # it seems no need for logits to be normalized logits_tempered = logits #tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) temperature_warmup_steps = int(num_train_steps) * 0.1 annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), temperature_warmup_steps, end_learning_rate=0.01, power=1.0, cycle=False) gumbel_samples = None if not kargs.get('use_tpu', True): tf.summary.scalar('annealed_temp', annealed_temp) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if not kargs.get('use_tpu', True): tf.summary.scalar('softplus temperature', tf.reduce_mean(annealed_temp)) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) gumbel_samples = None elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5 annealed_temp_decay = 1.01 - inverse_lin_decay( steps) # minimum temperature is set 0.2 annealed_temp = tf.cond( tf.less(tf.random_uniform([]), 0.9), lambda: annealed_temp_decay, lambda: tf.random_uniform( [], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = 1.01 - inverse_exp_decay( kargs.get("num_train_steps", 10000)) # minimum temperature is set 0.2 annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = inverse_temp_exp_decay( kargs.get("num_train_steps", 10000), kargs.get("max_temp", 100)) annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* " ) tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ", str(kargs.get("num_train_steps", 10000)), str(kargs.get("max_temp", 100))) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) else: annealed_temp = 0.01 gumbel_samples = None tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) if not kargs.get('use_tpu', True): tf.summary.scalar('gumbel_temperature', annealed_temp) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) if kargs.get('stable_gradient', True): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) tf.logging.info( "****** apply normal derivate for gradient calculation *******" ) else: sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1)) tf.logging.info( "****** apply log deriviate for stable gradient calculation *******" ) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', False): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get('if_flip_grad', True): tf.logging.info("****** apply gradient flipping *******") sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp) else: tf.logging.info("****** not apply gradient flipping *******") sampled_logprob_temp_1 = sampled_logprob_temp if kargs.get("straight_through", True): tf.logging.info("****** apply straight_through_estimator *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp_1) else: tf.logging.info("****** apply gumbel-softmax probs *******") sampled_id = sampled_logprob_temp_1 sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids_1 = input_ori_ids input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 ori_input_mask = tf.expand_dims(input_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32) + (1 - unsampled_mask) * tf.cast( ori_input_mask, tf.float32) * tf.cast( input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_equal_id = tf.equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) label_diff_ids_my = tf.cast(label_diff_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast( sampled_equal_id, tf.float32) * tf.cast( tf.squeeze(unsampled_mask, axis=-1), tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) label_diff_ids_my = tf.cast(unsampled_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) tf.summary.scalar("generator_valid_token", tf.reduce_sum(input_mask)) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id sampled_hard_id = tf.cast(sampled_hard_id, tf.float32) sampled_hard_id = tf.reshape( sampled_hard_id, [batch_size, seq_length, config.vocab_size]) sampled_soft_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) sampled_hard_id *= label_diff_ids_my sampled_soft_id *= label_diff_ids_my hard_soft_bias = tf.reduce_sum( tf.sqrt( tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2), axis=-1))) / (1e-10 + tf.reduce_sum( tf.cast(label_diff_ids_my, tf.float32))) tf.summary.scalar('soft_hard_bias', hard_soft_bias) return sampled_input_id
def get_masked_lm_output(config, input_tensor, output_weights, positions, label_ids, label_weights, **kargs): reuse = kargs.get('reuse', False) embedding_projection = kargs.get('embedding_projection', None) """Get loss and log probs for the masked LM.""" input_tensor = tf.cast(input_tensor, tf.float32) positions = tf.cast(positions, tf.int32) label_ids = tf.cast(label_ids, tf.int32) label_weights = tf.cast(label_weights, tf.float32) input_tensor = bert_utils.gather_indexes(input_tensor, positions) """ flatten masked lm ids with positions """ scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=reuse): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: input_tensor = tf.matmul(input_tensor, embedding_projection, transpose_b=True) else: print("==no need for embedding projection==") input_tensor = input_tensor # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) # logits = tf.multiply(logits, # 1.0 / math.sqrt(float(config.hidden_size))) # logits *= 2 logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) # one_hot_labels = tf.one_hot( # label_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.stop_gradient(label_ids), logits=logits) # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # numerator = tf.reduce_sum(label_weights * per_example_loss) # denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs, label_weights)
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] if not kargs.get("apply_valid_vocab", False): logits = logits tf.logging.info("****** normal logits *******") elif kargs.get("apply_valid_vocab", False) == 'topk': prob, _ = top_k_softmax(logits, kargs.get('topk', 10)) logits = tf.log(prob + 1e-10) tf.logging.info("****** topk logits *******") else: invalid_size = kargs.get("invalid_size", 106) invalid_mask = tf.cast( tf.ones((1, invalid_size)) * (-10000), tf.float32) valid_mask = tf.cast( tf.zeros((1, config.vocab_size - invalid_size)), tf.float32) invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1) # invaild_mask = tf.expand_dims(invaild_mask, axis=1) # batch x seq x vocab logits += tf.cast(invaild_mask, tf.float32) tf.logging.info( "****** only valid logits ******* , invalid size: %s", str(invalid_size)) logits_tempered = tf.nn.log_softmax(logits / config.get("temperature", 1.0)) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2 if not kargs.get("greedy", False): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=1.0, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) samples = tf.argmax(sampled_logprob, axis=1) # batch x seq tf.logging.info("****** normal sample *******") else: samples = tf.argmax(flat_logits_tempered, axis=-1) tf.logging.info("****** greedy sample *******") # samples = tf.multinomial(flat_logits_tempered, # num_samples=config.get('gen_sample', 1), # output_dtype=tf.int32) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = sampled_binary_mask # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") if not kargs.get('use_tpu', True): tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': input_ori_ids_1 = input_ori_ids unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32 ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) input_mask = tf.expand_dims(input_mask, axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.float32) if not kargs.get('use_tpu', True): sampled_not_equal_id = tf.not_equal( tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast(sampled_equal_id, tf.float32) * tf.cast( unsampled_mask, tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) # sampled_not_equal_id = tf.not_equal( # tf.cast(sampled_input_id, tf.int32), # tf.cast(input_ori_ids, tf.int32) # ) # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32) # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) # if not kargs.get('use_tpu', True): # tf.summary.scalar('generator_sample_acc', # sampled_not_equal) return sampled_input_id
def seq_mask_masked_lm_output(config, input_tensor, output_weights, input_mask, input_ori_ids, input_ids, sampled_binary_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) """ if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1 """ sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(input_ori_ids), ) per_example_loss *= sampled_binary_mask loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask)) return (loss, per_example_loss, logits, sampled_binary_mask)
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = logits / config.get("temperature", 1.0) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size / 2)) samples = tf.multinomial(flat_logits_tempered_topk, num_samples=config.get('gen_sample', 1), output_dtype=tf.int32) label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32)) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'all') == 'only_mask': label_diff_ids = tf.cast(label_diff_ids, tf.float32) samples = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + label_diff_ids * tf.cast( input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'all') == 'only_mask': # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + input_ori_ids * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) return sampled_input_id