Esempio n. 1
0
def export_model(FLAGS,
				init_checkpoint,
				checkpoint_dir,
				export_dir,
				**kargs):

	config = model_config_parser(FLAGS)
	opt_config = Bunch({})
	anneal_config = Bunch({})
	model_io_config = Bunch({"fix_lm":False})

	# with tf.gfile.Open(FLAGS.label_id, "r") as frobj:
	# 	label_dict = json.load(frobj)

	num_classes = int(FLAGS.num_classes)

	def get_receiver_features():
		receiver_tensors = {
			"input_ids":tf.placeholder(tf.int32, [None, FLAGS.max_length], name='input_ids'),
			"segment_ids":tf.placeholder(tf.int32, [None, FLAGS.max_length], name='segment_ids'),
			"input_mask":tf.placeholder(tf.int32, [None, FLAGS.max_length], name='input_mask'),
			"input_ori_ids":tf.placeholder(tf.int32, [None, FLAGS.max_length], name='input_ori_ids'),
			"context":tf.placeholder(tf.int32, [None, None], name='context'),
		}
		return receiver_tensors

	def serving_input_receiver_fn():
		receiver_features = get_receiver_features()
		print(receiver_features, "==input receiver_features==")
		input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(receiver_features)()
		return input_fn

	
	model_fn = model_fn_builder(config, num_classes, init_checkpoint, 
										model_reuse=None, 
										load_pretrained="yes",
										opt_config=opt_config,
										model_io_config=model_io_config,
										exclude_scope="",
										not_storage_params=[],
										target=kargs.get("input_target", ""),
										output_type="estimator",
										checkpoint_dir=checkpoint_dir,
										num_storage_steps=100,
										task_index=0,
										anneal_config=anneal_config,
										**kargs)

	estimator = tf.estimator.Estimator(
				model_fn=model_fn,
				model_dir=checkpoint_dir)

	export_dir = estimator.export_savedmodel(export_dir, 
									serving_input_receiver_fn,
									checkpoint_path=init_checkpoint)
	print("===Succeeded in exporting saved model==={}".format(export_dir))
Esempio n. 2
0
def export_model(FLAGS, init_checkpoint, checkpoint_dir, export_dir, **kargs):

    config = model_config_parser(FLAGS)
    opt_config = Bunch({})
    anneal_config = Bunch({})
    model_io_config = Bunch({"fix_lm": False})

    with tf.gfile.Open(FLAGS.label_id, "r") as frobj:
        label_dict = json.load(frobj)

    num_classes = len(label_dict["id2label"])

    def serving_input_receiver_fn():
        receiver_features = data_interface_server(FLAGS)
        print(receiver_features, "==input receiver_features==")
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
            receiver_features)()
        return input_fn

    model_fn_builder = model_fn_interface(FLAGS)
    model_fn = model_fn_builder(config,
                                num_classes,
                                init_checkpoint,
                                model_reuse=None,
                                load_pretrained=FLAGS.load_pretrained,
                                opt_config=opt_config,
                                model_io_config=model_io_config,
                                exclude_scope="",
                                not_storage_params=[],
                                target=kargs.get("input_target", ""),
                                output_type="estimator",
                                checkpoint_dir=checkpoint_dir,
                                num_storage_steps=100,
                                task_index=0,
                                anneal_config=anneal_config,
                                **kargs)

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=checkpoint_dir)

    export_dir = estimator.export_savedmodel(export_dir,
                                             serving_input_receiver_fn,
                                             checkpoint_path=init_checkpoint)
    print("===Succeeded in exporting saved model==={}".format(export_dir))
Esempio n. 3
0
def export_model(FLAGS, init_checkpoint, checkpoint_dir, export_dir, **kargs):

    config = model_config_parser(FLAGS)
    opt_config = Bunch({})
    anneal_config = Bunch({})
    model_io_config = Bunch({"fix_lm": False})

    def serving_input_receiver_fn():
        receiver_features = data_interface_server(FLAGS)
        print(receiver_features, "==input receiver_features==")
        input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
            receiver_features)()
        return input_fn

    ss = np.random.random((2, 10))
    p = tf.constant(ss)
    sess = tf.Session()
    print(sess.run(p), ss)

    model_fn = model_fn_builder(config,
                                2,
                                init_checkpoint,
                                model_reuse=None,
                                load_pretrained=FLAGS.load_pretrained,
                                opt_config=opt_config,
                                model_io_config=model_io_config,
                                exclude_scope="",
                                not_storage_params=[],
                                target=kargs.get("input_target", ""),
                                output_type="estimator",
                                checkpoint_dir=checkpoint_dir,
                                num_storage_steps=100,
                                task_index=0,
                                anneal_config=anneal_config,
                                **kargs)

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=checkpoint_dir)

    export_dir = estimator.export_savedmodel(export_dir,
                                             serving_input_receiver_fn,
                                             checkpoint_path=init_checkpoint)
    print("===Succeeded in exporting saved model==={}".format(export_dir))
