def classifier(config, seq_output, input_ids, sampled_ids, input_mask, num_labels, dropout_prob, **kargs): """ input_ids: original input ids sampled_ids: generated fake ids """ output_layer = seq_output hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable( "output_bias", [num_labels], initializer=tf.zeros_initializer()) if config.get('ln_type', 'postln') == 'preln': output_layer = albert_modules.layer_norm(output_layer) elif config.get('ln_type', 'postln') == 'postln': output_layer = output_layer else: output_layer = output_layer output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob) logits = tf.einsum("abc,dc->abd", seq_output, output_weights) logits = tf.nn.bias_add(logits, output_bias) # batch x seq_length x 2 input_ids = tf.cast(input_ids, tf.int32) sampled_ids = tf.cast(sampled_ids, tf.int32) discriminator_label_ids = tf.cast(tf.equal(input_ids, sampled_ids), tf.int32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(discriminator_label_ids)) loss = per_example_loss * tf.cast(input_mask, tf.float32) loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32))) return (loss, logits, per_example_loss)
def get_next_sentence_output(config, input_tensor, labels, reuse=None): """Get loss and log probs for the next sentence prediction.""" # Simple binary classification. Note that 0 is "next sentence" and 1 is # "random sentence". This weight matrix is not used after pre-training. scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/seq_relationship' else: scope = 'cls/seq_relationship' tf.logging.info("**** nsp scope **** %s", str(scope)) # with tf.variable_scope("cls/seq_relationship", reuse=reuse): with tf.variable_scope(scope, reuse=reuse): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor output_weights = tf.get_variable( "output_weights", shape=[2, config.hidden_size], initializer=albert_modules.create_initializer(config.initializer_range)) output_bias = tf.get_variable( "output_bias", shape=[2], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) labels = tf.reshape(labels, [-1]) one_hot_labels = tf.one_hot(labels, depth=2, dtype=tf.float32) per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, log_probs)
def token_generator_gumbel_normal(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), kargs.get("num_train_steps", 10000), end_learning_rate=0.1, power=1.0, cycle=False) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) else: annealed_temp = 1.0 tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, samples=config.get('gen_sample', 1)) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', True): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get("straight_through", True): tf.logging.info( "****** apply straight_through_estimator without grl *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp) else: tf.logging.info( "****** apply gumbel-softmax probs without grl *******") sampled_id = flip_gradient(sampled_logprob_temp) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) return sampled_input_id
def token_generator_gumbel(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] # it seems no need for logits to be normalized logits_tempered = logits #tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) temperature_warmup_steps = int(num_train_steps) * 0.1 annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), temperature_warmup_steps, end_learning_rate=0.01, power=1.0, cycle=False) gumbel_samples = None if not kargs.get('use_tpu', True): tf.summary.scalar('annealed_temp', annealed_temp) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if not kargs.get('use_tpu', True): tf.summary.scalar('softplus temperature', tf.reduce_mean(annealed_temp)) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) gumbel_samples = None elif kargs.get('gumbel_anneal', 'vqvae') == 'vqvae': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) gumbel_samples *= inverse_exp_decay(steps // 5) * 0.5 annealed_temp_decay = 1.01 - inverse_lin_decay( steps) # minimum temperature is set 0.2 annealed_temp = tf.cond( tf.less(tf.random_uniform([]), 0.9), lambda: annealed_temp_decay, lambda: tf.random_uniform( [], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v1') == 'vqvae_v1': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = 1.01 - inverse_exp_decay( kargs.get("num_train_steps", 10000)) # minimum temperature is set 0.2 annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan gumbel-softmax temperature annealing method ******* " ) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) elif kargs.get('gumbel_anneal', 'vqvae_v2') == 'vqvae_v2': temperature_warmup_steps = kargs.get("num_train_steps", 10000) * 0.1 tf.logging.info( "****** apply t2t gumbel-softmax temperature annealing method with warm up steps %s ******* ", str(kargs.get("num_train_steps", 10000) * 0.1)) steps = temperature_warmup_steps gumbel_samples = sample_gumbel(bert_utils.get_shape_list( flat_logits_tempered, expected_rank=2), samples=config.get('gen_sample', 1)) # gumbel_samples *= inverse_exp_decay(steps) annealed_temp_decay = inverse_temp_exp_decay( kargs.get("num_train_steps", 10000), kargs.get("max_temp", 100)) annealed_temp = annealed_temp_decay # annealed_temp = tf.cond( # tf.less(tf.random_uniform([]), 0.95), lambda: annealed_temp_decay, # lambda: tf.random_uniform([], minval=0.5, maxval=1.0)) # 10% step for tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax temperature annealing method ******* " ) tf.logging.info( "****** apply sel-gan-v2 gumbel-softmax num_train_steps:%s annealing method, temp:%s ******* ", str(kargs.get("num_train_steps", 10000)), str(kargs.get("max_temp", 100))) if not kargs.get('use_tpu', True): tf.summary.scalar('t2t_vqvae_stgs temperature', annealed_temp) tf.summary.scalar('t2t_vqvae_stgs temperature decay', annealed_temp_decay) else: annealed_temp = 0.01 gumbel_samples = None tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) if not kargs.get('use_tpu', True): tf.summary.scalar('gumbel_temperature', annealed_temp) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) if kargs.get('stable_gradient', True): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) tf.logging.info( "****** apply normal derivate for gradient calculation *******" ) else: sampled_logprob_temp, sampled_logprob = gumbel_softmax_custom_grad( flat_logits_tempered, temperature=annealed_temp, gumbel_samples=gumbel_samples, samples=config.get('gen_sample', 1)) tf.logging.info( "****** apply log deriviate for stable gradient calculation *******" ) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', False): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get('if_flip_grad', True): tf.logging.info("****** apply gradient flipping *******") sampled_logprob_temp_1 = flip_gradient(sampled_logprob_temp) else: tf.logging.info("****** not apply gradient flipping *******") sampled_logprob_temp_1 = sampled_logprob_temp if kargs.get("straight_through", True): tf.logging.info("****** apply straight_through_estimator *******") sampled_id = tf.stop_gradient(sampled_hard_id - sampled_logprob_temp) + ( sampled_logprob_temp_1) else: tf.logging.info("****** apply gumbel-softmax probs *******") sampled_id = sampled_logprob_temp_1 sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids_1 = input_ori_ids input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 ori_input_mask = tf.expand_dims(input_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32) + (1 - unsampled_mask) * tf.cast( ori_input_mask, tf.float32) * tf.cast( input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids tf.logging.info("====generator use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging generator loss ====") sampled_not_equal_id = tf.not_equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_equal_id = tf.equal( tf.cast(tf.argmax(sampled_input_id, axis=2), tf.int32), tf.cast(tf.argmax(input_ori_ids, axis=2), tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) label_diff_ids_my = tf.cast(label_diff_ids, tf.float32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast( sampled_equal_id, tf.float32) * tf.cast( tf.squeeze(unsampled_mask, axis=-1), tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) label_diff_ids_my = tf.cast(unsampled_mask, tf.float32) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) tf.summary.scalar("generator_valid_token", tf.reduce_sum(input_mask)) sampled_hard_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id sampled_hard_id = tf.cast(sampled_hard_id, tf.float32) sampled_hard_id = tf.reshape( sampled_hard_id, [batch_size, seq_length, config.vocab_size]) sampled_soft_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) sampled_hard_id *= label_diff_ids_my sampled_soft_id *= label_diff_ids_my hard_soft_bias = tf.reduce_sum( tf.sqrt( tf.reduce_sum(tf.pow(sampled_hard_id - sampled_soft_id, 2), axis=-1))) / (1e-10 + tf.reduce_sum( tf.cast(label_diff_ids_my, tf.float32))) tf.summary.scalar('soft_hard_bias', hard_soft_bias) return sampled_input_id
def classifier(config, seq_output, input_ids, sampled_ids, input_mask, num_labels, dropout_prob, **kargs): """ input_ids: original input ids sampled_ids: generated fake ids """ output_layer = seq_output hidden_size = output_layer.shape[-1].value unk_mask = tf.cast(tf.math.equal(input_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask input_mask = tf.cast(input_mask, tf.int32) input_mask *= tf.cast( 1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) if config.get('ln_type', 'postln') == 'preln': output_layer = albert_modules.layer_norm(output_layer) print('====preln transformer====') elif config.get('ln_type', 'postln') == 'postln': output_layer = output_layer print('====postln transformer====') else: output_layer = output_layer print('====no layer layer_norm====') output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob) logits = tf.einsum("abc,dc->abd", output_layer, output_weights) logits = tf.nn.bias_add(logits, output_bias) # batch x seq_length x 2 input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(sampled_ids, expected_rank=[2, 3]) if len(input_shape_list) == 3: tmp_sampled_ids = tf.argmax(sampled_ids, axis=-1) # batch x seq x vocab tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_sampled_ids = sampled_ids tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** normal 2-D sampled_ids *******") ori_sampled_ids = kargs.get('ori_sampled_ids', None) if ori_sampled_ids is not None: input_shape_list = bert_utils.get_shape_list(ori_sampled_ids, expected_rank=[2, 3]) if len(input_shape_list) == 3: tmp_ori_sampled_ids = tf.argmax(ori_sampled_ids, axis=-1) # batch x seq x vocab tmp_ori_sampled_ids = tf.cast(tmp_sampled_ori_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_ori_sampled_ids = tf.cast(ori_sampled_ids, tf.int32) tf.logging.info("****** normal 2-D sampled_ids *******") masked_not_equal_mask = tf.cast( tf.not_equal(input_ids, tmp_ori_sampled_ids), tf.int32) masked_not_equal_mask *= tf.cast(input_mask, tf.int32) else: masked_not_equal_mask = None if masked_not_equal_mask is not None: tf.logging.info( "****** loss mask using masked token mask for masked tokens *******" ) loss_mask = masked_not_equal_mask else: tf.logging.info( "****** loss mask using input_mask for all tokens *******") loss_mask = input_mask # original:0, replace:1 not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids), tf.int32) not_equal_label_ids *= tf.cast(input_mask, tf.int32) if kargs.get('loss', 'cross_entropy') == 'cross_entropy': per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(not_equal_label_ids)) elif kargs.get('loss', 'cross_entropy') == 'focal_loss': input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape_list[0] seq_length = input_shape_list[1] not_equal_label_ids_ = tf.reshape(not_equal_label_ids, [batch_size * seq_length]) logits_ = tf.reshape(logits, [batch_size * seq_length, -1]) per_example_loss, _ = loss_utils.focal_loss_binary_v2( config, logits_, not_equal_label_ids_) per_example_loss = tf.reshape(per_example_loss, [batch_size, seq_length]) # loss = per_example_loss * tf.cast(loss_mask, tf.float32) # loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast( loss_mask, tf.float32) equal_loss = tf.reduce_sum(per_example_loss * equal_label_ids) equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids)) not_equal_loss = tf.reduce_sum( per_example_loss * tf.cast(not_equal_label_ids, tf.float32)) # not equal:1, equal:0 not_equal_loss_output = not_equal_loss / ( 1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32))) loss = (equal_loss + not_equal_loss) / ( 1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) tf.logging.info("====discriminator classifier use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging discriminator loss ====") tf.summary.scalar('mask_based_loss', loss) tf.summary.scalar( 'equal_loss', equal_loss / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) tf.summary.scalar( 'not_equal_loss', not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) tf.summary.scalar( 'loss_decomposition', loss - (equal_loss + not_equal_loss) / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) return (loss, logits, per_example_loss)
def get_masked_lm_output(config, input_tensor, output_weights, positions, label_ids, label_weights, **kargs): reuse = kargs.get('reuse', False) embedding_projection = kargs.get('embedding_projection', None) """Get loss and log probs for the masked LM.""" input_tensor = tf.cast(input_tensor, tf.float32) positions = tf.cast(positions, tf.int32) label_ids = tf.cast(label_ids, tf.int32) label_weights = tf.cast(label_weights, tf.float32) input_tensor = bert_utils.gather_indexes(input_tensor, positions) """ flatten masked lm ids with positions """ scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=reuse): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): # We apply one more non-linear transformation before the output layer. # This matrix is not used after pre-training. if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: input_tensor = tf.matmul(input_tensor, embedding_projection, transpose_b=True) else: print("==no need for embedding projection==") input_tensor = input_tensor # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) logits = tf.matmul(input_tensor, output_weights, transpose_b=True) # logits = tf.multiply(logits, # 1.0 / math.sqrt(float(config.hidden_size))) # logits *= 2 logits = tf.nn.bias_add(logits, output_bias) log_probs = tf.nn.log_softmax(logits, axis=-1) label_ids = tf.reshape(label_ids, [-1]) label_weights = tf.reshape(label_weights, [-1]) # one_hot_labels = tf.one_hot( # label_ids, depth=config.vocab_size, dtype=tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=tf.stop_gradient(label_ids), logits=logits) # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) numerator = tf.reduce_sum(label_weights * per_example_loss) denominator = tf.reduce_sum(label_weights) + 1e-5 # The `positions` tensor might be zero-padded (if the sequence is too # short to have the maximum number of predictions). The `label_weights` # tensor has a value of 1.0 for every real prediction and 0.0 for the # padding predictions. # per_example_loss = -tf.reduce_sum(log_probs * one_hot_labels, axis=[-1]) # numerator = tf.reduce_sum(label_weights * per_example_loss) # denominator = tf.reduce_sum(label_weights) + 1e-5 loss = numerator / denominator return (loss, per_example_loss, log_probs, label_weights)
def transformer_cell(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False, shared_type=None, adapter_fn=None): layer_input = bert_utils.reshape_to_matrix(input_tensor) with tf.variable_scope("layer_shared", reuse=tf.AUTO_REUSE): with tf.variable_scope("attention", reuse=tf.AUTO_REUSE): attention_heads = [] with tf.variable_scope("self"): [attention_head, attention_scores] = albert_modules.attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) all_attention_scores.append(attention_scores) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope("output", reuse=tf.AUTO_REUSE): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=albert_modules.create_initializer( initializer_range)) attention_output = albert_modules.dropout( attention_output, hidden_dropout_prob) if adapter_fn: attention_output = adapter_fn(attention_output, layer_idx=layer_idx) attention_output = albert_modules.layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.variable_scope('intermediate', reuse=tf.AUTO_REUSE): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=albert_modules.create_initializer( initializer_range)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope('output', reuse=tf.AUTO_REUSE): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=albert_modules.create_initializer( initializer_range)) layer_output = albert_modules.dropout(layer_output, hidden_dropout_prob) if adapter_fn: layer_output = adapter_fn(attention_output, layer_idx=layer_idx) layer_output = albert_modules.layer_norm(layer_output + attention_output)
def classifier(config, pooled_output, num_labels, labels, dropout_prob, ratio_weight=None, **kargs): output_layer = pooled_output hidden_size = output_layer.shape[-1].value output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) if config.get('ln_type', 'postln') == 'preln': output_layer = albert_modules.layer_norm(output_layer) elif config.get('ln_type', 'postln') == 'postln': output_layer = output_layer else: output_layer = output_layer output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob) logits = tf.matmul(output_layer, output_weights, transpose_b=True) logits = tf.nn.bias_add(logits, output_bias) if config.get("label_type", "single_label") == "single_label": if config.get("loss", "entropy") == "entropy": print("==standard cross entropy==") tf.logging.info("****** loss type ******* %s", "entropy") per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(labels)) elif config.get("loss", "entropy") == "focal_loss": print("==multi_label focal loss==") per_example_loss, _ = loss_utils.focal_loss_multi_v1(config, logits=logits, labels=labels) elif config.get("loss", "entropy") == "dmi_loss": tf.logging.info("****** loss type ******* %s", "dmi_loss") loss, per_example_loss = loss_utils.dmi_loss(config, logits=logits, labels=labels, **kargs) try: per_example_loss = loss_utils.weighted_loss_ratio( config, per_example_loss, labels, ratio_weight) loss = tf.reduce_sum(per_example_loss) print(" == applying weighted loss == ") except: if config.get("loss", "entropy") in ["entropy", "focal_loss"]: loss = tf.reduce_mean(per_example_loss) elif config.get("loss", "entropy") == "dmi_loss": tf.logging.info( "****** dmi loss need no further calculation ******* ") loss = loss if config.get("with_center_loss", "no") == "center_loss": print("==apply with center loss==") center_loss, _ = loss_utils.center_loss_v2(config, features=pooled_output, labels=labels) loss += center_loss * config.get("center_loss_coef", 1e-3) return (loss, per_example_loss, logits) elif config.get("label_type", "single_label") == "multi_label": # logits = tf.log_sigmoid(logits) per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(tf.cast(labels, tf.float32))) per_example_loss = tf.reduce_sum(per_example_loss, axis=-1) loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, logits) else: raise NotImplementedError()
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] if not kargs.get("apply_valid_vocab", False): logits = logits tf.logging.info("****** normal logits *******") elif kargs.get("apply_valid_vocab", False) == 'topk': prob, _ = top_k_softmax(logits, kargs.get('topk', 10)) logits = tf.log(prob + 1e-10) tf.logging.info("****** topk logits *******") else: invalid_size = kargs.get("invalid_size", 106) invalid_mask = tf.cast( tf.ones((1, invalid_size)) * (-10000), tf.float32) valid_mask = tf.cast( tf.zeros((1, config.vocab_size - invalid_size)), tf.float32) invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1) # invaild_mask = tf.expand_dims(invaild_mask, axis=1) # batch x seq x vocab logits += tf.cast(invaild_mask, tf.float32) tf.logging.info( "****** only valid logits ******* , invalid size: %s", str(invalid_size)) logits_tempered = tf.nn.log_softmax(logits / config.get("temperature", 1.0)) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2 if not kargs.get("greedy", False): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=1.0, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) samples = tf.argmax(sampled_logprob, axis=1) # batch x seq tf.logging.info("****** normal sample *******") else: samples = tf.argmax(flat_logits_tempered, axis=-1) tf.logging.info("****** greedy sample *******") # samples = tf.multinomial(flat_logits_tempered, # num_samples=config.get('gen_sample', 1), # output_dtype=tf.int32) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = sampled_binary_mask # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") if not kargs.get('use_tpu', True): tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': input_ori_ids_1 = input_ori_ids unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32 ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) input_mask = tf.expand_dims(input_mask, axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.float32) if not kargs.get('use_tpu', True): sampled_not_equal_id = tf.not_equal( tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast(sampled_equal_id, tf.float32) * tf.cast( unsampled_mask, tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) # sampled_not_equal_id = tf.not_equal( # tf.cast(sampled_input_id, tf.int32), # tf.cast(input_ori_ids, tf.int32) # ) # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32) # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) # if not kargs.get('use_tpu', True): # tf.summary.scalar('generator_sample_acc', # sampled_not_equal) return sampled_input_id
def seq_mask_masked_lm_output(config, input_tensor, output_weights, input_mask, input_ori_ids, input_ids, sampled_binary_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) """ if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1 """ sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(input_ori_ids), ) per_example_loss *= sampled_binary_mask loss = tf.reduce_sum(per_example_loss) / (1e-10 + tf.reduce_sum(sampled_binary_mask)) return (loss, per_example_loss, logits, sampled_binary_mask)
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] logits_tempered = logits / config.get("temperature", 1.0) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size / 2)) samples = tf.multinomial(flat_logits_tempered_topk, num_samples=config.get('gen_sample', 1), output_dtype=tf.int32) label_diff_ids = tf.equal(tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32)) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'all') == 'only_mask': label_diff_ids = tf.cast(label_diff_ids, tf.float32) samples = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + label_diff_ids * tf.cast( input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'all') == 'only_mask': # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (1 - label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + input_ori_ids * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) return sampled_input_id