def model_fn(features, labels, mode):

        model = bert_encoder(model_config,
                             features,
                             labels,
                             mode,
                             target,
                             reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

            temperature_log_prob = tf.nn.log_softmax(
                logits / kargs.get("temperature", 2))

            return [loss, per_example_loss, logits, temperature_log_prob]
Пример #2
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope='teacher')
        return_dict = {
            "loss": loss,
            "logits": logits,
            "tvars": tvars,
            "model": model,
            "per_example_loss": per_example_loss
        }
        return return_dict
def adversarial_loss(model_config, feature, adv_ids, dropout_prob, model_reuse,
					**kargs):
	'''make the task classifier cannot reliably predict the task based on 
	the shared feature
	'''
	# input = tf.stop_gradient(input)
	feature = tf.nn.dropout(feature, 1 - dropout_prob)

	with tf.variable_scope(model_config.scope+"/adv_classifier", reuse=model_reuse):
		(adv_loss, 
			adv_per_example_loss, 
			adv_logits) = classifier.classifier(model_config,
										feature,
										kargs.get('adv_num_labels', 7),
										adv_ids,
										dropout_prob)
	return (adv_loss, adv_per_example_loss, adv_logits)
Пример #4
0
        def build_discriminator(model, scope, reuse):

            with tf.variable_scope(scope, reuse=reuse):

                try:
                    label_ratio_table = tf.get_variable(
                        name="label_ratio",
                        initializer=tf.constant(label_tensor),
                        trainable=False)

                    ratio_weight = tf.nn.embedding_lookup(
                        label_ratio_table, label_ids)
                    print("==applying class weight==")
                except:
                    ratio_weight = None

                (loss, per_example_loss,
                 logits) = classifier.classifier(model_config,
                                                 model.get_pooled_output(),
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)
                return loss, per_example_loss, logits
Пример #5
0
	def model_fn(features, labels, mode):
		print(features)
		input_ids = features["input_ids"]
		input_mask = features["input_mask"]
		segment_ids = features["segment_ids"]
		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			hidden_dropout_prob = model_config.hidden_dropout_prob
			attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
			dropout_prob = model_config.dropout_prob
		else:
			hidden_dropout_prob = 0.0
			attention_probs_dropout_prob = 0.0
			dropout_prob = 0.0

		model = bert.Bert(model_config)
		model.build_embedder(input_ids, segment_ids,
											hidden_dropout_prob,
											attention_probs_dropout_prob,
											reuse=reuse)
		model.build_encoder(input_ids,
											input_mask,
											hidden_dropout_prob, 
											attention_probs_dropout_prob,
											reuse=reuse)
		model.build_pooler(reuse=reuse)

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											model.get_pooled_output(),
											num_labels,
											label_ids,
											dropout_prob)

		# model_io_fn = model_io.ModelIO(model_io_config)
		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		if load_pretrained:
			model_io_fn.load_pretrained(pretrained_tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		tvars = pretrained_tvars
		model_io_fn.set_saver(var_lst=tvars)

		if mode == tf.estimator.ModeKeys.TRAIN:
			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				optimizer_fn = optimizer.Optimizer(opt_config)
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps)

				return [train_op, loss, per_example_loss, logits]
		else:
			model_io_fn.print_params(tvars, string=", trainable params")
			return [loss, loss, per_example_loss, logits]
Пример #6
0
	def model_fn(features, labels, mode):

		# model = bert_encoder(model_config, features, labels,
		# 					mode, target, reuse=model_reuse)

		model = albert_encoder(model_config, features, labels,
							mode, target, reuse=model_reuse)

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=model_reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											model.get_pooled_output(),
											num_labels,
											label_ids,
											dropout_prob)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			print("==update_ops==", update_ops)

			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)
			# train_op, hooks = model_io_fn.get_ema_hooks(train_op, 
			# 					tvars,
			# 					kargs.get('params_moving_average_decay', 0.99),
			# 					scope, mode, 
			# 					first_stage_steps=opt_config.num_warmup_steps,
			# 					two_stage=True)

			model_io_fn.set_saver()

			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op)
				
			return estimator_spec
		elif mode == tf.estimator.ModeKeys.EVAL:
			
			# _, hooks = model_io_fn.get_ema_hooks(None,
			# 							None,
			# 							kargs.get('params_moving_average_decay', 0.99), 
			# 							scope, mode)

			hooks = None

			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)

			eval_hooks = [hooks] if hooks else []
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops,
								evaluation_hooks=eval_hooks
								)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Пример #7