Esempio n. 4
0
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target,
                  init_checkpoint, train_file, dev_file, checkpoint_dir,
                  is_debug, **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        # config = model_config_parser(FLAGS)

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / worker_count)
        elif FLAGS.if_shard == "1":
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch
        else:
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch

        multi_task_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

        if is_debug == "0":
            num_storage_steps = 190
            num_eval_steps = 100
            num_train_steps = 200
        print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
              format(num_train_steps, num_eval_steps, num_storage_steps))

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": FLAGS.init_lr,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "worker_count": worker_count,
            "opt_type": FLAGS.opt_type,
            "is_chief": is_chief,
            "train_op": kargs.get("train_op", "adam"),
            "decay": kargs.get("decay", "no"),
            "warmup": kargs.get("warmup", "no"),
            "grad_clip": kargs.get("grad_clip", "global_norm"),
            "clip_norm": kargs.get("clip_norm", 1.0)
        })

        anneal_config = Bunch({
            "initial_value": 1.0,
            "num_train_steps": num_train_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        if FLAGS.opt_type == "hvd" and hvd:
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        else:
            checkpoint_dir = checkpoint_dir
        print("==checkpoint_dir==", checkpoint_dir, is_chief)

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}
        task_type_dict = {}
        model_type_lst = []
        label_dict = {}

        for task_type in FLAGS.multi_task_type.split(","):
            print("==task type==", task_type)
            model_config_dict[task_type] = model_config_parser(
                Bunch(multi_task_config[task_type]))
            num_labels_dict[task_type] = multi_task_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = multi_task_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = multi_task_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = multi_task_config[task_type][
                "not_storage_params"]
            target_dict[task_type] = multi_task_config[task_type]["target"]
            task_type_dict[task_type] = multi_task_config[task_type][
                "task_type"]
            label_dict[task_type] = json.load(
                open(
                    os.path.join(FLAGS.buckets,
                                 multi_task_config[task_type]["label_id"])))

        model_train_fn = multitask_model_fn(
            model_config_dict,
            num_labels_dict,
            task_type_dict,
            init_checkpoint_dict,
            load_pretrained_dict=load_pretrained_dict,
            opt_config=opt_config,
            model_io_config=model_io_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            output_type="sess",
            checkpoint_dir=checkpoint_dir,
            num_storage_steps=num_storage_steps,
            anneal_config=anneal_config,
            task_layer_reuse=None,
            model_type_lst=model_type_lst,
            **kargs)

        eval_model_fn = {}

        for task_type in FLAGS.multi_task_type.split(","):
            eval_task_type_dict = {}
            model_config_dict[task_type] = model_config_parser(
                Bunch(multi_task_config[task_type]))
            num_labels_dict[task_type] = multi_task_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = multi_task_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = multi_task_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = multi_task_config[task_type][
                "not_storage_params"]
            target_dict[task_type] = multi_task_config[task_type]["target"]
            eval_task_type_dict[task_type] = multi_task_config[task_type][
                "task_type"]

            eval_model_fn[task_type] = multitask_model_fn(
                model_config_dict,
                num_labels_dict,
                eval_task_type_dict,
                init_checkpoint_dict,
                load_pretrained_dict=load_pretrained_dict,
                opt_config=opt_config,
                model_io_config=model_io_config,
                exclude_scope_dict=exclude_scope_dict,
                not_storage_params_dict=not_storage_params_dict,
                target_dict=target_dict,
                output_type="sess",
                checkpoint_dir=checkpoint_dir,
                num_storage_steps=num_storage_steps,
                anneal_config=anneal_config,
                task_layer_reuse=True,
                model_type_lst=model_type_lst,
                multi_task_config=multi_task_config,
                **kargs)

        print("==succeeded in building model==")

        def eval_metric_fn(features, eval_op_dict, task_type):
            logits = eval_op_dict["logits"][task_type]
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["{}_label_ids".format(task_type)], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

            return {
                "accuracy": accuracy,
                "loss": eval_op_dict["loss"][task_type],
                "pred_label": pred_label,
                "label_ids": features["{}_label_ids".format(task_type)]
            }

        def train_metric_fn(features, train_op_dict):
            return train_op_dict

        name_to_features = data_interface(FLAGS, multi_task_config,
                                          FLAGS.multi_task_type.split(","))

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        def _decode_batch_record(record, name_to_features):
            example = tf.parse_example(record, name_to_features)
            return example

        params = Bunch({})
        params.epoch = epoch
        params.batch_size = FLAGS.batch_size

        if kargs.get("parse_type", "parse_single") == "parse_single":

            train_file_lst = [
                multi_task_config[task_type]["train_result_file"]
                for task_type in FLAGS.multi_task_type.split(",")
            ]

            print(train_file_lst)

            train_features = tf_data_utils.train_input_fn(
                train_file_lst,
                _decode_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

            eval_features_dict = {}
            for task_type in FLAGS.multi_task_type.split(","):
                name_to_features = data_interface(
                    FLAGS, {task_type: multi_task_config[task_type]})
                eval_features_dict[task_type] = tf_data_utils.eval_input_fn(
                    multi_task_config[task_type]["dev_result_file"],
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)

        elif kargs.get("parse_type", "parse_single") == "parse_batch":

            train_file_lst = [
                multi_task_config[task_type]["train_result_file"]
                for task_type in FLAGS.multi_task_type.split(",")
            ]
            train_file_path_lst = [
                os.path.join(FLAGS.buckets, train_file)
                for train_file in train_file_lst
            ]

            train_features = tf_data_utils.train_batch_input_fn(
                train_file_path_lst,
                _decode_batch_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

            eval_features_dict = {}
            for task_type in FLAGS.multi_task_type.split(","):
                name_to_features = data_interface(
                    FLAGS, {task_type: multi_task_config[task_type]},
                    [task_type_dict])

                dev_file_path = os.path.join(
                    FLAGS.buckets,
                    multi_task_config[task_type]["dev_result_file"])
                eval_features_dict[
                    task_type] = tf_data_utils.eval_batch_input_fn(
                        dev_file_path,
                        _decode_batch_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)

        train_op_dict = model_train_fn(train_features, [],
                                       tf.estimator.ModeKeys.TRAIN)
        train_dict = train_metric_fn(train_features, train_op_dict["train"])

        eval_dict = {}
        for task_type in eval_features_dict:
            eval_features = eval_features_dict[task_type]
            eval_op_dict = eval_model_fn[task_type](eval_features, [],
                                                    tf.estimator.ModeKeys.EVAL)
            eval_dict_tmp = eval_metric_fn(eval_features, eval_op_dict["eval"],
                                           task_type)
            eval_dict[task_type] = eval_dict_tmp

        print("==succeeded in building data and model==")

        print(train_op_dict)

        def task_eval(eval_dict, sess, eval_total_dict):
            eval_result = sess.run(eval_dict)
            for key in eval_result:
                if key not in eval_total_dict:
                    if key in ["pred_label", "label_ids"]:
                        eval_total_dict[key] = []
                        eval_total_dict[key].extend(eval_result[key])
                    if key in ["accuracy", "loss"]:
                        eval_total_dict[key] = 0.0
                        eval_total_dict[key] += eval_result[key]
                else:
                    if key in ["pred_label", "label_ids"]:
                        eval_total_dict[key].extend(eval_result[key])
                    if key in ["accuracy", "loss"]:
                        eval_total_dict[key] += eval_result[key]

        def task_metric(eval_dict, label_dict):
            label_id = eval_dict["label_ids"]
            pred_label = eval_dict["pred_label"]

            label_dict_id = sorted(list(label_dict["id2label"].keys()))

            print(len(label_id), len(pred_label), len(set(label_id)))

            accuracy = accuracy_score(label_id, pred_label)
            print("==accuracy==", accuracy)
            if len(label_dict["id2label"]) < 10:
                result = classification_report(label_id,
                                               pred_label,
                                               target_names=[
                                                   label_dict["id2label"][key]
                                                   for key in label_dict_id
                                               ],
                                               digits=4)
                print(result, task_index)
                eval_total_dict["classification_report"] = result
                print("==classification report==")

        def eval_fn(eval_dict, sess):
            i = 0
            total_accuracy = 0
            eval_total_dict = {}
            for task_type in eval_dict:
                eval_total_dict[task_type] = {}
            while True:
                try:
                    for task_type in eval_dict:
                        task_eval(eval_dict[task_type], sess,
                                  eval_total_dict[task_type])

                    i += 1
                    if np.mod(i, num_eval_steps) == 0:
                        break
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break

            for task_type in eval_total_dict:
                task_metric(eval_total_dict[task_type], label_dict[task_type])
            return eval_total_dict

        def train_fn(train_op_dict, sess):
            i = 0
            cnt = 0
            loss_dict = {}
            monitoring_train = []
            monitoring_eval = []
            while True:
                try:
                    [train_result] = sess.run([train_op_dict])
                    for key in train_result:
                        if key == "train_op":
                            continue
                        else:
                            if key == "loss":
                                for task_type in train_result[key]:
                                    loss_dict[task_type][
                                        "loss"] += train_result[key][task_type]
                            else:
                                try:
                                    if np.isnan(train_result[key]):
                                        print(train_loss, "get nan loss")
                                        break
                                    else:
                                        if key in loss_dict:
                                            loss_dict[key] += train_result[key]
                                        else:
                                            loss_dict[key] = train_result[key]
                                except:
                                    continue

                    i += 1
                    cnt += 1

                    if np.mod(i, num_storage_steps) == 0:
                        string = ""
                        for key in loss_dict:
                            tmp = key + " " + str(loss_dict[key] / cnt) + "\t"
                            string += tmp
                        print(string)
                        monitoring_train.append(loss_dict)

                        eval_finial_dict = eval_fn(eval_dict, sess)
                        monitoring_eval.append(eval_finial_dict)

                        for key in loss_dict:
                            loss_dict[key] = 0.0
                        cnt = 0

                    if is_debug == "0":
                        if i == num_train_steps:
                            break

                except tf.errors.OutOfRangeError:
                    print("==Succeeded in training model==")
                    break
            return {"eval": monitoring_eval, "train": monitoring_train}

        print("start training")

        hooks = []
        hooks.extend(train_op_dict["hooks"])
        if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
            sess_config = tf.ConfigProto(allow_soft_placement=False,
                                         log_device_placement=False)
            print("==create monitored training session==", FLAGS.opt_type,
                  is_chief)
            sess = tf.train.MonitoredTrainingSession(
                master=target,
                is_chief=is_chief,
                config=kargs.get("sess_config", sess_config),
                hooks=hooks,
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_steps=num_storage_steps)
        elif FLAGS.opt_type == "pai_soar" and pai:
            sess_config = tf.ConfigProto(allow_soft_placement=False,
                                         log_device_placement=False)
            sess = tf.train.MonitoredTrainingSession(
                master=target,
                is_chief=is_chief,
                config=kargs.get("sess_config", sess_config),
                hooks=hooks,
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_steps=num_storage_steps)
        elif FLAGS.opt_type == "hvd" and hvd:
            sess_config = tf.ConfigProto(allow_soft_placement=False,
                                         log_device_placement=False)
            sess_config.gpu_options.allow_growth = False
            sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
            sess = tf.train.MonitoredTrainingSession(
                checkpoint_dir=checkpoint_dir,
                hooks=hooks,
                config=sess_config,
                save_checkpoint_steps=num_storage_steps)
        else:
            print("==single sess==")
            sess_config = tf.ConfigProto(allow_soft_placement=False,
                                         log_device_placement=False)
            sess = tf.train.MonitoredTrainingSession(
                config=sess_config,
                hooks=hooks,
                checkpoint_dir=checkpoint_dir,
                save_checkpoint_steps=num_storage_steps)

        print("==begin to train and eval==")
        monitoring_info = train_fn(train_dict, sess)

        if task_index == 0:
            start_time = time.time()
            print("===========begin to eval============")
            eval_finial_dict = eval_fn(eval_dict, sess)
            end_time = time.time()
            print("==total forward time==", end_time - start_time)
Esempio n. 5
0
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target,
                  init_checkpoint, train_file, dev_file, checkpoint_dir,
                  is_debug, **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        # config = json.load(open(FLAGS.config_file, "r"))

        # config = Bunch(config)
        # config.use_one_hot_embeddings = True
        # config.scope = "bert"
        # config.dropout_prob = 0.1
        # config.label_type = "single_label"

        # config.model = FLAGS.model_type

        config = model_config_parser(FLAGS)

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / worker_count)
        elif FLAGS.if_shard == "1":
            print("==number of gpus==", kargs.get('num_gpus', 1))
            train_size = int(FLAGS.train_size / worker_count /
                             kargs.get('num_gpus', 1))
            # train_size = int(FLAGS.train_size)
            epoch = FLAGS.epoch
        else:
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch

        init_lr = FLAGS.init_lr

        distillation_dict = json.load(tf.gfile.Open(FLAGS.distillation_config))
        distillation_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        warmup_ratio = config.get('warmup', 0.1)

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        if config.get('ln_type', 'postln') == 'postln':
            num_warmup_steps = int(num_train_steps * warmup_ratio)
        elif config.get('ln_type', 'preln') == 'postln':
            num_warmup_steps = 0
        else:
            num_warmup_steps = int(num_train_steps * warmup_ratio)
        print('==num warmup steps==', num_warmup_steps)

        num_storage_steps = min([int(train_size / FLAGS.batch_size), 10000])
        if num_storage_steps <= 100:
            num_storage_steps = 500

        num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

        if is_debug == "0":
            num_storage_steps = 2
            num_eval_steps = 10
            num_train_steps = 10
        print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
              format(num_train_steps, num_eval_steps, num_storage_steps))

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============",
              kargs.get('num_gpus', 1), '==number of gpus==')

        if worker_count * kargs.get("num_gpus", 1) >= 2:
            clip_norm_scale = 1.0
            lr_scale = 0.8
        else:
            clip_norm_scale = 1.0
            lr_scale = 1.0
        lr = init_lr * worker_count * kargs.get("num_gpus", 1) * lr_scale
        if lr >= 1e-3:
            lr = 1e-3
        print('==init lr==', lr)

        opt_config = Bunch({
            "init_lr": lr,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "worker_count": worker_count,
            "gpu_count": worker_count * kargs.get("num_gpus", 1),
            "opt_type": FLAGS.opt_type,
            "is_chief": is_chief,
            "train_op": kargs.get("train_op", "adam"),
            "decay": kargs.get("decay", "no"),
            "warmup": kargs.get("warmup", "no"),
            "clip_norm": config.get("clip_norm", 1.0),
            "grad_clip": config.get("grad_clip", "global_norm"),
            "epoch": FLAGS.epoch,
            "strategy": FLAGS.distribution_strategy
        })

        anneal_config = Bunch({
            "initial_value": 1.0,
            "num_train_steps": num_train_steps
        })

        model_io_config = Bunch({"fix_lm": False})
        model_io_fn = model_io.ModelIO(model_io_config)

        num_classes = FLAGS.num_classes

        if FLAGS.opt_type == "hvd" and hvd:
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        elif FLAGS.opt_type == "all_reduce":
            checkpoint_dir = checkpoint_dir
        elif FLAGS.opt_type == "collective_reduce":
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        elif FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        print("==checkpoint_dir==", checkpoint_dir, is_chief)

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}

        for task_type in FLAGS.multi_task_type.split(","):
            print("==task type==", task_type)
            model_config_dict[task_type] = model_config_parser(
                Bunch(distillation_config[task_type]))
            print(task_type, distillation_config[task_type],
                  '=====task model config======')
            num_labels_dict[task_type] = distillation_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets,
                distillation_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = distillation_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = distillation_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = distillation_config[
                task_type]["not_storage_params"]
            target_dict[task_type] = distillation_config[task_type]["target"]

        model_fn = distillation_model_fn(
            model_config_dict,
            num_labels_dict,
            init_checkpoint_dict,
            load_pretrained_dict,
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            output_type="estimator",
            distillation_config=distillation_dict,
            **kargs)

        name_to_features = data_interface(FLAGS)

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        def _decode_batch_record(record, name_to_features):
            example = tf.parse_example(record, name_to_features)
            # for name in list(example.keys()):
            # 	t = example[name]
            # 	if t.dtype == tf.int64:
            # 		t = tf.to_int32(t)
            # 	example[name] = t

            return example

        params = Bunch({})
        params.epoch = FLAGS.epoch
        params.batch_size = FLAGS.batch_size

        if kargs.get("run_config", None):
            if kargs.get("parse_type", "parse_single") == "parse_single":
                train_features = lambda: tf_data_utils.all_reduce_train_input_fn(
                    train_file,
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)
                eval_features = lambda: tf_data_utils.all_reduce_eval_input_fn(
                    dev_file,
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)
            elif kargs.get("parse_type", "parse_single") == "parse_batch":
                print("==apply parse example==")
                train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn(
                    train_file,
                    _decode_batch_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)
                eval_features = lambda: tf_data_utils.all_reduce_eval_batch_input_fn(
                    dev_file,
                    _decode_batch_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)

        else:
            train_features = lambda: tf_data_utils.train_input_fn(
                train_file,
                _decode_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

            eval_features = lambda: tf_data_utils.eval_input_fn(
                dev_file,
                _decode_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

        train_hooks = []
        eval_hooks = []

        sess_config = tf.ConfigProto(allow_soft_placement=False,
                                     log_device_placement=False)
        if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
            print("==no need for hook==")
        elif FLAGS.opt_type == "pai_soar" and pai:
            print("no need for hook")
        elif FLAGS.opt_type == "hvd" and hvd:
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
            print("==no need fo hook==")
        else:
            print("==no need for hooks==")

        if kargs.get("run_config", None):
            run_config = kargs.get("run_config", None)
            run_config = run_config.replace(
                save_checkpoints_steps=num_storage_steps)
            print("==run config==", run_config.save_checkpoints_steps)
        else:
            run_config = tf.estimator.RunConfig(
                model_dir=checkpoint_dir,
                save_checkpoints_steps=num_storage_steps,
                session_config=sess_config)

        if kargs.get("profiler", "profiler") == "profiler":
            if checkpoint_dir:
                hooks = tf.train.ProfilerHook(
                    save_steps=100,
                    save_secs=None,
                    output_dir=os.path.join(checkpoint_dir, "profiler"),
                )
                train_hooks.append(hooks)
                print("==add profiler hooks==")

        model_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                                 model_dir=checkpoint_dir,
                                                 config=run_config)

        train_being_time = time.time()
        tf.logging.info("==training distribution_strategy=={}".format(
            kargs.get("distribution_strategy", "MirroredStrategy")))
        if kargs.get("distribution_strategy",
                     "MirroredStrategy") == "MirroredStrategy":
            print("==apply single machine multi-card training==")

            train_spec = tf.estimator.TrainSpec(input_fn=train_features,
                                                max_steps=num_train_steps)

            eval_spec = tf.estimator.EvalSpec(input_fn=eval_features,
                                              steps=num_eval_steps)

            model_estimator.train(input_fn=train_features,
                                  max_steps=num_train_steps,
                                  hooks=train_hooks)
            # tf.estimator.train(model_estimator, train_spec)

            train_end_time = time.time()
            print("==training time==", train_end_time - train_being_time)
            tf.logging.info("==training time=={}".format(train_end_time -
                                                         train_being_time))
            eval_results = model_estimator.evaluate(input_fn=eval_features,
                                                    steps=num_eval_steps)
            print(eval_results)

        elif kargs.get("distribution_strategy", "MirroredStrategy") in [
                "ParameterServerStrategy", "CollectiveAllReduceStrategy"
        ]:
            print("==apply multi-machine machine multi-card training==")
            try:
                print(os.environ['TF_CONFIG'], "==tf_run_config==")
            except:
                print("==not tf config==")
            train_spec = tf.estimator.TrainSpec(input_fn=train_features,
                                                max_steps=num_train_steps)

            eval_spec = tf.estimator.EvalSpec(input_fn=eval_features,
                                              steps=num_eval_steps)

            # tf.estimator.train(model_estimator, train_spec) # tf 1.12 doesn't need evaluate

            tf.estimator.train_and_evaluate(model_estimator, train_spec,
                                            eval_spec)
Esempio n. 6
0
def train_eval_fn(FLAGS, worker_count, task_index, is_chief, target,
                  init_checkpoint, train_file, dev_file, checkpoint_dir,
                  is_debug, **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        # config = model_config_parser(FLAGS)

        print(FLAGS.train_size)

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / worker_count)
        elif FLAGS.if_shard == "1":
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch
        else:
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch

        multi_task_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

        if is_debug == "0":
            num_storage_steps = 190
            num_eval_steps = 100
            num_train_steps = 200
        print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
              format(num_train_steps, num_eval_steps, num_storage_steps))

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": kargs.get("init_lr", 5e-5) / worker_count,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "worker_count": worker_count,
            "opt_type": FLAGS.opt_type,
            "is_chief": is_chief,
            "train_op": kargs.get("train_op", "adam"),
            "decay": kargs.get("decay", "no"),
            "warmup": kargs.get("warmup", "no"),
            "grad_clip": kargs.get("grad_clip", "global_norm"),
            "clip_norm": kargs.get("clip_norm", 1.0)
        })

        anneal_config = Bunch({
            "initial_value": 1.0,
            "num_train_steps": num_train_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        if FLAGS.opt_type == "hvd" and hvd:
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        else:
            checkpoint_dir = checkpoint_dir
        print("==checkpoint_dir==", checkpoint_dir, is_chief)

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}
        task_type_dict = {}
        model_type_lst = []
        label_dict = {}

        for task_type in FLAGS.multi_task_type.split(","):
            print("==task type==", task_type)
            model_config_dict[task_type] = model_config_parser(
                Bunch(multi_task_config[task_type]))
            num_labels_dict[task_type] = multi_task_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = multi_task_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = multi_task_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = multi_task_config[task_type][
                "not_storage_params"]
            target_dict[task_type] = multi_task_config[task_type]["target"]
            task_type_dict[task_type] = multi_task_config[task_type][
                "task_type"]
            label_dict[task_type] = json.load(
                tf.gfile.Open(
                    os.path.join(FLAGS.buckets,
                                 multi_task_config[task_type]["label_id"])))

        model_fn = multitask_model_fn(
            model_config_dict,
            num_labels_dict,
            task_type_dict,
            init_checkpoint_dict,
            load_pretrained_dict=load_pretrained_dict,
            opt_config=opt_config,
            model_io_config=model_io_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            output_type="estimator",
            checkpoint_dir=checkpoint_dir,
            num_storage_steps=num_storage_steps,
            anneal_config=anneal_config,
            task_layer_reuse=None,
            model_type_lst=model_type_lst,
            **kargs)

        print("==succeeded in building model==")

        name_to_features = data_interface(FLAGS, multi_task_config,
                                          FLAGS.multi_task_type.split(","))

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        def _decode_batch_record(record, name_to_features):
            example = tf.parse_example(record, name_to_features)
            return example

        params = Bunch({})
        params.epoch = epoch
        params.batch_size = FLAGS.batch_size

        if kargs.get("parse_type", "parse_single") == "parse_single":

            train_file_lst = [
                multi_task_config[task_type]["train_result_file"]
                for task_type in FLAGS.multi_task_type.split(",")
            ]

            print(train_file_lst)

            train_features = lambda: tf_data_utils.all_reduce_multitask_train_input_fn(
                train_file_lst,
                _decode_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

        elif kargs.get("parse_type", "parse_single") == "parse_batch":

            train_file_lst = [
                multi_task_config[task_type]["train_result_file"]
                for task_type in FLAGS.multi_task_type.split(",")
            ]
            train_file_path_lst = [
                os.path.join(FLAGS.buckets, train_file)
                for train_file in train_file_lst
            ]

            print(train_file_path_lst)
            print("==apply train batch==")

            train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn(
                train_file_path_lst,
                _decode_batch_record,
                name_to_features,
                params,
                if_shard=FLAGS.if_shard,
                worker_count=worker_count,
                task_index=task_index)

        print("==succeeded in building data and model==")
        print("start training")

        train_hooks = []

        sess_config = tf.ConfigProto(allow_soft_placement=False,
                                     log_device_placement=False)
        if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
            print("==no need for hook==")
        elif FLAGS.opt_type == "pai_soar" and pai:
            print("no need for hook")
        elif FLAGS.opt_type == "hvd" and hvd:
            sess_config.gpu_options.allow_growth = True
            sess_config.gpu_options.visible_device_list = str(hvd.local_rank())
            print("==no need fo hook==")
        else:
            print("==no need for hooks==")

        if kargs.get("run_config", None):
            run_config = kargs.get("run_config", None)
            run_config = run_config.replace(
                save_checkpoints_steps=num_storage_steps)
            print("==run config==", run_config.save_checkpoints_steps)
        else:
            run_config = tf.estimator.RunConfig(
                model_dir=checkpoint_dir,
                save_checkpoints_steps=num_storage_steps,
                session_config=sess_config)

        if kargs.get("profiler", "profiler") == "profiler":
            if checkpoint_dir:
                hooks = tf.train.ProfilerHook(
                    save_steps=100,
                    save_secs=None,
                    output_dir=os.path.join(checkpoint_dir, "profiler"),
                )
                train_hooks.append(hooks)
                print("==add profiler hooks==")

        model_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                                 config=run_config)

        train_being_time = time.time()
        tf.logging.info("==training distribution_strategy=={}".format(
            kargs.get("distribution_strategy", "MirroredStrategy")))
        if kargs.get("distribution_strategy",
                     "MirroredStrategy") == "MirroredStrategy":
            print("==apply single machine multi-card training==")
            model_estimator.train(input_fn=train_features,
                                  max_steps=num_train_steps,
                                  hooks=train_hooks)

            train_end_time = time.time()
            print("==training time==", train_end_time - train_being_time)
            tf.logging.info("==training time=={}".format(train_end_time -
                                                         train_being_time))

        elif kargs.get("distribution_strategy", "MirroredStrategy") in [
                "ParameterServerStrategy", "CollectiveAllReduceStrategy"
        ]:
            print("==apply multi-machine machine multi-card training==")
            try:
                print(os.environ['TF_CONFIG'], "==tf_run_config==")
            except:
                print("==not tf config==")
            train_spec = tf.estimator.TrainSpec(input_fn=train_features,
                                                max_steps=num_train_steps)

            eval_spec = tf.estimator.EvalSpec(input_fn=eval_features,
                                              steps=num_eval_steps)

            tf.estimator.train_and_evaluate(model_estimator, train_spec,
                                            eval_spec)
            train_end_time = time.time()
            print("==training time==", train_end_time - train_being_time)
def train_eval_fn(FLAGS, init_checkpoint, train_file, dev_file, checkpoint_dir,
                  **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        config = model_config_parser(FLAGS)

        train_size = int(FLAGS.train_size)
        init_lr = FLAGS.init_lr

        distillation_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        if FLAGS.use_tpu:
            warmup_ratio = config.get('warmup', 0.1)

            num_train_steps = int(train_size / FLAGS.batch_size * FLAGS.epoch)

            num_warmup_steps = int(num_train_steps * warmup_ratio)

            print('==num warmup steps==', num_warmup_steps)

            print(" model type {}".format(FLAGS.model_type))

            print(num_train_steps, num_warmup_steps, "=============",
                  kargs.get('num_gpus', 1), '==number of gpus==')
            tf.logging.info("***** Running evaluation *****")
            tf.logging.info("***** train steps : %d", num_train_steps)
            max_eval_steps = int(int(FLAGS.eval_size) / FLAGS.batch_size)

            clip_norm_scale = 1.0
            lr_scale = 1.0
            lr = init_lr

            checkpoint_dir = checkpoint_dir

            opt_config = Bunch({
                "init_lr":
                lr,
                "num_train_steps":
                num_train_steps,
                "num_warmup_steps":
                num_warmup_steps,
                "train_op":
                kargs.get("train_op", "adam"),
                "decay":
                kargs.get("decay", "no"),
                "warmup":
                kargs.get("warmup", "no"),
                "clip_norm":
                config.get("clip_norm", 1.0),
                "grad_clip":
                config.get("grad_clip", "global_norm"),
                "use_tpu":
                1
            })

        else:
            warmup_ratio = config.get('warmup', 0.1)
            worker_count = kargs.get('worker_count', 1)
            task_index = kargs.get('task_index', 0)
            is_chief = kargs.get('is_chief', 0)

            if FLAGS.if_shard == "0":
                train_size = FLAGS.train_size
                epoch = int(FLAGS.epoch / worker_count)
            elif FLAGS.if_shard == "1":
                print("==number of gpus==", kargs.get('num_gpus', 1))
                train_size = int(FLAGS.train_size / worker_count /
                                 kargs.get('num_gpus', 1))
                # train_size = int(FLAGS.train_size)
                epoch = FLAGS.epoch
            else:
                train_size = int(FLAGS.train_size / worker_count)
                epoch = FLAGS.epoch

            num_train_steps = int(train_size / FLAGS.batch_size * epoch)
            if config.get('ln_type', 'postln') == 'postln':
                num_warmup_steps = int(num_train_steps * warmup_ratio)
            elif config.get('ln_type', 'preln') == 'postln':
                num_warmup_steps = 0
            else:
                num_warmup_steps = int(num_train_steps * warmup_ratio)
            print('==num warmup steps==', num_warmup_steps)

            num_storage_steps = min(
                [int(train_size / FLAGS.batch_size), 10000])
            if num_storage_steps <= 100:
                num_storage_steps = 500

            num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

            print(
                "num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
                format(num_train_steps, num_eval_steps, num_storage_steps))

            print(" model type {}".format(FLAGS.model_type))

            print(num_train_steps, num_warmup_steps, "=============",
                  kargs.get('num_gpus', 1), '==number of gpus==')

            if worker_count * kargs.get("num_gpus", 1) >= 2:
                clip_norm_scale = 1.0
                lr_scale = 0.8
            else:
                clip_norm_scale = 1.0
                lr_scale = 1.0
            lr = init_lr * worker_count * kargs.get("num_gpus", 1) * lr_scale
            if lr >= 1e-3:
                lr = 1e-3
            print('==init lr==', lr)
            if FLAGS.opt_type == "hvd" and hvd:
                checkpoint_dir = checkpoint_dir if task_index == 0 else None
            elif FLAGS.opt_type == "all_reduce":
                checkpoint_dir = checkpoint_dir
            elif FLAGS.opt_type == "collective_reduce":
                checkpoint_dir = checkpoint_dir if task_index == 0 else None
            elif FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
                checkpoint_dir = checkpoint_dir if task_index == 0 else None

            opt_config = Bunch({
                "init_lr":
                lr,
                "num_train_steps":
                num_train_steps,
                "num_warmup_steps":
                num_warmup_steps,
                "worker_count":
                worker_count,
                "gpu_count":
                worker_count * kargs.get("num_gpus", 1),
                "opt_type":
                FLAGS.opt_type,
                "is_chief":
                is_chief,
                "train_op":
                kargs.get("train_op", "adam"),
                "decay":
                kargs.get("decay", "no"),
                "warmup":
                kargs.get("warmup", "no"),
                "clip_norm":
                config.get("clip_norm", 1.0),
                "grad_clip":
                config.get("grad_clip", "global_norm"),
                "epoch":
                FLAGS.epoch,
                "strategy":
                FLAGS.distribution_strategy,
                "use_tpu":
                0
            })

        model_io_config = Bunch({"fix_lm": False})
        model_io_fn = model_io.ModelIO(model_io_config)

        num_classes = FLAGS.num_classes

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}

        for task_type in FLAGS.multi_task_type.split(","):
            print("==task type==", task_type)
            model_config_dict[task_type] = model_config_parser(
                Bunch(distillation_config[task_type]))
            print(task_type, distillation_config[task_type],
                  '=====task model config======')
            num_labels_dict[task_type] = distillation_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets,
                distillation_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = distillation_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = distillation_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = distillation_config[
                task_type]["not_storage_params"]
            target_dict[task_type] = distillation_config[task_type]["target"]

        tf.logging.info("***** use tpu ***** %s", str(FLAGS.use_tpu))
        model_fn = classifier_model_fn_builder(
            model_config_dict,
            num_labels_dict,
            init_checkpoint_dict,
            load_pretrained_dict,
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            use_tpu=FLAGS.use_tpu,
            **kargs)

        if FLAGS.use_tpu:
            from data_generator import tf_data_utils
            estimator = tf.contrib.tpu.TPUEstimator(
                use_tpu=True,
                model_fn=model_fn,
                config=kargs.get('run_config', {}),
                train_batch_size=FLAGS.batch_size,
                eval_batch_size=FLAGS.batch_size)
            tf.logging.info("****** do train ******* %s", str(FLAGS.do_train))
            if FLAGS.do_train:
                tf.logging.info("***** Running training *****")
                tf.logging.info("  Batch size = %d", FLAGS.batch_size)
                input_features = tf_data_utils.electra_input_fn_builder(
                    train_file,
                    FLAGS.max_length,
                    FLAGS.max_predictions_per_seq,
                    True,
                    num_cpu_threads=4)
                estimator.train(input_fn=input_features,
                                max_steps=num_train_steps)
            else:
                tf.logging.info("***** Running evaluation *****")
                tf.logging.info("  Batch size = %d", FLAGS.batch_size)
                eval_input_fn = tf_data_utils.electra_input_fn_builder(
                    input_files=dev_file,
                    max_seq_length=FLAGS.max_length,
                    max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                    is_training=False)
                tf.logging.info("***** Begining Running evaluation *****")
                result = estimator.evaluate(input_fn=eval_input_fn,
                                            steps=max_eval_steps)
                output_eval_file = os.path.join(checkpoint_dir,
                                                "eval_results.txt")
                with tf.gfile.GFile(output_eval_file, "w") as writer:
                    tf.logging.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        tf.logging.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))
        else:
            from data_generator import distributed_tf_data_utils as tf_data_utils
            name_to_features = {
                "input_ids":
                tf.FixedLenFeature([FLAGS.max_length], tf.int64),
                "input_mask":
                tf.FixedLenFeature([FLAGS.max_length], tf.int64),
                "segment_ids":
                tf.FixedLenFeature([FLAGS.max_length], tf.int64),
                "input_ori_ids":
                tf.FixedLenFeature([FLAGS.max_length], tf.int64),
                "masked_lm_positions":
                tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64),
                "masked_lm_ids":
                tf.FixedLenFeature([FLAGS.max_predictions_per_seq], tf.int64),
                "masked_lm_weights":
                tf.FixedLenFeature([FLAGS.max_predictions_per_seq],
                                   tf.float32),
                "next_sentence_labels":
                tf.FixedLenFeature([], tf.int64),
            }

            def _decode_record(record, name_to_features):
                """Decodes a record to a TensorFlow example.
				"""
                example = tf.parse_single_example(record, name_to_features)

                # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
                # So cast all int64 to int32.
                for name in list(example.keys()):
                    t = example[name]
                    if t.dtype == tf.int64:
                        t = tf.to_int32(t)
                    example[name] = t

                return example

            def _decode_batch_record(record, name_to_features):
                example = tf.parse_example(record, name_to_features)
                # for name in list(example.keys()):
                # 	t = example[name]
                # 	if t.dtype == tf.int64:
                # 		t = tf.to_int32(t)
                # 	example[name] = t

                return example

            params = Bunch({})
            params.epoch = FLAGS.epoch
            params.batch_size = FLAGS.batch_size

            if kargs.get("run_config", None):
                if kargs.get("parse_type", "parse_single") == "parse_single":
                    train_features = lambda: tf_data_utils.all_reduce_train_input_fn(
                        train_file,
                        _decode_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)
                    eval_features = lambda: tf_data_utils.all_reduce_eval_input_fn(
                        dev_file,
                        _decode_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)
                elif kargs.get("parse_type", "parse_single") == "parse_batch":
                    print("==apply parse example==")
                    train_features = lambda: tf_data_utils.all_reduce_train_batch_input_fn(
                        train_file,
                        _decode_batch_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)
                    eval_features = lambda: tf_data_utils.all_reduce_eval_batch_input_fn(
                        dev_file,
                        _decode_batch_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)

            else:
                train_features = lambda: tf_data_utils.train_input_fn(
                    train_file,
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)

                eval_features = lambda: tf_data_utils.eval_input_fn(
                    dev_file,
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)

            train_hooks = []
            eval_hooks = []

            sess_config = tf.ConfigProto(allow_soft_placement=False,
                                         log_device_placement=False)
            if FLAGS.opt_type == "ps" or FLAGS.opt_type == "ps_sync":
                print("==no need for hook==")
            elif FLAGS.opt_type == "pai_soar" and pai:
                print("no need for hook")
            elif FLAGS.opt_type == "hvd" and hvd:
                sess_config.gpu_options.allow_growth = True
                sess_config.gpu_options.visible_device_list = str(
                    hvd.local_rank())
                print("==no need fo hook==")
            else:
                print("==no need for hooks==")

            if kargs.get("run_config", None):
                run_config = kargs.get("run_config", None)
                run_config = run_config.replace(
                    save_checkpoints_steps=num_storage_steps)
                print("==run config==", run_config.save_checkpoints_steps)
            else:
                run_config = tf.estimator.RunConfig(
                    model_dir=checkpoint_dir,
                    save_checkpoints_steps=num_storage_steps,
                    session_config=sess_config)

            if kargs.get("profiler", "profiler") == "profiler":
                if checkpoint_dir:
                    hooks = tf.train.ProfilerHook(
                        save_steps=100,
                        save_secs=None,
                        output_dir=os.path.join(checkpoint_dir, "profiler"),
                    )
                    train_hooks.append(hooks)
                    print("==add profiler hooks==")

            model_estimator = tf.estimator.Estimator(model_fn=model_fn,
                                                     model_dir=checkpoint_dir,
                                                     config=run_config)

            train_being_time = time.time()
            tf.logging.info("==training distribution_strategy=={}".format(
                kargs.get("distribution_strategy", "MirroredStrategy")))
            if kargs.get("distribution_strategy",
                         "MirroredStrategy") == "MirroredStrategy":
                print("==apply single machine multi-card training==")

                train_spec = tf.estimator.TrainSpec(input_fn=train_features,
                                                    max_steps=num_train_steps)

                eval_spec = tf.estimator.EvalSpec(input_fn=eval_features,
                                                  steps=num_eval_steps)

                model_estimator.train(input_fn=train_features,
                                      max_steps=num_train_steps,
                                      hooks=train_hooks)
                # tf.estimator.train(model_estimator, train_spec)

                train_end_time = time.time()
                print("==training time==", train_end_time - train_being_time)
                tf.logging.info("==training time=={}".format(train_end_time -
                                                             train_being_time))
                eval_results = model_estimator.evaluate(input_fn=eval_features,
                                                        steps=num_eval_steps)
                print(eval_results)

            elif kargs.get("distribution_strategy", "MirroredStrategy") in [
                    "ParameterServerStrategy", "CollectiveAllReduceStrategy"
            ]:
                print("==apply multi-machine machine multi-card training==")
                try:
                    print(os.environ['TF_CONFIG'], "==tf_run_config==")
                except:
                    print("==not tf config==")
                train_spec = tf.estimator.TrainSpec(input_fn=train_features,
                                                    max_steps=num_train_steps)

                eval_spec = tf.estimator.EvalSpec(input_fn=eval_features,
                                                  steps=num_eval_steps)

                # tf.estimator.train(model_estimator, train_spec) # tf 1.12 doesn't need evaluate

                tf.estimator.train_and_evaluate(model_estimator, train_spec,
                                                eval_spec)
