def model_config_parser(FLAGS): print(FLAGS.model_type) if FLAGS.model_type in [ "bert", "bert_rule", "albert", "electra_gumbel_encoder", "albert_official", "electra_gumbel_albert_official_encoder", "bert_seq" ]: config = json.load(open(FLAGS.config_file, "r")) print(config, '==model config==') config = Bunch(config) config.use_one_hot_embeddings = True # if FLAGS.exclude_scope: # config.scope = FLAGS.exclude_scope + "/" + "bert" # tf.logging.info("****** add exclude_scope ******* %s", str(config.scope)) # else: # config.scope = FLAGS.exclude_scope + "/" + "bert" # tf.logging.info("****** add exclude_scope ******* %s", str(config.scope)) # else: config.scope = FLAGS.model_scope #"bert" tf.logging.info("****** original scope ******* %s", str(config.scope)) config.dropout_prob = 0.1 try: config.label_type = FLAGS.label_type except: config.label_type = "single_label" tf.logging.info("****** label type ******* %s", str(config.label_type)) config.model_type = FLAGS.model_type config.ln_type = FLAGS.ln_type if FLAGS.task_type in ['bert_pretrain']: if FLAGS.load_pretrained == "yes": config.init_lr = FLAGS.init_lr config.warmup = 0.1 else: config.init_lr = FLAGS.init_lr config.warmup = 0.1 print('==apply bert pretrain==', config.init_lr) else: if FLAGS.model_type in ['albert']: try: config.init_lr = FLAGS.init_lr except: config.init_lr = 1e-4 else: # try: print(FLAGS) config.init_lr = FLAGS.init_lr # except: # config.init_lr = 2e-5 print('==apply albert finetuning==', config.init_lr) print("===learning rate===", config.init_lr) try: if FLAGS.attention_type in ['rezero_transformer']: config.warmup = 0.0 tf.logging.info("****** warmup ******* %s", str(config.warmup)) except: tf.logging.info("****** normal attention ******* ") tf.logging.info("****** learning rate ******* %s", str(config.init_lr)) # config.loss = "dmi_loss" try: config.loss = FLAGS.loss except: config.loss = "entropy" tf.logging.info("****** loss type ******* %s", str(config.loss)) # config.loss = "focal_loss" config.rule_type_size = 2 config.lm_ratio = 1.0 config.max_length = FLAGS.max_length config.nsp_ratio = 1.0 config.max_predictions_per_seq = FLAGS.max_predictions_per_seq if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier elif FLAGS.model_type in ["bert_small"]: config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" config.model_type = FLAGS.model_type config.init_lr = 3e-5 config.num_hidden_layers = FLAGS.num_hidden_layers config.loss = "entropy" config.rule_type_size = 2 if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in [ "textcnn", 'textcnn_distillation', 'textcnn_distillation_adv_adaptation', 'textcnn_interaction' ]: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) config = json.load(open(FLAGS.config_file, "r")) [w2v_embed, token2id, id2token, is_extral_symbol, use_pretrained ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path, config.get('emb_size', 64)) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "textcnn" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate config.use_pretrained = use_pretrained config.label_type = FLAGS.label_type if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ["textlstm", "textlstm_distillation"]: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "textlstm" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate config.grad_clip = "gloabl_norm" config.clip_norm = 5.0 if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ["match_pyramid", "match_pyramid_distillation"]: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "match_pyramid" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate config.grad_clip = "gloabl_norm" config.clip_norm = 5.0 if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") config.max_seq_len = FLAGS.max_length if FLAGS.task_type in ["interaction_pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer if config.compress_emb: config.embedding_dim_compressed = config.cnn_num_filters elif FLAGS.model_type in ["dan", 'dan_distillation']: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol ] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "dan" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ['gpt']: config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.dropout_prob = 0.1 config.init_lr = 1e-4 elif FLAGS.model_type in ["gated_cnn_seq"]: config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = None config.char_emb_mat = None config.vocab_size = config.vocab_size config.max_length = FLAGS.max_length config.emb_size = config.emb_size config.scope = "textcnn" config.char_dim = config.emb_char_size config.char_vocab_size = config.vocab_size config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = FLAGS.init_lr config.grad_clip = "gloabl_norm" config.clip_norm = 10.0 config.max_seq_len = FLAGS.max_length return config
def model_config_parser(FLAGS): print(FLAGS.model_type) if FLAGS.model_type in ["bert", "bert_rule"]: config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" config.model_type = FLAGS.model_type config.init_lr = 2e-5 config.loss = "entropy" config.rule_type_size = 2 if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier elif FLAGS.model_type in ["bert_small"]: config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.use_one_hot_embeddings = True config.scope = "bert" config.dropout_prob = 0.1 config.label_type = "single_label" config.model_type = FLAGS.model_type config.init_lr = 2e-5 config.num_hidden_layers = FLAGS.num_hidden_layers config.loss = "entropy" config.rule_type_size = 2 if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ["textcnn", 'textcnn_distillation']: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "textcnn" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ["textlstm", "textlstm_distillation"]: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "textlstm" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate config.grad_clip = "gloabl_norm" config.clip_norm = 5.0 if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") if FLAGS.task_type in ["pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer elif FLAGS.model_type in ["match_pyramid", "match_pyramid_distillation"]: from data_generator import load_w2v w2v_path = os.path.join(FLAGS.buckets, FLAGS.w2v_path) vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file) print(w2v_path, vocab_path) [w2v_embed, token2id, id2token, is_extral_symbol] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) config = json.load(open(FLAGS.config_file, "r")) config = Bunch(config) config.token_emb_mat = w2v_embed config.char_emb_mat = None config.vocab_size = w2v_embed.shape[0] config.max_length = FLAGS.max_length config.emb_size = w2v_embed.shape[1] config.scope = "match_pyramid" config.char_dim = w2v_embed.shape[1] config.char_vocab_size = w2v_embed.shape[0] config.char_embedding = None config.model_type = FLAGS.model_type config.dropout_prob = config.dropout_rate config.init_lr = config.learning_rate config.grad_clip = "gloabl_norm" config.clip_norm = 5.0 if is_extral_symbol == 1: config.extra_symbol = ["<pad>", "<unk>", "<s>", "</s>"] print("==need extra_symbol==") config.max_seq_len = FLAGS.max_length if FLAGS.task_type in ["interaction_pair_sentence_classification"]: config.classifier = FLAGS.classifier config.output_layer = FLAGS.output_layer if config.compress_emb: config.embedding_dim_compressed = config.cnn_num_filters return config
def model_fn(features, labels, mode): train_ops = [] train_hooks = [] logits_dict = {} losses_dict = {} features_dict = {} tvars = [] task_num_dict = {} total_loss = tf.constant(0.0) task_num = 0 encoder = {} hook_dict = {} print(task_type_dict.keys(), "==task type dict==") num_task = len(task_type_dict) from data_generator import load_w2v flags = kargs.get('flags', Bunch({})) print(flags.pretrained_w2v_path, "===pretrain vocab path===") w2v_path = os.path.join(flags.buckets, flags.pretrained_w2v_path) vocab_path = os.path.join(flags.buckets, flags.vocab_file) [w2v_embed, token2id, id2token, is_extral_symbol, use_pretrained] = load_w2v.load_pretrained_w2v(vocab_path, w2v_path) pretrained_embed = tf.cast(tf.constant(w2v_embed), tf.float32) for index, task_type in enumerate(task_type_dict.keys()): if model_config_dict[task_type].model_type in model_type_lst: reuse = True else: reuse = None model_type_lst.append(model_config_dict[task_type].model_type) if task_type_dict[task_type] == "cls_task": if model_config_dict[task_type].model_type not in encoder: model_api = model_zoo(model_config_dict[task_type]) model = model_api(model_config_dict[task_type], features, labels, mode, target_dict[task_type], reuse=reuse, cnn_type='multilayer_textcnn') encoder[model_config_dict[task_type].model_type] = model print(encoder, "==encode==") task_model_fn = cls_model_fn(encoder[model_config_dict[task_type].model_type], model_config_dict[task_type], num_labels_dict[task_type], init_checkpoint_dict[task_type], reuse, load_pretrained_dict[task_type], model_io_config, opt_config, exclude_scope=exclude_scope_dict[task_type], not_storage_params=not_storage_params_dict[task_type], target=target_dict[task_type], label_lst=None, output_type=output_type, task_layer_reuse=task_layer_reuse, task_type=task_type, num_task=num_task, task_adversarial=1e-2, get_pooled_output='task_output', feature_distillation=False, embedding_distillation=True, pretrained_embed=pretrained_embed, **kargs) print("==SUCCEEDED IN LODING==", task_type) result_dict = task_model_fn(features, labels, mode) logits_dict[task_type] = result_dict["logits"] losses_dict[task_type] = result_dict["loss"] # task loss for key in ["masked_lm_loss", "task_loss", "acc", "task_acc", "masked_lm_acc"]: name = "{}_{}".format(task_type, key) if name in result_dict: hook_dict[name] = result_dict[name] hook_dict["{}_loss".format(task_type)] = result_dict["loss"] hook_dict["{}_num".format(task_type)] = result_dict["task_num"] total_loss += result_dict["loss"] hook_dict['embed_loss'] = result_dict["embed_loss"] hook_dict['feature_loss'] = result_dict["feature_loss"] hook_dict["{}_task_loss".format(task_type)] = result_dict["task_loss"] if mode == tf.estimator.ModeKeys.TRAIN: tvars.extend(result_dict["tvars"]) task_num += result_dict["task_num"] task_num_dict[task_type] = result_dict["task_num"] elif mode == tf.estimator.ModeKeys.EVAL: features[task_type] = result_dict["feature"] else: continue hook_dict["total_loss"] = total_loss if mode == tf.estimator.ModeKeys.TRAIN: model_io_fn = model_io.ModelIO(model_io_config) optimizer_fn = optimizer.Optimizer(opt_config) model_io_fn.print_params(list(set(tvars)), string=", trainable params") update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) print("==update_ops==", update_ops) with tf.control_dependencies(update_ops): train_op = optimizer_fn.get_train_op(total_loss, list(set(tvars)), opt_config.init_lr, opt_config.num_train_steps, **kargs) model_io_fn.set_saver(optimizer_fn.opt) if kargs.get("task_index", 1) == 1 and kargs.get("run_config", None): model_io_fn.get_hooks(kargs.get("checkpoint_dir", None), kargs.get("num_storage_steps", 1000)) training_hooks = model_io_fn.checkpoint_hook elif kargs.get("task_index", 1) == 1: training_hooks = [] else: training_hooks = [] if len(optimizer_fn.distributed_hooks) >= 1: training_hooks.extend(optimizer_fn.distributed_hooks) print(training_hooks, "==training_hooks==", "==task_index==", kargs.get("task_index", 1)) if output_type == "sess": return { "train":{ "total_loss":total_loss, "loss":losses_dict, "logits":logits_dict, "train_op":train_op, "task_num_dict":task_num_dict }, "hooks":train_hooks } elif output_type == "estimator": hook_dict['learning_rate'] = optimizer_fn.learning_rate logging_hook = tf.train.LoggingTensorHook( hook_dict, every_n_iter=100) training_hooks.append(logging_hook) print("==hook_dict==") print(hook_dict) for key in hook_dict: tf.summary.scalar(key, hook_dict[key]) for index, task_type in enumerate(task_type_dict.keys()): tmp = "{}_loss".format(task_type) if tmp == key: tf.summary.scalar("loss_gap_{}".format(task_type), hook_dict["total_loss"]-hook_dict[key]) for key in task_num_dict: tf.summary.scalar(key+"_task_num", task_num_dict[key]) estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss, train_op=train_op, training_hooks=training_hooks) return estimator_spec elif mode == tf.estimator.ModeKeys.EVAL: # eval execute for each class solo def metric_fn(logits, label_ids): """Computes the loss and accuracy of the model.""" sentence_log_probs = tf.reshape( logits, [-1, logits.shape[-1]]) sentence_predictions = tf.argmax( logits, axis=-1, output_type=tf.int32) sentence_labels = tf.reshape(label_ids, [-1]) sentence_accuracy = tf.metrics.accuracy( labels=label_ids, predictions=sentence_predictions) sentence_f = tf_metrics.f1(label_ids, sentence_predictions, num_labels, label_lst, average="macro") eval_metric_ops = { "f1": sentence_f, "acc":sentence_accuracy } return eval_metric_ops if output_type == "sess": return { "eval":{ "logits":logits_dict, "total_loss":total_loss, "feature":features, "loss":losses_dict } } elif output_type == "estimator": eval_metric_ops = {} for key in logits_dict: eval_dict = metric_fn( logits_dict[key], features_task_dict[key]["label_ids"] ) for sub_key in eval_dict.keys(): eval_key = "{}_{}".format(key, sub_key) eval_metric_ops[eval_key] = eval_dict[sub_key] estimator_spec = tf.estimator.EstimatorSpec(mode=mode, loss=total_loss/task_num, eval_metric_ops=eval_metric_ops) return estimator_spec else: raise NotImplementedError()