Example #1
0
 def __init__(self,
              model_config,
              num_labels,
              init_checkpoint,
              load_pretrained=True,
              model_io_config={},
              opt_config={},
              exclude_scope="",
              not_storage_params=[],
              target="a",
              label_lst=None,
              output_type="sess",
              **kargs):
     self.model_config = model_config
     self.num_labels = num_labels
     self.init_checkpoint = init_checkpoint
     self.load_pretrained = load_pretrained
     self.model_io_config = model_io_config
     self.opt_config = opt_config
     self.exclude_scope = exclude_scope
     self.not_storage_params = not_storage_params
     self.target = target
     self.label_lst = label_lst
     self.output_type = output_type
     self.kargs = kargs
     self.model_io_fn = model_io.ModelIO(self.model_io_config)
     self.optimizer_fn = optimizer.Optimizer(self.opt_config)
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)
        label_ids = features["label_ids"]

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        print(kargs.get("temperature", 0.5),
              kargs.get("distillation_ratio", 0.5),
              "==distillation hyparameter==")

        # anneal_fn = anneal_strategy.AnnealStrategy(kargs.get("anneal_config", {}))

        # get teacher logits
        teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get(
            "temperature", 2.0)  # log_softmax logits
        student_logit = tf.nn.log_softmax(
            logits / kargs.get("temperature", 2.0))  # log_softmax logits

        distillation_loss = kd_distance(
            teacher_logit, student_logit,
            kargs.get("distillation_distance", "kd"))
        distillation_loss *= features["distillation_ratio"]
        distillation_loss = tf.reduce_sum(distillation_loss) / (
            1e-10 + tf.reduce_sum(features["distillation_ratio"]))

        label_loss = tf.reduce_sum(
            per_example_loss * features["label_ratio"]) / (
                1e-10 + tf.reduce_sum(features["label_ratio"]))

        print(
            "==distillation loss ratio==",
            kargs.get("distillation_ratio", 0.9) *
            tf.pow(kargs.get("temperature", 2.0), 2))

        # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
        loss = (1 -
                kargs.get("distillation_ratio", 0.9)) * label_loss + tf.pow(
                    kargs.get("temperature", 2.0), 2) * kargs.get(
                        "distillation_ratio", 0.9) * distillation_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        params_size = model_io_fn.count_params(model_config.scope)
        print("==total params==", params_size)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op,
                            "cross_entropy": label_loss,
                            "kd_loss": distillation_loss,
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num": tf.reduce_sum(features["label_ratio"]),
                            "teacher_logit": teacher_logit,
                            "student_logit": student_logit,
                            "label_ratio": features["label_ratio"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Example #3
0
	def model_fn(features, labels, mode, params):

		ebm_noise_fce = EBM_NOISE_FCE(model_config_dict,
									num_labels_dict,
									init_checkpoint_dict,
									load_pretrained_dict,
									model_io_config=model_io_config,
									opt_config=opt_config,
									exclude_scope_dict=exclude_scope_dict,
									not_storage_params_dict=not_storage_params_dict,
									target_dict=target_dict,
									**kargs)

		model_io_fn = model_io.ModelIO(model_io_config)

		if mode == tf.estimator.ModeKeys.TRAIN:

			if kargs.get('use_tpu', False):
				optimizer_fn = optimizer.Optimizer(opt_config)
				use_tpu = 1
			else:
				optimizer_fn = distributed_optimizer.Optimizer(opt_config)
				use_tpu = 0

			train_op, loss, var_checkpoint_dict_list = get_train_op(
								optimizer_fn, opt_config,
								model_config_dict['ebm_dist'], 
								model_config_dict['noise_dist'],
								features, labels, mode, params,
								ebm_noise_fce,
								use_tpu=use_tpu)

			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			
			if len(var_checkpoint_dict_list) >= 1:
				scaffold_fn = model_io_fn.load_multi_pretrained(
												var_checkpoint_dict_list,
												use_tpu=use_tpu)
			else:
				scaffold_fn = None

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
								mode=mode,
								loss=loss,
								train_op=train_op,
								scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(
								mode=mode, 
								loss=loss, 
								train_op=train_op)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			ebm_noise_fce.get_loss(features, labels, mode, params, **kargs)

			tpu_eval_metrics = (ebm_noise_eval_metric, 
								[
								ebm_noise_fce.true_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['true_logits'], 
								ebm_noise_fce.fake_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								ebm_noise_fce.noise_dist_dict["true_seq_logits"]
								])
			gpu_eval_metrics = ebm_noise_eval_metric(
								ebm_noise_fce.true_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['true_logits'], 
								ebm_noise_fce.fake_ebm_dist_dict['logits'], 
								ebm_noise_fce.noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								ebm_noise_fce.noise_dist_dict["true_seq_logits"]
								)

			loss = ebm_noise_fce.ebm_loss + ebm_noise_fce.noise_loss
			var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list

			if len(var_checkpoint_dict_list) >= 1:
				scaffold_fn = model_io_fn.load_multi_pretrained(
												var_checkpoint_dict_list,
												use_tpu=use_tpu)
			else:
				scaffold_fn = None

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							  mode=mode,
							  loss=loss,
							  eval_metrics=tpu_eval_metrics,
							  scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=gpu_eval_metrics)

			return estimator_spec
		else:
			raise NotImplementedError()
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)
            label_loss = tf.reduce_sum(
                per_example_loss * features["label_ratio"]) / (
                    1e-10 + tf.reduce_sum(features["label_ratio"]))

        if mode == tf.estimator.ModeKeys.TRAIN:

            distillation_api = distill.KnowledgeDistillation(
                kargs.get(
                    "disitllation_config",
                    Bunch({
                        "logits_ratio_decay": "constant",
                        "logits_ratio": 0.5,
                        "logits_decay_rate": 0.999,
                        "distillation": ['relation_kd', 'logits'],
                        "feature_ratio": 0.5,
                        "feature_ratio_decay": "constant",
                        "feature_decay_rate": 0.999,
                        "kd_type": "kd",
                        "scope": scope
                    })))
            # get teacher logits
            teacher_logit = tf.log(features["label_probs"] +
                                   1e-10) / kargs.get(
                                       "temperature",
                                       2.0)  # log_softmax logits
            student_logit = tf.nn.log_softmax(
                logits / kargs.get("temperature", 2.0))  # log_softmax logits

            distillation_features = {
                "student_logits_tensor": student_logit,
                "teacher_logits_tensor": teacher_logit,
                "student_feature_tensor": model.get_pooled_output(),
                "teacher_feature_tensor": features["distillation_feature"],
                "student_label": tf.ones_like(label_ids, dtype=tf.int32),
                "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32),
                "logits_ratio": kargs.get("logits_ratio", 0.5),
                "feature_ratio": kargs.get("logits_ratio", 0.5),
                "distillation_ratio": features["distillation_ratio"],
                "src_f_logit": logits,
                "tgt_f_logit": logits,
                "src_tensor": model.get_pooled_output(),
                "tgt_tensor": features["distillation_feature"]
            }

            distillation_loss = distillation_api.distillation(
                distillation_features,
                2,
                dropout_prob,
                model_reuse,
                opt_config.num_train_steps,
                feature_ratio=1.0,
                logits_ratio_decay="constant",
                feature_ratio_decay="constant",
                feature_decay_rate=0.999,
                logits_decay_rate=0.999,
                logits_ratio=0.5,
                scope=scope + "/adv_classifier",
                num_classes=num_labels,
                gamma=kargs.get("gamma", 4))

            loss = label_loss + distillation_loss["distillation_loss"]

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":

                    return {
                        "train": {
                            "loss":
                            loss,
                            "logits":
                            logits,
                            "train_op":
                            train_op,
                            "cross_entropy":
                            label_loss,
                            "distillation_loss":
                            distillation_loss["distillation_loss"],
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num":
                            tf.reduce_sum(features["label_ratio"]),
                            "label_ratio":
                            features["label_ratio"],
                            "distilaltion_logits_loss":
                            distillation_loss["distillation_logits_loss"],
                            "distilaltion_feature_loss":
                            distillation_loss["distillation_feature_loss"],
                            "rkd_loss":
                            distillation_loss["rkd_loss"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Example #5
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse,
                          **kargs)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                print(tf.global_variables(), "==global_variables==")
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # if model_config.get('label_type', 'single_label') == 'single_label':
            # 	print(logits.get_shape(), "===logits shape===")
            # 	pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # 	prob = tf.nn.softmax(logits)
            # 	max_prob = tf.reduce_max(prob, axis=-1)

            # 	estimator_spec = tf.estimator.EstimatorSpec(
            # 							mode=mode,
            # 							predictions={
            # 										'pred_label':pred_label,
            # 										"max_prob":max_prob
            # 							},
            # 							export_outputs={
            # 								"output":tf.estimator.export.PredictOutput(
            # 											{
            # 												'pred_label':pred_label,
            # 												"max_prob":max_prob
            # 											}
            # 										)
            # 							}
            # 				)
            if model_config.get('label_type', 'single_label') == 'multi_label':
                prob = tf.nn.sigmoid(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            elif model_config.get('label_type',
                                  'single_label') == "single_label":
                prob = tf.nn.softmax(logits)
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'pred_label': prob,
                        "max_prob": prob
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'pred_label': prob,
                            "max_prob": prob
                        })
                    })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #6