0
	def model_fn(features, labels, mode):

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		label_ids = features["label_ids"]

		model_lst = []
		for index, name in enumerate(target):
			if index > 0:
				reuse = True
			else:
				reuse = model_reuse
			model_lst.append(bert_encoding(model_config, features, labels, 
												mode, name,
												scope, dropout_rate, 
												reuse=reuse))

		[input_mask_a, repres_a] = model_lst[0]
		[input_mask_b, repres_b] = model_lst[1]

		output_a, output_b = alignment_aggerate(model_config, 
				repres_a, repres_b, 
				input_mask_a, 
				input_mask_b, 
				scope, 
				reuse=model_reuse)

		if model_config.pooling == "ave_max_pooling":
			pooling_fn = ave_max_pooling
		elif model_config.pooling == "multihead_pooling":
			pooling_fn = multihead_pooling

		repres_a = pooling_fn(model_config, output_a, 
					input_mask_a, 
					scope, 
					dropout_prob, 
					reuse=model_reuse)

		repres_b = pooling_fn(model_config, output_b,
					input_mask_b,
					scope, 
					dropout_prob,
					reuse=True)

		pair_repres = tf.concat([repres_a, repres_b,
					tf.abs(repres_a-repres_b),
					repres_b*repres_a], axis=-1)

		with tf.variable_scope(scope, reuse=model_reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											pair_repres,
											num_labels,
											label_ids,
											dropout_prob)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		if load_pretrained:
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
										kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks)

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)

				if output_type == "sess":
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec
		elif mode == tf.estimator.ModeKeys.PREDICT:
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			prob = tf.nn.softmax(logits)
			max_prob = tf.reduce_max(prob, axis=-1)
			
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'pred_label':pred_label,
												"max_prob":max_prob
								  	},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'pred_label':pred_label,
														"max_prob":max_prob
													}
												)
								  	}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"loss": sentence_mean_loss,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)
			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Пример #8
