Esempio n. 1
0
def train_input_fn(FLAGS):

    multi_task_config = Bunch(
        json.load(open(os.path.join(FLAGS.buckets, FLAGS.multitask_dict))))

    vocab_path = FLAGS.vocab_file

    train_file_dict = {}
    test_file_dict = {}
    dev_file_dict = {}
    train_result_dict = {}
    test_result_dict = {}
    dev_result_dict = {}
    label_id_dict = {}
    for task in multi_task_config:
        train_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["train_file"])

        test_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["test_file"])

        dev_file_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["dev_file"])

        label_id_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["label_id"])

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path,
                                           do_lower_case=FLAGS.lower_case)

    index = 0
    task_type_id = OrderedDict()
    label2id_dict = {}

    for task in (FLAGS.multi_task_type.split(",")):
        if task not in multi_task_config:
            continue
        task_type_id[task] = multi_task_config[task]
        index += 1
        data_type = multi_task_config[task]["data_type"]
        if data_type == "single_sentence":
            classifier_data_api = classifier_processor.SentenceProcessor()
            classifier_data_api.get_labels(label_id_dict[task])
        elif data_type == "sentence_pair":
            classifier_data_api = classifier_processor.SentencePairProcessor()
            classifier_data_api.get_labels(label_id_dict[task])

        train_examples = classifier_data_api.get_train_examples(
            train_file_dict[task], is_shuffle=True)
        label2id_dict[task] = classifier_data_api.label2id

        for item in train_examples:
            tmp = {"example": item, "task": task}
            total_examples.append(tmp)

    print(task_type_id.keys())
    print("==total data==", len(total_examples))
def main(_):
	tf.logging.set_verbosity(tf.logging.INFO)

	print(FLAGS.do_whole_word_mask, FLAGS.do_lower_case)

	if FLAGS.tokenizer_type == "spm":
		word_piece_model = os.path.join(FLAGS.buckets, FLAGS.word_piece_model)
		tokenizer = tokenization.SPM(config={
			"word_dict":FLAGS.vocab_file,
			"word_piece_model":word_piece_model
			})
		tokenizer.load_dict()
		tokenizer.load_model()
		tokenizer.add_extra_word()
		tokenizer.build_word_id()
	elif FLAGS.tokenizer_type == "word_piece":
		tokenizer = tokenization.FullTokenizer(
			vocab_file=FLAGS.vocab_file, 
			do_lower_case=FLAGS.do_lower_case,
			do_whole_word_mask=FLAGS.do_whole_word_mask)

	input_files = []
	for input_pattern in FLAGS.input_file.split(","):
		input_files.extend(tf.gfile.Glob(input_pattern))

	tf.logging.info("*** Reading from input files ***")
	for input_file in input_files:
		tf.logging.info("  %s", input_file)

	rng = random.Random(FLAGS.random_seed)

	output_files = FLAGS.output_file.split(",")
	tf.logging.info("*** Writing to output files ***")
	for output_file in output_files:
		tf.logging.info("  %s", output_file)

	start = time.time()

	multi_process(
			input_files=input_files, 
			tokenizer=tokenizer,
			max_seq_length=FLAGS.max_seq_length,
			masked_lm_prob=FLAGS.masked_lm_prob, 
			max_predictions_per_seq=FLAGS.max_predictions_per_seq, 
			short_seq_prob=FLAGS.short_seq_prob,
			output_file=output_file,
			process_num=1,
			dupe_factor=FLAGS.dupe_factor,
			random_seed=1234567
		)
	print(time.time()-start, "==total time==")
Esempio n. 3
0
    def init_model(self):

        self.graph = tf.Graph()
        with self.graph.as_default():

            init_checkpoint = self.config["init_checkpoint"]
            bert_config = json.load(open(self.config["bert_config"], "r"))

            self.model_config = Bunch(bert_config)
            self.model_config.use_one_hot_embeddings = True
            self.model_config.scope = "bert"
            self.model_config.dropout_prob = 0.1
            self.model_config.label_type = "single_label"

            self.input_queue = Queue(maxsize=self.config.get("batch_size", 20))
            self.output_queue = Queue(
                maxsize=self.config.get("batch_size", 20))

            opt_config = Bunch({
                "init_lr": 2e-5,
                "num_train_steps": 1e30,
                "cycle": False
            })
            model_io_config = Bunch({"fix_lm": False})

            self.num_classes = len(self.label_dict["id2label"])
            self.max_seq_length = self.config["max_length"]

            self.tokenizer = tokenization.FullTokenizer(
                vocab_file=self.config["bert_vocab"], do_lower_case=True)

            self.sess = tf.Session()
            self.model_io_fn = model_io.ModelIO(model_io_config)

            model_fn = bert_classifier_estimator.classifier_model_fn_builder(
                self.model_config,
                self.num_classes,
                init_checkpoint,
                reuse=None,
                load_pretrained=True,
                model_io_fn=self.model_io_fn,
                model_io_config=model_io_config,
                opt_config=opt_config)

            self.estimator = tf.estimator.Estimator(
                model_fn=model_fn, model_dir=self.config["model_dir"])
Esempio n. 4
0
def get_tokenizer(FLAGS, vocab_path, **kargs):
    if FLAGS.tokenizer == "bert":
        tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_path, do_lower_case=FLAGS.do_lower_case)
    elif FLAGS.tokenizer == "jieba_char":
        tokenizer = tokenization.Jieba_CHAR(config=kargs.get("config", {}))

        with tf.gfile.Open(vocab_path, "r") as f:
            lines = f.read().splitlines()
            vocab_lst = []
            for line in lines:
                vocab_lst.append(line)
            print(len(vocab_lst))

        tokenizer.load_vocab(vocab_lst)

    return tokenizer