0
    def model_fn(features, labels, mode):

        train_ops = []
        train_hooks = []
        logits_dict = {}
        losses_dict = {}
        features_dict = {}
        tvars = []
        task_num_dict = {}
        multi_task_config = kargs.get('multi_task_config', {})

        total_loss = tf.constant(0.0)

        task_num = 0

        encoder = {}
        hook_dict = {}

        print(task_type_dict.keys(), "==task type dict==")
        num_task = len(task_type_dict)

        from data_generator import load_w2v
        flags = kargs.get('flags', Bunch({}))
        print(flags.pretrained_w2v_path, "===pretrain vocab path===")
        w2v_path = os.path.join(flags.buckets, flags.pretrained_w2v_path)
        vocab_path = os.path.join(flags.buckets, flags.vocab_file)

        # [w2v_embed, token2id,
        # id2token, is_extral_symbol, use_pretrained] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)

        # pretrained_embed = tf.cast(tf.constant(w2v_embed), tf.float32)
        pretrained_embed = None

        for index, task_type in enumerate(task_type_dict.keys()):
            if model_config_dict[task_type].model_type in model_type_lst:
                reuse = True
            else:
                reuse = None
                model_type_lst.append(model_config_dict[task_type].model_type)

            if model_config_dict[task_type].model_type not in encoder:
                model_api = model_zoo(model_config_dict[task_type])

                model = model_api(model_config_dict[task_type],
                                  features,
                                  labels,
                                  mode,
                                  target_dict[task_type],
                                  reuse=reuse,
                                  cnn_type=model_config_dict[task_type].get(
                                      'cnn_type', 'bi_dgcnn'))
                encoder[model_config_dict[task_type].model_type] = model

                # vae_kl_model = vae_model_fn(encoder[model_config_dict[task_type].model_type],
                # 			model_config_dict[task_type],
                # 			num_labels_dict[task_type],
                # 			init_checkpoint_dict[task_type],
                # 			reuse,
                # 			load_pretrained_dict[task_type],
                # 			model_io_config,
                # 			opt_config,
                # 			exclude_scope=exclude_scope_dict[task_type],
                # 			not_storage_params=not_storage_params_dict[task_type],
                # 			target=target_dict[task_type],
                # 			label_lst=None,
                # 			output_type=output_type,
                # 			task_layer_reuse=task_layer_reuse,
                # 			task_type=task_type,
                # 			num_task=num_task,
                # 			task_adversarial=1e-2,
                # 			get_pooled_output='task_output',
                # 			feature_distillation=False,
                # 			embedding_distillation=False,
                # 			pretrained_embed=pretrained_embed,
                # 			**kargs)
                # vae_result_dict = vae_kl_model(features, labels, mode)
                # tvars.extend(vae_result_dict['tvars'])
                # total_loss += vae_result_dict["loss"]
                # for key in vae_result_dict:
                # 	if key in ['perplexity', 'token_acc', 'kl_div']:
                # 		hook_dict[key] = vae_result_dict[key]
            print(encoder, "==encode==")

            if task_type_dict[task_type] == "cls_task":
                task_model_fn = cls_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
            elif task_type_dict[task_type] == "embed_task":
                task_model_fn = embed_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
                # cpc_model_fn = embed_cpc_model_fn(encoder[model_config_dict[task_type].model_type],
                # 								model_config_dict[task_type],
                # 								num_labels_dict[task_type],
                # 								init_checkpoint_dict[task_type],
                # 								reuse,
                # 								load_pretrained_dict[task_type],
                # 								model_io_config,
                # 								opt_config,
                # 								exclude_scope=exclude_scope_dict[task_type],
                # 								not_storage_params=not_storage_params_dict[task_type],
                # 								target=target_dict[task_type],
                # 								label_lst=None,
                # 								output_type=output_type,
                # 								task_layer_reuse=task_layer_reuse,
                # 								task_type=task_type,
                # 								num_task=num_task,
                # 								task_adversarial=1e-2,
                # 								get_pooled_output='task_output',
                # 								feature_distillation=False,
                # 								embedding_distillation=False,
                # 								pretrained_embed=pretrained_embed,
                # 								loss='contrastive_loss',
                # 								apply_head_proj=False,
                # 								**kargs)

                # cpc_result_dict = cpc_model_fn(features, labels, mode)
                # result_dict['loss'] += cpc_result_dict['loss']
                # result_dict['tvars'].extend(cpc_result_dict['tvars'])
                # hook_dict["{}_all_neg_loss".format(task_type)] = cpc_result_dict['loss']
                # hook_dict["{}_all_neg_num".format(task_type)] = cpc_result_dict['task_num']

            elif task_type_dict[task_type] == "cpc_task":
                task_model_fn = embed_cpc_v1_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    task_seperate_proj=True,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)

            elif task_type_dict[task_type] == "regression_task":
                task_model_fn = regression_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
            else:
                continue
            print("==SUCCEEDED IN LODING==", task_type)

            # result_dict = task_model_fn(features, labels, mode)
            logits_dict[task_type] = result_dict["logits"]
            losses_dict[task_type] = result_dict["loss"]  # task loss
            for key in [
                    "pos_num", "neg_num", "masked_lm_loss", "task_loss", "acc",
                    "task_acc", "masked_lm_acc"
            ]:
                name = "{}_{}".format(task_type, key)
                if name in result_dict:
                    hook_dict[name] = result_dict[name]
            hook_dict["{}_loss".format(task_type)] = result_dict["loss"]
            hook_dict["{}_num".format(task_type)] = result_dict["task_num"]
            print("==loss ratio==", task_type,
                  multi_task_config[task_type].get('loss_ratio', 1.0))
            total_loss += result_dict["loss"] * multi_task_config[
                task_type].get('loss_ratio', 1.0)
            hook_dict['embed_loss'] = result_dict["embed_loss"]
            hook_dict['feature_loss'] = result_dict["feature_loss"]
            hook_dict["{}_task_loss".format(
                task_type)] = result_dict["task_loss"]
            if 'positive_label' in result_dict:
                hook_dict["{}_task_positive_label".format(
                    task_type)] = result_dict["positive_label"]
            if mode == tf.estimator.ModeKeys.TRAIN:
                tvars.extend(result_dict["tvars"])
                task_num += result_dict["task_num"]
                task_num_dict[task_type] = result_dict["task_num"]
            elif mode == tf.estimator.ModeKeys.EVAL:
                features[task_type] = result_dict["feature"]

        hook_dict["total_loss"] = total_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn = model_io.ModelIO(model_io_config)

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(list(set(tvars)),
                                     string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)

            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, list(set(tvars)), opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver(optimizer_fn.opt)

                if kargs.get("task_index", 1) == 1 and kargs.get(
                        "run_config", None):
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                elif kargs.get("task_index", 1) == 1:
                    training_hooks = []
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

            if output_type == "sess":
                return {
                    "train": {
                        "total_loss": total_loss,
                        "loss": losses_dict,
                        "logits": logits_dict,
                        "train_op": train_op,
                        "task_num_dict": task_num_dict
                    },
                    "hooks": train_hooks
                }
            elif output_type == "estimator":

                hook_dict['learning_rate'] = optimizer_fn.learning_rate
                logging_hook = tf.train.LoggingTensorHook(hook_dict,
                                                          every_n_iter=100)
                training_hooks.append(logging_hook)

                print("==hook_dict==")

                print(hook_dict)

                for key in hook_dict:
                    tf.summary.scalar(key, hook_dict[key])
                    for index, task_type in enumerate(task_type_dict.keys()):
                        tmp = "{}_loss".format(task_type)
                        if tmp == key:
                            tf.summary.scalar(
                                "loss_gap_{}".format(task_type),
                                hook_dict["total_loss"] - hook_dict[key])
                for key in task_num_dict:
                    tf.summary.scalar(key + "_task_num", task_num_dict[key])

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:  # eval execute for each class solo

            def metric_fn(logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "logits": logits_dict,
                        "total_loss": total_loss,
                        "feature": features,
                        "loss": losses_dict
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = {}
                for key in logits_dict:
                    eval_dict = metric_fn(logits_dict[key],
                                          features_task_dict[key]["label_ids"])
                    for sub_key in eval_dict.keys():
                        eval_key = "{}_{}".format(key, sub_key)
                        eval_metric_ops[eval_key] = eval_dict[sub_key]
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss / task_num,
                    eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #7
0
    def model_fn(features, labels, mode):

        shape_lst_a = bert_utils.get_shape_list(features['input_ids_a'])
        batch_size_a = shape_lst_a[0]
        total_length_a = shape_lst_a[1]

        shape_lst_b = bert_utils.get_shape_list(features['input_ids_b'])
        batch_size_b = shape_lst_b[0]
        total_length_b = shape_lst_b[1]

        features['input_ids_a'] = tf.reshape(features['input_ids_a'],
                                             [-1, model_config.max_length])
        features['segment_ids_a'] = tf.reshape(features['segment_ids_a'],
                                               [-1, model_config.max_length])
        features['input_mask_a'] = tf.cast(
            tf.not_equal(features['input_ids_a'], kargs.get('[PAD]', 0)),
            tf.int64)

        features['input_ids_b'] = tf.reshape(
            features['input_ids_b'],
            [-1, model_config.max_predictions_per_seq])
        features['segment_ids_b'] = tf.reshape(
            features['segment_ids_b'],
            [-1, model_config.max_predictions_per_seq])
        features['input_mask_b'] = tf.cast(
            tf.not_equal(features['input_ids_b'], kargs.get('[PAD]', 0)),
            tf.int64)

        features['batch_size'] = batch_size_a
        features['total_length_a'] = total_length_a
        features['total_length_b'] = total_length_b

        model_dict = {}
        for target in ["a", "b"]:
            model = bert_encoder(model_config,
                                 features,
                                 labels,
                                 mode,
                                 target,
                                 reuse=tf.AUTO_REUSE)
            model_dict[target] = model

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss, logits,
             transition_params) = multi_position_crf_classifier(
                 model_config, features, model_dict, num_labels, dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

            train_op, hooks = model_io_fn.get_ema_hooks(
                train_op,
                tvars,
                kargs.get('params_moving_average_decay', 0.99),
                scope,
                mode,
                first_stage_steps=opt_config.num_warmup_steps,
                two_stage=True)

            model_io_fn.set_saver()

            if kargs.get("task_index", 1) == 0 and kargs.get(
                    "run_config", None):
                training_hooks = []
            elif kargs.get("task_index", 1) == 0:
                model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                      kargs.get("num_storage_steps", 1000))

                training_hooks = model_io_fn.checkpoint_hook
            else:
                training_hooks = []

            if len(optimizer_fn.distributed_hooks) >= 1:
                training_hooks.extend(optimizer_fn.distributed_hooks)
            print(training_hooks, "==training_hooks==", "==task_index==",
                  kargs.get("task_index", 1))

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                training_hooks=training_hooks)
            print(tf.global_variables(), "==global_variables==")
            if output_type == "sess":
                return {
                    "train": {
                        "loss": loss,
                        "logits": logits,
                        "train_op": train_op
                    },
                    "hooks": training_hooks
                }
            elif output_type == "estimator":
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")

            label_weights = tf.cast(features['label_weights'], tf.int32)
            label_seq_length = tf.reduce_sum(label_weights, axis=-1)

            decode_tags, best_score = tf.contrib.crf.crf_decode(
                logits, transition_params, label_seq_length)

            _, hooks = model_io_fn.get_ema_hooks(
                None, None, kargs.get('params_moving_average_decay', 0.99),
                scope, mode)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'decode_tags': decode_tags,
                    "best_score": best_score,
                    "transition_params": transition_params,
                    "logits": logits
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'decode_tags': decode_tags,
                        "best_score": best_score,
                        "transition_params": transition_params,
                        "logits": logits
                    })
                },
                prediction_hooks=[hooks])
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            _, hooks = model_io_fn.get_ema_hooks(
                None, None, kargs.get('params_moving_average_decay', 0.99),
                scope, mode)
            eval_hooks = []

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":

                eval_metric_ops = eval_logtis(logits, features, num_labels,
                                              transition_params)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')

        ebm_noise_fce = EBM_NOISE_NCE(
            model_config_dict,
            num_labels_dict,
            init_checkpoint_dict,
            load_pretrained_dict,
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            **kargs)

        model_io_fn = model_io.ModelIO(model_io_config)
        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            train_op = get_train_op(ebm_noise_fce,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['ebm_dist'],
                                    model_config_dict['noise_dist'],
                                    model_config_dict['generator'],
                                    features,
                                    labels,
                                    mode,
                                    params,
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    alternate_order=['ebm', 'generator'])

            ebm_noise_fce.load_pretrained_model(**kargs)
            var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list
            loss = ebm_noise_fce.loss
            tvars = ebm_noise_fce.tvars

            if len(var_checkpoint_dict_list) >= 1:
                scaffold_fn = model_io_fn.load_multi_pretrained(
                    var_checkpoint_dict_list, use_tpu=use_tpu)
            else:
                scaffold_fn = None

            metric_dict = ebm_train_metric(
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits'])

            if not kargs.get('use_tpu', False):
                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("ebm_loss",
                                  ebm_noise_fce.ebm_opt_dict['ebm_loss'])
                tf.summary.scalar("mlm_loss",
                                  ebm_noise_fce.ebm_opt_dict['mlm_loss'])
                tf.summary.scalar("all_loss",
                                  ebm_noise_fce.ebm_opt_dict['all_loss'])

            model_io_fn.print_params(tvars, string=", trainable params")

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            ebm_noise_fce.get_loss(features, labels, mode, params, **kargs)
            ebm_noise_fce.load_pretrained_model(**kargs)
            var_checkpoint_dict_list = ebm_noise_fce.var_checkpoint_dict_list
            loss = ebm_noise_fce.loss

            if len(var_checkpoint_dict_list) >= 1:
                scaffold_fn = model_io_fn.load_multi_pretrained(
                    var_checkpoint_dict_list, use_tpu=use_tpu)
            else:
                scaffold_fn = None

            tpu_eval_metrics = (ebm_eval_metric, [
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits']
            ])
            gpu_eval_metrics = ebm_eval_metric(
                ebm_noise_fce.true_ebm_dist_dict['logits'],
                ebm_noise_fce.fake_ebm_dist_dict['logits'])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=model_reuse)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = pretrain.get_masked_lm_output(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=model_reuse)
        loss = model_config.lm_ratio * masked_lm_loss + model_config.nsp_ratio * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            lm_pretrain_tvars = model_io_fn.get_params(
                "cls", not_storage_params=not_storage_params)

            pretrained_tvars.extend(lm_pretrain_tvars)

            optimizer_fn = optimizer.Optimizer(opt_config)

            if load_pretrained:
                model_io_fn.load_pretrained(pretrained_tvars,
                                            init_checkpoint,
                                            exclude_scope=exclude_scope)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):

                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                train_op, hooks = model_io_fn.get_ema_hooks(
                    train_op,
                    tvars,
                    kargs.get('params_moving_average_decay', 0.99),
                    scope,
                    mode,
                    first_stage_steps=opt_config.num_warmup_steps,
                    two_stage=True)

                model_io_fn.set_saver()

                train_metric_dict = train_metric_fn(
                    masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights, nsp_per_example_loss, nsp_log_prob,
                    features['next_sentence_labels'])

                for key in train_metric_dict:
                    tf.summary.scalar(key, train_metric_dict[key])
                tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "nsp_log_pro": nsp_log_prob,
                            "train_op": train_op,
                            "masked_lm_loss": masked_lm_loss,
                            "next_sentence_loss": nsp_loss,
                            "masked_lm_log_pro": masked_lm_log_probs
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:

            def prediction_fn(logits):

                predictions = {
                    "nsp_classes":
                    tf.argmax(input=nsp_log_prob, axis=1),
                    "nsp_probabilities":
                    tf.exp(nsp_log_prob, name="nsp_softmax"),
                    "masked_vocab_classes":
                    tf.argmax(input=masked_lm_log_probs, axis=1),
                    "masked_probabilities":
                    tf.exp(masked_lm_log_probs, name='masked_softmax')
                }
                return predictions

            predictions = prediction_fn(logits)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={
                    "output": tf.estimator.export.PredictOutput(predictions)
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            if output_type == "sess":
                return {
                    "eval": {
                        "nsp_log_prob": nsp_log_prob,
                        "masked_lm_log_prob": masked_lm_log_probs,
                        "nsp_loss": nsp_loss,
                        "masked_lm_loss": masked_lm_loss,
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(masked_lm_example_loss,
                                            masked_lm_log_probs, masked_lm_ids,
                                            masked_lm_weights,
                                            nsp_per_example_loss, nsp_log_prob,
                                            features['next_sentence_labels'])
                _, hooks = model_io_fn.get_ema_hooks(
                    None, None, kargs.get('params_moving_average_decay', 0.99),
                    scope, mode)

                eval_hooks = [hooks] if hooks else []

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #10
0
    def model_fn(features, labels, mode, params):

        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            **kargs)
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        discriminator_features = {}
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []
        loss = discriminator_dict['loss']
        print(loss)
        tvars.extend(discriminator_dict['tvars'])
        if kargs.get('joint_train', '0') == '1':
            tvars.extend(generator_fn['tvars'])
            loss += generator_dict['loss']

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars": generator_dict['tvars'],
                        "init_checkpoint": init_checkpoint_dict['generator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars": discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope": exclude_scope_dict[key]
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            metric_dict = discriminator_metric_train(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            for key in metric_dict:
                tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':
                generator_metric = generator_metric_fn_eval(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None))
            else:
                generator_metric = {}

            discriminator_metric = discriminator_metric_eval(
                discriminator_dict['per_example_loss'],
                discriminator_dict['logits'],
                generator_dict['sampled_input_ids'],
                generator_dict['sampled_ids'],
                generator_dict['sampled_input_mask'])

            metric_dict = discriminator_metric
            if len(generator_metric):
                metric_dict.update(discriminator_metric)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=metric_dict,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=metric_dict)

            return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        if kargs.get('optimization_type', 'grl') == 'grl':
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            generator_fn = generator_normal(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        else:
            generator_fn = generator(
                model_config_dict['generator'],
                num_labels_dict['generator'],
                init_checkpoint_dict['generator'],
                model_reuse=None,
                load_pretrained=load_pretrained_dict['generator'],
                model_io_config=model_io_config,
                opt_config=opt_config,
                exclude_scope=exclude_scope_dict.get('generator', ""),
                not_storage_params=not_storage_params_dict.get(
                    'generator', []),
                target=target_dict['generator'],
                **kargs)
        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        nce_loss = nce_loss_fn(true_distriminator_dict,
                               true_distriminator_features,
                               fake_discriminator_dict,
                               fake_discriminator_features)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * nce_loss

        tvars.extend(fake_discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('summary_debug', False):
                metric_dict = discriminator_metric_train(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    list(set(tvars)),
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    fake_discriminator_dict['per_example_loss'],
                    fake_discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
	def model_fn(features, labels, mode):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		# model_adv_config = copy.deepcopy(model_config)
		# model_adv_config.scope = model_config.scope + "/adv_encoder"

		# model_adv_adaptation = model_api(model_adv_config, features, labels,
		# 					mode, target, reuse=tf.AUTO_REUSE)

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		common_feature = model.get_pooled_output()

		task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/task_residual", if_grl=False)
		adv_task_feature = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual", if_grl=True)

		with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
			concat_feature = task_feature
			# concat_feature = tf.concat([task_feature, 
			# 							adv_task_feature], 
			# 							axis=-1)
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											concat_feature,
											num_labels,
											label_ids,
											dropout_prob,
											*kargs)

		with tf.variable_scope(scope+"/adv_classifier", reuse=tf.AUTO_REUSE):
			adv_ids = features["adv_ids"]
			(adv_loss, 
				adv_per_example_loss, 
				adv_logits) = classifier.classifier(model_config,
											adv_task_feature,
											kargs.get('adv_num_labels', 12),
											adv_ids,
											dropout_prob,
											**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			loss_diff = tf.constant(0.0)
			# adv_task_feature_no_grl = get_task_feature(model_config, common_feature, dropout_prob, scope+"/adv_residual")

			# loss_diff = diff_loss(task_feature, 
			# 						adv_task_feature_no_grl)

			print(kargs.get("temperature", 0.5), kargs.get("distillation_ratio", 0.5), "==distillation hyparameter==")

			# get teacher logits
			teacher_logit = tf.log(features["label_probs"]+1e-10)/kargs.get("temperature", 2.0) # log_softmax logits
			student_logit = tf.nn.log_softmax(logits /kargs.get("temperature", 2.0)) # log_softmax logits

			distillation_loss = kd_distance(teacher_logit, student_logit, kargs.get("distillation_distance", "kd")) 
			distillation_loss *= features["distillation_ratio"]
			distillation_loss = tf.reduce_sum(distillation_loss) / (1e-10+tf.reduce_sum(features["distillation_ratio"]))

			label_loss = loss #tf.reduce_sum(per_example_loss * features["label_ratio"]) / (1e-10+tf.reduce_sum(features["label_ratio"]))
			print("==distillation loss ratio==", kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2))

			# loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
			loss = (1-kargs.get("distillation_ratio", 0.9))*label_loss + kargs.get("distillation_ratio", 0.9) * distillation_loss
			if mode == tf.estimator.ModeKeys.TRAIN:
				loss += kargs.get("adv_ratio", 0.1) * adv_loss + loss_diff

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				if output_type == "sess":

					adv_pred_label = tf.argmax(adv_logits, axis=-1, output_type=tf.int32)
					adv_correct = tf.equal(
						tf.cast(adv_pred_label, tf.int32),
						tf.cast(adv_ids, tf.int32)
					)
					adv_accuracy = tf.reduce_mean(tf.cast(adv_correct, tf.float32))                 
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op,
										"cross_entropy":label_loss,
										"kd_loss":distillation_loss,
										"kd_num":tf.reduce_sum(features["distillation_ratio"]),
										"ce_num":tf.reduce_sum(features["label_ratio"]),
										"teacher_logit":teacher_logit,
										"student_logit":student_logit,
										"label_ratio":features["label_ratio"],
										"loss_diff":loss_diff,
										"adv_loss":adv_loss,
										"adv_accuracy":adv_accuracy
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			task_prob = tf.exp(tf.nn.log_softmax(logits))
			adv_prob = tf.exp(tf.nn.log_softmax(adv_logits))
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'adv_prob':adv_prob,
												"task_prob":task_prob
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'adv_prob':adv_prob,
														"task_prob":task_prob
													}
												)
									}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Example #13
0
    def model_fn(features, labels, mode):

        features['input_mask'] = tf.cast(
            tf.not_equal(features['input_ids'], kargs.get('[PAD]', 0)),
            tf.int64)

        # for key in ['input_mask', 'input_ids', 'segment_ids']:
        # 	features[key] = features[key][:, :274]

        model = bert_encoder(model_config,
                             features,
                             labels,
                             mode,
                             target,
                             reuse=tf.AUTO_REUSE)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = multi_position_classifier(model_config, features,
                                                 model.get_sequence_output(),
                                                 num_labels, dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

            # train_op, hooks = model_io_fn.get_ema_hooks(train_op,
            # 				tvars,
            # 				kargs.get('params_moving_average_decay', 0.99),
            # 				scope, mode,
            # 				first_stage_steps=opt_config.num_warmup_steps,
            # 				two_stage=True)

            model_io_fn.set_saver()

            train_metric_dict = train_metric_fn(logits, features, num_labels)

            for key in train_metric_dict:
                tf.summary.scalar(key, train_metric_dict[key])
            tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

            if kargs.get("task_index", 1) == 0 and kargs.get(
                    "run_config", None):
                training_hooks = []
            elif kargs.get("task_index", 1) == 0:
                model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                      kargs.get("num_storage_steps", 1000))

                training_hooks = model_io_fn.checkpoint_hook
            else:
                training_hooks = []

            if len(optimizer_fn.distributed_hooks) >= 1:
                training_hooks.extend(optimizer_fn.distributed_hooks)
            print(training_hooks, "==training_hooks==", "==task_index==",
                  kargs.get("task_index", 1))

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                loss=loss,
                train_op=train_op,
                training_hooks=training_hooks)
            print(tf.global_variables(), "==global_variables==")
            if output_type == "sess":
                return {
                    "train": {
                        "loss": loss,
                        "logits": logits,
                        "train_op": train_op
                    },
                    "hooks": training_hooks
                }
            elif output_type == "estimator":
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            # _, hooks = model_io_fn.get_ema_hooks(None,
            # 							None,
            # 							kargs.get('params_moving_average_decay', 0.99),
            # 							scope, mode)

            hooks = []

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                },
                prediction_hooks=hooks)
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            _, hooks = model_io_fn.get_ema_hooks(
                None, None, kargs.get('params_moving_average_decay', 0.99),
                scope, mode)
            eval_hooks = []

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss),
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":

                eval_metric_ops = eval_logtis(logits, features, num_labels)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #14
