def build_encoder(self, input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, **kargs): reuse = kargs["reuse"] input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) if len(input_shape) == 3: tmp_input_ids = tf.argmax(input_ids, axis=-1) else: tmp_input_ids = input_ids attention_mask = bert_modules.create_attention_mask_from_input_mask( tmp_input_ids, input_mask) seq_type = kargs.get('seq_type', "None") if seq_type == "seq2seq": if kargs.get("mask_type", "left2right") == "left2right": mask_sequence = input_mask tf.logging.info( "==apply left2right LM model with casual mask==") elif kargs.get("mask_type", "left2right") == "seq2seq": token_type_ids = kargs.get("token_type_ids", None) tf.logging.info( "==apply left2right LM model with conditional casual mask==" ) if token_type_ids is None: token_type_ids = tf.zeros( shape=[batch_size, seq_length], dtype=tf.int32) tf.logging.info( "==conditional mask is set to 0 and degenerate to left2right LM model==" ) mask_sequence = token_type_ids attention_mask = bert_utils.generate_seq2seq_mask( attention_mask, mask_sequence, seq_type, **kargs) else: tf.logging.info( "==apply bi-directional LM model with bi-directional mask==" ) # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. if kargs.get('attention_type', 'efficient_attention') == 'normal_attention': tf.logging.info("****** normal attention *******") transformer_model = bert_modules.transformer_model elif kargs.get('attention_type', 'efficient_attention') == 'efficient_attention': tf.logging.info("****** efficient attention *******") transformer_model = bert_modules.transformer_efficient_model elif kargs.get('attention_type', 'efficient_attention') == 'rezero_transformer': transformer_model = bert_modules.transformer_rezero_model tf.logging.info("****** rezero_transformer *******") else: tf.logging.info("****** normal attention *******") transformer_model = bert_modules.transformer_model [ self.all_encoder_layers, self.all_attention_scores, self.all_value_outputs ] = transformer_model( input_tensor=self.embedding_output, attention_mask=attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_modules.get_activation( self.config.hidden_act), hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=self.config.initializer_range, do_return_all_layers=True, attention_fixed_size=self.config.get( 'attention_fixed_size', None))
def model_fn(features, labels, mode): shape_lst_a = bert_utils.get_shape_list(features['input_ids_a']) batch_size_a = shape_lst_a[0] total_length_a = shape_lst_a[1] shape_lst_b = bert_utils.get_shape_list(features['input_ids_b']) batch_size_b = shape_lst_b[0] total_length_b = shape_lst_b[1] features['input_ids_a'] = tf.reshape(features['input_ids_a'], [-1, model_config.max_length]) features['segment_ids_a'] = tf.reshape(features['segment_ids_a'], [-1, model_config.max_length]) features['input_mask_a'] = tf.cast( tf.not_equal(features['input_ids_a'], kargs.get('[PAD]', 0)), tf.int64) features['input_ids_b'] = tf.reshape( features['input_ids_b'], [-1, model_config.max_predictions_per_seq]) features['segment_ids_b'] = tf.reshape( features['segment_ids_b'], [-1, model_config.max_predictions_per_seq]) features['input_mask_b'] = tf.cast( tf.not_equal(features['input_ids_b'], kargs.get('[PAD]', 0)), tf.int64) features['batch_size'] = batch_size_a features['total_length_a'] = total_length_a features['total_length_b'] = total_length_b model_dict = {} for target in ["a", "b"]: model = bert_encoder(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) model_dict[target] = model if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits, transition_params) = multi_position_crf_classifier( model_config, features, model_dict, num_labels, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) train_op, hooks = model_io_fn.get_ema_hooks( train_op, tvars, kargs.get('params_moving_average_decay', 0.99), scope, mode, first_stage_steps=opt_config.num_warmup_steps, two_stage=True) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") label_weights = tf.cast(features['label_weights'], tf.int32) label_seq_length = tf.reduce_sum(label_weights, axis=-1) decode_tags, best_score = tf.contrib.crf.crf_decode( logits, transition_params, label_seq_length) _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'decode_tags': decode_tags, "best_score": best_score, "transition_params": transition_params, "logits": logits }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'decode_tags': decode_tags, "best_score": best_score, "transition_params": transition_params, "logits": logits }) }, prediction_hooks=[hooks]) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) eval_hooks = [] if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss), "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = eval_logtis(logits, features, num_labels, transition_params) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks) return estimator_spec else: raise NotImplementedError()
def embedding_postprocessor(input_tensor, use_token_type=False, token_type_ids=None, token_type_vocab_size=16, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=0.02, max_position_embeddings=512, dropout_prob=0.1): """Performs various post-processing on a word embedding tensor. Args: input_tensor: float Tensor of shape [batch_size, seq_length, embedding_size]. use_token_type: bool. Whether to add embeddings for `token_type_ids`. token_type_ids: (optional) int32 Tensor of shape [batch_size, seq_length]. Must be specified if `use_token_type` is True. token_type_vocab_size: int. The vocabulary size of `token_type_ids`. token_type_embedding_name: string. The name of the embedding table variable for token type ids. use_position_embeddings: bool. Whether to add position embeddings for the position of each token in the sequence. position_embedding_name: string. The name of the embedding table variable for positional embeddings. initializer_range: float. Range of the weight initialization. max_position_embeddings: int. Maximum sequence length that might ever be used with this model. This can be longer than the sequence length of input_tensor, but cannot be shorter. dropout_prob: float. Dropout probability applied to the final output tensor. Returns: float tensor with same shape as `input_tensor`. Raises: ValueError: One of the tensor shapes or input values is invalid. """ input_shape = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] width = input_shape[2] if seq_length > max_position_embeddings: raise ValueError("The seq length (%d) cannot be greater than " "`max_position_embeddings` (%d)" % (seq_length, max_position_embeddings)) output = input_tensor if use_token_type: if token_type_ids is None: raise ValueError("`token_type_ids` must be specified if" "`use_token_type` is True.") token_type_table = tf.get_variable( name=token_type_embedding_name, shape=[token_type_vocab_size, width], initializer=create_initializer(initializer_range)) # This vocab will be small so we always do one-hot here, since it is always # faster for a small vocabulary. flat_token_type_ids = tf.reshape(token_type_ids, [-1]) one_hot_ids = tf.one_hot(flat_token_type_ids, depth=token_type_vocab_size) token_type_embeddings = tf.matmul(one_hot_ids, token_type_table) token_type_embeddings = tf.reshape(token_type_embeddings, [batch_size, seq_length, width]) output += token_type_embeddings if use_position_embeddings: full_position_embeddings = tf.get_variable( name=position_embedding_name, shape=[max_position_embeddings, width], initializer=create_initializer(initializer_range)) # Since the position embedding table is a learned variable, we create it # using a (long) sequence length `max_position_embeddings`. The actual # sequence length might be shorter than this, for faster training of # tasks that do not have long sequences. # # So `full_position_embeddings` is effectively an embedding table # for position [0, 1, 2, ..., max_position_embeddings-1], and the current # sequence has positions [0, 1, 2, ... seq_length-1], so we can just # perform a slice. if seq_length < max_position_embeddings: position_embeddings = tf.slice(full_position_embeddings, [0, 0], [seq_length, -1]) else: position_embeddings = full_position_embeddings # position_embeddings = tf.cond(tf.less(seq_length, max_position_embeddings), # lambda:tf.slice(full_position_embeddings, [0, 0], # [seq_length, -1]), # lambda:full_position_embeddings) num_dims = len(output.shape.as_list()) # Only the last two dimensions are relevant (`seq_length` and `width`), so # we broadcast among the first dimensions, which is typically just # the batch size. position_broadcast_shape = [] for _ in range(num_dims - 2): position_broadcast_shape.append(1) position_broadcast_shape.extend([seq_length, width]) position_embeddings = tf.reshape(position_embeddings, position_broadcast_shape) output += position_embeddings output = layer_norm_and_dropout(output, dropout_prob) return output
def token_generator_igr(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] # logits_tempered = tf.nn.log_softmax(logits, axis=-1) # width=config.vocab_size flat_logits_tempered = tf.reshape(logits, [batch_size * seq_length, width]) num_train_steps = kargs.get('num_train_steps', None) if num_train_steps and kargs.get('gumbel_anneal', "anneal") == 'anneal': tf.logging.info("****** apply annealed temperature ******* %s", str(num_train_steps)) annealed_temp = tf.train.polynomial_decay( config.get('gumbel_temperature', 1.0), tf.train.get_or_create_global_step(), kargs.get("num_train_steps", 10000), end_learning_rate=0.1, power=1.0, cycle=False) elif kargs.get('gumbel_anneal', "anneal") == 'softplus': tf.logging.info("****** apply auto-scale temperature *******") # batch x seq x dim with tf.variable_scope("gumbel_auto_scaling_temperature"): annealed_temp = tf.layers.dense( input_tensor, 1, activation=tf.nn.softplus, ) + 1.0 annealed_temp = 1. / annealed_temp annealed_temp = tf.reshape(annealed_temp, [batch_size * seq_length, 1]) if config.get('gen_sample', 1) > 1: tf.logging.info( "****** apply auto-scale temperature for multi-sampling *******" ) annealed_temp = tf.expand_dims(annealed_temp, -1) else: annealed_temp = 0.01 tf.logging.info( "****** not apply annealed tenperature with fixed temp ******* %s", str(annealed_temp)) # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) sampled_logprob_temp, sampled_logprob = iso_gaussian_sample( flat_logits_tempered, temperature=annealed_temp, samples=config.get('gen_sample', 1)) # argmax on config.vocab_size which is always axis=1 # [batch x seq] x config.vocab_size x config.get('gen_sample', 1) # armax(logits+gumbel_samples) to sample a categoritical distribution if kargs.get('sampled_prob_id', True): tf.logging.info( "****** apply categorical sampled id of original logits *******" ) sampled_id = tf.one_hot(tf.argmax(sampled_logprob, axis=1), config.vocab_size, axis=1) # sampled multiminal id else: tf.logging.info( "****** apply gumbel-softmax logprob for logits *******") sampled_id = tf.one_hot(tf.argmax(sampled_logprob_temp, axis=1), config.vocab_size, axis=1) # sampled multiminal id # straight-through gumbel softmax estimator if kargs.get("straight_through", True): tf.logging.info("****** apply straight_through_estimator *******") sampled_id = tf.stop_gradient( sampled_id - sampled_logprob_temp) + flip_gradient(sampled_logprob_temp) else: tf.logging.info("****** apply gumbel-softmax probs *******") sampled_id = flip_gradient(sampled_logprob_temp) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = tf.identity( sampled_binary_mask) # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) label_diff_ids = tf.expand_dims(label_diff_ids, axis=[-1]) # batch x seq x 1 input_ori_ids = tf.one_hot(input_ori_ids, config.vocab_size) # batch x seq x vocab input_ori_ids = tf.cast(input_ori_ids, tf.float32) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape( sampled_id, [batch_size, seq_length, config.vocab_size]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) else: sampled_input_id = tf.reshape(samples, [ batch_size, seq_length, config.vocab_size, config.get('gen_sample', 1) ]) label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) # batch x seq x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) # batch x seq x vocab x 1 if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids return sampled_input_id
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if kargs.get('random_generator', '1') == '1': if mode in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN ]: input_ori_ids = features['input_ori_ids'] # [output_ids, # sampled_binary_mask] = random_input_ids_generation(model_config, # features['input_ori_ids'], # features['input_mask'], # mask_probability=0.2, # replace_probability=0.1, # original_probability=0.1, # **kargs) [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.2, replace_probability=0.0, original_probability=0.0, mask_prior=tf.constant(mask_prior, tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info("****** do random generator *******") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope=generator_scope_prefix) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) masked_lm_ids = features['input_ori_ids'] else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss if kargs.get("resample_discriminator", False): input_ori_ids = features['input_ori_ids'] [output_ids, sampled_binary_mask ] = random_input_ids_generation(model_config, features['input_ori_ids'], features['input_mask'], mask_probability=0.2, replace_probability=0.1, original_probability=0.1) resample_features = {} for key in features: resample_features[key] = features[key] resample_features['input_ids'] = tf.identity(output_ids) model_resample = model_api(model_config, resample_features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) tf.logging.info("**** apply discriminator resample **** ") else: model_resample = model resample_features = features tf.logging.info("**** not apply discriminator resample **** ") sampled_ids = token_generator(model_config, model_resample.get_sequence_output(), model_resample.get_embedding_table(), resample_features['input_ids'], resample_features['input_ori_ids'], resample_features['input_mask'], embedding_projection=model_resample. get_embedding_projection_table(), scope=generator_scope_prefix, mask_method='only_mask', use_tpu=kargs.get('use_tpu', True)) if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] else: input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1) # batch x seq_length x 1 input_ids = tf.einsum( 'abc,cd->abd', input_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] gen_sample = input_shape_list[2] sampled_ids = tf.reshape(sampled_ids, [batch * gen_sample, seq_length]) input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length]) input_mask = tf.expand_dims(features['input_mask'], axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.int32) segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1) segment_ids = tf.einsum( 'abc,cd->abd', segment_ids, tf.ones((1, model_config.get('gen_sample', 1)))) segment_ids = tf.cast(segment_ids, tf.int32) segment_ids = tf.reshape(segment_ids, [batch * gen_sample, seq_length]) input_mask = tf.reshape(input_mask, [batch * gen_sample, seq_length]) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if generator_scope_prefix: """ "generator/cls/predictions" """ lm_pretrain_tvars = model_io_fn.get_params( generator_scope_prefix + "/cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( generator_scope_prefix + "/cls/seq_relationship", not_storage_params=not_storage_params) else: lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( "cls/seq_relationship", not_storage_params=not_storage_params) if model_config.get('embedding_scope', None) is not None: embedding_tvars = model_io_fn.get_params( model_config.get('embedding_scope', 'bert') + "/embeddings", not_storage_params=not_storage_params) pretrained_tvars.extend(embedding_tvars) pretrained_tvars.extend(lm_pretrain_tvars) pretrained_tvars.extend(nsp_pretrain_vars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None tf.add_to_collection("generator_loss", loss) return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_mask, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss, "next_sentence_example_loss": nsp_per_example_loss, "next_sentence_log_probs": nsp_log_prob, "next_sentence_labels": features['next_sentence_labels'], "sampled_binary_mask": sampled_binary_mask } return return_dict
def get_losses(d_out_real, d_out_fake, **kargs): # 1:original, 0:fake input_shape_list = bert_utils.get_shape_list(d_out_real, expected_rank=[1, 2, 3]) batch_size = input_shape_list[0] gan_type = kargs.get('gan_type', 'standard') tf.logging.info("**** gan type **** %s", str(gan_type)) if gan_type == 'standard': # the non-satuating GAN loss d_loss_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_real, labels=tf.cast(tf.ones(batch_size), tf.int32))) d_loss_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_fake, labels=tf.cast(tf.zeros(batch_size), tf.int32))) d_loss = d_loss_real + d_loss_fake g_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_fake, labels=tf.cast(tf.ones(batch_size), tf.int32))) tf.logging.info("**** gan type **** %s", str(gan_type)) elif gan_type == 'JS': # the vanilla GAN loss d_loss_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_real, labels=tf.cast(tf.ones(batch_size), tf.int32))) d_loss_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_fake, labels=tf.cast(tf.zeros(batch_size), tf.int32))) d_loss = d_loss_real + d_loss_fake g_loss = -d_loss_fake tf.logging.info("**** gan type **** %s", str(gan_type)) elif gan_type == 'KL': # the GAN loss implicitly minimizing KL-divergence d_loss_real = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_real, labels=tf.cast(tf.ones(batch_size), tf.int32))) d_loss_fake = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_fake, labels=tf.cast(tf.zeros(batch_size), tf.int32))) d_loss = d_loss_real + d_loss_fake g_loss = tf.reduce_mean(-d_out_fake) tf.logging.info("**** gan type **** %s", str(gan_type)) elif gan_type == 'hinge': # the hinge loss d_loss_real = tf.reduce_mean(tf.nn.relu(1.0 - d_out_real)) d_loss_fake = tf.reduce_mean(tf.nn.relu(1.0 + d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = -tf.reduce_mean(d_out_fake) tf.logging.info("**** gan type **** %s", str(gan_type)) elif gan_type == 'tv': # the total variation distance d_loss = tf.reduce_mean(tf.tanh(d_out_fake) - tf.tanh(d_out_real)) g_loss = tf.reduce_mean(-tf.tanh(d_out_fake)) tf.logging.info("**** gan type **** %s", str(gan_type)) # elif gan_type == 'wgan-gp': # WGAN-GP # d_loss = tf.reduce_mean(d_out_fake) - tf.reduce_mean(d_out_real) # GP = gradient_penalty(discriminator, x_real_onehot, x_fake_onehot_appr, config) # d_loss += GP # g_loss = -tf.reduce_mean(d_out_fake) elif gan_type == 'LS': # LS-GAN d_loss_real = tf.reduce_mean(tf.squared_difference(d_out_real, 1.0)) d_loss_fake = tf.reduce_mean(tf.square(d_out_fake)) d_loss = d_loss_real + d_loss_fake g_loss = tf.reduce_mean(tf.squared_difference(d_out_fake, 1.0)) tf.logging.info("**** gan type **** %s", str(gan_type)) elif gan_type == 'RSGAN': # relativistic standard GAN d_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_real - d_out_fake, labels=tf.cast(tf.ones(batch_size), tf.int32))) g_loss = tf.reduce_mean( tf.nn.sparse_softmax_cross_entropy_with_logits( logits=d_out_fake - d_out_real, labels=tf.cast(tf.ones(batch_size), tf.int32))) tf.logging.info("**** gan type **** %s", str(gan_type)) else: raise NotImplementedError("Divergence '%s' is not implemented" % gan_type) if not kargs.get('use_tpu', True): tf.logging.info("====logging discriminator global loss ====") tf.summary.scalar('disc_loss', d_loss) tf.summary.scalar('gen_loss', g_loss) return {"gen_loss": g_loss, "disc_loss": d_loss}
def classifier(config, seq_output, input_ids, sampled_ids, input_mask, num_labels, dropout_prob, **kargs): """ input_ids: original input ids sampled_ids: generated fake ids """ output_layer = seq_output hidden_size = output_layer.shape[-1].value unk_mask = tf.cast(tf.math.equal(input_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask input_mask = tf.cast(input_mask, tf.int32) input_mask *= tf.cast( 1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original output_weights = tf.get_variable( "output_weights", [num_labels, hidden_size], initializer=tf.truncated_normal_initializer(stddev=0.02)) output_bias = tf.get_variable("output_bias", [num_labels], initializer=tf.zeros_initializer()) if config.get('ln_type', 'postln') == 'preln': output_layer = albert_modules.layer_norm(output_layer) print('====preln transformer====') elif config.get('ln_type', 'postln') == 'postln': output_layer = output_layer print('====postln transformer====') else: output_layer = output_layer print('====no layer layer_norm====') output_layer = tf.nn.dropout(output_layer, keep_prob=1 - dropout_prob) logits = tf.einsum("abc,dc->abd", output_layer, output_weights) logits = tf.nn.bias_add(logits, output_bias) # batch x seq_length x 2 input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(sampled_ids, expected_rank=[2, 3]) if len(input_shape_list) == 3: tmp_sampled_ids = tf.argmax(sampled_ids, axis=-1) # batch x seq x vocab tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_sampled_ids = sampled_ids tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** normal 2-D sampled_ids *******") ori_sampled_ids = kargs.get('ori_sampled_ids', None) if ori_sampled_ids is not None: input_shape_list = bert_utils.get_shape_list(ori_sampled_ids, expected_rank=[2, 3]) if len(input_shape_list) == 3: tmp_ori_sampled_ids = tf.argmax(ori_sampled_ids, axis=-1) # batch x seq x vocab tmp_ori_sampled_ids = tf.cast(tmp_sampled_ori_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_ori_sampled_ids = tf.cast(ori_sampled_ids, tf.int32) tf.logging.info("****** normal 2-D sampled_ids *******") masked_not_equal_mask = tf.cast( tf.not_equal(input_ids, tmp_ori_sampled_ids), tf.int32) masked_not_equal_mask *= tf.cast(input_mask, tf.int32) else: masked_not_equal_mask = None if masked_not_equal_mask is not None: tf.logging.info( "****** loss mask using masked token mask for masked tokens *******" ) loss_mask = masked_not_equal_mask else: tf.logging.info( "****** loss mask using input_mask for all tokens *******") loss_mask = input_mask # original:0, replace:1 not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids), tf.int32) not_equal_label_ids *= tf.cast(input_mask, tf.int32) if kargs.get('loss', 'cross_entropy') == 'cross_entropy': per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(not_equal_label_ids)) elif kargs.get('loss', 'cross_entropy') == 'focal_loss': input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=2) batch_size = input_shape_list[0] seq_length = input_shape_list[1] not_equal_label_ids_ = tf.reshape(not_equal_label_ids, [batch_size * seq_length]) logits_ = tf.reshape(logits, [batch_size * seq_length, -1]) per_example_loss, _ = loss_utils.focal_loss_binary_v2( config, logits_, not_equal_label_ids_) per_example_loss = tf.reshape(per_example_loss, [batch_size, seq_length]) # loss = per_example_loss * tf.cast(loss_mask, tf.float32) # loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast( loss_mask, tf.float32) equal_loss = tf.reduce_sum(per_example_loss * equal_label_ids) equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids)) not_equal_loss = tf.reduce_sum( per_example_loss * tf.cast(not_equal_label_ids, tf.float32)) # not equal:1, equal:0 not_equal_loss_output = not_equal_loss / ( 1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32))) loss = (equal_loss + not_equal_loss) / ( 1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) tf.logging.info("====discriminator classifier use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging discriminator loss ====") tf.summary.scalar('mask_based_loss', loss) tf.summary.scalar( 'equal_loss', equal_loss / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) tf.summary.scalar( 'not_equal_loss', not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) tf.summary.scalar( 'loss_decomposition', loss - (equal_loss + not_equal_loss) / (1e-10 + tf.reduce_sum(tf.cast(input_mask, tf.float32)))) return (loss, logits, per_example_loss)
def optimal_discriminator(config, true_model_dict, true_features_dict, fake_model_dict, fake_features_dict, **kargs): alpha = (1-0.15)/0.15 sampled_ids = fake_features_dict['input_ids'] input_shape_list = bert_utils.get_shape_list(fake_features_dict["input_ori_ids"], expected_rank=[2,3]) batch_size = input_shape_list[0] seq_length = input_shape_list[1] true_logits = tf.exp(tf.nn.log_softmax(tf.reshape(true_model_dict['masked_lm_log_probs'], [-1, config.vocab_size]))) fake_logits = tf.exp(tf.nn.log_softmax(tf.reshape(fake_model_dict['masked_lm_log_probs'], [-1, config.vocab_size]))) labels = tf.reshape(sampled_ids, [-1, 1]) # [batch x seq, 1] batch_idxs = tf.range(0, tf.shape(labels)[0]) batch_idxs = tf.expand_dims(batch_idxs, 1) idxs = tf.concat([batch_idxs, labels], 1) y_true_pred = tf.gather_nd(true_logits, idxs) y_fake_pred = tf.gather_nd(fake_logits, idxs) disc_probs = (y_true_pred * (alpha+y_fake_pred)+1e-10) / ((y_fake_pred+alpha*y_true_pred+1e-10)) # batch x seq disc_probs = tf.expand_dims(disc_probs, axis=-1) # [batch x seq, 1] neg_probs = 1 - disc_probs + 1e-10 logits = tf.log(tf.concat([disc_probs, neg_probs], axis=-1)+1e-10) logits = tf.reshape(logits, [batch_size, seq_length, -1]) input_ids = tf.cast(fake_features_dict['input_ori_ids'], tf.int32) unk_mask = tf.cast(tf.equal(input_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.equal(input_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.equal(input_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask input_mask = fake_features_dict['input_mask'] input_mask = tf.cast(input_mask, tf.int32) input_mask *= tf.cast(1 - none_replace_mask, tf.int32) # cls, unk, sep are not considered as replace or original input_shape_list = bert_utils.get_shape_list(sampled_ids, expected_rank=[2,3]) if len(input_shape_list) == 3: tmp_sampled_ids = tf.argmax(sampled_ids, axis=-1) # batch x seq x vocab tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) tf.logging.info("****** gumbel 3-D sampled_ids *******") elif len(input_shape_list) == 2: tmp_sampled_ids = sampled_ids tmp_sampled_ids = tf.cast(tmp_sampled_ids, tf.int32) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: tf.logging.info("****** loss mask using masked token mask for masked tokens *******") loss_mask = sampled_binary_mask else: tf.logging.info("****** loss mask using input_mask for all tokens *******") loss_mask = input_mask not_equal_label_ids = tf.cast(tf.not_equal(input_ids, tmp_sampled_ids), tf.int32) not_equal_label_ids *= tf.cast(loss_mask, tf.int32) print(logits.get_shape(), "===disc logits shape==", not_equal_label_ids.get_shape(), "==label ids shape==") per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(not_equal_label_ids)) equal_label_ids = (1 - tf.cast(not_equal_label_ids, tf.float32)) * tf.cast(loss_mask, tf.float32) equal_per_example_loss = per_example_loss * equal_label_ids equal_loss = tf.reduce_sum(equal_per_example_loss) equal_loss_all = equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) equal_loss_output = equal_loss / (1e-10 + tf.reduce_sum(equal_label_ids)) not_equal_per_example_loss = per_example_loss * tf.cast(not_equal_label_ids, tf.float32) not_equal_loss = tf.reduce_sum(not_equal_per_example_loss) # not equal:1, equal:0 not_equal_loss_all = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) not_equal_loss_output = not_equal_loss / (1e-10 + tf.reduce_sum(tf.cast(not_equal_label_ids, tf.float32))) loss = (equal_loss + not_equal_loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) # loss = equal_loss_output + not_equal_loss_output * 0.1 tf.logging.info("====discriminator classifier use_tpu %s ====", str(kargs.get('use_tpu', True))) if not kargs.get('use_tpu', True): tf.logging.info("====logging discriminator loss ====") tf.summary.scalar('mask_based_loss', loss) loss = per_example_loss * tf.cast(loss_mask, tf.float32) loss = tf.reduce_sum(loss) / (1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32))) tf.summary.scalar('equal_loss', equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) tf.summary.scalar('not_equal_loss', not_equal_loss/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) tf.summary.scalar('loss_decomposition', loss - (equal_loss+not_equal_loss)/(1e-10 + tf.reduce_sum(tf.cast(loss_mask, tf.float32)))) return (loss, logits, per_example_loss)
def model_fn(features, labels, mode): task_type = kargs.get("task_type", "cls") num_task = kargs.get('num_task', 1) temp = kargs.get('temp', 0.1) print("==task_type==", task_type) model_io_fn = model_io.ModelIO(model_io_config) label_ids = tf.cast(features["{}_label_ids".format(task_type)], dtype=tf.int32) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob is_training = True else: dropout_prob = 0.0 is_training = False if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if kargs.get("get_pooled_output", "pooled_output") == "pooled_output": pooled_feature = model.get_pooled_output() elif kargs.get("get_pooled_output", "task_output") == "task_output": pooled_feature_dict = model.get_task_output() pooled_feature = pooled_feature_dict['pooled_feature'] shape_list = bert_utils.get_shape_list(pooled_feature_dict['feature_a'], expected_rank=[2]) batch_size = shape_list[0] if kargs.get('apply_head_proj', False): with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE): feature_a = simclr_utils.projection_head(pooled_feature_dict['feature_a'], is_training, head_proj_dim=128, num_nlh_layers=1, head_proj_mode='nonlinear', name='head_contrastive') pooled_feature_dict['feature_a'] = feature_a with tf.variable_scope(scope+"/head_proj", reuse=tf.AUTO_REUSE): feature_b = simclr_utils.projection_head(pooled_feature_dict['feature_b'], is_training, head_proj_dim=128, num_nlh_layers=1, head_proj_mode='nonlinear', name='head_contrastive') pooled_feature_dict['feature_b'] = feature_b tf.logging.info("****** apply contrastive feature projection *******") else: feature_a = pooled_feature_dict['feature_a'] feature_b = pooled_feature_dict['feature_b'] tf.logging.info("****** not apply projection *******") loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) if kargs.get('merge_mode', 'all') == 'all': input_ids = tf.concat([features['input_ids_a'], features['input_ids_b']], axis=0) hidden_repres = tf.concat([feature_a, feature_b], axis=0) sent_repres = tf.concat([pooled_feature_dict['sent_repres_a'], pooled_feature_dict['sent_repres_b']], axis=0) tf.logging.info("****** double batch *******") else: input_ids = features['input_ids_b'] hidden_repres = feature_b sent_repres = pooled_feature_dict['sent_repres_b'] tf.logging.info("****** single batch b *******") sequence_mask = tf.to_float(tf.not_equal(input_ids, kargs.get('[PAD]', 0))) with tf.variable_scope("vae/connect", reuse=tf.AUTO_REUSE): with tf.variable_scope("z_mean"): z_mean = tf.layers.dense( hidden_repres, 128, use_bias=None, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) bn_z_mean = vae_utils.mean_normalize_scale(z_mean, is_training, "bn_mean", tau=0.5, reuse=tf.AUTO_REUSE, **kargs) with tf.variable_scope("z_std"): z_std = tf.layers.dense( hidden_repres, 128, use_bias=True, activation=tf.nn.relu, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) bn_z_std = vae_utils.std_normalize_scale(z_std, is_training, "bn_std", tau=0.5, reuse=tf.AUTO_REUSE, **kargs) gaussian_noise = vae_utils.hidden_sampling(bn_z_mean, bn_z_std, **kargs) sent_repres_shape = bert_utils.get_shape_list(sent_repres, expected_rank=[3]) with tf.variable_scope("vae/projection"): gaussian_noise = tf.layers.dense( gaussian_noise, sent_repres_shape[-1], use_bias=None, activation=None, kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) sent_repres += tf.expand_dims(gaussian_noise, axis=1) with tf.variable_scope("vae/decoder", reuse=tf.AUTO_REUSE): sequence_output = dgcnn_utils.dgcnn( sent_repres, sequence_mask, num_layers=model_config['cnn_num_layers'], dilation_rates=model_config.get('cnn_dilation_rates', [1,2]), strides=model_config.get('cnn_dilation_rates', [1,1]), num_filters=model_config.get('cnn_num_filters', [128,128]), kernel_sizes=model_config.get('cnn_filter_sizes', [3,3]), is_training=is_training, scope_name="textcnn_encoder/textcnn/forward", reuse=tf.AUTO_REUSE, activation=tf.nn.relu, is_casual=model_config['is_casual'], padding=model_config.get('padding', 'same') ) sequence_output_logits = model.build_other_output_logits(sequence_output, reuse=tf.AUTO_REUSE) resc_loss = vae_utils.reconstruction_loss(sequence_output_logits, input_ids, name="decoder_resc", use_tpu=False) kl_loss = vae_utils.kl_loss(bn_z_mean, bn_z_std, opt_config.get('num_train_steps', 10000), name="kl_div", use_tpu=False, kl_anneal="kl_anneal") loss = resc_loss + kl_loss task_loss = loss params_size = model_io_fn.count_params(model_config.scope) print("==total encoder params==", params_size) if kargs.get("feature_distillation", True): universal_feature_a = features.get("input_ids_a_features", None) universal_feature_b = features.get("input_ids_b_features", None) if universal_feature_a is None or universal_feature_b is None: tf.logging.info("****** not apply feature distillation *******") feature_loss = tf.constant(0.0) else: feature_a = pooled_feature_dict['feature_a'] feature_a_shape = bert_utils.get_shape_list(feature_a, expected_rank=[2,3]) pretrain_feature_a_shape = bert_utils.get_shape_list(universal_feature_a, expected_rank=[2,3]) if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_a = tf.layers.dense(feature_a, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task) tf.logging.info("****** apply auto-encoder for feature compression *******") else: proj_feature_a = feature_a feature_a_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_a, 2), axis=-1, keepdims=True))+1e-20) proj_feature_a /= feature_a_norm feature_b = pooled_feature_dict['feature_b'] if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope+"/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_b = tf.layers.dense(feature_b, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task) tf.logging.info("****** apply auto-encoder for feature compression *******") else: proj_feature_b = feature_b feature_b_norm = tf.stop_gradient(tf.sqrt(tf.reduce_sum(tf.pow(proj_feature_b, 2), axis=-1, keepdims=True))+1e-20) proj_feature_b /= feature_b_norm feature_a_distillation = tf.reduce_mean(tf.square(universal_feature_a-proj_feature_a), axis=-1) feature_b_distillation = tf.reduce_mean(tf.square(universal_feature_b-proj_feature_b), axis=-1) feature_loss = tf.reduce_mean((feature_a_distillation + feature_b_distillation)/2.0)/float(num_task) loss += feature_loss tf.logging.info("****** apply prertained feature distillation *******") if kargs.get("embedding_distillation", True): word_embed = model.emb_mat random_embed_shape = bert_utils.get_shape_list(word_embed, expected_rank=[2,3]) print("==random_embed_shape==", random_embed_shape) pretrained_embed = kargs.get('pretrained_embed', None) if pretrained_embed is None: tf.logging.info("****** not apply prertained feature distillation *******") embed_loss = tf.constant(0.0) else: pretrain_embed_shape = bert_utils.get_shape_list(pretrained_embed, expected_rank=[2,3]) print("==pretrain_embed_shape==", pretrain_embed_shape) if random_embed_shape[-1] != pretrain_embed_shape[-1]: with tf.variable_scope(scope+"/embedding_proj", reuse=tf.AUTO_REUSE): proj_embed = tf.layers.dense(word_embed, pretrain_embed_shape[-1]) else: proj_embed = word_embed embed_loss = tf.reduce_mean(tf.reduce_mean(tf.square(proj_embed-pretrained_embed), axis=-1))/float(num_task) loss += embed_loss tf.logging.info("****** apply prertained feature distillation *******") if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config.get(task_type, {}).get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones((1, multi_task_config[task_type]["max_predictions_per_seq"])) masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, )) masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32) masked_lm_example_loss *= masked_lm_loss_mask# multiply task_mask masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (1e-10+tf.reduce_sum(masked_lm_loss_mask)) loss += multi_task_config[task_type]["masked_lm_loss_ratio"]*masked_lm_loss masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1]) print(masked_lm_log_probs.get_shape(), "===masked lm log probs===") print(masked_lm_label_ids.get_shape(), "===masked lm ids===") print(masked_lm_label_weights.get_shape(), "===masked lm mask===") lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask) if kargs.get("task_invariant", "no") == "yes": print("==apply task adversarial training==") with tf.variable_scope(scope+"/dann_task_invariant", reuse=model_reuse): (_, task_example_loss, task_logits) = distillation_utils.feature_distillation(model.get_pooled_output(), 1.0, features["task_id"], kargs.get("num_task", 7), dropout_prob, True) masked_task_example_loss = loss_mask * task_example_loss masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (1e-10+tf.reduce_sum(loss_mask)) loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) vae_tvars = model_io_fn.get_params("vae", not_storage_params=not_storage_params) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config.get(task_type, {}).get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", not_storage_params=not_storage_params) tvars.extend(masked_lm_pretrain_tvars) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") # print(tvars) if load_pretrained == "yes": [assignment_map, initialized_variable_names] = model_io_utils.get_assigment_map_from_checkpoint( tvars, init_checkpoint, exclude_scope="") [assignment_map_vae, initialized_variable_names_vae] = model_io_utils.get_assigment_map_from_checkpoint( vae_tvars, init_checkpoint, exclude_scope="vae/decoder") assignment_map.update(assignment_map_vae) initialized_variable_names.update(initialized_variable_names_vae) model_io_utils.init_pretrained(assignment_map, initialized_variable_names, tvars+vae_tvars, init_checkpoint) if mode == tf.estimator.ModeKeys.TRAIN: train_metric_dict = train_metric(input_ids, sequence_output_logits, **kargs) return_dict = { "loss":loss, "tvars":tvars+vae_tvars } return_dict["perplexity"] = train_metric_dict['perplexity'] return_dict["token_acc"] = train_metric_dict['token_acc'] return_dict["kl_div"] = kl_loss if kargs.get("task_invariant", "no") == "yes": return_dict["{}_task_loss".format(task_type)] = masked_task_loss task_acc = build_accuracy(task_logits, features["task_id"], loss_mask) return_dict["{}_task_acc".format(task_type)] = task_acc if multi_task_config.get(task_type, {}).get("lm_augumentation", False): return_dict["{}_masked_lm_loss".format(task_type)] = masked_lm_loss return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc if kargs.get("embedding_distillation", True): return_dict["embed_loss"] = embed_loss*float(num_task) else: return_dict["embed_loss"] = task_loss if kargs.get("feature_distillation", True): return_dict["feature_loss"] = feature_loss*float(num_task) else: return_dict["feature_loss"] = task_loss return_dict["task_loss"] = task_loss return return_dict elif mode == tf.estimator.ModeKeys.EVAL: eval_dict = { "loss":loss, "logits":logits, "feature":model.get_pooled_output() } if kargs.get("adversarial", "no") == "adversarial": eval_dict["task_logits"] = task_logits return eval_dict
def token_generator(config, input_tensor, output_weights, input_ids, input_ori_ids, input_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if config.get("embedding", "none_factorized") == "none_factorized": projection_width = config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = config.get('embedding_size', config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) input_shape_list = bert_utils.get_shape_list(logits, expected_rank=3) width = input_shape_list[2] if not kargs.get("apply_valid_vocab", False): logits = logits tf.logging.info("****** normal logits *******") elif kargs.get("apply_valid_vocab", False) == 'topk': prob, _ = top_k_softmax(logits, kargs.get('topk', 10)) logits = tf.log(prob + 1e-10) tf.logging.info("****** topk logits *******") else: invalid_size = kargs.get("invalid_size", 106) invalid_mask = tf.cast( tf.ones((1, invalid_size)) * (-10000), tf.float32) valid_mask = tf.cast( tf.zeros((1, config.vocab_size - invalid_size)), tf.float32) invaild_mask = tf.concat([invalid_mask, valid_mask], axis=-1) # invaild_mask = tf.expand_dims(invaild_mask, axis=1) # batch x seq x vocab logits += tf.cast(invaild_mask, tf.float32) tf.logging.info( "****** only valid logits ******* , invalid size: %s", str(invalid_size)) logits_tempered = tf.nn.log_softmax(logits / config.get("temperature", 1.0)) flat_logits_tempered = tf.reshape(logits_tempered, [batch_size * seq_length, width]) # flat_logits_tempered_topk = top_k_logits(flat_logits_tempered, int(config.vocab_size/2 if not kargs.get("greedy", False): sampled_logprob_temp, sampled_logprob = gumbel_softmax( flat_logits_tempered, temperature=1.0, samples=config.get('gen_sample', 1), greedy=kargs.get("greedy", False)) samples = tf.argmax(sampled_logprob, axis=1) # batch x seq tf.logging.info("****** normal sample *******") else: samples = tf.argmax(flat_logits_tempered, axis=-1) tf.logging.info("****** greedy sample *******") # samples = tf.multinomial(flat_logits_tempered, # num_samples=config.get('gen_sample', 1), # output_dtype=tf.int32) sampled_binary_mask = kargs.get('sampled_binary_mask', None) if sampled_binary_mask is not None: label_diff_ids = sampled_binary_mask # 0 for original and 1 for replace else: label_diff_ids = tf.not_equal( tf.cast(input_ids, tf.int32), tf.cast(input_ori_ids, tf.int32) # 0 for original and 1 for replace ) label_diff_ids = tf.cast(label_diff_ids, tf.float32) print(label_diff_ids, "===label diff ids===") if not kargs.get('use_tpu', True): tf.summary.scalar( 'label_diff_ids', tf.reduce_sum(label_diff_ids * tf.cast(input_mask, tf.float32)) / tf.reduce_sum(tf.cast(input_mask, tf.float32))) if config.get('gen_sample', 1) == 1: sampled_input_id = tf.reshape(samples, [batch_size, seq_length]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") label_diff_ids = tf.cast(label_diff_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32 ) + (1 - label_diff_ids) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) elif kargs.get('mask_method', 'only_mask') == 'all_mask': input_ori_ids_1 = input_ori_ids unk_mask = tf.cast(tf.math.equal(input_ori_ids_1, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids_1, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids_1, 102), tf.float32) # not replace sep unsampled_mask = (1 - (unk_mask + cls_mask + sep_mask)) * tf.cast( input_mask, tf.float32) # unsampled_mask = tf.expand_dims(unsampled_mask, axis=[-1]) # batch x seq x 1 tf.logging.info("****** all mask sample *******") sampled_input_id = unsampled_mask * tf.cast( sampled_input_id, tf.float32 ) + (1 - unsampled_mask) * tf.cast(input_ori_ids, tf.float32) sampled_input_id = tf.cast(sampled_input_id, tf.int32) else: sampled_input_id = tf.reshape( samples, [batch_size, seq_length, config.get('gen_sample', 1)]) if kargs.get('mask_method', 'only_mask') == 'only_mask': tf.logging.info("****** only mask sample *******") # batch x seq_length x 1 label_diff_ids = tf.expand_dims(label_diff_ids, axis=-1) label_diff_ids = tf.einsum( 'abc,cd->abd', label_diff_ids, tf.ones((1, model_config.get('gen_sample', 1)))) # batch x seq_length x 1 input_ori_ids = tf.expand_dims(input_ori_ids, axis=-1) input_ori_ids = tf.einsum( 'abc,cd->abd', input_ori_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ori_ids = tf.cast(input_ori_ids, tf.float32) sampled_input_id = (label_diff_ids) * tf.cast( sampled_input_id, tf.float32) + (1 - input_ori_ids) * label_diff_ids sampled_input_id = tf.cast(sampled_input_id, tf.int32) input_mask = tf.expand_dims(input_mask, axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.float32) if not kargs.get('use_tpu', True): sampled_not_equal_id = tf.not_equal( tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast( input_mask, tf.float32) sampled_equal_id = tf.equal(tf.cast(sampled_input_id, tf.int32), tf.cast(input_ori_ids, tf.int32)) if kargs.get('mask_method', 'only_mask') == 'only_mask': sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) elif kargs.get('mask_method', 'only_mask') == 'all_mask': sampled_equal = tf.cast(sampled_equal_id, tf.float32) * tf.cast( unsampled_mask, tf.float32) tf.summary.scalar('generator_equal_sample_acc', tf.reduce_sum(sampled_equal)) sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) sampled_equal = tf.reduce_sum(sampled_equal) / ( 1e-10 + tf.reduce_sum(tf.cast(unsampled_mask, tf.float32))) tf.summary.scalar('generator_sample_acc', sampled_not_equal) # sampled_not_equal_id = tf.not_equal( # tf.cast(sampled_input_id, tf.int32), # tf.cast(input_ori_ids, tf.int32) # ) # sampled_not_equal = tf.cast(sampled_not_equal_id, tf.float32) * tf.cast(input_mask, tf.float32) # sampled_not_equal = 1 - tf.reduce_sum(sampled_not_equal) / (1e-10 + tf.reduce_sum(tf.cast(label_diff_ids, tf.float32))) # if not kargs.get('use_tpu', True): # tf.summary.scalar('generator_sample_acc', # sampled_not_equal) return sampled_input_id
def apply_gradients(self, grads_and_vars, global_step=None, name=None, learning_rate=None): """See base class.""" if learning_rate is None: learning_rate = self.learning_rate tf.logging.info("***** use default learning rate ***** ", learning_rate) else: tf.logging.info("***** use provided learning rate ***** ", learning_rate) assignments = [] for (grad, param) in grads_and_vars: if grad is None or param is None: continue param_name = self._get_variable_name(param.name) tf.logging.info("***** apply gradients parameter name ***** %s", param_name) tf.logging.info("***** param: %s learning rate: %s ***** ", param_name, str(learning_rate)) shape_list = bert_utils.get_shape_list(param, expected_rank=[1, 2]) # decay_rate = 1 - tf.pow(tf.cast(tf.train.get_or_create_global_step(), tf.float32) + 1.0, -0.8) decay_rate = self.beta_2 grad_squared = tf.square(grad) + self.epsilon1 update_scale = self.learning_rate # update_scale = self.learning_rate * tf.cast(self._parameter_scale(param), dtype=tf.float32) # HACK: Make things dependent on grad. # This confounds the XLA rewriter and keeps it from fusing computations # across different variables. This fusion is a bad for HBM usage, since # it causes the gradients to persist in memory. grad_squared_mean = tf.reduce_mean(grad_squared) decay_rate += grad_squared_mean * 1e-30 update_scale += grad_squared_mean * 1e-30 # END HACK if self._use_factored(shape_list): num_rows, num_columns = shape_list vr = tf.get_variable(name=param_name + "/adafactor_vr", shape=[num_rows], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) vc = tf.get_variable(name=param_name + "/adafactor_vc", shape=[num_columns], dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) next_vr = decay_rate * vr + (1 - decay_rate) * tf.reduce_mean( grad_squared, 1) next_vc = decay_rate * vc + (1 - decay_rate) * tf.reduce_mean( grad_squared, 0) long_term_mean = tf.reduce_mean(next_vr, -1, keepdims=True) r_factor = tf.rsqrt(next_vr / long_term_mean + self.epsilon1) c_factor = tf.rsqrt(next_vc + self.epsilon1) update = grad * tf.expand_dims(r_factor, -1) * tf.expand_dims( c_factor, -2) assignments.append( vr.assign(next_vr, use_locking=self.use_locking)) assignments.append( vc.assign(next_vc, use_locking=self.use_locking)) else: v = tf.get_variable(name=param_name + "/adafactor_v", shape=shape_list, dtype=tf.float32, trainable=False, initializer=tf.zeros_initializer()) next_v = decay_rate * v + (1 - decay_rate) * grad_squared assignments.append( v.assign(next_v, use_locking=self.use_locking)) update = grad * tf.rsqrt(next_v + self.epsilon1) clipping_denom = tf.maximum( 1.0, reduce_rms(update) / self.clipping_rate) update /= clipping_denom # Do weight decay # Just adding the square of the weights to the loss function is *not* # the correct way of using L2 regularization/weight decay with Adam, # since that will interact with the m and v parameters in strange ways. # # Instead we want ot decay the weights in a manner that doesn't interact # with the m/v parameters. This is equivalent to adding the square # # of the weights to the loss with plain (non-momentum) SGD. if self._do_use_weight_decay(param_name): update += self.weight_decay_rate * param update_with_lr = update_scale * update next_param = param - update_with_lr assignments.append( param.assign(next_param, use_locking=self.use_locking)) return tf.group(*assignments, name=name)
def random_input_ids_generation_v1(config, input_ori_ids, input_mask, **kargs): mask_id = kargs.get('mask_id', 103) valid_vocab = kargs.get('valid_vocab', 105) input_ori_ids = tf.cast(input_ori_ids, tf.int32) input_mask = tf.cast(input_mask, tf.int32) unk_mask = tf.cast(tf.math.equal(input_ori_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask input_shape_list = bert_utils.get_shape_list(input_ori_ids, expected_rank=2) batch_size = input_shape_list[0] seq_length = input_shape_list[1] if kargs.get('annealed_mask_prob', False): mask_probability = 1 - tf.train.polynomial_decay( 0.95, tf.train.get_or_create_global_step(), kargs.get("num_train_steps", 10000) * 0.1, end_learning_rate=0.85, power=1.0, cycle=False) tf.logging.info("**** apply annealed_mask_prob **** ") else: mask_probability = 0.15 tf.logging.info("**** apply fixed_mask_prob %s **** ", str(mask_probability)) # must_have_one = tf.cast(tf.expand_dims(tf.eye(seq_length)[4], axis=[0]), tf.int32) # batch x seq_length # must_have_one = must_have_one * input_mask * (1 - tf.cast(none_replace_mask, tf.int32)) sample_probs = tf.ones_like(input_ori_ids) * input_mask * ( 1 - tf.cast(none_replace_mask, tf.int32)) sample_probs = mask_probability * tf.cast( sample_probs, tf.float32 ) #+ 0.8 * tf.cast(must_have_one, tf.float32) # mask 15% token noise_dist = tf.distributions.Bernoulli(probs=sample_probs, dtype=tf.float32) sampled_binary_mask = noise_dist.sample() sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) # mask_binary_probs = 0.8 * sampled_binary_mask # use 80% [mask] for masked token # mask_noise_dist = tf.distributions.Bernoulli(probs=mask_binary_probs, dtype=tf.float32) # sampled_mask_binary_mask = mask_noise_dist.sample() # sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32) # replace_binary_probs = 0.5 * (sampled_binary_mask - sampled_mask_binary_mask) # use 10% [mask] to replace token # replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs, dtype=tf.float32) # sampled_replace_binary_mask = replace_noise_dist.sample() # sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask, tf.float32) # ori_binary_probs = 1.0 * (sampled_binary_mask - sampled_mask_binary_mask - sampled_replace_binary_mask) # ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs, dtype=tf.float32) # sampled_ori_binary_mask = ori_noise_dist.sample() # sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32) replace_binary_probs = 0.1 * (sampled_binary_mask ) # use 10% [mask] to replace token replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs, dtype=tf.float32) sampled_replace_binary_mask = replace_noise_dist.sample() sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask, tf.float32) ori_binary_probs = 0.1 * (sampled_binary_mask - sampled_replace_binary_mask) ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs, dtype=tf.float32) sampled_ori_binary_mask = ori_noise_dist.sample() sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32) # mask_binary_probs = 0.85 * (sampled_binary_mask - sampled_replace_binary_mask - sampled_ori_binary_mask) # use 80% [mask] for masked token # mask_noise_dist = tf.distributions.Bernoulli(probs=mask_binary_probs, dtype=tf.float32) # sampled_mask_binary_mask = mask_noise_dist.sample() # sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32) sampled_mask_binary_mask = (sampled_binary_mask - sampled_replace_binary_mask - sampled_ori_binary_mask) sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32) # sampled_replace_binary_mask *= (1 - tf.cast(none_replace_mask, tf.float32)) # sampled_replace_binary_mask *= tf.cast(input_mask, tf.float32) # sampled_mask_binary_mask *= (1 - tf.cast(none_replace_mask, tf.float32)) # sampled_mask_binary_mask *= tf.cast(input_mask, tf.float32) # sampled_ori_binary_mask *= (1 - tf.cast(none_replace_mask, tf.float32)) # sampled_ori_binary_mask *= tf.cast(input_mask, tf.float32) vocab_sample_logits = tf.random.uniform( [batch_size, seq_length, config.vocab_size], minval=0.0, maxval=1.0, dtype=tf.float32) vocab_sample_logits = tf.nn.log_softmax(vocab_sample_logits) flatten_vocab_sample_logits = tf.reshape(vocab_sample_logits, [batch_size * seq_length, -1]) sampled_logprob_temp, sampled_logprob = gumbel_softmax( flatten_vocab_sample_logits, temperature=0.1, samples=config.get('gen_sample', 1)) sample_vocab_ids = tf.argmax(sampled_logprob, axis=1) # batch x seq # sample_vocab_ids = tf.multinomial(flatten_vocab_sample_logits, # num_samples=config.get('gen_sample', 1), # output_dtype=tf.int32) sample_vocab_ids = tf.reshape(sample_vocab_ids, [batch_size, seq_length]) sample_vocab_ids = tf.cast(sample_vocab_ids, tf.float32) input_ori_ids = tf.cast(input_ori_ids, tf.float32) output_input_ids = mask_id * tf.cast( sampled_mask_binary_mask, tf.float32) * tf.ones_like(input_ori_ids) output_input_ids += sample_vocab_ids * tf.cast(sampled_replace_binary_mask, tf.float32) output_input_ids += ( 1 - tf.cast(sampled_mask_binary_mask + sampled_replace_binary_mask, tf.float32)) * input_ori_ids output_sampled_binary_mask = sampled_mask_binary_mask + sampled_replace_binary_mask + sampled_ori_binary_mask output_sampled_binary_mask = tf.cast(output_sampled_binary_mask, tf.int32) return [tf.cast(output_input_ids, tf.int32), output_sampled_binary_mask]
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, tf.estimator.ModeKeys.PREDICT, target, reuse=model_reuse, cnn_type=model_config.get('cnn_type', 'bi_dgcnn'), **kargs) dropout_prob = 0.0 is_training = False with tf.variable_scope(model_config.scope + "/feature_output", reuse=tf.AUTO_REUSE): hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1] sentence_pres = model.get_pooled_output() sentence_pres = tf.layers.dense( sentence_pres, 128, use_bias=True, activation=tf.tanh, kernel_initializer=tf.truncated_normal_initializer( stddev=0.01)) # sentence_pres = tf.layers.dense( # model.get_pooled_output(), # hidden_size, # use_bias=None, # activation=tf.nn.relu, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) # sentence_pres = tf.layers.dense( # sentence_pres, # hidden_size, # use_bias=None, # activation=None, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) # hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1] # sentence_pres = tf.layers.dense( # model.get_pooled_output(), # hidden_size, # use_bias=True, # activation=tf.tanh, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) # feature_output_a = tf.layers.dense( # model.get_pooled_output(), # hidden_size, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.01)) # feature_output_a = tf.nn.dropout(feature_output_a, keep_prob=1 - dropout_prob) # feature_output_a += model.get_pooled_output() # sentence_pres = tf.layers.dense( # feature_output_a, # hidden_size, # kernel_initializer=tf.truncated_normal_initializer(stddev=0.01), # activation=tf.tanh) if kargs.get('apply_head_proj', False): with tf.variable_scope(model_config.scope + "/head_proj", reuse=tf.AUTO_REUSE): sentence_pres = simclr_utils.projection_head( sentence_pres, is_training, head_proj_dim=128, num_nlh_layers=1, head_proj_mode='nonlinear', name='head_contrastive') l2_sentence_pres = tf.nn.l2_normalize(sentence_pres + 1e-20, axis=-1) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) estimator_spec = tf.estimator.EstimatorSpec( mode=tf.estimator.ModeKeys.PREDICT, predictions={ 'sentence_pres': l2_sentence_pres, # "before_l2":sentence_pres }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'sentence_pres': l2_sentence_pres, # "before_l2":sentence_pres }) }) return estimator_spec
def model_fn(features, labels, mode): input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] input_shape = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape[0] choice_num = input_shape[1] seq_length = input_shape[2] input_ids = tf.reshape(input_ids, [batch_size * choice_num, seq_length]) input_mask = tf.reshape(input_mask, [batch_size * choice_num, seq_length]) segment_ids = tf.reshape(segment_ids, [batch_size * choice_num, seq_length]) if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 model = bert.Bert(model_config) model.build_embedder(input_ids, segment_ids, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_encoder(input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_pooler(reuse=reuse) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=reuse): (loss, per_example_loss, logits) = classifier.multi_choice_classifier( model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = model_io_fn.get_params(scope, not_storage_params=not_storage_params) model_io_fn.set_saver(var_lst=tvars) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return [train_op, loss, per_example_loss, logits] else: model_io_fn.print_params(tvars, string=", trainable params") return [loss, loss, per_example_loss, logits]
def build_output_logits(self, **kargs): layer_num = kargs.get("layer_num", -1) self.sequence_output = self.get_encoder_layers(layer_num) input_shape_list = bert_utils.get_shape_list(self.sequence_output, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if self.config.get('ln_type', 'postln') == 'preln': input_tensor = bert_modules.layer_norm(self.sequence_output) tf.logging.info("**** pre ln doing layer norm ****") elif self.config.get('ln_type', 'postln') == 'postln': input_tensor = self.sequence_output tf.logging.info("**** post ln ****") else: input_tensor = self.sequence_output tf.logging.info("**** post ln ****") # if config.get("embedding", "factorized") == "factorized": # projection_width = config.hidden_size # else: # projection_width = config.embedding_size if self.config.get("embedding", "none_factorized") == "none_factorized": projection_width = self.config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = self.config.get('embedding_size', self.config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=bert_modules.get_activation( self.config.hidden_act), kernel_initializer=bert_modules.create_initializer( self.config.initializer_range)) if self.config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor tf.logging.info("**** pre ln ****") elif self.config.get('ln_type', 'postln') == 'postln': input_tensor = bert_modules.layer_norm(input_tensor) tf.logging.info("**** post ln doing layer norm ****") else: input_tensor = bert_modules.layer_norm(input_tensor) tf.logging.info("**** post ln doing layer norm ****") if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable("output_bias", shape=[self.config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, self.embedding_table) self.logits = tf.nn.bias_add(logits, output_bias)
def model_fn(features, labels, mode): task_type = kargs.get("task_type", "cls") label_ids = features["{}_label_ids".format(task_type)] num_task = kargs.get('num_task', 1) model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if kargs.get("get_pooled_output", "pooled_output") == "pooled_output": pooled_feature = model.get_pooled_output() elif kargs.get("get_pooled_output", "task_output") == "task_output": pooled_feature_dict = model.get_task_output() pooled_feature = pooled_feature_dict['pooled_feature'] loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) loss = tf.constant(0.0) params_size = model_io_fn.count_params(model_config.scope) print("==total encoder params==", params_size) if kargs.get("feature_distillation", True): universal_feature_a = features.get("input_ids_a_features", None) universal_feature_b = features.get("input_ids_b_features", None) if universal_feature_a is None or universal_feature_b is None: tf.logging.info( "****** not apply feature distillation *******") feature_loss = tf.constant(0.0) else: feature_a = pooled_feature_dict['feature_a'] feature_a_shape = bert_utils.get_shape_list( feature_a, expected_rank=[2, 3]) pretrain_feature_a_shape = bert_utils.get_shape_list( universal_feature_a, expected_rank=[2, 3]) if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope + "/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_a = tf.layers.dense( feature_a, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task) tf.logging.info( "****** apply auto-encoder for feature compression *******" ) else: proj_feature_a = feature_a feature_a_norm = tf.stop_gradient( tf.sqrt( tf.reduce_sum(tf.pow(proj_feature_a, 2), axis=-1, keepdims=True)) + 1e-20) proj_feature_a /= feature_a_norm feature_b = pooled_feature_dict['feature_b'] if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope + "/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_b = tf.layers.dense( feature_b, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task) tf.logging.info( "****** apply auto-encoder for feature compression *******" ) else: proj_feature_b = feature_b feature_b_norm = tf.stop_gradient( tf.sqrt( tf.reduce_sum(tf.pow(proj_feature_b, 2), axis=-1, keepdims=True)) + 1e-20) proj_feature_b /= feature_b_norm feature_a_distillation = tf.reduce_mean( tf.square(universal_feature_a - proj_feature_a), axis=-1) feature_b_distillation = tf.reduce_mean( tf.square(universal_feature_b - proj_feature_b), axis=-1) feature_loss = tf.reduce_mean( (feature_a_distillation + feature_b_distillation) / 2.0) / float(num_task) loss += feature_loss tf.logging.info( "****** apply prertained feature distillation *******") if kargs.get("embedding_distillation", True): word_embed = model.emb_mat random_embed_shape = bert_utils.get_shape_list( word_embed, expected_rank=[2, 3]) print("==random_embed_shape==", random_embed_shape) pretrained_embed = kargs.get('pretrained_embed', None) if pretrained_embed is None: tf.logging.info( "****** not apply prertained feature distillation *******") embed_loss = tf.constant(0.0) else: pretrain_embed_shape = bert_utils.get_shape_list( pretrained_embed, expected_rank=[2, 3]) print("==pretrain_embed_shape==", pretrain_embed_shape) if random_embed_shape[-1] != pretrain_embed_shape[-1]: with tf.variable_scope(scope + "/embedding_proj", reuse=tf.AUTO_REUSE): proj_embed = tf.layers.dense(word_embed, pretrain_embed_shape[-1]) else: proj_embed = word_embed embed_loss = tf.reduce_mean( tf.reduce_mean(tf.square(proj_embed - pretrained_embed), axis=-1)) / float(num_task) loss += embed_loss tf.logging.info( "****** apply prertained feature distillation *******") with tf.variable_scope(scope + "/{}/classifier".format(task_type), reuse=task_layer_reuse): (_, per_example_loss, logits) = classifier.classifier(model_config, pooled_feature, num_labels, label_ids, dropout_prob) loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) masked_per_example_loss = per_example_loss * loss_mask task_loss = tf.reduce_sum(masked_per_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) loss += task_loss if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones( (1, multi_task_config[task_type]["max_predictions_per_seq"])) masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, )) masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32) masked_lm_example_loss *= masked_lm_loss_mask # multiply task_mask masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / ( 1e-10 + tf.reduce_sum(masked_lm_loss_mask)) loss += multi_task_config[task_type][ "masked_lm_loss_ratio"] * masked_lm_loss masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1]) print(masked_lm_log_probs.get_shape(), "===masked lm log probs===") print(masked_lm_label_ids.get_shape(), "===masked lm ids===") print(masked_lm_label_weights.get_shape(), "===masked lm mask===") lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask) if kargs.get("task_invariant", "no") == "yes": print("==apply task adversarial training==") with tf.variable_scope(scope + "/dann_task_invariant", reuse=model_reuse): (_, task_example_loss, task_logits) = distillation_utils.feature_distillation( model.get_pooled_output(), 1.0, features["task_id"], kargs.get("num_task", 7), dropout_prob, True) masked_task_example_loss = loss_mask * task_example_loss masked_task_loss = tf.reduce_sum(masked_task_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) tvars.extend(masked_lm_pretrain_tvars) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") # print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: acc = build_accuracy(logits, label_ids, loss_mask) return_dict = { "loss": loss, "logits": logits, "task_num": tf.reduce_sum(loss_mask), "tvars": tvars } return_dict["{}_acc".format(task_type)] = acc if kargs.get("task_invariant", "no") == "yes": return_dict["{}_task_loss".format( task_type)] = masked_task_loss task_acc = build_accuracy(task_logits, features["task_id"], loss_mask) return_dict["{}_task_acc".format(task_type)] = task_acc if multi_task_config[task_type].get("lm_augumentation", False): return_dict["{}_masked_lm_loss".format( task_type)] = masked_lm_loss return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc if kargs.get("embedding_distillation", True): return_dict["embed_loss"] = embed_loss * float(num_task) else: return_dict["embed_loss"] = task_loss if kargs.get("feature_distillation", True): return_dict["feature_loss"] = feature_loss * float(num_task) else: return_dict["feature_loss"] = task_loss return_dict["task_loss"] = task_loss return return_dict elif mode == tf.estimator.ModeKeys.EVAL: eval_dict = { "loss": loss, "logits": logits, "feature": model.get_pooled_output() } if kargs.get("adversarial", "no") == "adversarial": eval_dict["task_logits"] = task_logits return eval_dict
def discriminator_metric_eval(input_dict): d_out_real = input_dict['true_logits'] d_out_fake = input_dict['fake_logits'] input_shape_list = bert_utils.get_shape_list(d_out_real, expected_rank=[2]) batch_size = input_shape_list[0] true_labels = tf.cast(tf.ones(batch_size), tf.int32) fake_labels = tf.cast(tf.zeros(batch_size), tf.int32) pred_true_label = tf.argmax(d_out_real, axis=-1) pred_fake_label = tf.argmax(d_out_fake, axis=-1) all_pred_label = tf.concat([pred_true_label, pred_fake_label], axis=0) all_true_label = tf.concat([true_labels, fake_labels], axis=0) if not kargs.get('use_tpu', True): discriminator_f1 = tf_metrics.f1(all_true_label, all_pred_label, 2, average="macro") discriminator_precison = tf_metrics.precision(all_true_label, all_pred_label, 2, average="macro") discriminator_recall = tf_metrics.recall(all_true_label, all_pred_label, 2, average="macro") discriminator_f1_original = tf_metrics.f1(all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_f1_replaced = tf_metrics.f1(all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") discriminator_precision_original = tf_metrics.precision( all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_precision_replaced = tf_metrics.precision( all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") discriminator_recall_original = tf_metrics.recall(all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_recall_replaced = tf_metrics.recall(all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") output_dict['discriminator_f1'] = discriminator_f1 output_dict['discriminator_precison'] = discriminator_precison output_dict['discriminator_recall'] = discriminator_recall output_dict['discriminator_f1_original'] = discriminator_f1_original output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced output_dict[ 'discriminator_precision_original'] = discriminator_precision_original output_dict[ 'discriminator_precision_replaced'] = discriminator_precision_replaced output_dict[ 'discriminator_recall_original'] = discriminator_recall_original output_dict[ 'discriminator_recall_replaced'] = discriminator_recall_replaced else: discriminator_recall = tf.compat.v1.metrics.recall( tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2)) discriminator_precison = tf.compat.v1.metrics.precision( tf.one_hot(all_true_label, 2), tf.one_hot(all_pred_label, 2)) discriminator_f1 = tf_metrics.f1(all_true_label, all_pred_label, 2, average="macro") discriminator_f1_original = tf_metrics.f1(all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_f1_replaced = tf_metrics.f1(all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") discriminator_precision_original = tf_metrics.precision( all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_precision_replaced = tf_metrics.precision( all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") discriminator_recall_original = tf_metrics.recall(all_true_label, all_pred_label, 2, pos_indices=[0], average="macro") discriminator_recall_replaced = tf_metrics.recall(all_true_label, all_pred_label, 2, pos_indices=[1], average="macro") output_dict['discriminator_f1_original'] = discriminator_f1_original output_dict['discriminator_f1_replaced'] = discriminator_f1_replaced output_dict[ 'discriminator_precision_original'] = discriminator_precision_original output_dict[ 'discriminator_precision_replaced'] = discriminator_precision_replaced output_dict[ 'discriminator_recall_original'] = discriminator_recall_original output_dict[ 'discriminator_recall_replaced'] = discriminator_recall_replaced output_dict['discriminator_f1'] = discriminator_f1 output_dict['discriminator_precison'] = discriminator_precison output_dict['discriminator_recall'] = discriminator_recall return output_dict
def sample_sequence_without_cache(model_api, model_config, mode, features, target="", start_token=101, batch_size=None, seq_length=None, context=None, temperature=1, n_samples=1, top_k=0, end_token=102, greedy_or_sample="sample", gumbel_temp=0.01, estimator="straight_through", back_prop=True, swap_memory=True, max_seq_length=512, **kargs): input_shape = bert_utils.get_shape_list(features["input_ids"], expected_rank=[2, 3]) batch_size = input_shape[0] seq_length = input_shape[1] actual_length = seq_length if context is None: assert start_token is not None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) context = tf.cast(context, tf.int32) context_shape = bert_utils.get_shape_list(context, expected_rank=[2]) print(context.get_shape(), "===init context shape===") else: context = tf.cast(context, tf.int32) context_shape = bert_utils.get_shape_list(context, expected_rank=[2]) batch_size = input_shape[0] samples = tf.cast(tf.zeros((batch_size, actual_length)), tf.int32) end_mask = tf.expand_dims(tf.one_hot(actual_length - 1, actual_length), axis=(0)) samples += end_token * tf.cast( end_mask, tf.int32) # make sure last token is end token start_mask = tf.one_hot(tf.range(0, context_shape[1]), actual_length) samples += tf.cast( tf.einsum("ab,bc->ac", tf.cast(context, tf.float32), tf.cast(start_mask, tf.float32)), tf.int32) segment_ids = tf.cast( tf.zeros((batch_size, actual_length - context_shape[1])), tf.int32) if kargs.get("mask_type", "left2right") == 'left2right': segment_ids = tf.concat([ tf.cast(tf.zeros( (batch_size, context_shape[1])), tf.int32), segment_ids ], axis=-1) elif kargs.get("mask_type", "left2right") == 'seq2seq': segment_ids = tf.concat([ tf.cast(tf.ones( (batch_size, context_shape[1])), tf.int32), segment_ids ], axis=-1) logits = tf.cast(tf.zeros((batch_size, actual_length)), tf.float32) input_mask = tf.cast( tf.zeros((batch_size, actual_length - context_shape[1])), tf.int32) input_mask = tf.concat([ tf.cast(tf.ones((batch_size, context_shape[1])), tf.int32), input_mask ], axis=-1) if estimator in ["straight_through", "soft"]: gumbel_probs = tf.zeros((batch_size, actual_length - context_shape[1], model_config.vocab_size)) start_probs = context start_one_hot = tf.one_hot(start_probs, model_config.vocab_size) gumbel_probs = tf.concat( [tf.cast(start_one_hot, tf.float32), gumbel_probs], axis=1) def step(step, tokens, input_mask, segment_ids): token_shape = bert_utils.get_shape_list(tokens, expected_rank=[2, 3]) features = {} features['input_ids'] = tokens features['segment_ids'] = segment_ids features['input_mask'] = input_mask inference_model = model_api(model_config, features, [], mode, target, reuse=tf.AUTO_REUSE, **kargs) logits = inference_model.get_sequence_output_logits() return {'logits': logits} with tf.name_scope('sample_sequence'): def get_samples_logits(samples, logits): batch_idxs = tf.range(0, tf.shape(samples)[0]) batch_idxs = tf.expand_dims(tf.cast(batch_idxs, tf.int32), 1) samples = tf.expand_dims(tf.cast(samples, tf.int32), 1) idxs = tf.concat([batch_idxs, samples], 1) sample_logits = tf.gather_nd(logits, idxs) return sample_logits def body(i, samples, input_mask, segment_ids, logits): next_outputs = step(i, samples, input_mask, segment_ids) logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length), axis=(0)) # [1, seq] next_logits = tf.reduce_sum( next_outputs['logits'] * tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32), axis=1) next_logits = next_logits / tf.to_float(temperature) next_logits = tf.nn.log_softmax(next_logits, axis=-1) if greedy_or_sample == "sample": next_samples = tf.multinomial(next_logits, num_samples=1, output_dtype=tf.int32) next_samples = tf.squeeze(next_samples, axis=-1) elif greedy_or_sample == "greedy": next_samples = tf.argmax(next_logits, axis=-1) else: next_samples = tf.argmax(next_logits, axis=-1) next_samples = tf.cast(next_samples, tf.int32) print(next_samples.get_shape(), "==sample shape==") print(tf.one_hot(i, actual_length).get_shape(), "====shhhhape===") sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) print(next_sample_logits.get_shape(), "===next sampleslogis shape==") logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) input_mask += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(tf.ones_like(next_samples), axis=-1), tf.int32) return [i + 1, samples, input_mask, segment_ids, logits] def gumbel_st_body(i, samples, gumbel_probs, input_mask, segment_ids, logits): next_outputs = step(i, gumbel_probs, input_mask, segment_ids) # next_logits = next_outputs['logits'][:, i-1, :] / tf.to_float(temperature) logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length), axis=(0)) # [1, seq] next_logits = tf.reduce_sum( next_outputs['logits'] * tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32), axis=1) next_logits = next_logits / tf.to_float(temperature) next_logits = tf.nn.log_softmax(next_logits, axis=-1) next_gumbel_probs, _ = gumbel_softmax(next_logits, gumbel_temp, gumbel_samples=None, samples=1) next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1), tf.int32) next_samples_onehot = tf.one_hot(next_samples, model_config.vocab_size, axis=1) # sampled multiminal id straight_through_onehot = tf.stop_gradient( next_samples_onehot - next_gumbel_probs) + next_gumbel_probs print(next_gumbel_probs.get_shape(), "=====gumbel====", straight_through_onehot.get_shape()) gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot( i, actual_length), axis=0), axis=2) # [1, seq, 1] gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims( straight_through_onehot, axis=1) # b x 1 x vocab sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq, 1] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) input_mask += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(tf.ones_like(next_samples), axis=-1), tf.int32) return [ i + 1, samples, gumbel_probs, input_mask, segment_ids, logits ] def gumbel_soft_body(i, samples, gumbel_probs, input_mask, segment_ids, logits): next_outputs = step(i, samples, input_mask, segment_ids) logits_mask = tf.expand_dims(tf.one_hot(i - 1, actual_length), axis=(0)) # [1, seq] next_logits = tf.reduce_sum( next_outputs['logits'] * tf.cast(tf.expand_dims(logits_mask, axis=-1), tf.float32), axis=1) next_logits = next_logits / tf.to_float(temperature) # gumbel sample next_gumbel_probs, _ = gumbel_softmax(next_logits, gumbel_temp, gumbel_samples=None, samples=1) next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1), tf.int32) print(next_gumbel_probs.get_shape()) gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot( i, actual_length), axis=0), axis=2) # [1, seq, 1] gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims( next_gumbel_probs, axis=1) # b x 1 x vocab sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) return [ i + 1, samples, gumbel_probs, input_mask, segment_ids, logits ] init_i = tf.cast( bert_utils.get_shape_list(context, expected_rank=[2, 3])[1], tf.int32) if estimator == "straight_through": # final, samples, gumbel_probs, input_mask, segment_ids, logits = tf.while_loop( # cond=lambda i, _1, _2, _3, _4, _5: i < seq_length-1, # body=gumbel_st_body, # loop_vars=[init_i, # samples, # gumbel_probs, # input_mask, # segment_ids, # logits # ], # back_prop=back_prop, # swap_memory=swap_memory, # maximum_iterations=seq_length # ) for i in range(1, max_seq_length - 1): [ final, samples, gumbel_probs, input_mask, segment_ids, logits ] = gumbel_st_body(i, samples, gumbel_probs, input_mask, segment_ids, logits) elif estimator == "soft": # final, samples, gumbel_probs, input_mask, segment_ids, logits = tf.while_loop( # cond=lambda i, _1, _2, _3, _4, _5: i < seq_length-1, # body=gumbel_soft_body, # loop_vars=[init_i, # samples, # gumbel_probs, # input_mask, # segment_ids, # logits # ], # back_prop=back_prop, # swap_memory=swap_memory, # maximum_iterations=seq_length # ) for i in range(1, max_seq_length - 1): [ final, samples, gumbel_probs, input_mask, segment_ids, logits ] = gumbel_soft_body(i, samples, gumbel_probs, input_mask, segment_ids, logits) else: # final, samples, input_mask, segment_ids, logits = tf.while_loop( # cond=lambda i, _1, _2, _3, _4: i < seq_length-1, # body=body, # loop_vars=[init_i, # samples, # input_mask, # segment_ids, # logits # ], # back_prop=back_prop, # swap_memory=swap_memory, # maximum_iterations=seq_length # ) for i in range(1, max_seq_length - 1): [final, samples, input_mask, segment_ids, logits] = body(i, samples, input_mask, segment_ids, logits) mask_sequence = get_finised_pos_v1(samples, end_token, actual_length) print(mask_sequence.get_shape(), "==mask shape==") samples *= tf.cast(mask_sequence, tf.int32) logits *= tf.cast(mask_sequence, tf.float32) if estimator in ["straight_through", "soft"]: gumbel_probs *= tf.expand_dims(tf.cast(mask_sequence, tf.float32), axis=-1) return { "samples": samples, "mask_sequence": mask_sequence, "gumbel_probs": gumbel_probs, "logits": logits, "input_mask": input_mask } else: return { "samples": samples, "mask_sequence": mask_sequence, "logits": logits, "input_mask": input_mask }
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope='generator') masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope='generator') print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss sampled_ids = token_generator( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_ids'], features['input_ori_ids'], features['input_mask'], embedding_projection=model.get_embedding_projection_table(), scope='generator', mask_method='only_mask') if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] else: input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1) # batch x seq_length x 1 input_ids = tf.einsum( 'abc,cd->abd', input_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] gen_sample = input_shape_list[2] sampled_ids = tf.reshape(sampled_ids, [batch * gen_sample, seq_length]) input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length]) input_mask = tf.expand_dims(features['input_mask'], axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.int32) segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1) segment_ids = tf.einsum( 'abc,cd->abd', segment_ids, tf.ones((1, model_config.get('gen_sample', 1)))) segment_ids = tf.cast(segment_ids, tf.int32) segment_ids = tf.reshape(segment_ids, [batch * gen_sample, seq_length]) input_mask = tf.reshape(input_mask, [batch * gen_sample, seq_length]) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "generator/cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope="generator", use_tpu=use_tpu) else: scaffold_fn = None return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_positions": masked_lm_positions, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_weights, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss } return return_dict
def seq_mask_masked_lm_output(config, input_tensor, output_weights, input_mask, input_ori_ids, input_ids, sampled_binary_mask, **kargs): input_shape_list = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] hidden_dims = input_shape_list[2] embedding_projection = kargs.get('embedding_projection', None) scope = kargs.get('scope', None) if scope: scope = scope + '/' + 'cls/predictions' else: scope = 'cls/predictions' tf.logging.info("**** mlm generator scope **** %s", str(scope)) # with tf.variable_scope("cls/predictions", reuse=tf.AUTO_REUSE): with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): if config.get('ln_type', 'postln') == 'preln': input_tensor = albert_modules.layer_norm(input_tensor) elif config.get('ln_type', 'postln') == 'postln': input_tensor = input_tensor else: input_tensor = input_tensor if config.get("embedding", "factorized") == "factorized": projection_width = config.hidden_size else: projection_width = config.embedding_size with tf.variable_scope("transform"): input_tensor = tf.layers.dense( input_tensor, units=projection_width, activation=albert_modules.get_activation(config.hidden_act), kernel_initializer=albert_modules.create_initializer( config.initializer_range)) if config.get('ln_type', 'postln') == 'preln': input_tensor = input_tensor elif config.get('ln_type', 'postln') == 'postln': input_tensor = albert_modules.layer_norm(input_tensor) else: input_tensor = albert_modules.layer_norm(input_tensor) if embedding_projection is not None: # batch x seq x hidden, embedding x hidden print(input_tensor.get_shape(), embedding_projection.get_shape()) input_tensor = tf.einsum("abc,dc->abd", input_tensor, embedding_projection) else: print("==no need for embedding projection==") input_tensor = input_tensor output_bias = tf.get_variable( "output_bias", shape=[config.vocab_size], initializer=tf.zeros_initializer()) # batch x seq x embedding logits = tf.einsum("abc,dc->abd", input_tensor, output_weights) logits = tf.nn.bias_add(logits, output_bias) """ if input_ori_ids[i] is random pertubated, sampled_binary_mask[i]=1 """ sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) per_example_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=logits, labels=tf.stop_gradient(input_ori_ids), ) per_example_loss *= sampled_binary_mask loss = tf.reduce_sum(per_example_loss) / tf.reduce_sum(sampled_binary_mask) return (loss, per_example_loss, logits, sampled_binary_mask)
def iso_gaussian_sample(logits, temperature, samples=1): input_shape_list = bert_utils.get_shape_list(logits, expected_rank=2) if samples > 1: logits = tf.expand_dims(logits, -1) y = logits + sample_normal(input_shape_list, samples) return [tf.exp(tf.nn.log_softmax(y / temperature)), logits]
def distributed_transformer_model(input_tensor, attention_mask=None, hidden_size=768, num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, intermediate_act_fn=gelu, hidden_dropout_prob=0.1, attention_probs_dropout_prob=0.1, initializer_range=0.02, do_return_all_layers=False, gpu_nums=2): """Multi-headed, multi-layer Transformer from "Attention is All You Need". This is almost an exact implementation of the original Transformer encoder. See the original paper: https://arxiv.org/abs/1706.03762 Also see: https://github.com/tensorflow/tensor2tensor/blob/master/tensor2tensor/models/transformer.py Args: input_tensor: float Tensor of shape [batch_size, seq_length, hidden_size]. attention_mask: (optional) int32 Tensor of shape [batch_size, seq_length, seq_length], with 1 for positions that can be attended to and 0 in positions that should not be. hidden_size: int. Hidden size of the Transformer. num_hidden_layers: int. Number of layers (blocks) in the Transformer. num_attention_heads: int. Number of attention heads in the Transformer. intermediate_size: int. The size of the "intermediate" (a.k.a., feed forward) layer. intermediate_act_fn: function. The non-linear activation function to apply to the output of the intermediate/feed-forward layer. hidden_dropout_prob: float. Dropout probability for the hidden layers. attention_probs_dropout_prob: float. Dropout probability of the attention probabilities. initializer_range: float. Range of the initializer (stddev of truncated normal). do_return_all_layers: Whether to also return all layers or just the final layer. Returns: float Tensor of shape [batch_size, seq_length, hidden_size], the final hidden layer of the Transformer. Raises: ValueError: A Tensor shape or parameter is invalid. """ if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) attention_head_size = int(hidden_size / num_attention_heads) input_shape = bert_utils.get_shape_list(input_tensor, expected_rank=3) batch_size = input_shape[0] seq_length = input_shape[1] input_width = input_shape[2] # The Transformer performs sum residuals on all layers so the input needs # to be the same as the hidden size. if input_width != hidden_size: raise ValueError( "The width of the input tensor (%d) != hidden size (%d)" % (input_width, hidden_size)) # We keep the representation as a 2D tensor to avoid re-shaping it back and # forth from a 3D tensor to a 2D tensor. Re-shapes are normally free on # the GPU/CPU but may not be free on the TPU, so we want to minimize them to # help the optimizer. prev_output = bert_utils.reshape_to_matrix(input_tensor) all_layer_outputs = [] gpu_partition = int(num_hidden_layers / gpu_nums) gpu_id = -1 # gpu_id is started from 0 to gpu_nums for layer_idx in range(num_hidden_layers): with tf.variable_scope("layer_%d" % layer_idx): layer_input = prev_output if np.mod(layer_idx, gpu_partition) == 0: gpu_id += 1 with tf.device('/gpu:{}'.format(gpu_id)): tf.logging.info( " apply transformer attention {}-th layer on device {} ". format(layer_idx, gpu_id)) print(" apply transformer attention {}-th layer on device {} ". format(layer_idx, gpu_id)) with tf.variable_scope("attention"): attention_heads = [] with tf.variable_scope("self"): attention_head = attention_layer( from_tensor=layer_input, to_tensor=layer_input, attention_mask=attention_mask, num_attention_heads=num_attention_heads, size_per_head=attention_head_size, attention_probs_dropout_prob= attention_probs_dropout_prob, initializer_range=initializer_range, do_return_2d_tensor=True, batch_size=batch_size, from_seq_length=seq_length, to_seq_length=seq_length) attention_heads.append(attention_head) attention_output = None if len(attention_heads) == 1: attention_output = attention_heads[0] else: # In the case where we have other sequences, we just concatenate # them to the self-attention head before the projection. attention_output = tf.concat(attention_heads, axis=-1) # Run a linear projection of `hidden_size` then add a residual # with `layer_input`. with tf.variable_scope("output"): attention_output = tf.layers.dense( attention_output, hidden_size, kernel_initializer=create_initializer( initializer_range)) attention_output = dropout(attention_output, hidden_dropout_prob) attention_output = layer_norm(attention_output + layer_input) # The activation is only applied to the "intermediate" hidden layer. with tf.variable_scope("intermediate"): intermediate_output = tf.layers.dense( attention_output, intermediate_size, activation=intermediate_act_fn, kernel_initializer=create_initializer( initializer_range)) # Down-project back to `hidden_size` then add the residual. with tf.variable_scope("output"): layer_output = tf.layers.dense( intermediate_output, hidden_size, kernel_initializer=create_initializer( initializer_range)) layer_output = dropout(layer_output, hidden_dropout_prob) layer_output = layer_norm(layer_output + attention_output) prev_output = layer_output all_layer_outputs.append(layer_output) if do_return_all_layers: final_outputs = [] for layer_output in all_layer_outputs: final_output = bert_utils.reshape_from_matrix( layer_output, input_shape) final_outputs.append(final_output) return final_outputs else: final_output = bert_utils.reshape_from_matrix(prev_output, input_shape) return final_output
def hmm_input_ids_generation(config, input_ori_ids, input_mask, hmm_tran_prob_list, **kargs): mask_id = kargs.get('mask_id', 103) input_ori_ids = tf.cast(input_ori_ids, tf.int32) input_mask = tf.cast(input_mask, tf.int32) unk_mask = tf.cast(tf.math.equal(input_ori_ids, 100), tf.float32) # not replace unk cls_mask = tf.cast(tf.math.equal(input_ori_ids, 101), tf.float32) # not replace cls sep_mask = tf.cast(tf.math.equal(input_ori_ids, 102), tf.float32) # not replace sep none_replace_mask = unk_mask + cls_mask + sep_mask mask_probability = kargs.get("mask_probability", 0.2) replace_probability = kargs.get("replace_probability", 0.1) original_probability = kargs.get("original_probability", 0.1) input_shape_list = bert_utils.get_shape_list(input_mask, expected_rank=2) batch_size = input_shape_list[0] seq_length = input_shape_list[1] tf.logging.info("**** apply fixed_mask_prob %s **** ", str(mask_probability)) tf.logging.info("**** apply replace_probability %s **** ", str(replace_probability)) tf.logging.info("**** apply original_probability %s **** ", str(original_probability)) # state, sampled_binary_mask = dynamic_span_mask_v1(batch_size, seq_length, hmm_tran_prob_list[0]) sampled_binary_mask = mask_method(batch_size, seq_length, hmm_tran_prob_list, **kargs) sampled_binary_mask = input_mask * (1 - tf.cast(none_replace_mask, tf.int32)) * sampled_binary_mask sampled_binary_mask = tf.cast(sampled_binary_mask, tf.float32) replace_binary_probs = replace_probability * (sampled_binary_mask) # use 10% [mask] to replace token replace_noise_dist = tf.distributions.Bernoulli(probs=replace_binary_probs, dtype=tf.float32) sampled_replace_binary_mask = replace_noise_dist.sample() sampled_replace_binary_mask = tf.cast(sampled_replace_binary_mask, tf.float32) ori_binary_probs = original_probability * (sampled_binary_mask - sampled_replace_binary_mask) ori_noise_dist = tf.distributions.Bernoulli(probs=ori_binary_probs, dtype=tf.float32) sampled_ori_binary_mask = ori_noise_dist.sample() sampled_ori_binary_mask = tf.cast(sampled_ori_binary_mask, tf.float32) sampled_mask_binary_mask = (sampled_binary_mask - sampled_replace_binary_mask - sampled_ori_binary_mask) sampled_mask_binary_mask = tf.cast(sampled_mask_binary_mask, tf.float32) vocab_sample_logits = tf.random.uniform( [batch_size, seq_length, config.vocab_size], minval=0.0, maxval=10.0, dtype=tf.float32) vocab_sample_logits = tf.nn.log_softmax(vocab_sample_logits) flatten_vocab_sample_logits = tf.reshape(vocab_sample_logits, [batch_size*seq_length, -1]) # sampled_logprob_temp, sampled_logprob = gumbel_softmax(flatten_vocab_sample_logits, # temperature=0.1, # samples=config.get('gen_sample', 1)) # sample_vocab_ids = tf.argmax(sampled_logprob, axis=1) # batch x seq sample_vocab_ids = tf.multinomial(flatten_vocab_sample_logits, num_samples=config.get('gen_sample', 1), output_dtype=tf.int32) sample_vocab_ids = tf.reshape(sample_vocab_ids, [batch_size, seq_length]) sample_vocab_ids = tf.cast(sample_vocab_ids, tf.float32) input_ori_ids = tf.cast(input_ori_ids, tf.float32) output_input_ids = mask_id * tf.cast(sampled_mask_binary_mask, tf.float32) * tf.ones_like(input_ori_ids) output_input_ids += sample_vocab_ids * tf.cast(sampled_replace_binary_mask, tf.float32) output_input_ids += (1 - tf.cast(sampled_mask_binary_mask + sampled_replace_binary_mask, tf.float32)) * input_ori_ids output_sampled_binary_mask = sampled_mask_binary_mask + sampled_replace_binary_mask + sampled_ori_binary_mask print("===output_input_ids shape===", output_input_ids.get_shape()) input_shape_list = bert_utils.get_shape_list(output_input_ids, expected_rank=2) print("==input shape list==", input_shape_list) output_sampled_binary_mask = tf.cast(output_sampled_binary_mask, tf.int32) if not kargs.get('use_tpu', True): tf.summary.scalar('mask_ratio', tf.reduce_sum(tf.cast(output_sampled_binary_mask, tf.float32))/(1e-10+tf.cast(tf.reduce_sum(input_mask), dtype=tf.float32))) return [tf.cast(output_input_ids, tf.int32), output_sampled_binary_mask]
def model_fn(features, labels, mode): label_ids = features["label_ids"] model_lst = [] for index, name in enumerate(input_name): if index > 0: reuse = True else: reuse = model_reuse model_lst.append( base_model(model_config, features, labels, mode, name, reuse=reuse)) if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 assert len(model_lst) == len(input_name) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): try: label_ratio_table = tf.get_variable( name="label_ratio", shape=[ num_labels, ], initializer=tf.constant(label_tensor), trainable=False) ratio_weight = tf.nn.embedding_lookup(label_ratio_table, label_ids) except: ratio_weight = None seq_output_lst = [model.get_pooled_output() for model in model_lst] repres = seq_output_lst[0] + seq_output_lst[1] final_hidden_shape = bert_utils.get_shape_list(repres, expected_rank=2) z_mean = tf.layers.dense(repres, final_hidden_shape[1], name="z_mean") z_log_var = tf.layers.dense(repres, final_hidden_shape[1], name="z_log_var") print("=======applying vib============") if mode == tf.estimator.ModeKeys.TRAIN: print("====applying vib====") vib_connector = vib.VIB(vib_config) [kl_loss, latent_vector ] = vib_connector.build_regularizer([z_mean, z_log_var]) [loss, per_example_loss, logits] = classifier.classifier(model_config, latent_vector, num_labels, label_ids, dropout_prob, ratio_weight) loss += tf.reduce_mean(kl_loss) else: print("====applying z_mean for prediction====") [loss, per_example_loss, logits] = classifier.classifier(model_config, z_mean, num_labels, label_ids, dropout_prob, ratio_weight) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint) tvars = model_io_fn.get_params(scope) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return [train_op, loss, per_example_loss, logits] else: model_io_fn.print_params(tvars, string=", trainable params") return [loss, loss, per_example_loss, logits]
def sample_sequence(model_api, model_config, mode, features, target="", start_token=101, batch_size=None, context=None, temperature=1, n_samples=1, top_k=0, end_token=102, greedy_or_sample="sample", gumbel_temp=0.01, estimator="straight_through", back_prop=True, swap_memory=True, **kargs): input_shape = bert_utils.get_shape_list(features["input_ids"], expected_rank=[2, 3]) batch_size = input_shape[0] seq_length = input_shape[1] if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) print(context.get_shape(), "===init context shape===") context_shape = bert_utils.get_shape_list(context, expected_rank=[2]) actual_length = seq_length # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` attention_head_size = int(model_config.hidden_size / model_config.num_attention_heads) # single layer present: [B, 2, N, T, H] # all layer present: [B, N_layer, 2, N, T, H] presents = tf.zeros( (batch_size, model_config.num_hidden_layers, 2, model_config.num_attention_heads, actual_length, attention_head_size)) samples = tf.cast(tf.zeros((batch_size, actual_length)), tf.int32) end_mask = tf.expand_dims(tf.one_hot(actual_length - 1, actual_length), axis=(0)) samples += end_token * tf.cast( end_mask, tf.int32) # make sure last token is end token # samples += start_token * tf.einsum("ab,bc->ac", # tf.cast(tf.ones((batch_size, tf.shape(start_mask)[0])), tf.int32), # tf.cast(start_mask, tf.int32)) start_mask = tf.one_hot(tf.range(0, context_shape[1]), actual_length) samples += tf.einsum("ab,bc->ac", context, tf.cast(start_mask, tf.int32)) logits = tf.cast(tf.zeros((batch_size, actual_length)), tf.float32) # start_mask = tf.expand_dims(tf.one_hot(0, seq_length+1), axis=(0)) # samples += start_token*tf.cast(start_mask, tf.int32) # make sure last token is end token if estimator in ["straight_through", "soft"]: gumbel_probs = tf.zeros((batch_size, actual_length - context_shape[1], model_config.vocab_size)) start_probs = context start_one_hot = tf.one_hot(start_probs, model_config.vocab_size) gumbel_probs = tf.concat( [tf.cast(start_one_hot, tf.float32), gumbel_probs], axis=1) def step(step, tokens, segment_ids=None, past=None): token_shape = bert_utils.get_shape_list(tokens, expected_rank=[2, 3]) features = {} features['input_ids'] = tokens if segment_ids is None: features['segment_ids'] = tf.cast( tf.zeros((token_shape[0], token_shape[1])), tf.int32) else: features['segment_ids'] = segment_ids if past is None: features['input_mask'] = tf.cast( tf.ones((token_shape[0], token_shape[1])), tf.int32) features['past'] = None else: past_shape = bert_utils.get_shape_list(past, expected_rank=[6]) features['input_mask'] = tf.cast( tf.ones((past_shape[0], step + token_shape[1])), tf.int32) features['past'] = past[:, :, :, :, :(step), :] inference_model = model_api(model_config, features, [], mode, target, reuse=tf.AUTO_REUSE, **kargs) logits = inference_model.get_sequence_output_logits() next_presents = inference_model.get_present() next_presents_shape = bert_utils.get_shape_list(next_presents, expected_rank=[6]) print(presents.get_shape()) if next_presents_shape[-2] > 0: print(next_presents_shape) print(next_presents.get_shape(), "===next presents shape===") # mask = tf.expand_dims(tf.one_hot(step, seq_length+1), axis=(0, 1, 2, 3, 5)) mask = tf.one_hot(tf.range(step, step + token_shape[1]), actual_length) # tf.expand_dims(tf.one_hot(tf.range(step, step+token_shape[1]), seq_length+1), axis=0) # mask = tf.expand_dims(mask, axis=1) # mask = tf.expand_dims(mask, axis=2) # mask = tf.expand_dims(mask, axis=3) # mask = tf.expand_dims(mask, axis=5) print(mask.get_shape(), "===mask shape===") past = tf.einsum("abcdef,eg->abcdgf", next_presents, mask) + past # past = past + tf.cast(mask, tf.float32) * next_presents return { 'logits': logits, 'presents': past, } with tf.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. print(context[:, :-1].get_shape()) init_context_shape = bert_utils.get_shape_list(context[:, :-1], expected_rank=[2, 3]) init_segment_ids = tf.cast( tf.zeros((init_context_shape[0], init_context_shape[1])), tf.int32) context_output = step(0, context[:, :-1], segment_ids=init_segment_ids, past=presents) def get_samples_logits(samples, logits): batch_idxs = tf.range(0, tf.shape(samples)[0]) batch_idxs = tf.expand_dims(batch_idxs, 1) samples = tf.expand_dims(samples, 1) idxs = tf.concat([batch_idxs, samples], 1) sample_logits = tf.gather_nd(logits, idxs) return sample_logits def body(i, past, prev, samples, segment_ids, logits): print(prev.get_shape(), "==prev shape==") next_outputs = step(i - 1, prev[:, tf.newaxis], segment_ids=segment_ids, past=past) next_logits = next_outputs['logits'][:, -1, :] / tf.to_float( temperature) next_logits = tf.nn.log_softmax(next_logits, axis=-1) if greedy_or_sample == "sample": next_samples = tf.multinomial(next_logits, num_samples=1, output_dtype=tf.int32) next_samples = tf.squeeze(next_samples, axis=-1) elif greedy_or_sample == "greedy": next_samples = tf.argmax(next_logits, axis=-1) else: next_samples = tf.argmax(next_logits, axis=-1) print(next_samples.get_shape(), "==sample shape==") print(tf.one_hot(i, seq_length + 1).get_shape(), "====shhhhape===") sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq, 1] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) return [ i + 1, next_outputs['presents'], next_samples, samples, segment_ids, logits ] def gumbel_st_body(i, past, prev, samples, gumbel_probs, segment_ids, logits): # next_outputs = step(i-1, prev[:, tf.newaxis], past=past) # gumbel_probs[:, i-1, :] next_outputs = step(i - 1, tf.expand_dims(gumbel_probs[:, i - 1, :], axis=1), segment_ids=segment_ids, past=past) next_logits = next_outputs['logits'][:, -1, :] / tf.to_float( temperature) next_logits = tf.nn.log_softmax(next_logits, axis=-1) # if greedy_or_sample == "sample": # next_samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) # next_samples = tf.squeeze(next_samples, axis=-1) # elif greedy_or_sample == "greedy": # next_samples = tf.argmax(logits, axis=-1, keepdims=True) # else: # next_samples = tf.argmax(logits, axis=-1, keepdims=True) next_gumbel_probs, _ = gumbel_softmax(next_logits, gumbel_temp, gumbel_samples=None, samples=1) next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1), tf.int32) next_samples_onehot = tf.one_hot(next_samples, model_config.vocab_size, axis=1) # sampled multiminal id straight_through_onehot = tf.stop_gradient( next_samples_onehot - next_gumbel_probs) + next_gumbel_probs print(next_gumbel_probs.get_shape(), "=====gumbel====", straight_through_onehot.get_shape()) gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot( i, actual_length), axis=0), axis=2) # [1, seq, 1] gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims( straight_through_onehot, axis=1) # b x 1 x vocab sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq, 1] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) return [ i + 1, next_outputs['presents'], next_samples, samples, gumbel_probs, segment_ids, logits ] def gumbel_soft_body(i, past, prev, samples, gumbel_probs, segment_ids, logits): next_outputs = step(i - 1, prev[:, tf.newaxis], segment_ids=segment_ids, past=past) # # gumbel_probs[:, i-1, :] # next_outputs = step(i-1, tf.expand_dims(gumbel_probs[:, i-1, :], axis=1), # segment_ids=segment_ids, # past=past) next_logits = next_outputs['logits'][:, -1, :] / tf.to_float( temperature) next_logits = tf.nn.log_softmax(next_logits, axis=-1) # # gumbel sample next_gumbel_probs, _ = gumbel_softmax(next_logits, gumbel_temp, gumbel_samples=None, samples=1) next_samples = tf.cast(tf.argmax(next_gumbel_probs, axis=1), tf.int32) next_samples_onehot = tf.one_hot(next_samples, model_config.vocab_size, axis=1) # sampled multiminal id # straight-through token_matrix # straight_through_onehot = tf.stop_gradient(next_samples_onehot-next_gumbel_probs)+next_gumbel_probs print(next_gumbel_probs.get_shape()) gumbel_mask = tf.expand_dims(tf.expand_dims(tf.one_hot( i, actual_length), axis=0), axis=2) # [1, seq, 1] gumbel_probs += tf.cast(gumbel_mask, tf.float32) * tf.expand_dims( next_gumbel_probs, axis=1) # b x 1 x vocab sample_mask = tf.expand_dims(tf.one_hot(i, actual_length), axis=(0)) # [1, seq] print(sample_mask.get_shape(), "==sample mask shape==") print(samples.get_shape(), "==samples shape==") samples += tf.cast(sample_mask, tf.int32) * tf.cast( tf.expand_dims(next_samples, axis=-1), tf.int32) next_sample_logits = get_samples_logits(next_samples, next_logits) logits += tf.cast(sample_mask, tf.float32) * tf.expand_dims( next_sample_logits, axis=-1) return [ i + 1, next_outputs['presents'], next_samples, samples, gumbel_probs, segment_ids, logits ] init_i = bert_utils.get_shape_list(context[:, :-1], expected_rank=[2, 3])[1] + 1 if kargs.get("mask_type", "left2right") == 'left2right': left_segment_ids = tf.expand_dims(tf.cast( tf.zeros_like(context[:, -1]), tf.int32), axis=-1) elif kargs.get("mask_type", "left2right") == 'seq2seq': left_segment_ids = tf.expand_dims(tf.cast( tf.ones_like(context[:, -1]), tf.int32), axis=-1) if estimator == "straight_through": final, presents, _, samples, gumbel_probs, _, logits = tf.while_loop( cond=lambda i, _1, _2, _3, _4, _5, _6: i < seq_length - 1, body=gumbel_st_body, loop_vars=[ init_i, context_output['presents'], # presents, context[:, -1], samples, gumbel_probs, left_segment_ids, logits ], back_prop=back_prop, swap_memory=swap_memory) elif estimator == "soft": final, presents, _, samples, gumbel_probs, _, logits = tf.while_loop( cond=lambda i, _1, _2, _3, _4, _5, _6: i < seq_length - 1, body=gumbel_soft_body, loop_vars=[ init_i, context_output['presents'], # presents, context[:, -1], samples, gumbel_probs, left_segment_ids, logits ], back_prop=back_prop, swap_memory=swap_memory) else: final, presents, _, samples, _, logits = tf.while_loop( cond=lambda i, _1, _2, _3, _4, _5: i < seq_length - 1, body=body, loop_vars=[ init_i, context_output['presents'], # presents, context[:, -1], samples, left_segment_ids, logits ], back_prop=back_prop, swap_memory=swap_memory) # results = body(5, presents, context[:, -1], samples) # samples = results[-1] # print(samples) mask_sequence = get_finised_pos(samples, end_token, actual_length) # print(mask_sequence.get_shape()) # samples *= tf.cast(mask_sequence, tf.int32) # logits *= tf.cast(mask_sequence, tf.float32) if estimator in ["straight_through", "soft"]: gumbel_probs *= tf.expand_dims(tf.cast(mask_sequence, tf.float32), axis=-1) return samples, gumbel_probs, presents, logits, final else: return samples, mask_sequence, presents, logits, final
def build_encoder(self, input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, past=None, decode_loop_step=None, max_decode_length=None, if_bp=False, if_cache_decode=None, **kargs): reuse = kargs["reuse"] input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) batch_size = input_shape[0] seq_length = input_shape[1] if input_mask is None: input_mask = tf.ones(shape=[batch_size, seq_length], dtype=tf.int32) with tf.variable_scope(self.config.get("scope", "bert"), reuse=reuse): with tf.variable_scope("encoder"): # This converts a 2D mask of shape [batch_size, seq_length] to a 3D # mask of shape [batch_size, seq_length, seq_length] which is used # for the attention scores. input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) if len(input_shape) == 3: tmp_input_ids = tf.argmax(input_ids, axis=-1) else: tmp_input_ids = input_ids if decode_loop_step is None: self.bi_attention_mask = bert_seq_modules.create_attention_mask_from_input_mask( tmp_input_ids, input_mask) else: if max_decode_length is None: max_decode_length = self.max_position_embeddings # [max_decode_length, 1] input_mask = tf.expand_dims(tf.sequence_mask( decode_loop_step + 1, maxlen=max_decode_length), axis=-1) # [1, max_decode_length] input_mask = tf.transpose(input_mask, perm=[1, 0]) input_mask = tf.tile(input_mask, [batch_size, 1]) self.bi_attention_mask = bert_seq_modules.create_attention_mask_from_input_mask( tmp_input_ids, input_mask) seq_type = kargs.get('seq_type', "None") print(seq_type) if seq_type == "seq2seq": if kargs.get("mask_type", "left2right") == "left2right": mask_sequence = None tf.logging.info( "==apply left2right LM model with casual mask==") elif kargs.get("mask_type", "left2right") == "seq2seq": token_type_ids = kargs.get("token_type_ids", None) tf.logging.info( "==apply left2right LM model with conditional casual mask==" ) if token_type_ids is None: token_type_ids = tf.zeros_like(input_mask) tf.logging.info( "==conditional mask is set to 0 and degenerate to left2right LM model==" ) mask_sequence = token_type_ids else: mask_sequence = None if decode_loop_step is None: self.attention_mask = bert_utils.generate_seq2seq_mask( self.bi_attention_mask, mask_sequence, seq_type) else: # with loop step, we must do casual decoding self.attention_mask = bert_utils.generate_seq2seq_mask( self.bi_attention_mask, None, seq_type) else: tf.logging.info( "==apply bi-directional LM model with bi-directional mask==" ) self.attention_mask = self.bi_attention_mask # Run the stacked transformer. # `sequence_output` shape = [batch_size, seq_length, hidden_size]. if kargs.get('attention_type', 'normal_attention') == 'normal_attention': tf.logging.info("****** normal attention *******") transformer_model = bert_seq_modules.transformer_model elif kargs.get('attention_type', 'normal_attention') == 'rezero_transformer': transformer_model = bert_seq_modules.transformer_rezero_model tf.logging.info("****** rezero_transformer *******") else: tf.logging.info("****** normal attention *******") transformer_model = bert_seq_modules.transformer_model [ self.all_encoder_layers, self.all_present, self.all_attention_scores, self.all_value_outputs ] = transformer_model( input_tensor=self.embedding_output, attention_mask=self.attention_mask, hidden_size=self.config.hidden_size, num_hidden_layers=self.config.num_hidden_layers, num_attention_heads=self.config.num_attention_heads, intermediate_size=self.config.intermediate_size, intermediate_act_fn=bert_seq_modules.get_activation( self.config.hidden_act), hidden_dropout_prob=hidden_dropout_prob, attention_probs_dropout_prob=attention_probs_dropout_prob, initializer_range=self.config.initializer_range, do_return_all_layers=True, past=past, decode_loop_step=decode_loop_step, if_bp=if_bp, if_cache_decode=if_cache_decode, attention_fixed_size=self.config.get( 'attention_fixed_size', None))
def multi_position_crf_classifier(config, features, model_dict, num_labels, dropout_prob): batch_size = features['batch_size'] total_length_a = features['total_length_a'] total_length_b = features['total_length_b'] sequence_output_a = model_dict["a"].get_sequence_output( ) # [batch x 10, 130, 768] shape_lst = bert_utils.get_shape_list(sequence_output_a, expected_rank=3) sequence_output_a = tf.reshape( sequence_output_a, [-1, total_length_a, shape_lst[-1]]) # [batch, 10 x 130, 768] answer_pos = tf.cast(features['label_positions'], tf.int32) sequence_output_a = bert_utils.gather_indexes( sequence_output_a, answer_pos) # [batch*10, 768] sequence_output_a = tf.reshape( sequence_output_a, [-1, config.max_predictions_per_seq, shape_lst[-1] ]) # [batch, 10, 768] sequence_output_b = model_dict["b"].get_pooled_output() # [batch x 10,768] sequence_output_b = tf.reshape( sequence_output_b, [-1, num_labels, shape_lst[-1]]) # [batch, 10, 768] seq_b_shape = bert_utils.get_shape_list(sequence_output_b, expected_rank=3) cross_matrix = tf.get_variable( "output_weights", [shape_lst[-1], shape_lst[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) # batch x 10 x 768 sequence_output_a_proj = tf.einsum("abc,cd->abd", sequence_output_a, cross_matrix) # batch x 10 x 768. batch x 10 x 768 # batch x 10(ans_pos) x 11(ans_field) logits = tf.einsum("abd,acd->abc", sequence_output_a_proj, sequence_output_b) logits = tf.multiply( logits, 1.0 / tf.math.sqrt(tf.cast(shape_lst[-1], tf.float32))) # print(sequence_output_a.get_shape(), sequence_output_b.get_shape(), logits.get_shape()) # label_ids = tf.cast(features['label_ids'], tf.int32) # label_weights = tf.cast(features['label_weights'], tf.int32) # label_seq_length = tf.reduce_sum(label_weights, axis=-1) # transition = zero_transition(seq_b_shape) # log_likelihood, transition_params = tf.contrib.crf.crf_log_likelihood( # inputs=logits, # tag_indices=label_ids, # sequence_lengths=label_seq_length, # transition_params=transition) # transition_params = tf.stop_gradient(transition_params) # per_example_loss = -log_likelihood # loss = tf.reduce_mean(per_example_loss) return (loss, per_example_loss, logits, transition_params)
def build_embedder(self, input_ids, token_type_ids, hidden_dropout_prob, attention_probs_dropout_prob, past=None, decode_loop_step=None, **kargs): reuse = kargs["reuse"] if self.config.get("embedding", "none_factorized") == "none_factorized": projection_width = self.config.hidden_size tf.logging.info("==not using embedding factorized==") else: projection_width = self.config.get('embedding_size', self.config.hidden_size) tf.logging.info( "==using embedding factorized: embedding size: %s==", str(projection_width)) if self.config.get('embedding_scope', None): embedding_scope = self.config['embedding_scope'] other_embedding_scope = self.config[ 'embedding_scope'] #self.config.get("scope", "bert") tf.logging.info( "==using embedding scope of original model_config.embedding_scope: %s, other_embedding_scope:%s ==", embedding_scope, other_embedding_scope) else: embedding_scope = self.config.get("scope", "bert") other_embedding_scope = self.config.get("scope", "bert") tf.logging.info( "==using embedding scope of original model_config.embedding_scope: %s, other_embedding_scope:%s ==", embedding_scope, other_embedding_scope) if past is None: self.past_length = 0 else: # batch_size_, num_layers_, two_, num_heads_, self.cache_length, features_ if decode_loop_step is None: # gpu-decode length past_shape = bert_utils.get_shape_list(past, expected_rank=[6]) self.past_length = past_shape[-2] else: self.past_length = decode_loop_step with tf.variable_scope(embedding_scope, reuse=reuse): with tf.variable_scope("embeddings"): # Perform embedding lookup on the word ids. # (self.embedding_output_word, self.embedding_table) = bert_modules.embedding_lookup( # input_ids=input_ids, # vocab_size=self.config.vocab_size, # embedding_size=projection_width, # initializer_range=self.config.initializer_range, # word_embedding_name="word_embeddings", # use_one_hot_embeddings=self.config.use_one_hot_embeddings) input_shape = bert_utils.get_shape_list(input_ids, expected_rank=[2, 3]) print(input_shape, "=====input_shape=====") if len(input_shape) == 3: tf.logging.info("****** 3D embedding matmul *******") (self.embedding_output_word, self.embedding_table ) = bert_modules.gumbel_embedding_lookup( input_ids=input_ids, vocab_size=self.config.vocab_size, embedding_size=projection_width, initializer_range=self.config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=self.config. use_one_hot_embeddings) elif len(input_shape) == 2: (self.embedding_output_word, self.embedding_table) = bert_modules.embedding_lookup( input_ids=input_ids, vocab_size=self.config.vocab_size, embedding_size=projection_width, initializer_range=self.config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=self.config. use_one_hot_embeddings) else: (self.embedding_output_word, self.embedding_table) = bert_modules.embedding_lookup( input_ids=input_ids, vocab_size=self.config.vocab_size, embedding_size=projection_width, initializer_range=self.config.initializer_range, word_embedding_name="word_embeddings", use_one_hot_embeddings=self.config. use_one_hot_embeddings) if kargs.get("perturbation", None): self.embedding_output_word += kargs["perturbation"] tf.logging.info( " add word pertubation for robust learning ") with tf.variable_scope(other_embedding_scope, reuse=reuse): with tf.variable_scope("embeddings"): # Add positional embeddings and token type embeddings, then layer # normalize and perform dropout. tf.logging.info("==using segment type embedding ratio: %s==", str(self.config.get("token_type_ratio", 1.0))) self.embedding_output = bert_seq_modules.embedding_postprocessor( input_tensor=self.embedding_output_word, use_token_type=kargs.get('use_token_type', True), token_type_ids=token_type_ids, token_type_vocab_size=self.config.type_vocab_size, token_type_embedding_name="token_type_embeddings", use_position_embeddings=True, position_embedding_name="position_embeddings", initializer_range=self.config.initializer_range, max_position_embeddings=self.config. max_position_embeddings, dropout_prob=hidden_dropout_prob, token_type_ratio=self.config.get("token_type_ratio", 1.0), position_offset=self.past_length)
def attention_layer(from_tensor, to_tensor, attention_mask=None, num_attention_heads=1, size_per_head=512, query_act=None, key_act=None, value_act=None, attention_probs_dropout_prob=0.0, initializer_range=0.02, do_return_2d_tensor=False, batch_size=None, from_seq_length=None, to_seq_length=None): """Performs multi-headed attention from `from_tensor` to `to_tensor`. This is an implementation of multi-headed attention based on "Attention is all you Need". If `from_tensor` and `to_tensor` are the same, then this is self-attention. Each timestep in `from_tensor` attends to the corresponding sequence in `to_tensor`, and returns a fixed-with vector. This function first projects `from_tensor` into a "query" tensor and `to_tensor` into "key" and "value" tensors. These are (effectively) a list of tensors of length `num_attention_heads`, where each tensor is of shape [batch_size, seq_length, size_per_head]. Then, the query and key tensors are dot-producted and scaled. These are softmaxed to obtain attention probabilities. The value tensors are then interpolated by these probabilities, then concatenated back to a single tensor and returned. In practice, the multi-headed attention are done with transposes and reshapes rather than actual separate tensors. Args: from_tensor: float Tensor of shape [batch_size, from_seq_length, from_width]. to_tensor: float Tensor of shape [batch_size, to_seq_length, to_width]. attention_mask: (optional) int32 Tensor of shape [batch_size, from_seq_length, to_seq_length]. The values should be 1 or 0. The attention scores will effectively be set to -infinity for any positions in the mask that are 0, and will be unchanged for positions that are 1. num_attention_heads: int. Number of attention heads. size_per_head: int. Size of each attention head. query_act: (optional) Activation function for the query transform. key_act: (optional) Activation function for the key transform. value_act: (optional) Activation function for the value transform. attention_probs_dropout_prob: initializer_range: float. Range of the weight initializer. do_return_2d_tensor: bool. If True, the output will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]. If False, the output will be of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. batch_size: (Optional) int. If the input is 2D, this might be the batch size of the 3D version of the `from_tensor` and `to_tensor`. from_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `from_tensor`. to_seq_length: (Optional) If the input is 2D, this might be the seq length of the 3D version of the `to_tensor`. Returns: float Tensor of shape [batch_size, from_seq_length, num_attention_heads * size_per_head]. (If `do_return_2d_tensor` is true, this will be of shape [batch_size * from_seq_length, num_attention_heads * size_per_head]). Raises: ValueError: Any of the arguments or tensor shapes are invalid. """ def transpose_for_scores(input_tensor, batch_size, num_attention_heads, seq_length, width): output_tensor = tf.reshape( input_tensor, [batch_size, seq_length, num_attention_heads, width]) output_tensor = tf.transpose(output_tensor, [0, 2, 1, 3]) return output_tensor from_shape = bert_utils.get_shape_list(from_tensor, expected_rank=[2, 3]) to_shape = bert_utils.get_shape_list(to_tensor, expected_rank=[2, 3]) if len(from_shape) != len(to_shape): raise ValueError( "The rank of `from_tensor` must match the rank of `to_tensor`.") if len(from_shape) == 3: batch_size = from_shape[0] from_seq_length = from_shape[1] to_seq_length = to_shape[1] elif len(from_shape) == 2: if (batch_size is None or from_seq_length is None or to_seq_length is None): raise ValueError( "When passing in rank 2 tensors to attention_layer, the values " "for `batch_size`, `from_seq_length`, and `to_seq_length` " "must all be specified.") # Scalar dimensions referenced here: # B = batch size (number of sequences) # F = `from_tensor` sequence length # T = `to_tensor` sequence length # N = `num_attention_heads` # H = `size_per_head` from_tensor_2d = bert_utils.reshape_to_matrix(from_tensor) to_tensor_2d = bert_utils.reshape_to_matrix(to_tensor) # `query_layer` = [B*F, N*H] query_layer = tf.layers.dense( from_tensor_2d, num_attention_heads * size_per_head, activation=query_act, name="query", kernel_initializer=create_initializer(initializer_range)) # `key_layer` = [B*T, N*H] key_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=key_act, name="key", kernel_initializer=create_initializer(initializer_range)) # `value_layer` = [B*T, N*H] value_layer = tf.layers.dense( to_tensor_2d, num_attention_heads * size_per_head, activation=value_act, name="value", kernel_initializer=create_initializer(initializer_range)) # `query_layer` = [B, N, F, H] query_layer = transpose_for_scores(query_layer, batch_size, num_attention_heads, from_seq_length, size_per_head) # `key_layer` = [B, N, T, H] key_layer = transpose_for_scores(key_layer, batch_size, num_attention_heads, to_seq_length, size_per_head) # Take the dot product between "query" and "key" to get the raw # attention scores. # `attention_scores` = [B, N, F, T] attention_scores = tf.matmul(query_layer, key_layer, transpose_b=True) attention_scores = tf.multiply(attention_scores, 1.0 / math.sqrt(float(size_per_head))) if attention_mask is not None: # `attention_mask` = [B, 1, F, T] attention_mask = tf.expand_dims(attention_mask, axis=[1]) # Since attention_mask is 1.0 for positions we want to attend and 0.0 for # masked positions, this operation will create a tensor which is 0.0 for # positions we want to attend and -10000.0 for masked positions. adder = (1.0 - tf.cast(attention_mask, tf.float32)) * -10000.0 # Since we are adding it to the raw scores before the softmax, this is # effectively the same as removing these entirely. attention_scores += adder # Normalize the attention scores to probabilities. # `attention_probs` = [B, N, F, T] # attention_probs = tf.nn.softmax(attention_scores) attention_probs = tf.exp(tf.nn.log_softmax(attention_scores)) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = dropout(attention_probs, attention_probs_dropout_prob) # `value_layer` = [B, T, N, H] value_layer = tf.reshape( value_layer, [batch_size, to_seq_length, num_attention_heads, size_per_head]) # `value_layer` = [B, N, T, H] value_layer = tf.transpose(value_layer, [0, 2, 1, 3]) # `context_layer` = [B, N, F, H] context_layer = tf.matmul(attention_probs, value_layer) # `context_layer` = [B, F, N, H] context_layer = tf.transpose(context_layer, [0, 2, 1, 3]) if do_return_2d_tensor: # `context_layer` = [B*F, N*V] context_layer = tf.reshape( context_layer, [batch_size * from_seq_length, num_attention_heads * size_per_head]) else: # `context_layer` = [B, F, N*V] context_layer = tf.reshape( context_layer, [batch_size, from_seq_length, num_attention_heads * size_per_head]) return context_layer
def hidden_cls_matching(teacher_hidden, student_hidden, match_direction=0): teacher_shape = bert_utils.get_shape_list(teacher_hidden[0], expected_rank=[3]) student_shape = bert_utils.get_shape_list(student_hidden[0], expected_rank=[3]) if match_direction == 0: with tf.variable_scope("attention_weights", reuse=tf.AUTO_REUSE): projection_weights = tf.get_variable( "attention_score_weights", [len(student_hidden), len(teacher_hidden)], initializer=tf.constant_initializer(np.ones( (len(student_hidden), len(teacher_hidden))) / len(teacher_hidden), dtype=tf.float32)) normalized_weights = tf.abs(projection_weights) / tf.reduce_sum( tf.abs(projection_weights), axis=-1, keepdims=True) else: print("===apply teacher model to student model==") with tf.variable_scope("attention_weights", reuse=tf.AUTO_REUSE): projection_weights = tf.get_variable( "attention_score_weights", [len(student_hidden), len(teacher_hidden)], initializer=tf.constant_initializer(np.ones( (len(student_hidden), len(teacher_hidden))) / len(student_hidden), dtype=tf.float32)) normalized_weights = tf.abs(projection_weights) / tf.reduce_sum( tf.abs(projection_weights), axis=0, keepdims=True) # B X F X H def projection_fn(input_tensor): with tf.variable_scope("uniformal_mapping/projection", reuse=tf.AUTO_REUSE): projection_weights = tf.get_variable( "output_weights", [student_shape[-1], teacher_shape[-1]], initializer=tf.truncated_normal_initializer(stddev=0.02)) input_tensor_projection = tf.einsum("ac,cd->ad", input_tensor, projection_weights) return input_tensor_projection loss = tf.constant(0.0) for i in range(len(student_hidden)): student_hidden_ = student_hidden[i][:, 0:1, :] student_hidden_ = tf.squeeze(student_hidden_, axis=1) student_hidden_ = projection_fn(student_hidden_) student_hidden_ = tf.nn.l2_normalize(student_hidden_, axis=-1) for j in range(len(teacher_hidden)): teacher_hidden_ = teacher_hidden[j][:, 0:1, :] teacher_hidden_ = tf.squeeze(teacher_hidden_, axis=1) teacher_hidden_ = tf.nn.l2_normalize(teacher_hidden_, axis=-1) weight = normalized_weights[i, j] # normalized to [0,1] tmp_loss = weight * l1_distance( student_hidden_, teacher_hidden_, axis=-1) loss += tf.reduce_mean(tmp_loss, axis=0) loss /= (len(student_hidden) * len(teacher_hidden)) return loss