def main(_):

	# vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)
	vocab_path = FLAGS.vocab_file
	train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
	test_file = os.path.join(FLAGS.buckets, FLAGS.test_file)
	dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

	train_result_file = os.path.join(FLAGS.buckets, FLAGS.train_result_file)
	test_result_file = os.path.join(FLAGS.buckets, FLAGS.test_result_file)
	dev_result_file = os.path.join(FLAGS.buckets, FLAGS.dev_result_file)

	tokenizer = tokenization.FullTokenizer(
		vocab_file=vocab_path, 
		do_lower_case=FLAGS.lower_case)

	if FLAGS.data_type == "lcqmc":
		classifier_data_api = classifier_processor.LCQMCProcessor()

	classifier_data_api.get_labels(FLAGS.label_id)

	train_examples = classifier_data_api.get_train_examples(train_file,
										is_shuffle=False)
	dev_examples = classifier_data_api.get_train_examples(dev_file,
														is_shuffle=False)
	test_examples = classifier_data_api.get_train_examples(test_file,
										is_shuffle=False)

	write_to_tfrecords.convert_pair_order_classifier_examples_to_features(train_examples,
															classifier_data_api.label2id,
															FLAGS.max_length,
															tokenizer,
															train_result_file)

	write_to_tfrecords.convert_pair_order_classifier_examples_to_features(dev_examples,
															classifier_data_api.label2id,
															FLAGS.max_length,
															tokenizer,
															dev_result_file)
	
	write_to_tfrecords.convert_pair_order_classifier_examples_to_features(test_examples,
															classifier_data_api.label2id,
															FLAGS.max_length,
															tokenizer,
															test_result_file)
Esempio n. 6
0
def main(_):

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.lower_case)

    classifier_data_api = classifier_processor.PornClassifierProcessor()
    classifier_data_api.get_labels(FLAGS.label_id)

    train_examples = classifier_data_api.get_train_examples(FLAGS.train_file)

    write_to_tfrecords.convert_classifier_examples_to_features(
        train_examples, classifier_data_api.label2id, FLAGS.max_length,
        tokenizer, FLAGS.train_result_file)

    test_examples = classifier_data_api.get_train_examples(FLAGS.test_file)
    write_to_tfrecords.convert_classifier_examples_to_features(
        test_examples, classifier_data_api.label2id, FLAGS.max_length,
        tokenizer, FLAGS.test_result_file)
Esempio n. 7
0
def read_data_fn(FLAGS, multi_task_config, task, mode):

    train_file = os.path.join(FLAGS.buckets,
                              multi_task_config[task]["train_file"])

    test_file = os.path.join(FLAGS.buckets,
                             multi_task_config[task]["test_file"])

    dev_file = os.path.join(FLAGS.buckets, multi_task_config[task]["dev_file"])

    label_id = os.path.join(FLAGS.buckets, multi_task_config[task]["label_id"])

    print(train_file, test_file, dev_file, label_id, task, "======")

    tokenizer = tokenization.FullTokenizer(
        vocab_file=multi_task_config[task]["vocab_file"],
        do_lower_case=multi_task_config[task]["do_lower_case"])

    data_type = multi_task_config[task]["data_type"]
    if data_type == "single_sentence":
        classifier_data_api = SentenceProcessor()
    elif data_type == "sentence_pair":
        classifier_data_api = SentencePairProcessor()
    classifier_data_api.get_labels(label_id)

    if mode == "train":
        examples = classifier_data_api.get_train_examples(train_file,
                                                          is_shuffle=True)
    elif mode == "eval":
        examples = classifier_data_api.get_train_examples(dev_file,
                                                          is_shuffle=False)
    elif mode == "test":
        examples = classifier_data_api.get_train_examples(test_file,
                                                          is_shuffle=False)
    else:
        examples = None

    return {
        "examples": examples,
        "tokenizer": tokenizer,
        "label2id": classifier_data_api.label2id
    }
Esempio n. 8
0
def main(_):
	tf.logging.set_verbosity(tf.logging.INFO)

	print(FLAGS.do_whole_word_mask, FLAGS.do_lower_case)

	if FLAGS.tokenizer_type == "spm":
		word_piece_model = os.path.join(FLAGS.buckets, FLAGS.word_piece_model)
		tokenizer = tokenization.SPM(config={
			"word_dict":FLAGS.vocab_file,
			"word_piece_model":word_piece_model
			})
		tokenizer.load_dict()
		tokenizer.load_model()
		tokenizer.add_extra_word()
		tokenizer.build_word_id()
	elif FLAGS.tokenizer_type == "word_piece"::
		tokenizer = tokenization.FullTokenizer(
			vocab_file=FLAGS.vocab_file, 
			do_lower_case=FLAGS.do_lower_case,
			do_whole_word_mask=FLAGS.do_whole_word_mask)

	input_files = []
	for input_pattern in FLAGS.input_file.split(","):
		input_files.extend(tf.gfile.Glob(input_pattern))

	tf.logging.info("*** Reading from input files ***")
	for input_file in input_files:
		tf.logging.info("  %s", input_file)

	rng = random.Random(FLAGS.random_seed)
	instances = create_training_instances(
			input_files, tokenizer, FLAGS.max_seq_length, FLAGS.dupe_factor,
			FLAGS.short_seq_prob, FLAGS.masked_lm_prob, FLAGS.max_predictions_per_seq,
			rng)

	output_files = FLAGS.output_file.split(",")
	tf.logging.info("*** Writing to output files ***")
	for output_file in output_files:
		tf.logging.info("  %s", output_file)

	write_instance_to_example_files(instances, tokenizer, FLAGS.max_seq_length,
																	FLAGS.max_predictions_per_seq, output_files)
