Exemple #1
0
def main(_):

	hvd.init()

	sess_config = tf.ConfigProto()
	sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

	graph = tf.Graph()
	with graph.as_default():
		import json
				
		config = json.load(open(FLAGS.config_file, "r"))
		init_checkpoint = FLAGS.init_checkpoint

		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "bert"
		config.dropout_prob = 0.1
		config.label_type = "single_label"
		
		if FLAGS.if_shard == "0":
			train_size = FLAGS.train_size
			epoch = int(FLAGS.epoch / hvd.size())
		elif FLAGS.if_shard == "1":
			train_size = int(FLAGS.train_size/hvd.size())
			epoch = FLAGS.epoch

		tokenizer = tokenization.FullTokenizer(
		vocab_file=FLAGS.vocab_file, 
		do_lower_case=FLAGS.lower_case)

		classifier_data_api = classifier_processor.EvaluationProcessor()
		classifier_data_api.get_labels(FLAGS.label_id)

		train_examples = classifier_data_api.get_train_examples(FLAGS.train_file)

		write_to_tfrecords.convert_classifier_examples_to_features(train_examples,
																classifier_data_api.label2id,
																FLAGS.max_length,
																tokenizer,
																FLAGS.eval_data_file)

		init_lr = 2e-5

		num_train_steps = int(
			train_size / FLAGS.batch_size * epoch)
		num_warmup_steps = int(num_train_steps * 0.1)

		num_storage_steps = int(train_size / FLAGS.batch_size)

		print(" model type {}".format(FLAGS.model_type))

		print(num_train_steps, num_warmup_steps, "=============")
		
		opt_config = Bunch({"init_lr":init_lr/hvd.size(), 
							"num_train_steps":num_train_steps,
							"num_warmup_steps":num_warmup_steps})

		sess = tf.Session(config=sess_config)

		model_io_config = Bunch({"fix_lm":False})
		
		model_io_fn = model_io.ModelIO(model_io_config)

		optimizer_fn = optimizer.Optimizer(opt_config)
		
		num_classes = FLAGS.num_classes
		
		model_eval_fn = bert_classifier.classifier_model_fn_builder(config, num_classes, init_checkpoint, 
												reuse=False, 
												load_pretrained=True,
												model_io_fn=model_io_fn,
												optimizer_fn=optimizer_fn,
												model_io_config=model_io_config, 
												opt_config=opt_config)
		
		def metric_fn(features, logits, loss):
			print(logits.get_shape(), "===logits shape===")
			pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
			prob = tf.nn.softmax(logits)
			accuracy = correct = tf.equal(
				tf.cast(pred_label, tf.int32),
				tf.cast(features["label_ids"], tf.int32)
			)
			accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
			return {"accuracy":accuracy, "loss":loss, "pred_label":pred_label, 
				"label_ids":features["label_ids"],
				"prob":prob}
		
		name_to_features = {
				"input_ids":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"input_mask":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"segment_ids":
						tf.FixedLenFeature([FLAGS.max_length], tf.int64),
				"label_ids":
						tf.FixedLenFeature([], tf.int64),
		}

		def _decode_record(record, name_to_features):
			"""Decodes a record to a TensorFlow example.
			"""
			example = tf.parse_single_example(record, name_to_features)

			# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
			# So cast all int64 to int32.
			for name in list(example.keys()):
				t = example[name]
				if t.dtype == tf.int64:
					t = tf.to_int32(t)
				example[name] = t
			return example 

		params = Bunch({})
		params.epoch = FLAGS.epoch
		params.batch_size = FLAGS.batch_size

		eval_features = tf_data_utils.eval_input_fn(FLAGS.eval_data_file,
									_decode_record, name_to_features, params, if_shard=FLAGS.if_shard)
		
		[_, eval_loss, eval_per_example_loss, eval_logits] = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
		result = metric_fn(eval_features, eval_logits, eval_loss)
		
		init_op = tf.group(tf.global_variables_initializer(), 
					tf.local_variables_initializer())
		sess.run(init_op)

		sess.run(hvd.broadcast_global_variables(0))

		print("===horovod rank==={}".format(hvd.rank()))
		
		def eval_fn(result):
			i = 0
			total_accuracy = 0
			label, label_id, prob = [], [], []
			while True:
				try:
					eval_result = sess.run(result)
					total_accuracy += eval_result["accuracy"]
					label_id.extend(eval_result["label_ids"])
					label.extend(eval_result["pred_label"])
					prob.extend(eval_result["prob"])
					i += 1
				except tf.errors.OutOfRangeError:
					print("End of dataset")
					break
			macro_f1 = f1_score(label_id, label, average="macro")
			micro_f1 = f1_score(label_id, label, average="micro")
			macro_precision = precision_score(label_id, label, average="macro")
			micro_precision = precision_score(label_id, label, average="micro")
			macro_recall = recall_score(label_id, label, average="macro")
			micro_recall = recall_score(label_id, label, average="micro")
			accuracy = accuracy_score(label_id, label)
			print("test accuracy {} macro_f1 score {} micro_f1 {} accuracy {}".format(total_accuracy/ i, 
																					macro_f1,  micro_f1, accuracy))
			return total_accuracy/ i, label_id, label, prob
		
		import time
		import time
		start = time.time()
		
		acc, true_label, pred_label, prob = eval_fn(result)
		end = time.time()
		print("==total time {} numbers of devices {}".format(end - start, hvd.size()))
		if hvd.rank() == 0:
			import _pickle as pkl
			pkl.dump({"true_label":true_label, 
						"pred_label":pred_label,
						"prob":prob}, 
						open(FLAGS.model_output+"/predict.pkl", "wb"))