Esempio n. 8
0
def eval_fn(FLAGS, worker_count, task_index, is_chief, target, init_checkpoint,
            train_file, dev_file, checkpoint_dir, is_debug, **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        # config = model_config_parser(FLAGS)

        if FLAGS.if_shard == "0":
            train_size = FLAGS.train_size
            epoch = int(FLAGS.epoch / worker_count)
        elif FLAGS.if_shard == "1":
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch
        else:
            train_size = int(FLAGS.train_size / worker_count)
            epoch = FLAGS.epoch

        multi_task_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        num_train_steps = int(train_size / FLAGS.batch_size * epoch)
        num_warmup_steps = int(num_train_steps * 0.1)

        num_storage_steps = int(train_size / FLAGS.batch_size)

        num_eval_steps = int(FLAGS.eval_size / FLAGS.batch_size)

        if is_debug == "0":
            num_storage_steps = 190
            num_eval_steps = 100
            num_train_steps = 200
        print("num_train_steps {}, num_eval_steps {}, num_storage_steps {}".
              format(num_train_steps, num_eval_steps, num_storage_steps))

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============")

        opt_config = Bunch({
            "init_lr": kargs.get("init_lr", 1e-5) / worker_count,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "worker_count": worker_count,
            "opt_type": FLAGS.opt_type,
            "is_chief": is_chief,
            "train_op": kargs.get("train_op", "adam"),
            "decay": kargs.get("decay", "no"),
            "warmup": kargs.get("warmup", "no"),
            "grad_clip": kargs.get("grad_clip", "global_norm"),
            "clip_norm": kargs.get("clip_norm", 1.0)
        })

        anneal_config = Bunch({
            "initial_value": 1.0,
            "num_train_steps": num_train_steps
        })

        model_io_config = Bunch({"fix_lm": False})

        if FLAGS.opt_type == "hvd" and hvd:
            checkpoint_dir = checkpoint_dir if task_index == 0 else None
        else:
            checkpoint_dir = checkpoint_dir
        print("==checkpoint_dir==", checkpoint_dir, is_chief)

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}
        task_type_dict = {}
        model_type_lst = []
        label_dict = {}

        eval_model_fn = {}

        for task_type in FLAGS.multi_task_type.split(","):
            eval_task_type_dict = {}
            model_config_dict[task_type] = model_config_parser(
                Bunch(multi_task_config[task_type]))
            num_labels_dict[task_type] = multi_task_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets, multi_task_config[task_type]["init_checkpoint"])
            print(
                init_checkpoint_dict[task_type], task_type, "===",
                os.path.join(FLAGS.buckets,
                             multi_task_config[task_type]["init_checkpoint"]))
            load_pretrained_dict[task_type] = multi_task_config[task_type][
                "load_pretrained"]
            exclude_scope_dict[task_type] = multi_task_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = multi_task_config[task_type][
                "not_storage_params"]
            target_dict[task_type] = multi_task_config[task_type]["target"]
            eval_task_type_dict[task_type] = multi_task_config[task_type][
                "task_type"]
            label_dict[task_type] = json.load(
                tf.gfile.Open(
                    os.path.join(FLAGS.buckets,
                                 multi_task_config[task_type]["label_id"])))

            eval_model_fn[task_type] = multitask_model_fn(
                model_config_dict,
                num_labels_dict,
                eval_task_type_dict,
                init_checkpoint_dict,
                load_pretrained_dict=load_pretrained_dict,
                opt_config=opt_config,
                model_io_config=model_io_config,
                exclude_scope_dict=exclude_scope_dict,
                not_storage_params_dict=not_storage_params_dict,
                target_dict=target_dict,
                output_type="sess",
                checkpoint_dir=checkpoint_dir,
                num_storage_steps=num_storage_steps,
                anneal_config=anneal_config,
                task_layer_reuse=False,
                model_type_lst=model_type_lst,
                **kargs)

        print(init_checkpoint_dict, "==init_checkpoint==")

        print("==succeeded in building model==")

        def eval_metric_fn(features, eval_op_dict, task_type):
            logits = eval_op_dict["logits"][task_type]
            print(logits.get_shape(), "===logits shape===")
            pred_label = tf.argmax(logits, axis=-1, output_type=tf.int32)
            prob = tf.nn.softmax(logits)
            accuracy = correct = tf.equal(
                tf.cast(pred_label, tf.int32),
                tf.cast(features["{}_label_ids".format(task_type)], tf.int32))
            accuracy = tf.reduce_mean(tf.cast(correct, tf.float32))

            return {
                "accuracy": accuracy,
                "loss": eval_op_dict["loss"][task_type],
                "pred_label": pred_label,
                "label_ids": features["{}_label_ids".format(task_type)]
            }

        name_to_features = data_interface(FLAGS, multi_task_config,
                                          FLAGS.multi_task_type.split(","))

        def _decode_record(record, name_to_features):
            """Decodes a record to a TensorFlow example.
			"""
            example = tf.parse_single_example(record, name_to_features)

            # tf.Example only supports tf.int64, but the TPU only supports tf.int32.
            # So cast all int64 to int32.
            for name in list(example.keys()):
                t = example[name]
                if t.dtype == tf.int64:
                    t = tf.to_int32(t)
                example[name] = t

            return example

        def _decode_batch_record(record, name_to_features):
            example = tf.parse_example(record, name_to_features)
            return example

        params = Bunch({})
        params.epoch = 0
        params.batch_size = FLAGS.batch_size

        if kargs.get("parse_type", "parse_single") == "parse_single":

            eval_features_dict = {}
            for task_type in FLAGS.multi_task_type.split(","):
                name_to_features = data_interface(
                    FLAGS, {task_type: multi_task_config[task_type]},
                    [task_type])
                eval_features_dict[task_type] = tf_data_utils.eval_input_fn(
                    multi_task_config[task_type]["dev_result_file"],
                    _decode_record,
                    name_to_features,
                    params,
                    if_shard=FLAGS.if_shard,
                    worker_count=worker_count,
                    task_index=task_index)

        elif kargs.get("parse_type", "parse_single") == "parse_batch":

            eval_features_dict = {}
            for task_type in FLAGS.multi_task_type.split(","):
                name_to_features = data_interface(
                    FLAGS, {task_type: multi_task_config[task_type]},
                    [task_type])

                dev_file_path = os.path.join(
                    FLAGS.buckets,
                    multi_task_config[task_type]["test_result_file"])
                eval_features_dict[
                    task_type] = tf_data_utils.eval_batch_input_fn(
                        dev_file_path,
                        _decode_batch_record,
                        name_to_features,
                        params,
                        if_shard=FLAGS.if_shard,
                        worker_count=worker_count,
                        task_index=task_index)

        eval_dict = {}
        for task_type in eval_features_dict:
            eval_features = eval_features_dict[task_type]
            eval_op_dict = eval_model_fn[task_type](eval_features, [],
                                                    tf.estimator.ModeKeys.EVAL)
            eval_dict_tmp = eval_metric_fn(eval_features, eval_op_dict["eval"],
                                           task_type)
            eval_dict[task_type] = eval_dict_tmp
        print(eval_dict)

        print("==succeeded in building data and model==")

        def task_eval(eval_dict, sess, eval_total_dict):
            eval_result = sess.run(eval_dict)
            for key in eval_result:
                if key not in eval_total_dict:
                    if key in ["pred_label", "label_ids"]:
                        eval_total_dict[key] = []
                        eval_total_dict[key].extend(eval_result[key])
                    if key in ["accuracy", "loss"]:
                        eval_total_dict[key] = 0.0
                        eval_total_dict[key] += eval_result[key]
                else:
                    if key in ["pred_label", "label_ids"]:
                        eval_total_dict[key].extend(eval_result[key])
                    if key in ["accuracy", "loss"]:
                        eval_total_dict[key] += eval_result[key]

        def task_metric(eval_dict, label_dict, eval_total_dict):
            label_id = eval_dict["label_ids"]
            pred_label = eval_dict["pred_label"]

            label_dict_id = sorted(list(label_dict["id2label"].keys()))

            print(len(label_id), len(pred_label), len(set(label_id)))

            accuracy = accuracy_score(label_id, pred_label)
            print("==accuracy==", accuracy)
            if len(label_dict["id2label"]) < 10:
                result = classification_report(label_id,
                                               pred_label,
                                               target_names=[
                                                   label_dict["id2label"][key]
                                                   for key in label_dict_id
                                               ],
                                               digits=4)
                print(result, task_index)
                eval_total_dict["classification_report"] = result
                print("==classification report==")

        def eval_fn(eval_dict, sess):
            i = 0
            total_accuracy = 0
            eval_total_dict = {}
            for task_type in eval_dict:
                eval_total_dict[task_type] = {}
            while True:
                try:
                    for task_type in eval_dict:
                        task_eval(eval_dict[task_type], sess,
                                  eval_total_dict[task_type])
                    i += 1
                except tf.errors.OutOfRangeError:
                    print("End of dataset")
                    break

            for task_type in eval_total_dict:
                task_metric(eval_total_dict[task_type], label_dict[task_type],
                            eval_total_dict[task_type])
            return eval_total_dict

        print("start evaluating")
        sess_config = tf.ConfigProto(allow_soft_placement=False,
                                     log_device_placement=False)

        sess = tf.Session(config=sess_config)
        init_op = tf.group(tf.global_variables_initializer(),
                           tf.local_variables_initializer())

        sess.run(init_op)

        print("==begin to train and eval==")
        start_time = time.time()
        eval_finial_dict = eval_fn(eval_dict, sess)
        end_time = time.time()
        print("==forward time==", end_time - start_time)
        return eval_finial_dict