Esempio n. 9
0
def main(_):

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.lower_case)

    classifier_data_api = classifier_processor.ClassificationProcessor()
    classifier_data_api.get_labels(FLAGS.label_id)

    train_examples = classifier_data_api.get_train_examples(FLAGS.train_file)

    write_to_records_pretrain.multi_process(
        examples=train_examples,
        process_num=FLAGS.num_threads,
        label_dict=classifier_data_api.label2id,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_length,
        masked_lm_prob=FLAGS.masked_lm_prob,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        output_file=FLAGS.train_result_file,
        dupe=FLAGS.dupe,
        random_seed=2018,
        feature_type=FLAGS.feature_type,
        log_cycle=FLAGS.log_cycle)

    test_examples = classifier_data_api.get_train_examples(FLAGS.test_file)
    write_to_records_pretrain.multi_process(
        examples=test_examples,
        process_num=FLAGS.num_threads,
        label_dict=classifier_data_api.label2id,
        tokenizer=tokenizer,
        max_seq_length=FLAGS.max_length,
        masked_lm_prob=FLAGS.masked_lm_prob,
        max_predictions_per_seq=FLAGS.max_predictions_per_seq,
        output_file=FLAGS.test_result_file,
        dupe=FLAGS.dupe,
        random_seed=2018,
        feature_type=FLAGS.feature_type,
        log_cycle=FLAGS.log_cycle)

    print(
        "==Succeeded in preparing masked lm with finetuning data for task-finetuning with masked lm regularization"
    )
Esempio n. 10
0
def main(_):

    # vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)
    vocab_path = FLAGS.vocab_file
    train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)

    train_result_file = os.path.join(FLAGS.buckets, FLAGS.train_result_file)

    tokenizer = tokenization.FullTokenizer(
        vocab_file=vocab_path,
        do_lower_case=True if FLAGS.lower_case == "true" else False,
        do_whole_word_mask=True
        if FLAGS.do_whole_word_mask == "true" else False)

    token_mapping = {'<code_num>': '[unused1]'}

    processor = sequence_processor.ProductTitleLanguageModelProcessor()
    processor._read_write(train_file,
                          train_result_file,
                          tokenizer,
                          max_length=FLAGS.max_length,
                          bos='<S>',
                          eos='<T>',
                          token_mapping=token_mapping)
Esempio n. 11
0
flags.DEFINE_integer(
	"max_length", 100,
	"Input TF example files (can be a glob or comma separated).")

flags.DEFINE_string(
	"model_type", None,
	"Input TF example files (can be a glob or comma separated).")

graph = tf.Graph()
# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
with graph.as_default():
	import json

	tokenizer = tokenization.FullTokenizer(
	  vocab_file=FLAGS.vocab_file, 
		do_lower_case=True)

	classifier_data_api = classifier_processor.PiarOrderProcessor()

	eval_examples = classifier_data_api.get_test_examples(FLAGS.eval_data_file,
														FLAGS.lang)

	print(eval_examples[0].guid)

	label_tensor = np.asarray([0.18987, 0.20253, 0.60759]).astype(np.float32)

	label_id = json.load(open(FLAGS.label_id, "r"))

	num_choice = 3
	max_seq_length = FLAGS.max_length
def main(_):

    import json
    multi_task_config = Bunch(
        json.load(open(os.path.join(FLAGS.buckets, FLAGS.multitask_dict))))

    vocab_path = FLAGS.vocab_file
    # os.path.join(FLAGS.buckets, FLAGS.vocab_file)

    train_file_dict = {}
    test_file_dict = {}
    dev_file_dict = {}
    train_result_dict = {}
    test_result_dict = {}
    dev_result_dict = {}
    label_id_dict = {}
    for task in multi_task_config:
        train_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["train_file"])

        test_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["test_file"])

        dev_file_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["dev_file"])

        train_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["train_result_file"])

        test_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["test_result_file"])

        dev_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["dev_result_file"])
        label_id_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["label_id"])

    print(train_file_dict)

    if FLAGS.lower_case == "True":
        lower_case = True
    else:
        lower_case = False

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path,
                                           do_lower_case=lower_case)

    for task in (FLAGS.multi_task_type.split(",")):
        if task not in multi_task_config:
            continue
        data_type = multi_task_config[task]["data_type"]
        if data_type == "single_sentence":
            classifier_data_api = classifier_processor.SentenceProcessor()
            classifier_data_api.get_labels(label_id_dict[task])
        elif data_type == "sentence_pair":
            classifier_data_api = classifier_processor.SentencePairProcessor()
            classifier_data_api.get_labels(label_id_dict[task])

        train_examples = classifier_data_api.get_train_examples(
            train_file_dict[task], is_shuffle=True)
        test_examples = classifier_data_api.get_train_examples(
            test_file_dict[task], is_shuffle=False)
        dev_examples = classifier_data_api.get_train_examples(
            dev_file_dict[task], is_shuffle=False)

        print(classifier_data_api.label2id, task)

        write_to_tfrecords_multitask.convert_multitask_classifier_examples_to_features(
            train_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, train_result_dict[task], task, multi_task_config)

        write_to_tfrecords_multitask.convert_multitask_classifier_examples_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, test_result_dict[task], task, multi_task_config)

        write_to_tfrecords_multitask.convert_multitask_classifier_examples_to_features(
            dev_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, dev_result_dict[task], task, multi_task_config)
