Exemplo n.º 1
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        logits = global_discriminator_logits(model_config,
                                             model.get_pooled_output(),
                                             reuse=tf.AUTO_REUSE,
                                             **kargs)

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)
        global_prediction_tvars = model_io_fn.get_params(
            "cls/seq_global", not_storage_params=not_storage_params)

        pretrained_tvars.extend(global_prediction_tvars)
        tvars = pretrained_tvars

        print('==discriminator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.PREDICT:
            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={"probs": tf.nn.softmax(logits)},
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput(
                        {"probs": tf.nn.softmax(logits)})
                })
            return estimator_spec
Exemplo n.º 2
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope='teacher')
        return_dict = {
            "loss": loss,
            "logits": logits,
            "tvars": tvars,
            "model": model,
            "per_example_loss": per_example_loss
        }
        return return_dict
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		logits = global_discriminator_logits(model_config, 
											model.get_pooled_output(), 
											reuse=tf.AUTO_REUSE, **kargs)

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		global_prediction_tvars = model_io_fn.get_params("cls/seq_global", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(global_prediction_tvars)
		tvars = pretrained_tvars

		print('==discriminator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu,
											restore_var_name=model_config.get('restore_var_name', []))
		else:
			scaffold_fn = None
		
		return_dict = {
					"logits":logits,
					"tvars":tvars,
					"model":model
				}
		return return_dict
Exemplo n.º 4
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)
        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE)

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_mask) = seq_masked_lm_fn(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             features['input_mask'],
             features['input_ori_ids'],
             features['input_ids'],
             features['input_mask'],
             reuse=tf.AUTO_REUSE,
             embedding_projection=model.get_embedding_projection_table())
        masked_lm_ids = features['input_ori_ids']

        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        nsp_pretrain_vars = model_io_fn.get_params(
            "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("discriminator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "per_example_loss": masked_lm_example_loss,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels']
        }
        return return_dict
Exemplo n.º 5
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)
        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)
        input_ori_ids = features['input_ori_ids']

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        sampled_binary_mask = features.get('masked_lm_mask', None)

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 scope=generator_scope_prefix,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:
            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 scope=generator_scope_prefix,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=generator_scope_prefix,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "logits": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "tvars": tvars,
            "model": model,
            "masked_lm_mask": masked_lm_mask,
            "masked_lm_ids": masked_lm_ids
        }
        return return_dict