0
	def model_fn(features, labels, mode, params):

		train_op_type = kargs.get('train_op_type', 'joint')
		print("==input shape==", features["input_ids"].get_shape())

		ebm_dist_fn = ebm_dist(model_config_dict['ebm_dist'],
					num_labels_dict['ebm_dist'],
					init_checkpoint_dict['ebm_dist'],
					model_reuse=None,
					load_pretrained=load_pretrained_dict['ebm_dist'],
					model_io_config=model_io_config,
					opt_config=opt_config,
					exclude_scope=exclude_scope_dict.get('ebm_dist', ""),
					not_storage_params=not_storage_params_dict.get('ebm_dist', []),
					target=target_dict['ebm_dist'],
					prob_ln=False,
					transform=False,
					transformer_activation="linear",
					logz_mode='standard',
					normalized_constant="length_linear",
					energy_pooling="mi",
					softplus_features=False,
					**kargs)

		noise_prob_ln = False
		noise_sample = kargs.get("noise_sample", 'mlm')

		if kargs.get("noise_sample", 'mlm') == 'gpt':
			tf.logging.info("****** using gpt for noise dist sample *******")
			sample_noise_dist = True
		elif kargs.get("noise_sample", 'mlm') == 'mlm':
			tf.logging.info("****** using bert mlm for noise dist sample *******")
			sample_noise_dist = False
		else:
			tf.logging.info("****** using gpt for noise dist sample *******")
			sample_noise_dist = True

		noise_dist_fn = noise_dist(model_config_dict['noise_dist'],
					num_labels_dict['noise_dist'],
					init_checkpoint_dict['noise_dist'],
					model_reuse=None,
					load_pretrained=load_pretrained_dict['noise_dist'],
					model_io_config=model_io_config,
					opt_config=opt_config,
					exclude_scope=exclude_scope_dict.get('noise_dist', ""),
					not_storage_params=not_storage_params_dict.get('noise_dist', []),
					target=target_dict['noise_dist'],
					noise_true_distribution=True,
					sample_noise_dist=sample_noise_dist,
					noise_estimator_type=kargs.get("noise_estimator_type", "stop_gradient"),
					prob_ln=noise_prob_ln,
					if_bp=True,
					**kargs)

		if not sample_noise_dist:
			tf.logging.info("****** using bert mlm for noise dist sample *******")

			global_step = tf.train.get_or_create_global_step()
			noise_sample_ratio = tf.train.polynomial_decay(
													0.20,
													global_step,
													opt_config.num_train_steps,
													end_learning_rate=0.1,
													power=1.0,
													cycle=False)

			mlm_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'],
						num_labels_dict['generator'],
						init_checkpoint_dict['generator'],
						model_reuse=None,
						load_pretrained=load_pretrained_dict['generator'],
						model_io_config=model_io_config,
						opt_config=opt_config,
						exclude_scope=exclude_scope_dict.get('generator', ""),
						not_storage_params=not_storage_params_dict.get('generator', []),
						target=target_dict['generator'],
						mask_probability=noise_sample_ratio,
						replace_probability=0.2,
						original_probability=0.0,
						**kargs)
		else:
			mlm_noise_dist_fn = None

		true_features = {}

		for key in features:
			if key == 'input_ori_ids':
				true_features["input_ids"] = tf.cast(features['input_ori_ids'], tf.int32)
			if key in ['input_mask', 'segment_ids']:
				true_features[key] = tf.cast(features[key], tf.int32)

		if kargs.get("dnce", False):

			if kargs.get("anneal_dnce", False):
				global_step = tf.train.get_or_create_global_step()
				noise_sample_ratio = tf.train.polynomial_decay(
														0.10,
														global_step,
														opt_config.num_train_steps,
														end_learning_rate=0.05,
														power=1.0,
														cycle=False)
				tf.logging.info("****** anneal dnce mix ratio *******")
			else:
				noise_sample_ratio = 0.10
				tf.logging.info("****** not anneal dnce mix ratio *******")

			mlm_noise_noise_dist_fn = mlm_noise_dist(model_config_dict['generator'],
						num_labels_dict['generator'],
						init_checkpoint_dict['generator'],
						model_reuse=None,
						load_pretrained=load_pretrained_dict['generator'],
						model_io_config=model_io_config,
						opt_config=opt_config,
						exclude_scope=exclude_scope_dict.get('generator', ""),
						not_storage_params=not_storage_params_dict.get('generator', []),
						target=target_dict['generator'],
						mask_probability=noise_sample_ratio,
						replace_probability=0.0,
						original_probability=0.0,
						**kargs)

			mlm_noise_dist_dict_noise = mlm_noise_noise_dist_fn(features, labels, mode, params)

			mixed_mask = mixed_sample(features, mix_ratio=noise_sample_ratio)
			tf.logging.info("****** apply dnce *******")
			mixed_mask = tf.expand_dims(mixed_mask, axis=-1) # batch_size x 1
			mixed_mask = tf.cast(mixed_mask, tf.int32)
			true_features["input_ids"] = (1-mixed_mask)*true_features["input_ids"] + mixed_mask * mlm_noise_dist_dict_noise['sampled_ids']

		if not sample_noise_dist:
			mlm_noise_dist_dict = mlm_noise_dist_fn(features, labels, mode, params)
		else:
			mlm_noise_dist_dict = {}

		# first get noise dict
		noise_dist_dict = noise_dist_fn(true_features, labels, mode, params)

		# third, get fake ebm dict
		fake_features = {}

		if noise_sample == 'gpt':
			if kargs.get("training_mode", "stop_gradient") == 'stop_gradient':
				fake_features["input_ids"] = noise_dist_dict['fake_samples']
				tf.logging.info("****** using samples stop gradient *******")
			elif kargs.get("training_mode", "stop_gradient") == 'adv_gumbel':
				fake_features["input_ids"] = noise_dist_dict['gumbel_probs']
				tf.logging.info("****** using samples with gradient *******")
			fake_features['input_mask'] = tf.cast(noise_dist_dict['fake_mask'], tf.int32)
			fake_features['segment_ids'] = tf.zeros_like(fake_features['input_mask'])
		elif noise_sample == 'mlm':
			fake_features["input_ids"] = mlm_noise_dist_dict['sampled_ids']
			fake_features['input_mask'] = tf.cast(features['input_mask'], tf.int32)
			fake_features['segment_ids'] = tf.zeros_like(features['input_mask'])
			tf.logging.info("****** using bert mlm stop gradient *******")

		# second, get true ebm dict
		true_ebm_dist_dict = ebm_dist_fn(true_features, labels, mode, params)
		fake_ebm_dist_dict = ebm_dist_fn(fake_features, labels, mode, params)
		if not sample_noise_dist:
			fake_noise_dist_dict = noise_dist_fn(fake_features, labels, mode, params)
			noise_dist_dict['fake_logits'] = fake_noise_dist_dict['true_logits']

		[ebm_loss, 
		ebm_all_true_loss,
		ebm_all_fake_loss] = get_ebm_loss(true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'], 
								use_tpu=kargs.get('use_tpu', False),
								valid_mask=mlm_noise_dist_dict.get("valid_mask", None))

		logz_length_true_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'],
															true_features,
															ebm_all_true_loss,
															valid_mask=mlm_noise_dist_dict.get("valid_mask", None))

		logz_length_fake_loss = ebm_logz_length_cond_loss(model_config_dict['ebm_dist'],
															fake_features,
															ebm_all_fake_loss,
															valid_mask=mlm_noise_dist_dict.get("valid_mask", None))
		true_ebm_dist_dict['logz_loss'] = logz_length_true_loss + logz_length_fake_loss

		noise_loss = get_noise_loss(true_ebm_dist_dict['logits'], 
									noise_dist_dict['true_logits'], 
									fake_ebm_dist_dict['logits'], 
									noise_dist_dict['fake_logits'], 
									noise_loss_type=kargs.get('noise_loss_type', 'jsd_noise'),
									num_train_steps=opt_config.num_train_steps,
									num_warmup_steps=opt_config.num_warmup_steps,
									use_tpu=kargs.get('use_tpu', False),
									loss_mask=features['input_mask'],
									prob_ln=noise_prob_ln)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = []
		loss = ebm_loss
		tvars.extend(true_ebm_dist_dict['tvars'])

		if kargs.get('joint_train', '1') == '1':
			tf.logging.info("****** joint generator and discriminator training *******")
			tvars.extend(noise_dist_dict['tvars'])
			loss += noise_loss
		tvars = list(set(tvars))

		ebm_opt_dict = {
			"loss":ebm_loss,
			"tvars":true_ebm_dist_dict['tvars'],
			"logz_tvars":true_ebm_dist_dict['logz_tvars'],
			"logz_loss":true_ebm_dist_dict['logz_loss']
		}

		noise_opt_dict = {
			"loss":noise_loss,
			"tvars":noise_dist_dict['tvars']
		}

		var_checkpoint_dict_list = []
		for key in init_checkpoint_dict:
			if load_pretrained_dict[key] == "yes":
				if key == 'ebm_dist':
					tmp = {
							"tvars":ebm_opt_dict['tvars']+ebm_opt_dict['logz_tvars'],
							"init_checkpoint":init_checkpoint_dict['ebm_dist'],
							"exclude_scope":exclude_scope_dict[key],
							"restore_var_name":model_config_dict['ebm_dist'].get('restore_var_name', [])
					}
					if kargs.get("sharing_mode", "none") != "none":
						tmp['exclude_scope'] = ''
					var_checkpoint_dict_list.append(tmp)
				elif key == 'noise_dist':
					tmp = {
							"tvars":noise_opt_dict['tvars'],
							"init_checkpoint":init_checkpoint_dict['noise_dist'],
							"exclude_scope":exclude_scope_dict[key],
							"restore_var_name":model_config_dict['noise_dist'].get('restore_var_name', [])
					}
					var_checkpoint_dict_list.append(tmp)
				elif key == 'generator':
					if not sample_noise_dist:
						tmp = {
								"tvars":mlm_noise_dist_dict['tvars'],
								"init_checkpoint":init_checkpoint_dict['generator'],
								"exclude_scope":exclude_scope_dict[key],
								"restore_var_name":model_config_dict['generator'].get('restore_var_name', [])
						}
						if kargs.get("sharing_mode", "none") != "none":
							tmp['exclude_scope'] = ''
						var_checkpoint_dict_list.append(tmp)

		use_tpu = 1 if kargs.get('use_tpu', False) else 0
			
		if len(var_checkpoint_dict_list) >= 1:
			scaffold_fn = model_io_fn.load_multi_pretrained(
											var_checkpoint_dict_list,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None

		if mode == tf.estimator.ModeKeys.TRAIN:

			metric_dict = ebm_noise_train_metric(
										true_ebm_dist_dict['logits'], 
										noise_dist_dict['true_logits'], 
										fake_ebm_dist_dict['logits'], 
										noise_dist_dict['fake_logits'],
										features['input_ori_ids'],
										tf.cast(features['input_mask'], tf.float32),
										noise_dist_dict["true_seq_logits"],
										prob_ln=noise_prob_ln,
										)

			if not kargs.get('use_tpu', False):
				for key in metric_dict:
					tf.summary.scalar(key, metric_dict[key])
				tf.summary.scalar("ebm_loss", ebm_opt_dict['loss'])
				tf.summary.scalar("noise_loss", noise_opt_dict['loss'])
	
			if kargs.get('use_tpu', False):
				optimizer_fn = optimizer.Optimizer(opt_config)
				use_tpu = 1
			else:
				optimizer_fn = distributed_optimizer.Optimizer(opt_config)
				use_tpu = 0

			model_io_fn.print_params(tvars, string=", trainable params")

			train_op = get_train_op(ebm_opt_dict, noise_opt_dict, 
								optimizer_fn, opt_config,
								model_config_dict['ebm_dist'], 
								model_config_dict['noise_dist'],
								use_tpu=use_tpu, 
								train_op_type=train_op_type,
								fce_acc=metric_dict['all_accuracy'])
			
			# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			# with tf.control_dependencies(update_ops):
			# 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
			# 					opt_config.init_lr, 
			# 					opt_config.num_train_steps,
			# 					use_tpu=use_tpu)

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
								mode=mode,
								loss=loss,
								train_op=train_op,
								scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(
								mode=mode, 
								loss=loss, 
								train_op=train_op)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			tpu_eval_metrics = (ebm_noise_eval_metric, 
								[
								true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								noise_dist_dict["true_seq_logits"]
								])
			gpu_eval_metrics = ebm_noise_eval_metric(
								true_ebm_dist_dict['logits'], 
								noise_dist_dict['true_logits'], 
								fake_ebm_dist_dict['logits'], 
								noise_dist_dict['fake_logits'],
								features['input_ori_ids'],
								tf.cast(features['input_mask'], tf.float32),
								noise_dist_dict["true_seq_logits"]
								)

			if kargs.get('use_tpu', False):
				estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							  mode=mode,
							  loss=loss,
							  eval_metrics=tpu_eval_metrics,
							  scaffold_fn=scaffold_fn)
			else:
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=gpu_eval_metrics)

			return estimator_spec
		else:
			raise NotImplementedError()
