Exemple #1
0
def model_config_parser(FLAGS):

    print(FLAGS.model_type)

    if FLAGS.model_type in [
            "bert", "bert_rule", "albert", "electra_gumbel_encoder",
            "albert_official", "electra_gumbel_albert_official_encoder",
            "bert_seq"
    ]:
        config = json.load(open(FLAGS.config_file, "r"))
        print(config, '==model config==')
        config = Bunch(config)
        config.use_one_hot_embeddings = True
        # if FLAGS.exclude_scope:
        #	config.scope = FLAGS.exclude_scope + "/" + "bert"
        #	tf.logging.info("****** add exclude_scope ******* %s", str(config.scope))
        #	else:
        # 	config.scope = FLAGS.exclude_scope + "/" + "bert"
        # 	tf.logging.info("****** add exclude_scope ******* %s", str(config.scope))
        # else:
        config.scope = FLAGS.model_scope  #"bert"
        tf.logging.info("****** original scope ******* %s", str(config.scope))
        config.dropout_prob = 0.1
        try:
            config.label_type = FLAGS.label_type
        except:
            config.label_type = "single_label"
        tf.logging.info("****** label type ******* %s", str(config.label_type))
        config.model_type = FLAGS.model_type
        config.ln_type = FLAGS.ln_type
        if FLAGS.task_type in ['bert_pretrain']:
            if FLAGS.load_pretrained == "yes":
                config.init_lr = FLAGS.init_lr
                config.warmup = 0.1
            else:
                config.init_lr = FLAGS.init_lr
                config.warmup = 0.1
            print('==apply bert pretrain==', config.init_lr)
        else:
            if FLAGS.model_type in ['albert']:
                try:
                    config.init_lr = FLAGS.init_lr
                except:
                    config.init_lr = 1e-4
            else:
                # try:
                print(FLAGS)
                config.init_lr = FLAGS.init_lr
                # except:
                # 	config.init_lr = 2e-5
            print('==apply albert finetuning==', config.init_lr)
        print("===learning rate===", config.init_lr)
        try:
            if FLAGS.attention_type in ['rezero_transformer']:
                config.warmup = 0.0
                tf.logging.info("****** warmup ******* %s", str(config.warmup))
        except:
            tf.logging.info("****** normal attention ******* ")
        tf.logging.info("****** learning rate ******* %s", str(config.init_lr))
        # config.loss = "dmi_loss"

        try:
            config.loss = FLAGS.loss
        except:
            config.loss = "entropy"
        tf.logging.info("****** loss type ******* %s", str(config.loss))

        # config.loss = "focal_loss"
        config.rule_type_size = 2
        config.lm_ratio = 1.0
        config.max_length = FLAGS.max_length
        config.nsp_ratio = 1.0
        config.max_predictions_per_seq = FLAGS.max_predictions_per_seq
        if FLAGS.task_type in ["pair_sentence_classification"]:
            config.classifier = FLAGS.classifier

    elif FLAGS.model_type in ["bert_small"]:
        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"
        config.model_type = FLAGS.model_type
        config.init_lr = 3e-5
        config.num_hidden_layers = FLAGS.num_hidden_layers
        config.loss = "entropy"
        config.rule_type_size = 2
        if FLAGS.task_type in ["pair_sentence_classification"]:
            config.classifier = FLAGS.classifier
            config.output_layer = FLAGS.output_layer

    elif FLAGS.model_type in [
            "textcnn", 'textcnn_distillation',
            'textcnn_distillation_adv_adaptation', 'textcnn_interaction'
    ]:
        from data_generator import load_w2v
        w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
        vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

        print(w2v_path, vocab_path)
        config = json.load(open(FLAGS.config_file, "r"))

        [w2v_embed, token2id, id2token, is_extral_symbol, use_pretrained
         ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path,
                                          config.get('emb_size', 64))

        config = Bunch(config)
        config.token_emb_mat = w2v_embed
        config.char_emb_mat = None
        config.vocab_size = w2v_embed.shape[0]
        config.max_length = FLAGS.max_length
        config.emb_size = w2v_embed.shape[1]
        config.scope = "textcnn"
        config.char_dim = w2v_embed.shape[1]
        config.char_vocab_size = w2v_embed.shape[0]
        config.char_embedding = None
        config.model_type = FLAGS.model_type
        config.dropout_prob = config.dropout_rate
        config.init_lr = config.learning_rate
        config.use_pretrained = use_pretrained
        config.label_type = FLAGS.label_type
        if is_extral_symbol == 1:
            config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
            print("==need extra_symbol==")

        if FLAGS.task_type in ["pair_sentence_classification"]:
            config.classifier = FLAGS.classifier
            config.output_layer = FLAGS.output_layer

    elif FLAGS.model_type in ["textlstm", "textlstm_distillation"]:
        from data_generator import load_w2v
        w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
        vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

        print(w2v_path, vocab_path)

        [w2v_embed, token2id, id2token, is_extral_symbol
         ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.token_emb_mat = w2v_embed
        config.char_emb_mat = None
        config.vocab_size = w2v_embed.shape[0]
        config.max_length = FLAGS.max_length
        config.emb_size = w2v_embed.shape[1]
        config.scope = "textlstm"
        config.char_dim = w2v_embed.shape[1]
        config.char_vocab_size = w2v_embed.shape[0]
        config.char_embedding = None
        config.model_type = FLAGS.model_type
        config.dropout_prob = config.dropout_rate
        config.init_lr = config.learning_rate
        config.grad_clip = "gloabl_norm"
        config.clip_norm = 5.0
        if is_extral_symbol == 1:
            config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
            print("==need extra_symbol==")

        if FLAGS.task_type in ["pair_sentence_classification"]:
            config.classifier = FLAGS.classifier
            config.output_layer = FLAGS.output_layer

    elif FLAGS.model_type in ["match_pyramid", "match_pyramid_distillation"]:
        from data_generator import load_w2v
        w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
        vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

        print(w2v_path, vocab_path)

        [w2v_embed, token2id, id2token, is_extral_symbol
         ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.token_emb_mat = w2v_embed
        config.char_emb_mat = None
        config.vocab_size = w2v_embed.shape[0]
        config.max_length = FLAGS.max_length
        config.emb_size = w2v_embed.shape[1]
        config.scope = "match_pyramid"
        config.char_dim = w2v_embed.shape[1]
        config.char_vocab_size = w2v_embed.shape[0]
        config.char_embedding = None
        config.model_type = FLAGS.model_type
        config.dropout_prob = config.dropout_rate
        config.init_lr = config.learning_rate
        config.grad_clip = "gloabl_norm"
        config.clip_norm = 5.0
        if is_extral_symbol == 1:
            config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
            print("==need extra_symbol==")
        config.max_seq_len = FLAGS.max_length
        if FLAGS.task_type in ["interaction_pair_sentence_classification"]:
            config.classifier = FLAGS.classifier
            config.output_layer = FLAGS.output_layer

        if config.compress_emb:
            config.embedding_dim_compressed = config.cnn_num_filters

    elif FLAGS.model_type in ["dan", 'dan_distillation']:
        from data_generator import load_w2v
        w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
        vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

        print(w2v_path, vocab_path)

        [w2v_embed, token2id, id2token, is_extral_symbol
         ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.token_emb_mat = w2v_embed
        config.char_emb_mat = None
        config.vocab_size = w2v_embed.shape[0]
        config.max_length = FLAGS.max_length
        config.emb_size = w2v_embed.shape[1]
        config.scope = "dan"
        config.char_dim = w2v_embed.shape[1]
        config.char_vocab_size = w2v_embed.shape[0]
        config.char_embedding = None
        config.model_type = FLAGS.model_type
        config.dropout_prob = config.dropout_rate
        config.init_lr = config.learning_rate
        if is_extral_symbol == 1:
            config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
            print("==need extra_symbol==")

        if FLAGS.task_type in ["pair_sentence_classification"]:
            config.classifier = FLAGS.classifier
            config.output_layer = FLAGS.output_layer

    elif FLAGS.model_type in ['gpt']:
        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.dropout_prob = 0.1
        config.init_lr = 1e-4

    elif FLAGS.model_type in ["gated_cnn_seq"]:

        config = json.load(open(FLAGS.config_file, "r"))
        config = Bunch(config)
        config.token_emb_mat = None
        config.char_emb_mat = None
        config.vocab_size = config.vocab_size
        config.max_length = FLAGS.max_length
        config.emb_size = config.emb_size
        config.scope = "textcnn"
        config.char_dim = config.emb_char_size
        config.char_vocab_size = config.vocab_size
        config.char_embedding = None
        config.model_type = FLAGS.model_type
        config.dropout_prob = config.dropout_rate
        config.init_lr = FLAGS.init_lr
        config.grad_clip = "gloabl_norm"
        config.clip_norm = 10.0
        config.max_seq_len = FLAGS.max_length

    return config
Exemple #2
0
def model_config_parser(FLAGS):

	print(FLAGS.model_type)

	if FLAGS.model_type in ["bert", "bert_rule"]:
		config = json.load(open(FLAGS.config_file, "r"))
		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "bert"
		config.dropout_prob = 0.1
		config.label_type = "single_label"
		config.model_type = FLAGS.model_type
		config.init_lr = 2e-5
		config.loss = "entropy"
		config.rule_type_size = 2
		if FLAGS.task_type in ["pair_sentence_classification"]:
			config.classifier = FLAGS.classifier

	elif FLAGS.model_type in ["bert_small"]:
		config = json.load(open(FLAGS.config_file, "r"))
		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "bert"
		config.dropout_prob = 0.1
		config.label_type = "single_label"
		config.model_type = FLAGS.model_type
		config.init_lr = 2e-5
		config.num_hidden_layers = FLAGS.num_hidden_layers
		config.loss = "entropy"
		config.rule_type_size = 2
		if FLAGS.task_type in ["pair_sentence_classification"]:
			config.classifier = FLAGS.classifier
			config.output_layer = FLAGS.output_layer

	elif FLAGS.model_type in ["textcnn", 'textcnn_distillation']:
		from data_generator import load_w2v
		w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
		vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

		print(w2v_path, vocab_path)

		[w2v_embed, token2id, 
		id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
		config = json.load(open(FLAGS.config_file, "r"))
		config = Bunch(config)
		config.token_emb_mat = w2v_embed
		config.char_emb_mat = None
		config.vocab_size = w2v_embed.shape[0]
		config.max_length = FLAGS.max_length
		config.emb_size = w2v_embed.shape[1]
		config.scope = "textcnn"
		config.char_dim = w2v_embed.shape[1]
		config.char_vocab_size = w2v_embed.shape[0]
		config.char_embedding = None
		config.model_type = FLAGS.model_type
		config.dropout_prob = config.dropout_rate
		config.init_lr = config.learning_rate
		if is_extral_symbol == 1:
			config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
			print("==need extra_symbol==")

		if FLAGS.task_type in ["pair_sentence_classification"]:
			config.classifier = FLAGS.classifier
			config.output_layer = FLAGS.output_layer

	elif FLAGS.model_type in ["textlstm", "textlstm_distillation"]:
		from data_generator import load_w2v
		w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
		vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

		print(w2v_path, vocab_path)

		[w2v_embed, token2id, 
		id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
		config = json.load(open(FLAGS.config_file, "r"))
		config = Bunch(config)
		config.token_emb_mat = w2v_embed
		config.char_emb_mat = None
		config.vocab_size = w2v_embed.shape[0]
		config.max_length = FLAGS.max_length
		config.emb_size = w2v_embed.shape[1]
		config.scope = "textlstm"
		config.char_dim = w2v_embed.shape[1]
		config.char_vocab_size = w2v_embed.shape[0]
		config.char_embedding = None
		config.model_type = FLAGS.model_type
		config.dropout_prob = config.dropout_rate
		config.init_lr = config.learning_rate
		config.grad_clip = "gloabl_norm"
		config.clip_norm = 5.0
		if is_extral_symbol == 1:
			config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
			print("==need extra_symbol==")
		
		if FLAGS.task_type in ["pair_sentence_classification"]:
			config.classifier = FLAGS.classifier
			config.output_layer = FLAGS.output_layer

	elif FLAGS.model_type in ["match_pyramid", "match_pyramid_distillation"]:
		from data_generator import load_w2v
		w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path)
		vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)

		print(w2v_path, vocab_path)

		[w2v_embed, token2id, 
		id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)
		config = json.load(open(FLAGS.config_file, "r"))
		config = Bunch(config)
		config.token_emb_mat = w2v_embed
		config.char_emb_mat = None
		config.vocab_size = w2v_embed.shape[0]
		config.max_length = FLAGS.max_length
		config.emb_size = w2v_embed.shape[1]
		config.scope = "match_pyramid"
		config.char_dim = w2v_embed.shape[1]
		config.char_vocab_size = w2v_embed.shape[0]
		config.char_embedding = None
		config.model_type = FLAGS.model_type
		config.dropout_prob = config.dropout_rate
		config.init_lr = config.learning_rate
		config.grad_clip = "gloabl_norm"
		config.clip_norm = 5.0
		if is_extral_symbol == 1:
			config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"]
			print("==need extra_symbol==")
		config.max_seq_len = FLAGS.max_length
		if FLAGS.task_type in ["interaction_pair_sentence_classification"]:
			config.classifier = FLAGS.classifier
			config.output_layer = FLAGS.output_layer

		if config.compress_emb:
			config.embedding_dim_compressed = config.cnn_num_filters

	return config
	def model_fn(features, labels, mode):

		train_ops = []
		train_hooks = []
		logits_dict = {}
		losses_dict = {}
		features_dict = {}
		tvars = []
		task_num_dict = {}

		total_loss = tf.constant(0.0)

		task_num = 0

		encoder = {}
		hook_dict = {}

		print(task_type_dict.keys(), "==task type dict==")
		num_task = len(task_type_dict)

		from data_generator import load_w2v
		flags = kargs.get('flags', Bunch({}))
		print(flags.pretrained_w2v_path, "===pretrain vocab path===")
		w2v_path = os.path.join(flags.buckets, flags.pretrained_w2v_path)
		vocab_path = os.path.join(flags.buckets, flags.vocab_file)

		[w2v_embed, token2id, 
		id2token, is_extral_symbol, use_pretrained] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path)

		pretrained_embed = tf.cast(tf.constant(w2v_embed), tf.float32)

		for index, task_type in enumerate(task_type_dict.keys()):
			if model_config_dict[task_type].model_type in model_type_lst:
				reuse = True
			else:
				reuse = None
				model_type_lst.append(model_config_dict[task_type].model_type)
			if task_type_dict[task_type] == "cls_task":

				if model_config_dict[task_type].model_type not in encoder:
					model_api = model_zoo(model_config_dict[task_type])

					model = model_api(model_config_dict[task_type], features, labels,
							mode, target_dict[task_type], reuse=reuse,
														cnn_type='multilayer_textcnn')
					encoder[model_config_dict[task_type].model_type] = model

				print(encoder, "==encode==")

				task_model_fn = cls_model_fn(encoder[model_config_dict[task_type].model_type],
												model_config_dict[task_type],
												num_labels_dict[task_type],
												init_checkpoint_dict[task_type],
												reuse,
												load_pretrained_dict[task_type],
												model_io_config,
												opt_config,
												exclude_scope=exclude_scope_dict[task_type],
												not_storage_params=not_storage_params_dict[task_type],
												target=target_dict[task_type],
												label_lst=None,
												output_type=output_type,
												task_layer_reuse=task_layer_reuse,
												task_type=task_type,
												num_task=num_task,
												task_adversarial=1e-2,
												get_pooled_output='task_output',
												feature_distillation=False,
												embedding_distillation=True,
												pretrained_embed=pretrained_embed,
												**kargs)
				print("==SUCCEEDED IN LODING==", task_type)

				result_dict = task_model_fn(features, labels, mode)
				logits_dict[task_type] = result_dict["logits"]
				losses_dict[task_type] = result_dict["loss"] # task loss
				for key in ["masked_lm_loss", "task_loss", "acc", "task_acc", "masked_lm_acc"]:
					name = "{}_{}".format(task_type, key)
					if name in result_dict:
						hook_dict[name] = result_dict[name]
				hook_dict["{}_loss".format(task_type)] = result_dict["loss"]
				hook_dict["{}_num".format(task_type)] = result_dict["task_num"]
				total_loss += result_dict["loss"]
				hook_dict['embed_loss'] = result_dict["embed_loss"]
				hook_dict['feature_loss'] = result_dict["feature_loss"]
				hook_dict["{}_task_loss".format(task_type)] = result_dict["task_loss"]
				if mode == tf.estimator.ModeKeys.TRAIN:
					tvars.extend(result_dict["tvars"])
					task_num += result_dict["task_num"]
					task_num_dict[task_type] = result_dict["task_num"]
				elif mode == tf.estimator.ModeKeys.EVAL:
					features[task_type] = result_dict["feature"]
			else:
				continue

		hook_dict["total_loss"] = total_loss

		if mode == tf.estimator.ModeKeys.TRAIN:
			model_io_fn = model_io.ModelIO(model_io_config)

			optimizer_fn = optimizer.Optimizer(opt_config)

			model_io_fn.print_params(list(set(tvars)), string=", trainable params")
			update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
			print("==update_ops==", update_ops)

			with tf.control_dependencies(update_ops):
				train_op = optimizer_fn.get_train_op(total_loss, list(set(tvars)), 
								opt_config.init_lr, 
								opt_config.num_train_steps,
								**kargs)

				model_io_fn.set_saver(optimizer_fn.opt)

				if kargs.get("task_index", 1) == 1 and kargs.get("run_config", None):
					model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), 
														kargs.get("num_storage_steps", 1000))

					training_hooks = model_io_fn.checkpoint_hook
				elif kargs.get("task_index", 1) == 1:
					training_hooks = []
				else:
					training_hooks = []

				if len(optimizer_fn.distributed_hooks) >= 1:
					training_hooks.extend(optimizer_fn.distributed_hooks)
				print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1))

			if output_type == "sess":
				return {
					"train":{
							"total_loss":total_loss, 
							"loss":losses_dict,
							"logits":logits_dict,
							"train_op":train_op,
							"task_num_dict":task_num_dict
					},
					"hooks":train_hooks
				}
			elif output_type == "estimator":

				hook_dict['learning_rate'] = optimizer_fn.learning_rate
				logging_hook = tf.train.LoggingTensorHook(
					hook_dict, every_n_iter=100)
				training_hooks.append(logging_hook)

				print("==hook_dict==")

				print(hook_dict)

				for key in hook_dict:
					tf.summary.scalar(key, hook_dict[key])
					for index, task_type in enumerate(task_type_dict.keys()):
						tmp = "{}_loss".format(task_type)
						if tmp == key:
							tf.summary.scalar("loss_gap_{}".format(task_type), 
												hook_dict["total_loss"]-hook_dict[key])
				for key in task_num_dict:
					tf.summary.scalar(key+"_task_num", task_num_dict[key])
				

				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=total_loss,
								train_op=train_op,
								training_hooks=training_hooks)
				return estimator_spec

		elif mode == tf.estimator.ModeKeys.EVAL: # eval execute for each class solo
			def metric_fn(logits, 
						label_ids):
				"""Computes the loss and accuracy of the model."""
				sentence_log_probs = tf.reshape(
					logits, [-1, logits.shape[-1]])
				sentence_predictions = tf.argmax(
					logits, axis=-1, output_type=tf.int32)
				sentence_labels = tf.reshape(label_ids, [-1])
				sentence_accuracy = tf.metrics.accuracy(
					labels=label_ids, predictions=sentence_predictions)
				sentence_f = tf_metrics.f1(label_ids, 
										sentence_predictions, 
										num_labels, 
										label_lst, average="macro")

				eval_metric_ops = {
									"f1": sentence_f,
									"acc":sentence_accuracy
								}

				return eval_metric_ops

			if output_type == "sess":
				return {
					"eval":{
							"logits":logits_dict,
							"total_loss":total_loss,
							"feature":features,
							"loss":losses_dict
						}
				}
			elif output_type == "estimator":
				eval_metric_ops = {}
				for key in logits_dict:
					eval_dict = metric_fn(
							logits_dict[key],
							features_task_dict[key]["label_ids"]
						)
					for sub_key in eval_dict.keys():
						eval_key = "{}_{}".format(key, sub_key)
						eval_metric_ops[eval_key] = eval_dict[sub_key]
				estimator_spec = tf.estimator.EstimatorSpec(mode=mode, 
								loss=total_loss/task_num,
								eval_metric_ops=eval_metric_ops)
				return estimator_spec
		else:
			raise NotImplementedError()