Exemplo n.º 6
0
    def model_fn(features, labels, mode):

        train_ops = []
        train_hooks = []
        logits_dict = {}
        losses_dict = {}
        features_dict = {}
        tvars = []
        task_num_dict = {}
        multi_task_config = kargs.get('multi_task_config', {})

        total_loss = tf.constant(0.0)

        task_num = 0

        encoder = {}
        hook_dict = {}

        print(task_type_dict.keys(), "==task type dict==")
        num_task = len(task_type_dict)

        from data_generator import load_w2v
        flags = kargs.get('flags', Bunch({}))
        print(flags.pretrained_w2v_path, "===pretrain vocab path===")
        w2v_path = os.path.join(flags.buckets, flags.pretrained_w2v_path)
        vocab_path = os.path.join(flags.buckets, flags.vocab_file)

        # [w2v_embed, token2id,
        # id2token, is_extral_symbol, use_pretrained] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)

        # pretrained_embed = tf.cast(tf.constant(w2v_embed), tf.float32)
        pretrained_embed = None

        for index, task_type in enumerate(task_type_dict.keys()):
            if model_config_dict[task_type].model_type in model_type_lst:
                reuse = True
            else:
                reuse = None
                model_type_lst.append(model_config_dict[task_type].model_type)

            if model_config_dict[task_type].model_type not in encoder:
                model_api = model_zoo(model_config_dict[task_type])

                model = model_api(model_config_dict[task_type],
                                  features,
                                  labels,
                                  mode,
                                  target_dict[task_type],
                                  reuse=reuse,
                                  cnn_type=model_config_dict[task_type].get(
                                      'cnn_type', 'bi_dgcnn'))
                encoder[model_config_dict[task_type].model_type] = model

                # vae_kl_model = vae_model_fn(encoder[model_config_dict[task_type].model_type],
                # 			model_config_dict[task_type],
                # 			num_labels_dict[task_type],
                # 			init_checkpoint_dict[task_type],
                # 			reuse,
                # 			load_pretrained_dict[task_type],
                # 			model_io_config,
                # 			opt_config,
                # 			exclude_scope=exclude_scope_dict[task_type],
                # 			not_storage_params=not_storage_params_dict[task_type],
                # 			target=target_dict[task_type],
                # 			label_lst=None,
                # 			output_type=output_type,
                # 			task_layer_reuse=task_layer_reuse,
                # 			task_type=task_type,
                # 			num_task=num_task,
                # 			task_adversarial=1e-2,
                # 			get_pooled_output='task_output',
                # 			feature_distillation=False,
                # 			embedding_distillation=False,
                # 			pretrained_embed=pretrained_embed,
                # 			**kargs)
                # vae_result_dict = vae_kl_model(features, labels, mode)
                # tvars.extend(vae_result_dict['tvars'])
                # total_loss += vae_result_dict["loss"]
                # for key in vae_result_dict:
                # 	if key in ['perplexity', 'token_acc', 'kl_div']:
                # 		hook_dict[key] = vae_result_dict[key]
            print(encoder, "==encode==")

            if task_type_dict[task_type] == "cls_task":
                task_model_fn = cls_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
            elif task_type_dict[task_type] == "embed_task":
                task_model_fn = embed_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
                # cpc_model_fn = embed_cpc_model_fn(encoder[model_config_dict[task_type].model_type],
                # 								model_config_dict[task_type],
                # 								num_labels_dict[task_type],
                # 								init_checkpoint_dict[task_type],
                # 								reuse,
                # 								load_pretrained_dict[task_type],
                # 								model_io_config,
                # 								opt_config,
                # 								exclude_scope=exclude_scope_dict[task_type],
                # 								not_storage_params=not_storage_params_dict[task_type],
                # 								target=target_dict[task_type],
                # 								label_lst=None,
                # 								output_type=output_type,
                # 								task_layer_reuse=task_layer_reuse,
                # 								task_type=task_type,
                # 								num_task=num_task,
                # 								task_adversarial=1e-2,
                # 								get_pooled_output='task_output',
                # 								feature_distillation=False,
                # 								embedding_distillation=False,
                # 								pretrained_embed=pretrained_embed,
                # 								loss='contrastive_loss',
                # 								apply_head_proj=False,
                # 								**kargs)

                # cpc_result_dict = cpc_model_fn(features, labels, mode)
                # result_dict['loss'] += cpc_result_dict['loss']
                # result_dict['tvars'].extend(cpc_result_dict['tvars'])
                # hook_dict["{}_all_neg_loss".format(task_type)] = cpc_result_dict['loss']
                # hook_dict["{}_all_neg_num".format(task_type)] = cpc_result_dict['task_num']

            elif task_type_dict[task_type] == "cpc_task":
                task_model_fn = embed_cpc_v1_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    task_seperate_proj=True,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)

            elif task_type_dict[task_type] == "regression_task":
                task_model_fn = regression_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    get_pooled_output='task_output',
                    feature_distillation=False,
                    embedding_distillation=False,
                    pretrained_embed=pretrained_embed,
                    loss='contrastive_loss',
                    apply_head_proj=False,
                    **kargs)
                result_dict = task_model_fn(features, labels, mode)
                tf.logging.info("****** task: *******",
                                task_type_dict[task_type], task_type)
            else:
                continue
            print("==SUCCEEDED IN LODING==", task_type)

            # result_dict = task_model_fn(features, labels, mode)
            logits_dict[task_type] = result_dict["logits"]
            losses_dict[task_type] = result_dict["loss"]  # task loss
            for key in [
                    "pos_num", "neg_num", "masked_lm_loss", "task_loss", "acc",
                    "task_acc", "masked_lm_acc"
            ]:
                name = "{}_{}".format(task_type, key)
                if name in result_dict:
                    hook_dict[name] = result_dict[name]
            hook_dict["{}_loss".format(task_type)] = result_dict["loss"]
            hook_dict["{}_num".format(task_type)] = result_dict["task_num"]
            print("==loss ratio==", task_type,
                  multi_task_config[task_type].get('loss_ratio', 1.0))
            total_loss += result_dict["loss"] * multi_task_config[
                task_type].get('loss_ratio', 1.0)
            hook_dict['embed_loss'] = result_dict["embed_loss"]
            hook_dict['feature_loss'] = result_dict["feature_loss"]
            hook_dict["{}_task_loss".format(
                task_type)] = result_dict["task_loss"]
            if 'positive_label' in result_dict:
                hook_dict["{}_task_positive_label".format(
                    task_type)] = result_dict["positive_label"]
            if mode == tf.estimator.ModeKeys.TRAIN:
                tvars.extend(result_dict["tvars"])
                task_num += result_dict["task_num"]
                task_num_dict[task_type] = result_dict["task_num"]
            elif mode == tf.estimator.ModeKeys.EVAL:
                features[task_type] = result_dict["feature"]

        hook_dict["total_loss"] = total_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn = model_io.ModelIO(model_io_config)

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(list(set(tvars)),
                                     string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)

            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, list(set(tvars)), opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver(optimizer_fn.opt)

                if kargs.get("task_index", 1) == 1 and kargs.get(
                        "run_config", None):
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                elif kargs.get("task_index", 1) == 1:
                    training_hooks = []
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

            if output_type == "sess":
                return {
                    "train": {
                        "total_loss": total_loss,
                        "loss": losses_dict,
                        "logits": logits_dict,
                        "train_op": train_op,
                        "task_num_dict": task_num_dict
                    },
                    "hooks": train_hooks
                }
            elif output_type == "estimator":

                hook_dict['learning_rate'] = optimizer_fn.learning_rate
                logging_hook = tf.train.LoggingTensorHook(hook_dict,
                                                          every_n_iter=100)
                training_hooks.append(logging_hook)

                print("==hook_dict==")

                print(hook_dict)

                for key in hook_dict:
                    tf.summary.scalar(key, hook_dict[key])
                    for index, task_type in enumerate(task_type_dict.keys()):
                        tmp = "{}_loss".format(task_type)
                        if tmp == key:
                            tf.summary.scalar(
                                "loss_gap_{}".format(task_type),
                                hook_dict["total_loss"] - hook_dict[key])
                for key in task_num_dict:
                    tf.summary.scalar(key + "_task_num", task_num_dict[key])

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:  # eval execute for each class solo

            def metric_fn(logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "logits": logits_dict,
                        "total_loss": total_loss,
                        "feature": features,
                        "loss": losses_dict
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = {}
                for key in logits_dict:
                    eval_dict = metric_fn(logits_dict[key],
                                          features_task_dict[key]["label_ids"])
                    for sub_key in eval_dict.keys():
                        eval_key = "{}_{}".format(key, sub_key)
                        eval_metric_ops[eval_key] = eval_dict[sub_key]
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss / task_num,
                    eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)
        label_ids = features["label_ids"]

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)

        print(kargs.get("temperature", 0.5),
              kargs.get("distillation_ratio", 0.5),
              "==distillation hyparameter==")

        # anneal_fn = anneal_strategy.AnnealStrategy(kargs.get("anneal_config", {}))

        # get teacher logits
        teacher_logit = tf.log(features["label_probs"] + 1e-10) / kargs.get(
            "temperature", 2.0)  # log_softmax logits
        student_logit = tf.nn.log_softmax(
            logits / kargs.get("temperature", 2.0))  # log_softmax logits

        distillation_loss = kd_distance(
            teacher_logit, student_logit,
            kargs.get("distillation_distance", "kd"))
        distillation_loss *= features["distillation_ratio"]
        distillation_loss = tf.reduce_sum(distillation_loss) / (
            1e-10 + tf.reduce_sum(features["distillation_ratio"]))

        label_loss = tf.reduce_sum(
            per_example_loss * features["label_ratio"]) / (
                1e-10 + tf.reduce_sum(features["label_ratio"]))

        print(
            "==distillation loss ratio==",
            kargs.get("distillation_ratio", 0.9) *
            tf.pow(kargs.get("temperature", 2.0), 2))

        # loss = label_loss + kargs.get("distillation_ratio", 0.9)*tf.pow(kargs.get("temperature", 2.0), 2)*distillation_loss
        loss = (1 -
                kargs.get("distillation_ratio", 0.9)) * label_loss + tf.pow(
                    kargs.get("temperature", 2.0), 2) * kargs.get(
                        "distillation_ratio", 0.9) * distillation_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        params_size = model_io_fn.count_params(model_config.scope)
        print("==total params==", params_size)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "logits": logits,
                            "train_op": train_op,
                            "cross_entropy": label_loss,
                            "kd_loss": distillation_loss,
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num": tf.reduce_sum(features["label_ratio"]),
                            "teacher_logit": teacher_logit,
                            "student_logit": student_logit,
                            "label_ratio": features["label_ratio"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 8
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        return_dict = {}

        if kargs.get("noise_true_distribution", True):
            model = model_api(
                model_config,
                features,
                labels,
                mode,
                target,
                reuse=tf.AUTO_REUSE,
                scope=generator_scope_prefix,  # need to add noise scope to lm
                **kargs)

            sequence_mask = tf.to_float(
                tf.not_equal(features['input_ids'][:, 1:],
                             kargs.get('[PAD]', 0)))

            # batch x seq_length
            seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=features['input_ids'][:, 1:],
                logits=model.get_sequence_output_logits()[:, :-1])

            if not kargs.get("prob_ln", False):
                tf.logging.info(
                    "****** sum of plogprob as sentence probability of noise true data *******"
                )
                logits = tf.reduce_sum(
                    seq_loss * sequence_mask,
                    axis=-1)  #/ (tf.reduce_sum(sequence_mask, axis=-1)+1e-10)
            else:
                tf.logging.info(
                    "****** sum of plogprob with length normalization as sentence probability of noise true data *******"
                )
                logits = tf.reduce_sum(seq_loss * sequence_mask, axis=-1) / (
                    tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
            # since sparse_softmax_cross_entropy_with_logits will output -logits for minimization
            # while we actually need the log_prob, so we need to minus logits
            return_dict['true_logits'] = -logits
            return_dict['true_seq_logits'] = model.get_sequence_output_logits()
            tf.logging.info("****** noise distribution for true data *******")

        noise_estimator_type = kargs.get("noise_estimator_type",
                                         "straight_through")
        tf.logging.info("****** noise estimator for nce: %s *******",
                        noise_estimator_type)

        # with tf.variable_scope("noise", reuse=tf.AUTO_REUSE):
        # 	noise_global_step = tf.get_variable(
        # 						"global_step",
        # 						shape=[],
        # 						initializer=tf.constant_initializer(0, dtype=tf.int64),
        # 						trainable=False,
        # 						dtype=tf.int64)
        # return_dict['global_step'] = noise_global_step

        if kargs.get("sample_noise_dist", True):
            tf.logging.info("****** noise distribution for fake data *******")

            temp_adapt = kargs.get("gumbel_adapt", "exp")
            temper = kargs.get("gumbel_inv_temper", 100)
            num_train_steps = kargs.get("num_train_steps", 100000)

            # step = tf.cast(return_dict['global_step'], tf.float32)
            step = tf.cast(tf.train.get_or_create_global_step(), tf.float32)

            temperature = get_fixed_temperature(temper, step, num_train_steps,
                                                temp_adapt)

            sample_type = kargs.get("sample_type", "cache_sample")

            if sample_type == 'none_cache_sample':
                sample_sequence_api = bert_seq_tpu_utils.sample_sequence_without_cache
                if_bp = False
                if_cache_decode = False
                tf.logging.info("****** noise sample without cache *******")
            elif sample_type == 'cache_sample':
                sample_sequence_api = bert_seq_tpu_utils.sample_sequence
                if_bp = True
                if_cache_decode = True
                tf.logging.info("****** noise sample with cache *******")
            else:
                sample_sequence_api = bert_seq_tpu_utils.sample_sequence_without_cache
                if_bp = False
                if_cache_decode = False
                tf.logging.info("****** noise sample without cache *******")
            tf.logging.info("****** max_length: %s *******",
                            str(kargs.get('max_length', 512)))

            if noise_estimator_type in ["straight_through", "soft"]:
                back_prop = True
                tf.logging.info("****** st or soft with bp: %s *******",
                                str(back_prop))
            else:
                back_prop = False
                tf.logging.info("****** hard without bp: %s *******",
                                str(back_prop))

            results = sample_sequence_api(
                model_api,
                model_config,
                tf.estimator.ModeKeys.TRAIN,
                features,
                target="",
                start_token=kargs.get("start_token_id", 101),
                batch_size=None,
                context=features["input_ids"][:, :10],
                temperature=1.0,
                n_samples=kargs.get("n_samples", 1),
                top_k=0,
                end_token=kargs.get("end_token_id", 102),
                greedy_or_sample="greedy",
                gumbel_temp=temperature,
                estimator=noise_estimator_type,
                back_prop=back_prop,
                swap_memory=True,
                seq_type=kargs.get("seq_type", "seq2seq"),
                mask_type=kargs.get("mask_type", "left2right"),
                attention_type=kargs.get('attention_type', 'normal_attention'),
                scope=generator_scope_prefix,  # need to add noise scope to lm,
                max_length=max(int(kargs.get('max_length', 512) / 6), 42),
                if_bp=if_bp,
                if_cache_decode=if_cache_decode)

            if noise_estimator_type in ["straight_through", "soft"]:
                tf.logging.info("****** using apply gumbel samples *******")
                gumbel_probs = results['gumbel_probs']
            else:
                gumbel_probs = tf.cast(results['samples'], tf.int32)
                tf.logging.info(
                    "****** using apply stop gradient samples *******")
            return_dict['gumbel_probs'] = tf.cast(gumbel_probs, tf.float32)
            sample_mask = results['mask_sequence']
            if not kargs.get("prob_ln", False):
                tf.logging.info(
                    "****** sum of plogprob as sentence probability of noise sampled data *******"
                )
                return_dict['fake_logits'] = tf.reduce_sum(
                    results['logits'] * tf.cast(sample_mask, tf.float32),
                    axis=-1
                )  #/ tf.reduce_sum(1e-10+tf.cast(sample_mask, tf.float32), axis=-1)
            else:
                tf.logging.info(
                    "****** sum of plogprob with length normalization as sentence probability of noise sampled data *******"
                )
                return_dict['fake_logits'] = tf.reduce_sum(
                    results['logits'] * tf.cast(sample_mask, tf.float32),
                    axis=-1) / tf.reduce_sum(
                        1e-10 + tf.cast(sample_mask, tf.float32), axis=-1)
            return_dict['fake_samples'] = tf.cast(results['samples'], tf.int32)
            return_dict['fake_mask'] = results['mask_sequence']

            print(return_dict['fake_samples'].get_shape(),
                  return_dict['fake_logits'].get_shape(),
                  results['logits'].get_shape(),
                  "====fake samples, logitss, shape===")

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        return_dict['tvars'] = pretrained_tvars

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        return return_dict
Exemplo n.º 9
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if 'input_mask' not in features:
            input_mask = tf.cast(
                tf.not_equal(features['input_ids_{}'.format(target)],
                             kargs.get('[PAD]', 0)), tf.int32)

            if target:
                features['input_mask_{}'.format(target)] = input_mask
            else:
                features['input_mask'] = input_mask
        if 'segment_ids' not in features:
            segment_ids = tf.zeros_like(input_mask)
            if target:
                features['segment_ids_{}'.format(target)] = segment_ids
            else:
                features['segment_ids'] = segment_ids

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_mask'] = features['input_mask_{}'.format(target)]
            features['segment_ids'] = features['segment_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]

        input_ori_ids = features.get('input_ori_ids', None)
        if mode == tf.estimator.ModeKeys.TRAIN:
            if input_ori_ids is not None:
                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(
                # 							model_config,
                # 							input_ori_ids,
                # 							features['input_mask'],
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.1,
                    original_probability=0.1,
                    mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info(
                    "***** Running random sample input generation *****")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        #(nsp_loss,
        # nsp_per_example_loss,
        # nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
        #								model.get_pooled_output(),
        #								features['next_sentence_labels'],
        #								reuse=tf.AUTO_REUSE)

        # masked_lm_positions = features["masked_lm_positions"]
        # masked_lm_ids = features["masked_lm_ids"]
        # masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:

            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]

            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        if load_pretrained == "yes":
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=1)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=opt_config.use_tpu)

                #	train_metric_dict = train_metric_fn(
                #			masked_lm_example_loss, masked_lm_log_probs,
                #			masked_lm_ids,
                #			masked_lm_mask,
                #			nsp_per_example_loss,
                #			nsp_log_prob,
                #			features['next_sentence_labels'],
                #			masked_lm_mask=masked_lm_mask
                #		)

                # for key in train_metric_dict:
                # 	tf.summary.scalar(key, train_metric_dict[key])
                # tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    scaffold_fn=scaffold_fn)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            eval_metrics = (metric_fn, [
                masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                masked_lm_mask, nsp_per_example_loss, nsp_log_prob,
                features['next_sentence_labels']
            ])

            estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                mode=mode,
                loss=loss,
                eval_metrics=eval_metrics,
                scaffold_fn=scaffold_fn)

            return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 10
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=model_reuse)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]
        (masked_lm_loss, masked_lm_example_loss,
         masked_lm_log_probs) = pretrain.get_masked_lm_output(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=model_reuse)
        loss = model_config.lm_ratio * masked_lm_loss + model_config.nsp_ratio * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            lm_pretrain_tvars = model_io_fn.get_params(
                "cls", not_storage_params=not_storage_params)

            pretrained_tvars.extend(lm_pretrain_tvars)

            optimizer_fn = optimizer.Optimizer(opt_config)

            if load_pretrained:
                model_io_fn.load_pretrained(pretrained_tvars,
                                            init_checkpoint,
                                            exclude_scope=exclude_scope)

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):

                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                train_op, hooks = model_io_fn.get_ema_hooks(
                    train_op,
                    tvars,
                    kargs.get('params_moving_average_decay', 0.99),
                    scope,
                    mode,
                    first_stage_steps=opt_config.num_warmup_steps,
                    two_stage=True)

                model_io_fn.set_saver()

                train_metric_dict = train_metric_fn(
                    masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
                    masked_lm_weights, nsp_per_example_loss, nsp_log_prob,
                    features['next_sentence_labels'])

                for key in train_metric_dict:
                    tf.summary.scalar(key, train_metric_dict[key])
                tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=loss,
                                                            train_op=train_op)

                if output_type == "sess":
                    return {
                        "train": {
                            "loss": loss,
                            "nsp_log_pro": nsp_log_prob,
                            "train_op": train_op,
                            "masked_lm_loss": masked_lm_loss,
                            "next_sentence_loss": nsp_loss,
                            "masked_lm_log_pro": masked_lm_log_probs
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:

            def prediction_fn(logits):

                predictions = {
                    "nsp_classes":
                    tf.argmax(input=nsp_log_prob, axis=1),
                    "nsp_probabilities":
                    tf.exp(nsp_log_prob, name="nsp_softmax"),
                    "masked_vocab_classes":
                    tf.argmax(input=masked_lm_log_probs, axis=1),
                    "masked_probabilities":
                    tf.exp(masked_lm_log_probs, name='masked_softmax')
                }
                return predictions

            predictions = prediction_fn(logits)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs={
                    "output": tf.estimator.export.PredictOutput(predictions)
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(masked_lm_example_loss, masked_lm_log_probs,
                          masked_lm_ids, masked_lm_weights,
                          next_sentence_example_loss, next_sentence_log_probs,
                          next_sentence_labels):
                """Computes the loss and accuracy of the model."""
                masked_lm_log_probs = tf.reshape(
                    masked_lm_log_probs, [-1, masked_lm_log_probs.shape[-1]])
                masked_lm_predictions = tf.argmax(masked_lm_log_probs,
                                                  axis=-1,
                                                  output_type=tf.int32)
                masked_lm_example_loss = tf.reshape(masked_lm_example_loss,
                                                    [-1])
                masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
                masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
                masked_lm_accuracy = tf.metrics.accuracy(
                    labels=masked_lm_ids,
                    predictions=masked_lm_predictions,
                    weights=masked_lm_weights)
                masked_lm_mean_loss = tf.metrics.mean(
                    values=masked_lm_example_loss, weights=masked_lm_weights)

                next_sentence_log_probs = tf.reshape(
                    next_sentence_log_probs,
                    [-1, next_sentence_log_probs.shape[-1]])
                next_sentence_predictions = tf.argmax(next_sentence_log_probs,
                                                      axis=-1,
                                                      output_type=tf.int32)
                next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
                next_sentence_accuracy = tf.metrics.accuracy(
                    labels=next_sentence_labels,
                    predictions=next_sentence_predictions)
                next_sentence_mean_loss = tf.metrics.mean(
                    values=next_sentence_example_loss)

                return {
                    "masked_lm_accuracy": masked_lm_accuracy,
                    "masked_lm_loss": masked_lm_mean_loss,
                    "next_sentence_accuracy": next_sentence_accuracy,
                    "next_sentence_loss": next_sentence_mean_loss
                }

            if output_type == "sess":
                return {
                    "eval": {
                        "nsp_log_prob": nsp_log_prob,
                        "masked_lm_log_prob": masked_lm_log_probs,
                        "nsp_loss": nsp_loss,
                        "masked_lm_loss": masked_lm_loss,
                        "feature": model.get_pooled_output()
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = metric_fn(masked_lm_example_loss,
                                            masked_lm_log_probs, masked_lm_ids,
                                            masked_lm_weights,
                                            nsp_per_example_loss, nsp_log_prob,
                                            features['next_sentence_labels'])
                _, hooks = model_io_fn.get_ema_hooks(
                    None, None, kargs.get('params_moving_average_decay', 0.99),
                    scope, mode)

                eval_hooks = [hooks] if hooks else []

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops,
                    evaluation_hooks=eval_hooks)
                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 11
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE):
			(loss, 
			logits, 
			per_example_loss) = classifier(model_config, 
									model.get_sequence_output(),
									features['input_ori_ids'],
									features['ori_input_ids'],
									features['input_mask'],
									2,
									dropout_prob,
									ori_sampled_ids=features.get('ori_sampled_ids', None),
									use_tpu=kargs.get('use_tpu', True))
	
		tf.add_to_collection("discriminator_loss", loss)
		loss += 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", 
									not_storage_params=not_storage_params)
		lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_seq_prediction_tvars)
		pretrained_tvars.extend(lm_pretrain_tvars)
		tvars = pretrained_tvars

		print('==discriminator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu,
											restore_var_name=model_config.get('restore_var_name', []))
		else:
			scaffold_fn = None
		return_dict = {
					"loss":loss, 
					"logits":logits,
					"tvars":tvars,
					"model":model,
					"per_example_loss":per_example_loss
				}
		return return_dict
Exemplo n.º 12
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)
        model = model_api(model_config,
                          features,
                          labels,
                          tf.estimator.ModeKeys.PREDICT,
                          target,
                          reuse=model_reuse,
                          cnn_type=model_config.get('cnn_type', 'bi_dgcnn'),
                          **kargs)

        dropout_prob = 0.0
        is_training = False

        with tf.variable_scope(model_config.scope + "/feature_output",
                               reuse=tf.AUTO_REUSE):
            hidden_size = bert_utils.get_shape_list(model.get_pooled_output(),
                                                    expected_rank=2)[-1]
            sentence_pres = model.get_pooled_output()

            sentence_pres = tf.layers.dense(
                sentence_pres,
                128,
                use_bias=True,
                activation=tf.tanh,
                kernel_initializer=tf.truncated_normal_initializer(
                    stddev=0.01))

            # sentence_pres = tf.layers.dense(
            # 				model.get_pooled_output(),
            # 				hidden_size,
            # 				use_bias=None,
            # 				activation=tf.nn.relu,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

            # sentence_pres = tf.layers.dense(
            # 				sentence_pres,
            # 				hidden_size,
            # 				use_bias=None,
            # 				activation=None,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))

            # hidden_size = bert_utils.get_shape_list(model.get_pooled_output(), expected_rank=2)[-1]
            # sentence_pres = tf.layers.dense(
            # 			model.get_pooled_output(),
            # 			hidden_size,
            # 			use_bias=True,
            # 			activation=tf.tanh,
            # 			kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
            # feature_output_a = tf.layers.dense(
            # 				model.get_pooled_output(),
            # 				hidden_size,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01))
            # feature_output_a = tf.nn.dropout(feature_output_a, keep_prob=1 - dropout_prob)
            # feature_output_a += model.get_pooled_output()
            # sentence_pres = tf.layers.dense(
            # 				feature_output_a,
            # 				hidden_size,
            # 				kernel_initializer=tf.truncated_normal_initializer(stddev=0.01),
            # 				activation=tf.tanh)

        if kargs.get('apply_head_proj', False):
            with tf.variable_scope(model_config.scope + "/head_proj",
                                   reuse=tf.AUTO_REUSE):
                sentence_pres = simclr_utils.projection_head(
                    sentence_pres,
                    is_training,
                    head_proj_dim=128,
                    num_nlh_layers=1,
                    head_proj_mode='nonlinear',
                    name='head_contrastive')

        l2_sentence_pres = tf.nn.l2_normalize(sentence_pres + 1e-20, axis=-1)

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)

        try:
            params_size = model_io_fn.count_params(model_config.scope)
            print("==total params==", params_size)
        except:
            print("==not count params==")
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        estimator_spec = tf.estimator.EstimatorSpec(
            mode=tf.estimator.ModeKeys.PREDICT,
            predictions={
                'sentence_pres': l2_sentence_pres,
                # "before_l2":sentence_pres
            },
            export_outputs={
                "output":
                tf.estimator.export.PredictOutput({
                    'sentence_pres':
                    l2_sentence_pres,
                    # "before_l2":sentence_pres
                })
            })
        return estimator_spec