Esempio n. 9
0
def train_eval_fn(FLAGS, init_checkpoint, train_file, dev_file, checkpoint_dir,
                  **kargs):

    graph = tf.Graph()
    with graph.as_default():
        import json

        config = model_config_parser(FLAGS)

        train_size = int(FLAGS.train_size)
        init_lr = FLAGS.init_lr

        warmup_ratio = config.get('warmup', 0.1)

        num_train_steps = int(train_size / FLAGS.batch_size * FLAGS.epoch)

        num_warmup_steps = int(num_train_steps * warmup_ratio)

        print('==num warmup steps==', num_warmup_steps)

        print(" model type {}".format(FLAGS.model_type))

        print(num_train_steps, num_warmup_steps, "=============",
              kargs.get('num_gpus', 1), '==number of gpus==')
        tf.logging.info("***** Running evaluation *****")
        tf.logging.info("***** train steps : %d", num_train_steps)
        max_eval_steps = int(int(FLAGS.eval_size) / FLAGS.batch_size)

        clip_norm_scale = 1.0
        lr_scale = 1.0
        lr = init_lr

        opt_config = Bunch({
            "init_lr": lr,
            "num_train_steps": num_train_steps,
            "num_warmup_steps": num_warmup_steps,
            "train_op": kargs.get("train_op", "adam"),
            "decay": kargs.get("decay", "no"),
            "warmup": kargs.get("warmup", "no"),
            "clip_norm": config.get("clip_norm", 1.0),
            "grad_clip": config.get("grad_clip", "global_norm"),
            "use_tpu": 1
        })

        model_io_config = Bunch({"fix_lm": False})
        model_io_fn = model_io.ModelIO(model_io_config)

        num_classes = FLAGS.num_classes
        checkpoint_dir = checkpoint_dir

        if FLAGS.model_type == 'bert':
            model_fn_builder = classifier_model_fn_builder
            tf.logging.info("****** bert mlm ******")
        elif FLAGS.model_type == 'bert_seq':
            model_fn_builder = classifier_seq_model_fn_builder
            tf.logging.info("****** bert seq as gpt ******")
        elif FLAGS.model_type == 'gated_cnn_seq':
            model_fn_builder = gatedcnn_model_fn_builder
            tf.logging.info("****** gated cnn seq ******")
        else:
            model_fn_builder = classifier_model_fn_builder
            tf.logging.info("****** bert mlm ******")

        model_fn = model_fn_builder(
            config,
            num_classes,
            init_checkpoint,
            model_reuse=None,
            load_pretrained=FLAGS.load_pretrained,
            model_io_config=model_io_config,
            opt_config=opt_config,
            model_io_fn=model_io_fn,
            # exclude_scope=kargs.get('exclude_scope', ""),
            not_storage_params=[],
            target=kargs.get("input_target", ""),
            num_train_steps=num_train_steps,
            use_tpu=True,
            **kargs)

        estimator = tf.contrib.tpu.TPUEstimator(
            use_tpu=True,
            model_fn=model_fn,
            config=kargs.get('run_config', {}),
            train_batch_size=FLAGS.batch_size,
            eval_batch_size=FLAGS.batch_size)
        tf.logging.info("****** do train ******* %s", str(FLAGS.do_train))

        if FLAGS.random_generator == "1":
            input_fn_builder = tf_data_utils.electra_input_fn_builder
            tf.logging.info(
                "***** Running random sample input fn builder *****")
        elif FLAGS.random_generator == "2":
            input_fn_builder = tf_data_utils.bert_seq_input_fn_builder
            tf.logging.info("***** Running bert seq input fn builder *****")
        elif FLAGS.random_generator == "3":
            input_fn_builder = tf_data_utils.bert_mnli_input_fn_builder
            tf.logging.info("***** Running bert seq input fn builder *****")
        elif FLAGS.random_generator == "4":
            input_fn_builder = tf_data_utils.gatedcnn_pretrain_input_fn_builder_v1
            tf.logging.info("***** Running gatedcnn input fn builder *****")
        else:
            input_fn_builder = tf_data_utils.input_fn_builder
            tf.logging.info(
                "***** Running fixed sample input fn builder *****")

        if FLAGS.do_train:
            tf.logging.info("***** Running training *****")
            tf.logging.info("  Batch size = %d", FLAGS.batch_size)
            input_features = input_fn_builder(train_file,
                                              FLAGS.max_length,
                                              FLAGS.max_predictions_per_seq,
                                              True,
                                              num_cpu_threads=4)
            estimator.train(input_fn=input_features, max_steps=num_train_steps)
        else:
            tf.logging.info("***** Running evaluation *****")
            tf.logging.info("  Batch size = %d", FLAGS.batch_size)
            eval_input_fn = input_fn_builder(
                input_files=dev_file,
                max_seq_length=FLAGS.max_length,
                max_predictions_per_seq=FLAGS.max_predictions_per_seq,
                is_training=False)
            tf.logging.info("***** Begining Running evaluation *****")
            result = estimator.evaluate(input_fn=eval_input_fn,
                                        steps=max_eval_steps)
            output_eval_file = os.path.join(checkpoint_dir, "eval_results.txt")
            with tf.gfile.GFile(output_eval_file, "w") as writer:
                tf.logging.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    tf.logging.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