Exemple #2
0
    def model_fn(features, labels, mode):

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "/task"
        else:
            scope = model_config.scope

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        label_ids = features["label_ids"]

        target = kargs["target"]

        [a_mask, a_repres, b_mask,
         b_repres] = bert_lstm_encoding(model_config,
                                        features,
                                        labels,
                                        mode,
                                        target,
                                        max_len,
                                        scope,
                                        dropout_prob,
                                        reuse=model_reuse)

        a_repres = lstm_model(model_config, a_repres, a_mask, dropout_prob,
                              scope, model_reuse)

        b_repres = lstm_model(model_config, b_repres, b_mask, dropout_prob,
                              scope, True)

        a_output, b_output = alignment(model_config,
                                       a_repres,
                                       b_repres,
                                       a_mask,
                                       b_mask,
                                       scope,
                                       reuse=model_reuse)

        repres_a = bert_multihead_pooling(model_config,
                                          a_output,
                                          a_mask,
                                          scope,
                                          dropout_prob,
                                          reuse=model_reuse)

        repres_b = bert_multihead_pooling(model_config,
                                          b_output,
                                          b_mask,
                                          scope,
                                          dropout_prob,
                                          reuse=True)

        pair_repres = tf.concat([
            repres_a, repres_b,
            tf.abs(repres_a - repres_b), repres_b * repres_a
        ],
                                axis=-1)

        print(pair_repres.get_shape(), "==repres shape==")

        with tf.variable_scope(scope, reuse=model_reuse):

            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    shape=[
                        num_labels,
                    ],
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
                print("==applying class weight==")
            except:
                ratio_weight = None

            (loss, per_example_loss,
             logits) = classifier.classifier(model_config, pair_repres,
                                             num_labels, label_ids,
                                             dropout_prob, ratio_weight)
        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            if load_pretrained:
                model_io_fn.load_pretrained(
                    pretrained_tvars,
                    init_checkpoint,
                    exclude_scope=exclude_scope_dict["task"])

        trainable_params = model_io_fn.get_params(
            scope, not_storage_params=not_storage_params)

        tvars = trainable_params

        storage_params = model_io_fn.get_params(
            model_config.scope, not_storage_params=not_storage_params)

        # for var in storage_params:
        # 	print(var.name, var.get_shape(), "==storage params==")

        # for var in tvars:
        # 	print(var.name, var.get_shape(), "==trainable params==")

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
def main(_):

    hvd.init()

    sess_config = tf.ConfigProto()
    # sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

    sess_config.gpu_options.visible_device_list = \
           '%d,%d' % (hvd.local_rank() * 2, hvd.local_rank() * 2 + 1)

    graph = tf.Graph()
    with graph.as_default():
        import json

        config = json.load(
            open("/data/xuht/bert/chinese_L-12_H-768_A-12/bert_config.json",
                 "r"))
        init_checkpoint = "/data/xuht/bert/chinese_L-12_H-768_A-12/bert_model.ckpt"
        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"
        config.loss = "focal_loss"
        #     config.num_hidden_layers =

        # os.environ["CUDA_VISIBLE_DEVICES"] = "0"

        num_train = int(33056 / hvd.size())

        batch_size = 32

        valid_step = int(num_train / batch_size)

        epoch = 2
        num_train_steps = int(num_train / (batch_size) * epoch)

        decay_train_steps = num_train_steps

        # decay_train_steps = int(
        # 		33056 / batch_size * epoch)

        num_warmup_steps = int(num_train_steps * 0.01)

        sess = tf.Session(config=sess_config)

        opt_config = Bunch({
            "init_lr": float(1e-5 / hvd.size()),
            "num_train_steps": decay_train_steps,
            "cycle": False,
            "num_warmup_steps": num_warmup_steps,
            "lr_decay": "polynomial_decay"
        })
        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)

        optimizer_fn = optimizer.Optimizer(opt_config)

        num_calsses = 2

        model_train_fn = bert_classifier.classifier_model_fn_builder(
            config,
            num_calsses,
            init_checkpoint,
            reuse=None,
            load_pretrained=True,
            model_io_fn=model_io_fn,
            optimizer_fn=optimizer_fn,
            model_io_config=model_io_config,
            opt_config=opt_config,
            gpu_id=0,
            gpu_nums=2)

        # model_eval_fn = bert_classifier.classifier_model_fn_builder(config, num_calsses, init_checkpoint,
        # 										reuse=True,
        # 										load_pretrained=True,
        # 										model_io_fn=model_io_fn,
        # 										optimizer_fn=optimizer_fn,
        # 										model_io_config=model_io_config,
        # 										opt_config=opt_config,
        # 										gpu_id=0,
        # 										gpu_nums=2)

        def metric_fn(features, logits, loss):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            return {
                "accuracy": accuracy,
                "loss": loss,
                "pred_label": pred_label,
                "label_ids": features["label_ids"]
            }

        name_to_features = {
            "input_ids": tf.FixedLenFeature([128], tf.int64),
            "input_mask": tf.FixedLenFeature([128], tf.int64),
            "segment_ids": tf.FixedLenFeature([128], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
        }

        params = Bunch({})
        params.epoch = epoch
        params.batch_size = 32
        train_file = "/data/xuht/eventy_detection/event/model/train.tfrecords"
        train_file1 = "/data/xuht/eventy_detection/sentiment/model/sentiment_11_14/train.tfrecords"
        title_sentiment = "/data/xuht/eventy_detection/sentiment/model/test/train.tfrecords"
        sentiment = "/data/xuht/eventy_detection/sentiment/model/bert/train_11_15.tfrecords"
        jd_train = "/data/xuht/jd_comment/train.tfrecords"
        train_features = tf_data_utils.train_input_fn(
            jd_train, tf_data_utils._decode_record, name_to_features, params)

        test_file = [
            "/data/xuht/eventy_detection/sentiment/model/sentiment_11_14/test.tfrecords"
        ]
        test_file1_1 = [
            "/data/xuht/eventy_detection/sentiment/model/test/train.tfrecords",
            "/data/xuht/eventy_detection/sentiment/model/test/test.tfrecords"
        ]
        test_file2 = "/data/xuht/eventy_detection/event/model/test.tfrecords"
        title_test = "/data/xuht/eventy_detection/sentiment/model/test/test.tfrecords"
        jd_test = "/data/xuht/jd_comment/test.tfrecords"
        sentiment_test = "/data/xuht/eventy_detection/sentiment/model/bert/test_11_15.tfrecords"

        eval_features = tf_data_utils.eval_input_fn(
            jd_test, tf_data_utils._decode_record, name_to_features, params)

        [train_op, train_loss, train_per_example_loss,
         train_logits] = model_train_fn(train_features, [],
                                        tf.estimator.ModeKeys.TRAIN)
        # [_, eval_loss, eval_per_example_loss, eval_logits] = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
        # result = metric_fn(eval_features, eval_logits, eval_loss)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        sess.run(hvd.broadcast_global_variables(0))

        model_io_fn.set_saver()

        print("===horovod rank==={}".format(hvd.rank()))

        def eval_fn(result):
            i = 0
            total_accuracy = 0
            label, label_id = [], []
            while True:
                try:
                    eval_result = sess.run(result)
                    total_accuracy += eval_result["accuracy"]
                    label_id.extend(eval_result["label_ids"])
                    label.extend(eval_result["pred_label"])
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break
            macro_f1 = f1_score(label_id, label, average="macro")
            micro_f1 = f1_score(label_id, label, average="micro")
            accuracy = accuracy_score(label_id, label)
            print("test accuracy {} macro_f1 score {} micro_f1 {} accuracy {}".
                  format(total_accuracy / i, macro_f1, micro_f1, accuracy))
            return total_accuracy / i, label_id, label

        def train_fn(op, loss):
            i = 0
            total_loss = 0
            cnt = 0
            while True:
                try:
                    [_, train_loss] = sess.run([op, loss])
                    i += 1
                    cnt += 1
                    total_loss += train_loss
                    # print("==device id {} global step {}".format(hvd.rank(), step))
                    if np.mod(i, valid_step) == 0:
                        print(total_loss / cnt)
                        cnt = 0
                        total_loss = 0
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break

        import time
        start = time.time()
        train_fn(train_op, train_loss)
        # acc, true_label, pred_label = eval_fn(result)
        end = time.time()
        print("==total time {} numbers of devices {}".format(
            end - start, hvd.size()))
Exemple #4
0
def main(_):

    hvd.init()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

    graph = tf.Graph()
    with graph.as_default():
        import json

        config = json.load(open(FLAGS.config_file, "r"))
        init_checkpoint = FLAGS.init_checkpoint

        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / hvd.size())
        elif FLAGS.if_shard == "1":
            train_size = int(FLAGS.train_size / hvd.size())
            epoch = FLAGS.epoch

        init_lr = 2e-5

        label_dict = json.load(open(FLAGS.label_id))

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": init_lr / hvd.size(),
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps
        })

        sess = tf.Session(config=sess_config)

        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)

        optimizer_fn = optimizer.Optimizer(opt_config)

        num_classes = FLAGS.num_classes

        model_train_fn = bert_classifier.classifier_model_fn_builder(
            config,
            num_classes,
            init_checkpoint,
            reuse=None,
            load_pretrained=True,
            model_io_fn=model_io_fn,
            optimizer_fn=optimizer_fn,
            model_io_config=model_io_config,
            opt_config=opt_config)

        model_eval_fn = bert_classifier.classifier_model_fn_builder(
            config,
            num_classes,
            init_checkpoint,
            reuse=True,
            load_pretrained=True,
            model_io_fn=model_io_fn,
            optimizer_fn=optimizer_fn,
            model_io_config=model_io_config,
            opt_config=opt_config)

        def metric_fn(features, logits, loss):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            return {
                "accuracy": accuracy,
                "loss": loss,
                "pred_label": pred_label,
                "label_ids": features["label_ids"]
            }

        name_to_features = {
            "input_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "input_mask": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([FLAGS.max_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t
            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size

        train_features = tf_data_utils.train_input_fn(FLAGS.train_file,
                                                      _decode_record,
                                                      name_to_features,
                                                      params,
                                                      if_shard=FLAGS.if_shard)
        eval_features = tf_data_utils.eval_input_fn(FLAGS.dev_file,
                                                    _decode_record,
                                                    name_to_features,
                                                    params,
                                                    if_shard=FLAGS.if_shard)

        [train_op, train_loss, train_per_example_loss,
         train_logits] = model_train_fn(train_features, [],
                                        tf.estimator.ModeKeys.TRAIN)
        train_dict = {"train_op": train_op, "train_loss": train_loss}
        [_, eval_loss, eval_per_example_loss,
         eval_logits] = model_eval_fn(eval_features, [],
                                      tf.estimator.ModeKeys.EVAL)
        eval_dict = metric_fn(eval_features, eval_logits, eval_loss)

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        sess.run(hvd.broadcast_global_variables(0))

        model_io_fn.set_saver()

        print("===horovod rank==={}".format(hvd.rank()))

        def run_eval(steps):
            import _pickle as pkl
            # eval_features = tf_data_utils.eval_input_fn(
            # 							FLAGS.dev_file,
            # 							_decode_record,
            # 							name_to_features, params)
            # [_, eval_loss,
            # eval_per_example_loss, eval_logits] = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
            # eval_dict = metric_fn(eval_features, eval_logits, eval_loss)
            # sess.run(tf.local_variables_initializer())
            eval_finial_dict = eval_fn(eval_dict)
            if hvd.rank() == 0:
                pkl.dump(
                    eval_finial_dict,
                    open(
                        FLAGS.model_output + "/eval_dict_{}.pkl".format(steps),
                        "wb"))
            return eval_finial_dict

        def eval_fn(result):
            i = 0
            total_accuracy = 0
            eval_total_dict = {}

            while True:
                try:
                    eval_result = sess.run(result)
                    for key in eval_result:
                        if key not in eval_total_dict:
                            if key in ["pred_label", "label_ids"]:
                                eval_total_dict[key] = []
                                eval_total_dict[key].extend(eval_result[key])
                            if key in ["accuracy", "loss"]:
                                eval_total_dict[key] = 0.0
                                eval_total_dict[key] += eval_result[key]
                        else:
                            if key in ["pred_label", "label_ids"]:
                                eval_total_dict[key].extend(eval_result[key])
                            if key in ["accuracy", "loss"]:
                                eval_total_dict[key] += eval_result[key]

                    i += 1
                    # if i == 100:
                    # 	break
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break

            label_id = eval_total_dict["label_ids"]
            pred_label = eval_total_dict["pred_label"]

            result = classification_report(label_id,
                                           pred_label,
                                           target_names=list(
                                               label_dict["label2id"].keys()))

            print(result)
            eval_total_dict["classification_report"] = result
            return eval_total_dict

        def train_fn(op_dict):
            i = 0
            cnt = 0
            loss_dict = {}
            monitoring_train = []
            monitoring_eval = []
            while True:
                try:
                    train_result = sess.run(op_dict)
                    for key in train_result:
                        if key == "train_op":
                            continue
                        else:
                            if np.isnan(train_result[key]):
                                print(train_loss, "get nan loss")
                                break
                            else:
                                if key in loss_dict:
                                    loss_dict[key] += train_result[key]
                                else:
                                    loss_dict[key] = train_result[key]

                    i += 1
                    cnt += 1

                    if np.mod(i, num_storage_steps) == 0:
                        string = ""
                        for key in loss_dict:
                            tmp = key + " " + str(loss_dict[key] / cnt) + "\t"
                            string += tmp
                        print(string)
                        monitoring_train.append(loss_dict)

                        if hvd.rank() == 0:
                            model_io_fn.save_model(
                                sess,
                                FLAGS.model_output + "/model_{}.ckpt".format(
                                    int(i / num_storage_steps)))

                        print("==successful storing model=={}".format(
                            int(i / num_storage_steps)))
                        cnt = 0

                        # eval_finial_dict = run_eval(int(i/num_storage_steps))
                        # monitoring_eval.append(eval_finial_dict)

                        for key in loss_dict:
                            loss_dict[key] = 0.0

                except tf.errors.OutOfRangeError:
                    if hvd.rank() == 0:
                        import _pickle as pkl
                        pkl.dump(
                            {
                                "train": monitoring_train,
                                "eval": monitoring_eval
                            },
                            open(FLAGS.model_output + "/monitoring.pkl", "wb"))

                    break

        print("===========begin to train============")
        train_fn(train_dict)
        if hvd.rank() == 0:
            model_io_fn.save_model(sess, FLAGS.model_output + "/model.ckpt")
            print("===========begin to eval============")
            eval_finial_dict = run_eval("final")
Exemple #5
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]
        model_lst = []
        for index, name in enumerate(input_name):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                base_model(model_config,
                           features,
                           labels,
                           mode,
                           name,
                           reuse=reuse))

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):
            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
                print("==applying class weight==")
            except:
                ratio_weight = None

            seq_output_lst = [model.get_pooled_output() for model in model_lst]

            [loss, per_example_loss,
             logits] = classifier.order_classifier(model_config,
                                                   seq_output_lst, num_labels,
                                                   label_ids, dropout_prob,
                                                   ratio_weight)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Exemple #6
