def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if kargs.get('random_generator', '1') == '1': if mode in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN ]: input_ori_ids = features['input_ori_ids'] # [output_ids, # sampled_binary_mask] = random_input_ids_generation(model_config, # features['input_ori_ids'], # features['input_mask'], # mask_probability=0.2, # replace_probability=0.1, # original_probability=0.1, # **kargs) [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.2, replace_probability=0.0, original_probability=0.0, mask_prior=tf.constant(mask_prior, tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info("****** do random generator *******") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope=generator_scope_prefix) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) masked_lm_ids = features['input_ori_ids'] else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss if kargs.get("resample_discriminator", False): input_ori_ids = features['input_ori_ids'] [output_ids, sampled_binary_mask ] = random_input_ids_generation(model_config, features['input_ori_ids'], features['input_mask'], mask_probability=0.2, replace_probability=0.1, original_probability=0.1) resample_features = {} for key in features: resample_features[key] = features[key] resample_features['input_ids'] = tf.identity(output_ids) model_resample = model_api(model_config, resample_features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) tf.logging.info("**** apply discriminator resample **** ") else: model_resample = model resample_features = features tf.logging.info("**** not apply discriminator resample **** ") sampled_ids = token_generator(model_config, model_resample.get_sequence_output(), model_resample.get_embedding_table(), resample_features['input_ids'], resample_features['input_ori_ids'], resample_features['input_mask'], embedding_projection=model_resample. get_embedding_projection_table(), scope=generator_scope_prefix, mask_method='only_mask', use_tpu=kargs.get('use_tpu', True)) if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] else: input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1) # batch x seq_length x 1 input_ids = tf.einsum( 'abc,cd->abd', input_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] gen_sample = input_shape_list[2] sampled_ids = tf.reshape(sampled_ids, [batch * gen_sample, seq_length]) input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length]) input_mask = tf.expand_dims(features['input_mask'], axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.int32) segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1) segment_ids = tf.einsum( 'abc,cd->abd', segment_ids, tf.ones((1, model_config.get('gen_sample', 1)))) segment_ids = tf.cast(segment_ids, tf.int32) segment_ids = tf.reshape(segment_ids, [batch * gen_sample, seq_length]) input_mask = tf.reshape(input_mask, [batch * gen_sample, seq_length]) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if generator_scope_prefix: """ "generator/cls/predictions" """ lm_pretrain_tvars = model_io_fn.get_params( generator_scope_prefix + "/cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( generator_scope_prefix + "/cls/seq_relationship", not_storage_params=not_storage_params) else: lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( "cls/seq_relationship", not_storage_params=not_storage_params) if model_config.get('embedding_scope', None) is not None: embedding_tvars = model_io_fn.get_params( model_config.get('embedding_scope', 'bert') + "/embeddings", not_storage_params=not_storage_params) pretrained_tvars.extend(embedding_tvars) pretrained_tvars.extend(lm_pretrain_tvars) pretrained_tvars.extend(nsp_pretrain_vars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None tf.add_to_collection("generator_loss", loss) return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_mask, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss, "next_sentence_example_loss": nsp_per_example_loss, "next_sentence_log_probs": nsp_log_prob, "next_sentence_labels": features['next_sentence_labels'], "sampled_binary_mask": sampled_binary_mask } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output(model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) if load_pretrained == "yes": scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=1) else: scaffold_fn = None print("******* scaffold fn *******", scaffold_fn) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): print('==gpu count==', opt_config.get('gpu_count', 1)) train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=opt_config.use_tpu) train_metric_dict = train_metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels'], masked_lm_mask=masked_lm_mask ) # for key in train_metric_dict: # tf.summary.scalar(key, train_metric_dict[key]) # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax( masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax( next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss } eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels'] ]) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope='generator') masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output print("==apply bert masked lm==") (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope='generator') print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss sampled_ids = token_generator( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_ids'], features['input_ori_ids'], features['input_mask'], embedding_projection=model.get_embedding_projection_table(), scope='generator', mask_method='only_mask') if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] else: input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1) # batch x seq_length x 1 input_ids = tf.einsum( 'abc,cd->abd', input_ids, tf.ones((1, model_config.get('gen_sample', 1)))) input_ids = tf.cast(input_ids, tf.int32) input_shape_list = bert_utils.get_shape_list(input_ids, expected_rank=3) batch_size = input_shape_list[0] seq_length = input_shape_list[1] gen_sample = input_shape_list[2] sampled_ids = tf.reshape(sampled_ids, [batch * gen_sample, seq_length]) input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length]) input_mask = tf.expand_dims(features['input_mask'], axis=-1) input_mask = tf.einsum( 'abc,cd->abd', input_mask, tf.ones((1, model_config.get('gen_sample', 1)))) input_mask = tf.cast(input_mask, tf.int32) segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1) segment_ids = tf.einsum( 'abc,cd->abd', segment_ids, tf.ones((1, model_config.get('gen_sample', 1)))) segment_ids = tf.cast(segment_ids, tf.int32) segment_ids = tf.reshape(segment_ids, [batch * gen_sample, seq_length]) input_mask = tf.reshape(input_mask, [batch * gen_sample, seq_length]) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "generator/cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope="generator", use_tpu=use_tpu) else: scaffold_fn = None return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_positions": masked_lm_positions, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_weights, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if kargs.get('random_generator', '1') == '1': if mode in [ tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL ]: input_ori_ids = features['input_ori_ids'] [output_ids, sampled_binary_mask ] = random_input_ids_generation(model_config, features['input_ori_ids'], features['input_mask']) features['input_ids'] = tf.identity(output_ids) tf.logging.info("****** do random generator *******") else: sampled_binary_mask = None output_ids = tf.identity(features['input_ids']) else: sampled_binary_mask = None output_ids = tf.identity(features['input_ids']) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope=generator_scope_prefix) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) masked_lm_ids = features['input_ori_ids'] else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss sampled_ids = token_generator_igr( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_ids'], features['input_ori_ids'], features['input_mask'], embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix, mask_method='only_mask', **kargs) if model_config.get('gen_sample', 1) == 1: input_ids = features['input_ori_ids'] input_mask = features['input_mask'] segment_ids = features['segment_ids'] model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if generator_scope_prefix: """ "generator/cls/predictions" """ lm_pretrain_tvars = model_io_fn.get_params( generator_scope_prefix + "/cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( generator_scope_prefix + "/cls/seq_relationship", not_storage_params=not_storage_params) else: lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( "cls/seq_relationship", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) else: scaffold_fn = None # tf.add_to_collection("generator_loss", masked_lm_loss) return_dict = { "loss": loss, "tvars": tvars, "model": model, "sampled_ids": sampled_ids, # batch x gen_sample, seg_length "sampled_input_ids": input_ids, # batch x gen_sample, seg_length, "sampled_input_mask": input_mask, "sampled_segment_ids": segment_ids, "masked_lm_ids": masked_lm_ids, "masked_lm_weights": masked_lm_mask, "masked_lm_log_probs": masked_lm_log_probs, "masked_lm_example_loss": masked_lm_example_loss, "next_sentence_example_loss": nsp_per_example_loss, "next_sentence_log_probs": nsp_log_prob, "next_sentence_labels": features['next_sentence_labels'], "output_ids": output_ids } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE) if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], features['input_mask'], reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) masked_lm_ids = features['input_ori_ids'] loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params( "cls/seq_relationship", not_storage_params=not_storage_params) if model_config.get('embedding_scope', None) is not None: embedding_tvars = model_io_fn.get_params( model_config.get('embedding_scope', 'bert') + "/embeddings", not_storage_params=not_storage_params) pretrained_tvars.extend(embedding_tvars) pretrained_tvars.extend(lm_pretrain_tvars) pretrained_tvars.extend(nsp_pretrain_vars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None tf.add_to_collection("discriminator_loss", loss) return_dict = { "loss": loss, "tvars": tvars, "model": model, "per_example_loss": masked_lm_example_loss, "masked_lm_weights": masked_lm_mask, "masked_lm_log_probs": masked_lm_log_probs, "next_sentence_example_loss": nsp_per_example_loss, "next_sentence_log_probs": nsp_log_prob, "next_sentence_labels": features['next_sentence_labels'] } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (_, _, nsp_log_prob) = pretrain.get_next_sentence_output(model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE) with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE): (_, logits, _) = classifier(model_config, model.get_sequence_output(), features['input_ori_ids'], features['input_ids'], features['input_mask'], 2, dropout_prob) # , # loss='focal_loss') # loss += 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_seq_prediction_tvars) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars print('==discriminator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.PREDICT: mask = tf.cast(tf.expand_dims(features['input_mask'], axis=-1), tf.float32) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ "probs":tf.nn.softmax(logits)*mask }, export_outputs={ "output":tf.estimator.export.PredictOutput( { "probs":tf.nn.softmax(logits)*mask } ) } ) return estimator_spec
def next_sentence_prediction(model_config, model, features, reuse=None): next_sentence_labels = features["next_sentence_label"] (next_sentence_loss, next_sentence_example_loss, next_sentence_log_probs) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), next_sentence_labels, reuse=None)
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=model_reuse) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) loss = model_config.lm_ratio * masked_lm_loss + model_config.nsp_ratio * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) optimizer_fn = optimizer.Optimizer(opt_config) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) train_op, hooks = model_io_fn.get_ema_hooks( train_op, tvars, kargs.get('params_moving_average_decay', 0.99), scope, mode, first_stage_steps=opt_config.num_warmup_steps, two_stage=True) model_io_fn.set_saver() train_metric_dict = train_metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels']) for key in train_metric_dict: tf.summary.scalar(key, train_metric_dict[key]) tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) if output_type == "sess": return { "train": { "loss": loss, "nsp_log_pro": nsp_log_prob, "train_op": train_op, "masked_lm_loss": masked_lm_loss, "next_sentence_loss": nsp_loss, "masked_lm_log_pro": masked_lm_log_probs }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: def prediction_fn(logits): predictions = { "nsp_classes": tf.argmax(input=nsp_log_prob, axis=1), "nsp_probabilities": tf.exp(nsp_log_prob, name="nsp_softmax"), "masked_vocab_classes": tf.argmax(input=masked_lm_log_probs, axis=1), "masked_probabilities": tf.exp(masked_lm_log_probs, name='masked_softmax') } return predictions predictions = prediction_fn(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(predictions) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape( masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss } if output_type == "sess": return { "eval": { "nsp_log_prob": nsp_log_prob, "masked_lm_log_prob": masked_lm_log_probs, "nsp_loss": nsp_loss, "masked_lm_loss": masked_lm_loss, "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels']) _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) eval_hooks = [hooks] if hooks else [] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output(model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE) with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE): (loss, logits, per_example_loss) = classifier(model_config, model.get_sequence_output(), features['input_ori_ids'], features['ori_input_ids'], features['input_mask'], 2, dropout_prob, ori_sampled_ids=features.get('ori_sampled_ids', None), use_tpu=kargs.get('use_tpu', True)) tf.add_to_collection("discriminator_loss", loss) loss += 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_seq_prediction_tvars) pretrained_tvars.extend(lm_pretrain_tvars) tvars = pretrained_tvars print('==discriminator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu, restore_var_name=model_config.get('restore_var_name', [])) else: scaffold_fn = None return_dict = { "loss":loss, "logits":logits, "tvars":tvars, "model":model, "per_example_loss":per_example_loss } return return_dict
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output(model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE, scope=generator_scope_prefix) if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") (_, _, masked_lm_log_probs, _) = seq_masked_lm_fn(model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], features['input_mask'], reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table(), scope=generator_scope_prefix) print(model_config.lm_ratio, '==mlm lm_ratio==') # loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if generator_scope_prefix: """ "generator/cls/predictions" """ lm_pretrain_tvars = model_io_fn.get_params(generator_scope_prefix+"/cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params(generator_scope_prefix+"/cls/seq_relationship", not_storage_params=not_storage_params) else: lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", not_storage_params=not_storage_params) nsp_pretrain_vars = model_io_fn.get_params("cls/seq_relationship", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) pretrained_tvars.extend(nsp_pretrain_vars) tvars = pretrained_tvars print('==generator parameters==', tvars) if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.PREDICT: mask = tf.expand_dims(tf.cast(features['input_mask'], tf.float32), axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ "probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs)) }, export_outputs={ "output":tf.estimator.export.PredictOutput( { "probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs)) } ) } ) return estimator_spec
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) input_ori_ids = features.get('input_ori_ids', None) if mode == tf.estimator.ModeKeys.TRAIN: if input_ori_ids is not None: # [output_ids, # sampled_binary_mask] = random_input_ids_generation( # model_config, # input_ori_ids, # features['input_mask'], # **kargs) [output_ids, sampled_binary_mask] = hmm_input_ids_generation(model_config, features['input_ori_ids'], features['input_mask'], [tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list], mask_probability=0.2, replace_probability=0.1, original_probability=0.1, mask_prior=tf.cast(tf.constant(mask_prior), tf.float32), **kargs) features['input_ids'] = output_ids tf.logging.info("***** Running random sample input generation *****") else: sampled_binary_mask = None else: sampled_binary_mask = None model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output(model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=tf.AUTO_REUSE) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] if model_config.model_type == 'bert': masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output print("==apply bert masked lm==") elif model_config.model_type == 'albert': masked_lm_fn = pretrain_albert.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply albert masked lm==") else: masked_lm_fn = pretrain.get_masked_lm_output seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output print("==apply bert masked lm==") if sampled_binary_mask is not None: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = seq_masked_lm_fn(model_config, model.get_sequence_output(), model.get_embedding_table(), features['input_mask'], features['input_ori_ids'], features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) masked_lm_ids = input_ori_ids else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = masked_lm_fn( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) print(model_config.lm_ratio, '==mlm lm_ratio==') loss = model_config.lm_ratio * masked_lm_loss #+ 0.0 * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) if load_pretrained == "yes": scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=1) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=opt_config.use_tpu) train_metric_dict = train_metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_mask, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels'], masked_lm_mask=masked_lm_mask ) # for key in train_metric_dict: # tf.summary.scalar(key, train_metric_dict[key]) # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape(masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax( masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax( next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss } eval_metrics = (metric_fn, [ masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_mask, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels'] ]) estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=eval_metrics, scaffold_fn=scaffold_fn) return estimator_spec else: raise NotImplementedError()