Esempio n. 13
0
def main(_):

    # with tf.gfile.Open(FLAGS.vocab_file, "r") as f:
    # 	vocab_lst = []
    # 	for line in f:
    # 		vocab_lst.append(line.strip())

    vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)
    train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    test_file = os.path.join(FLAGS.buckets, FLAGS.test_file)
    dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

    train_result_file = os.path.join(FLAGS.buckets, FLAGS.train_result_file)
    test_result_file = os.path.join(FLAGS.buckets, FLAGS.test_result_file)
    dev_result_file = os.path.join(FLAGS.buckets, FLAGS.dev_result_file)

    corpus_vocab_path = os.path.join(FLAGS.buckets, FLAGS.corpus_vocab_path)
    unsupervised_distillation_file = os.path.join(
        FLAGS.buckets, FLAGS.unsupervised_distillation_file)
    supervised_distillation_file = os.path.join(
        FLAGS.buckets, FLAGS.supervised_distillation_file)

    if FLAGS.tokenizer_type == "jieba":
        tokenizer = tokenization.Jieba_CHAR(config=FLAGS.config)
    elif FLAGS.tokenizer_type == "full_bpe":
        tokenizer = tokenization.FullTokenizer(
            vocab_file=vocab_path,
            do_lower_case=True if FLAGS.lower_case == "true" else False)

    if FLAGS.tokenizer_type == "jieba":
        print(FLAGS.with_char)
        with tf.gfile.Open(vocab_path, "r") as f:
            lines = f.read().splitlines()
            vocab_lst = []
            for line in lines:
                vocab_lst.append(line)
            print(len(vocab_lst))

        tokenizer.load_vocab(vocab_lst)

    print("==not apply rule==")

    if FLAGS.distillation_type == "prob":
        classifier_data_api = classifier_processor.FasttextDistillationProcessor(
        )
    elif FLAGS.distillation_type == "structure":
        classifier_data_api = classifier_processor.FasttextStructureDistillationProcessor(
        )
    classifier_data_api.get_labels(FLAGS.label_id)

    train_examples = classifier_data_api.get_supervised_distillation_examples(
        train_file, supervised_distillation_file, is_shuffle=True)

    if FLAGS.tokenizer_type == "jieba":
        vocab_filter.vocab_filter(train_examples, vocab_lst, tokenizer,
                                  FLAGS.predefined_vocab_size,
                                  corpus_vocab_path)

        tokenizer_corpus = tokenization.Jieba_CHAR(config=FLAGS.config)

        with tf.gfile.Open(corpus_vocab_path, "r") as f:
            lines = f.read().splitlines()
            vocab_lst = []
            for line in lines:
                vocab_lst.append(line)
            print(len(vocab_lst))

        tokenizer_corpus.load_vocab(vocab_lst)
    elif FLAGS.tokenizer_type == "full_bpe":
        tokenizer_corpus = tokenizer

    dev_examples = classifier_data_api.get_unsupervised_distillation_examples(
        dev_file, unsupervised_distillation_file, is_shuffle=False)

    import random
    if FLAGS.if_add_unlabeled_distillation == "yes":
        total_train_examples = train_examples + dev_examples
    else:
        total_train_examples = train_examples
    random.shuffle(total_train_examples)

    if FLAGS.tokenizer_type == "jieba":

        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            total_train_examples, classifier_data_api.label2id,
            FLAGS.max_length, tokenizer_corpus, train_result_file,
            FLAGS.with_char, FLAGS.char_len)

        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            dev_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, dev_result_file, FLAGS.with_char, FLAGS.char_len)

        test_examples = classifier_data_api.get_train_examples(
            test_file, is_shuffle=False)
        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, test_result_file, FLAGS.with_char,
            FLAGS.char_len)
    elif FLAGS.tokenizer_type == "full_bpe":
        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            total_train_examples, classifier_data_api.label2id,
            FLAGS.max_length, tokenizer_corpus, train_result_file,
            FLAGS.with_char, FLAGS.char_len)

        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            dev_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, dev_result_file, FLAGS.with_char, FLAGS.char_len)

        test_examples = classifier_data_api.get_train_examples(
            test_file, is_shuffle=False)
        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, test_result_file, FLAGS.with_char,
            FLAGS.char_len)