Example #15
0
	def model_fn(features, labels, mode):

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		label_ids = features["label_ids"]

		model_lst = []
		for index, name in enumerate(target):
			if index > 0:
				reuse = True
			else:
				reuse = model_reuse
			model_lst.append(bert_encoding(model_config, features, labels, 
												mode, name,
												scope, dropout_rate, 
												reuse=reuse))

		[input_mask_a, repres_a] = model_lst[0]
		[input_mask_b, repres_b] = model_lst[1]

		output_a, output_b = alignment_aggerate(model_config, 
				repres_a, repres_b, 
				input_mask_a, 
				input_mask_b, 
				scope, 
				reuse=model_reuse)

		if model_config.pooling == "ave_max_pooling":
			pooling_fn = ave_max_pooling
		elif model_config.pooling == "multihead_pooling":
			pooling_fn = multihead_pooling

		repres_a = pooling_fn(model_config, output_a, 
					input_mask_a, 
					scope, 
					dropout_prob, 
					reuse=model_reuse)

		repres_b = pooling_fn(model_config, output_b,
					input_mask_b,
					scope, 
					dropout_prob,
					reuse=True)

		pair_repres = tf.concat([repres_a, repres_b,
					tf.abs(repres_a-repres_b),
					repres_b*repres_a], axis=-1)

		with tf.variable_scope(scope, reuse=model_reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											pair_repres,
											num_labels,
											label_ids,
											dropout_prob)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		if load_pretrained:
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
										kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks)

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)

				if output_type == "sess":
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec
		elif mode == tf.estimator.ModeKeys.PREDICT:
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			prob = tf.nn.softmax(logits)
			max_prob = tf.reduce_max(prob, axis=-1)
			
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'pred_label':pred_label,
												"max_prob":max_prob
								  	},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'pred_label':pred_label,
														"max_prob":max_prob
													}
												)
								  	}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"loss": sentence_mean_loss,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)
			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Example #16
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method="all_mask",
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            **kargs)

        tf.logging.info("****** true sampled_ids of discriminator *******")
        true_distriminator_features = {}
        true_distriminator_features['input_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        true_distriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        true_distriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        true_distriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        true_distriminator_features['ori_input_ids'] = generator_dict[
            'sampled_input_ids']

        true_distriminator_dict = discriminator_fn(true_distriminator_features,
                                                   labels, mode, params)

        fake_discriminator_features = {}
        if kargs.get('minmax_mode', 'corrupted') == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif kargs.get('minmax_mode', 'corrupted') == 'masked':
            fake_discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditioanl sampled_ids *******")
        fake_discriminator_features['input_ids'] = generator_dict[
            'sampled_ids']
        fake_discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        fake_discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        fake_discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        fake_discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        fake_discriminator_features['ori_input_ids'] = generator_dict[
            'sampled_ids']

        fake_discriminator_dict = discriminator_fn(fake_discriminator_features,
                                                   labels, mode, params)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        output_dict = get_losses(true_distriminator_dict["logits"],
                                 fake_discriminator_dict["logits"],
                                 use_tpu=use_tpu,
                                 gan_type=kargs.get('gan_type', "JS"))

        discriminator_dict = {}
        discriminator_dict['gen_loss'] = output_dict['gen_loss']
        discriminator_dict['disc_loss'] = output_dict['disc_loss']
        discriminator_dict['tvars'] = fake_discriminator_dict['tvars']
        discriminator_dict['fake_logits'] = fake_discriminator_dict['logits']
        discriminator_dict['true_logits'] = true_distriminator_dict['logits']

        model_io_fn = model_io.ModelIO(model_io_config)

        loss = discriminator_dict['disc_loss']
        tvars = []
        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(discriminator_dict)

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_true_loss",
                                  discriminator_dict['disc_loss'])
                tf.summary.scalar("discriminator_fake_loss",
                                  discriminator_dict['gen_loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    discriminator_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 discriminator_dict):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        discriminator_dict)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels',
                                       None), discriminator_dict)
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict)
                tpu_eval_metrics = (discriminator_metric_eval,
                                    [discriminator_dict])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