0
def main(_):

    graph = tf.Graph()
    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    with graph.as_default():
        import json

        hvd.init()

        sess_config = tf.ConfigProto()
        sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

        # config = json.load(open("/data/xuht/bert/chinese_L-12_H-768_A-12/bert_config.json", "r"))

        config = json.load(open(FLAGS.config_file, "r"))

        init_checkpoint = FLAGS.init_checkpoint
        print("===init checkoutpoint==={}".format(init_checkpoint))

        import json
        label_dict = json.load(open(FLAGS.label_id))

        # init_checkpoint = "/data/xuht/bert/chinese_L-12_H-768_A-12/bert_model.ckpt"
        # init_checkpoint = "/data/xuht/concat/model_1/oqmrc.ckpt"
        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.1
        config.label_type = "single_label"
        # config.loss = "focal_loss"

        # os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
        sess = tf.Session(config=sess_config)

        train_size = int(FLAGS.train_size / hvd.size())

        num_train_steps = int(train_size / FLAGS.batch_size * FLAGS.epoch)
        num_warmup_steps = int(num_train_steps * 0.01)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": (2e-5 / hvd.size()),
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)
        optimizer_fn = optimizer.Optimizer(opt_config)

        num_choice = FLAGS.num_classes
        max_seq_length = FLAGS.max_length

        # model_train_fn = bert_classifier.classifier_model_fn_builder(config, num_choice, init_checkpoint,
        #                                         reuse=None,
        #                                         load_pretrained=True,
        #                                         model_io_fn=model_io_fn,
        #                                         optimizer_fn=optimizer_fn,
        #                                         model_io_config=model_io_config,
        #                                         opt_config=opt_config)

        model_eval_fn = bert_classifier.classifier_model_fn_builder(
            config,
            num_choice,
            init_checkpoint,
            reuse=None,
            load_pretrained=True,
            model_io_fn=model_io_fn,
            optimizer_fn=optimizer_fn,
            model_io_config=model_io_config,
            opt_config=opt_config)

        def metric_fn(features, logits, loss):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["label_ids"], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))
            return {
                "accuracy": accuracy,
                "loss": loss,
                "pred_label": pred_label,
                "label_ids": features["label_ids"]
            }

        name_to_features = {
            "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
            """
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t
            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size
        # train_features = tf_data_utils.train_input_fn("/data/xuht/wsdm19/data/train.tfrecords",
        #                             _decode_record, name_to_features, params)
        # eval_features = tf_data_utils.eval_input_fn("/data/xuht/wsdm19/data/dev.tfrecords",
        #                             _decode_record, name_to_features, params)

        # train_features = tf_data_utils.train_input_fn(FLAGS.train_file,
        #                             _decode_record, name_to_features, params)
        eval_features = tf_data_utils.eval_input_fn(FLAGS.dev_file,
                                                    _decode_record,
                                                    name_to_features, params)

        # [train_op, train_loss, train_per_example_loss, train_logits] = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN)
        [_, eval_loss, eval_per_example_loss,
         eval_logits] = model_eval_fn(eval_features, [],
                                      tf.estimator.ModeKeys.EVAL)
        result = metric_fn(eval_features, eval_logits, eval_loss)

        model_io_fn.set_saver()

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)
        model_io_fn.load_model(sess, init_checkpoint)
        sess.run(hvd.broadcast_global_variables(0))

        def eval_fn(result):
            i = 0
            total_accuracy = 0
            label, label_id = [], []
            while True:
                try:
                    eval_result = sess.run(result)
                    total_accuracy += eval_result["accuracy"]
                    label_id.extend(eval_result["label_ids"])
                    label.extend(eval_result["pred_label"])
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break
            f1 = f1_score(label_id, label, average="macro")
            accuracy = accuracy_score(label_id, label)
            print("test accuracy accuracy {} {} f1 {}".format(
                total_accuracy / i, accuracy, f1))
            return total_accuracy / i, f1

        if hvd.rank() == 0:
            print("===========begin to eval============")
            accuracy, f1 = eval_fn(result)
            print("==accuracy {} f1 {}==".format(accuracy, f1))
Exemple #7
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]
        model_lst = []
        for index, name in enumerate(input_name):
            if index > 0:
                reuse = True
            else:
                reuse = model_reuse
            model_lst.append(
                base_model(model_config,
                           features,
                           labels,
                           mode,
                           name,
                           reuse=reuse))

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=model_reuse):

            try:
                label_ratio_table = tf.get_variable(
                    name="label_ratio",
                    shape=[
                        num_labels,
                    ],
                    initializer=tf.constant(label_tensor),
                    trainable=False)

                ratio_weight = tf.nn.embedding_lookup(label_ratio_table,
                                                      label_ids)
            except:
                ratio_weight = None

            seq_output_lst = [model.get_pooled_output() for model in model_lst]
            repres = seq_output_lst[0] + seq_output_lst[1]

            final_hidden_shape = bert_utils.get_shape_list(repres,
                                                           expected_rank=2)

            z_mean = tf.layers.dense(repres,
                                     final_hidden_shape[1],
                                     name="z_mean")
            z_log_var = tf.layers.dense(repres,
                                        final_hidden_shape[1],
                                        name="z_log_var")
            print("=======applying vib============")
            if mode == tf.estimator.ModeKeys.TRAIN:
                print("====applying vib====")
                vib_connector = vib.VIB(vib_config)
                [kl_loss, latent_vector
                 ] = vib_connector.build_regularizer([z_mean, z_log_var])

                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, latent_vector,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

                loss += tf.reduce_mean(kl_loss)
            else:
                print("====applying z_mean for prediction====")
                [loss, per_example_loss,
                 logits] = classifier.classifier(model_config, z_mean,
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars, init_checkpoint)

        tvars = model_io_fn.get_params(scope)

        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]
        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Exemple #8
0
    def model_fn(features, labels, mode):
        label_ids = features["label_ids"]

        model = base_model(model_config,
                           features,
                           labels,
                           mode,
                           input_name,
                           reuse=model_reuse,
                           perturbation=None)

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        # assert len(model_lst) == len(input_name)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        def build_discriminator(model, scope, reuse):

            with tf.variable_scope(scope, reuse=reuse):

                try:
                    label_ratio_table = tf.get_variable(
                        name="label_ratio",
                        initializer=tf.constant(label_tensor),
                        trainable=False)

                    ratio_weight = tf.nn.embedding_lookup(
                        label_ratio_table, label_ids)
                    print("==applying class weight==")
                except:
                    ratio_weight = None

                (loss, per_example_loss,
                 logits) = classifier.classifier(model_config,
                                                 model.get_pooled_output(),
                                                 num_labels, label_ids,
                                                 dropout_prob, ratio_weight)
                return loss, per_example_loss, logits

        [loss, per_example_loss,
         logits] = build_discriminator(model, scope, model_reuse)

        if mode == tf.estimator.ModeKeys.TRAIN:
            pretrained_tvars = model_io_fn.get_params(
                model_config.scope, not_storage_params=not_storage_params)

            if load_pretrained:
                tf.logging.info(" load pre-trained base model ")
                print(" load pre-trained base model ")
                model_io_fn.load_pretrained(pretrained_tvars,
                                            init_checkpoint,
                                            exclude_scope=exclude_scope)
            tvars = pretrained_tvars
            model_io_fn.set_saver(var_lst=tvars)

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                optimizer_fn.get_opt(opt_config.init_lr,
                                     opt_config.num_train_steps)

                perturb = get_perturbation(model_config, optimizer_fn.opt,
                                           model.embedding_output_word, loss,
                                           tvars)

                adv_model = base_model(model_config,
                                       features,
                                       labels,
                                       mode,
                                       input_name,
                                       reuse=True,
                                       perturbation=perturb)

                [adv_loss, adv_per_example_loss,
                 adv_logits] = build_discriminator(adv_model, scope, True)

                total_loss = adv_loss + loss
                total_train_op = optimizer_fn.get_train_op_v1(
                    total_loss, tvars)

            return [total_train_op, total_loss, per_example_loss, logits]
        else:
            model_io_fn.set_saver()
            return [loss, loss, per_example_loss, logits]
    def model_fn(features, labels, mode):

        input_ids = features["input_ids"]
        input_mask = features["input_mask"]
        segment_ids = features["segment_ids"]
        label_ids = features["label_ids"]

        input_shape = bert_utils.get_shape_list(input_ids, expected_rank=3)
        batch_size = input_shape[0]
        choice_num = input_shape[1]
        seq_length = input_shape[2]

        input_ids = tf.reshape(input_ids,
                               [batch_size * choice_num, seq_length])
        input_mask = tf.reshape(input_mask,
                                [batch_size * choice_num, seq_length])
        segment_ids = tf.reshape(segment_ids,
                                 [batch_size * choice_num, seq_length])

        if mode == tf.estimator.ModeKeys.TRAIN:
            hidden_dropout_prob = model_config.hidden_dropout_prob
            attention_probs_dropout_prob = model_config.attention_probs_dropout_prob
            dropout_prob = model_config.dropout_prob
        else:
            hidden_dropout_prob = 0.0
            attention_probs_dropout_prob = 0.0
            dropout_prob = 0.0

        model = bert.Bert(model_config)
        model.build_embedder(input_ids,
                             segment_ids,
                             hidden_dropout_prob,
                             attention_probs_dropout_prob,
                             reuse=reuse)
        model.build_encoder(input_ids,
                            input_mask,
                            hidden_dropout_prob,
                            attention_probs_dropout_prob,
                            reuse=reuse)
        model.build_pooler(reuse=reuse)

        if model_io_config.fix_lm == True:
            scope = model_config.scope + "_finetuning"
        else:
            scope = model_config.scope

        with tf.variable_scope(scope, reuse=reuse):
            (loss, per_example_loss,
             logits) = classifier.multi_choice_classifier(
                 model_config, model.get_pooled_output(), num_labels,
                 label_ids, dropout_prob)

        # model_io_fn = model_io.ModelIO(model_io_config)
        pretrained_tvars = model_io_fn.get_params(model_config.scope)
        if load_pretrained:
            model_io_fn.load_pretrained(pretrained_tvars,
                                        init_checkpoint,
                                        exclude_scope=exclude_scope)

        tvars = model_io_fn.get_params(scope,
                                       not_storage_params=not_storage_params)
        model_io_fn.set_saver(var_lst=tvars)
        if mode == tf.estimator.ModeKeys.TRAIN:
            model_io_fn.print_params(tvars, string=", trainable params")
            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                optimizer_fn = optimizer.Optimizer(opt_config)
                train_op = optimizer_fn.get_train_op(
                    loss, tvars, opt_config.init_lr,
                    opt_config.num_train_steps)

                return [train_op, loss, per_example_loss, logits]

        else:
            model_io_fn.print_params(tvars, string=", trainable params")
            return [loss, loss, per_example_loss, logits]
Exemple #10
0
def main(_):

    hvd.init()

    graph = tf.Graph()

    sess_config = tf.ConfigProto()
    sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

    from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
    with graph.as_default():
        import json

        tokenizer = tokenization.FullTokenizer(vocab_file=FLAGS.vocab_file,
                                               do_lower_case=True)

        classifier_data_api = classifier_processor.PiarChoiceProcessor()

        eval_examples = classifier_data_api.get_test_examples(
            FLAGS.eval_data_file, FLAGS.lang)

        print(eval_examples[0].guid)

        label_id = json.load(open(FLAGS.label_id, "r"))

        num_choice = 3
        max_seq_length = FLAGS.max_length

        write_to_tfrecords.convert_classifier_examples_to_features(
            eval_examples, label_id["label2id"], max_seq_length, tokenizer,
            FLAGS.output_file)

        config = json.load(open(FLAGS.config_file, "r"))
        init_checkpoint = FLAGS.init_checkpoint

        print("===init checkoutpoint==={}".format(init_checkpoint))

        config = Bunch(config)
        config.use_one_hot_embeddings = True
        config.scope = "bert"
        config.dropout_prob = 0.2
        config.label_type = "single_label"

        # os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.gpu_id
        sess = tf.Session(config=sess_config)

        train_size = int(FLAGS.train_size / hvd.size())

        num_train_steps = int(train_size / FLAGS.batch_size * FLAGS.epoch)
        num_warmup_steps = int(num_train_steps * 0.01)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": (2e-5 / hvd.size()),
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        model_io_fn = model_io.ModelIO(model_io_config)
        optimizer_fn = optimizer.Optimizer(opt_config)

        num_choice = FLAGS.num_classes
        max_seq_length = FLAGS.max_length

        model_eval_fn = bert_classifier.classifier_model_fn_builder(
            config,
            num_choice,
            init_checkpoint,
            reuse=None,
            load_pretrained=True,
            model_io_fn=model_io_fn,
            optimizer_fn=optimizer_fn,
            model_io_config=model_io_config,
            opt_config=opt_config)

        def metric_fn(features, logits):
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.exp(tf.nn.log_softmax(logits))
            return {
                "pred_label": pred_label,
                "qas_id": features["qas_id"],
                "prob": prob
            }

        name_to_features = {
            "input_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            "input_mask": tf.FixedLenFeature([max_seq_length], tf.int64),
            "segment_ids": tf.FixedLenFeature([max_seq_length], tf.int64),
            "label_ids": tf.FixedLenFeature([], tf.int64),
            "qas_id": tf.FixedLenFeature([], tf.int64),
        }

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        params = Bunch({})
        params.epoch = 2
        params.batch_size = 32

        eval_features = tf_data_utils.eval_input_fn(FLAGS.output_file,
                                                    _decode_record,
                                                    name_to_features, params)

        [_, eval_loss, eval_per_example_loss,
         eval_logits] = model_eval_fn(eval_features, [],
                                      tf.estimator.ModeKeys.EVAL)
        result = metric_fn(eval_features, eval_logits)

        model_io_fn.set_saver()

        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())
        sess.run(init_op)

        def eval_fn(result):
            i = 0
            pred_label, qas_id, prob = [], [], []
            while True:
                try:
                    eval_result = sess.run(result)
                    pred_label.extend(eval_result["pred_label"].tolist())
                    qas_id.extend(eval_result["qas_id"].tolist())
                    prob.extend(eval_result["prob"].tolist())
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break
            return pred_label, qas_id, prob

        print("===========begin to eval============")
        [pred_label, qas_id, prob] = eval_fn(result)
        result = dict(zip(qas_id, pred_label))

        print(FLAGS.result_file.split("."))
        tmp_output = FLAGS.result_file.split(".")[0] + ".json"
        print(tmp_output, "===temp output===")
        json.dump({
            "id": qas_id,
            "label": pred_label,
            "prob": prob
        }, open(tmp_output, "w"))

        print(len(result), "=====valid result======")

        import pandas as pd
        df = pd.read_csv(FLAGS.eval_data_file)

        output = {}
        for index in range(df.shape[0]):
            output[df.loc[index]["id"]] = ""

        final_output = []

        cnt = 0
        for key in output:
            if key in result:
                final_output.append({
                    "Id":
                    key,
                    "Category":
                    label_id["id2label"][str(result[key])]
                })
                cnt += 1
            else:
                final_output.append({"Id": key, "Category": "unrelated"})

        df_out = pd.DataFrame(final_output)
        df_out.to_csv(FLAGS.result_file)

        print(len(output), cnt, len(final_output),
              "======num of results from model==========")
Exemple #11
0
def main(_):

	hvd.init()

	sess_config = tf.ConfigProto()
	sess_config.gpu_options.visible_device_list = str(hvd.local_rank())

	graph = tf.Graph()
	from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
	with graph.as_default():
		import json
		
		# config = json.load(open("/data/xuht/bert/chinese_L-12_H-768_A-12/bert_config.json", "r"))
		
		config = json.load(open(FLAGS.config_file, "r"))

		init_checkpoint = FLAGS.init_checkpoint
		print("===init checkoutpoint==={}".format(init_checkpoint))

		config = Bunch(config)
		config.use_one_hot_embeddings = True
		config.scope = "bert"
		config.dropout_prob = 0.1
		config.label_type = "single_label"
		config.lm_ratio = 0.0
		config.task_ratio = 1.0

		json.dump(config, open(FLAGS.model_output+"/config.json", "w"))

		init_lr = 1e-5

		if FLAGS.if_shard == "0":
			train_size = FLAGS.train_size
			epoch = int(FLAGS.epoch / hvd.size())
		elif FLAGS.if_shard == "1":
			train_size = int(FLAGS.train_size/hvd.size())
			epoch = FLAGS.epoch

		sess = tf.Session(config=sess_config)

		num_train_steps = int(
			train_size / FLAGS.batch_size * epoch)
		num_warmup_steps = int(num_train_steps * 0.1)

		num_storage_steps = int(train_size / FLAGS.batch_size)

		print(num_train_steps, num_warmup_steps, "=============")
		
		opt_config = Bunch({"init_lr":init_lr/(hvd.size()), 
							"num_train_steps":num_train_steps,
							"num_warmup_steps":num_warmup_steps})

		model_io_config = Bunch({"fix_lm":False})
		
		model_io_fn = model_io.ModelIO(model_io_config)

		optimizer_fn = optimizer.Optimizer(opt_config)
		
		num_choice = FLAGS.num_classes
		max_seq_length = FLAGS.max_length
		max_predictions_per_seq = FLAGS.max_predictions_per_seq

		model_train_fn = classifier_fn.classifier_model_fn_builder(config, 
												num_choice, init_checkpoint, 
												reuse=None, 
												load_pretrained=True,
												model_io_fn=model_io_fn,
												optimizer_fn=optimizer_fn,
												model_io_config=model_io_config, 
												opt_config=opt_config)


		model_eval_fn = classifier_fn.classifier_model_fn_builder(config, 
												num_choice, init_checkpoint, 
												reuse=True, 
												load_pretrained=True,
												model_io_fn=model_io_fn,
												optimizer_fn=optimizer_fn,
												model_io_config=model_io_config, 
												opt_config=opt_config)
		
		name_to_features = {
				"input_ids":
					tf.FixedLenFeature([max_seq_length], tf.int64),
				"input_mask":
					tf.FixedLenFeature([max_seq_length], tf.int64),
				"segment_ids":
					tf.FixedLenFeature([max_seq_length], tf.int64),
				"masked_lm_positions":
					tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
				"masked_lm_ids":
					tf.FixedLenFeature([max_predictions_per_seq], tf.int64),
				"masked_lm_weights":
					tf.FixedLenFeature([max_predictions_per_seq], tf.float32),
				"label_ids":
					tf.FixedLenFeature([], tf.int64),
				}

		def _decode_record(record, name_to_features):
			"""Decodes a record to a TensorFlow example.
			"""
			example = tf.parse_single_example(record, name_to_features)

			# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
			# So cast all int64 to int32.
			for name in list(example.keys()):
				t = example[name]
				if t.dtype == tf.int64:
					t = tf.to_int32(t)
				example[name] = t
			return example 

		params = Bunch({})
		params.epoch = epoch
		params.batch_size = FLAGS.batch_size

		def parse_folder(path):
			files = os.listdir(path)
			output = []
			for file_name in files:
				output.append(os.path.join(path, file_name))
			random.shuffle(output)
			return output

		train_features = tf_data_utils.train_input_fn(
									parse_folder(FLAGS.train_file),
									_decode_record, name_to_features, params)
		train_dict = model_train_fn(train_features, [], tf.estimator.ModeKeys.TRAIN)

		eval_features = tf_data_utils.eval_input_fn(
										parse_folder(FLAGS.dev_file),
										_decode_record, name_to_features, params)
		eval_dict = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)

		model_io_fn.set_saver()
		
		init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
		sess.run(init_op)

		sess.run(hvd.broadcast_global_variables(0))
		
		def eval_fn(op_dict):
			i = 0
			eval_total_dict = {}
			while True:
				try:
					eval_result = sess.run(op_dict)
					for key in eval_result:
						if key in ["probabilities", "label_ids"]:
							if key in eval_total_dict:
								eval_total_dict[key].extend(eval_result[key])
							else:
								eval_total_dict[key] = []
								eval_total_dict[key].extend(eval_result[key])
					i += 1
				except tf.errors.OutOfRangeError:
					print("End of dataset")
					break

			for key in eval_result:
				if key not in ["probabilities", "label_ids"]:
					eval_total_dict[key] = eval_result[key]

			label_id = eval_total_dict["label_ids"]
			label = np.argmax(np.array(eval_total_dict["probabilities"]), axis=-1)

			macro_f1 = f1_score(label_id, label, average="macro")
			micro_f1 = f1_score(label_id, label, average="micro")
			accuracy = accuracy_score(label_id, label)

			print("test accuracy {} macro_f1 score {} micro_f1 {} masked_lm_accuracy {} sentence_f {}".format(accuracy, 
																		macro_f1,  micro_f1, 
																		eval_total_dict["masked_lm_accuracy"],
																		eval_total_dict["sentence_f"]))
			return eval_total_dict

		def run_eval(steps):
			import _pickle as pkl
			eval_features = tf_data_utils.eval_input_fn(
										parse_folder(FLAGS.dev_file),
										_decode_record, name_to_features, params)
			eval_dict = model_eval_fn(eval_features, [], tf.estimator.ModeKeys.EVAL)
			sess.run(tf.local_variables_initializer())
			eval_finial_dict = eval_fn(eval_dict)
			if hvd.rank() == 0:
				pkl.dump(eval_finial_dict, open(FLAGS.model_output+"/eval_dict_{}.pkl".format(steps), "wb"))
			return eval_finial_dict
		
		def train_fn(op_dict):
			i = 0
			cnt = 0
			loss_dict = {}
			monitoring_train = []
			monitoring_eval = []
			while True:
				try:
					train_result = sess.run(op_dict)
					for key in train_result:
						if key == "train_op":
							continue
						else:
							if np.isnan(train_result[key]):
								print(train_loss, "get nan loss")
								break
							else:
								if key in loss_dict:
									loss_dict[key] += train_result[key]
								else:
									loss_dict[key] = train_result[key]
					
					i += 1
					cnt += 1
					
					if np.mod(i, num_storage_steps) == 0:
						string = ""
						for key in loss_dict:
							tmp = key + " " + str(loss_dict[key]/cnt) + "\t"
							string += tmp
						print(string)
						monitoring_train.append(loss_dict)

						eval_finial_dict = run_eval(int(i/num_storage_steps))
						monitoring_eval.append(eval_finial_dict)

						for key in loss_dict:
							loss_dict[key] = 0.0
						if hvd.rank() == 0:
							model_io_fn.save_model(sess, FLAGS.model_output+"/model_{}.ckpt".format(int(i/num_storage_steps)))
							print("==successful storing model=={}".format(int(i/num_storage_steps)))
						cnt = 0

				except tf.errors.OutOfRangeError:
					if hvd.rank() == 0:
						import _pickle as pkl
						pkl.dump({"train":monitoring_train,
							"eval":monitoring_eval}, open(FLAGS.model_output+"/monitoring.pkl", "wb"))

					break
		print("===========begin to train============")        
		train_fn(train_dict)
		if hvd.rank() == 0:
			model_io_fn.save_model(sess, FLAGS.model_output+"/model.ckpt")
			print("===========begin to eval============")
			eval_finial_dict = run_eval("final")