Esempio n. 14
0
def main(_):
    graph = tf.Graph()
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    with graph.as_default():

        tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                               do_lower_case=True)

        classifier_data_api = classifier_processor.PiarOrderProcessor()

        eval_examples = classifier_data_api.get_test_examples(
            FLAGS.eval_data_file, FLAGS.lang)
        print(len(eval_examples), eval_examples[0:10])

        label_id = json.load(open(FLAGS.label_id, "r"))

        num_choice = FLAGS.num_classes
        max_seq_length = FLAGS.max_length

        write_to_tfrecords.convert_pair_order_classifier_examples_to_features(
            eval_examples, label_id["label2id"], max_seq_length, tokenizer,
            FLAGS.output_file)

        os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
        sess = tf.Session()

        config = json.load(open(FLAGS.config_file, "r"))

        student_config = json.load(open(FLAGS.student_config_file, "r"))

        student_config = Bunch(student_config)
        # student_config.use_one_hot_embeddings = True
        # student_config.scope = "student/bert"
        # student_config.dropout_prob = 0.1
        # student_config.label_type = "single_label"
        # student_config.init_checkpoint = FLAGS.student_init_checkpoint

        temperature = student_config.temperature
        distill_ratio = student_config.distill_ratio

        # json.dump(student_config, open(FLAGS.model_output+"/student_config.json", "w"))

        teacher_config = Bunch(config)
        teacher_config.use_one_hot_embeddings = True
        teacher_config.scope = "teacher/bert"
        teacher_config.dropout_prob = 0.1
        teacher_config.label_type = "single_label"
        teacher_config.init_checkpoint = FLAGS.teacher_init_checkpoint

        # json.dump(teacher_config, open(FLAGS.model_output+"/teacher_config.json", "w"))

        model_config_dict = {
            "student": student_config,
            "teacher": teacher_config
        }
        init_checkpoint_dict = {
            "student": FLAGS.student_init_checkpoint,
            "teacher": FLAGS.teacher_init_checkpoint
        }

        num_train_steps = int(FLAGS.train_size / FLAGS.batch_size *
                              FLAGS.epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(FLAGS.train_size / FLAGS.batch_size)

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": 1e-5,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)

        model_eval_fn = distillation.distillation_model_fn(
            model_config_dict=model_config_dict,
            num_labels=num_choice,
            init_checkpoint_dict=init_checkpoint_dict,
            model_reuse=None,
            load_pretrained={
                "teacher": True,
                "student": True
            },
            model_io_fn=model_io_fn,
            model_io_config=model_io_config,
            opt_config=opt_config,
            student_input_name=["a", "b"],
            teacher_input_name=["a", "b"],
            unlabel_input_name=["ua", "ub"],
            temperature=temperature,
            exclude_scope_dict={
                "student": "",
                "teacher": "teacher"
            },
            not_storage_params=["adam_m", "adam_v"],
            distillation_weight={
                "label": distill_ratio,
                "unlabel": distill_ratio
            },
            if_distill_unlabeled=False)

        def metric_fn(features, logits):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.exp(tf.nn.log_softmax(logits))
            return {
                "pred_label": pred_label,
                "qas_id": features["qas_id"],
                "prob": prob
            }

        name_to_features = {
            "input_ids_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids_a": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_ids_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids_b": tf.FixedLenFeature([max_seq_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
            "qas_id": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t
            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size
        # train_features = tf_data_utils.train_input_fn("/data/xuht/wsdm19/data/train.tfrecords",
        #                             _decode_record, name_to_features, params)
        # eval_features = tf_data_utils.eval_input_fn("/data/xuht/wsdm19/data/dev.tfrecords",
        #                             _decode_record, name_to_features, params)

        eval_features = tf_data_utils.eval_input_fn(FLAGS.output_file,
                                                    _decode_record,
                                                    name_to_features, params)

        [_, eval_loss, eval_per_example_loss,
         eval_logits] = model_eval_fn(eval_features, [],
                                      tf.estimator.ModeKeys.EVAL)
        result = metric_fn(eval_features, eval_logits)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        def eval_fn(result):
            i = 0
            pred_label, qas_id, prob = [], [], []
            while True:
                try:
                    eval_result = sess.run(result)
                    pred_label.extend(eval_result["pred_label"].tolist())
                    qas_id.extend(eval_result["qas_id"].tolist())
                    prob.extend(eval_result["prob"].tolist())
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break
            return pred_label, qas_id, prob

        print("===========begin to eval============")
        [pred_label, qas_id, prob] = eval_fn(result)
        result = dict(zip(qas_id, pred_label))

        print(FLAGS.result_file.split("."))
        tmp_output = FLAGS.result_file.split(".")[0] + ".json"
        print(tmp_output, "===temp output===")
        json.dump({
            "id": qas_id,
            "label": pred_label,
            "prob": prob
        }, open(tmp_output, "w"))

        print(len(result), "=====valid result======")

        import pandas as pd
        df = pd.read_csv(FLAGS.eval_data_file)

        output = {}
        for index in range(df.shape[0]):
            output[df.loc[index]["id"]] = ""

        final_output = []

        cnt = 0
        for key in output:
            if key in result:
                final_output.append({
                    "Id":
                    key,
                    "Category":
                    label_id["id2label"][str(result[key])]
                })
                cnt += 1
            else:
                final_output.append({"Id": key, "Category": "unrelated"})

        df_out = pd.DataFrame(final_output)
        df_out.to_csv(FLAGS.result_file)

        print(len(output), cnt, len(final_output),
              "======num of results from model==========")
Esempio n. 15
0
def main(_):

	graph = tf.Graph()
	# from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
	with graph.as_default():
		import json

		tokenizer = tokenization.FullTokenizer(
		  vocab_file=FLAGS.vocab_file, 
			do_lower_case=True)

		classifier_data_api = classifier_processor.PiarInteractionProcessor()

		eval_examples = classifier_data_api.get_test_examples(FLAGS.eval_data_file,
															FLAGS.lang)

		print(eval_examples[0].guid)

		label_tensor = None

		label_id = json.load(open(FLAGS.label_id, "r"))

		num_choice = 3

		write_to_tfrecords.convert_interaction_classifier_examples_to_features_v1(
																eval_examples,
																label_id["label2id"],
															   FLAGS.max_length,
															   tokenizer,
															   FLAGS.output_file)

		config = json.load(open(FLAGS.config_file, "r"))
		init_checkpoint = FLAGS.init_checkpoint

		max_seq_length = FLAGS.max_length * 2 + 3

		print("===init checkoutpoint==={}".format(init_checkpoint))

		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "esim/bert"
		config.dropout_prob = 0.2
		config.label_type = "single_label"
		config.lstm_dim = 128
		config.num_heads = 12
		config.num_units = 768
		
		# os.environ["CUDA_VISIBLE_DEVICES"] = "2"
		sess = tf.Session()
		
		opt_config = Bunch({"init_lr":(5e-5), 
							"num_train_steps":0,
							"num_warmup_steps":0,
							"train_op":"adam"})
		model_io_config = Bunch({"fix_lm":False})
		
		model_io_fn = model_io.ModelIO(model_io_config)

		model_function = bert_esim.classifier_attn_model_fn_builder
		model_eval_fn = model_function(
									config, 
									num_choice, 
									init_checkpoint, 
									model_reuse=None, 
									load_pretrained=True,
									model_io_fn=model_io_fn,
									model_io_config=model_io_config, 
									opt_config=opt_config,
									input_name=["a", "b"],
									label_tensor=label_tensor,
									not_storage_params=["adam", "adam_1"],
									exclude_scope_dict={"task":"esim"})
		
		# def metric_fn(features, logits):
		#     print(logits.get_shape(), "===logits shape===")
		#     pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
		#     return {"pred_label":pred_label, "qas_id":features["qas_id"]}

		def metric_fn(features, logits):
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			prob = tf.exp(tf.nn.log_softmax(logits))
			return {"pred_label":pred_label, 
					"qas_id":features["qas_id"],
					"prob":prob}
		
		name_to_features = {
				"input_ids_a":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"input_mask_a":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"segment_ids_a":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"input_ids_b":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"input_mask_b":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"segment_ids_b":
						tf.FixedLenFeature([max_seq_length], tf.int64),
				"label_ids":
						tf.FixedLenFeature([], tf.int64),
				"qas_id":
						tf.FixedLenFeature([], tf.int64),
		}
		
		def _decode_record(record, name_to_features):
			"""Decodes a record to a TensorFlow example.
			"""
			example = tf.parse_single_example(record, name_to_features)

			# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
			# So cast all int64 to int32.
			for name in list(example.keys()):
				t = example[name]
				if t.dtype == tf.int64:
					t = tf.to_int32(t)
				example[name] = t
			
			return example 

		params = Bunch({})
		params.epoch = 2
		params.batch_size = 32

		eval_features = tf_data_utils.eval_input_fn(FLAGS.output_file,
									_decode_record, name_to_features, params)
		
		[_, eval_loss, eval_per_example_loss, eval_logits] = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
		result = metric_fn(eval_features, eval_logits)
		
		model_io_fn.set_saver()
		
		init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
		sess.run(init_op)

		model_io_fn.load_model(sess, init_checkpoint)
		print(" ==succeeded in loading model== ")

		def eval_fn(result):
			i = 0
			pred_label, qas_id, prob = [], [], []
			while True:
				try:
					eval_result = sess.run(result)
					pred_label.extend(eval_result["pred_label"].tolist())
					qas_id.extend(eval_result["qas_id"].tolist())
					prob.extend(eval_result["prob"].tolist())
					i += 1
				except tf.errors.OutOfRangeError:
					print("End of dataset")
					break
			return pred_label, qas_id, prob
		
		print("===========begin to eval============")
		[pred_label, qas_id, prob] = eval_fn(result)
		result = dict(zip(qas_id, pred_label))

		print(FLAGS.result_file.split("."))
		tmp_output = FLAGS.result_file.split(".")[0] + ".json"
		print(tmp_output, "===temp output===")
		json.dump({"id":qas_id,
					"label":pred_label,
					"prob":prob},
					open(tmp_output, "w"))

		print(len(result), "=====valid result======")

		print(len(result), "=====valid result======")

		import pandas as pd
		df = pd.read_csv(FLAGS.eval_data_file)

		output = {}
		for index in range(df.shape[0]):
			output[df.loc[index]["id"]] = ""

		final_output = []

		cnt = 0
		for key in output:
			if key in result:
				final_output.append({"Id":key, 
					"Category":label_id["id2label"][str(result[key])]})
				cnt += 1
			else:
				final_output.append({"Id":key, "Category":"unrelated"})
		
		df_out = pd.DataFrame(final_output)
		df_out.to_csv(FLAGS.result_file)

		print(len(output), cnt, len(final_output), "======num of results from model==========")
Esempio n. 16
0
import sys, os
sys.path.append("..")

import numpy as np
import tensorflow as tf
from example import bert_classifier
from bunch import Bunch
from example import feature_writer, write_to_tfrecords, classifier_processor
from porn_classification import classifier_processor
from data_generator import tokenization
from data_generator import tf_data_utils
from model_io import model_io
import json

tokenizer = tokenization.FullTokenizer(
    vocab_file="/data/xuht/chinese_L-12_H-768_A-12/vocab.txt",
    do_lower_case=True)

with open(
        "/data/xuht/websiteanalyze-data-seqing20180821/data/rule/mined_porn_domain_adaptation_v2.txt",
        "r") as frobj:
    lines = frobj.read().splitlines()
    freq_dict = []
    for line in lines:
        content = line.split("&&&&")
        word = "".join(content[0].split("&"))
        label = "rule"
        tmp = {}
        tmp["word"] = word
        tmp["label"] = "rule"
        freq_dict.append(tmp)
def main(_):

    import json
    multi_task_config = Bunch(
        json.load(open(os.path.join(FLAGS.buckets, FLAGS.multitask_dict))))

    vocab_path = FLAGS.vocab_file
    # os.path.join(FLAGS.buckets, FLAGS.vocab_file)

    train_file_dict = {}
    test_file_dict = {}
    dev_file_dict = {}
    train_result_dict = {}
    test_result_dict = {}
    dev_result_dict = {}
    label_id_dict = {}
    for task in multi_task_config:
        train_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["train_file"])

        test_file_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["test_file"])

        dev_file_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["dev_file"])

        train_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["train_result_file"])

        test_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["test_result_file"])

        dev_result_dict[task] = os.path.join(
            FLAGS.buckets, multi_task_config[task]["dev_result_file"])
        label_id_dict[task] = os.path.join(FLAGS.buckets,
                                           multi_task_config[task]["label_id"])

    print(train_file_dict)

    if FLAGS.lower_case == "True":
        lower_case = True
    else:
        lower_case = False

    tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path,
                                           do_lower_case=lower_case)

    total_examples = []

    task_type_id = OrderedDict()
    label2id_dict = {}

    task_dataset = {}

    index = 0
    for task in (FLAGS.multi_task_type.split(",")):
        if task not in multi_task_config:
            continue
        task_type_id[task] = multi_task_config[task]
        index += 1
        data_type = multi_task_config[task]["data_type"]
        if data_type == "single_sentence":
            classifier_data_api = classifier_processor.SentenceProcessor()
            classifier_data_api.get_labels(label_id_dict[task])
        elif data_type == "sentence_pair":
            classifier_data_api = classifier_processor.SentencePairProcessor()
            classifier_data_api.get_labels(label_id_dict[task])

        train_examples = classifier_data_api.get_train_examples(
            train_file_dict[task], is_shuffle=True)
        label2id_dict[task] = classifier_data_api.label2id

        for item in train_examples:
            tmp = {"example": item, "task": task}
            total_examples.append(tmp)

    print(task_type_id.keys())
    print("==total data==", len(total_examples))

    for i in range(10):
        random.shuffle(total_examples)
        write_to_tfrecords_multitask.convert_multitask_classifier_merged_examples_to_features(
            total_examples, label2id_dict, FLAGS.max_length, tokenizer,
            os.path.join(FLAGS.buckets, FLAGS.output_path,
                         "train_tfrecords_{}".format(i)), task_type_id)