Example #17
0
    def model_fn(features, labels, mode, params):

        train_op_type = kargs.get('train_op_type', 'joint')
        gen_disc_type = kargs.get('gen_disc_type', 'all_disc')
        mask_method = kargs.get('mask_method', 'only_mask')
        use_tpu = 1 if kargs.get('use_tpu', False) else 0
        print(train_op_type, "===train op type===", gen_disc_type,
              "===generator loss type===")
        if mask_method == 'only_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)
        elif mask_method == 'all_mask':
            tf.logging.info(
                "****** generator token generation mask type:%s with all token *******",
                mask_method)
        else:
            mask_method = 'only_mask'
            tf.logging.info(
                "****** generator token generation mask type:%s with only masked token *******",
                mask_method)

        if kargs.get('optimization_type', 'grl') == 'grl':
            if_flip_grad = True
            train_op_type = 'joint'
        elif kargs.get('optimization_type', 'grl') == 'minmax':
            if_flip_grad = False
        else:
            if_flip_grad = True
            train_op_type = 'joint'
        generator_fn = generator(
            model_config_dict['generator'],
            num_labels_dict['generator'],
            init_checkpoint_dict['generator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['generator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('generator', ""),
            not_storage_params=not_storage_params_dict.get('generator', []),
            target=target_dict['generator'],
            if_flip_grad=if_flip_grad,
            # mask_method='only_mask',
            **kargs)

        tf.logging.info("****** train_op_type:%s *******", train_op_type)
        tf.logging.info("****** optimization_type:%s *******",
                        kargs.get('optimization_type', 'grl'))
        generator_dict = generator_fn(features, labels, mode, params)

        discriminator_fn = discriminator_generator(
            model_config_dict['discriminator'],
            num_labels_dict['discriminator'],
            init_checkpoint_dict['discriminator'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['discriminator'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('discriminator', ""),
            not_storage_params=not_storage_params_dict.get(
                'discriminator', []),
            target=target_dict['discriminator'],
            loss='cross_entropy',
            **kargs)

        discriminator_features = {}
        # minmax_mode in ['masked', 'corrupted']
        minmax_mode = kargs.get('minmax_mode', 'corrupted')
        tf.logging.info("****** minmax mode for discriminator: %s *******",
                        minmax_mode)
        if minmax_mode == 'corrupted':
            tf.logging.info("****** gumbel 3-D sampled_ids *******")
        elif minmax_mode == 'masked':
            discriminator_features['ori_sampled_ids'] = generator_dict[
                'output_ids']
            discriminator_features['sampled_binary_mask'] = generator_dict[
                'sampled_binary_mask']
            tf.logging.info("****** conditional sampled_ids *******")
        discriminator_features['input_ids'] = generator_dict['sampled_ids']
        discriminator_features['input_mask'] = generator_dict[
            'sampled_input_mask']
        discriminator_features['segment_ids'] = generator_dict[
            'sampled_segment_ids']
        discriminator_features['input_ori_ids'] = generator_dict[
            'sampled_input_ids']
        discriminator_features['next_sentence_labels'] = features[
            'next_sentence_labels']
        discriminator_features['ori_input_ids'] = generator_dict['sampled_ids']

        discriminator_dict = discriminator_fn(discriminator_features, labels,
                                              mode, params)

        [disc_loss, disc_logits, disc_per_example_loss
         ] = optimal_discriminator(model_config_dict['discriminator'],
                                   generator_dict,
                                   features,
                                   discriminator_dict,
                                   discriminator_features,
                                   use_tpu=use_tpu)

        [
            equal_per_example_loss, equal_loss_all, equal_loss_self,
            not_equal_per_example_loss, not_equal_loss_all, not_equal_loss_self
        ] = modified_loss(disc_per_example_loss,
                          disc_logits,
                          discriminator_features['input_ori_ids'],
                          discriminator_features['ori_input_ids'],
                          discriminator_features['input_mask'],
                          sampled_binary_mask=discriminator_features.get(
                              'sampled_binary_mask', None),
                          **kargs)
        output_dict = {}
        output_dict['logits'] = disc_logits
        output_dict['per_example_loss'] = disc_per_example_loss
        output_dict['loss'] = disc_loss + 0.0 * discriminator_dict["loss"]
        output_dict["equal_per_example_loss"] = equal_per_example_loss,
        output_dict["equal_loss_all"] = equal_loss_all,
        output_dict["equal_loss_self"] = equal_loss_self,
        output_dict["not_equal_per_example_loss"] = not_equal_per_example_loss,
        output_dict["not_equal_loss_all"] = not_equal_loss_all,
        output_dict["not_equal_loss_self"] = not_equal_loss_self
        output_dict['tvars'] = discriminator_dict['tvars']

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = []

        loss = kargs.get('dis_loss', 1.0) * output_dict['loss']

        tvars.extend(discriminator_dict['tvars'])

        if kargs.get('joint_train', '1') == '1':
            tf.logging.info(
                "****** joint generator and discriminator training *******")
            tvars.extend(generator_dict['tvars'])
            loss += generator_dict['loss']
        tvars = list(set(tvars))

        var_checkpoint_dict_list = []
        for key in init_checkpoint_dict:
            if load_pretrained_dict[key] == "yes":
                if key == 'generator':
                    tmp = {
                        "tvars":
                        generator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['generator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['generator'].get(
                            'restore_var_name', [])
                    }
                    if kargs.get("sharing_mode", "none") != "none":
                        tmp['exclude_scope'] = ''
                    var_checkpoint_dict_list.append(tmp)
                elif key == 'discriminator':
                    tmp = {
                        "tvars":
                        discriminator_dict['tvars'],
                        "init_checkpoint":
                        init_checkpoint_dict['discriminator'],
                        "exclude_scope":
                        exclude_scope_dict[key],
                        "restore_var_name":
                        model_config_dict['discriminator'].get(
                            'restore_var_name', [])
                    }
                    var_checkpoint_dict_list.append(tmp)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if len(var_checkpoint_dict_list) >= 1:
            scaffold_fn = model_io_fn.load_multi_pretrained(
                var_checkpoint_dict_list, use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if not kargs.get('use_tpu', False):
                metric_dict = discriminator_metric_train(
                    output_dict['per_example_loss'], output_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])

                for key in metric_dict:
                    tf.summary.scalar(key, metric_dict[key])
                tf.summary.scalar("generator_loss", generator_dict['loss'])
                tf.summary.scalar("discriminator_loss",
                                  discriminator_dict['loss'])

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0

            model_io_fn.print_params(tvars, string=", trainable params")

            train_op = get_train_op(generator_dict,
                                    output_dict,
                                    optimizer_fn,
                                    opt_config,
                                    model_config_dict['generator'],
                                    model_config_dict['discriminator'],
                                    use_tpu=use_tpu,
                                    train_op_type=train_op_type,
                                    gen_disc_type=gen_disc_type)

            # update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            # with tf.control_dependencies(update_ops):
            # 	train_op = optimizer_fn.get_train_op(loss, list(set(tvars)),
            # 					opt_config.init_lr,
            # 					opt_config.num_train_steps,
            # 					use_tpu=use_tpu)

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn
                    # training_hooks=[logging_hook]
                )
            else:
                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            if kargs.get('joint_train', '0') == '1':

                def joint_metric(masked_lm_example_loss, masked_lm_log_probs,
                                 masked_lm_ids, masked_lm_weights,
                                 next_sentence_example_loss,
                                 next_sentence_log_probs, next_sentence_labels,
                                 per_example_loss, logits, input_ori_ids,
                                 input_ids, input_mask):
                    generator_metric = generator_metric_fn_eval(
                        masked_lm_example_loss, masked_lm_log_probs,
                        masked_lm_ids, masked_lm_weights,
                        next_sentence_example_loss, next_sentence_log_probs,
                        next_sentence_labels)
                    discriminator_metric = discriminator_metric_eval(
                        per_example_loss, logits, input_ori_ids, input_ids,
                        input_mask)
                    generator_metric.update(discriminator_metric)
                    return generator_metric

                tpu_eval_metrics = (joint_metric, [
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])
                gpu_eval_metrics = joint_metric(
                    generator_dict['masked_lm_example_loss'],
                    generator_dict['masked_lm_log_probs'],
                    generator_dict['masked_lm_ids'],
                    generator_dict['masked_lm_weights'],
                    generator_dict.get('next_sentence_example_loss', None),
                    generator_dict.get('next_sentence_log_probs', None),
                    generator_dict.get('next_sentence_labels', None),
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
            else:
                gpu_eval_metrics = discriminator_metric_eval(
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask'])
                tpu_eval_metrics = (discriminator_metric_eval, [
                    discriminator_dict['per_example_loss'],
                    discriminator_dict['logits'],
                    generator_dict['sampled_input_ids'],
                    generator_dict['sampled_ids'],
                    generator_dict['sampled_input_mask']
                ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
Example #18
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]
        sequence_mask = tf.cast(
            tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)),
            tf.int32)
        features['input_mask'] = sequence_mask

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        if 'input_ori_ids' in features:
            seq_features['input_ids'] = features["input_ori_ids"]
        else:
            features['input_ori_ids'] = seq_features['input_ids']

        not_equal = tf.cast(
            tf.not_equal(features["input_ori_ids"],
                         tf.zeros_like(features["input_ori_ids"])), tf.int32)
        not_equal = tf.reduce_sum(not_equal, axis=-1)
        loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)),
                            tf.float32)

        if not kargs.get('use_tpu', False):
            tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask))

        casual_flag = model_config.get('is_casual', True)
        tf.logging.info("***** is casual flag *****", str(casual_flag))

        if not casual_flag:
            [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                model_config,
                features['input_ori_ids'],
                features['input_mask'], [
                    tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                    for hmm_tran_prob in hmm_tran_prob_list
                ],
                mask_probability=0.02,
                replace_probability=0.01,
                original_probability=0.01,
                mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                **kargs)
            tf.logging.info("***** apply random sampling *****")
            seq_features['input_ids'] = output_ids

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        # if mode == tf.estimator.ModeKeys.TRAIN:
        if kargs.get('mask_type', 'left2right') == 'left2right':
            tf.logging.info("***** using left2right mask and loss *****")
            sequence_mask = tf.to_float(
                tf.not_equal(features['input_ori_ids'][:, 1:],
                             kargs.get('[PAD]', 0)))
        elif kargs.get('mask_type', 'left2right') == 'seq2seq':
            tf.logging.info("***** using seq2seq mask and loss *****")
            sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
            if not kargs.get('use_tpu', False):
                tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask))

        # batch x seq_length
        if casual_flag:
            print(model.get_sequence_output_logits().get_shape(),
                  "===logits shape===")
            seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=features['input_ori_ids'][:, 1:],
                logits=model.get_sequence_output_logits()[:, :-1])

            per_example_loss = tf.reduce_sum(
                seq_loss * sequence_mask,
                axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
            loss = tf.reduce_mean(per_example_loss)

            if model_config.get("cnn_type",
                                "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']:
                seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ori_ids'][:, :-1],
                    logits=model.get_sequence_backward_output_logits()[:, 1:])

                per_backward_example_loss = tf.reduce_sum(
                    seq_backward_loss * sequence_mask,
                    axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
                backward_loss = tf.reduce_mean(per_backward_example_loss)
                loss += backward_loss
                tf.logging.info("***** using backward loss *****")
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = pretrain.seq_mask_masked_lm_output(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 seq_features['input_mask'],
                 seq_features['input_ori_ids'],
                 seq_features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            loss = masked_lm_loss
            tf.logging.info("***** using masked lm loss *****")
        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                # train_metric_dict = train_metric(features['input_ori_ids'],
                # 								model.get_sequence_output_logits(),
                # 								seq_features,
                # 								**kargs)

                # if not kargs.get('use_tpu', False):
                # 	for key in train_metric_dict:
                # 		tf.summary.scalar(key, train_metric_dict[key])
                # 	tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)
                # 	tf.logging.info("***** logging metric *****")
                # 	tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask))
                # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits(),
                                           seq_features, **kargs)
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits(), seq_features,
                kargs.get('mask_type', 'left2right')
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            if kargs.get('predict_type',
                         'sample_sequence') == 'sample_sequence':
                results = bert_seq_sample_utils.sample_sequence(
                    model_api,
                    model_config,
                    mode,
                    features,
                    target="",
                    start_token=kargs.get("start_token_id", 101),
                    batch_size=None,
                    context=features.get("context", None),
                    temperature=kargs.get("sample_temp", 1.0),
                    n_samples=kargs.get("n_samples", 1),
                    top_k=0,
                    end_token=kargs.get("end_token_id", 102),
                    greedy_or_sample="greedy",
                    gumbel_temp=0.01,
                    estimator="stop_gradient",
                    back_prop=True,
                    swap_memory=True,
                    seq_type=kargs.get("seq_type", "seq2seq"),
                    mask_type=kargs.get("mask_type", "seq2seq"),
                    attention_type=kargs.get('attention_type',
                                             'normal_attention'))
                # stop_gradient output:
                # samples, mask_sequence, presents, logits, final

                sampled_token = results['samples']
                sampled_token_logits = results['logits']
                mask_sequence = results['mask_sequence']

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': sampled_token,
                        "logits": sampled_token_logits,
                        "mask_sequence": mask_sequence
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            sampled_token,
                            "logits":
                            sampled_token_logits,
                            "mask_sequence":
                            mask_sequence
                        })
                    })

                return estimator_spec

            elif kargs.get('predict_type',
                           'sample_sequence') == 'infer_inputs':

                sequence_mask = tf.to_float(
                    tf.not_equal(features['input_ids'][:, 1:],
                                 kargs.get('[PAD]', 0)))

                if kargs.get('mask_type', 'left2right') == 'left2right':
                    tf.logging.info(
                        "***** using left2right mask and loss *****")
                    sequence_mask = tf.to_float(
                        tf.not_equal(features['input_ori_ids'][:, 1:],
                                     kargs.get('[PAD]', 0)))
                elif kargs.get('mask_type', 'left2right') == 'seq2seq':
                    tf.logging.info("***** using seq2seq mask and loss *****")
                    sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
                    if not kargs.get('use_tpu', False):
                        tf.summary.scalar("loss mask",
                                          tf.reduce_mean(sequence_mask))

                output_logits = model.get_sequence_output_logits()[:, :-1]
                # output_logits = tf.nn.log_softmax(output_logits, axis=-1)

                output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ids'][:, 1:], logits=output_logits)

                per_example_perplexity = tf.reduce_sum(output_id_logits *
                                                       sequence_mask,
                                                       axis=-1)  # batch
                per_example_perplexity /= tf.reduce_sum(sequence_mask,
                                                        axis=-1)  # batch

                perplexity = tf.exp(per_example_perplexity)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': features['input_ids'][:, 1:],
                        "logits": output_id_logits,
                        'perplexity': perplexity,
                        # "all_logits":output_logits
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            features['input_ids'][:, 1:],
                            "logits":
                            output_id_logits,
                            'perplexity':
                            perplexity,
                            # "all_logits":output_logits
                        })
                    })

                return estimator_spec
        else:
            raise NotImplementedError()
