def model_fn(features, labels, mode): model = bert_encoder(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) temperature_log_prob = tf.nn.log_softmax( logits / kargs.get("temperature", 2)) return [loss, per_example_loss, logits, temperature_log_prob]
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope='teacher') return_dict = { "loss": loss, "logits": logits, "tvars": tvars, "model": model, "per_example_loss": per_example_loss } return return_dict
def adversarial_loss(model_config, feature, adv_ids, dropout_prob, model_reuse, **kargs): '''make the task classifier cannot reliably predict the task based on the shared feature ''' # input = tf.stop_gradient(input) feature = tf.nn.dropout(feature, 1 - dropout_prob) with tf.variable_scope(model_config.scope+"/adv_classifier", reuse=model_reuse): (adv_loss, adv_per_example_loss, adv_logits) = classifier.classifier(model_config, feature, kargs.get('adv_num_labels', 7), adv_ids, dropout_prob) return (adv_loss, adv_per_example_loss, adv_logits)
def build_discriminator(model, scope, reuse): with tf.variable_scope(scope, reuse=reuse): try: label_ratio_table = tf.get_variable( name="label_ratio", initializer=tf.constant(label_tensor), trainable=False) ratio_weight = tf.nn.embedding_lookup( label_ratio_table, label_ids) print("==applying class weight==") except: ratio_weight = None (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob, ratio_weight) return loss, per_example_loss, logits
def model_fn(features, labels, mode): print(features) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 model = bert.Bert(model_config) model.build_embedder(input_ids, segment_ids, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_encoder(input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_pooler(reuse=reuse) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = pretrained_tvars model_io_fn.set_saver(var_lst=tvars) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return [train_op, loss, per_example_loss, logits] else: model_io_fn.print_params(tvars, string=", trainable params") return [loss, loss, per_example_loss, logits]
def model_fn(features, labels, mode): # model = bert_encoder(model_config, features, labels, # mode, target, reuse=model_reuse) model = albert_encoder(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) # train_op, hooks = model_io_fn.get_ema_hooks(train_op, # tvars, # kargs.get('params_moving_average_decay', 0.99), # scope, mode, # first_stage_steps=opt_config.num_warmup_steps, # two_stage=True) model_io_fn.set_saver() estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: # _, hooks = model_io_fn.get_ema_hooks(None, # None, # kargs.get('params_moving_average_decay', 0.99), # scope, mode) hooks = None def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) eval_hooks = [hooks] if hooks else [] estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops, evaluation_hooks=eval_hooks ) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 label_ids = features["label_ids"] model_lst = [] for index, name in enumerate(target): if index > 0: reuse = True else: reuse = model_reuse model_lst.append(bert_encoding(model_config, features, labels, mode, name, scope, dropout_rate, reuse=reuse)) [input_mask_a, repres_a] = model_lst[0] [input_mask_b, repres_b] = model_lst[1] output_a, output_b = alignment_aggerate(model_config, repres_a, repres_b, input_mask_a, input_mask_b, scope, reuse=model_reuse) if model_config.pooling == "ave_max_pooling": pooling_fn = ave_max_pooling elif model_config.pooling == "multihead_pooling": pooling_fn = multihead_pooling repres_a = pooling_fn(model_config, output_a, input_mask_a, scope, dropout_prob, reuse=model_reuse) repres_b = pooling_fn(model_config, output_b, input_mask_b, scope, dropout_prob, reuse=True) pair_repres = tf.concat([repres_a, repres_b, tf.abs(repres_a-repres_b), repres_b*repres_a], axis=-1) with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, pair_repres, num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train":{ "loss":loss, "logits":logits, "train_op":train_op }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label':pred_label, "max_prob":max_prob }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'pred_label':pred_label, "max_prob":max_prob } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "loss": sentence_mean_loss, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): # model_api = model_zoo(model_config) # model = model_api(model_config, features, labels, # mode, target, reuse=model_reuse) task_type = kargs.get("task_type", "cls") label_ids = features["{}_label_ids".format(task_type)] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope + "/{}/classifier".format(task_type), reuse=task_layer_reuse): (_, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) task_mask = tf.cast(features["{}_loss_multiplier".format(task_type)], tf.float32) masked_per_example_loss = task_mask * per_example_loss loss = tf.reduce_sum(masked_per_example_loss) / ( 1e-10 + tf.reduce_sum(task_mask)) logits = tf.expand_dims(task_mask, axis=-1) * logits model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) # try: # params_size = model_io_fn.count_params(model_config.scope) # print("==total params==", params_size) # except: # print("==not count params==") # print(tvars) # if load_pretrained == "yes": # model_io_fn.load_pretrained(tvars, # init_checkpoint, # exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: return { "loss": loss, "logits": logits, "task_num": tf.reduce_sum(task_mask), "tvars": tvars } elif mode == tf.estimator.ModeKeys.EVAL: return { "loss": loss, "logits": logits, "feature": model.get_pooled_output() }
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) label_loss = tf.reduce_sum( per_example_loss * features["label_ratio"]) / ( 1e-10 + tf.reduce_sum(features["label_ratio"])) if mode == tf.estimator.ModeKeys.TRAIN: distillation_api = distill.KnowledgeDistillation( kargs.get( "disitllation_config", Bunch({ "logits_ratio_decay": "constant", "logits_ratio": 0.5, "logits_decay_rate": 0.999, "distillation": ['relation_kd', 'logits'], "feature_ratio": 0.5, "feature_ratio_decay": "constant", "feature_decay_rate": 0.999, "kd_type": "kd", "scope": scope }))) # get teacher logits teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get( "temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax( logits / kargs.get("temperature", 2.0)) # log_softmax logits distillation_features = { "student_logits_tensor": student_logit, "teacher_logits_tensor": teacher_logit, "student_feature_tensor": model.get_pooled_output(), "teacher_feature_tensor": features["distillation_feature"], "student_label": tf.ones_like(label_ids, dtype=tf.int32), "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32), "logits_ratio": kargs.get("logits_ratio", 0.5), "feature_ratio": kargs.get("logits_ratio", 0.5), "distillation_ratio": features["distillation_ratio"], "src_f_logit": logits, "tgt_f_logit": logits, "src_tensor": model.get_pooled_output(), "tgt_tensor": features["distillation_feature"] } distillation_loss = distillation_api.distillation( distillation_features, 2, dropout_prob, model_reuse, opt_config.num_train_steps, feature_ratio=1.0, logits_ratio_decay="constant", feature_ratio_decay="constant", feature_decay_rate=0.999, logits_decay_rate=0.999, logits_ratio=0.5, scope=scope + "/adv_classifier", num_classes=num_labels, gamma=kargs.get("gamma", 4)) loss = label_loss + distillation_loss["distillation_loss"] model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op, "cross_entropy": label_loss, "distillation_loss": distillation_loss["distillation_loss"], "kd_num": tf.reduce_sum(features["distillation_ratio"]), "ce_num": tf.reduce_sum(features["label_ratio"]), "label_ratio": features["label_ratio"], "distilaltion_logits_loss": distillation_loss["distillation_logits_loss"], "distilaltion_feature_loss": distillation_loss["distillation_feature_loss"], "rkd_loss": distillation_loss["rkd_loss"] }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse, **kargs) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) print(tf.global_variables(), "==global_variables==") if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: # if model_config.get('label_type', 'single_label') == 'single_label': # print(logits.get_shape(), "===logits shape===") # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) # prob = tf.nn.softmax(logits) # max_prob = tf.reduce_max(prob, axis=-1) # estimator_spec = tf.estimator.EstimatorSpec( # mode=mode, # predictions={ # 'pred_label':pred_label, # "max_prob":max_prob # }, # export_outputs={ # "output":tf.estimator.export.PredictOutput( # { # 'pred_label':pred_label, # "max_prob":max_prob # } # ) # } # ) if model_config.get('label_type', 'single_label') == 'multi_label': prob = tf.nn.sigmoid(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': prob, "max_prob": prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': prob, "max_prob": prob }) }) elif model_config.get('label_type', 'single_label') == "single_label": prob = tf.nn.softmax(logits) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': prob, "max_prob": prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': prob, "max_prob": prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss), "feature": model.get_pooled_output() } } elif output_type == "estimator": eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): task_type = kargs.get("task_type", "cls") label_ids = features["{}_label_ids".format(task_type)] num_task = kargs.get('num_task', 1) model_io_fn = model_io.ModelIO(model_io_config) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope if kargs.get("get_pooled_output", "pooled_output") == "pooled_output": pooled_feature = model.get_pooled_output() elif kargs.get("get_pooled_output", "task_output") == "task_output": pooled_feature_dict = model.get_task_output() pooled_feature = pooled_feature_dict['pooled_feature'] loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) loss = tf.constant(0.0) params_size = model_io_fn.count_params(model_config.scope) print("==total encoder params==", params_size) if kargs.get("feature_distillation", True): universal_feature_a = features.get("input_ids_a_features", None) universal_feature_b = features.get("input_ids_b_features", None) if universal_feature_a is None or universal_feature_b is None: tf.logging.info( "****** not apply feature distillation *******") feature_loss = tf.constant(0.0) else: feature_a = pooled_feature_dict['feature_a'] feature_a_shape = bert_utils.get_shape_list( feature_a, expected_rank=[2, 3]) pretrain_feature_a_shape = bert_utils.get_shape_list( universal_feature_a, expected_rank=[2, 3]) if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope + "/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_a = tf.layers.dense( feature_a, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task) tf.logging.info( "****** apply auto-encoder for feature compression *******" ) else: proj_feature_a = feature_a feature_a_norm = tf.stop_gradient( tf.sqrt( tf.reduce_sum(tf.pow(proj_feature_a, 2), axis=-1, keepdims=True)) + 1e-20) proj_feature_a /= feature_a_norm feature_b = pooled_feature_dict['feature_b'] if feature_a_shape[-1] != pretrain_feature_a_shape[-1]: with tf.variable_scope(scope + "/feature_proj", reuse=tf.AUTO_REUSE): proj_feature_b = tf.layers.dense( feature_b, pretrain_feature_a_shape[-1]) # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE): # proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1]) # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task) tf.logging.info( "****** apply auto-encoder for feature compression *******" ) else: proj_feature_b = feature_b feature_b_norm = tf.stop_gradient( tf.sqrt( tf.reduce_sum(tf.pow(proj_feature_b, 2), axis=-1, keepdims=True)) + 1e-20) proj_feature_b /= feature_b_norm feature_a_distillation = tf.reduce_mean( tf.square(universal_feature_a - proj_feature_a), axis=-1) feature_b_distillation = tf.reduce_mean( tf.square(universal_feature_b - proj_feature_b), axis=-1) feature_loss = tf.reduce_mean( (feature_a_distillation + feature_b_distillation) / 2.0) / float(num_task) loss += feature_loss tf.logging.info( "****** apply prertained feature distillation *******") if kargs.get("embedding_distillation", True): word_embed = model.emb_mat random_embed_shape = bert_utils.get_shape_list( word_embed, expected_rank=[2, 3]) print("==random_embed_shape==", random_embed_shape) pretrained_embed = kargs.get('pretrained_embed', None) if pretrained_embed is None: tf.logging.info( "****** not apply prertained feature distillation *******") embed_loss = tf.constant(0.0) else: pretrain_embed_shape = bert_utils.get_shape_list( pretrained_embed, expected_rank=[2, 3]) print("==pretrain_embed_shape==", pretrain_embed_shape) if random_embed_shape[-1] != pretrain_embed_shape[-1]: with tf.variable_scope(scope + "/embedding_proj", reuse=tf.AUTO_REUSE): proj_embed = tf.layers.dense(word_embed, pretrain_embed_shape[-1]) else: proj_embed = word_embed embed_loss = tf.reduce_mean( tf.reduce_mean(tf.square(proj_embed - pretrained_embed), axis=-1)) / float(num_task) loss += embed_loss tf.logging.info( "****** apply prertained feature distillation *******") with tf.variable_scope(scope + "/{}/classifier".format(task_type), reuse=task_layer_reuse): (_, per_example_loss, logits) = classifier.classifier(model_config, pooled_feature, num_labels, label_ids, dropout_prob) loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)], tf.float32) masked_per_example_loss = per_example_loss * loss_mask task_loss = tf.reduce_sum(masked_per_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) loss += task_loss if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones( (1, multi_task_config[task_type]["max_predictions_per_seq"])) masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, )) masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32) masked_lm_example_loss *= masked_lm_loss_mask # multiply task_mask masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / ( 1e-10 + tf.reduce_sum(masked_lm_loss_mask)) loss += multi_task_config[task_type][ "masked_lm_loss_ratio"] * masked_lm_loss masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1]) print(masked_lm_log_probs.get_shape(), "===masked lm log probs===") print(masked_lm_label_ids.get_shape(), "===masked lm ids===") print(masked_lm_label_weights.get_shape(), "===masked lm mask===") lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask) if kargs.get("task_invariant", "no") == "yes": print("==apply task adversarial training==") with tf.variable_scope(scope + "/dann_task_invariant", reuse=model_reuse): (_, task_example_loss, task_logits) = distillation_utils.feature_distillation( model.get_pooled_output(), 1.0, features["task_id"], kargs.get("num_task", 7), dropout_prob, True) masked_task_example_loss = loss_mask * task_example_loss masked_task_loss = tf.reduce_sum(masked_task_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) tvars.extend(masked_lm_pretrain_tvars) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") # print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: acc = build_accuracy(logits, label_ids, loss_mask) return_dict = { "loss": loss, "logits": logits, "task_num": tf.reduce_sum(loss_mask), "tvars": tvars } return_dict["{}_acc".format(task_type)] = acc if kargs.get("task_invariant", "no") == "yes": return_dict["{}_task_loss".format( task_type)] = masked_task_loss task_acc = build_accuracy(task_logits, features["task_id"], loss_mask) return_dict["{}_task_acc".format(task_type)] = task_acc if multi_task_config[task_type].get("lm_augumentation", False): return_dict["{}_masked_lm_loss".format( task_type)] = masked_lm_loss return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc if kargs.get("embedding_distillation", True): return_dict["embed_loss"] = embed_loss * float(num_task) else: return_dict["embed_loss"] = task_loss if kargs.get("feature_distillation", True): return_dict["feature_loss"] = feature_loss * float(num_task) else: return_dict["feature_loss"] = task_loss return_dict["task_loss"] = task_loss return return_dict elif mode == tf.estimator.ModeKeys.EVAL: eval_dict = { "loss": loss, "logits": logits, "feature": model.get_pooled_output() } if kargs.get("adversarial", "no") == "adversarial": eval_dict["task_logits"] = task_logits return eval_dict
def model_fn(features, labels, mode): print(features) input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 model = bert.Bert(model_config) model.build_embedder(input_ids, segment_ids, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_encoder(input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_pooler(reuse=reuse) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = pretrained_tvars model_io_fn.set_saver(var_lst=tvars) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): # optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) elif mode == tf.estimator.ModeKeys.PREDICT: predictions = { # Generate predictions (for PREDICT and EVAL mode) "classes": tf.argmax(input=logits, axis=1), # Add `softmax_tensor` to the graph. It is used for PREDICT and by the # `logging_hook`. "probabilities": tf.exp(tf.nn.log_softmax(logits, name="softmax_tensor")) } return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) elif mode == tf.estimator.ModeKeys.EVAL: """ needs to manaually write metric ops see https://github.com/google/seq2seq/blob/7f485894d412e8d81ce0e07977831865e44309ce/seq2seq/metrics/metric_specs.py """ return None
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=tf.AUTO_REUSE) # model_adv_config = copy.deepcopy(model_config) # model_adv_config.scope = model_config.scope + "/adv_encoder" # model_adv_adaptation = model_api(model_adv_config, features, labels, # mode, target, reuse=tf.AUTO_REUSE) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope common_feature = model.get_pooled_output() task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/task_residual", if_grl=False) adv_task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual", if_grl=True) with tf.variable_scope(scope, reuse=tf.AUTO_REUSE): concat_feature = task_feature # concat_feature = tf.concat([task_feature, # adv_task_feature], # axis=-1) (loss, per_example_loss, logits) = classifier.classifier(model_config, concat_feature, num_labels, label_ids, dropout_prob, *kargs) with tf.variable_scope(scope+"/adv_classifier", reuse=tf.AUTO_REUSE): adv_ids = features["adv_ids"] (adv_loss, adv_per_example_loss, adv_logits) = classifier.classifier(model_config, adv_task_feature, kargs.get('adv_num_labels', 12), adv_ids, dropout_prob, **kargs) if mode == tf.estimator.ModeKeys.TRAIN: loss_diff = tf.constant(0.0) # adv_task_feature_no_grl = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual") # loss_diff = diff_loss(task_feature, # adv_task_feature_no_grl) print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==") # get teacher logits teacher_logit = tf.log(features["label_probs"]+1e-10)/kargs.get("temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax(logits /kargs.get("temperature", 2.0)) # log_softmax logits distillation_loss = kd_distance(teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) distillation_loss *= features["distillation_ratio"] distillation_loss = tf.reduce_sum(distillation_loss) / (1e-10+tf.reduce_sum(features["distillation_ratio"])) label_loss = loss #tf.reduce_sum(per_example_loss * features["label_ratio"]) / (1e-10+tf.reduce_sum(features["label_ratio"])) print("==distillation loss ratio==", kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)) # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss loss = (1-kargs.get("distillation_ratio", 0.9))*label_loss + kargs.get("distillation_ratio", 0.9) * distillation_loss if mode == tf.estimator.ModeKeys.TRAIN: loss += kargs.get("adv_ratio", 0.1) * adv_loss + loss_diff model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": adv_pred_label = tf.argmax(adv_logits, axis=-1, output_type=tf.int32) adv_correct = tf.equal( tf.cast(adv_pred_label, tf.int32), tf.cast(adv_ids, tf.int32) ) adv_accuracy = tf.reduce_mean(tf.cast(adv_correct, tf.float32)) return { "train":{ "loss":loss, "logits":logits, "train_op":train_op, "cross_entropy":label_loss, "kd_loss":distillation_loss, "kd_num":tf.reduce_sum(features["distillation_ratio"]), "ce_num":tf.reduce_sum(features["label_ratio"]), "teacher_logit":teacher_logit, "student_logit":student_logit, "label_ratio":features["label_ratio"], "loss_diff":loss_diff, "adv_loss":adv_loss, "adv_accuracy":adv_accuracy }, "hooks":training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: task_prob = tf.exp(tf.nn.log_softmax(logits)) adv_prob = tf.exp(tf.nn.log_softmax(adv_logits)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'adv_prob':adv_prob, "task_prob":task_prob }, export_outputs={ "output":tf.estimator.export.PredictOutput( { 'adv_prob':adv_prob, "task_prob":task_prob } ) } ) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean( values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops eval_metric_ops = metric_fn( per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval":{ "per_example_loss":per_example_loss, "logits":logits, "loss":tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): input_ids = features["input_ids"] input_mask = features["input_mask"] segment_ids = features["segment_ids"] label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 model = bert.Bert(model_config) model.build_embedder(input_ids, segment_ids, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_encoder(input_ids, input_mask, hidden_dropout_prob, attention_probs_dropout_prob, reuse=reuse) model.build_pooler(reuse=reuse) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint) tvars = model_io_fn.get_params(scope) update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) output_spec = tf.estimator.EstimatorSpec(mode=mode, predictions={ 'pred_label': pred_label, "label_ids": label_ids, "max_prob": max_prob }) return output_spec
def model_fn(features, labels, mode): label_ids = features["label_ids"] model_lst = [] for index, name in enumerate(input_name): if index > 0: reuse = True else: reuse = model_reuse model_lst.append( base_model(model_config, features, labels, mode, name, reuse=reuse)) if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 assert len(model_lst) == len(input_name) if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): try: label_ratio_table = tf.get_variable( name="label_ratio", shape=[ num_labels, ], initializer=tf.constant(label_tensor), trainable=False) ratio_weight = tf.nn.embedding_lookup(label_ratio_table, label_ids) except: ratio_weight = None seq_output_lst = [model.get_pooled_output() for model in model_lst] repres = seq_output_lst[0] + seq_output_lst[1] final_hidden_shape = bert_utils.get_shape_list(repres, expected_rank=2) z_mean = tf.layers.dense(repres, final_hidden_shape[1], name="z_mean") z_log_var = tf.layers.dense(repres, final_hidden_shape[1], name="z_log_var") print("=======applying vib============") if mode == tf.estimator.ModeKeys.TRAIN: print("====applying vib====") vib_connector = vib.VIB(vib_config) [kl_loss, latent_vector ] = vib_connector.build_regularizer([z_mean, z_log_var]) [loss, per_example_loss, logits] = classifier.classifier(model_config, latent_vector, num_labels, label_ids, dropout_prob, ratio_weight) loss += tf.reduce_mean(kl_loss) else: print("====applying z_mean for prediction====") [loss, per_example_loss, logits] = classifier.classifier(model_config, z_mean, num_labels, label_ids, dropout_prob, ratio_weight) # model_io_fn = model_io.ModelIO(model_io_config) pretrained_tvars = model_io_fn.get_params(model_config.scope) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint) tvars = model_io_fn.get_params(scope) if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return [train_op, loss, per_example_loss, logits] else: model_io_fn.print_params(tvars, string=", trainable params") return [loss, loss, per_example_loss, logits]
def model_fn(features, labels, mode): model = bert_encoder(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) model_io_fn.set_saver(var_lst=tvars) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op } } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): task_type = kargs.get("task_type", "cls") label_ids = features["{}_label_ids".format(task_type)] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope + "/{}/classifier".format(task_type), reuse=task_layer_reuse): (_, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) loss_mask = tf.cast(features["{}_loss_multiplier".format(task_type)], tf.float32) masked_per_example_loss = per_example_loss * loss_mask loss = tf.reduce_sum(masked_per_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=model_reuse) masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones( (1, multi_task_config[task_type]["max_predictions_per_seq"])) masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, )) masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_loss_mask *= tf.cast(masked_lm_label_weights, tf.float32) masked_lm_example_loss *= masked_lm_loss_mask # multiply task_mask masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / ( 1e-10 + tf.reduce_sum(masked_lm_loss_mask)) loss += multi_task_config[task_type][ "masked_lm_loss_ratio"] * masked_lm_loss masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1]) print(masked_lm_log_probs.get_shape(), "===masked lm log probs===") print(masked_lm_label_ids.get_shape(), "===masked lm ids===") print(masked_lm_label_weights.get_shape(), "===masked lm mask===") lm_acc = build_accuracy(masked_lm_log_probs, masked_lm_label_ids, masked_lm_loss_mask) if kargs.get("task_invariant", "no") == "yes": print("==apply task adversarial training==") with tf.variable_scope(scope + "/dann_task_invariant", reuse=model_reuse): (_, task_example_loss, task_logits) = distillation_utils.feature_distillation( model.get_pooled_output(), 1.0, features["task_id"], kargs.get("num_task", 7), dropout_prob, True) masked_task_example_loss = loss_mask * task_example_loss masked_task_loss = tf.reduce_sum(masked_task_example_loss) / ( 1e-10 + tf.reduce_sum(loss_mask)) loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) if mode == tf.estimator.ModeKeys.TRAIN: multi_task_config = kargs.get("multi_task_config", {}) if multi_task_config[task_type].get("lm_augumentation", False): print("==apply lm_augumentation==") masked_lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) tvars.extend(masked_lm_pretrain_tvars) try: params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) except: print("==not count params==") # print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: acc = build_accuracy(logits, label_ids, loss_mask) return_dict = { "loss": loss, "logits": logits, "task_num": tf.reduce_sum(loss_mask), "tvars": tvars } return_dict["{}_acc".format(task_type)] = acc if kargs.get("task_invariant", "no") == "yes": return_dict["{}_task_loss".format( task_type)] = masked_task_loss task_acc = build_accuracy(task_logits, features["task_id"], loss_mask) return_dict["{}_task_acc".format(task_type)] = task_acc if multi_task_config[task_type].get("lm_augumentation", False): return_dict["{}_masked_lm_loss".format( task_type)] = masked_lm_loss return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc return return_dict elif mode == tf.estimator.ModeKeys.EVAL: eval_dict = { "loss": loss, "logits": logits, "feature": model.get_pooled_output() } if kargs.get("adversarial", "no") == "adversarial": eval_dict["task_logits"] = task_logits return eval_dict
def model_fn(features, labels, mode): label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 model = base_model(model_config, features, labels, mode, reuse=reuse) with tf.variable_scope(model_config.scope, reuse=reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: masked_lm_positions = features["masked_lm_positions"] masked_lm_ids = features["masked_lm_ids"] masked_lm_weights = features["masked_lm_weights"] (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs) = pretrain.get_masked_lm_output( model_config, model.get_sequence_output(), model.get_embedding_table(), masked_lm_positions, masked_lm_ids, masked_lm_weights, reuse=reuse) total_loss = model_config.lm_ratio * masked_lm_loss + model_config.task_ratio * loss else: total_loss = loss if mode == tf.estimator.ModeKeys.TRAIN: pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) masked_lm_pretrain_tvars = model_io_fn.get_params( "cls/predictions", not_storage_params=not_storage_params) pretrained_tvars.extend(masked_lm_pretrain_tvars) if load_pretrained: model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope) tvars = pretrained_tvars model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( total_loss, tvars, opt_config.init_lr, opt_config.num_train_steps) output_dict = { "train_op": train_op, "total_loss": total_loss, "masked_lm_loss": masked_lm_loss, "sentence_loss": loss } return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op) elif mode == tf.estimator.ModeKeys.PREDICT: def prediction_fn(logits): predictions = { "classes": tf.argmax(input=logits, axis=1), "probabilities": tf.exp(tf.nn.log_softmax(logits, name="softmax_tensor")) } return predictions predictions = prediction_fn(logits) return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions) elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" masked_lm_log_probs = tf.reshape( masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]]) masked_lm_predictions = tf.argmax(masked_lm_log_probs, axis=-1, output_type=tf.int32) masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1]) masked_lm_ids = tf.reshape(masked_lm_ids, [-1]) masked_lm_weights = tf.reshape(masked_lm_weights, [-1]) masked_lm_accuracy = tf.metrics.accuracy( labels=masked_lm_ids, predictions=masked_lm_predictions, weights=masked_lm_weights) masked_lm_mean_loss = tf.metrics.mean( values=masked_lm_example_loss, weights=masked_lm_weights) sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") return { "masked_lm_accuracy": masked_lm_accuracy, "masked_lm_loss": masked_lm_mean_loss, "sentence_f": sentence_f, "sentence_loss": sentence_mean_loss } eval_metric_ops = metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids, masked_lm_weights, per_example_loss, logits, label_ids) return tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, eval_metric_ops=eval_metric_ops)
def model_fn(features, labels, mode): model_api = model_zoo(model_config) label_ids = features["label_ids"] model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, model.get_pooled_output(), num_labels, label_ids, dropout_prob) print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==") # anneal_fn = anneal_strategy.AnnealStrategy(kargs.get("anneal_config", {})) # get teacher logits teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get( "temperature", 2.0) # log_softmax logits student_logit = tf.nn.log_softmax( logits / kargs.get("temperature", 2.0)) # log_softmax logits distillation_loss = kd_distance( teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) distillation_loss *= features["distillation_ratio"] distillation_loss = tf.reduce_sum(distillation_loss) / ( 1e-10 + tf.reduce_sum(features["distillation_ratio"])) label_loss = tf.reduce_sum( per_example_loss * features["label_ratio"]) / ( 1e-10 + tf.reduce_sum(features["label_ratio"])) print( "==distillation loss ratio==", kargs.get("distillation_ratio", 0.9) * tf.pow(kargs.get("temperature", 2.0), 2)) # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss loss = (1 - kargs.get("distillation_ratio", 0.9)) * label_loss + tf.pow( kargs.get("temperature", 2.0), 2) * kargs.get( "distillation_ratio", 0.9) * distillation_loss model_io_fn = model_io.ModelIO(model_io_config) params_size = model_io_fn.count_params(model_config.scope) print("==total params==", params_size) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op, "cross_entropy": label_loss, "kd_loss": distillation_loss, "kd_num": tf.reduce_sum(features["distillation_ratio"]), "ce_num": tf.reduce_sum(features["label_ratio"]), "teacher_logit": teacher_logit, "student_logit": student_logit, "label_ratio": features["label_ratio"] }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: print(logits.get_shape(), "===logits shape===") pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) prob = tf.nn.softmax(logits) max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ 'pred_label': pred_label, "max_prob": max_prob }, export_outputs={ "output": tf.estimator.export.PredictOutput({ 'pred_label': pred_label, "max_prob": max_prob }) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": return estimator_spec else: raise NotImplementedError()
def model_fn(features, labels, mode): if model_io_config.fix_lm == True: scope = model_config.scope + "/task" else: scope = model_config.scope if mode == tf.estimator.ModeKeys.TRAIN: hidden_dropout_prob = model_config.hidden_dropout_prob attention_probs_dropout_prob = model_config.attention_probs_dropout_prob dropout_prob = model_config.dropout_prob else: hidden_dropout_prob = 0.0 attention_probs_dropout_prob = 0.0 dropout_prob = 0.0 label_ids = features["label_ids"] target = kargs["target"] [a_mask, a_repres, b_mask, b_repres] = bert_lstm_encoding(model_config, features, labels, mode, target, max_len, scope, dropout_prob, reuse=model_reuse) a_repres = lstm_model(model_config, a_repres, a_mask, dropout_prob, scope, model_reuse) b_repres = lstm_model(model_config, b_repres, b_mask, dropout_prob, scope, True) a_output, b_output = alignment(model_config, a_repres, b_repres, a_mask, b_mask, scope, reuse=model_reuse) repres_a = bert_multihead_pooling(model_config, a_output, a_mask, scope, dropout_prob, reuse=model_reuse) repres_b = bert_multihead_pooling(model_config, b_output, b_mask, scope, dropout_prob, reuse=True) pair_repres = tf.concat([ repres_a, repres_b, tf.abs(repres_a - repres_b), repres_b * repres_a ], axis=-1) print(pair_repres.get_shape(), "==repres shape==") with tf.variable_scope(scope, reuse=model_reuse): try: label_ratio_table = tf.get_variable( name="label_ratio", shape=[ num_labels, ], initializer=tf.constant(label_tensor), trainable=False) ratio_weight = tf.nn.embedding_lookup(label_ratio_table, label_ids) print("==applying class weight==") except: ratio_weight = None (loss, per_example_loss, logits) = classifier.classifier(model_config, pair_repres, num_labels, label_ids, dropout_prob, ratio_weight) if mode == tf.estimator.ModeKeys.TRAIN: pretrained_tvars = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) if load_pretrained: model_io_fn.load_pretrained( pretrained_tvars, init_checkpoint, exclude_scope=exclude_scope_dict["task"]) trainable_params = model_io_fn.get_params( scope, not_storage_params=not_storage_params) tvars = trainable_params storage_params = model_io_fn.get_params( model_config.scope, not_storage_params=not_storage_params) # for var in storage_params: # print(var.name, var.get_shape(), "==storage params==") # for var in tvars: # print(var.name, var.get_shape(), "==trainable params==") if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimizer_fn = optimizer.Optimizer(opt_config) train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps) return [train_op, loss, per_example_loss, logits] else: model_io_fn.print_params(tvars, string=", trainable params") return [loss, loss, per_example_loss, logits]
def model_fn(features, labels, mode): model_api = model_zoo(model_config) model = model_api(model_config, features, labels, mode, target, reuse=model_reuse) label_ids = features["label_ids"] if mode == tf.estimator.ModeKeys.TRAIN: dropout_prob = model_config.dropout_prob else: dropout_prob = 0.0 if model_io_config.fix_lm == True: scope = model_config.scope + "_finetuning" else: scope = model_config.scope with tf.variable_scope(scope + "/feature_output", reuse=model_reuse): hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1] feature_output = tf.layers.dense( model.get_pooled_output(), hidden_size, kernel_initializer=tf.truncated_normal_initializer( stddev=0.01)) feature_output = tf.nn.dropout(feature_output, keep_prob=1 - dropout_prob) feature_output += model.get_pooled_output() feature_output = tf.layers.dense( feature_output, hidden_size, kernel_initializer=tf.truncated_normal_initializer( stddev=0.01), activation=tf.tanh) if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL: with tf.variable_scope(scope, reuse=model_reuse): (loss, per_example_loss, logits) = classifier.classifier(model_config, feature_output, num_labels, label_ids, dropout_prob) model_io_fn = model_io.ModelIO(model_io_config) tvars = model_io_fn.get_params(model_config.scope, not_storage_params=not_storage_params) print(tvars) if load_pretrained == "yes": model_io_fn.load_pretrained(tvars, init_checkpoint, exclude_scope=exclude_scope) if mode == tf.estimator.ModeKeys.TRAIN: optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(tvars, string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op( loss, tvars, opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver() if kargs.get("task_index", 1) == 0 and kargs.get( "run_config", None): training_hooks = [] elif kargs.get("task_index", 1) == 0: model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, train_op=train_op, training_hooks=training_hooks) if output_type == "sess": return { "train": { "loss": loss, "logits": logits, "train_op": train_op }, "hooks": training_hooks } elif output_type == "estimator": return estimator_spec elif mode == tf.estimator.ModeKeys.PREDICT: # print(logits.get_shape(), "===logits shape===") # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32) # prob = tf.nn.softmax(logits) # max_prob = tf.reduce_max(prob, axis=-1) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, predictions={ # 'pred_label':pred_label, # "max_prob":max_prob "embedding": feature_output }, export_outputs={ "output": tf.estimator.export.PredictOutput( {"embedding": feature_output}) }) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: def metric_fn(per_example_loss, logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax(logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_mean_loss = tf.metrics.mean(values=per_example_loss) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy} return eval_metric_ops if output_type == "sess": return { "eval": { "per_example_loss": per_example_loss, "logits": logits, "loss": tf.reduce_mean(per_example_loss) } } elif output_type == "estimator": eval_metric_ops = metric_fn(per_example_loss, logits, label_ids) estimator_spec = tf.estimator.EstimatorSpec( mode=mode, loss=loss, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()