Esempio n. 18
0
def main(_):

    tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                           do_lower_case=FLAGS.lower_case)

    if FLAGS.if_rule != "rule":
        print("==not apply rule==")
        classifier_data_api = classifier_processor.PornClassifierProcessor()
        classifier_data_api.get_labels(FLAGS.label_id)

        train_examples = classifier_data_api.get_train_examples(
            FLAGS.train_file, is_shuffle=False)

        write_to_tfrecords.convert_classifier_examples_to_features(
            train_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, FLAGS.train_result_file)

        test_examples = classifier_data_api.get_train_examples(
            FLAGS.test_file, is_shuffle=False)
        write_to_tfrecords.convert_classifier_examples_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, FLAGS.test_result_file)
    elif FLAGS.if_rule == "rule":
        print("==apply rule==")
        with open(FLAGS.rule_word_path, "r") as frobj:
            lines = frobj.read().splitlines()
            freq_dict = []
            for line in lines:
                content = line.split("&&&&")
                word = "".join(content[0].split("&"))
                label = "rule"
                tmp = {}
                tmp["word"] = word
                tmp["label"] = "rule"
                freq_dict.append(tmp)
            print(len(freq_dict))
            json.dump(freq_dict, open(FLAGS.rule_word_dict, "w"))
        from data_generator import rule_detector

        # label_dict = {"label2id":{"正常":0,"rule":1}, "id2label":{0:"正常", 1:"rule"}}
        # json.dump(label_dict, open("/data/xuht/websiteanalyze-data-seqing20180821/data/rule/rule_label_dict.json", "w"))

        rule_config = {
            "keyword_path": FLAGS.rule_word_dict,
            "background_label": "正常",
            "label_dict": FLAGS.rule_label_dict
        }
        rule_api = rule_detector.RuleDetector(rule_config)
        rule_api.load(tokenizer)

        classifier_data_api = classifier_processor.PornClassifierProcessor()
        classifier_data_api.get_labels(FLAGS.label_id)

        train_examples = classifier_data_api.get_train_examples(
            FLAGS.train_file, is_shuffle=True)

        write_to_tfrecords.convert_classifier_examples_with_rule_to_features(
            train_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, rule_api, FLAGS.train_result_file)

        test_examples = classifier_data_api.get_train_examples(
            FLAGS.test_file, is_shuffle=False)
        write_to_tfrecords.convert_classifier_examples_with_rule_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer, rule_api, FLAGS.test_result_file)