Exemplo n.º 13
0
    def model_fn(features, labels, mode):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=model_reuse)

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            (loss, per_example_loss,
             logits) = classifier.classifier(model_config,
                                             model.get_pooled_output(),
                                             num_labels, label_ids,
                                             dropout_prob)
            label_loss = tf.reduce_sum(
                per_example_loss * features["label_ratio"]) / (
                    1e-10 + tf.reduce_sum(features["label_ratio"]))

        if mode == tf.estimator.ModeKeys.TRAIN:

            distillation_api = distill.KnowledgeDistillation(
                kargs.get(
                    "disitllation_config",
                    Bunch({
                        "logits_ratio_decay": "constant",
                        "logits_ratio": 0.5,
                        "logits_decay_rate": 0.999,
                        "distillation": ['relation_kd'],
                        "feature_ratio": 0.5,
                        "feature_ratio_decay": "constant",
                        "feature_decay_rate": 0.999,
                        "kd_type": "kd",
                        "scope": scope
                    })))
            # get teacher logits
            teacher_logit = tf.log(features["label_probs"] +
                                   1e-10) / kargs.get(
                                       "temperature",
                                       2.0)  # log_softmax logits
            student_logit = tf.nn.log_softmax(
                logits / kargs.get("temperature", 2.0))  # log_softmax logits

            distillation_features = {
                "student_logits_tensor": student_logit,
                "teacher_logits_tensor": teacher_logit,
                "student_feature_tensor": model.get_pooled_output(),
                "teacher_feature_tensor": features["distillation_feature"],
                "student_label": tf.ones_like(label_ids, dtype=tf.int32),
                "teacher_label": tf.zeros_like(label_ids, dtype=tf.int32),
                "logits_ratio": kargs.get("logits_ratio", 0.5),
                "feature_ratio": kargs.get("logits_ratio", 0.5),
                "distillation_ratio": features["distillation_ratio"],
                "src_f_logit": logits,
                "tgt_f_logit": logits,
                "src_tensor": model.get_pooled_output(),
                "tgt_tensor": features["distillation_feature"]
            }

            distillation_loss = distillation_api.distillation(
                distillation_features,
                2,
                dropout_prob,
                model_reuse,
                opt_config.num_train_steps,
                feature_ratio=10,
                logits_ratio_decay="constant",
                feature_ratio_decay="constant",
                feature_decay_rate=0.999,
                logits_decay_rate=0.999,
                logits_ratio=0.5,
                scope=scope + "/adv_classifier",
                num_classes=num_labels,
                gamma=kargs.get("gamma", 4))

            loss = label_loss + distillation_loss["distillation_loss"]

        model_io_fn = model_io.ModelIO(model_io_config)

        tvars = model_io_fn.get_params(model_config.scope,
                                       not_storage_params=not_storage_params)
        print(tvars)
        if load_pretrained == "yes":
            model_io_fn.load_pretrained(tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        if mode == tf.estimator.ModeKeys.TRAIN:

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver()

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    training_hooks = []
                elif kargs.get("task_index", 1) == 0:
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=loss,
                    train_op=train_op,
                    training_hooks=training_hooks)
                if output_type == "sess":

                    return {
                        "train": {
                            "loss":
                            loss,
                            "logits":
                            logits,
                            "train_op":
                            train_op,
                            "cross_entropy":
                            label_loss,
                            "distillation_loss":
                            distillation_loss["distillation_loss"],
                            "kd_num":
                            tf.reduce_sum(features["distillation_ratio"]),
                            "ce_num":
                            tf.reduce_sum(features["label_ratio"]),
                            "label_ratio":
                            features["label_ratio"],
                            "distilaltion_logits_loss":
                            distillation_loss["distillation_logits_loss"],
                            "distilaltion_feature_loss":
                            distillation_loss["distillation_feature_loss"],
                            "rkd_loss":
                            distillation_loss["rkd_loss"]
                        },
                        "hooks": training_hooks
                    }
                elif output_type == "estimator":
                    return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            max_prob = tf.reduce_max(prob, axis=-1)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode,
                predictions={
                    'pred_label': pred_label,
                    "max_prob": max_prob
                },
                export_outputs={
                    "output":
                    tf.estimator.export.PredictOutput({
                        'pred_label': pred_label,
                        "max_prob": max_prob
                    })
                })
            return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            def metric_fn(per_example_loss, logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_mean_loss = tf.metrics.mean(values=per_example_loss)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            eval_metric_ops = metric_fn(per_example_loss, logits, label_ids)

            estimator_spec = tf.estimator.EstimatorSpec(
                mode=mode, loss=loss, eval_metric_ops=eval_metric_ops)

            if output_type == "sess":
                return {
                    "eval": {
                        "per_example_loss": per_example_loss,
                        "logits": logits,
                        "loss": tf.reduce_mean(per_example_loss)
                    }
                }
            elif output_type == "estimator":
                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 14
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE,
										scope=generator_scope_prefix)

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
			print("==apply bert masked lm==")

		(_,
			_, 
			masked_lm_log_probs,
			_) = seq_masked_lm_fn(model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										features['input_mask'], 
										features['input_ori_ids'], 
										features['input_ids'],
										features['input_mask'],
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table(),
										scope=generator_scope_prefix)

		print(model_config.lm_ratio, '==mlm lm_ratio==')
		# loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		if generator_scope_prefix:
			"""
			"generator/cls/predictions"
			"""
			lm_pretrain_tvars = model_io_fn.get_params(generator_scope_prefix+"/cls/predictions", 
										not_storage_params=not_storage_params)

			nsp_pretrain_vars = model_io_fn.get_params(generator_scope_prefix+"/cls/seq_relationship",
										not_storage_params=not_storage_params)
		else:
			lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
										not_storage_params=not_storage_params)

			nsp_pretrain_vars = model_io_fn.get_params("cls/seq_relationship",
										not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)
		pretrained_tvars.extend(nsp_pretrain_vars)
		tvars = pretrained_tvars

		print('==generator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None

		if mode == tf.estimator.ModeKeys.PREDICT:
			mask = tf.expand_dims(tf.cast(features['input_mask'], tf.float32), axis=-1)
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												"probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs))
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														"probs":mask*tf.exp(tf.nn.log_softmax(masked_lm_log_probs))
													}
												)
									}
						)
			return estimator_spec
