def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)
            label_loss = tf.reduce_sum(
                per_example_loss * features["label_ratio"]) / (
                    1e-10 + tf.reduce_sum(features["label_ratio"]))

        if mode == tf.estimator.ModeKeys.TRAIN:

            distillation_api = distill.KnowledgeDistillation(
                kargs.get(
                    "disitllation_config",
                    Bunch({
                        "logits_ratio_decay": "constant",
                        "logits_ratio": 0.5,
                        "logits_decay_rate": 0.999,
                        "distillation": ['relation_kd', 'logits'],
                        "feature_ratio": 0.5,
                        "feature_ratio_decay": "constant",
                        "feature_decay_rate": 0.999,
                        "kd_type": "kd",
                        "scope": scope
                    })))
            # get teacher logits
            teacher_logit = tf.log(features["label_probs"] +
                                   1e-10) / kargs.get(
                                       "temperature",
                                       2.0)  # log_softmax logits
            student_logit = tf.nn.log_softmax(
                logits / kargs.get("temperature", 2.0))  # log_softmax logits

            distillation_features = {
                "student_logits_tensor": student_logit,
                "teacher_logits_tensor": teacher_logit,
                "student_feature_tensor": model.get_pooled_output(),
                "teacher_feature_tensor": features["distillation_feature"],
                "student_label": tf.ones_like(label_ids, dtype=tf.int32),
                "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32),
                "logits_ratio": kargs.get("logits_ratio", 0.5),
                "feature_ratio": kargs.get("logits_ratio", 0.5),
                "distillation_ratio": features["distillation_ratio"],
                "src_f_logit": logits,
                "tgt_f_logit": logits,
                "src_tensor": model.get_pooled_output(),
                "tgt_tensor": features["distillation_feature"]
            }

            distillation_loss = distillation_api.distillation(
                distillation_features,
                2,
                dropout_prob,
                model_reuse,
                opt_config.num_train_steps,
                feature_ratio=1.0,
                logits_ratio_decay="constant",
                feature_ratio_decay="constant",
                feature_decay_rate=0.999,
                logits_decay_rate=0.999,
                logits_ratio=0.5,
                scope=scope + "/adv_classifier",
                num_classes=num_labels,
                gamma=kargs.get("gamma", 4))

            loss = label_loss + distillation_loss["distillation_loss"]

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":

                    return {
                        "train": {
                            "loss":
                            loss,
                            "logits":
                            logits,
                            "train_op":
                            train_op,
                            "cross_entropy":
                            label_loss,
                            "distillation_loss":
                            distillation_loss["distillation_loss"],
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num":
                            tf.reduce_sum(features["label_ratio"]),
                            "label_ratio":
                            features["label_ratio"],
                            "distilaltion_logits_loss":
                            distillation_loss["distillation_logits_loss"],
                            "distilaltion_feature_loss":
                            distillation_loss["distillation_feature_loss"],
                            "rkd_loss":
                            distillation_loss["rkd_loss"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 2
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse,
                          **kargs)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                print(tf.global_variables(), "==global_variables==")
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # if model_config.get('label_type', 'single_label') == 'single_label':
            # 	print(logits.get_shape(), "===logits shape===")
            # 	pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # 	prob = tf.nn.softmax(logits)
            # 	max_prob = tf.reduce_max(prob, axis=-1)

            # 	estimator_spec = tf.estimator.EstimatorSpec(
            # 							mode=mode,
            # 							predictions={
            # 										'pred_label':pred_label,
            # 										"max_prob":max_prob
            # 							},
            # 							export_outputs={
            # 								"output":tf.estimator.export.PredictOutput(
            # 											{
            # 												'pred_label':pred_label,
            # 												"max_prob":max_prob
            # 											}
            # 										)
            # 							}
            # 				)
            if model_config.get('label_type', 'single_label') == 'multi_label':
                prob = tf.nn.sigmoid(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            elif model_config.get('label_type',
                                  'single_label') == "single_label":
                prob = tf.nn.softmax(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model_lst = []

        assert len(target.split(",")) == 2
        target_name_lst = target.split(",")
        print(target_name_lst)
        for index, name in enumerate(target_name_lst):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                model_api(model_config,
                          features,
                          labels,
                          mode,
                          name,
                          reuse=reuse))

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            if config.get("classifier",
                          "order_classifier") == "order_classifier":
                [loss, per_example_loss,
                 logits] = classifier.order_classifier(model_config,
                                                       seq_output_lst,
                                                       num_labels,
                                                       label_ids,
                                                       dropout_prob,
                                                       ratio_weight=None)
            elif config.get(
                    "classifier",
                    "order_classifier") == "siamese_interaction_classifier":
                [loss, per_example_loss,
                 logits] = classifier.siamese_classifier(model_config,
                                                         seq_output_lst,
                                                         num_labels,
                                                         label_ids,
                                                         dropout_prob,
                                                         ratio_weight=None)

        print(kargs.get("temperature", 0.5),
              kargs.get("distillation_ratio", 0.5),
              "==distillation hyparameter==")

        # get teacher logits
        teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get(
            "temperature", 2.0)  # log_softmax logits
        student_logit = tf.nn.log_softmax(
            logits / kargs.get("temperature", 2.0))  # log_softmax logits

        distillation_loss = kd_distance(
            teacher_logit, student_logit,
            kargs.get("distillation_distance", "kd"))
        distillation_loss *= features["distillation_ratio"]
        distillation_loss = tf.reduce_sum(distillation_loss) / (
            1e-10 + tf.reduce_sum(features["distillation_ratio"]))

        label_loss = tf.reduce_sum(
            per_example_loss * features["label_ratio"]) / (
                1e-10 + tf.reduce_sum(features["label_ratio"]))

        print(
            "==distillation loss ratio==",
            kargs.get("distillation_ratio", 0.9) *
            tf.pow(kargs.get("temperature", 2.0), 2))

        # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
        loss = (1 -
                kargs.get("distillation_ratio", 0.9)) * label_loss + kargs.get(
                    "distillation_ratio", 0.9) * distillation_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op,
                            "cross_entropy": label_loss,
                            "kd_loss": distillation_loss,
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num": tf.reduce_sum(features["label_ratio"]),
                            "teacher_logit": teacher_logit,
                            "student_logit": student_logit,
                            "label_ratio": features["label_ratio"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
	def model_fn(features, labels, mode):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		# model_adv_config = copy.deepcopy(model_config)
		# model_adv_config.scope = model_config.scope + "/adv_encoder"

		# model_adv_adaptation = model_api(model_adv_config, features, labels,
		# 					mode, target, reuse=tf.AUTO_REUSE)

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		common_feature = model.get_pooled_output()

		task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/task_residual", if_grl=False)
		adv_task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual", if_grl=True)

		with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
			concat_feature = task_feature
			# concat_feature = tf.concat([task_feature, 
			# 							adv_task_feature], 
			# 							axis=-1)
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											concat_feature,
											num_labels,
											label_ids,
											dropout_prob,
											*kargs)

		with tf.variable_scope(scope+"/adv_classifier", reuse=tf.AUTO_REUSE):
			adv_ids = features["adv_ids"]
			(adv_loss, 
				adv_per_example_loss, 
				adv_logits) = classifier.classifier(model_config,
											adv_task_feature,
											kargs.get('adv_num_labels', 12),
											adv_ids,
											dropout_prob,
											**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			loss_diff = tf.constant(0.0)
			# adv_task_feature_no_grl = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual")

			# loss_diff = diff_loss(task_feature, 
			# 						adv_task_feature_no_grl)

			print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==")

			# get teacher logits
			teacher_logit = tf.log(features["label_probs"]+1e-10)/kargs.get("temperature", 2.0) # log_softmax logits
			student_logit = tf.nn.log_softmax(logits /kargs.get("temperature", 2.0)) # log_softmax logits

			distillation_loss = kd_distance(teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) 
			distillation_loss *= features["distillation_ratio"]
			distillation_loss = tf.reduce_sum(distillation_loss) / (1e-10+tf.reduce_sum(features["distillation_ratio"]))

			label_loss = loss #tf.reduce_sum(per_example_loss * features["label_ratio"]) / (1e-10+tf.reduce_sum(features["label_ratio"]))
			print("==distillation loss ratio==", kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2))

			# loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
			loss = (1-kargs.get("distillation_ratio", 0.9))*label_loss + kargs.get("distillation_ratio", 0.9) * distillation_loss
			if mode == tf.estimator.ModeKeys.TRAIN:
				loss += kargs.get("adv_ratio", 0.1) * adv_loss + loss_diff

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				if output_type == "sess":

					adv_pred_label = tf.argmax(adv_logits, axis=-1, output_type=tf.int32)
					adv_correct = tf.equal(
						tf.cast(adv_pred_label, tf.int32),
						tf.cast(adv_ids, tf.int32)
					)
					adv_accuracy = tf.reduce_mean(tf.cast(adv_correct, tf.float32))                 
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op,
										"cross_entropy":label_loss,
										"kd_loss":distillation_loss,
										"kd_num":tf.reduce_sum(features["distillation_ratio"]),
										"ce_num":tf.reduce_sum(features["label_ratio"]),
										"teacher_logit":teacher_logit,
										"student_logit":student_logit,
										"label_ratio":features["label_ratio"],
										"loss_diff":loss_diff,
										"adv_loss":adv_loss,
										"adv_accuracy":adv_accuracy
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			task_prob = tf.exp(tf.nn.log_softmax(logits))
			adv_prob = tf.exp(tf.nn.log_softmax(adv_logits))
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'adv_prob':adv_prob,
												"task_prob":task_prob
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'adv_prob':adv_prob,
														"task_prob":task_prob
													}
												)
									}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Exemplo n.º 5
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope + "/feature_output", reuse=model_reuse):
            hidden_size = bert_utils.get_shape_list(model.get_pooled_output(),
                                                    expected_rank=2)[-1]
            feature_output = tf.layers.dense(
                model.get_pooled_output(),
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01))
            feature_output = tf.nn.dropout(feature_output,
                                           keep_prob=1 - dropout_prob)
            feature_output += model.get_pooled_output()
            feature_output = tf.layers.dense(
                feature_output,
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01),
                activation=tf.tanh)

        if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
            with tf.variable_scope(scope, reuse=model_reuse):
                (loss, per_example_loss,
                 logits) = classifier.classifier(model_config, feature_output,
                                                 num_labels, label_ids,
                                                 dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # print(logits.get_shape(), "===logits shape===")
            # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # prob = tf.nn.softmax(logits)
            # max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    # 'pred_label':pred_label,
                    # "max_prob":max_prob
                    "embedding": feature_output
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput(
                        {"embedding": feature_output})
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()