0
    def model_fn(features, labels, mode):

        # model_api = model_zoo(model_config)

        # model = model_api(model_config, features, labels,
        # 					mode, target, reuse=model_reuse)

        task_type = kargs.get("task_type", "cls")

        label_ids = features["{}_label_ids".format(task_type)]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope + "/{}/classifier".format(task_type),
                               reuse=task_layer_reuse):
            (_, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        task_mask = tf.cast(features["{}_loss_multiplier".format(task_type)],
                            tf.float32)

        masked_per_example_loss = task_mask * per_example_loss
        loss = tf.reduce_sum(masked_per_example_loss) / (
            1e-10 + tf.reduce_sum(task_mask))

        logits = tf.expand_dims(task_mask, axis=-1) * logits

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        # try:
        # 	params_size = model_io_fn.count_params(model_config.scope)
        # 	print("==total params==", params_size)
        # except:
        # 	print("==not count params==")
        # print(tvars)
        # if load_pretrained == "yes":
        # 	model_io_fn.load_pretrained(tvars,
        # 								init_checkpoint,
        # 								exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            return {
                "loss": loss,
                "logits": logits,
                "task_num": tf.reduce_sum(task_mask),
                "tvars": tvars
            }
        elif mode == tf.estimator.ModeKeys.EVAL:
            return {
                "loss": loss,
                "logits": logits,
                "feature": model.get_pooled_output()
            }
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)
            label_loss = tf.reduce_sum(
                per_example_loss * features["label_ratio"]) / (
                    1e-10 + tf.reduce_sum(features["label_ratio"]))

        if mode == tf.estimator.ModeKeys.TRAIN:

            distillation_api = distill.KnowledgeDistillation(
                kargs.get(
                    "disitllation_config",
                    Bunch({
                        "logits_ratio_decay": "constant",
                        "logits_ratio": 0.5,
                        "logits_decay_rate": 0.999,
                        "distillation": ['relation_kd', 'logits'],
                        "feature_ratio": 0.5,
                        "feature_ratio_decay": "constant",
                        "feature_decay_rate": 0.999,
                        "kd_type": "kd",
                        "scope": scope
                    })))
            # get teacher logits
            teacher_logit = tf.log(features["label_probs"] +
                                   1e-10) / kargs.get(
                                       "temperature",
                                       2.0)  # log_softmax logits
            student_logit = tf.nn.log_softmax(
                logits / kargs.get("temperature", 2.0))  # log_softmax logits

            distillation_features = {
                "student_logits_tensor": student_logit,
                "teacher_logits_tensor": teacher_logit,
                "student_feature_tensor": model.get_pooled_output(),
                "teacher_feature_tensor": features["distillation_feature"],
                "student_label": tf.ones_like(label_ids, dtype=tf.int32),
                "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32),
                "logits_ratio": kargs.get("logits_ratio", 0.5),
                "feature_ratio": kargs.get("logits_ratio", 0.5),
                "distillation_ratio": features["distillation_ratio"],
                "src_f_logit": logits,
                "tgt_f_logit": logits,
                "src_tensor": model.get_pooled_output(),
                "tgt_tensor": features["distillation_feature"]
            }

            distillation_loss = distillation_api.distillation(
                distillation_features,
                2,
                dropout_prob,
                model_reuse,
                opt_config.num_train_steps,
                feature_ratio=1.0,
                logits_ratio_decay="constant",
                feature_ratio_decay="constant",
                feature_decay_rate=0.999,
                logits_decay_rate=0.999,
                logits_ratio=0.5,
                scope=scope + "/adv_classifier",
                num_classes=num_labels,
                gamma=kargs.get("gamma", 4))

            loss = label_loss + distillation_loss["distillation_loss"]

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":

                    return {
                        "train": {
                            "loss":
                            loss,
                            "logits":
                            logits,
                            "train_op":
                            train_op,
                            "cross_entropy":
                            label_loss,
                            "distillation_loss":
                            distillation_loss["distillation_loss"],
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num":
                            tf.reduce_sum(features["label_ratio"]),
                            "label_ratio":
                            features["label_ratio"],
                            "distilaltion_logits_loss":
                            distillation_loss["distillation_logits_loss"],
                            "distilaltion_feature_loss":
                            distillation_loss["distillation_feature_loss"],
                            "rkd_loss":
                            distillation_loss["rkd_loss"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Пример #10
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse,
                          **kargs)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                print(tf.global_variables(), "==global_variables==")
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # if model_config.get('label_type', 'single_label') == 'single_label':
            # 	print(logits.get_shape(), "===logits shape===")
            # 	pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # 	prob = tf.nn.softmax(logits)
            # 	max_prob = tf.reduce_max(prob, axis=-1)

            # 	estimator_spec = tf.estimator.EstimatorSpec(
            # 							mode=mode,
            # 							predictions={
            # 										'pred_label':pred_label,
            # 										"max_prob":max_prob
            # 							},
            # 							export_outputs={
            # 								"output":tf.estimator.export.PredictOutput(
            # 											{
            # 												'pred_label':pred_label,
            # 												"max_prob":max_prob
            # 											}
            # 										)
            # 							}
            # 				)
            if model_config.get('label_type', 'single_label') == 'multi_label':
                prob = tf.nn.sigmoid(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            elif model_config.get('label_type',
                                  'single_label') == "single_label":
                prob = tf.nn.softmax(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Пример #11
0
    def model_fn(features, labels, mode):

        task_type = kargs.get("task_type", "cls")

        label_ids = features["{}_label_ids".format(task_type)]

        num_task = kargs.get('num_task', 1)

        model_io_fn = model_io.ModelIO(model_io_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        if kargs.get("get_pooled_output", "pooled_output") == "pooled_output":
            pooled_feature = model.get_pooled_output()
        elif kargs.get("get_pooled_output", "task_output") == "task_output":
            pooled_feature_dict = model.get_task_output()
            pooled_feature = pooled_feature_dict['pooled_feature']

        loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)],
                            tf.float32)
        loss = tf.constant(0.0)

        params_size = model_io_fn.count_params(model_config.scope)
        print("==total encoder params==", params_size)

        if kargs.get("feature_distillation", True):
            universal_feature_a = features.get("input_ids_a_features", None)
            universal_feature_b = features.get("input_ids_b_features", None)

            if universal_feature_a is None or universal_feature_b is None:
                tf.logging.info(
                    "****** not apply feature distillation *******")
                feature_loss = tf.constant(0.0)
            else:
                feature_a = pooled_feature_dict['feature_a']
                feature_a_shape = bert_utils.get_shape_list(
                    feature_a, expected_rank=[2, 3])
                pretrain_feature_a_shape = bert_utils.get_shape_list(
                    universal_feature_a, expected_rank=[2, 3])
                if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
                    with tf.variable_scope(scope + "/feature_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_feature_a = tf.layers.dense(
                            feature_a, pretrain_feature_a_shape[-1])
                    # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
                    # 	proj_feature_a_rec = tf.layers.dense(proj_feature_a, feature_a_shape[-1])
                    # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_a_rec-feature_a), axis=-1))/float(num_task)
                    tf.logging.info(
                        "****** apply auto-encoder for feature compression *******"
                    )
                else:
                    proj_feature_a = feature_a
                feature_a_norm = tf.stop_gradient(
                    tf.sqrt(
                        tf.reduce_sum(tf.pow(proj_feature_a, 2),
                                      axis=-1,
                                      keepdims=True)) + 1e-20)
                proj_feature_a /= feature_a_norm

                feature_b = pooled_feature_dict['feature_b']
                if feature_a_shape[-1] != pretrain_feature_a_shape[-1]:
                    with tf.variable_scope(scope + "/feature_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_feature_b = tf.layers.dense(
                            feature_b, pretrain_feature_a_shape[-1])
                    # with tf.variable_scope(scope+"/feature_rec", reuse=tf.AUTO_REUSE):
                    # 	proj_feature_b_rec = tf.layers.dense(proj_feature_b, feature_a_shape[-1])
                    # loss += tf.reduce_mean(tf.reduce_sum(tf.square(proj_feature_b_rec-feature_b), axis=-1))/float(num_task)
                    tf.logging.info(
                        "****** apply auto-encoder for feature compression *******"
                    )
                else:
                    proj_feature_b = feature_b

                feature_b_norm = tf.stop_gradient(
                    tf.sqrt(
                        tf.reduce_sum(tf.pow(proj_feature_b, 2),
                                      axis=-1,
                                      keepdims=True)) + 1e-20)
                proj_feature_b /= feature_b_norm

                feature_a_distillation = tf.reduce_mean(
                    tf.square(universal_feature_a - proj_feature_a), axis=-1)
                feature_b_distillation = tf.reduce_mean(
                    tf.square(universal_feature_b - proj_feature_b), axis=-1)

                feature_loss = tf.reduce_mean(
                    (feature_a_distillation + feature_b_distillation) /
                    2.0) / float(num_task)
                loss += feature_loss
                tf.logging.info(
                    "****** apply prertained feature distillation *******")

        if kargs.get("embedding_distillation", True):
            word_embed = model.emb_mat
            random_embed_shape = bert_utils.get_shape_list(
                word_embed, expected_rank=[2, 3])
            print("==random_embed_shape==", random_embed_shape)
            pretrained_embed = kargs.get('pretrained_embed', None)
            if pretrained_embed is None:
                tf.logging.info(
                    "****** not apply prertained feature distillation *******")
                embed_loss = tf.constant(0.0)
            else:
                pretrain_embed_shape = bert_utils.get_shape_list(
                    pretrained_embed, expected_rank=[2, 3])
                print("==pretrain_embed_shape==", pretrain_embed_shape)
                if random_embed_shape[-1] != pretrain_embed_shape[-1]:
                    with tf.variable_scope(scope + "/embedding_proj",
                                           reuse=tf.AUTO_REUSE):
                        proj_embed = tf.layers.dense(word_embed,
                                                     pretrain_embed_shape[-1])
                else:
                    proj_embed = word_embed

                embed_loss = tf.reduce_mean(
                    tf.reduce_mean(tf.square(proj_embed - pretrained_embed),
                                   axis=-1)) / float(num_task)
                loss += embed_loss
                tf.logging.info(
                    "****** apply prertained feature distillation *******")

        with tf.variable_scope(scope + "/{}/classifier".format(task_type),
                               reuse=task_layer_reuse):
            (_, per_example_loss,
             logits) = classifier.classifier(model_config, pooled_feature,
                                             num_labels, label_ids,
                                             dropout_prob)

        loss_mask = tf.cast(features["{}_loss_multipiler".format(task_type)],
                            tf.float32)
        masked_per_example_loss = per_example_loss * loss_mask
        task_loss = tf.reduce_sum(masked_per_example_loss) / (
            1e-10 + tf.reduce_sum(loss_mask))
        loss += task_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_positions = features["masked_lm_positions"]
                masked_lm_ids = features["masked_lm_ids"]
                masked_lm_weights = features["masked_lm_weights"]
                (masked_lm_loss, masked_lm_example_loss,
                 masked_lm_log_probs) = pretrain.get_masked_lm_output(
                     model_config,
                     model.get_sequence_output(),
                     model.get_embedding_table(),
                     masked_lm_positions,
                     masked_lm_ids,
                     masked_lm_weights,
                     reuse=model_reuse)

                masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones(
                    (1,
                     multi_task_config[task_type]["max_predictions_per_seq"]))
                masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, ))

                masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_loss_mask *= tf.cast(masked_lm_label_weights,
                                               tf.float32)

                masked_lm_example_loss *= masked_lm_loss_mask  # multiply task_mask
                masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (
                    1e-10 + tf.reduce_sum(masked_lm_loss_mask))
                loss += multi_task_config[task_type][
                    "masked_lm_loss_ratio"] * masked_lm_loss

                masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1])

                print(masked_lm_log_probs.get_shape(),
                      "===masked lm log probs===")
                print(masked_lm_label_ids.get_shape(), "===masked lm ids===")
                print(masked_lm_label_weights.get_shape(),
                      "===masked lm mask===")

                lm_acc = build_accuracy(masked_lm_log_probs,
                                        masked_lm_label_ids,
                                        masked_lm_loss_mask)

        if kargs.get("task_invariant", "no") == "yes":
            print("==apply task adversarial training==")
            with tf.variable_scope(scope + "/dann_task_invariant",
                                   reuse=model_reuse):
                (_, task_example_loss,
                 task_logits) = distillation_utils.feature_distillation(
                     model.get_pooled_output(), 1.0, features["task_id"],
                     kargs.get("num_task", 7), dropout_prob, True)
                masked_task_example_loss = loss_mask * task_example_loss
                masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (
                    1e-10 + tf.reduce_sum(loss_mask))
                loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_pretrain_tvars = model_io_fn.get_params(
                    "cls/predictions", not_storage_params=not_storage_params)
                tvars.extend(masked_lm_pretrain_tvars)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        # print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            acc = build_accuracy(logits, label_ids, loss_mask)

            return_dict = {
                "loss": loss,
                "logits": logits,
                "task_num": tf.reduce_sum(loss_mask),
                "tvars": tvars
            }
            return_dict["{}_acc".format(task_type)] = acc
            if kargs.get("task_invariant", "no") == "yes":
                return_dict["{}_task_loss".format(
                    task_type)] = masked_task_loss
                task_acc = build_accuracy(task_logits, features["task_id"],
                                          loss_mask)
                return_dict["{}_task_acc".format(task_type)] = task_acc
            if multi_task_config[task_type].get("lm_augumentation", False):
                return_dict["{}_masked_lm_loss".format(
                    task_type)] = masked_lm_loss
                return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc
            if kargs.get("embedding_distillation", True):
                return_dict["embed_loss"] = embed_loss * float(num_task)
            else:
                return_dict["embed_loss"] = task_loss
            if kargs.get("feature_distillation", True):
                return_dict["feature_loss"] = feature_loss * float(num_task)
            else:
                return_dict["feature_loss"] = task_loss
            return_dict["task_loss"] = task_loss
            return return_dict
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_dict = {
                "loss": loss,
                "logits": logits,
                "feature": model.get_pooled_output()
            }
            if kargs.get("adversarial", "no") == "adversarial":
                eval_dict["task_logits"] = task_logits
            return eval_dict
    def model_fn(features, labels, mode):
        print(features)
        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = bert.Bert(model_config)
        model.build_embedder(input_ids,
                             segment_ids,
                             hidden_dropout_prob,
                             attention_probs_dropout_prob,
                             reuse=reuse)
        model.build_encoder(input_ids,
                            input_mask,
                            hidden_dropout_prob,
                            attention_probs_dropout_prob,
                            reuse=reuse)
        model.build_pooler(reuse=reuse)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        tvars = pretrained_tvars
        model_io_fn.set_saver(var_lst=tvars)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                # optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return tf.estimator.EstimatorSpec(mode=mode,
                                                  loss=loss,
                                                  train_op=train_op)

        elif mode == tf.estimator.ModeKeys.PREDICT:
            predictions = {
                # Generate predictions (for PREDICT and EVAL mode)
                "classes":
                tf.argmax(input=logits, axis=1),
                # Add `softmax_tensor` to the graph. It is used for PREDICT and by the
                # `logging_hook`.
                "probabilities":
                tf.exp(tf.nn.log_softmax(logits, name="softmax_tensor"))
            }
            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)
        elif mode == tf.estimator.ModeKeys.EVAL:
            """
			needs to manaually write metric ops
			see 
			https://github.com/google/seq2seq/blob/7f485894d412e8d81ce0e07977831865e44309ce/seq2seq/metrics/metric_specs.py
			"""
            return None
	def model_fn(features, labels, mode):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		# model_adv_config = copy.deepcopy(model_config)
		# model_adv_config.scope = model_config.scope + "/adv_encoder"

		# model_adv_adaptation = model_api(model_adv_config, features, labels,
		# 					mode, target, reuse=tf.AUTO_REUSE)

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		common_feature = model.get_pooled_output()

		task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/task_residual", if_grl=False)
		adv_task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual", if_grl=True)

		with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
			concat_feature = task_feature
			# concat_feature = tf.concat([task_feature, 
			# 							adv_task_feature], 
			# 							axis=-1)
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											concat_feature,
											num_labels,
											label_ids,
											dropout_prob,
											*kargs)

		with tf.variable_scope(scope+"/adv_classifier", reuse=tf.AUTO_REUSE):
			adv_ids = features["adv_ids"]
			(adv_loss, 
				adv_per_example_loss, 
				adv_logits) = classifier.classifier(model_config,
											adv_task_feature,
											kargs.get('adv_num_labels', 12),
											adv_ids,
											dropout_prob,
											**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			loss_diff = tf.constant(0.0)
			# adv_task_feature_no_grl = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual")

			# loss_diff = diff_loss(task_feature, 
			# 						adv_task_feature_no_grl)

			print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==")

			# get teacher logits
			teacher_logit = tf.log(features["label_probs"]+1e-10)/kargs.get("temperature", 2.0) # log_softmax logits
			student_logit = tf.nn.log_softmax(logits /kargs.get("temperature", 2.0)) # log_softmax logits

			distillation_loss = kd_distance(teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) 
			distillation_loss *= features["distillation_ratio"]
			distillation_loss = tf.reduce_sum(distillation_loss) / (1e-10+tf.reduce_sum(features["distillation_ratio"]))

			label_loss = loss #tf.reduce_sum(per_example_loss * features["label_ratio"]) / (1e-10+tf.reduce_sum(features["label_ratio"]))
			print("==distillation loss ratio==", kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2))

			# loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
			loss = (1-kargs.get("distillation_ratio", 0.9))*label_loss + kargs.get("distillation_ratio", 0.9) * distillation_loss
			if mode == tf.estimator.ModeKeys.TRAIN:
				loss += kargs.get("adv_ratio", 0.1) * adv_loss + loss_diff

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				if output_type == "sess":

					adv_pred_label = tf.argmax(adv_logits, axis=-1, output_type=tf.int32)
					adv_correct = tf.equal(
						tf.cast(adv_pred_label, tf.int32),
						tf.cast(adv_ids, tf.int32)
					)
					adv_accuracy = tf.reduce_mean(tf.cast(adv_correct, tf.float32))                 
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op,
										"cross_entropy":label_loss,
										"kd_loss":distillation_loss,
										"kd_num":tf.reduce_sum(features["distillation_ratio"]),
										"ce_num":tf.reduce_sum(features["label_ratio"]),
										"teacher_logit":teacher_logit,
										"student_logit":student_logit,
										"label_ratio":features["label_ratio"],
										"loss_diff":loss_diff,
										"adv_loss":adv_loss,
										"adv_accuracy":adv_accuracy
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			task_prob = tf.exp(tf.nn.log_softmax(logits))
			adv_prob = tf.exp(tf.nn.log_softmax(adv_logits))
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'adv_prob':adv_prob,
												"task_prob":task_prob
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'adv_prob':adv_prob,
														"task_prob":task_prob
													}
												)
									}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Пример #14