Exemplo n.º 15
0
    def model_fn(self, features, labels, model_reuse):
        model_api = model_zoo(self.model_config)

        model_lst = []

        assert len(self.target.split(",")) == 2
        target_name_lst = self.target.split(",")
        print(target_name_lst)
        for index, name in enumerate(target_name_lst):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                model_api(self.model_config,
                          features,
                          labels,
                          mode,
                          name,
                          reuse=reuse))

        label_ids = features["label_ids"]

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = self.model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if self.model_io_config.fix_lm == True:
            scope = self.model_config.scope + "_finetuning"
        else:
            scope = self.model_config.scope

        with tf.variable_scope(scope, reuse=self.model_reuse):
            seq_output_lst = [model.get_pooled_output() for model in model_lst]
            if self.model_config.get("classifier",
                                     "order_classifier") == "order_classifier":
                [loss, per_example_loss,
                 logits] = classifier.order_classifier(model_config,
                                                       seq_output_lst,
                                                       num_labels,
                                                       label_ids,
                                                       dropout_prob,
                                                       ratio_weight=None)
            elif model_config.get(
                    "classifier",
                    "order_classifier") == "siamese_interaction_classifier":
                [loss, per_example_loss,
                 logits] = classifier.siamese_classifier(model_config,
                                                         seq_output_lst,
                                                         self.num_labels,
                                                         label_ids,
                                                         dropout_prob,
                                                         ratio_weight=None)

        params_size = self.model_io_fn.count_params(self.model_config.scope)
        print("==total params==", params_size)

        self.tvars = model_io_fn.get_params(
            self.model_config.scope,
            not_storage_params=self.not_storage_params)
        print(tvars)
        if self.load_pretrained == "yes":
            self.model_io_fn.load_pretrained(self.tvars,
                                             self.init_checkpoint,
                                             exclude_scope=self.exclude_scope)
        self.loss = loss
        self.per_example_loss = per_example_loss
        self.logits = logits

        return self.loss
Exemplo n.º 16
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE,
							**kargs)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		(_, 
		 _, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		with tf.variable_scope('cls/seq_predictions', reuse=tf.AUTO_REUSE):
			(_, 
			logits, 
			_) = classifier(model_config, 
									model.get_sequence_output(),
									features['input_ori_ids'],
									features['input_ids'],
									features['input_mask'],
									2,
									dropout_prob)
									# ,
									# loss='focal_loss')

		# loss += 0.0 * nsp_loss

		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		lm_seq_prediction_tvars = model_io_fn.get_params("cls/seq_predictions", 
									not_storage_params=not_storage_params)
		lm_pretrain_tvars = model_io_fn.get_params("cls/seq_relationship", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_seq_prediction_tvars)
		pretrained_tvars.extend(lm_pretrain_tvars)
		tvars = pretrained_tvars

		print('==discriminator parameters==', tvars)

		if load_pretrained == "yes":
			use_tpu = 1 if kargs.get('use_tpu', False) else 0
			scaffold_fn = model_io_fn.load_pretrained(tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=use_tpu)
		else:
			scaffold_fn = None
		
		if mode == tf.estimator.ModeKeys.PREDICT:
			mask = tf.cast(tf.expand_dims(features['input_mask'], axis=-1), tf.float32)
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												"probs":tf.nn.softmax(logits)*mask
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														"probs":tf.nn.softmax(logits)*mask
													}
												)
									}
						)
			return estimator_spec