Example #19
0
	def model_fn(features, labels, mode):

		# model = bert_encoder(model_config, features, labels,
		# 					mode, target, reuse=model_reuse)

		model = albert_encoder(model_config, features, labels,
							mode, target, reuse=model_reuse)

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=model_reuse):
			(loss, 
				per_example_loss, 
				logits) = classifier.classifier(model_config,
											model.get_pooled_output(),
											num_labels,
											label_ids,
											dropout_prob)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			print("==update_ops==", update_ops)

			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)
			# train_op, hooks = model_io_fn.get_ema_hooks(train_op, 
			# 					tvars,
			# 					kargs.get('params_moving_average_decay', 0.99),
			# 					scope, mode, 
			# 					first_stage_steps=opt_config.num_warmup_steps,
			# 					two_stage=True)

			model_io_fn.set_saver()

			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op)
				
			return estimator_spec
		elif mode == tf.estimator.ModeKeys.EVAL:
			
			# _, hooks = model_io_fn.get_ema_hooks(None,
			# 							None,
			# 							kargs.get('params_moving_average_decay', 0.99), 
			# 							scope, mode)

			hooks = None

			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)

			eval_hooks = [hooks] if hooks else []
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops,
								evaluation_hooks=eval_hooks
								)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss)
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Example #20
0
	def model_fn(features, labels, mode):

		model_api = model_zoo(model_config)

		model_lst = []

		assert len(target.split(",")) == 2
		target_name_lst = target.split(",")
		print(target_name_lst)
		for index, name in enumerate(target_name_lst):
			if index > 0:
				reuse = True
			else:
				reuse = model_reuse
			model_lst.append(model_api(model_config, features, labels,
							mode, name, reuse=reuse))

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=model_reuse):
			seq_output_lst = [model.get_pooled_output() for model in model_lst]
			if model_config.get("classifier", "order_classifier") == "order_classifier":
				[loss, 
					per_example_loss, 
					logits] = classifier.order_classifier(
								model_config, seq_output_lst, 
								num_labels, label_ids,
								dropout_prob, ratio_weight=None)
			elif model_config.get("classifier", "order_classifier") == "siamese_interaction_classifier":
				[loss, 
					per_example_loss, 
					logits] = classifier.siamese_classifier(
								model_config, seq_output_lst, 
								num_labels, label_ids,
								dropout_prob, ratio_weight=None)

		model_io_fn = model_io.ModelIO(model_io_config)

		params_size = model_io_fn.count_params(model_config.scope)
		print("==total params==", params_size)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):

				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				if output_type == "sess":
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			
			
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'pred_label':pred_label,
												"max_prob":max_prob
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'pred_label':pred_label,
														"max_prob":max_prob
													}
												)
									}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss),
							"feature":(seq_output_lst[0]+seq_output_lst[1])/2
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Example #21
0
    def model_fn(features, labels, mode):

        train_ops = []
        train_hooks = []
        logits_dict = {}
        losses_dict = {}
        features_dict = {}
        tvars = []
        task_num_dict = {}

        total_loss = tf.constant(0.0)

        task_num = 0

        encoder = {}
        hook_dict = {}

        print(task_type_dict.keys(), "==task type dict==")
        num_task = len(task_type_dict)

        for index, task_type in enumerate(task_type_dict.keys()):
            if model_config_dict[task_type].model_type in model_type_lst:
                reuse = True
            else:
                reuse = None
                model_type_lst.append(model_config_dict[task_type].model_type)
            if task_type_dict[task_type] == "cls_task":

                if model_config_dict[task_type].model_type not in encoder:
                    model_api = model_zoo(model_config_dict[task_type])

                    model = model_api(model_config_dict[task_type],
                                      features,
                                      labels,
                                      mode,
                                      target_dict[task_type],
                                      reuse=reuse)
                    encoder[model_config_dict[task_type].model_type] = model

                print(encoder, "==encode==")

                task_model_fn = cls_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    **kargs)
                print("==SUCCEEDED IN LODING==", task_type)

                result_dict = task_model_fn(features, labels, mode)
                logits_dict[task_type] = result_dict["logits"]
                losses_dict[task_type] = result_dict["loss"]  # task loss
                for key in [
                        "masked_lm_loss", "task_loss", "acc", "task_acc",
                        "masked_lm_acc"
                ]:
                    name = "{}_{}".format(task_type, key)
                    if name in result_dict:
                        hook_dict[name] = result_dict[name]
                hook_dict["{}_loss".format(task_type)] = result_dict["loss"]
                total_loss += result_dict["loss"]

                if mode == tf.estimator.ModeKeys.TRAIN:
                    tvars.extend(result_dict["tvars"])
                    task_num += result_dict["task_num"]
                    task_num_dict[task_type] = result_dict["task_num"]
                elif mode == tf.estimator.ModeKeys.EVAL:
                    features[task_type] = result_dict["feature"]
            else:
                continue

        hook_dict["total_loss"] = total_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn = model_io.ModelIO(model_io_config)

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(list(set(tvars)),
                                     string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)

            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, list(set(tvars)), opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver(optimizer_fn.opt)

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                elif kargs.get("task_index", 1) == 1:
                    training_hooks = []
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

            if output_type == "sess":
                return {
                    "train": {
                        "total_loss": total_loss,
                        "loss": losses_dict,
                        "logits": logits_dict,
                        "train_op": train_op,
                        "task_num_dict": task_num_dict
                    },
                    "hooks": train_hooks
                }
            elif output_type == "estimator":

                hook_dict['learning_rate'] = optimizer_fn.learning_rate
                logging_hook = tf.train.LoggingTensorHook(hook_dict,
                                                          every_n_iter=100)
                training_hooks.append(logging_hook)

                print("==hook_dict==")

                print(hook_dict)

                for key in hook_dict:
                    tf.summary.scalar(key, hook_dict[key])
                    for index, task_type in enumerate(task_type_dict.keys()):
                        tmp = "{}_loss".format(task_type)
                        if tmp == key:
                            tf.summary.scalar(
                                "loss_gap_{}".format(task_type),
                                hook_dict["total_loss"] - hook_dict[key])
                for key in task_num_dict:
                    tf.summary.scalar(key + "_task_num", task_num_dict[key])

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=total_loss,
                                                            train_op=train_op)
                # training_hooks=training_hooks)
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:  # eval execute for each class solo

            def metric_fn(logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "logits": logits_dict,
                        "total_loss": total_loss,
                        "feature": features,
                        "loss": losses_dict
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = {}
                for key in logits:
                    eval_dict = metric_fn(logits[key],
                                          features_task_dict[key]["label_ids"])
                    for sub_key in eval_dict.keys():
                        eval_key = "{}_{}".format(key, sub_key)
                        eval_metric_ops[eval_key] = eval_dict[sub_key]
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss / task_num,
                    eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #22
0
	def model_fn(features, labels, mode, params):

		original_loss = tf.constant(0.0)
		distilled_loss = tf.constant(0.0)

		st_model = st_model_fn(model_config_dict['student'],
		 			num_labels_dict['student'],
					init_checkpoint_dict['student'],
					model_reuse=None,
					load_pretrained=load_pretrained_dict['student'],
					model_io_config=model_io_config,
					opt_config=opt_config,
					exclude_scope=exclude_scope_dict.get('student', ""),
					not_storage_params=not_storage_params_dict.get('student', []),
					target=target_dict['student'],
					**kargs)
		st_dict = st_model(features, labels, mode, params)

		# ta_model = ta_model_fn(model_config_dict['teacher'],
		#  			num_labels_dict['teacher'],
		# 			init_checkpoint_dict['teacher'],
		# 			model_reuse=None,
		# 			load_pretrained=load_pretrained_dict['teacher'],
		# 			model_io_config=model_io_config,
		# 			opt_config=opt_config,
		# 			exclude_scope=exclude_scope_dict.get('teacher', ""),
		# 			not_storage_params=not_storage_params_dict.get('teacher', []),
		# 			target=target_dict['teacher'],
		# 			**kargs)
		# ta_features = {}
		# for key in features:
		# 	ta_features[key] = features[key]
		# ta_features['masked_lm_mask'] = st_dict['masked_lm_mask']
		# ta_features['input_ids'] = st_dict['output_ids']
		# ta_features['input_ori_ids'] = features['input_ids']
		# ta_dict = ta_model(ta_features, labels, mode, params)

		# studnet_logit = st_dict['logits']
		# teacher_logit = ta_dict['logits']

		model_io_fn = model_io.ModelIO(model_io_config)

		original_loss += st_dict['loss'] * (distillation_config.get('ce_loss', 1.0))
		print(distillation_config.get('ce_loss', 1.0), '===ce_loss===')
		if not kargs.get('use_tpu', False):
			tf.summary.scalar("ce_loss", st_dict['loss'])

		hook_dict = {}

		# if 'kl_logits' in distillation_config.get('distillation_type', ['kl_logits']):
		# 	temperature = distillation_config.get('kl_temperature', 2.0)
		# 	distilled_teacher_logit = tf.nn.log_softmax((teacher_logit+1e-10) / temperature) # log_softmax logits
		# 	distilled_student_logit = tf.nn.log_softmax((studnet_logit+1e-10) / temperature) # log_softmax logits

		# 	logits_mask = tf.cast(st_dict['masked_lm_mask'], tf.float32)
		# 	kl_distilled_loss = distillation_utils.kd(distilled_teacher_logit, 
		# 												distilled_student_logit)
		# 	kl_distilled_loss = tf.reduce_sum(logits_mask*kl_distilled_loss) / tf.reduce_sum(logits_mask)

		# 	if not kargs.get('use_tpu', False):
		# 		tf.summary.scalar("kl_logits_loss", kl_distilled_loss)
		# 		tf.summary.scalar("kl_logits_mask", tf.reduce_mean(logits_mask))
		# 	tf.logging.info("***** with knowledge distillation %s tenperature *****", str(temperature))
		# 	hook_dict['kl_logits_loss'] = kl_distilled_loss
		# 	# kl_distilled_loss *= np.power(temperature, 2)
		# 	distilled_loss += kl_distilled_loss * distillation_config.get('kl_logits', 0.9)
		# 	print(distillation_config.get('kl_logits_ratio', 0.9), '===kl_logits_ratio===')

		# if "attention_score_uniform" in distillation_config.get('distillation_type', ['kl_logits']):
		# 	source_attention_score = ta_dict['model'].get_multihead_attention()
		# 	target_attention_score = st_dict['model'].get_multihead_attention()

		# 	print("==apply attention_score_uniform==")

		# 	with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):  
		# 		attention_loss = uniform_mapping.attention_score_matching(source_attention_score, 
		# 																target_attention_score,
		# 																features['input_mask'],
		# 																0)
		# 	tf.summary.scalar("attention_score_uniform_loss", attention_loss)
		# 	distilled_loss += attention_loss * distillation_config.get("attention_score_uniform", 0.1)
		# 	hook_dict['attention_mse_loss'] = attention_loss
		# 	print(distillation_config.get('attention_score_uniform', 0.1), '===attention_score_uniform===')
			
		# if "hidden_uniform" in distillation_config.get('distillation_type', ['kl_logits']):
		# 	source_hidden = ta_dict['model'].get_all_encoder_layers()
		# 	target_hidden = st_dict['model'].get_all_encoder_layers()

		# 	print("==apply hidden_uniform==")

		# 	with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
		# 		hidden_loss = uniform_mapping.hidden_matching(source_hidden, target_hidden, 
		# 													features['input_mask'],
		# 													0)
		# 	if not kargs.get('use_tpu', False):
		# 		tf.summary.scalar("hidden_uniform_loss", hidden_loss)
		# 	distilled_loss += hidden_loss * distillation_config.get("hidden_uniform", 0.1)
		# 	hook_dict['hidden_loss'] = hidden_loss
		# 	print(distillation_config.get('hidden_uniform', 0.1), '===hidden_uniform===')

		# if "embedding_distillation" in distillation_config.get('distillation_type', ['embedding_distillation']):
		# 	st_word_embed = st_dict['model'].get_embedding_table()
		# 	ta_word_embed = ta_dict['model'].get_embedding_table()
		# 	st_word_embed_shape = bert_utils.get_shape_list(st_word_embed, expected_rank=[2,3])
		# 	print("==random_embed_shape==", st_word_embed_shape)
		# 	ta_word_embed_shape = bert_utils.get_shape_list(ta_word_embed, expected_rank=[2,3])
		# 	print("==pretrain_embed_shape==", ta_word_embed_shape)
		# 	if st_word_embed_shape[-1] != ta_word_embed_shape[-1]:
		# 		with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
		# 			with tf.variable_scope("embedding_proj"):
		# 				proj_embed = tf.layers.dense(ta_word_embed, st_word_embed_shape[-1])
		# 	else:
		# 		proj_embed = ta_word_embed
			
		# 	embed_loss = tf.reduce_mean(tf.reduce_mean(tf.square(proj_embed-st_word_embed), axis=-1))
		# 	distilled_loss += embed_loss
		# 	hook_dict['embed_loss'] = embed_loss
		# 	tf.logging.info("****** apply prertained feature distillation *******")

		total_loss = distilled_loss + original_loss
		tvars = []
		tvars.extend(st_dict['tvars'])

		distillation_vars = model_io_fn.get_params('distillation', 
							not_storage_params=[])
		tvars.extend(distillation_vars)

		# if kargs.get('update_ta', False):
		# 	total_loss += ta_dict['loss']
		# 	tvars.extend(ta_dict['tvars'])

		if not kargs.get('use_tpu', False):
			student_eval_metrics = train_metric_fn(
						  st_dict['masked_lm_example_loss'], 
						  st_dict['logits'], 
						  st_dict["masked_lm_ids"],
						  st_dict['masked_lm_mask'],
						  'student')

			# teacher_eval_metric =  train_metric_fn( 
			# 			  ta_dict['masked_lm_example_loss'], 
			# 			  ta_dict['logits'], 
			# 			  ta_dict["masked_lm_ids"],
			# 			  ta_dict['masked_lm_mask'],
			# 			  'teacher')
			# student_eval_metrics.update(teacher_eval_metric)
			for key in student_eval_metrics:
				hook_dict[key] = student_eval_metrics[key]

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			print("==update_ops==", update_ops)

			print('==total trainable vars==', list(tvars))

			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(total_loss, list(set(tvars)), 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 1 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 1:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				logging_hook = tf.train.LoggingTensorHook(
					hook_dict, every_n_iter=100)
				training_hooks.append(logging_hook)

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=total_loss, train_op=train_op,
								training_hooks=training_hooks)
				
				return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
		
			student_eval_metrics = metric_fn(
								  st_dict['masked_lm_example_loss'], 
								  st_dict['logits'], 
								  st_dict["masked_lm_ids"],
								  st_dict['masked_lm_mask'],
								  'student')

			# teacher_eval_metric =  metric_fn( 
			# 					  ta_dict['masked_lm_example_loss'], 
			# 					  ta_dict['logits'], 
			# 					  ta_dict["masked_lm_ids"],
			# 					  ta_dict['masked_lm_mask'],
			# 					  'teacher')

			# student_eval_metrics.update(teacher_eval_metric)

			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
							loss=total_loss,
							eval_metric_ops=student_eval_metrics)
			return estimator_spec
		else:
			raise NotImplementedError()