0
    def model_fn(features, labels, mode):

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = bert.Bert(model_config)
        model.build_embedder(input_ids,
                             segment_ids,
                             hidden_dropout_prob,
                             attention_probs_dropout_prob,
                             reuse=reuse)
        model.build_encoder(input_ids,
                            input_mask,
                            hidden_dropout_prob,
                            attention_probs_dropout_prob,
                            reuse=reuse)
        model.build_pooler(reuse=reuse)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

        print(logits.get_shape(), "===logits shape===")
        pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
        prob = tf.nn.softmax(logits)
        max_prob = tf.reduce_max(prob, axis=-1)

        output_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                 predictions={
                                                     'pred_label': pred_label,
                                                     "label_ids": label_ids,
                                                     "max_prob": max_prob
                                                 })
        return output_spec
Пример #15
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]
        model_lst = []
        for index, name in enumerate(input_name):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                base_model(model_config,
                           features,
                           labels,
                           mode,
                           name,
                           reuse=reuse))

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):

            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    shape=[
                        num_labels,
                    ],
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
            except:
                ratio_weight = None

            seq_output_lst = [model.get_pooled_output() for model in model_lst]
            repres = seq_output_lst[0] + seq_output_lst[1]

            final_hidden_shape = bert_utils.get_shape_list(repres,
                                                           expected_rank=2)

            z_mean = tf.layers.dense(repres,
                                     final_hidden_shape[1],
                                     name="z_mean")
            z_log_var = tf.layers.dense(repres,
                                        final_hidden_shape[1],
                                        name="z_log_var")
            print("=======applying vib============")
            if mode == tf.estimator.ModeKeys.TRAIN:
                print("====applying vib====")
                vib_connector = vib.VIB(vib_config)
                [kl_loss, latent_vector
                 ] = vib_connector.build_regularizer([z_mean, z_log_var])

                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, latent_vector,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

                loss += tf.reduce_mean(kl_loss)
            else:
                print("====applying z_mean for prediction====")
                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, z_mean,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Пример #16