Esempio n. 19
0
File: eval.py Progetto: P79N6A/BERT
def main(_):

	hvd.init()

	sess_config = tf.ConfigProto()
	sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

	graph = tf.Graph()
	with graph.as_default():
		import json
				
		config = json.load(open(FLAGS.config_file, "r"))
		init_checkpoint = FLAGS.init_checkpoint

		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "bert"
		config.dropout_prob = 0.1
		config.label_type = "single_label"
		
		if FLAGS.if_shard == "0":
			train_size = FLAGS.train_size
			epoch = int(FLAGS.epoch / hvd.size())
		elif FLAGS.if_shard == "1":
			train_size = int(FLAGS.train_size/hvd.size())
			epoch = FLAGS.epoch

		tokenizer = tokenization.FullTokenizer(
		vocab_file=FLAGS.vocab_file, 
		do_lower_case=FLAGS.lower_case)

		classifier_data_api = classifier_processor.EvaluationProcessor()
		classifier_data_api.get_labels(FLAGS.label_id)

		train_examples = classifier_data_api.get_train_examples(FLAGS.train_file)

		write_to_tfrecords.convert_classifier_examples_to_features(train_examples,
																classifier_data_api.label2id,
																FLAGS.max_length,
																tokenizer,
																FLAGS.eval_data_file)

		init_lr = 2e-5

		num_train_steps = int(
			train_size / FLAGS.batch_size * epoch)
		num_warmup_steps = int(num_train_steps * 0.1)

		num_storage_steps = int(train_size / FLAGS.batch_size)

		print(" model type {}".format(FLAGS.model_type))

		print(num_train_steps, num_warmup_steps, "=============")
		
		opt_config = Bunch({"init_lr":init_lr/hvd.size(), 
							"num_train_steps":num_train_steps,
							"num_warmup_steps":num_warmup_steps})

		sess = tf.Session(config=sess_config)

		model_io_config = Bunch({"fix_lm":False})
		
		model_io_fn = model_io.ModelIO(model_io_config)

		optimizer_fn = optimizer.Optimizer(opt_config)
		
		num_classes = FLAGS.num_classes
		
		model_eval_fn = bert_classifier.classifier_model_fn_builder(config, num_classes, init_checkpoint, 
												reuse=False, 
												load_pretrained=True,
												model_io_fn=model_io_fn,
												optimizer_fn=optimizer_fn,
												model_io_config=model_io_config, 
												opt_config=opt_config)
		
		def metric_fn(features, logits, loss):
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			prob = tf.nn.softmax(logits)
			accuracy = correct = tf.equal(
				tf.cast(pred_label, tf.int32),
				tf.cast(features["label_ids"], tf.int32)
			)
			accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
			return {"accuracy":accuracy, "loss":loss, "pred_label":pred_label, 
				"label_ids":features["label_ids"],
				"prob":prob}
		
		name_to_features = {
				"input_ids":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"input_mask":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"segment_ids":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"label_ids":
						tf.FixedLenFeature([], tf.int64),
		}

		def _decode_record(record, name_to_features):
			"""Decodes a record to a TensorFlow example.
			"""
			example = tf.parse_single_example(record, name_to_features)

			# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
			# So cast all int64 to int32.
			for name in list(example.keys()):
				t = example[name]
				if t.dtype == tf.int64:
					t = tf.to_int32(t)
				example[name] = t
			return example 

		params = Bunch({})
		params.epoch = FLAGS.epoch
		params.batch_size = FLAGS.batch_size

		eval_features = tf_data_utils.eval_input_fn(FLAGS.eval_data_file,
									_decode_record, name_to_features, params, if_shard=FLAGS.if_shard)
		
		[_, eval_loss, eval_per_example_loss, eval_logits] = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
		result = metric_fn(eval_features, eval_logits, eval_loss)
		
		init_op = tf.group(tf.global_variables_initializer(), 
					tf.local_variables_initializer())
		sess.run(init_op)

		sess.run(hvd.broadcast_global_variables(0))

		print("===horovod rank==={}".format(hvd.rank()))
		
		def eval_fn(result):
			i = 0
			total_accuracy = 0
			label, label_id, prob = [], [], []
			while True:
				try:
					eval_result = sess.run(result)
					total_accuracy += eval_result["accuracy"]
					label_id.extend(eval_result["label_ids"])
					label.extend(eval_result["pred_label"])
					prob.extend(eval_result["prob"])
					i += 1
				except tf.errors.OutOfRangeError:
					print("End of dataset")
					break
			macro_f1 = f1_score(label_id, label, average="macro")
			micro_f1 = f1_score(label_id, label, average="micro")
			macro_precision = precision_score(label_id, label, average="macro")
			micro_precision = precision_score(label_id, label, average="micro")
			macro_recall = recall_score(label_id, label, average="macro")
			micro_recall = recall_score(label_id, label, average="micro")
			accuracy = accuracy_score(label_id, label)
			print("test accuracy {} macro_f1 score {} micro_f1 {} accuracy {}".format(total_accuracy/ i, 
																					macro_f1,  micro_f1, accuracy))
			return total_accuracy/ i, label_id, label, prob
		
		import time
		import time
		start = time.time()
		
		acc, true_label, pred_label, prob = eval_fn(result)
		end = time.time()
		print("==total time {} numbers of devices {}".format(end - start, hvd.size()))
		if hvd.rank() == 0:
			import _pickle as pkl
			pkl.dump({"true_label":true_label, 
						"pred_label":pred_label,
						"prob":prob}, 
						open(FLAGS.model_output+"/predict.pkl", "wb"))
