def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if kargs.get('random_generator', '1') == '1': if mode in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN ]: input_ori_ids = features['input_ori_ids'] # [output_ids, # sampled_binary_mask] = random_input_ids_generation(model_config, # features['input_ori_ids'], # features['input_mask'], # mask_probability=0.2, # replace_probability=0.1, # original_probability=0.1, # **kargs) [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.2, replace_probability=0.0, original_probability=0.0, mask_prior=tf.constant(mask_prior, tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info("****** do random generator *******") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope=generator_scope_prefix) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) masked_lm_ids = features['input_ori_ids'] else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss if kargs.get("resample_discriminator", False): input_ori_ids = features['input_ori_ids'] [output_ids, sampled_binary_mask ] = random_input_ids_generation(model_config, features['input_ori_ids'], features['input_mask'], mask_probability=0.2, replace_probability=0.1, original_probability=0.1) resample_features = {} for key in features: resample_features[key] = features[key] resample_features['input_ids'] = tf.identity(output_ids) model_resample = model_api(model_config, resample_features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) tf.logging.info("**** apply discriminator resample **** ") else: model_resample = model resample_features = features tf.logging.info("**** not apply discriminator resample **** ") sampled_ids = token_generator(model_config, model_resample.get_sequence_output(), model_resample.get_embedding_table(), resample_features['input_ids'], resample_features['input_ori_ids'], resample_features['input_mask'], embedding_projection=model_resample. get_embedding_projection_table(), scope=generator_scope_prefix, mask_method='only_mask', use_tpu=kargs.get('use_tpu', True)) if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] else: input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1) # batch x seq_length x 1 input_ids = tf.einsum( 'abc,cd->abd', input_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] gen_sample = input_shape_list[2] sampled_ids = tf.reshape(sampled_ids, [batch * gen_sample, seq_length]) input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length]) input_mask = tf.expand_dims(features['input_mask'], axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.int32) segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1) segment_ids = tf.einsum( 'abc,cd->abd', segment_ids, tf.ones((1, model_config.get('gen_sample', 1)))) segment_ids = tf.cast(segment_ids, tf.int32) segment_ids = tf.reshape(segment_ids, [batch * gen_sample, seq_length]) input_mask = tf.reshape(input_mask, [batch * gen_sample, seq_length]) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if generator_scope_prefix: """ "generator/cls/predictions" """ lm_pretrain_tvars = model_io_fn.get_params( generator_scope_prefix + "/cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( generator_scope_prefix + "/cls/seq_relationship", not_storage_params=not_storage_params) else: lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( "cls/seq_relationship", not_storage_params=not_storage_params) if model_config.get('embedding_scope', None) is not None: embedding_tvars = model_io_fn.get_params( model_config.get('embedding_scope', 'bert') + "/embeddings", not_storage_params=not_storage_params) pretrained_tvars.extend(embedding_tvars) pretrained_tvars.extend(lm_pretrain_tvars) pretrained_tvars.extend(nsp_pretrain_vars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None tf.add_to_collection("generator_loss", loss) return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_mask, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss, "next_sentence_example_loss": nsp_per_example_loss, "next_sentence_log_probs": nsp_log_prob, "next_sentence_labels": features['next_sentence_labels'], "sampled_binary_mask": sampled_binary_mask } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if 'input_mask' not in features: input_mask = tf.cast( tf.not_equal(features['input_ids_{}'.format(target)], kargs.get('[PAD]', 0)), tf.int32) if target: features['input_mask_{}'.format(target)] = input_mask else: features['input_mask'] = input_mask if 'segment_ids' not in features: segment_ids = tf.zeros_like(input_mask) if target: features['segment_ids_{}'.format(target)] = segment_ids else: features['segment_ids'] = segment_ids if target: features['input_ori_ids'] = features['input_ids_{}'.format(target)] features['input_mask'] = features['input_mask_{}'.format(target)] features['segment_ids'] = features['segment_ids_{}'.format(target)] features['input_ids'] = features['input_ids_{}'.format(target)] input_ori_ids = features.get('input_ori_ids', None) if mode == tf.estimator.ModeKeys.TRAIN: if input_ori_ids is not None: # [output_ids, # sampled_binary_mask] = random_input_ids_generation( # model_config, # input_ori_ids, # features['input_mask'], # **kargs) [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.2, replace_probability=0.1, original_probability=0.1, mask_prior=tf.cast(tf.constant(mask_prior), tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info( "***** Running random sample input generation *****") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope #(nsp_loss, # nsp_per_example_loss, # nsp_log_prob) = pretrain.get_next_sentence_output(model_config, # model.get_pooled_output(), # features['next_sentence_labels'], # reuse=tf.AUTO_REUSE) # masked_lm_positions = features["masked_lm_positions"] # masked_lm_ids = features["masked_lm_ids"] # masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) masked_lm_ids = input_ori_ids else: masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) if load_pretrained == "yes": scaffold_fn = model_io_fn.load_pretrained( pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=1) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=opt_config.use_tpu) # train_metric_dict = train_metric_fn( # masked_lm_example_loss, masked_lm_log_probs, # masked_lm_ids, # masked_lm_mask, # nsp_per_example_loss, # nsp_log_prob, # features['next_sentence_labels'], # masked_lm_mask=masked_lm_mask # ) # for key in train_metric_dict: # tf.summary.scalar(key, train_metric_dict[key]) # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape( masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss } eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_mask, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels'] ]) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if target: features['input_ori_ids'] = features['input_ids_{}'.format(target)] features['input_ids'] = features['input_ids_{}'.format(target)] sequence_mask = tf.cast( tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)), tf.int32) features['input_mask'] = sequence_mask seq_features = {} for key in features: seq_features[key] = features[key] if 'input_ori_ids' in features: seq_features['input_ids'] = features["input_ori_ids"] else: features['input_ori_ids'] = seq_features['input_ids'] not_equal = tf.cast( tf.not_equal(features["input_ori_ids"], tf.zeros_like(features["input_ori_ids"])), tf.int32) not_equal = tf.reduce_sum(not_equal, axis=-1) loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)), tf.float32) if not kargs.get('use_tpu', False): tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask)) casual_flag = model_config.get('is_casual', True) tf.logging.info("***** is casual flag *****", str(casual_flag)) if not casual_flag: [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.02, replace_probability=0.01, original_probability=0.01, mask_prior=tf.cast(tf.constant(mask_prior), tf.float32), **kargs) tf.logging.info("***** apply random sampling *****") seq_features['input_ids'] = output_ids model = model_api(model_config, seq_features, labels, mode, "", reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope # if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('mask_type', 'left2right') == 'left2right': tf.logging.info("***** using left2right mask and loss *****") sequence_mask = tf.to_float( tf.not_equal(features['input_ori_ids'][:, 1:], kargs.get('[PAD]', 0))) elif kargs.get('mask_type', 'left2right') == 'seq2seq': tf.logging.info("***** using seq2seq mask and loss *****") sequence_mask = tf.to_float(features['segment_ids'][:, 1:]) if not kargs.get('use_tpu', False): tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask)) # batch x seq_length if casual_flag: print(model.get_sequence_output_logits().get_shape(), "===logits shape===") seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ori_ids'][:, 1:], logits=model.get_sequence_output_logits()[:, :-1]) per_example_loss = tf.reduce_sum( seq_loss * sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10) loss = tf.reduce_mean(per_example_loss) if model_config.get("cnn_type", "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']: seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ori_ids'][:, :-1], logits=model.get_sequence_backward_output_logits()[:, 1:]) per_backward_example_loss = tf.reduce_sum( seq_backward_loss * sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10) backward_loss = tf.reduce_mean(per_backward_example_loss) loss += backward_loss tf.logging.info("***** using backward loss *****") else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = pretrain.seq_mask_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), seq_features['input_mask'], seq_features['input_ori_ids'], seq_features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) loss = masked_lm_loss tf.logging.info("***** using masked lm loss *****") model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 tf.logging.info( "***** using tpu with tpu-captiable optimizer *****") else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 tf.logging.info( "***** using gpu with gpu-captiable optimizer *****") tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=use_tpu) # train_metric_dict = train_metric(features['input_ori_ids'], # model.get_sequence_output_logits(), # seq_features, # **kargs) # if not kargs.get('use_tpu', False): # for key in train_metric_dict: # tf.summary.scalar(key, train_metric_dict[key]) # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) # tf.logging.info("***** logging metric *****") # tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask)) # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask)) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: gpu_eval_metrics = eval_metric(features['input_ori_ids'], model.get_sequence_output_logits(), seq_features, **kargs) tpu_eval_metrics = (eval_metric, [ features['input_ori_ids'], model.get_sequence_output_logits(), seq_features, kargs.get('mask_type', 'left2right') ]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: if kargs.get('predict_type', 'sample_sequence') == 'sample_sequence': results = bert_seq_sample_utils.sample_sequence( model_api, model_config, mode, features, target="", start_token=kargs.get("start_token_id", 101), batch_size=None, context=features.get("context", None), temperature=kargs.get("sample_temp", 1.0), n_samples=kargs.get("n_samples", 1), top_k=0, end_token=kargs.get("end_token_id", 102), greedy_or_sample="greedy", gumbel_temp=0.01, estimator="stop_gradient", back_prop=True, swap_memory=True, seq_type=kargs.get("seq_type", "seq2seq"), mask_type=kargs.get("mask_type", "seq2seq"), attention_type=kargs.get('attention_type', 'normal_attention')) # stop_gradient output: # samples, mask_sequence, presents, logits, final sampled_token = results['samples'] sampled_token_logits = results['logits'] mask_sequence = results['mask_sequence'] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token': sampled_token, "logits": sampled_token_logits, "mask_sequence": mask_sequence }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'token': sampled_token, "logits": sampled_token_logits, "mask_sequence": mask_sequence }) }) return estimator_spec elif kargs.get('predict_type', 'sample_sequence') == 'infer_inputs': sequence_mask = tf.to_float( tf.not_equal(features['input_ids'][:, 1:], kargs.get('[PAD]', 0))) if kargs.get('mask_type', 'left2right') == 'left2right': tf.logging.info( "***** using left2right mask and loss *****") sequence_mask = tf.to_float( tf.not_equal(features['input_ori_ids'][:, 1:], kargs.get('[PAD]', 0))) elif kargs.get('mask_type', 'left2right') == 'seq2seq': tf.logging.info("***** using seq2seq mask and loss *****") sequence_mask = tf.to_float(features['segment_ids'][:, 1:]) if not kargs.get('use_tpu', False): tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask)) output_logits = model.get_sequence_output_logits()[:, :-1] # output_logits = tf.nn.log_softmax(output_logits, axis=-1) output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ids'][:, 1:], logits=output_logits) per_example_perplexity = tf.reduce_sum(output_id_logits * sequence_mask, axis=-1) # batch per_example_perplexity /= tf.reduce_sum(sequence_mask, axis=-1) # batch perplexity = tf.exp(per_example_perplexity) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token': features['input_ids'][:, 1:], "logits": output_id_logits, 'perplexity': perplexity, # "all_logits":output_logits }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'token': features['input_ids'][:, 1:], "logits": output_id_logits, 'perplexity': perplexity, # "all_logits":output_logits }) }) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) input_mask = tf.cast( tf.not_equal(features['input_ids_{}'.format(target)], kargs.get('[PAD]', 0)), tf.int32) segment_ids = tf.zeros_like(input_mask) if target: features['input_ori_ids'] = features['input_ids_{}'.format(target)] features['input_mask'] = input_mask features['segment_ids'] = segment_ids # features['input_mask'] = features['input_mask_{}'.format(target)] # features['segment_ids'] = features['segment_ids_{}'.format(target)] features['input_ids'] = features['input_ids_{}'.format(target)] input_ori_ids = features.get('input_ori_ids', None) if mode == tf.estimator.ModeKeys.TRAIN: if input_ori_ids is not None: [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.1, replace_probability=0.1, original_probability=0.1, mask_prior=tf.cast(tf.constant(mask_prior), tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info( "***** Running random sample input generation *****") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, "", reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) masked_lm_ids = input_ori_ids else: masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None return_dict = { "loss": loss, "logits": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss, "tvars": tvars, "model": model, "masked_lm_mask": masked_lm_mask, "output_ids": output_ids, "masked_lm_ids": masked_lm_ids } return return_dict