0
    def model_fn(features, labels, mode):
        model = bert_encoder(model_config,
                             features,
                             labels,
                             mode,
                             target,
                             reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        if load_pretrained:
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        model_io_fn.set_saver(var_lst=tvars)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        }
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Пример #17
0
    def model_fn(features, labels, mode):

        task_type = kargs.get("task_type", "cls")

        label_ids = features["{}_label_ids".format(task_type)]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope + "/{}/classifier".format(task_type),
                               reuse=task_layer_reuse):
            (_, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        loss_mask = tf.cast(features["{}_loss_multiplier".format(task_type)],
                            tf.float32)
        masked_per_example_loss = per_example_loss * loss_mask
        loss = tf.reduce_sum(masked_per_example_loss) / (
            1e-10 + tf.reduce_sum(loss_mask))

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_positions = features["masked_lm_positions"]
                masked_lm_ids = features["masked_lm_ids"]
                masked_lm_weights = features["masked_lm_weights"]
                (masked_lm_loss, masked_lm_example_loss,
                 masked_lm_log_probs) = pretrain.get_masked_lm_output(
                     model_config,
                     model.get_sequence_output(),
                     model.get_embedding_table(),
                     masked_lm_positions,
                     masked_lm_ids,
                     masked_lm_weights,
                     reuse=model_reuse)

                masked_lm_loss_mask = tf.expand_dims(loss_mask, -1) * tf.ones(
                    (1,
                     multi_task_config[task_type]["max_predictions_per_seq"]))
                masked_lm_loss_mask = tf.reshape(masked_lm_loss_mask, (-1, ))

                masked_lm_label_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_loss_mask *= tf.cast(masked_lm_label_weights,
                                               tf.float32)

                masked_lm_example_loss *= masked_lm_loss_mask  # multiply task_mask
                masked_lm_loss = tf.reduce_sum(masked_lm_example_loss) / (
                    1e-10 + tf.reduce_sum(masked_lm_loss_mask))
                loss += multi_task_config[task_type][
                    "masked_lm_loss_ratio"] * masked_lm_loss

                masked_lm_label_ids = tf.reshape(masked_lm_ids, [-1])

                print(masked_lm_log_probs.get_shape(),
                      "===masked lm log probs===")
                print(masked_lm_label_ids.get_shape(), "===masked lm ids===")
                print(masked_lm_label_weights.get_shape(),
                      "===masked lm mask===")

                lm_acc = build_accuracy(masked_lm_log_probs,
                                        masked_lm_label_ids,
                                        masked_lm_loss_mask)

        if kargs.get("task_invariant", "no") == "yes":
            print("==apply task adversarial training==")
            with tf.variable_scope(scope + "/dann_task_invariant",
                                   reuse=model_reuse):
                (_, task_example_loss,
                 task_logits) = distillation_utils.feature_distillation(
                     model.get_pooled_output(), 1.0, features["task_id"],
                     kargs.get("num_task", 7), dropout_prob, True)
                masked_task_example_loss = loss_mask * task_example_loss
                masked_task_loss = tf.reduce_sum(masked_task_example_loss) / (
                    1e-10 + tf.reduce_sum(loss_mask))
                loss += kargs.get("task_adversarial", 1e-2) * masked_task_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        if mode == tf.estimator.ModeKeys.TRAIN:
            multi_task_config = kargs.get("multi_task_config", {})
            if multi_task_config[task_type].get("lm_augumentation", False):
                print("==apply lm_augumentation==")
                masked_lm_pretrain_tvars = model_io_fn.get_params(
                    "cls/predictions", not_storage_params=not_storage_params)
                tvars.extend(masked_lm_pretrain_tvars)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        # print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            acc = build_accuracy(logits, label_ids, loss_mask)

            return_dict = {
                "loss": loss,
                "logits": logits,
                "task_num": tf.reduce_sum(loss_mask),
                "tvars": tvars
            }
            return_dict["{}_acc".format(task_type)] = acc
            if kargs.get("task_invariant", "no") == "yes":
                return_dict["{}_task_loss".format(
                    task_type)] = masked_task_loss
                task_acc = build_accuracy(task_logits, features["task_id"],
                                          loss_mask)
                return_dict["{}_task_acc".format(task_type)] = task_acc
            if multi_task_config[task_type].get("lm_augumentation", False):
                return_dict["{}_masked_lm_loss".format(
                    task_type)] = masked_lm_loss
                return_dict["{}_masked_lm_acc".format(task_type)] = lm_acc
            return return_dict
        elif mode == tf.estimator.ModeKeys.EVAL:
            eval_dict = {
                "loss": loss,
                "logits": logits,
                "feature": model.get_pooled_output()
            }
            if kargs.get("adversarial", "no") == "adversarial":
                eval_dict["task_logits"] = task_logits
            return eval_dict
Пример #18
0
    def model_fn(features, labels, mode):

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = base_model(model_config, features, labels, mode, reuse=reuse)

        with tf.variable_scope(model_config.scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]
            (masked_lm_loss, masked_lm_example_loss,
             masked_lm_log_probs) = pretrain.get_masked_lm_output(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=reuse)
            total_loss = model_config.lm_ratio * masked_lm_loss + model_config.task_ratio * loss
        else:
            total_loss = loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            masked_lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            pretrained_tvars.extend(masked_lm_pretrain_tvars)

            if load_pretrained:
                model_io_fn.load_pretrained(pretrained_tvars,
                                            init_checkpoint,
                                            exclude_scope=exclude_scope)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                output_dict = {
                    "train_op": train_op,
                    "total_loss": total_loss,
                    "masked_lm_loss": masked_lm_loss,
                    "sentence_loss": loss
                }

                return tf.estimator.EstimatorSpec(mode=mode,
                                                  loss=total_loss,
                                                  train_op=train_op)

        elif mode == tf.estimator.ModeKeys.PREDICT:

            def prediction_fn(logits):

                predictions = {
                    "classes":
                    tf.argmax(input=logits, axis=1),
                    "probabilities":
                    tf.exp(tf.nn.log_softmax(logits, name="softmax_tensor"))
                }
                return predictions

            predictions = prediction_fn(logits)

            return tf.estimator.EstimatorSpec(mode=mode,
                                              predictions=predictions)

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights, per_example_loss,
                          logits, label_ids):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "sentence_f": sentence_f,
                    "sentence_loss": sentence_mean_loss
                }

            eval_metric_ops = metric_fn(masked_lm_example_loss,
                                        masked_lm_log_probs, masked_lm_ids,
                                        masked_lm_weights, per_example_loss,
                                        logits, label_ids)

            return tf.estimator.EstimatorSpec(mode=mode,
                                              loss=total_loss,
                                              eval_metric_ops=eval_metric_ops)
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)
        label_ids = features["label_ids"]

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        print(kargs.get("temperature", 0.5),
              kargs.get("distillation_ratio", 0.5),
              "==distillation hyparameter==")

        # anneal_fn = anneal_strategy.AnnealStrategy(kargs.get("anneal_config", {}))

        # get teacher logits
        teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get(
            "temperature", 2.0)  # log_softmax logits
        student_logit = tf.nn.log_softmax(
            logits / kargs.get("temperature", 2.0))  # log_softmax logits

        distillation_loss = kd_distance(
            teacher_logit, student_logit,
            kargs.get("distillation_distance", "kd"))
        distillation_loss *= features["distillation_ratio"]
        distillation_loss = tf.reduce_sum(distillation_loss) / (
            1e-10 + tf.reduce_sum(features["distillation_ratio"]))

        label_loss = tf.reduce_sum(
            per_example_loss * features["label_ratio"]) / (
                1e-10 + tf.reduce_sum(features["label_ratio"]))

        print(
            "==distillation loss ratio==",
            kargs.get("distillation_ratio", 0.9) *
            tf.pow(kargs.get("temperature", 2.0), 2))

        # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
        loss = (1 -
                kargs.get("distillation_ratio", 0.9)) * label_loss + tf.pow(
                    kargs.get("temperature", 2.0), 2) * kargs.get(
                        "distillation_ratio", 0.9) * distillation_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        params_size = model_io_fn.count_params(model_config.scope)
        print("==total params==", params_size)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op,
                            "cross_entropy": label_loss,
                            "kd_loss": distillation_loss,
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num": tf.reduce_sum(features["label_ratio"]),
                            "teacher_logit": teacher_logit,
                            "student_logit": student_logit,
                            "label_ratio": features["label_ratio"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Пример #20
0
    def model_fn(features, labels, mode):

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "/task"
        else:
            scope = model_config.scope

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        label_ids = features["label_ids"]

        target = kargs["target"]

        [a_mask, a_repres, b_mask,
         b_repres] = bert_lstm_encoding(model_config,
                                        features,
                                        labels,
                                        mode,
                                        target,
                                        max_len,
                                        scope,
                                        dropout_prob,
                                        reuse=model_reuse)

        a_repres = lstm_model(model_config, a_repres, a_mask, dropout_prob,
                              scope, model_reuse)

        b_repres = lstm_model(model_config, b_repres, b_mask, dropout_prob,
                              scope, True)

        a_output, b_output = alignment(model_config,
                                       a_repres,
                                       b_repres,
                                       a_mask,
                                       b_mask,
                                       scope,
                                       reuse=model_reuse)

        repres_a = bert_multihead_pooling(model_config,
                                          a_output,
                                          a_mask,
                                          scope,
                                          dropout_prob,
                                          reuse=model_reuse)

        repres_b = bert_multihead_pooling(model_config,
                                          b_output,
                                          b_mask,
                                          scope,
                                          dropout_prob,
                                          reuse=True)

        pair_repres = tf.concat([
            repres_a, repres_b,
            tf.abs(repres_a - repres_b), repres_b * repres_a
        ],
                                axis=-1)

        print(pair_repres.get_shape(), "==repres shape==")

        with tf.variable_scope(scope, reuse=model_reuse):

            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    shape=[
                        num_labels,
                    ],
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
                print("==applying class weight==")
            except:
                ratio_weight = None

            (loss, per_example_loss,
             logits) = classifier.classifier(model_config, pair_repres,
                                             num_labels, label_ids,
                                             dropout_prob, ratio_weight)
        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            if load_pretrained:
                model_io_fn.load_pretrained(
                    pretrained_tvars,
                    init_checkpoint,
                    exclude_scope=exclude_scope_dict["task"])

        trainable_params = model_io_fn.get_params(
            scope, not_storage_params=not_storage_params)

        tvars = trainable_params

        storage_params = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        # for var in storage_params:
        # 	print(var.name, var.get_shape(), "==storage params==")

        # for var in tvars:
        # 	print(var.name, var.get_shape(), "==trainable params==")

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Пример #21
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope + "/feature_output", reuse=model_reuse):
            hidden_size = bert_utils.get_shape_list(model.get_pooled_output(),
                                                    expected_rank=2)[-1]
            feature_output = tf.layers.dense(
                model.get_pooled_output(),
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01))
            feature_output = tf.nn.dropout(feature_output,
                                           keep_prob=1 - dropout_prob)
            feature_output += model.get_pooled_output()
            feature_output = tf.layers.dense(
                feature_output,
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01),
                activation=tf.tanh)

        if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
            with tf.variable_scope(scope, reuse=model_reuse):
                (loss, per_example_loss,
                 logits) = classifier.classifier(model_config, feature_output,
                                                 num_labels, label_ids,
                                                 dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # print(logits.get_shape(), "===logits shape===")
            # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # prob = tf.nn.softmax(logits)
            # max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    # 'pred_label':pred_label,
                    # "max_prob":max_prob
                    "embedding": feature_output
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput(
                        {"embedding": feature_output})
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()