Esempio n. 10
0
def export_api():
    graph = tf.Graph()
    with graph.as_default():
        import json

        config = model_config_parser(FLAGS)

        train_size = int(FLAGS.train_size)
        init_lr = FLAGS.init_lr

        distillation_config = Bunch(
            json.load(tf.gfile.Open(FLAGS.multi_task_config)))

        model_io_config = Bunch({"fix_lm": False})
        model_io_fn = model_io.ModelIO(model_io_config)

        num_classes = FLAGS.num_classes

        model_config_dict = {}
        num_labels_dict = {}
        init_checkpoint_dict = {}
        load_pretrained_dict = {}
        exclude_scope_dict = {}
        not_storage_params_dict = {}
        target_dict = {}

        for task_type in FLAGS.multi_task_type.split(","):
            print("==task type==", task_type)
            model_config_dict[task_type] = model_config_parser(
                Bunch(distillation_config[task_type]))
            print(task_type, distillation_config[task_type],
                  '=====task model config======')
            num_labels_dict[task_type] = distillation_config[task_type][
                "num_labels"]
            init_checkpoint_dict[task_type] = os.path.join(
                FLAGS.buckets,
                distillation_config[task_type]["init_checkpoint"])
            load_pretrained_dict[task_type] = "yes"
            exclude_scope_dict[task_type] = distillation_config[task_type][
                "exclude_scope"]
            not_storage_params_dict[task_type] = distillation_config[
                task_type]["not_storage_params"]
            target_dict[task_type] = distillation_config[task_type]["target"]

        def serving_input_receiver_fn():
            receiver_features = {}
            print(receiver_features, "==input receiver_features==")
            input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(
                receiver_features)()
            return input_fn

        model_fn = model_fn_builder(
            model_config_dict,
            num_labels_dict,
            init_checkpoint_dict,
            load_pretrained_dict,
            model_io_config=model_io_config,
            opt_config=opt_config,
            exclude_scope_dict=exclude_scope_dict,
            not_storage_params_dict=not_storage_params_dict,
            target_dict=target_dict,
            use_tpu=FLAGS.use_tpu,
            **kargs)

    estimator = tf.estimator.Estimator(model_fn=model_fn,
                                       model_dir=checkpoint_dir)

    export_dir = estimator.export_savedmodel(export_dir,
                                             serving_input_receiver_fn,
                                             checkpoint_path=init_checkpoint)
    print("===Succeeded in exporting saved model==={}".format(export_dir))