Example #23
0
    def model_fn(features, labels, mode):
        model = bert_encoder(model_config,
                             features,
                             labels,
                             mode,
                             target,
                             reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        if load_pretrained:
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        model_io_fn.set_saver(var_lst=tvars)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        }
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        seq_features['input_ids'] = features["input_ori_ids"]

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        sequence_mask = tf.to_float(
            tf.not_equal(features['input_ori_ids'][:, 1:],
                         kargs.get('[PAD]', 0)))

        # batch x seq_length
        print(model.get_sequence_output_logits().get_shape(),
              "===logits shape===")
        seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=features['input_ori_ids'][:, 1:],
            logits=model.get_sequence_output_logits()[:, :-1])

        per_example_loss = tf.reduce_sum(seq_loss * sequence_mask, axis=-1) / (
            tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
        loss = tf.reduce_mean(per_example_loss)

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
            tf.logging.info("***** using tpu *****")
        else:
            scaffold_fn = None
            tf.logging.info("***** not using tpu *****")

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                train_metric_dict = train_metric(
                    features['input_ori_ids'],
                    model.get_sequence_output_logits(), **kargs)

                if not kargs.get('use_tpu', False):
                    for key in train_metric_dict:
                        tf.summary.scalar(key, train_metric_dict[key])
                    tf.summary.scalar('learning_rate',
                                      optimizer_fn.learning_rate)
                    tf.logging.info("***** logging metric *****")
                    tf.summary.scalar("causal_attenion_mask_length",
                                      tf.reduce_sum(model.attention_mask))
                    tf.summary.scalar("bi_attenion_mask_length",
                                      tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits())
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits()
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
Example #25
0
    def model_fn(features, labels, mode):

        original_loss = tf.constant(0.0)
        distilled_loss = tf.constant(0.0)

        st_model = st_model_fn(
            model_config_dict['student'],
            num_labels_dict['student'],
            init_checkpoint_dict['student'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['student'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('student', ""),
            not_storage_params=not_storage_params_dict.get('student', []),
            target=target_dict['student'],
            **kargs)
        st_dict = st_model(features, labels, mode)

        ta_model = ta_model_fn(
            model_config_dict['teacher'],
            num_labels_dict['teacher'],
            init_checkpoint_dict['teacher'],
            model_reuse=None,
            load_pretrained=load_pretrained_dict['teacher'],
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope=exclude_scope_dict.get('teacher', ""),
            not_storage_params=not_storage_params_dict.get('teacher', []),
            target=target_dict['teacher'],
            **kargs)
        ta_dict = ta_model(features, labels, mode)

        studnet_logit = st_dict['logits']
        teacher_logit = ta_dict['logits']

        model_io_fn = model_io.ModelIO(model_io_config)

        feature_flag = False

        original_loss += st_dict['loss'] * (distillation_config.get(
            'ce_loss', 1.0))
        print(distillation_config.get('ce_loss', 1.0), '===ce_loss===')
        tf.summary.scalar("ce_loss", st_dict['loss'])

        if 'kl_logits' in distillation_config.get('distillation_type',
                                                  ['kl_logits']):
            temperature = distillation_config.get('kl_temperature', 2.0)
            distilled_teacher_logit = tf.nn.log_softmax(
                (teacher_logit + 1e-10) / temperature)  # log_softmax logits
            distilled_student_logit = tf.nn.log_softmax(
                (studnet_logit + 1e-10) / temperature)  # log_softmax logits

            kl_distilled_loss = tf.reduce_mean(
                distillation_utils.kd(distilled_teacher_logit,
                                      distilled_student_logit))

            tf.summary.scalar("kl_logits_loss", kl_distilled_loss)
            tf.logging.info(
                "***** with knowledge distillation %s tenperature *****",
                str(temperature))

            # kl_distilled_loss *= np.power(temperature, 2)
            distilled_loss += kl_distilled_loss * distillation_config.get(
                'kl_logits_ratio', 0.9)
            print(distillation_config.get('kl_logits_ratio', 0.9),
                  '===kl_logits_ratio===')

        if 'rkd' in distillation_config.get('distillation_type',
                                            ['kl_logits']):
            source = ta_dict['model'].get_pooled_output()
            target = st_dict['model'].get_pooled_output()
            print("==apply rkd==")
            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                rkd_loss = repo_distillation_utils.RKD(source,
                                                       target,
                                                       l=[25, 50])
            tf.summary.scalar("rkd_loss", rkd_loss)
            distilled_loss += rkd_loss * distillation_config.get(
                "rkd_ratio", 0.1)

        if "attention_score_uniform" in distillation_config.get(
                'distillation_type', ['kl_logits']):
            source_attention_score = ta_dict['model'].get_multihead_attention()
            target_attention_score = st_dict['model'].get_multihead_attention()

            print("==apply attention_score_uniform==")

            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                attention_loss = uniform_mapping.attention_score_matching(
                    source_attention_score, target_attention_score,
                    features['input_mask'], 0)
            tf.summary.scalar("attention_score_uniform_loss", attention_loss)
            feature_flag = True
            distilled_loss += attention_loss * distillation_config.get(
                "attention_score_uniform", 0.1)

            print(distillation_config.get('attention_score_uniform', 0.1),
                  '===attention_score_uniform===')

        if "hidden_uniform" in distillation_config.get('distillation_type',
                                                       ['kl_logits']):
            source_hidden = ta_dict['model'].get_all_encoder_layers()
            target_hidden = st_dict['model'].get_all_encoder_layers()

            print("==apply hidden_uniform==")

            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                hidden_loss = uniform_mapping.hidden_matching(
                    source_hidden, target_hidden, features['input_mask'], 0)
            tf.summary.scalar("hidden_uniform_loss", hidden_loss)
            distilled_loss += hidden_loss * distillation_config.get(
                "hidden_uniform", 0.1)
            feature_flag = True

            print(distillation_config.get('hidden_uniform', 0.1),
                  '===hidden_uniform===')

        if "hidden_cls_uniform" in distillation_config.get(
                'distillation_type', ['kl_logits']):
            source_hidden = ta_dict['model'].get_all_encoder_layers()
            target_hidden = st_dict['model'].get_all_encoder_layers()

            print("==apply hidden_cls_uniform==")
            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                hidden_cls_loss = uniform_mapping.hidden_cls_matching(
                    source_hidden, target_hidden, 0)
            tf.summary.scalar("hidden_cls_uniform_loss", hidden_cls_loss)
            distilled_loss += hidden_cls_loss * distillation_config.get(
                "hidden_uniform", 0.1)
            feature_flag = True

        if "mdd" in distillation_config.get('distillation_type', ['mdd']):
            source = ta_dict['model'].get_pooled_output()
            target = st_dict['model'].get_pooled_output()

            print("==apply mdd==")

        if "cpc" in distillation_config.get('distillation_type', ['mdd']):
            source_hidden = ta_dict['model'].get_all_encoder_layers()
            target_hidden = st_dict['model'].get_all_encoder_layers()
            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                cpc_loss = cpc_utils.CPC_Hidden(target_hidden, source_hidden,
                                                features['input_mask'])
            tf.summary.scalar("hidden_cpc_loss", cpc_loss)
            distilled_loss += cpc_loss + distillation_config.get(
                "cpc_hidden", 0.1)

        if "wpc" in distillation_config.get('distillation_type', ['mdd']):
            source_hidden = ta_dict['model'].get_all_encoder_layers()
            target_hidden = st_dict['model'].get_all_encoder_layers()
            with tf.variable_scope("distillation", reuse=tf.AUTO_REUSE):
                wpc_loss = cpc_utils.WPC_Hidden(target_hidden, source_hidden,
                                                features['input_mask'])
            tf.summary.scalar("hidden_wpc_loss", wpc_loss)
            distilled_loss += wpc_loss + distillation_config.get(
                "wpc_hidden", 0.1)

        total_loss = distilled_loss + original_loss

        tvars = []
        tvars.extend(st_dict['tvars'])

        if feature_flag:
            distillation_vars = model_io_fn.get_params('distillation',
                                                       not_storage_params=[])
            tvars.extend(distillation_vars)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)

            print('==total trainable vars==', list(tvars))

            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, list(set(tvars)), opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": total_loss,
                            "logits": studnet_logit,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids, model_type):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels_dict['student'],
                                           None,
                                           average="macro")

                eval_metric_ops = {
                    "{}_f1".format(model_type): sentence_f,
                    "{}_acc".format(model_type): sentence_accuracy
                }

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss":
                        st_dict['logits']['per_example_loss'],
                        "logits":
                        studnet_logit,
                        "loss":
                        tf.reduce_mean(st_dict['logits']['per_example_loss']),
                        "feature":
                        st_dict['model'].get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(st_dict['per_example_loss'],
                                            studnet_logit,
                                            features['label_ids'], "student")
                teacher_eval_metric_ops = metric_fn(
                    ta_dict['per_example_loss'], teacher_logit,
                    features['label_ids'], "teacher")

                eval_metric_ops.update(teacher_eval_metric_ops)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Example #26