Exemplo n.º 17
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope='generator')

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            print("==apply bert masked lm==")

        (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
         masked_lm_mask) = masked_lm_fn(
             model_config,
             model.get_sequence_output(),
             model.get_embedding_table(),
             masked_lm_positions,
             masked_lm_ids,
             masked_lm_weights,
             reuse=tf.AUTO_REUSE,
             embedding_projection=model.get_embedding_projection_table(),
             scope='generator')
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ model_config.nsp_ratio * nsp_loss

        sampled_ids = token_generator(
            model_config,
            model.get_sequence_output(),
            model.get_embedding_table(),
            features['input_ids'],
            features['input_ori_ids'],
            features['input_mask'],
            embedding_projection=model.get_embedding_projection_table(),
            scope='generator',
            mask_method='only_mask')

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "generator/cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope="generator",
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_positions": masked_lm_positions,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_weights,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss
        }
        return return_dict
Exemplo n.º 18
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]
        sequence_mask = tf.cast(
            tf.not_equal(features['input_ori_ids'], kargs.get('[PAD]', 0)),
            tf.int32)
        features['input_mask'] = sequence_mask

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        if 'input_ori_ids' in features:
            seq_features['input_ids'] = features["input_ori_ids"]
        else:
            features['input_ori_ids'] = seq_features['input_ids']

        not_equal = tf.cast(
            tf.not_equal(features["input_ori_ids"],
                         tf.zeros_like(features["input_ori_ids"])), tf.int32)
        not_equal = tf.reduce_sum(not_equal, axis=-1)
        loss_mask = tf.cast(tf.not_equal(not_equal, tf.zeros_like(not_equal)),
                            tf.float32)

        if not kargs.get('use_tpu', False):
            tf.summary.scalar('loss_mask', tf.reduce_sum(loss_mask))

        casual_flag = model_config.get('is_casual', True)
        tf.logging.info("***** is casual flag *****", str(casual_flag))

        if not casual_flag:
            [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                model_config,
                features['input_ori_ids'],
                features['input_mask'], [
                    tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                    for hmm_tran_prob in hmm_tran_prob_list
                ],
                mask_probability=0.02,
                replace_probability=0.01,
                original_probability=0.01,
                mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                **kargs)
            tf.logging.info("***** apply random sampling *****")
            seq_features['input_ids'] = output_ids

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        # if mode == tf.estimator.ModeKeys.TRAIN:
        if kargs.get('mask_type', 'left2right') == 'left2right':
            tf.logging.info("***** using left2right mask and loss *****")
            sequence_mask = tf.to_float(
                tf.not_equal(features['input_ori_ids'][:, 1:],
                             kargs.get('[PAD]', 0)))
        elif kargs.get('mask_type', 'left2right') == 'seq2seq':
            tf.logging.info("***** using seq2seq mask and loss *****")
            sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
            if not kargs.get('use_tpu', False):
                tf.summary.scalar("loss mask", tf.reduce_mean(sequence_mask))

        # batch x seq_length
        if casual_flag:
            print(model.get_sequence_output_logits().get_shape(),
                  "===logits shape===")
            seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                labels=features['input_ori_ids'][:, 1:],
                logits=model.get_sequence_output_logits()[:, :-1])

            per_example_loss = tf.reduce_sum(
                seq_loss * sequence_mask,
                axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
            loss = tf.reduce_mean(per_example_loss)

            if model_config.get("cnn_type",
                                "dgcnn") in ['bi_dgcnn', 'bi_light_dgcnn']:
                seq_backward_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ori_ids'][:, :-1],
                    logits=model.get_sequence_backward_output_logits()[:, 1:])

                per_backward_example_loss = tf.reduce_sum(
                    seq_backward_loss * sequence_mask,
                    axis=-1) / (tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
                backward_loss = tf.reduce_mean(per_backward_example_loss)
                loss += backward_loss
                tf.logging.info("***** using backward loss *****")
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = pretrain.seq_mask_masked_lm_output(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 seq_features['input_mask'],
                 seq_features['input_ori_ids'],
                 seq_features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            loss = masked_lm_loss
            tf.logging.info("***** using masked lm loss *****")
        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                # train_metric_dict = train_metric(features['input_ori_ids'],
                # 								model.get_sequence_output_logits(),
                # 								seq_features,
                # 								**kargs)

                # if not kargs.get('use_tpu', False):
                # 	for key in train_metric_dict:
                # 		tf.summary.scalar(key, train_metric_dict[key])
                # 	tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)
                # 	tf.logging.info("***** logging metric *****")
                # 	tf.summary.scalar("causal_attenion_mask_length", tf.reduce_sum(sequence_mask))
                # tf.summary.scalar("bi_attenion_mask_length", tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits(),
                                           seq_features, **kargs)
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits(), seq_features,
                kargs.get('mask_type', 'left2right')
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec

        elif mode == tf.estimator.ModeKeys.PREDICT:
            if kargs.get('predict_type',
                         'sample_sequence') == 'sample_sequence':
                results = bert_seq_sample_utils.sample_sequence(
                    model_api,
                    model_config,
                    mode,
                    features,
                    target="",
                    start_token=kargs.get("start_token_id", 101),
                    batch_size=None,
                    context=features.get("context", None),
                    temperature=kargs.get("sample_temp", 1.0),
                    n_samples=kargs.get("n_samples", 1),
                    top_k=0,
                    end_token=kargs.get("end_token_id", 102),
                    greedy_or_sample="greedy",
                    gumbel_temp=0.01,
                    estimator="stop_gradient",
                    back_prop=True,
                    swap_memory=True,
                    seq_type=kargs.get("seq_type", "seq2seq"),
                    mask_type=kargs.get("mask_type", "seq2seq"),
                    attention_type=kargs.get('attention_type',
                                             'normal_attention'))
                # stop_gradient output:
                # samples, mask_sequence, presents, logits, final

                sampled_token = results['samples']
                sampled_token_logits = results['logits']
                mask_sequence = results['mask_sequence']

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': sampled_token,
                        "logits": sampled_token_logits,
                        "mask_sequence": mask_sequence
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            sampled_token,
                            "logits":
                            sampled_token_logits,
                            "mask_sequence":
                            mask_sequence
                        })
                    })

                return estimator_spec

            elif kargs.get('predict_type',
                           'sample_sequence') == 'infer_inputs':

                sequence_mask = tf.to_float(
                    tf.not_equal(features['input_ids'][:, 1:],
                                 kargs.get('[PAD]', 0)))

                if kargs.get('mask_type', 'left2right') == 'left2right':
                    tf.logging.info(
                        "***** using left2right mask and loss *****")
                    sequence_mask = tf.to_float(
                        tf.not_equal(features['input_ori_ids'][:, 1:],
                                     kargs.get('[PAD]', 0)))
                elif kargs.get('mask_type', 'left2right') == 'seq2seq':
                    tf.logging.info("***** using seq2seq mask and loss *****")
                    sequence_mask = tf.to_float(features['segment_ids'][:, 1:])
                    if not kargs.get('use_tpu', False):
                        tf.summary.scalar("loss mask",
                                          tf.reduce_mean(sequence_mask))

                output_logits = model.get_sequence_output_logits()[:, :-1]
                # output_logits = tf.nn.log_softmax(output_logits, axis=-1)

                output_id_logits = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=features['input_ids'][:, 1:], logits=output_logits)

                per_example_perplexity = tf.reduce_sum(output_id_logits *
                                                       sequence_mask,
                                                       axis=-1)  # batch
                per_example_perplexity /= tf.reduce_sum(sequence_mask,
                                                        axis=-1)  # batch

                perplexity = tf.exp(per_example_perplexity)

                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions={
                        'token': features['input_ids'][:, 1:],
                        "logits": output_id_logits,
                        'perplexity': perplexity,
                        # "all_logits":output_logits
                    },
                    export_outputs={
                        "output":
                        tf.estimator.export.PredictOutput({
                            'token':
                            features['input_ids'][:, 1:],
                            "logits":
                            output_id_logits,
                            'perplexity':
                            perplexity,
                            # "all_logits":output_logits
                        })
                    })

                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 19
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.EVAL
            ]:
                input_ori_ids = features['input_ori_ids']

                [output_ids, sampled_binary_mask
                 ] = random_input_ids_generation(model_config,
                                                 features['input_ori_ids'],
                                                 features['input_mask'])
                features['input_ids'] = tf.identity(output_ids)
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
                output_ids = tf.identity(features['input_ids'])
        else:
            sampled_binary_mask = None
            output_ids = tf.identity(features['input_ids'])

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        sampled_ids = token_generator_igr(
            model_config,
            model.get_sequence_output(),
            model.get_embedding_table(),
            features['input_ids'],
            features['input_ori_ids'],
            features['input_mask'],
            embedding_projection=model.get_embedding_projection_table(),
            scope=generator_scope_prefix,
            mask_method='only_mask',
            **kargs)

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
        else:
            scaffold_fn = None

        # tf.add_to_collection("generator_loss", masked_lm_loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "output_ids": output_ids
        }
        return return_dict
