def __init__(self, model_config, num_labels, init_checkpoint, load_pretrained=True, model_io_config={}, opt_config={}, exclude_scope="", not_storage_params=[], target="a", label_lst=None, output_type="sess", **kargs): self.model_config = model_config self.num_labels = num_labels self.init_checkpoint = init_checkpoint self.load_pretrained = load_pretrained self.model_io_config = model_io_config self.opt_config = opt_config self.exclude_scope = exclude_scope self.not_storage_params = not_storage_params self.target = target self.label_lst = label_lst self.output_type = output_type self.kargs = kargs self.model_io_fn = model_io.ModelIO(self.model_io_config) self.optimizer_fn = optimizer.Optimizer(self.opt_config)
def model_fn(features, labels, mode): model_api = model_zoo(model_config) label_ids = features["label_ids"] model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==") # anneal_fn = anneal_strategy.AnnealStrategy(kargs.get("anneal_config", {})) # get teacher logits teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get( "temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax( logits / kargs.get("temperature", 2.0)) # log_softmax logits distillation_loss = kd_distance( teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) distillation_loss *= features["distillation_ratio"] distillation_loss = tf.reduce_sum(distillation_loss) / ( 1e-10 + tf.reduce_sum(features["distillation_ratio"])) label_loss = tf.reduce_sum( per_example_loss * features["label_ratio"]) / ( 1e-10 + tf.reduce_sum(features["label_ratio"])) print( "==distillation loss ratio==", kargs.get("distillation_ratio", 0.9) * tf.pow(kargs.get("temperature", 2.0), 2)) # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss loss = (1 - kargs.get("distillation_ratio", 0.9)) * label_loss + tf.pow( kargs.get("temperature", 2.0), 2) * kargs.get( "distillation_ratio", 0.9) * distillation_loss model_io_fn = model_io.ModelIO(model_io_config) params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op, "cross_entropy": label_loss, "kd_loss": distillation_loss, "kd_num": tf.reduce_sum(features["distillation_ratio"]), "ce_num": tf.reduce_sum(features["label_ratio"]), "teacher_logit": teacher_logit, "student_logit": student_logit, "label_ratio": features["label_ratio"] }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): ebm_noise_fce = EBM_NOISE_FCE(model_config_dict, num_labels_dict, init_checkpoint_dict, load_pretrained_dict, model_io_config=model_io_config, opt_config=opt_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, **kargs) model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 train_op, loss, var_checkpoint_dict_list = get_train_op( optimizer_fn, opt_config, model_config_dict['ebm_dist'], model_config_dict['noise_dist'], features, labels, mode, params, ebm_noise_fce, use_tpu=use_tpu) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: ebm_noise_fce.get_loss(features, labels, mode, params, **kargs) tpu_eval_metrics = (ebm_noise_eval_metric, [ ebm_noise_fce.true_ebm_dist_dict['logits'], ebm_noise_fce.noise_dist_dict['true_logits'], ebm_noise_fce.fake_ebm_dist_dict['logits'], ebm_noise_fce.noise_dist_dict['fake_logits'], features['input_ori_ids'], tf.cast(features['input_mask'], tf.float32), ebm_noise_fce.noise_dist_dict["true_seq_logits"] ]) gpu_eval_metrics = ebm_noise_eval_metric( ebm_noise_fce.true_ebm_dist_dict['logits'], ebm_noise_fce.noise_dist_dict['true_logits'], ebm_noise_fce.fake_ebm_dist_dict['logits'], ebm_noise_fce.noise_dist_dict['fake_logits'], features['input_ori_ids'], tf.cast(features['input_mask'], tf.float32), ebm_noise_fce.noise_dist_dict["true_seq_logits"] ) loss = ebm_noise_fce.ebm_loss + ebm_noise_fce.noise_loss var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) label_loss = tf.reduce_sum( per_example_loss * features["label_ratio"]) / ( 1e-10 + tf.reduce_sum(features["label_ratio"])) if mode == tf.estimator.ModeKeys.TRAIN: distillation_api = distill.KnowledgeDistillation( kargs.get( "disitllation_config", Bunch({ "logits_ratio_decay": "constant", "logits_ratio": 0.5, "logits_decay_rate": 0.999, "distillation": ['relation_kd', 'logits'], "feature_ratio": 0.5, "feature_ratio_decay": "constant", "feature_decay_rate": 0.999, "kd_type": "kd", "scope": scope }))) # get teacher logits teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get( "temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax( logits / kargs.get("temperature", 2.0)) # log_softmax logits distillation_features = { "student_logits_tensor": student_logit, "teacher_logits_tensor": teacher_logit, "student_feature_tensor": model.get_pooled_output(), "teacher_feature_tensor": features["distillation_feature"], "student_label": tf.ones_like(label_ids, dtype=tf.int32), "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32), "logits_ratio": kargs.get("logits_ratio", 0.5), "feature_ratio": kargs.get("logits_ratio", 0.5), "distillation_ratio": features["distillation_ratio"], "src_f_logit": logits, "tgt_f_logit": logits, "src_tensor": model.get_pooled_output(), "tgt_tensor": features["distillation_feature"] } distillation_loss = distillation_api.distillation( distillation_features, 2, dropout_prob, model_reuse, opt_config.num_train_steps, feature_ratio=1.0, logits_ratio_decay="constant", feature_ratio_decay="constant", feature_decay_rate=0.999, logits_decay_rate=0.999, logits_ratio=0.5, scope=scope + "/adv_classifier", num_classes=num_labels, gamma=kargs.get("gamma", 4)) loss = label_loss + distillation_loss["distillation_loss"] model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op, "cross_entropy": label_loss, "distillation_loss": distillation_loss["distillation_loss"], "kd_num": tf.reduce_sum(features["distillation_ratio"]), "ce_num": tf.reduce_sum(features["label_ratio"]), "label_ratio": features["label_ratio"], "distilaltion_logits_loss": distillation_loss["distillation_logits_loss"], "distilaltion_feature_loss": distillation_loss["distillation_feature_loss"], "rkd_loss": distillation_loss["rkd_loss"] }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse, **kargs) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: # if model_config.get('label_type', 'single_label') == 'single_label': # print(logits.get_shape(), "===logits shape===") # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) # prob = tf.nn.softmax(logits) # max_prob = tf.reduce_max(prob, axis=-1) # estimator_spec = tf.estimator.EstimatorSpec( # mode=mode, # predictions={ # 'pred_label':pred_label, # "max_prob":max_prob # }, # export_outputs={ # "output":tf.estimator.export.PredictOutput( # { # 'pred_label':pred_label, # "max_prob":max_prob # } # ) # } # ) if model_config.get('label_type', 'single_label') == 'multi_label': prob = tf.nn.sigmoid(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': prob, "max_prob": prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': prob, "max_prob": prob }) }) elif model_config.get('label_type', 'single_label') == "single_label": prob = tf.nn.softmax(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': prob, "max_prob": prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': prob, "max_prob": prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss), "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): train_ops = [] train_hooks = [] logits_dict = {} losses_dict = {} features_dict = {} tvars = [] task_num_dict = {} multi_task_config = kargs.get('multi_task_config', {}) total_loss = tf.constant(0.0) task_num = 0 encoder = {} hook_dict = {} print(task_type_dict.keys(), "==task type dict==") num_task = len(task_type_dict) from data_generator import load_w2v flags = kargs.get('flags', Bunch({})) print(flags.pretrained_w2v_path, "===pretrain vocab path===") w2v_path = os.path.join(flags.buckets, flags.pretrained_w2v_path) vocab_path = os.path.join(flags.buckets, flags.vocab_file) # [w2v_embed, token2id, # id2token, is_extral_symbol, use_pretrained] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) # pretrained_embed = tf.cast(tf.constant(w2v_embed), tf.float32) pretrained_embed = None for index, task_type in enumerate(task_type_dict.keys()): if model_config_dict[task_type].model_type in model_type_lst: reuse = True else: reuse = None model_type_lst.append(model_config_dict[task_type].model_type) if model_config_dict[task_type].model_type not in encoder: model_api = model_zoo(model_config_dict[task_type]) model = model_api(model_config_dict[task_type], features, labels, mode, target_dict[task_type], reuse=reuse, cnn_type=model_config_dict[task_type].get( 'cnn_type', 'bi_dgcnn')) encoder[model_config_dict[task_type].model_type] = model # vae_kl_model = vae_model_fn(encoder[model_config_dict[task_type].model_type], # model_config_dict[task_type], # num_labels_dict[task_type], # init_checkpoint_dict[task_type], # reuse, # load_pretrained_dict[task_type], # model_io_config, # opt_config, # exclude_scope=exclude_scope_dict[task_type], # not_storage_params=not_storage_params_dict[task_type], # target=target_dict[task_type], # label_lst=None, # output_type=output_type, # task_layer_reuse=task_layer_reuse, # task_type=task_type, # num_task=num_task, # task_adversarial=1e-2, # get_pooled_output='task_output', # feature_distillation=False, # embedding_distillation=False, # pretrained_embed=pretrained_embed, # **kargs) # vae_result_dict = vae_kl_model(features, labels, mode) # tvars.extend(vae_result_dict['tvars']) # total_loss += vae_result_dict["loss"] # for key in vae_result_dict: # if key in ['perplexity', 'token_acc', 'kl_div']: # hook_dict[key] = vae_result_dict[key] print(encoder, "==encode==") if task_type_dict[task_type] == "cls_task": task_model_fn = cls_model_fn( encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, get_pooled_output='task_output', feature_distillation=False, embedding_distillation=False, pretrained_embed=pretrained_embed, **kargs) result_dict = task_model_fn(features, labels, mode) tf.logging.info("****** task: *******", task_type_dict[task_type], task_type) elif task_type_dict[task_type] == "embed_task": task_model_fn = embed_model_fn( encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, get_pooled_output='task_output', feature_distillation=False, embedding_distillation=False, pretrained_embed=pretrained_embed, loss='contrastive_loss', apply_head_proj=False, **kargs) result_dict = task_model_fn(features, labels, mode) tf.logging.info("****** task: *******", task_type_dict[task_type], task_type) # cpc_model_fn = embed_cpc_model_fn(encoder[model_config_dict[task_type].model_type], # model_config_dict[task_type], # num_labels_dict[task_type], # init_checkpoint_dict[task_type], # reuse, # load_pretrained_dict[task_type], # model_io_config, # opt_config, # exclude_scope=exclude_scope_dict[task_type], # not_storage_params=not_storage_params_dict[task_type], # target=target_dict[task_type], # label_lst=None, # output_type=output_type, # task_layer_reuse=task_layer_reuse, # task_type=task_type, # num_task=num_task, # task_adversarial=1e-2, # get_pooled_output='task_output', # feature_distillation=False, # embedding_distillation=False, # pretrained_embed=pretrained_embed, # loss='contrastive_loss', # apply_head_proj=False, # **kargs) # cpc_result_dict = cpc_model_fn(features, labels, mode) # result_dict['loss'] += cpc_result_dict['loss'] # result_dict['tvars'].extend(cpc_result_dict['tvars']) # hook_dict["{}_all_neg_loss".format(task_type)] = cpc_result_dict['loss'] # hook_dict["{}_all_neg_num".format(task_type)] = cpc_result_dict['task_num'] elif task_type_dict[task_type] == "cpc_task": task_model_fn = embed_cpc_v1_model_fn( encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, get_pooled_output='task_output', feature_distillation=False, embedding_distillation=False, pretrained_embed=pretrained_embed, loss='contrastive_loss', apply_head_proj=False, task_seperate_proj=True, **kargs) result_dict = task_model_fn(features, labels, mode) tf.logging.info("****** task: *******", task_type_dict[task_type], task_type) elif task_type_dict[task_type] == "regression_task": task_model_fn = regression_model_fn( encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, get_pooled_output='task_output', feature_distillation=False, embedding_distillation=False, pretrained_embed=pretrained_embed, loss='contrastive_loss', apply_head_proj=False, **kargs) result_dict = task_model_fn(features, labels, mode) tf.logging.info("****** task: *******", task_type_dict[task_type], task_type) else: continue print("==SUCCEEDED IN LODING==", task_type) # result_dict = task_model_fn(features, labels, mode) logits_dict[task_type] = result_dict["logits"] losses_dict[task_type] = result_dict["loss"] # task loss for key in [ "pos_num", "neg_num", "masked_lm_loss", "task_loss", "acc", "task_acc", "masked_lm_acc" ]: name = "{}_{}".format(task_type, key) if name in result_dict: hook_dict[name] = result_dict[name] hook_dict["{}_loss".format(task_type)] = result_dict["loss"] hook_dict["{}_num".format(task_type)] = result_dict["task_num"] print("==loss ratio==", task_type, multi_task_config[task_type].get('loss_ratio', 1.0)) total_loss += result_dict["loss"] * multi_task_config[ task_type].get('loss_ratio', 1.0) hook_dict['embed_loss'] = result_dict["embed_loss"] hook_dict['feature_loss'] = result_dict["feature_loss"] hook_dict["{}_task_loss".format( task_type)] = result_dict["task_loss"] if 'positive_label' in result_dict: hook_dict["{}_task_positive_label".format( task_type)] = result_dict["positive_label"] if mode == tf.estimator.ModeKeys.TRAIN: tvars.extend(result_dict["tvars"]) task_num += result_dict["task_num"] task_num_dict[task_type] = result_dict["task_num"] elif mode == tf.estimator.ModeKeys.EVAL: features[task_type] = result_dict["feature"] hook_dict["total_loss"] = total_loss if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn = model_io.ModelIO(model_io_config) optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(list(set(tvars)), string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( total_loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver(optimizer_fn.opt) if kargs.get("task_index", 1) == 1 and kargs.get( "run_config", None): model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook elif kargs.get("task_index", 1) == 1: training_hooks = [] else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) if output_type == "sess": return { "train": { "total_loss": total_loss, "loss": losses_dict, "logits": logits_dict, "train_op": train_op, "task_num_dict": task_num_dict }, "hooks": train_hooks } elif output_type == "estimator": hook_dict['learning_rate'] = optimizer_fn.learning_rate logging_hook = tf.train.LoggingTensorHook(hook_dict, every_n_iter=100) training_hooks.append(logging_hook) print("==hook_dict==") print(hook_dict) for key in hook_dict: tf.summary.scalar(key, hook_dict[key]) for index, task_type in enumerate(task_type_dict.keys()): tmp = "{}_loss".format(task_type) if tmp == key: tf.summary.scalar( "loss_gap_{}".format(task_type), hook_dict["total_loss"] - hook_dict[key]) for key in task_num_dict: tf.summary.scalar(key + "_task_num", task_num_dict[key]) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: # eval execute for each class solo def metric_fn(logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "logits": logits_dict, "total_loss": total_loss, "feature": features, "loss": losses_dict } } elif output_type == "estimator": eval_metric_ops = {} for key in logits_dict: eval_dict = metric_fn(logits_dict[key], features_task_dict[key]["label_ids"]) for sub_key in eval_dict.keys(): eval_key = "{}_{}".format(key, sub_key) eval_metric_ops[eval_key] = eval_dict[sub_key] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss / task_num, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): shape_lst_a = bert_utils.get_shape_list(features['input_ids_a']) batch_size_a = shape_lst_a[0] total_length_a = shape_lst_a[1] shape_lst_b = bert_utils.get_shape_list(features['input_ids_b']) batch_size_b = shape_lst_b[0] total_length_b = shape_lst_b[1] features['input_ids_a'] = tf.reshape(features['input_ids_a'], [-1, model_config.max_length]) features['segment_ids_a'] = tf.reshape(features['segment_ids_a'], [-1, model_config.max_length]) features['input_mask_a'] = tf.cast( tf.not_equal(features['input_ids_a'], kargs.get('[PAD]', 0)), tf.int64) features['input_ids_b'] = tf.reshape( features['input_ids_b'], [-1, model_config.max_predictions_per_seq]) features['segment_ids_b'] = tf.reshape( features['segment_ids_b'], [-1, model_config.max_predictions_per_seq]) features['input_mask_b'] = tf.cast( tf.not_equal(features['input_ids_b'], kargs.get('[PAD]', 0)), tf.int64) features['batch_size'] = batch_size_a features['total_length_a'] = total_length_a features['total_length_b'] = total_length_b model_dict = {} for target in ["a", "b"]: model = bert_encoder(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) model_dict[target] = model if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits, transition_params) = multi_position_crf_classifier( model_config, features, model_dict, num_labels, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) train_op, hooks = model_io_fn.get_ema_hooks( train_op, tvars, kargs.get('params_moving_average_decay', 0.99), scope, mode, first_stage_steps=opt_config.num_warmup_steps, two_stage=True) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") label_weights = tf.cast(features['label_weights'], tf.int32) label_seq_length = tf.reduce_sum(label_weights, axis=-1) decode_tags, best_score = tf.contrib.crf.crf_decode( logits, transition_params, label_seq_length) _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'decode_tags': decode_tags, "best_score": best_score, "transition_params": transition_params, "logits": logits }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'decode_tags': decode_tags, "best_score": best_score, "transition_params": transition_params, "logits": logits }) }, prediction_hooks=[hooks]) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) eval_hooks = [] if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss), "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = eval_logtis(logits, features, num_labels, transition_params) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): train_op_type = kargs.get('train_op_type', 'joint') ebm_noise_fce = EBM_NOISE_NCE( model_config_dict, num_labels_dict, init_checkpoint_dict, load_pretrained_dict, model_io_config=model_io_config, opt_config=opt_config, exclude_scope_dict=exclude_scope_dict, not_storage_params_dict=not_storage_params_dict, target_dict=target_dict, **kargs) model_io_fn = model_io.ModelIO(model_io_config) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 train_op = get_train_op(ebm_noise_fce, optimizer_fn, opt_config, model_config_dict['ebm_dist'], model_config_dict['noise_dist'], model_config_dict['generator'], features, labels, mode, params, use_tpu=use_tpu, train_op_type=train_op_type, alternate_order=['ebm', 'generator']) ebm_noise_fce.load_pretrained_model(**kargs) var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list loss = ebm_noise_fce.loss tvars = ebm_noise_fce.tvars if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None metric_dict = ebm_train_metric( ebm_noise_fce.true_ebm_dist_dict['logits'], ebm_noise_fce.fake_ebm_dist_dict['logits']) if not kargs.get('use_tpu', False): for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) tf.summary.scalar("ebm_loss", ebm_noise_fce.ebm_opt_dict['ebm_loss']) tf.summary.scalar("mlm_loss", ebm_noise_fce.ebm_opt_dict['mlm_loss']) tf.summary.scalar("all_loss", ebm_noise_fce.ebm_opt_dict['all_loss']) model_io_fn.print_params(tvars, string=", trainable params") if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: ebm_noise_fce.get_loss(features, labels, mode, params, **kargs) ebm_noise_fce.load_pretrained_model(**kargs) var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list loss = ebm_noise_fce.loss if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None tpu_eval_metrics = (ebm_eval_metric, [ ebm_noise_fce.true_ebm_dist_dict['logits'], ebm_noise_fce.fake_ebm_dist_dict['logits'] ]) gpu_eval_metrics = ebm_eval_metric( ebm_noise_fce.true_ebm_dist_dict['logits'], ebm_noise_fce.fake_ebm_dist_dict['logits']) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope (nsp_loss, nsp_per_example_loss, nsp_log_prob) = pretrain.get_next_sentence_output( model_config, model.get_pooled_output(), features['next_sentence_labels'], reuse=model_reuse) masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) loss = model_config.lm_ratio * masked_lm_loss + model_config.nsp_ratio * nsp_loss model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) optimizer_fn = optimizer.Optimizer(opt_config) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) train_op, hooks = model_io_fn.get_ema_hooks( train_op, tvars, kargs.get('params_moving_average_decay', 0.99), scope, mode, first_stage_steps=opt_config.num_warmup_steps, two_stage=True) model_io_fn.set_saver() train_metric_dict = train_metric_fn( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels']) for key in train_metric_dict: tf.summary.scalar(key, train_metric_dict[key]) tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) if output_type == "sess": return { "train": { "loss": loss, "nsp_log_pro": nsp_log_prob, "train_op": train_op, "masked_lm_loss": masked_lm_loss, "next_sentence_loss": nsp_loss, "masked_lm_log_pro": masked_lm_log_probs }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: def prediction_fn(logits): predictions = { "nsp_classes": tf.argmax(input=nsp_log_prob, axis=1), "nsp_probabilities": tf.exp(nsp_log_prob, name="nsp_softmax"), "masked_vocab_classes": tf.argmax(input=masked_lm_log_probs, axis=1), "masked_probabilities": tf.exp(masked_lm_log_probs, name='masked_softmax') } return predictions predictions = prediction_fn(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions=predictions, export_outputs={ "output": tf.estimator.export.PredictOutput(predictions) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape( masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) next_sentence_log_probs = tf.reshape( next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]]) next_sentence_predictions = tf.argmax(next_sentence_log_probs, axis=-1, output_type=tf.int32) next_sentence_labels = tf.reshape(next_sentence_labels, [-1]) next_sentence_accuracy = tf.metrics.accuracy( labels=next_sentence_labels, predictions=next_sentence_predictions) next_sentence_mean_loss = tf.metrics.mean( values=next_sentence_example_loss) return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "next_sentence_accuracy": next_sentence_accuracy, "next_sentence_loss": next_sentence_mean_loss } if output_type == "sess": return { "eval": { "nsp_log_prob": nsp_log_prob, "masked_lm_log_prob": masked_lm_log_probs, "nsp_loss": nsp_loss, "masked_lm_loss": masked_lm_loss, "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, nsp_per_example_loss, nsp_log_prob, features['next_sentence_labels']) _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) eval_hooks = [hooks] if hooks else [] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): generator_fn = generator( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get('generator', []), target=target_dict['generator'], **kargs) generator_dict = generator_fn(features, labels, mode, params) discriminator_fn = discriminator( model_config_dict['discriminator'], num_labels_dict['discriminator'], init_checkpoint_dict['discriminator'], model_reuse=None, load_pretrained=load_pretrained_dict['discriminator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('discriminator', ""), not_storage_params=not_storage_params_dict.get( 'discriminator', []), target=target_dict['discriminator'], **kargs) discriminator_features = {} discriminator_features['input_ids'] = generator_dict['sampled_ids'] discriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] discriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] discriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] discriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] discriminator_dict = discriminator_fn(discriminator_features, labels, mode, params) model_io_fn = model_io.ModelIO(model_io_config) tvars = [] loss = discriminator_dict['loss'] print(loss) tvars.extend(discriminator_dict['tvars']) if kargs.get('joint_train', '0') == '1': tvars.extend(generator_fn['tvars']) loss += generator_dict['loss'] var_checkpoint_dict_list = [] for key in init_checkpoint_dict: if load_pretrained_dict[key] == "yes": if key == 'generator': tmp = { "tvars": generator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['generator'], "exclude_scope": exclude_scope_dict[key] } var_checkpoint_dict_list.append(tmp) elif key == 'discriminator': tmp = { "tvars": discriminator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['discriminator'], "exclude_scope": exclude_scope_dict[key] } var_checkpoint_dict_list.append(tmp) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: metric_dict = discriminator_metric_train( discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, use_tpu=use_tpu) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: if kargs.get('joint_train', '0') == '1': generator_metric = generator_metric_fn_eval( generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None)) else: generator_metric = {} discriminator_metric = discriminator_metric_eval( discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) metric_dict = discriminator_metric if len(generator_metric): metric_dict.update(discriminator_metric) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=metric_dict, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=metric_dict) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): train_op_type = kargs.get('train_op_type', 'joint') if kargs.get('optimization_type', 'grl') == 'grl': generator_fn = generator( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get( 'generator', []), target=target_dict['generator'], **kargs) train_op_type = 'joint' elif kargs.get('optimization_type', 'grl') == 'minmax': generator_fn = generator_normal( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get( 'generator', []), target=target_dict['generator'], **kargs) else: generator_fn = generator( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get( 'generator', []), target=target_dict['generator'], **kargs) tf.logging.info("****** train_op_type:%s *******", train_op_type) tf.logging.info("****** optimization_type:%s *******", kargs.get('optimization_type', 'grl')) generator_dict = generator_fn(features, labels, mode, params) discriminator_fn = discriminator( model_config_dict['discriminator'], num_labels_dict['discriminator'], init_checkpoint_dict['discriminator'], model_reuse=None, load_pretrained=load_pretrained_dict['discriminator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('discriminator', ""), not_storage_params=not_storage_params_dict.get( 'discriminator', []), target=target_dict['discriminator'], **kargs) tf.logging.info("****** true sampled_ids of discriminator *******") true_distriminator_features = {} true_distriminator_features['input_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] true_distriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] true_distriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] true_distriminator_features['ori_input_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_dict = discriminator_fn(true_distriminator_features, labels, mode, params) fake_discriminator_features = {} if kargs.get('minmax_mode', 'corrupted') == 'corrupted': tf.logging.info("****** gumbel 3-D sampled_ids *******") elif kargs.get('minmax_mode', 'corrupted') == 'masked': fake_discriminator_features['ori_sampled_ids'] = generator_dict[ 'output_ids'] tf.logging.info("****** conditioanl sampled_ids *******") fake_discriminator_features['input_ids'] = generator_dict[ 'sampled_ids'] fake_discriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] fake_discriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] fake_discriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] fake_discriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] fake_discriminator_features['ori_input_ids'] = generator_dict[ 'sampled_ids'] fake_discriminator_dict = discriminator_fn(fake_discriminator_features, labels, mode, params) nce_loss = nce_loss_fn(true_distriminator_dict, true_distriminator_features, fake_discriminator_dict, fake_discriminator_features) model_io_fn = model_io.ModelIO(model_io_config) tvars = [] loss = kargs.get('dis_loss', 1.0) * nce_loss tvars.extend(fake_discriminator_dict['tvars']) if kargs.get('joint_train', '1') == '1': tf.logging.info( "****** joint generator and discriminator training *******") tvars.extend(generator_dict['tvars']) loss += generator_dict['loss'] tvars = list(set(tvars)) var_checkpoint_dict_list = [] for key in init_checkpoint_dict: if load_pretrained_dict[key] == "yes": if key == 'generator': tmp = { "tvars": generator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['generator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['generator'].get( 'restore_var_name', []) } if kargs.get("sharing_mode", "none") != "none": tmp['exclude_scope'] = '' var_checkpoint_dict_list.append(tmp) elif key == 'discriminator': tmp = { "tvars": discriminator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['discriminator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['discriminator'].get( 'restore_var_name', []) } var_checkpoint_dict_list.append(tmp) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('summary_debug', False): metric_dict = discriminator_metric_train( fake_discriminator_dict['per_example_loss'], fake_discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, use_tpu=use_tpu) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: if kargs.get('joint_train', '0') == '1': def joint_metric(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels, per_example_loss, logits, input_ori_ids, input_ids, input_mask): generator_metric = generator_metric_fn_eval( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels) discriminator_metric = discriminator_metric_eval( per_example_loss, logits, input_ori_ids, input_ids, input_mask) generator_metric.update(discriminator_metric) return generator_metric tpu_eval_metrics = (joint_metric, [ generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), fake_discriminator_dict['per_example_loss'], fake_discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask'] ]) gpu_eval_metrics = joint_metric( generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), fake_discriminator_dict['per_example_loss'], fake_discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) else: gpu_eval_metrics = discriminator_metric_eval( fake_discriminator_dict['per_example_loss'], fake_discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) tpu_eval_metrics = (discriminator_metric_eval, [ fake_discriminator_dict['per_example_loss'], fake_discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask'] ]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) # model_adv_config = copy.deepcopy(model_config) # model_adv_config.scope = model_config.scope + "/adv_encoder" # model_adv_adaptation = model_api(model_adv_config, features, labels, # mode, target, reuse=tf.AUTO_REUSE) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope common_feature = model.get_pooled_output() task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/task_residual", if_grl=False) adv_task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual", if_grl=True) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): concat_feature = task_feature # concat_feature = tf.concat([task_feature, # adv_task_feature], # axis=-1) (loss, per_example_loss, logits) = classifier.classifier(model_config, concat_feature, num_labels, label_ids, dropout_prob, *kargs) with tf.variable_scope(scope+"/adv_classifier", reuse=tf.AUTO_REUSE): adv_ids = features["adv_ids"] (adv_loss, adv_per_example_loss, adv_logits) = classifier.classifier(model_config, adv_task_feature, kargs.get('adv_num_labels', 12), adv_ids, dropout_prob, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: loss_diff = tf.constant(0.0) # adv_task_feature_no_grl = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual") # loss_diff = diff_loss(task_feature, # adv_task_feature_no_grl) print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==") # get teacher logits teacher_logit = tf.log(features["label_probs"]+1e-10)/kargs.get("temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax(logits /kargs.get("temperature", 2.0)) # log_softmax logits distillation_loss = kd_distance(teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) distillation_loss *= features["distillation_ratio"] distillation_loss = tf.reduce_sum(distillation_loss) / (1e-10+tf.reduce_sum(features["distillation_ratio"])) label_loss = loss #tf.reduce_sum(per_example_loss * features["label_ratio"]) / (1e-10+tf.reduce_sum(features["label_ratio"])) print("==distillation loss ratio==", kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)) # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss loss = (1-kargs.get("distillation_ratio", 0.9))*label_loss + kargs.get("distillation_ratio", 0.9) * distillation_loss if mode == tf.estimator.ModeKeys.TRAIN: loss += kargs.get("adv_ratio", 0.1) * adv_loss + loss_diff model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": adv_pred_label = tf.argmax(adv_logits, axis=-1, output_type=tf.int32) adv_correct = tf.equal( tf.cast(adv_pred_label, tf.int32), tf.cast(adv_ids, tf.int32) ) adv_accuracy = tf.reduce_mean(tf.cast(adv_correct, tf.float32)) return { "train":{ "loss":loss, "logits":logits, "train_op":train_op, "cross_entropy":label_loss, "kd_loss":distillation_loss, "kd_num":tf.reduce_sum(features["distillation_ratio"]), "ce_num":tf.reduce_sum(features["label_ratio"]), "teacher_logit":teacher_logit, "student_logit":student_logit, "label_ratio":features["label_ratio"], "loss_diff":loss_diff, "adv_loss":adv_loss, "adv_accuracy":adv_accuracy }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: task_prob = tf.exp(tf.nn.log_softmax(logits)) adv_prob = tf.exp(tf.nn.log_softmax(adv_logits)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'adv_prob':adv_prob, "task_prob":task_prob }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'adv_prob':adv_prob, "task_prob":task_prob } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): features['input_mask'] = tf.cast( tf.not_equal(features['input_ids'], kargs.get('[PAD]', 0)), tf.int64) # for key in ['input_mask', 'input_ids', 'segment_ids']: # features[key] = features[key][:, :274] model = bert_encoder(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = multi_position_classifier(model_config, features, model.get_sequence_output(), num_labels, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) # train_op, hooks = model_io_fn.get_ema_hooks(train_op, # tvars, # kargs.get('params_moving_average_decay', 0.99), # scope, mode, # first_stage_steps=opt_config.num_warmup_steps, # two_stage=True) model_io_fn.set_saver() train_metric_dict = train_metric_fn(logits, features, num_labels) for key in train_metric_dict: tf.summary.scalar(key, train_metric_dict[key]) tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) # _, hooks = model_io_fn.get_ema_hooks(None, # None, # kargs.get('params_moving_average_decay', 0.99), # scope, mode) hooks = [] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }, prediction_hooks=hooks) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: _, hooks = model_io_fn.get_ema_hooks( None, None, kargs.get('params_moving_average_decay', 0.99), scope, mode) eval_hooks = [] if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss), "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = eval_logtis(logits, features, num_labels) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): train_op_type = kargs.get('train_op_type', 'joint') print("==input shape==", features["input_ids"].get_shape()) ebm_dist_fn = ebm_dist(model_config_dict['ebm_dist'], num_labels_dict['ebm_dist'], init_checkpoint_dict['ebm_dist'], model_reuse=None, load_pretrained=load_pretrained_dict['ebm_dist'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('ebm_dist', ""), not_storage_params=not_storage_params_dict.get('ebm_dist', []), target=target_dict['ebm_dist'], prob_ln=False, transform=False, transformer_activation="linear", logz_mode='standard', normalized_constant="length_linear", energy_pooling="mi", softplus_features=False, **kargs) noise_prob_ln = False noise_sample = kargs.get("noise_sample", 'mlm') if kargs.get("noise_sample", 'mlm') == 'gpt': tf.logging.info("****** using gpt for noise dist sample *******") sample_noise_dist = True elif kargs.get("noise_sample", 'mlm') == 'mlm': tf.logging.info("****** using bert mlm for noise dist sample *******") sample_noise_dist = False else: tf.logging.info("****** using gpt for noise dist sample *******") sample_noise_dist = True noise_dist_fn = noise_dist(model_config_dict['noise_dist'], num_labels_dict['noise_dist'], init_checkpoint_dict['noise_dist'], model_reuse=None, load_pretrained=load_pretrained_dict['noise_dist'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('noise_dist', ""), not_storage_params=not_storage_params_dict.get('noise_dist', []), target=target_dict['noise_dist'], noise_true_distribution=True, sample_noise_dist=sample_noise_dist, noise_estimator_type=kargs.get("noise_estimator_type", "stop_gradient"), prob_ln=noise_prob_ln, if_bp=True, **kargs) if not sample_noise_dist: tf.logging.info("****** using bert mlm for noise dist sample *******") global_step = tf.train.get_or_create_global_step() noise_sample_ratio = tf.train.polynomial_decay( 0.20, global_step, opt_config.num_train_steps, end_learning_rate=0.1, power=1.0, cycle=False) mlm_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get('generator', []), target=target_dict['generator'], mask_probability=noise_sample_ratio, replace_probability=0.2, original_probability=0.0, **kargs) else: mlm_noise_dist_fn = None true_features = {} for key in features: if key == 'input_ori_ids': true_features["input_ids"] = tf.cast(features['input_ori_ids'], tf.int32) if key in ['input_mask', 'segment_ids']: true_features[key] = tf.cast(features[key], tf.int32) if kargs.get("dnce", False): if kargs.get("anneal_dnce", False): global_step = tf.train.get_or_create_global_step() noise_sample_ratio = tf.train.polynomial_decay( 0.10, global_step, opt_config.num_train_steps, end_learning_rate=0.05, power=1.0, cycle=False) tf.logging.info("****** anneal dnce mix ratio *******") else: noise_sample_ratio = 0.10 tf.logging.info("****** not anneal dnce mix ratio *******") mlm_noise_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get('generator', []), target=target_dict['generator'], mask_probability=noise_sample_ratio, replace_probability=0.0, original_probability=0.0, **kargs) mlm_noise_dist_dict_noise = mlm_noise_noise_dist_fn(features, labels, mode, params) mixed_mask = mixed_sample(features, mix_ratio=noise_sample_ratio) tf.logging.info("****** apply dnce *******") mixed_mask = tf.expand_dims(mixed_mask, axis=-1) # batch_size x 1 mixed_mask = tf.cast(mixed_mask, tf.int32) true_features["input_ids"] = (1-mixed_mask)*true_features["input_ids"] + mixed_mask * mlm_noise_dist_dict_noise['sampled_ids'] if not sample_noise_dist: mlm_noise_dist_dict = mlm_noise_dist_fn(features, labels, mode, params) else: mlm_noise_dist_dict = {} # first get noise dict noise_dist_dict = noise_dist_fn(true_features, labels, mode, params) # third, get fake ebm dict fake_features = {} if noise_sample == 'gpt': if kargs.get("training_mode", "stop_gradient") == 'stop_gradient': fake_features["input_ids"] = noise_dist_dict['fake_samples'] tf.logging.info("****** using samples stop gradient *******") elif kargs.get("training_mode", "stop_gradient") == 'adv_gumbel': fake_features["input_ids"] = noise_dist_dict['gumbel_probs'] tf.logging.info("****** using samples with gradient *******") fake_features['input_mask'] = tf.cast(noise_dist_dict['fake_mask'], tf.int32) fake_features['segment_ids'] = tf.zeros_like(fake_features['input_mask']) elif noise_sample == 'mlm': fake_features["input_ids"] = mlm_noise_dist_dict['sampled_ids'] fake_features['input_mask'] = tf.cast(features['input_mask'], tf.int32) fake_features['segment_ids'] = tf.zeros_like(features['input_mask']) tf.logging.info("****** using bert mlm stop gradient *******") # second, get true ebm dict true_ebm_dist_dict = ebm_dist_fn(true_features, labels, mode, params) fake_ebm_dist_dict = ebm_dist_fn(fake_features, labels, mode, params) if not sample_noise_dist: fake_noise_dist_dict = noise_dist_fn(fake_features, labels, mode, params) noise_dist_dict['fake_logits'] = fake_noise_dist_dict['true_logits'] [ebm_loss, ebm_all_true_loss, ebm_all_fake_loss] = get_ebm_loss(true_ebm_dist_dict['logits'], noise_dist_dict['true_logits'], fake_ebm_dist_dict['logits'], noise_dist_dict['fake_logits'], use_tpu=kargs.get('use_tpu', False), valid_mask=mlm_noise_dist_dict.get("valid_mask", None)) logz_length_true_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'], true_features, ebm_all_true_loss, valid_mask=mlm_noise_dist_dict.get("valid_mask", None)) logz_length_fake_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'], fake_features, ebm_all_fake_loss, valid_mask=mlm_noise_dist_dict.get("valid_mask", None)) true_ebm_dist_dict['logz_loss'] = logz_length_true_loss + logz_length_fake_loss noise_loss = get_noise_loss(true_ebm_dist_dict['logits'], noise_dist_dict['true_logits'], fake_ebm_dist_dict['logits'], noise_dist_dict['fake_logits'], noise_loss_type=kargs.get('noise_loss_type', 'jsd_noise'), num_train_steps=opt_config.num_train_steps, num_warmup_steps=opt_config.num_warmup_steps, use_tpu=kargs.get('use_tpu', False), loss_mask=features['input_mask'], prob_ln=noise_prob_ln) model_io_fn = model_io.ModelIO(model_io_config) tvars = [] loss = ebm_loss tvars.extend(true_ebm_dist_dict['tvars']) if kargs.get('joint_train', '1') == '1': tf.logging.info("****** joint generator and discriminator training *******") tvars.extend(noise_dist_dict['tvars']) loss += noise_loss tvars = list(set(tvars)) ebm_opt_dict = { "loss":ebm_loss, "tvars":true_ebm_dist_dict['tvars'], "logz_tvars":true_ebm_dist_dict['logz_tvars'], "logz_loss":true_ebm_dist_dict['logz_loss'] } noise_opt_dict = { "loss":noise_loss, "tvars":noise_dist_dict['tvars'] } var_checkpoint_dict_list = [] for key in init_checkpoint_dict: if load_pretrained_dict[key] == "yes": if key == 'ebm_dist': tmp = { "tvars":ebm_opt_dict['tvars']+ebm_opt_dict['logz_tvars'], "init_checkpoint":init_checkpoint_dict['ebm_dist'], "exclude_scope":exclude_scope_dict[key], "restore_var_name":model_config_dict['ebm_dist'].get('restore_var_name', []) } if kargs.get("sharing_mode", "none") != "none": tmp['exclude_scope'] = '' var_checkpoint_dict_list.append(tmp) elif key == 'noise_dist': tmp = { "tvars":noise_opt_dict['tvars'], "init_checkpoint":init_checkpoint_dict['noise_dist'], "exclude_scope":exclude_scope_dict[key], "restore_var_name":model_config_dict['noise_dist'].get('restore_var_name', []) } var_checkpoint_dict_list.append(tmp) elif key == 'generator': if not sample_noise_dist: tmp = { "tvars":mlm_noise_dist_dict['tvars'], "init_checkpoint":init_checkpoint_dict['generator'], "exclude_scope":exclude_scope_dict[key], "restore_var_name":model_config_dict['generator'].get('restore_var_name', []) } if kargs.get("sharing_mode", "none") != "none": tmp['exclude_scope'] = '' var_checkpoint_dict_list.append(tmp) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: metric_dict = ebm_noise_train_metric( true_ebm_dist_dict['logits'], noise_dist_dict['true_logits'], fake_ebm_dist_dict['logits'], noise_dist_dict['fake_logits'], features['input_ori_ids'], tf.cast(features['input_mask'], tf.float32), noise_dist_dict["true_seq_logits"], prob_ln=noise_prob_ln, ) if not kargs.get('use_tpu', False): for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) tf.summary.scalar("ebm_loss", ebm_opt_dict['loss']) tf.summary.scalar("noise_loss", noise_opt_dict['loss']) if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 model_io_fn.print_params(tvars, string=", trainable params") train_op = get_train_op(ebm_opt_dict, noise_opt_dict, optimizer_fn, opt_config, model_config_dict['ebm_dist'], model_config_dict['noise_dist'], use_tpu=use_tpu, train_op_type=train_op_type, fce_acc=metric_dict['all_accuracy']) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # train_op = optimizer_fn.get_train_op(loss, list(set(tvars)), # opt_config.init_lr, # opt_config.num_train_steps, # use_tpu=use_tpu) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: tpu_eval_metrics = (ebm_noise_eval_metric, [ true_ebm_dist_dict['logits'], noise_dist_dict['true_logits'], fake_ebm_dist_dict['logits'], noise_dist_dict['fake_logits'], features['input_ori_ids'], tf.cast(features['input_mask'], tf.float32), noise_dist_dict["true_seq_logits"] ]) gpu_eval_metrics = ebm_noise_eval_metric( true_ebm_dist_dict['logits'], noise_dist_dict['true_logits'], fake_ebm_dist_dict['logits'], noise_dist_dict['fake_logits'], features['input_ori_ids'], tf.cast(features['input_mask'], tf.float32), noise_dist_dict["true_seq_logits"] ) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 label_ids = features["label_ids"] model_lst = [] for index, name in enumerate(target): if index > 0: reuse = True else: reuse = model_reuse model_lst.append(bert_encoding(model_config, features, labels, mode, name, scope, dropout_rate, reuse=reuse)) [input_mask_a, repres_a] = model_lst[0] [input_mask_b, repres_b] = model_lst[1] output_a, output_b = alignment_aggerate(model_config, repres_a, repres_b, input_mask_a, input_mask_b, scope, reuse=model_reuse) if model_config.pooling == "ave_max_pooling": pooling_fn = ave_max_pooling elif model_config.pooling == "multihead_pooling": pooling_fn = multihead_pooling repres_a = pooling_fn(model_config, output_a, input_mask_a, scope, dropout_prob, reuse=model_reuse) repres_b = pooling_fn(model_config, output_b, input_mask_b, scope, dropout_prob, reuse=True) pair_repres = tf.concat([repres_a, repres_b, tf.abs(repres_a-repres_b), repres_b*repres_a], axis=-1) with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, pair_repres, num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train":{ "loss":loss, "logits":logits, "train_op":train_op }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label':pred_label, "max_prob":max_prob }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'pred_label':pred_label, "max_prob":max_prob } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "loss": sentence_mean_loss, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): train_op_type = kargs.get('train_op_type', 'joint') gen_disc_type = kargs.get('gen_disc_type', 'all_disc') print(train_op_type, "===train op type===", gen_disc_type, "===generator loss type===") if kargs.get('optimization_type', 'grl') == 'grl': if_flip_grad = True train_op_type = 'joint' elif kargs.get('optimization_type', 'grl') == 'minmax': if_flip_grad = False generator_fn = generator( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get('generator', []), target=target_dict['generator'], if_flip_grad=if_flip_grad, # mask_method="all_mask", **kargs) tf.logging.info("****** train_op_type:%s *******", train_op_type) tf.logging.info("****** optimization_type:%s *******", kargs.get('optimization_type', 'grl')) generator_dict = generator_fn(features, labels, mode, params) discriminator_fn = discriminator( model_config_dict['discriminator'], num_labels_dict['discriminator'], init_checkpoint_dict['discriminator'], model_reuse=None, load_pretrained=load_pretrained_dict['discriminator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('discriminator', ""), not_storage_params=not_storage_params_dict.get( 'discriminator', []), target=target_dict['discriminator'], **kargs) tf.logging.info("****** true sampled_ids of discriminator *******") true_distriminator_features = {} true_distriminator_features['input_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] true_distriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] true_distriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] true_distriminator_features['ori_input_ids'] = generator_dict[ 'sampled_input_ids'] true_distriminator_dict = discriminator_fn(true_distriminator_features, labels, mode, params) fake_discriminator_features = {} if kargs.get('minmax_mode', 'corrupted') == 'corrupted': tf.logging.info("****** gumbel 3-D sampled_ids *******") elif kargs.get('minmax_mode', 'corrupted') == 'masked': fake_discriminator_features['ori_sampled_ids'] = generator_dict[ 'output_ids'] discriminator_features['sampled_binary_mask'] = generator_dict[ 'sampled_binary_mask'] tf.logging.info("****** conditioanl sampled_ids *******") fake_discriminator_features['input_ids'] = generator_dict[ 'sampled_ids'] fake_discriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] fake_discriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] fake_discriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] fake_discriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] fake_discriminator_features['ori_input_ids'] = generator_dict[ 'sampled_ids'] fake_discriminator_dict = discriminator_fn(fake_discriminator_features, labels, mode, params) use_tpu = 1 if kargs.get('use_tpu', False) else 0 output_dict = get_losses(true_distriminator_dict["logits"], fake_discriminator_dict["logits"], use_tpu=use_tpu, gan_type=kargs.get('gan_type', "JS")) discriminator_dict = {} discriminator_dict['gen_loss'] = output_dict['gen_loss'] discriminator_dict['disc_loss'] = output_dict['disc_loss'] discriminator_dict['tvars'] = fake_discriminator_dict['tvars'] discriminator_dict['fake_logits'] = fake_discriminator_dict['logits'] discriminator_dict['true_logits'] = true_distriminator_dict['logits'] model_io_fn = model_io.ModelIO(model_io_config) loss = discriminator_dict['disc_loss'] tvars = [] tvars.extend(discriminator_dict['tvars']) if kargs.get('joint_train', '1') == '1': tf.logging.info( "****** joint generator and discriminator training *******") tvars.extend(generator_dict['tvars']) loss += generator_dict['loss'] tvars = list(set(tvars)) var_checkpoint_dict_list = [] for key in init_checkpoint_dict: if load_pretrained_dict[key] == "yes": if key == 'generator': tmp = { "tvars": generator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['generator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['generator'].get( 'restore_var_name', []) } if kargs.get("sharing_mode", "none") != "none": tmp['exclude_scope'] = '' var_checkpoint_dict_list.append(tmp) elif key == 'discriminator': tmp = { "tvars": discriminator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['discriminator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['discriminator'].get( 'restore_var_name', []) } var_checkpoint_dict_list.append(tmp) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: if not kargs.get('use_tpu', False): metric_dict = discriminator_metric_train(discriminator_dict) for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) tf.summary.scalar("generator_loss", generator_dict['loss']) tf.summary.scalar("discriminator_true_loss", discriminator_dict['disc_loss']) tf.summary.scalar("discriminator_fake_loss", discriminator_dict['gen_loss']) if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 model_io_fn.print_params(tvars, string=", trainable params") train_op = get_train_op(generator_dict, discriminator_dict, optimizer_fn, opt_config, model_config_dict['generator'], model_config_dict['discriminator'], use_tpu=use_tpu, train_op_type=train_op_type, gen_disc_type=gen_disc_type) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn # training_hooks=[logging_hook] ) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: if kargs.get('joint_train', '0') == '1': def joint_metric(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels, discriminator_dict): generator_metric = generator_metric_fn_eval( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels) discriminator_metric = discriminator_metric_eval( discriminator_dict) generator_metric.update(discriminator_metric) return generator_metric tpu_eval_metrics = (joint_metric, [ generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), discriminator_dict ]) gpu_eval_metrics = joint_metric( generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), discriminator_dict) else: gpu_eval_metrics = discriminator_metric_eval( discriminator_dict) tpu_eval_metrics = (discriminator_metric_eval, [discriminator_dict]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): train_op_type = kargs.get('train_op_type', 'joint') gen_disc_type = kargs.get('gen_disc_type', 'all_disc') mask_method = kargs.get('mask_method', 'only_mask') use_tpu = 1 if kargs.get('use_tpu', False) else 0 print(train_op_type, "===train op type===", gen_disc_type, "===generator loss type===") if mask_method == 'only_mask': tf.logging.info( "****** generator token generation mask type:%s with only masked token *******", mask_method) elif mask_method == 'all_mask': tf.logging.info( "****** generator token generation mask type:%s with all token *******", mask_method) else: mask_method = 'only_mask' tf.logging.info( "****** generator token generation mask type:%s with only masked token *******", mask_method) if kargs.get('optimization_type', 'grl') == 'grl': if_flip_grad = True train_op_type = 'joint' elif kargs.get('optimization_type', 'grl') == 'minmax': if_flip_grad = False else: if_flip_grad = True train_op_type = 'joint' generator_fn = generator( model_config_dict['generator'], num_labels_dict['generator'], init_checkpoint_dict['generator'], model_reuse=None, load_pretrained=load_pretrained_dict['generator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('generator', ""), not_storage_params=not_storage_params_dict.get('generator', []), target=target_dict['generator'], if_flip_grad=if_flip_grad, # mask_method='only_mask', **kargs) tf.logging.info("****** train_op_type:%s *******", train_op_type) tf.logging.info("****** optimization_type:%s *******", kargs.get('optimization_type', 'grl')) generator_dict = generator_fn(features, labels, mode, params) discriminator_fn = discriminator_generator( model_config_dict['discriminator'], num_labels_dict['discriminator'], init_checkpoint_dict['discriminator'], model_reuse=None, load_pretrained=load_pretrained_dict['discriminator'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('discriminator', ""), not_storage_params=not_storage_params_dict.get( 'discriminator', []), target=target_dict['discriminator'], loss='cross_entropy', **kargs) discriminator_features = {} # minmax_mode in ['masked', 'corrupted'] minmax_mode = kargs.get('minmax_mode', 'corrupted') tf.logging.info("****** minmax mode for discriminator: %s *******", minmax_mode) if minmax_mode == 'corrupted': tf.logging.info("****** gumbel 3-D sampled_ids *******") elif minmax_mode == 'masked': discriminator_features['ori_sampled_ids'] = generator_dict[ 'output_ids'] discriminator_features['sampled_binary_mask'] = generator_dict[ 'sampled_binary_mask'] tf.logging.info("****** conditional sampled_ids *******") discriminator_features['input_ids'] = generator_dict['sampled_ids'] discriminator_features['input_mask'] = generator_dict[ 'sampled_input_mask'] discriminator_features['segment_ids'] = generator_dict[ 'sampled_segment_ids'] discriminator_features['input_ori_ids'] = generator_dict[ 'sampled_input_ids'] discriminator_features['next_sentence_labels'] = features[ 'next_sentence_labels'] discriminator_features['ori_input_ids'] = generator_dict['sampled_ids'] discriminator_dict = discriminator_fn(discriminator_features, labels, mode, params) [disc_loss, disc_logits, disc_per_example_loss ] = optimal_discriminator(model_config_dict['discriminator'], generator_dict, features, discriminator_dict, discriminator_features, use_tpu=use_tpu) [ equal_per_example_loss, equal_loss_all, equal_loss_self, not_equal_per_example_loss, not_equal_loss_all, not_equal_loss_self ] = modified_loss(disc_per_example_loss, disc_logits, discriminator_features['input_ori_ids'], discriminator_features['ori_input_ids'], discriminator_features['input_mask'], sampled_binary_mask=discriminator_features.get( 'sampled_binary_mask', None), **kargs) output_dict = {} output_dict['logits'] = disc_logits output_dict['per_example_loss'] = disc_per_example_loss output_dict['loss'] = disc_loss + 0.0 * discriminator_dict["loss"] output_dict["equal_per_example_loss"] = equal_per_example_loss, output_dict["equal_loss_all"] = equal_loss_all, output_dict["equal_loss_self"] = equal_loss_self, output_dict["not_equal_per_example_loss"] = not_equal_per_example_loss, output_dict["not_equal_loss_all"] = not_equal_loss_all, output_dict["not_equal_loss_self"] = not_equal_loss_self output_dict['tvars'] = discriminator_dict['tvars'] model_io_fn = model_io.ModelIO(model_io_config) tvars = [] loss = kargs.get('dis_loss', 1.0) * output_dict['loss'] tvars.extend(discriminator_dict['tvars']) if kargs.get('joint_train', '1') == '1': tf.logging.info( "****** joint generator and discriminator training *******") tvars.extend(generator_dict['tvars']) loss += generator_dict['loss'] tvars = list(set(tvars)) var_checkpoint_dict_list = [] for key in init_checkpoint_dict: if load_pretrained_dict[key] == "yes": if key == 'generator': tmp = { "tvars": generator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['generator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['generator'].get( 'restore_var_name', []) } if kargs.get("sharing_mode", "none") != "none": tmp['exclude_scope'] = '' var_checkpoint_dict_list.append(tmp) elif key == 'discriminator': tmp = { "tvars": discriminator_dict['tvars'], "init_checkpoint": init_checkpoint_dict['discriminator'], "exclude_scope": exclude_scope_dict[key], "restore_var_name": model_config_dict['discriminator'].get( 'restore_var_name', []) } var_checkpoint_dict_list.append(tmp) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if len(var_checkpoint_dict_list) >= 1: scaffold_fn = model_io_fn.load_multi_pretrained( var_checkpoint_dict_list, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: if not kargs.get('use_tpu', False): metric_dict = discriminator_metric_train( output_dict['per_example_loss'], output_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) for key in metric_dict: tf.summary.scalar(key, metric_dict[key]) tf.summary.scalar("generator_loss", generator_dict['loss']) tf.summary.scalar("discriminator_loss", discriminator_dict['loss']) if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 model_io_fn.print_params(tvars, string=", trainable params") train_op = get_train_op(generator_dict, output_dict, optimizer_fn, opt_config, model_config_dict['generator'], model_config_dict['discriminator'], use_tpu=use_tpu, train_op_type=train_op_type, gen_disc_type=gen_disc_type) # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) # with tf.control_dependencies(update_ops): # train_op = optimizer_fn.get_train_op(loss, list(set(tvars)), # opt_config.init_lr, # opt_config.num_train_steps, # use_tpu=use_tpu) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn # training_hooks=[logging_hook] ) else: estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: if kargs.get('joint_train', '0') == '1': def joint_metric(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels, per_example_loss, logits, input_ori_ids, input_ids, input_mask): generator_metric = generator_metric_fn_eval( masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, next_sentence_example_loss, next_sentence_log_probs, next_sentence_labels) discriminator_metric = discriminator_metric_eval( per_example_loss, logits, input_ori_ids, input_ids, input_mask) generator_metric.update(discriminator_metric) return generator_metric tpu_eval_metrics = (joint_metric, [ generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask'] ]) gpu_eval_metrics = joint_metric( generator_dict['masked_lm_example_loss'], generator_dict['masked_lm_log_probs'], generator_dict['masked_lm_ids'], generator_dict['masked_lm_weights'], generator_dict.get('next_sentence_example_loss', None), generator_dict.get('next_sentence_log_probs', None), generator_dict.get('next_sentence_labels', None), discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) else: gpu_eval_metrics = discriminator_metric_eval( discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask']) tpu_eval_metrics = (discriminator_metric_eval, [ discriminator_dict['per_example_loss'], discriminator_dict['logits'], generator_dict['sampled_input_ids'], generator_dict['sampled_ids'], generator_dict['sampled_input_mask'] ]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) if target: features['input_ori_ids'] = features['input_ids_{}'.format(target)] features['input_ids'] = features['input_ids_{}'.format(target)] sequence_mask = tf.cast( tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)), tf.int32) features['input_mask'] = sequence_mask seq_features = {} for key in features: seq_features[key] = features[key] if 'input_ori_ids' in features: seq_features['input_ids'] = features["input_ori_ids"] else: features['input_ori_ids'] = seq_features['input_ids'] not_equal = tf.cast( tf.not_equal(features["input_ori_ids"], tf.zeros_like(features["input_ori_ids"])), tf.int32) not_equal = tf.reduce_sum(not_equal, axis=-1) loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)), tf.float32) if not kargs.get('use_tpu', False): tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask)) casual_flag = model_config.get('is_casual', True) tf.logging.info("***** is casual flag *****", str(casual_flag)) if not casual_flag: [output_ids, sampled_binary_mask] = hmm_input_ids_generation( model_config, features['input_ori_ids'], features['input_mask'], [ tf.cast(tf.constant(hmm_tran_prob), tf.float32) for hmm_tran_prob in hmm_tran_prob_list ], mask_probability=0.02, replace_probability=0.01, original_probability=0.01, mask_prior=tf.cast(tf.constant(mask_prior), tf.float32), **kargs) tf.logging.info("***** apply random sampling *****") seq_features['input_ids'] = output_ids model = model_api(model_config, seq_features, labels, mode, "", reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope # if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('mask_type', 'left2right') == 'left2right': tf.logging.info("***** using left2right mask and loss *****") sequence_mask = tf.to_float( tf.not_equal(features['input_ori_ids'][:, 1:], kargs.get('[PAD]', 0))) elif kargs.get('mask_type', 'left2right') == 'seq2seq': tf.logging.info("***** using seq2seq mask and loss *****") sequence_mask = tf.to_float(features['segment_ids'][:, 1:]) if not kargs.get('use_tpu', False): tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask)) # batch x seq_length if casual_flag: print(model.get_sequence_output_logits().get_shape(), "===logits shape===") seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ori_ids'][:, 1:], logits=model.get_sequence_output_logits()[:, :-1]) per_example_loss = tf.reduce_sum( seq_loss * sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10) loss = tf.reduce_mean(per_example_loss) if model_config.get("cnn_type", "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']: seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ori_ids'][:, :-1], logits=model.get_sequence_backward_output_logits()[:, 1:]) per_backward_example_loss = tf.reduce_sum( seq_backward_loss * sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10) backward_loss = tf.reduce_mean(per_backward_example_loss) loss += backward_loss tf.logging.info("***** using backward loss *****") else: (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs, masked_lm_mask) = pretrain.seq_mask_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), seq_features['input_mask'], seq_features['input_ori_ids'], seq_features['input_ids'], sampled_binary_mask, reuse=tf.AUTO_REUSE, embedding_projection=model.get_embedding_projection_table()) loss = masked_lm_loss tf.logging.info("***** using masked lm loss *****") model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) else: scaffold_fn = None if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 tf.logging.info( "***** using tpu with tpu-captiable optimizer *****") else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 tf.logging.info( "***** using gpu with gpu-captiable optimizer *****") tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=use_tpu) # train_metric_dict = train_metric(features['input_ori_ids'], # model.get_sequence_output_logits(), # seq_features, # **kargs) # if not kargs.get('use_tpu', False): # for key in train_metric_dict: # tf.summary.scalar(key, train_metric_dict[key]) # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) # tf.logging.info("***** logging metric *****") # tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask)) # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask)) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: gpu_eval_metrics = eval_metric(features['input_ori_ids'], model.get_sequence_output_logits(), seq_features, **kargs) tpu_eval_metrics = (eval_metric, [ features['input_ori_ids'], model.get_sequence_output_logits(), seq_features, kargs.get('mask_type', 'left2right') ]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: if kargs.get('predict_type', 'sample_sequence') == 'sample_sequence': results = bert_seq_sample_utils.sample_sequence( model_api, model_config, mode, features, target="", start_token=kargs.get("start_token_id", 101), batch_size=None, context=features.get("context", None), temperature=kargs.get("sample_temp", 1.0), n_samples=kargs.get("n_samples", 1), top_k=0, end_token=kargs.get("end_token_id", 102), greedy_or_sample="greedy", gumbel_temp=0.01, estimator="stop_gradient", back_prop=True, swap_memory=True, seq_type=kargs.get("seq_type", "seq2seq"), mask_type=kargs.get("mask_type", "seq2seq"), attention_type=kargs.get('attention_type', 'normal_attention')) # stop_gradient output: # samples, mask_sequence, presents, logits, final sampled_token = results['samples'] sampled_token_logits = results['logits'] mask_sequence = results['mask_sequence'] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token': sampled_token, "logits": sampled_token_logits, "mask_sequence": mask_sequence }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'token': sampled_token, "logits": sampled_token_logits, "mask_sequence": mask_sequence }) }) return estimator_spec elif kargs.get('predict_type', 'sample_sequence') == 'infer_inputs': sequence_mask = tf.to_float( tf.not_equal(features['input_ids'][:, 1:], kargs.get('[PAD]', 0))) if kargs.get('mask_type', 'left2right') == 'left2right': tf.logging.info( "***** using left2right mask and loss *****") sequence_mask = tf.to_float( tf.not_equal(features['input_ori_ids'][:, 1:], kargs.get('[PAD]', 0))) elif kargs.get('mask_type', 'left2right') == 'seq2seq': tf.logging.info("***** using seq2seq mask and loss *****") sequence_mask = tf.to_float(features['segment_ids'][:, 1:]) if not kargs.get('use_tpu', False): tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask)) output_logits = model.get_sequence_output_logits()[:, :-1] # output_logits = tf.nn.log_softmax(output_logits, axis=-1) output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ids'][:, 1:], logits=output_logits) per_example_perplexity = tf.reduce_sum(output_id_logits * sequence_mask, axis=-1) # batch per_example_perplexity /= tf.reduce_sum(sequence_mask, axis=-1) # batch perplexity = tf.exp(per_example_perplexity) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token': features['input_ids'][:, 1:], "logits": output_id_logits, 'perplexity': perplexity, # "all_logits":output_logits }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'token': features['input_ids'][:, 1:], "logits": output_id_logits, 'perplexity': perplexity, # "all_logits":output_logits }) }) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): # model = bert_encoder(model_config, features, labels, # mode, target, reuse=model_reuse) model = albert_encoder(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) # train_op, hooks = model_io_fn.get_ema_hooks(train_op, # tvars, # kargs.get('params_moving_average_decay', 0.99), # scope, mode, # first_stage_steps=opt_config.num_warmup_steps, # two_stage=True) model_io_fn.set_saver() estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: # _, hooks = model_io_fn.get_ema_hooks(None, # None, # kargs.get('params_moving_average_decay', 0.99), # scope, mode) hooks = None def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) eval_hooks = [hooks] if hooks else [] estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks ) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model_lst = [] assert len(target.split(",")) == 2 target_name_lst = target.split(",") print(target_name_lst) for index, name in enumerate(target_name_lst): if index > 0: reuse = True else: reuse = model_reuse model_lst.append(model_api(model_config, features, labels, mode, name, reuse=reuse)) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): seq_output_lst = [model.get_pooled_output() for model in model_lst] if model_config.get("classifier", "order_classifier") == "order_classifier": [loss, per_example_loss, logits] = classifier.order_classifier( model_config, seq_output_lst, num_labels, label_ids, dropout_prob, ratio_weight=None) elif model_config.get("classifier", "order_classifier") == "siamese_interaction_classifier": [loss, per_example_loss, logits] = classifier.siamese_classifier( model_config, seq_output_lst, num_labels, label_ids, dropout_prob, ratio_weight=None) model_io_fn = model_io.ModelIO(model_io_config) params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train":{ "loss":loss, "logits":logits, "train_op":train_op }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label':pred_label, "max_prob":max_prob }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'pred_label':pred_label, "max_prob":max_prob } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss), "feature":(seq_output_lst[0]+seq_output_lst[1])/2 } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): train_ops = [] train_hooks = [] logits_dict = {} losses_dict = {} features_dict = {} tvars = [] task_num_dict = {} total_loss = tf.constant(0.0) task_num = 0 encoder = {} hook_dict = {} print(task_type_dict.keys(), "==task type dict==") num_task = len(task_type_dict) for index, task_type in enumerate(task_type_dict.keys()): if model_config_dict[task_type].model_type in model_type_lst: reuse = True else: reuse = None model_type_lst.append(model_config_dict[task_type].model_type) if task_type_dict[task_type] == "cls_task": if model_config_dict[task_type].model_type not in encoder: model_api = model_zoo(model_config_dict[task_type]) model = model_api(model_config_dict[task_type], features, labels, mode, target_dict[task_type], reuse=reuse) encoder[model_config_dict[task_type].model_type] = model print(encoder, "==encode==") task_model_fn = cls_model_fn( encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, **kargs) print("==SUCCEEDED IN LODING==", task_type) result_dict = task_model_fn(features, labels, mode) logits_dict[task_type] = result_dict["logits"] losses_dict[task_type] = result_dict["loss"] # task loss for key in [ "masked_lm_loss", "task_loss", "acc", "task_acc", "masked_lm_acc" ]: name = "{}_{}".format(task_type, key) if name in result_dict: hook_dict[name] = result_dict[name] hook_dict["{}_loss".format(task_type)] = result_dict["loss"] total_loss += result_dict["loss"] if mode == tf.estimator.ModeKeys.TRAIN: tvars.extend(result_dict["tvars"]) task_num += result_dict["task_num"] task_num_dict[task_type] = result_dict["task_num"] elif mode == tf.estimator.ModeKeys.EVAL: features[task_type] = result_dict["feature"] else: continue hook_dict["total_loss"] = total_loss if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn = model_io.ModelIO(model_io_config) optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(list(set(tvars)), string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( total_loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver(optimizer_fn.opt) if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook elif kargs.get("task_index", 1) == 1: training_hooks = [] else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) if output_type == "sess": return { "train": { "total_loss": total_loss, "loss": losses_dict, "logits": logits_dict, "train_op": train_op, "task_num_dict": task_num_dict }, "hooks": train_hooks } elif output_type == "estimator": hook_dict['learning_rate'] = optimizer_fn.learning_rate logging_hook = tf.train.LoggingTensorHook(hook_dict, every_n_iter=100) training_hooks.append(logging_hook) print("==hook_dict==") print(hook_dict) for key in hook_dict: tf.summary.scalar(key, hook_dict[key]) for index, task_type in enumerate(task_type_dict.keys()): tmp = "{}_loss".format(task_type) if tmp == key: tf.summary.scalar( "loss_gap_{}".format(task_type), hook_dict["total_loss"] - hook_dict[key]) for key in task_num_dict: tf.summary.scalar(key + "_task_num", task_num_dict[key]) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) # training_hooks=training_hooks) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: # eval execute for each class solo def metric_fn(logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "logits": logits_dict, "total_loss": total_loss, "feature": features, "loss": losses_dict } } elif output_type == "estimator": eval_metric_ops = {} for key in logits: eval_dict = metric_fn(logits[key], features_task_dict[key]["label_ids"]) for sub_key in eval_dict.keys(): eval_key = "{}_{}".format(key, sub_key) eval_metric_ops[eval_key] = eval_dict[sub_key] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss / task_num, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): original_loss = tf.constant(0.0) distilled_loss = tf.constant(0.0) st_model = st_model_fn(model_config_dict['student'], num_labels_dict['student'], init_checkpoint_dict['student'], model_reuse=None, load_pretrained=load_pretrained_dict['student'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('student', ""), not_storage_params=not_storage_params_dict.get('student', []), target=target_dict['student'], **kargs) st_dict = st_model(features, labels, mode, params) # ta_model = ta_model_fn(model_config_dict['teacher'], # num_labels_dict['teacher'], # init_checkpoint_dict['teacher'], # model_reuse=None, # load_pretrained=load_pretrained_dict['teacher'], # model_io_config=model_io_config, # opt_config=opt_config, # exclude_scope=exclude_scope_dict.get('teacher', ""), # not_storage_params=not_storage_params_dict.get('teacher', []), # target=target_dict['teacher'], # **kargs) # ta_features = {} # for key in features: # ta_features[key] = features[key] # ta_features['masked_lm_mask'] = st_dict['masked_lm_mask'] # ta_features['input_ids'] = st_dict['output_ids'] # ta_features['input_ori_ids'] = features['input_ids'] # ta_dict = ta_model(ta_features, labels, mode, params) # studnet_logit = st_dict['logits'] # teacher_logit = ta_dict['logits'] model_io_fn = model_io.ModelIO(model_io_config) original_loss += st_dict['loss'] * (distillation_config.get('ce_loss', 1.0)) print(distillation_config.get('ce_loss', 1.0), '===ce_loss===') if not kargs.get('use_tpu', False): tf.summary.scalar("ce_loss", st_dict['loss']) hook_dict = {} # if 'kl_logits' in distillation_config.get('distillation_type', ['kl_logits']): # temperature = distillation_config.get('kl_temperature', 2.0) # distilled_teacher_logit = tf.nn.log_softmax((teacher_logit+1e-10) / temperature) # log_softmax logits # distilled_student_logit = tf.nn.log_softmax((studnet_logit+1e-10) / temperature) # log_softmax logits # logits_mask = tf.cast(st_dict['masked_lm_mask'], tf.float32) # kl_distilled_loss = distillation_utils.kd(distilled_teacher_logit, # distilled_student_logit) # kl_distilled_loss = tf.reduce_sum(logits_mask*kl_distilled_loss) / tf.reduce_sum(logits_mask) # if not kargs.get('use_tpu', False): # tf.summary.scalar("kl_logits_loss", kl_distilled_loss) # tf.summary.scalar("kl_logits_mask", tf.reduce_mean(logits_mask)) # tf.logging.info("***** with knowledge distillation %s tenperature *****", str(temperature)) # hook_dict['kl_logits_loss'] = kl_distilled_loss # # kl_distilled_loss *= np.power(temperature, 2) # distilled_loss += kl_distilled_loss * distillation_config.get('kl_logits', 0.9) # print(distillation_config.get('kl_logits_ratio', 0.9), '===kl_logits_ratio===') # if "attention_score_uniform" in distillation_config.get('distillation_type', ['kl_logits']): # source_attention_score = ta_dict['model'].get_multihead_attention() # target_attention_score = st_dict['model'].get_multihead_attention() # print("==apply attention_score_uniform==") # with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): # attention_loss = uniform_mapping.attention_score_matching(source_attention_score, # target_attention_score, # features['input_mask'], # 0) # tf.summary.scalar("attention_score_uniform_loss", attention_loss) # distilled_loss += attention_loss * distillation_config.get("attention_score_uniform", 0.1) # hook_dict['attention_mse_loss'] = attention_loss # print(distillation_config.get('attention_score_uniform', 0.1), '===attention_score_uniform===') # if "hidden_uniform" in distillation_config.get('distillation_type', ['kl_logits']): # source_hidden = ta_dict['model'].get_all_encoder_layers() # target_hidden = st_dict['model'].get_all_encoder_layers() # print("==apply hidden_uniform==") # with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): # hidden_loss = uniform_mapping.hidden_matching(source_hidden, target_hidden, # features['input_mask'], # 0) # if not kargs.get('use_tpu', False): # tf.summary.scalar("hidden_uniform_loss", hidden_loss) # distilled_loss += hidden_loss * distillation_config.get("hidden_uniform", 0.1) # hook_dict['hidden_loss'] = hidden_loss # print(distillation_config.get('hidden_uniform', 0.1), '===hidden_uniform===') # if "embedding_distillation" in distillation_config.get('distillation_type', ['embedding_distillation']): # st_word_embed = st_dict['model'].get_embedding_table() # ta_word_embed = ta_dict['model'].get_embedding_table() # st_word_embed_shape = bert_utils.get_shape_list(st_word_embed, expected_rank=[2,3]) # print("==random_embed_shape==", st_word_embed_shape) # ta_word_embed_shape = bert_utils.get_shape_list(ta_word_embed, expected_rank=[2,3]) # print("==pretrain_embed_shape==", ta_word_embed_shape) # if st_word_embed_shape[-1] != ta_word_embed_shape[-1]: # with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): # with tf.variable_scope("embedding_proj"): # proj_embed = tf.layers.dense(ta_word_embed, st_word_embed_shape[-1]) # else: # proj_embed = ta_word_embed # embed_loss = tf.reduce_mean(tf.reduce_mean(tf.square(proj_embed-st_word_embed), axis=-1)) # distilled_loss += embed_loss # hook_dict['embed_loss'] = embed_loss # tf.logging.info("****** apply prertained feature distillation *******") total_loss = distilled_loss + original_loss tvars = [] tvars.extend(st_dict['tvars']) distillation_vars = model_io_fn.get_params('distillation', not_storage_params=[]) tvars.extend(distillation_vars) # if kargs.get('update_ta', False): # total_loss += ta_dict['loss'] # tvars.extend(ta_dict['tvars']) if not kargs.get('use_tpu', False): student_eval_metrics = train_metric_fn( st_dict['masked_lm_example_loss'], st_dict['logits'], st_dict["masked_lm_ids"], st_dict['masked_lm_mask'], 'student') # teacher_eval_metric = train_metric_fn( # ta_dict['masked_lm_example_loss'], # ta_dict['logits'], # ta_dict["masked_lm_ids"], # ta_dict['masked_lm_mask'], # 'teacher') # student_eval_metrics.update(teacher_eval_metric) for key in student_eval_metrics: hook_dict[key] = student_eval_metrics[key] if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) print('==total trainable vars==', list(tvars)) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(total_loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 1 and kargs.get("run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 1: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] logging_hook = tf.train.LoggingTensorHook( hook_dict, every_n_iter=100) training_hooks.append(logging_hook) if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: student_eval_metrics = metric_fn( st_dict['masked_lm_example_loss'], st_dict['logits'], st_dict["masked_lm_ids"], st_dict['masked_lm_mask'], 'student') # teacher_eval_metric = metric_fn( # ta_dict['masked_lm_example_loss'], # ta_dict['logits'], # ta_dict["masked_lm_ids"], # ta_dict['masked_lm_mask'], # 'teacher') # student_eval_metrics.update(teacher_eval_metric) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, eval_metric_ops=student_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model = bert_encoder(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) model_io_fn.set_saver(var_lst=tvars) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op } } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode, params): model_api = model_zoo(model_config) seq_features = {} for key in features: seq_features[key] = features[key] seq_features['input_ids'] = features["input_ori_ids"] model = model_api(model_config, seq_features, labels, mode, target, reuse=tf.AUTO_REUSE, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope sequence_mask = tf.to_float( tf.not_equal(features['input_ori_ids'][:, 1:], kargs.get('[PAD]', 0))) # batch x seq_length print(model.get_sequence_output_logits().get_shape(), "===logits shape===") seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ori_ids'][:, 1:], logits=model.get_sequence_output_logits()[:, :-1]) per_example_loss = tf.reduce_sum(seq_loss * sequence_mask, axis=-1) / ( tf.reduce_sum(sequence_mask, axis=-1) + 1e-10) loss = tf.reduce_mean(per_example_loss) model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(lm_pretrain_tvars) use_tpu = 1 if kargs.get('use_tpu', False) else 0 if load_pretrained == "yes": use_tpu = 1 if kargs.get('use_tpu', False) else 0 scaffold_fn = model_io_fn.load_pretrained( pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope, use_tpu=use_tpu) tf.logging.info("***** using tpu *****") else: scaffold_fn = None tf.logging.info("***** not using tpu *****") if mode == tf.estimator.ModeKeys.TRAIN: if kargs.get('use_tpu', False): optimizer_fn = optimizer.Optimizer(opt_config) use_tpu = 1 tf.logging.info( "***** using tpu with tpu-captiable optimizer *****") else: optimizer_fn = distributed_optimizer.Optimizer(opt_config) use_tpu = 0 tf.logging.info( "***** using gpu with gpu-captiable optimizer *****") tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, use_tpu=use_tpu) train_metric_dict = train_metric( features['input_ori_ids'], model.get_sequence_output_logits(), **kargs) if not kargs.get('use_tpu', False): for key in train_metric_dict: tf.summary.scalar(key, train_metric_dict[key]) tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) tf.logging.info("***** logging metric *****") tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(model.attention_mask)) tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask)) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, train_op=train_op, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: gpu_eval_metrics = eval_metric(features['input_ori_ids'], model.get_sequence_output_logits()) tpu_eval_metrics = (eval_metric, [ features['input_ori_ids'], model.get_sequence_output_logits() ]) if kargs.get('use_tpu', False): estimator_spec = tf.contrib.tpu.TPUEstimatorSpec( mode=mode, loss=loss, eval_metrics=tpu_eval_metrics, scaffold_fn=scaffold_fn) else: estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): original_loss = tf.constant(0.0) distilled_loss = tf.constant(0.0) st_model = st_model_fn( model_config_dict['student'], num_labels_dict['student'], init_checkpoint_dict['student'], model_reuse=None, load_pretrained=load_pretrained_dict['student'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('student', ""), not_storage_params=not_storage_params_dict.get('student', []), target=target_dict['student'], **kargs) st_dict = st_model(features, labels, mode) ta_model = ta_model_fn( model_config_dict['teacher'], num_labels_dict['teacher'], init_checkpoint_dict['teacher'], model_reuse=None, load_pretrained=load_pretrained_dict['teacher'], model_io_config=model_io_config, opt_config=opt_config, exclude_scope=exclude_scope_dict.get('teacher', ""), not_storage_params=not_storage_params_dict.get('teacher', []), target=target_dict['teacher'], **kargs) ta_dict = ta_model(features, labels, mode) studnet_logit = st_dict['logits'] teacher_logit = ta_dict['logits'] model_io_fn = model_io.ModelIO(model_io_config) feature_flag = False original_loss += st_dict['loss'] * (distillation_config.get( 'ce_loss', 1.0)) print(distillation_config.get('ce_loss', 1.0), '===ce_loss===') tf.summary.scalar("ce_loss", st_dict['loss']) if 'kl_logits' in distillation_config.get('distillation_type', ['kl_logits']): temperature = distillation_config.get('kl_temperature', 2.0) distilled_teacher_logit = tf.nn.log_softmax( (teacher_logit + 1e-10) / temperature) # log_softmax logits distilled_student_logit = tf.nn.log_softmax( (studnet_logit + 1e-10) / temperature) # log_softmax logits kl_distilled_loss = tf.reduce_mean( distillation_utils.kd(distilled_teacher_logit, distilled_student_logit)) tf.summary.scalar("kl_logits_loss", kl_distilled_loss) tf.logging.info( "***** with knowledge distillation %s tenperature *****", str(temperature)) # kl_distilled_loss *= np.power(temperature, 2) distilled_loss += kl_distilled_loss * distillation_config.get( 'kl_logits_ratio', 0.9) print(distillation_config.get('kl_logits_ratio', 0.9), '===kl_logits_ratio===') if 'rkd' in distillation_config.get('distillation_type', ['kl_logits']): source = ta_dict['model'].get_pooled_output() target = st_dict['model'].get_pooled_output() print("==apply rkd==") with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): rkd_loss = repo_distillation_utils.RKD(source, target, l=[25, 50]) tf.summary.scalar("rkd_loss", rkd_loss) distilled_loss += rkd_loss * distillation_config.get( "rkd_ratio", 0.1) if "attention_score_uniform" in distillation_config.get( 'distillation_type', ['kl_logits']): source_attention_score = ta_dict['model'].get_multihead_attention() target_attention_score = st_dict['model'].get_multihead_attention() print("==apply attention_score_uniform==") with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): attention_loss = uniform_mapping.attention_score_matching( source_attention_score, target_attention_score, features['input_mask'], 0) tf.summary.scalar("attention_score_uniform_loss", attention_loss) feature_flag = True distilled_loss += attention_loss * distillation_config.get( "attention_score_uniform", 0.1) print(distillation_config.get('attention_score_uniform', 0.1), '===attention_score_uniform===') if "hidden_uniform" in distillation_config.get('distillation_type', ['kl_logits']): source_hidden = ta_dict['model'].get_all_encoder_layers() target_hidden = st_dict['model'].get_all_encoder_layers() print("==apply hidden_uniform==") with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): hidden_loss = uniform_mapping.hidden_matching( source_hidden, target_hidden, features['input_mask'], 0) tf.summary.scalar("hidden_uniform_loss", hidden_loss) distilled_loss += hidden_loss * distillation_config.get( "hidden_uniform", 0.1) feature_flag = True print(distillation_config.get('hidden_uniform', 0.1), '===hidden_uniform===') if "hidden_cls_uniform" in distillation_config.get( 'distillation_type', ['kl_logits']): source_hidden = ta_dict['model'].get_all_encoder_layers() target_hidden = st_dict['model'].get_all_encoder_layers() print("==apply hidden_cls_uniform==") with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): hidden_cls_loss = uniform_mapping.hidden_cls_matching( source_hidden, target_hidden, 0) tf.summary.scalar("hidden_cls_uniform_loss", hidden_cls_loss) distilled_loss += hidden_cls_loss * distillation_config.get( "hidden_uniform", 0.1) feature_flag = True if "mdd" in distillation_config.get('distillation_type', ['mdd']): source = ta_dict['model'].get_pooled_output() target = st_dict['model'].get_pooled_output() print("==apply mdd==") if "cpc" in distillation_config.get('distillation_type', ['mdd']): source_hidden = ta_dict['model'].get_all_encoder_layers() target_hidden = st_dict['model'].get_all_encoder_layers() with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): cpc_loss = cpc_utils.CPC_Hidden(target_hidden, source_hidden, features['input_mask']) tf.summary.scalar("hidden_cpc_loss", cpc_loss) distilled_loss += cpc_loss + distillation_config.get( "cpc_hidden", 0.1) if "wpc" in distillation_config.get('distillation_type', ['mdd']): source_hidden = ta_dict['model'].get_all_encoder_layers() target_hidden = st_dict['model'].get_all_encoder_layers() with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE): wpc_loss = cpc_utils.WPC_Hidden(target_hidden, source_hidden, features['input_mask']) tf.summary.scalar("hidden_wpc_loss", wpc_loss) distilled_loss += wpc_loss + distillation_config.get( "wpc_hidden", 0.1) total_loss = distilled_loss + original_loss tvars = [] tvars.extend(st_dict['tvars']) if feature_flag: distillation_vars = model_io_fn.get_params('distillation', not_storage_params=[]) tvars.extend(distillation_vars) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) print('==total trainable vars==', list(tvars)) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( total_loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": total_loss, "logits": studnet_logit, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids, model_type): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels_dict['student'], None, average="macro") eval_metric_ops = { "{}_f1".format(model_type): sentence_f, "{}_acc".format(model_type): sentence_accuracy } return eval_metric_ops if output_type == "sess": return { "eval": { "per_example_loss": st_dict['logits']['per_example_loss'], "logits": studnet_logit, "loss": tf.reduce_mean(st_dict['logits']['per_example_loss']), "feature": st_dict['model'].get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn(st_dict['per_example_loss'], studnet_logit, features['label_ids'], "student") teacher_eval_metric_ops = metric_fn( ta_dict['per_example_loss'], teacher_logit, features['label_ids'], "teacher") eval_metric_ops.update(teacher_eval_metric_ops) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=total_loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint, train_file, dev_file, checkpoint_dir, is_debug): graph = tf.Graph() with graph.as_default(): import json config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" if FLAGS.if_shard == "0": train_size = FLAGS.train_size epoch = int(FLAGS.epoch / worker_count) elif FLAGS.if_shard == "1": train_size = int(FLAGS.train_size / worker_count) epoch = FLAGS.epoch init_lr = 2e-5 label_dict = json.load(open(FLAGS.label_id)) num_train_steps = int(train_size / FLAGS.batch_size * epoch) num_warmup_steps = int(num_train_steps * 0.1) num_storage_steps = int(train_size / FLAGS.batch_size) num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size) if is_debug == "0": num_storage_steps = 2 num_eval_steps = 10 num_train_steps = 10 print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}". format(num_train_steps, num_eval_steps, num_storage_steps)) print(" model type {}".format(FLAGS.model_type)) print(num_train_steps, num_warmup_steps, "=============") opt_config = Bunch({ "init_lr": init_lr / worker_count, "num_train_steps": num_train_steps, "num_warmup_steps": num_warmup_steps, "worker_count": worker_count, "opt_type": FLAGS.opt_type }) model_io_config = Bunch({"fix_lm": False}) model_io_fn = model_io.ModelIO(model_io_config) optimizer_fn = optimizer.Optimizer(opt_config) num_classes = FLAGS.num_classes model_train_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=None, load_pretrained=True, model_io_fn=model_io_fn, optimizer_fn=optimizer_fn, model_io_config=model_io_config, opt_config=opt_config, exclude_scope="", not_storage_params=[], target="") model_eval_fn = model_fn_builder(config, num_classes, init_checkpoint, model_reuse=True, load_pretrained=True, model_io_fn=model_io_fn, optimizer_fn=optimizer_fn, model_io_config=model_io_config, opt_config=opt_config, exclude_scope="", not_storage_params=[], target="") if FLAGS.opt_type == "ps": sync_replicas_hook = optimizer_fn.opt.make_session_run_hook( is_chief, num_tokens=0) else: sync_replicas_hook = [] def eval_metric_fn(features, eval_op_dict): logits = eval_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return { "accuracy": accuracy, "loss": eval_op_dict["loss"], "pred_label": pred_label, "label_ids": features["label_ids"] } def train_metric_fn(features, train_op_dict): logits = train_op_dict["logits"] print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) accuracy = correct = tf.equal( tf.cast(pred_label, tf.int32), tf.cast(features["label_ids"], tf.int32)) accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) return { "accuracy": accuracy, "loss": train_op_dict["loss"], "train_op": train_op_dict["train_op"] } name_to_features = { "input_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "input_mask": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "segment_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64), "label_ids": tf.FixedLenFeature([], tf.int64), } def _decode_record(record, name_to_features): """Decodes a record to a TensorFlow example. """ example = tf.parse_single_example(record, name_to_features) # tf.Example only supports tf.int64, but the TPU only supports tf.int32. # So cast all int64 to int32. for name in list(example.keys()): t = example[name] if t.dtype == tf.int64: t = tf.to_int32(t) example[name] = t return example params = Bunch({}) params.epoch = FLAGS.epoch params.batch_size = FLAGS.batch_size train_features = tf_data_utils.train_input_fn( train_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) eval_features = tf_data_utils.eval_input_fn(dev_file, _decode_record, name_to_features, params, if_shard=FLAGS.if_shard, worker_count=worker_count, task_index=task_index) train_op_dict = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN) eval_op_dict = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL) eval_dict = eval_metric_fn(eval_features, eval_op_dict["eval"]) train_dict = train_metric_fn(train_features, train_op_dict["train"]) print("===========begin to train============") sess_config = tf.ConfigProto(allow_soft_placement=False, log_device_placement=False) checkpoint_dir = checkpoint_dir if task_index == 0 else None print("start training") # hooks = [tf.train.StopAtStepHook(last_step=num_train_steps)] hooks = [] if FLAGS.opt_type == "ps": sync_replicas_hook = optimizer_fn.opt.make_session_run_hook( is_chief, num_tokens=0) hooks.append(sync_replicas_hook) sess = tf.train.MonitoredTrainingSession( master=target, is_chief=is_chief, config=sess_config, hooks=hooks, checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) else: sess = tf.train.MonitoredTrainingSession( config=sess_config, hooks=[], checkpoint_dir=checkpoint_dir, save_checkpoint_steps=num_storage_steps) def eval_fn(eval_dict, sess): i = 0 total_accuracy = 0 eval_total_dict = {} while True: try: eval_result = sess.run(eval_dict) for key in eval_result: if key not in eval_total_dict: if key in ["pred_label", "label_ids"]: eval_total_dict[key] = [] eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] = 0.0 eval_total_dict[key] += eval_result[key] else: if key in ["pred_label", "label_ids"]: eval_total_dict[key].extend(eval_result[key]) if key in ["accuracy", "loss"]: eval_total_dict[key] += eval_result[key] i += 1 if np.mod(i, num_eval_steps) == 0: break except tf.errors.OutOfRangeError: print("End of dataset") break label_id = eval_total_dict["label_ids"] pred_label = eval_total_dict["pred_label"] result = classification_report(label_id, pred_label, target_names=list( label_dict["label2id"].keys())) print(result, task_index) eval_total_dict["classification_report"] = result return eval_total_dict def train_fn(train_op_dict, sess): i = 0 cnt = 0 loss_dict = {} monitoring_train = [] monitoring_eval = [] while True: try: [train_result] = sess.run([train_op_dict]) step = sess.run(tf.train.get_global_step()) for key in train_result: if key == "train_op": continue else: if np.isnan(train_result[key]): print(train_loss, "get nan loss") break else: if key in loss_dict: loss_dict[key] += train_result[key] else: loss_dict[key] = train_result[key] i += 1 cnt += 1 if np.mod(i, num_storage_steps) == 0: string = "" for key in loss_dict: tmp = key + " " + str(loss_dict[key] / cnt) + "\t" string += tmp print(string, step) monitoring_train.append(loss_dict) eval_finial_dict = eval_fn(eval_dict, sess) monitoring_eval.append(eval_finial_dict) for key in loss_dict: loss_dict[key] = 0.0 cnt = 0 if is_debug == "0": if i == num_train_steps: break except tf.errors.OutOfRangeError: print("==Succeeded in training model==") # print("===========begin to train============") # sess_config = tf.ConfigProto(allow_soft_placement=False, # log_device_placement=False) # checkpoint_dir = checkpoint_dir if task_index == 0 else None # print("start training") # hooks = [tf.train.StopAtStepHook(last_step=num_train_steps)] # if sync_replicas_hook: # hooks.append(sync_replicas_hook) # sess = tf.train.MonitoredTrainingSession(master=target, # is_chief=is_chief, # config=sess_config, # hooks=[], # checkpoint_dir=checkpoint_dir, # save_checkpoint_steps=num_storage_steps) # with tf.train.MonitoredTrainingSession(master=target, # is_chief=is_chief, # config=sess_config, # hooks=[], # checkpoint_dir=checkpoint_dir, # save_checkpoint_steps=num_storage_steps) as sess: step = sess.run(optimizer_fn.global_step) print(step) train_fn(train_dict, sess) if task_index == 0: print("===========begin to eval============") eval_finial_dict = eval_fn(eval_dict, sess)
def model_fn(features, labels, mode): model = gpt_encoder(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) scope = model_config.scope if mode == tf.estimator.ModeKeys.TRAIN: # batch x seq_length sequence_mask = tf.to_float(tf.not_equal(features['input_ids'][:, 1:], kargs.get('[PAD]', 0))) # batch x seq_length seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ids'][:, 1:], logits=model.get_sequence_output_logits()[:, :-1]) per_example_loss = tf.reduce_sum(seq_loss*sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1)+1e-10) loss = tf.reduce_mean(per_example_loss) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) train_metric_dict = train_metric(features['input_ids'], model.get_sequence_output_logits(), **kargs) for key in train_metric_dict: tf.summary.scalar(key, train_metric_dict[key]) tf.summary.scalar('learning_rate', optimizer_fn.learning_rate) tf.summary.scalar('seq_length', tf.reduce_mean(tf.reduce_sum(sequence_mask, axis=-1))) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train":{ "loss":loss, "logits":logits, "train_op":train_op }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: if kargs.get('predict_type', 'sample_sequence') == 'sample_sequence': results = sample.sample_sequence( gpt_encoder, hparams=model_config, length=kargs.get('max_length', 64), start_token=None, batch_size=10, context=features['input_ids'], temperature=2, top_k=10) sampled_token = results['tokens'][:, 1:] sampled_token_logits = results['logits'][:, 1:] estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token':sampled_token, "logits":sampled_token_logits }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'token':sampled_token, "logits":sampled_token_logits } ) } ) return estimator_spec elif kargs.get('predict_type', 'sample_sequence') == 'infer_inputs': sequence_mask = tf.to_float(tf.not_equal(features['input_ids'][:, 1:], kargs.get('[PAD]', 0))) output_logits = model.get_sequence_output_logits()[:, :-1] # output_logits = tf.nn.log_softmax(output_logits, axis=-1) output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits( labels=features['input_ids'][:, 1:], logits=output_logits) per_example_perplexity = tf.reduce_sum(output_id_logits * sequence_mask, axis=-1) # batch per_example_perplexity /= tf.reduce_sum(sequence_mask, axis=-1) # batch perplexity = tf.exp(per_example_perplexity) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'token':features['input_ids'][:, 1:], "logits":output_id_logits, 'perplexity':perplexity }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'token':features['input_ids'][:,1:], "logits":output_id_logits, 'perplexity':perplexity } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss), "feature":model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope + "/feature_output", reuse=model_reuse): hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1] feature_output = tf.layers.dense( model.get_pooled_output(), hidden_size, kernel_initializer=tf.truncated_normal_initializer( stddev=0.01)) feature_output = tf.nn.dropout(feature_output, keep_prob=1 - dropout_prob) feature_output += model.get_pooled_output() feature_output = tf.layers.dense( feature_output, hidden_size, kernel_initializer=tf.truncated_normal_initializer( stddev=0.01), activation=tf.tanh) if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, feature_output, num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: # print(logits.get_shape(), "===logits shape===") # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) # prob = tf.nn.softmax(logits) # max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ # 'pred_label':pred_label, # "max_prob":max_prob "embedding": feature_output }, export_outputs={ "output": tf.estimator.export.PredictOutput( {"embedding": feature_output}) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()