0
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target,
                  init_checkpoint, train_file, dev_file, checkpoint_dir,
                  is_debug):

    graph = tf.Graph()
    with graph.as_default():
        import json

        config = json.load(open(FLAGS.config_file, "r"))

        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / worker_count)
        elif FLAGS.if_shard == "1":
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch

        init_lr = 2e-5

        label_dict = json.load(open(FLAGS.label_id))

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

        if is_debug == "0":
            num_storage_steps = 2
            num_eval_steps = 10
            num_train_steps = 10
        print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
              format(num_train_steps, num_eval_steps, num_storage_steps))

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": init_lr / worker_count,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "worker_count": worker_count,
            "opt_type": FLAGS.opt_type
        })

        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)

        optimizer_fn = optimizer.Optimizer(opt_config)

        num_classes = FLAGS.num_classes

        model_train_fn = model_fn_builder(config,
                                          num_classes,
                                          init_checkpoint,
                                          model_reuse=None,
                                          load_pretrained=True,
                                          model_io_fn=model_io_fn,
                                          optimizer_fn=optimizer_fn,
                                          model_io_config=model_io_config,
                                          opt_config=opt_config,
                                          exclude_scope="",
                                          not_storage_params=[],
                                          target="")

        model_eval_fn = model_fn_builder(config,
                                         num_classes,
                                         init_checkpoint,
                                         model_reuse=True,
                                         load_pretrained=True,
                                         model_io_fn=model_io_fn,
                                         optimizer_fn=optimizer_fn,
                                         model_io_config=model_io_config,
                                         opt_config=opt_config,
                                         exclude_scope="",
                                         not_storage_params=[],
                                         target="")
        if FLAGS.opt_type == "ps":
            sync_replicas_hook = optimizer_fn.opt.make_session_run_hook(
                is_chief, num_tokens=0)
        else:
            sync_replicas_hook = []

        def eval_metric_fn(features, eval_op_dict):
            logits = eval_op_dict["logits"]
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

            return {
                "accuracy": accuracy,
                "loss": eval_op_dict["loss"],
                "pred_label": pred_label,
                "label_ids": features["label_ids"]
            }

        def train_metric_fn(features, train_op_dict):
            logits = train_op_dict["logits"]
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            return {
                "accuracy": accuracy,
                "loss": train_op_dict["loss"],
                "train_op": train_op_dict["train_op"]
            }

        name_to_features = {
            "input_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "input_mask": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size

        train_features = tf_data_utils.train_input_fn(
            train_file,
            _decode_record,
            name_to_features,
            params,
            if_shard=FLAGS.if_shard,
            worker_count=worker_count,
            task_index=task_index)

        eval_features = tf_data_utils.eval_input_fn(dev_file,
                                                    _decode_record,
                                                    name_to_features,
                                                    params,
                                                    if_shard=FLAGS.if_shard,
                                                    worker_count=worker_count,
                                                    task_index=task_index)

        train_op_dict = model_train_fn(train_features, [],
                                       tf.estimator.ModeKeys.TRAIN)
        eval_op_dict = model_eval_fn(eval_features, [],
                                     tf.estimator.ModeKeys.EVAL)
        eval_dict = eval_metric_fn(eval_features, eval_op_dict["eval"])
        train_dict = train_metric_fn(train_features, train_op_dict["train"])

        print("===========begin to train============")
        sess_config = tf.ConfigProto(allow_soft_placement=False,
                                     log_device_placement=False)

        checkpoint_dir = checkpoint_dir if task_index == 0 else None

        print("start training")

        # hooks = [tf.train.StopAtStepHook(last_step=num_train_steps)]
        hooks = []
        if FLAGS.opt_type == "ps":
            sync_replicas_hook = optimizer_fn.opt.make_session_run_hook(
                is_chief, num_tokens=0)
            hooks.append(sync_replicas_hook)
            sess = tf.train.MonitoredTrainingSession(
                master=target,
                is_chief=is_chief,
                config=sess_config,
                hooks=hooks,
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_steps=num_storage_steps)
        else:
            sess = tf.train.MonitoredTrainingSession(
                config=sess_config,
                hooks=[],
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_steps=num_storage_steps)

        def eval_fn(eval_dict, sess):
            i = 0
            total_accuracy = 0
            eval_total_dict = {}
            while True:
                try:
                    eval_result = sess.run(eval_dict)
                    for key in eval_result:
                        if key not in eval_total_dict:
                            if key in ["pred_label", "label_ids"]:
                                eval_total_dict[key] = []
                                eval_total_dict[key].extend(eval_result[key])
                            if key in ["accuracy", "loss"]:
                                eval_total_dict[key] = 0.0
                                eval_total_dict[key] += eval_result[key]
                        else:
                            if key in ["pred_label", "label_ids"]:
                                eval_total_dict[key].extend(eval_result[key])
                            if key in ["accuracy", "loss"]:
                                eval_total_dict[key] += eval_result[key]

                    i += 1
                    if np.mod(i, num_eval_steps) == 0:
                        break
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break

            label_id = eval_total_dict["label_ids"]
            pred_label = eval_total_dict["pred_label"]

            result = classification_report(label_id,
                                           pred_label,
                                           target_names=list(
                                               label_dict["label2id"].keys()))

            print(result, task_index)
            eval_total_dict["classification_report"] = result
            return eval_total_dict

        def train_fn(train_op_dict, sess):
            i = 0
            cnt = 0
            loss_dict = {}
            monitoring_train = []
            monitoring_eval = []
            while True:
                try:
                    [train_result] = sess.run([train_op_dict])
                    step = sess.run(tf.train.get_global_step())
                    for key in train_result:
                        if key == "train_op":
                            continue
                        else:
                            if np.isnan(train_result[key]):
                                print(train_loss, "get nan loss")
                                break
                            else:
                                if key in loss_dict:
                                    loss_dict[key] += train_result[key]
                                else:
                                    loss_dict[key] = train_result[key]

                    i += 1
                    cnt += 1

                    if np.mod(i, num_storage_steps) == 0:
                        string = ""
                        for key in loss_dict:
                            tmp = key + " " + str(loss_dict[key] / cnt) + "\t"
                            string += tmp
                        print(string, step)
                        monitoring_train.append(loss_dict)

                        eval_finial_dict = eval_fn(eval_dict, sess)
                        monitoring_eval.append(eval_finial_dict)

                        for key in loss_dict:
                            loss_dict[key] = 0.0
                        cnt = 0

                    if is_debug == "0":
                        if i == num_train_steps:
                            break

                except tf.errors.OutOfRangeError:
                    print("==Succeeded in training model==")

        # print("===========begin to train============")
        # sess_config = tf.ConfigProto(allow_soft_placement=False,
        # 							log_device_placement=False)

        # checkpoint_dir = checkpoint_dir if task_index == 0 else None

        # print("start training")

        # hooks = [tf.train.StopAtStepHook(last_step=num_train_steps)]
        # if sync_replicas_hook:
        # 	hooks.append(sync_replicas_hook)

        # sess = tf.train.MonitoredTrainingSession(master=target,
        # 									 is_chief=is_chief,
        # 									 config=sess_config,
        # 									 hooks=[],
        # 									 checkpoint_dir=checkpoint_dir,
        # 									 save_checkpoint_steps=num_storage_steps)

        # with tf.train.MonitoredTrainingSession(master=target,
        # 									 is_chief=is_chief,
        # 									 config=sess_config,
        # 									 hooks=[],
        # 									 checkpoint_dir=checkpoint_dir,
        # 									 save_checkpoint_steps=num_storage_steps) as sess:
        step = sess.run(optimizer_fn.global_step)
        print(step)
        train_fn(train_dict, sess)

        if task_index == 0:
            print("===========begin to eval============")
            eval_finial_dict = eval_fn(eval_dict, sess)
Example #27
0
	def model_fn(features, labels, mode):

		model = gpt_encoder(model_config, features, labels, 
			mode, target, reuse=tf.AUTO_REUSE)
		scope = model_config.scope

		if mode == tf.estimator.ModeKeys.TRAIN:
			# batch x seq_length
			sequence_mask = tf.to_float(tf.not_equal(features['input_ids'][:, 1:], 
													kargs.get('[PAD]', 0)))

			# batch x seq_length
			seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
						labels=features['input_ids'][:, 1:], 
						logits=model.get_sequence_output_logits()[:, :-1])

			per_example_loss = tf.reduce_sum(seq_loss*sequence_mask, axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1)+1e-10)
			loss = tf.reduce_mean(per_example_loss)

		model_io_fn = model_io.ModelIO(model_io_config)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		try:
			params_size = model_io_fn.count_params(model_config.scope)
			print("==total params==", params_size)
		except:
			print("==not count params==")
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			print("==update_ops==", update_ops)
			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				train_metric_dict = train_metric(features['input_ids'], 
												model.get_sequence_output_logits(), 
												**kargs)

				for key in train_metric_dict:
					tf.summary.scalar(key, train_metric_dict[key])
				tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)
				tf.summary.scalar('seq_length', tf.reduce_mean(tf.reduce_sum(sequence_mask, axis=-1)))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				print(tf.global_variables(), "==global_variables==")
				if output_type == "sess":
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			if kargs.get('predict_type', 'sample_sequence') == 'sample_sequence':
				results = sample.sample_sequence(
							gpt_encoder, hparams=model_config, 
							length=kargs.get('max_length', 64), 
							start_token=None, 
							batch_size=10, 
							context=features['input_ids'],
							temperature=2,
							top_k=10)
				
				sampled_token = results['tokens'][:, 1:]
				sampled_token_logits = results['logits'][:, 1:]

				estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'token':sampled_token,
												"logits":sampled_token_logits
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'token':sampled_token,
														"logits":sampled_token_logits
													}
												)
									}
						)

				return estimator_spec

			elif kargs.get('predict_type', 'sample_sequence') == 'infer_inputs':
				sequence_mask = tf.to_float(tf.not_equal(features['input_ids'][:, 1:], 
													kargs.get('[PAD]', 0)))
				output_logits = model.get_sequence_output_logits()[:, :-1]
				# output_logits = tf.nn.log_softmax(output_logits, axis=-1)

				output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits(
										labels=features['input_ids'][:, 1:], 
										logits=output_logits)

				per_example_perplexity = tf.reduce_sum(output_id_logits * sequence_mask, 
												axis=-1) # batch
				per_example_perplexity /= tf.reduce_sum(sequence_mask, axis=-1) # batch

				perplexity = tf.exp(per_example_perplexity)

				estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'token':features['input_ids'][:, 1:],
												"logits":output_id_logits,
												'perplexity':perplexity
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'token':features['input_ids'][:,1:],
														"logits":output_id_logits,
														'perplexity':perplexity
													}
												)
									}
						)

				return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss),
							"feature":model.get_pooled_output()
						}
				}
			elif output_type == "estimator":
				eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)
				return estimator_spec
		else:
			raise NotImplementedError()
Example #28
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope + "/feature_output", reuse=model_reuse):
            hidden_size = bert_utils.get_shape_list(model.get_pooled_output(),
                                                    expected_rank=2)[-1]
            feature_output = tf.layers.dense(
                model.get_pooled_output(),
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01))
            feature_output = tf.nn.dropout(feature_output,
                                           keep_prob=1 - dropout_prob)
            feature_output += model.get_pooled_output()
            feature_output = tf.layers.dense(
                feature_output,
                hidden_size,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01),
                activation=tf.tanh)

        if mode == tf.estimator.ModeKeys.TRAIN or mode == tf.estimator.ModeKeys.EVAL:
            with tf.variable_scope(scope, reuse=model_reuse):
                (loss, per_example_loss,
                 logits) = classifier.classifier(model_config, feature_output,
                                                 num_labels, label_ids,
                                                 dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            # print(logits.get_shape(), "===logits shape===")
            # pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            # prob = tf.nn.softmax(logits)
            # max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    # 'pred_label':pred_label,
                    # "max_prob":max_prob
                    "embedding": feature_output
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput(
                        {"embedding": feature_output})
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(per_example_loss, logits,
                                            label_ids)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()