Exemplo n.º 20
0
	def model_fn(features, labels, mode):

		model_api = model_zoo(model_config)

		model_lst = []

		assert len(target.split(",")) == 2
		target_name_lst = target.split(",")
		print(target_name_lst)
		for index, name in enumerate(target_name_lst):
			if index > 0:
				reuse = True
			else:
				reuse = model_reuse
			model_lst.append(model_api(model_config, features, labels,
							mode, name, reuse=reuse))

		label_ids = features["label_ids"]

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope

		with tf.variable_scope(scope, reuse=model_reuse):
			seq_output_lst = [model.get_pooled_output() for model in model_lst]
			if model_config.get("classifier", "order_classifier") == "order_classifier":
				[loss, 
					per_example_loss, 
					logits] = classifier.order_classifier(
								model_config, seq_output_lst, 
								num_labels, label_ids,
								dropout_prob, ratio_weight=None)
			elif model_config.get("classifier", "order_classifier") == "siamese_interaction_classifier":
				[loss, 
					per_example_loss, 
					logits] = classifier.siamese_classifier(
								model_config, seq_output_lst, 
								num_labels, label_ids,
								dropout_prob, ratio_weight=None)

		model_io_fn = model_io.ModelIO(model_io_config)

		params_size = model_io_fn.count_params(model_config.scope)
		print("==total params==", params_size)

		tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)
		print(tvars)
		if load_pretrained == "yes":
			model_io_fn.load_pretrained(tvars, 
										init_checkpoint,
										exclude_scope=exclude_scope)

		if mode == tf.estimator.ModeKeys.TRAIN:

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(tvars, string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			with tf.control_dependencies(update_ops):

				train_op = optimizer_fn.get_train_op(loss, tvars, 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver()

				if kargs.get("task_index", 1) == 0 and kargs.get("run_config", None):
					training_hooks = []
				elif kargs.get("task_index", 1) == 0:
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss, train_op=train_op,
								training_hooks=training_hooks)
				if output_type == "sess":
					return {
						"train":{
										"loss":loss, 
										"logits":logits,
										"train_op":train_op
									},
						"hooks":training_hooks
					}
				elif output_type == "estimator":
					return estimator_spec

		elif mode == tf.estimator.ModeKeys.PREDICT:
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			
			
			estimator_spec = tf.estimator.EstimatorSpec(
									mode=mode,
									predictions={
												'pred_label':pred_label,
												"max_prob":max_prob
									},
									export_outputs={
										"output":tf.estimator.export.PredictOutput(
													{
														'pred_label':pred_label,
														"max_prob":max_prob
													}
												)
									}
						)
			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:
			def metric_fn(per_example_loss,
						logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_mean_loss = tf.metrics.mean(
					values=per_example_loss)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			eval_metric_ops = metric_fn( 
							per_example_loss,
							logits, 
							label_ids)
			
			estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=loss,
								eval_metric_ops=eval_metric_ops)

			if output_type == "sess":
				return {
					"eval":{
							"per_example_loss":per_example_loss,
							"logits":logits,
							"loss":tf.reduce_mean(per_example_loss),
							"feature":(seq_output_lst[0]+seq_output_lst[1])/2
						}
				}
			elif output_type == "estimator":
				return estimator_spec
		else:
			raise NotImplementedError()
Exemplo n.º 21
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        if kargs.get('random_generator', '1') == '1':
            if mode in [
                    tf.estimator.ModeKeys.TRAIN, tf.estimator.ModeKeys.TRAIN
            ]:
                input_ori_ids = features['input_ori_ids']

                # [output_ids,
                # sampled_binary_mask] = random_input_ids_generation(model_config,
                # 							features['input_ori_ids'],
                # 							features['input_mask'],
                # 							mask_probability=0.2,
                # 							replace_probability=0.1,
                # 							original_probability=0.1,
                # 							**kargs)

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.2,
                    replace_probability=0.0,
                    original_probability=0.0,
                    mask_prior=tf.constant(mask_prior, tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info("****** do random generator *******")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        (nsp_loss, nsp_per_example_loss,
         nsp_log_prob) = pretrain.get_next_sentence_output(
             model_config,
             model.get_pooled_output(),
             features['next_sentence_labels'],
             reuse=tf.AUTO_REUSE,
             scope=generator_scope_prefix)

        masked_lm_positions = features["masked_lm_positions"]
        masked_lm_ids = features["masked_lm_ids"]
        masked_lm_weights = features["masked_lm_weights"]

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
            masked_lm_ids = features['input_ori_ids']
        else:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table(),
                 scope=generator_scope_prefix)
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss + 0.0 * nsp_loss

        if kargs.get("resample_discriminator", False):
            input_ori_ids = features['input_ori_ids']

            [output_ids, sampled_binary_mask
             ] = random_input_ids_generation(model_config,
                                             features['input_ori_ids'],
                                             features['input_mask'],
                                             mask_probability=0.2,
                                             replace_probability=0.1,
                                             original_probability=0.1)

            resample_features = {}
            for key in features:
                resample_features[key] = features[key]

            resample_features['input_ids'] = tf.identity(output_ids)
            model_resample = model_api(model_config,
                                       resample_features,
                                       labels,
                                       mode,
                                       target,
                                       reuse=tf.AUTO_REUSE,
                                       **kargs)

            tf.logging.info("**** apply discriminator resample **** ")
        else:
            model_resample = model
            resample_features = features
            tf.logging.info("**** not apply discriminator resample **** ")

        sampled_ids = token_generator(model_config,
                                      model_resample.get_sequence_output(),
                                      model_resample.get_embedding_table(),
                                      resample_features['input_ids'],
                                      resample_features['input_ori_ids'],
                                      resample_features['input_mask'],
                                      embedding_projection=model_resample.
                                      get_embedding_projection_table(),
                                      scope=generator_scope_prefix,
                                      mask_method='only_mask',
                                      use_tpu=kargs.get('use_tpu', True))

        if model_config.get('gen_sample', 1) == 1:
            input_ids = features['input_ori_ids']
            input_mask = features['input_mask']
            segment_ids = features['segment_ids']
        else:
            input_ids = tf.expand_dims(features['input_ori_ids'], axis=-1)
            # batch x seq_length x 1
            input_ids = tf.einsum(
                'abc,cd->abd', input_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_ids = tf.cast(input_ids, tf.int32)

            input_shape_list = bert_utils.get_shape_list(input_ids,
                                                         expected_rank=3)
            batch_size = input_shape_list[0]
            seq_length = input_shape_list[1]
            gen_sample = input_shape_list[2]

            sampled_ids = tf.reshape(sampled_ids,
                                     [batch * gen_sample, seq_length])
            input_ids = tf.reshape(input_ids, [batch * gen_sample, seq_length])

            input_mask = tf.expand_dims(features['input_mask'], axis=-1)
            input_mask = tf.einsum(
                'abc,cd->abd', input_mask,
                tf.ones((1, model_config.get('gen_sample', 1))))
            input_mask = tf.cast(input_mask, tf.int32)

            segment_ids = tf.expand_dims(features['segmnet_ids'], axis=-1)
            segment_ids = tf.einsum(
                'abc,cd->abd', segment_ids,
                tf.ones((1, model_config.get('gen_sample', 1))))
            segment_ids = tf.cast(segment_ids, tf.int32)

            segment_ids = tf.reshape(segment_ids,
                                     [batch * gen_sample, seq_length])
            input_mask = tf.reshape(input_mask,
                                    [batch * gen_sample, seq_length])

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        if generator_scope_prefix:
            """
			"generator/cls/predictions"
			"""
            lm_pretrain_tvars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/predictions",
                not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                generator_scope_prefix + "/cls/seq_relationship",
                not_storage_params=not_storage_params)
        else:
            lm_pretrain_tvars = model_io_fn.get_params(
                "cls/predictions", not_storage_params=not_storage_params)

            nsp_pretrain_vars = model_io_fn.get_params(
                "cls/seq_relationship", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        pretrained_tvars.extend(nsp_pretrain_vars)
        tvars = pretrained_tvars

        print('==generator parameters==', tvars)

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None
        tf.add_to_collection("generator_loss", loss)
        return_dict = {
            "loss": loss,
            "tvars": tvars,
            "model": model,
            "sampled_ids": sampled_ids,  # batch x gen_sample, seg_length
            "sampled_input_ids": input_ids,  # batch x gen_sample, seg_length,
            "sampled_input_mask": input_mask,
            "sampled_segment_ids": segment_ids,
            "masked_lm_ids": masked_lm_ids,
            "masked_lm_weights": masked_lm_mask,
            "masked_lm_log_probs": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "next_sentence_example_loss": nsp_per_example_loss,
            "next_sentence_log_probs": nsp_log_prob,
            "next_sentence_labels": features['next_sentence_labels'],
            "sampled_binary_mask": sampled_binary_mask
        }
        return return_dict
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        seq_features = {}
        for key in features:
            seq_features[key] = features[key]
        seq_features['input_ids'] = features["input_ori_ids"]

        model = model_api(model_config,
                          seq_features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        sequence_mask = tf.to_float(
            tf.not_equal(features['input_ori_ids'][:, 1:],
                         kargs.get('[PAD]', 0)))

        # batch x seq_length
        print(model.get_sequence_output_logits().get_shape(),
              "===logits shape===")
        seq_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=features['input_ori_ids'][:, 1:],
            logits=model.get_sequence_output_logits()[:, :-1])

        per_example_loss = tf.reduce_sum(seq_loss * sequence_mask, axis=-1) / (
            tf.reduce_sum(sequence_mask, axis=-1) + 1e-10)
        loss = tf.reduce_mean(per_example_loss)

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)

        use_tpu = 1 if kargs.get('use_tpu', False) else 0

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                pretrained_tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu)
            tf.logging.info("***** using tpu *****")
        else:
            scaffold_fn = None
            tf.logging.info("***** not using tpu *****")

        if mode == tf.estimator.ModeKeys.TRAIN:

            if kargs.get('use_tpu', False):
                optimizer_fn = optimizer.Optimizer(opt_config)
                use_tpu = 1
                tf.logging.info(
                    "***** using tpu with tpu-captiable optimizer *****")
            else:
                optimizer_fn = distributed_optimizer.Optimizer(opt_config)
                use_tpu = 0
                tf.logging.info(
                    "***** using gpu with gpu-captiable optimizer *****")

            tvars = pretrained_tvars
            model_io_fn.print_params(tvars, string=", trainable params")

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    loss,
                    tvars,
                    opt_config.init_lr,
                    opt_config.num_train_steps,
                    use_tpu=use_tpu)

                train_metric_dict = train_metric(
                    features['input_ori_ids'],
                    model.get_sequence_output_logits(), **kargs)

                if not kargs.get('use_tpu', False):
                    for key in train_metric_dict:
                        tf.summary.scalar(key, train_metric_dict[key])
                    tf.summary.scalar('learning_rate',
                                      optimizer_fn.learning_rate)
                    tf.logging.info("***** logging metric *****")
                    tf.summary.scalar("causal_attenion_mask_length",
                                      tf.reduce_sum(model.attention_mask))
                    tf.summary.scalar("bi_attenion_mask_length",
                                      tf.reduce_sum(model.bi_attention_mask))

                if kargs.get('use_tpu', False):
                    estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                        mode=mode,
                        loss=loss,
                        train_op=train_op,
                        scaffold_fn=scaffold_fn)
                else:
                    estimator_spec = tf.estimator.EstimatorSpec(
                        mode=mode, loss=loss, train_op=train_op)

                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:

            gpu_eval_metrics = eval_metric(features['input_ori_ids'],
                                           model.get_sequence_output_logits())
            tpu_eval_metrics = (eval_metric, [
                features['input_ori_ids'],
                model.get_sequence_output_logits()
            ])

            if kargs.get('use_tpu', False):
                estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
                    mode=mode,
                    loss=loss,
                    eval_metrics=tpu_eval_metrics,
                    scaffold_fn=scaffold_fn)
            else:
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode, loss=loss, eval_metric_ops=gpu_eval_metrics)

            return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 23