def main(_):

    # tokenizer = tokenization.Jieba_CHAR(
    # 	config=FLAGS.config)

    # with tf.gfile.Open(FLAGS.vocab_file, "r") as f:
    # 	vocab_lst = []
    # 	for line in f:
    # 		vocab_lst.append(line.strip())

    # vocab_path = os.path.join(FLAGS.buckets, FLAGS.vocab_file)
    vocab_path = FLAGS.vocab_file
    train_file = os.path.join(FLAGS.buckets, FLAGS.train_file)
    test_file = os.path.join(FLAGS.buckets, FLAGS.test_file)
    dev_file = os.path.join(FLAGS.buckets, FLAGS.dev_file)

    train_result_file = os.path.join(FLAGS.buckets, FLAGS.train_result_file)
    test_result_file = os.path.join(FLAGS.buckets, FLAGS.test_result_file)
    dev_result_file = os.path.join(FLAGS.buckets, FLAGS.dev_result_file)

    corpus_vocab_path = os.path.join(FLAGS.buckets, FLAGS.corpus_vocab_path)

    if FLAGS.tokenizer_type == "jieba":
        tokenizer = tokenization.Jieba_CHAR(config=FLAGS.config)
    elif FLAGS.tokenizer_type == "full_bpe":
        tokenizer = tokenization.FullTokenizer(vocab_file=vocab_path,
                                               do_lower_case=FLAGS.lower_case)

    if FLAGS.tokenizer_type == "jieba":
        print(FLAGS.with_char)
        with tf.gfile.Open(vocab_path, "r") as f:
            lines = f.read().splitlines()
            vocab_lst = []
            for line in lines:
                vocab_lst.append(line)
            print(len(vocab_lst))

        tokenizer.load_vocab(vocab_lst)

    print("==not apply rule==")
    if FLAGS.data_type == "fasttext":
        classifier_data_api = classifier_processor.FasttextClassifierProcessor(
        )

    classifier_data_api.get_labels(FLAGS.label_id)

    train_examples = classifier_data_api.get_train_examples(train_file,
                                                            is_shuffle=True)
    print("==total train examples==", len(train_examples))

    test_examples = classifier_data_api.get_train_examples(test_file,
                                                           is_shuffle=False)
    print("==total test examples==", len(test_examples))

    dev_examples = classifier_data_api.get_train_examples(dev_file,
                                                          is_shuffle=False)
    print("==total dev examples==", len(dev_examples))

    if FLAGS.tokenizer_type == "jieba":
        vocab_filter.vocab_filter(
            train_examples + test_examples + dev_examples, vocab_lst,
            tokenizer, FLAGS.predefined_vocab_size, corpus_vocab_path)

        tokenizer_corpus = tokenization.Jieba_CHAR(config=FLAGS.config)

        with tf.gfile.Open(corpus_vocab_path, "r") as f:
            lines = f.read().splitlines()
            vocab_lst = []
            for line in lines:
                vocab_lst.append(line)
            print(len(vocab_lst))
            # print(vocab_lst)

        tokenizer_corpus.load_vocab(vocab_lst)
    elif FLAGS.tokenizer_type == "full_bpe":
        tokenizer_corpus = tokenizer

    if FLAGS.tokenizer_type == "jieba":
        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            train_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, train_result_file, FLAGS.with_char,
            FLAGS.char_len)

        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            test_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, test_result_file, FLAGS.with_char,
            FLAGS.char_len)

        write_to_tfrecords.convert_distillation_classifier_examples_to_features(
            dev_examples, classifier_data_api.label2id, FLAGS.max_length,
            tokenizer_corpus, dev_result_file, FLAGS.with_char, FLAGS.char_len)
    elif FLAGS.tokenizer_type == "full_bpe":
        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            train_examples,
            classifier_data_api.label2id,
            FLAGS.max_length,
            tokenizer_corpus,
            train_result_file,
            FLAGS.with_char,
            FLAGS.char_len,
            label_type=FLAGS.label_type)

        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            dev_examples,
            classifier_data_api.label2id,
            FLAGS.max_length,
            tokenizer_corpus,
            dev_result_file,
            FLAGS.with_char,
            FLAGS.char_len,
            label_type=FLAGS.label_type)

        test_examples = classifier_data_api.get_train_examples(
            test_file, is_shuffle=False)
        write_to_tfrecords.convert_bert_distillation_classifier_examples_to_features(
            test_examples,
            classifier_data_api.label2id,
            FLAGS.max_length,
            tokenizer_corpus,
            test_result_file,
            FLAGS.with_char,
            FLAGS.char_len,
            label_type=FLAGS.label_type)