0
    def model_fn(features, labels, mode):

        train_ops = []
        train_hooks = []
        logits_dict = {}
        losses_dict = {}
        features_dict = {}
        tvars = []
        task_num_dict = {}

        total_loss = tf.constant(0.0)

        task_num = 0

        encoder = {}
        hook_dict = {}

        print(task_type_dict.keys(), "==task type dict==")
        num_task = len(task_type_dict)

        for index, task_type in enumerate(task_type_dict.keys()):
            if model_config_dict[task_type].model_type in model_type_lst:
                reuse = True
            else:
                reuse = None
                model_type_lst.append(model_config_dict[task_type].model_type)
            if task_type_dict[task_type] == "cls_task":

                if model_config_dict[task_type].model_type not in encoder:
                    model_api = model_zoo(model_config_dict[task_type])

                    model = model_api(model_config_dict[task_type],
                                      features,
                                      labels,
                                      mode,
                                      target_dict[task_type],
                                      reuse=reuse)
                    encoder[model_config_dict[task_type].model_type] = model

                print(encoder, "==encode==")

                task_model_fn = cls_model_fn(
                    encoder[model_config_dict[task_type].model_type],
                    model_config_dict[task_type],
                    num_labels_dict[task_type],
                    init_checkpoint_dict[task_type],
                    reuse,
                    load_pretrained_dict[task_type],
                    model_io_config,
                    opt_config,
                    exclude_scope=exclude_scope_dict[task_type],
                    not_storage_params=not_storage_params_dict[task_type],
                    target=target_dict[task_type],
                    label_lst=None,
                    output_type=output_type,
                    task_layer_reuse=task_layer_reuse,
                    task_type=task_type,
                    num_task=num_task,
                    task_adversarial=1e-2,
                    **kargs)
                print("==SUCCEEDED IN LODING==", task_type)

                result_dict = task_model_fn(features, labels, mode)
                logits_dict[task_type] = result_dict["logits"]
                losses_dict[task_type] = result_dict["loss"]  # task loss
                for key in [
                        "masked_lm_loss", "task_loss", "acc", "task_acc",
                        "masked_lm_acc"
                ]:
                    name = "{}_{}".format(task_type, key)
                    if name in result_dict:
                        hook_dict[name] = result_dict[name]
                hook_dict["{}_loss".format(task_type)] = result_dict["loss"]
                total_loss += result_dict["loss"]

                if mode == tf.estimator.ModeKeys.TRAIN:
                    tvars.extend(result_dict["tvars"])
                    task_num += result_dict["task_num"]
                    task_num_dict[task_type] = result_dict["task_num"]
                elif mode == tf.estimator.ModeKeys.EVAL:
                    features[task_type] = result_dict["feature"]
            else:
                continue

        hook_dict["total_loss"] = total_loss

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn = model_io.ModelIO(model_io_config)

            optimizer_fn = optimizer.Optimizer(opt_config)

            model_io_fn.print_params(list(set(tvars)),
                                     string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            print("==update_ops==", update_ops)

            with tf.control_dependencies(update_ops):
                train_op = optimizer_fn.get_train_op(
                    total_loss, list(set(tvars)), opt_config.init_lr,
                    opt_config.num_train_steps, **kargs)

                model_io_fn.set_saver(optimizer_fn.opt)

                if kargs.get("task_index", 1) == 0 and kargs.get(
                        "run_config", None):
                    model_io_fn.get_hooks(kargs.get("checkpoint_dir", None),
                                          kargs.get("num_storage_steps", 1000))

                    training_hooks = model_io_fn.checkpoint_hook
                elif kargs.get("task_index", 1) == 1:
                    training_hooks = []
                else:
                    training_hooks = []

                if len(optimizer_fn.distributed_hooks) >= 1:
                    training_hooks.extend(optimizer_fn.distributed_hooks)
                print(training_hooks, "==training_hooks==", "==task_index==",
                      kargs.get("task_index", 1))

            if output_type == "sess":
                return {
                    "train": {
                        "total_loss": total_loss,
                        "loss": losses_dict,
                        "logits": logits_dict,
                        "train_op": train_op,
                        "task_num_dict": task_num_dict
                    },
                    "hooks": train_hooks
                }
            elif output_type == "estimator":

                hook_dict['learning_rate'] = optimizer_fn.learning_rate
                logging_hook = tf.train.LoggingTensorHook(hook_dict,
                                                          every_n_iter=100)
                training_hooks.append(logging_hook)

                print("==hook_dict==")

                print(hook_dict)

                for key in hook_dict:
                    tf.summary.scalar(key, hook_dict[key])
                    for index, task_type in enumerate(task_type_dict.keys()):
                        tmp = "{}_loss".format(task_type)
                        if tmp == key:
                            tf.summary.scalar(
                                "loss_gap_{}".format(task_type),
                                hook_dict["total_loss"] - hook_dict[key])
                for key in task_num_dict:
                    tf.summary.scalar(key + "_task_num", task_num_dict[key])

                estimator_spec = tf.estimator.EstimatorSpec(mode=mode,
                                                            loss=total_loss,
                                                            train_op=train_op)
                # training_hooks=training_hooks)
                return estimator_spec

        elif mode == tf.estimator.ModeKeys.EVAL:  # eval execute for each class solo

            def metric_fn(logits, label_ids):
                """Computes the loss and accuracy of the model."""
                sentence_log_probs = tf.reshape(logits, [-1, logits.shape[-1]])
                sentence_predictions = tf.argmax(logits,
                                                 axis=-1,
                                                 output_type=tf.int32)
                sentence_labels = tf.reshape(label_ids, [-1])
                sentence_accuracy = tf.metrics.accuracy(
                    labels=label_ids, predictions=sentence_predictions)
                sentence_f = tf_metrics.f1(label_ids,
                                           sentence_predictions,
                                           num_labels,
                                           label_lst,
                                           average="macro")

                eval_metric_ops = {"f1": sentence_f, "acc": sentence_accuracy}

                return eval_metric_ops

            if output_type == "sess":
                return {
                    "eval": {
                        "logits": logits_dict,
                        "total_loss": total_loss,
                        "feature": features,
                        "loss": losses_dict
                    }
                }
            elif output_type == "estimator":
                eval_metric_ops = {}
                for key in logits:
                    eval_dict = metric_fn(logits[key],
                                          features_task_dict[key]["label_ids"])
                    for sub_key in eval_dict.keys():
                        eval_key = "{}_{}".format(key, sub_key)
                        eval_metric_ops[eval_key] = eval_dict[sub_key]
                estimator_spec = tf.estimator.EstimatorSpec(
                    mode=mode,
                    loss=total_loss / task_num,
                    eval_metric_ops=eval_metric_ops)
                return estimator_spec
        else:
            raise NotImplementedError()
Exemplo n.º 24
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          target,
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        logits = pretrain.emb_score(model_config, model.get_sequence_output(),
                                    features['input_ids'],
                                    model.get_embedding_table(),
                                    features['input_mask'], **kargs)

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        ebm_pretrain_tvars = model_io_fn.get_params(
            "ebm/predictions", not_storage_params=not_storage_params)

        if model_config.get('embedding_scope', None) is not None:
            embedding_tvars = model_io_fn.get_params(
                model_config.get('embedding_scope', 'bert') + "/embeddings",
                not_storage_params=not_storage_params)
            pretrained_tvars.extend(embedding_tvars)

        pretrained_tvars.extend(lm_pretrain_tvars)
        # pretrained_tvars.extend(ebm_pretrain_tvars)
        tvars = pretrained_tvars
        logz_tvars = ebm_pretrain_tvars

        print('==ebm parameters==', tvars)
        print('==ebm logz parameters==', logz_tvars)

        load_vars = tvars + logz_tvars

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            print("==load vars==", load_vars)
            scaffold_fn = model_io_fn.load_pretrained(
                load_vars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None

        # logits is logp, when we need to directly maximize it, we only minus
        # with tf.variable_scope("ebm", reuse=tf.AUTO_REUSE):
        # 	ebm_global_step = tf.get_variable(
        # 						"global_step",
        # 						shape=[],
        # 						initializer=tf.constant_initializer(0, dtype=tf.int64),
        # 						trainable=False,
        # 						dtype=tf.int64)
        return_dict = {
            "tvars": tvars,
            "logits": logits,
            "logz_tvars": logz_tvars
            # "global_step":ebm_global_step
        }
        return return_dict
Exemplo n.º 25
0
	def model_fn(features, labels, mode, params):

		model_api = model_zoo(model_config)

		model = model_api(model_config, features, labels,
							mode, target, reuse=tf.AUTO_REUSE)

		if mode == tf.estimator.ModeKeys.TRAIN:
			dropout_prob = model_config.dropout_prob
		else:
			dropout_prob = 0.0

		if model_io_config.fix_lm == True:
			scope = model_config.scope + "_finetuning"
		else:
			scope = model_config.scope
		
		(nsp_loss, 
		 nsp_per_example_loss, 
		 nsp_log_prob) = pretrain.get_next_sentence_output(model_config,
										model.get_pooled_output(),
										features['next_sentence_labels'],
										reuse=tf.AUTO_REUSE)

		masked_lm_positions = features["masked_lm_positions"]
		masked_lm_ids = features["masked_lm_ids"]
		masked_lm_weights = features["masked_lm_weights"]

		if model_config.model_type == 'bert':
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")
		elif model_config.model_type == 'albert':
			masked_lm_fn = pretrain_albert.get_masked_lm_output
			print("==apply albert masked lm==")
		else:
			masked_lm_fn = pretrain.get_masked_lm_output
			print("==apply bert masked lm==")

		(masked_lm_loss,
		masked_lm_example_loss, 
		masked_lm_log_probs,
		masked_lm_mask) = masked_lm_fn(
										model_config, 
										model.get_sequence_output(), 
										model.get_embedding_table(),
										masked_lm_positions, 
										masked_lm_ids, 
										masked_lm_weights,
										reuse=tf.AUTO_REUSE,
										embedding_projection=model.get_embedding_projection_table())
		print(model_config.lm_ratio, '==mlm lm_ratio==')
		loss = model_config.lm_ratio * masked_lm_loss #+ model_config.nsp_ratio * nsp_loss
		
		model_io_fn = model_io.ModelIO(model_io_config)

		pretrained_tvars = model_io_fn.get_params(model_config.scope, 
										not_storage_params=not_storage_params)

		lm_pretrain_tvars = model_io_fn.get_params("cls/predictions", 
									not_storage_params=not_storage_params)

		pretrained_tvars.extend(lm_pretrain_tvars)

		if load_pretrained == "yes":
			scaffold_fn = model_io_fn.load_pretrained(pretrained_tvars, 
											init_checkpoint,
											exclude_scope=exclude_scope,
											use_tpu=1)
		else:
			scaffold_fn = None
                print("******* scaffold fn *******", scaffold_fn)
		if mode == tf.estimator.ModeKeys.TRAIN:
						
			optimizer_fn = optimizer.Optimizer(opt_config)
						
			tvars = pretrained_tvars
			model_io_fn.print_params(tvars, string=", trainable params")
			
			# update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			# with tf.control_dependencies(update_ops):
			print('==gpu count==', opt_config.get('gpu_count', 1))

			train_op = optimizer_fn.get_train_op(loss, tvars,
							opt_config.init_lr, 
							opt_config.num_train_steps,
							use_tpu=opt_config.use_tpu)

			train_metric_dict = train_metric_fn(
					masked_lm_example_loss, masked_lm_log_probs, 
					masked_lm_ids,
					masked_lm_weights, 
					nsp_per_example_loss,
					nsp_log_prob, 
					features['next_sentence_labels'],
					masked_lm_mask=masked_lm_mask
				)

			# for key in train_metric_dict:
			# 	tf.summary.scalar(key, train_metric_dict[key])
			# tf.summary.scalar('learning_rate', optimizer_fn.learning_rate)

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
							mode=mode,
							loss=loss,
							train_op=train_op,
							scaffold_fn=scaffold_fn)

			return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL:

			def metric_fn(masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
					masked_lm_weights, next_sentence_example_loss,
					next_sentence_log_probs, next_sentence_labels):
				"""Computes the loss and accuracy of the model."""
				masked_lm_log_probs = tf.reshape(masked_lm_log_probs,
												 [-1, masked_lm_log_probs.shape[-1]])
				masked_lm_predictions = tf.argmax(
					masked_lm_log_probs, axis=-1, output_type=tf.int32)
				masked_lm_example_loss = tf.reshape(masked_lm_example_loss, [-1])
				masked_lm_ids = tf.reshape(masked_lm_ids, [-1])
				masked_lm_weights = tf.reshape(masked_lm_weights, [-1])
				masked_lm_accuracy = tf.metrics.accuracy(
					labels=masked_lm_ids,
					predictions=masked_lm_predictions,
					weights=masked_lm_weights)
				masked_lm_mean_loss = tf.metrics.mean(
					values=masked_lm_example_loss, weights=masked_lm_weights)

				next_sentence_log_probs = tf.reshape(
					next_sentence_log_probs, [-1, next_sentence_log_probs.shape[-1]])
				next_sentence_predictions = tf.argmax(
					next_sentence_log_probs, axis=-1, output_type=tf.int32)
				next_sentence_labels = tf.reshape(next_sentence_labels, [-1])
				next_sentence_accuracy = tf.metrics.accuracy(
					labels=next_sentence_labels, predictions=next_sentence_predictions)
				next_sentence_mean_loss = tf.metrics.mean(
					values=next_sentence_example_loss)

				return {
					"masked_lm_accuracy": masked_lm_accuracy,
					"masked_lm_loss": masked_lm_mean_loss,
					"next_sentence_accuracy": next_sentence_accuracy,
					"next_sentence_loss": next_sentence_mean_loss
					}

			eval_metrics = (metric_fn, [
			  masked_lm_example_loss, masked_lm_log_probs, masked_lm_ids,
			  masked_lm_weights, nsp_per_example_loss,
			  nsp_log_prob, features['next_sentence_labels']
			])

			estimator_spec = tf.contrib.tpu.TPUEstimatorSpec(
						  mode=mode,
						  loss=loss,
						  eval_metrics=eval_metrics,
						  scaffold_fn=scaffold_fn)

			return estimator_spec
		else:
			raise NotImplementedError()
Exemplo n.º 26
0
    def model_fn(features, labels, mode, params):

        model_api = model_zoo(model_config)

        input_mask = tf.cast(
            tf.not_equal(features['input_ids_{}'.format(target)],
                         kargs.get('[PAD]', 0)), tf.int32)
        segment_ids = tf.zeros_like(input_mask)

        if target:
            features['input_ori_ids'] = features['input_ids_{}'.format(target)]
            features['input_mask'] = input_mask
            features['segment_ids'] = segment_ids
            # features['input_mask'] = features['input_mask_{}'.format(target)]
            # features['segment_ids'] = features['segment_ids_{}'.format(target)]
            features['input_ids'] = features['input_ids_{}'.format(target)]

        input_ori_ids = features.get('input_ori_ids', None)
        if mode == tf.estimator.ModeKeys.TRAIN:
            if input_ori_ids is not None:

                [output_ids, sampled_binary_mask] = hmm_input_ids_generation(
                    model_config,
                    features['input_ori_ids'],
                    features['input_mask'], [
                        tf.cast(tf.constant(hmm_tran_prob), tf.float32)
                        for hmm_tran_prob in hmm_tran_prob_list
                    ],
                    mask_probability=0.1,
                    replace_probability=0.1,
                    original_probability=0.1,
                    mask_prior=tf.cast(tf.constant(mask_prior), tf.float32),
                    **kargs)

                features['input_ids'] = output_ids
                tf.logging.info(
                    "***** Running random sample input generation *****")
            else:
                sampled_binary_mask = None
        else:
            sampled_binary_mask = None

        model = model_api(model_config,
                          features,
                          labels,
                          mode,
                          "",
                          reuse=tf.AUTO_REUSE,
                          **kargs)

        if mode == tf.estimator.ModeKeys.TRAIN:
            dropout_prob = model_config.dropout_prob
        else:
            dropout_prob = 0.0

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        if model_config.model_type == 'bert':
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain.seq_mask_masked_lm_output
            print("==apply bert masked lm==")
        elif model_config.model_type == 'albert':
            masked_lm_fn = pretrain_albert.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply albert masked lm==")
        else:
            masked_lm_fn = pretrain.get_masked_lm_output
            seq_masked_lm_fn = pretrain_albert.seq_mask_masked_lm_output
            print("==apply bert masked lm==")

        if sampled_binary_mask is not None:
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = seq_masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 features['input_mask'],
                 features['input_ori_ids'],
                 features['input_ids'],
                 sampled_binary_mask,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
            masked_lm_ids = input_ori_ids
        else:
            masked_lm_positions = features["masked_lm_positions"]
            masked_lm_ids = features["masked_lm_ids"]
            masked_lm_weights = features["masked_lm_weights"]
            (masked_lm_loss, masked_lm_example_loss, masked_lm_log_probs,
             masked_lm_mask) = masked_lm_fn(
                 model_config,
                 model.get_sequence_output(),
                 model.get_embedding_table(),
                 masked_lm_positions,
                 masked_lm_ids,
                 masked_lm_weights,
                 reuse=tf.AUTO_REUSE,
                 embedding_projection=model.get_embedding_projection_table())
        print(model_config.lm_ratio, '==mlm lm_ratio==')
        loss = model_config.lm_ratio * masked_lm_loss  #+ 0.0 * nsp_loss

        model_io_fn = model_io.ModelIO(model_io_config)

        pretrained_tvars = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        lm_pretrain_tvars = model_io_fn.get_params(
            "cls/predictions", not_storage_params=not_storage_params)

        pretrained_tvars.extend(lm_pretrain_tvars)
        tvars = pretrained_tvars

        if load_pretrained == "yes":
            use_tpu = 1 if kargs.get('use_tpu', False) else 0
            scaffold_fn = model_io_fn.load_pretrained(
                tvars,
                init_checkpoint,
                exclude_scope=exclude_scope,
                use_tpu=use_tpu,
                restore_var_name=model_config.get('restore_var_name', []))
        else:
            scaffold_fn = None

        return_dict = {
            "loss": loss,
            "logits": masked_lm_log_probs,
            "masked_lm_example_loss": masked_lm_example_loss,
            "tvars": tvars,
            "model": model,
            "masked_lm_mask": masked_lm_mask,
            "output_ids": output_ids,
            "masked_lm_ids": masked_lm_ids
        }
        return return_dict