def main(_):
    domain_str = aFLAGS.domains
    global num_domains, layer_indexes, domain_weight, do_sent_pair
    num_domains = len(domain_str.split(","))
    layer_indexes_str = aFLAGS.layer_indexes
    layer_indexes = [int(x) for x in layer_indexes_str.split(",")]
    domain_weight = aFLAGS.domain_weight
    do_sent_pair = aFLAGS.do_sent_pair

    app = Application()

    if "PAI" in tf.__version__:
        train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                       is_training=True,
                                       input_schema=app.input_schema,
                                       batch_size=app.train_batch_size)
        eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                      is_training=False,
                                      input_schema=app.input_schema,
                                      batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)
    else:
        train_reader = CSVReader(input_glob=app.train_input_fp,
                                 is_training=True,
                                 input_schema=app.input_schema,
                                 batch_size=app.train_batch_size)
        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=False,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)
    def test_train_and_eval(self):
        app = Application()
        train_reader = CSVReader(input_glob=app.train_input_fp,
                                 is_training=True,
                                 input_schema=app.input_schema,
                                 batch_size=app.train_batch_size)

        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=False,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)

        app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
Exemple #3
0
    def test_dist_preprocess(self):
        app = Serialization()
        queue_size = 1

        proc_executor = ProcessExecutor(queue_size)

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           input_schema=app.input_schema,
                           is_training=False,
                           batch_size=app.preprocess_batch_size,
                           output_queue=proc_executor.get_output_queue())

        proc_executor.add(reader)

        feature_process = preprocessors.get_preprocessor(
            'google-bert-base-zh',
            thread_num=7,
            input_queue=proc_executor.get_input_queue(),
            output_queue=proc_executor.get_output_queue())
        proc_executor.add(feature_process)
        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema,
                           input_queue=proc_executor.get_input_queue())

        proc_executor.add(writer)
        proc_executor.run()
        proc_executor.wait()
        writer.close()
def do_train():
    config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)

    app = TextClassification(user_defined_config=config)

    train_reader = CSVReader(input_glob=app.train_input_fp,
                             is_training=True,
                             input_schema=app.input_schema,
                             batch_size=app.train_batch_size)

    eval_reader = CSVReader(input_glob=app.eval_input_fp,
                            is_training=False,
                            input_schema=app.input_schema,
                            batch_size=app.eval_batch_size)

    app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
    def test_preprocess_for_finetune(self):
        app = SerializationFinetuneModel()

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           is_training=False,
                           input_schema=app.input_schema,
                           batch_size=app.preprocess_batch_size)

        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema)

        app.run_preprocess(reader=reader, writer=writer)

        self.assertTrue(
            os.path.exists('output/preprocess_output_for_finetune.csv'))
        lines = open('output/preprocess_output_for_finetune.csv',
                     'r').readlines()
        pd = lines[0].strip()
        gt = "101,1352,1282,671,5709,1446,2990,7583,102,7027,1377,809,2990,5709,1446,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0"
        self.assertTrue(pd == gt)
        pd = lines[1].strip()
        gt = "101,5709,1446,3118,2898,7770,7188,4873,102,711,784,720,1351,802,2140,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0"
        self.assertTrue(pd == gt)
        pd = lines[2].strip()
        gt = "101,2769,4638,6010,6009,5709,1446,3118,102,2769,1168,3118,802,2140,2141,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t1"
        self.assertTrue(pd == gt)
Exemple #6
0
    def test_train_and_eval(self):
        app = Application()
        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=True,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)

        app.run_evaluate(eval_reader, checkpoint_path=app.eval_ckpt_path)
Exemple #7
0
def main(_):
    app = Application()
    if FLAGS.mode == "train_and_evaluate_on_the_fly":
        if "PAI" in tf.__version__:
            train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                           is_training=True,
                                           input_schema=app.input_schema,
                                           batch_size=app.train_batch_size)
            eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                          is_training=False,
                                          input_schema=app.input_schema,
                                          batch_size=app.eval_batch_size)
        else:
            train_reader = CSVReader(input_glob=app.train_input_fp,
                                     is_training=True,
                                     input_schema=app.input_schema,
                                     batch_size=app.train_batch_size)
            eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)

    if FLAGS.mode == "predict_on_the_fly":
        if "PAI" in tf.__version__:
            pred_reader = OdpsTableReader(input_glob=app.predict_input_fp,
                                          input_schema=app.input_schema,
                                          is_training=False,
                                          batch_size=app.predict_batch_size)
            pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp,
                                          output_schema=app.output_schema,
                                          slice_id=0,
                                          input_queue=None)
        else:
            pred_reader = CSVReader(input_glob=app.predict_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.predict_batch_size)
            pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                                    output_schema=app.output_schema)
        app.run_predict(reader=pred_reader,
                        writer=pred_writer,
                        checkpoint_path=app.predict_checkpoint_path)
Exemple #8
0
    def test_dist_preprocess(self):
        app = PretrainSerialization()

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           is_training=False,
                           input_schema=app.input_schema,
                           batch_size=app.preprocess_batch_size)

        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema)

        app.run_preprocess(reader=reader, writer=writer)
def main(_):
    app = FinetuneSerialization()

    reader = CSVReader(input_glob=app.preprocess_input_fp,
                       is_training=False,
                       input_schema=app.input_schema,
                       batch_size=app.preprocess_batch_size)

    writer = TFRecordWriter(output_glob=app.preprocess_output_fp,
                            output_schema=app.output_schema)

    app.run_preprocess(reader=reader, writer=writer)
    def test_predict(self):
        app = Application()
        predict_reader = CSVReader(input_glob=app.predict_input_fp,
                                   is_training=False,
                                   input_schema=app.input_schema,
                                   batch_size=app.predict_batch_size)

        predict_writer = CSVWriter(output_glob=app.predict_output_fp,
                                   output_schema=app.output_schema)

        app.run_predict(reader=predict_reader,
                        writer=predict_writer,
                        checkpoint_path=app.predict_checkpoint_path)
Exemple #11
0
def main(_):
    app = Application()
    if FLAGS.usePAI:
        train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                       is_training=True,
                                       input_schema=app.input_schema,
                                       batch_size=app.train_batch_size)
        eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                      is_training=False,
                                      input_schema=app.input_schema,
                                      batch_size=app.eval_batch_size)
    else:
        train_reader = CSVReader(input_glob=app.train_input_fp,
                                 is_training=True,
                                 input_schema=app.input_schema,
                                 batch_size=app.train_batch_size)
        eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                is_training=False,
                                input_schema=app.input_schema,
                                batch_size=app.eval_batch_size)

    app.run_train_and_evaluate(train_reader=train_reader,
                               eval_reader=eval_reader)
def do_predict():
    config = Config(mode="predict_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)
    pred_reader = CSVReader(input_glob=app.predict_input_fp,
                            is_training=False,
                            input_schema=app.input_schema,
                            batch_size=app.predict_batch_size)

    pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                            output_schema=app.output_schema)

    app.run_predict(reader=pred_reader, writer=pred_writer,
                    checkpoint_path=app.predict_checkpoint_path)

    pred = pd.read_csv('./data/predict.csv', header=None, delimiter='\t', encoding='utf8')

    pred.columns = ['true_label', 'pred_label_id']

    pred.head(10)
Exemple #13
0
def main(_):

    # load domain and class dict
    domains = aFLAGS.domains.split(",")
    classes = aFLAGS.classes.split(",")

    if aFLAGS.do_sent_pair:
        app = SentPairClsApplication()
    else:
        app = SentClsApplication()

    # create empty centroids
    domain_class_embeddings = dict()
    for domain_name in domains:
        for class_name in classes:
            key_name = domain_name + "\t" + class_name
            domain_class_embeddings[key_name] = list()

    # for training data
    if FLAGS.usePAI:
        predict_reader = OdpsTableReader(input_glob=app.predict_input_fp,
                                         is_training=False,
                                         input_schema=app.input_schema,
                                         batch_size=app.predict_batch_size)
    else:
        predict_reader = CSVReader(input_glob=app.predict_input_fp,
                                   is_training=False,
                                   input_schema=app.input_schema,
                                   batch_size=app.predict_batch_size)

    # do inference for training data
    temp_output_data = list()

    for output in app.run_predict(reader=predict_reader,
                                  checkpoint_path=app.predict_checkpoint_path):
        current_size = len(output["pool_output"])
        for i in range(current_size):
            if FLAGS.do_sent_pair:
                pool_output = output["pool_output"][i]
                text1 = output["text1"][i]
                text2 = output["text2"][i]
                domain = output["domain"][i]
                label = output["label"][i]
                key_name = domain.decode('utf-8') + "\t" + label.decode(
                    'utf-8')
                domain_class_embeddings[key_name].append(pool_output)
                temp_output_data.append(
                    (text1, text2, domain, label, pool_output))
            else:
                pool_output = output["pool_output"][i]
                text = output["text"][i]
                domain = output["domain"][i]
                label = output["label"][i]
                key_name = domain.decode('utf-8') + "\t" + label.decode(
                    'utf-8')
                domain_class_embeddings[key_name].append(pool_output)
                temp_output_data.append((text, domain, label, pool_output))

    # compute centroids
    centroid_embeddings = dict()
    for key_name in domain_class_embeddings:
        domain_class_data_embeddings = np.array(
            domain_class_embeddings[key_name])
        centroid_embeddings[key_name] = np.mean(domain_class_data_embeddings,
                                                axis=0)

    # output files for meta fine-tune
    if FLAGS.usePAI:
        #write odps tables
        records = []
        if aFLAGS.do_sent_pair:
            for text1, text2, domain, label, embeddings in temp_output_data:
                weight = compute_weight(domain, label, embeddings,
                                        centroid_embeddings)
                tup = (text1, text2, str(domains.index(domain)), label,
                       np.around(weight, decimals=5))
                records.append(tup)
        else:
            for text, domain, label, embeddings in temp_output_data:
                weight = compute_weight(domain, label, embeddings,
                                        centroid_embeddings)
                tup = (text, str(domains.index(domain)), label,
                       np.around(weight, decimals=5))
                records.append(tup)

        with tf.python_io.TableWriter(FLAGS.outputs) as writer:
            if aFLAGS.do_sent_pair:
                indices = list(x for x in range(0, 5))
            else:
                indices = list(x for x in range(0, 4))
            writer.write(records, indices)
    else:
        #write to local file
        with open(FLAGS.outputs, 'w+') as f:
            if aFLAGS.do_sent_pair:
                for text1, text2, domain, label, embeddings in temp_output_data:
                    weight = compute_weight(domain, label, embeddings,
                                            centroid_embeddings)
                    f.write(text1 + '\t' + text2 + '\t' +
                            str(domains.index(domain)) + '\t' + label + '\t' +
                            np.around(weight, decimals=5).astype('str') + '\n')
            else:
                for text, domain, label, embeddings in temp_output_data:
                    weight = compute_weight(domain, label, embeddings,
                                            centroid_embeddings)
                    f.write(text + '\t' + str(domains.index(domain)) + '\t' +
                            label + '\t' +
                            np.around(weight, decimals=5).astype('str') + '\n')

# # (六)启动训练

# In[13]:

config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json)

# In[14]:

app = Application(user_defined_config=config)

# In[15]:

train_reader = CSVReader(input_glob=app.train_input_fp,
                         is_training=True,
                         input_schema=app.input_schema,
                         batch_size=app.train_batch_size)

eval_reader = CSVReader(input_glob=app.eval_input_fp,
                        is_training=False,
                        input_schema=app.input_schema,
                        batch_size=app.eval_batch_size)

# In[16]:

app.run_train(reader=train_reader)

# In[17]:

ckpts = set()
with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"),
Exemple #15
0
def train_and_evaluate_on_the_fly():
    config_json = {
        "worker_hosts": "localhost",
        "task_index": 1,
        "job_name": "chief",
        "num_gpus": 1,
        "num_workers": 1,
        "preprocess_config": {
            "input_schema": None,
            "sequence_length": 128,
            "first_sequence": None,
            "second_sequence": None,
            "label_name": "label",
            "label_enumerate_values": None,
        },
        "model_config": {
            "pretrain_model_name_or_path": None,
            "num_labels": None
        },
        "train_config": {
            "keep_checkpoint_max": 11,
            "save_steps": None,
            "optimizer_config": {
                "optimizer": "adam",
                "weight_decay_ratio": 0.01,
                "warmup_ratio": 0.1,
            },
            "distribution_config": {
                "distribution_strategy": None,
            }
        },
        "evaluate_config": {
            "eval_batch_size": 8
        }
    }

    for arg in sys.argv[1:]:
        key = arg.split("=")[0].replace("--", "")
        val = arg.split("=")[1]
        if key == 'train_input_fp' or key == "train_batch_size" \
                or key == "model_dir" or key == 'num_epochs':
            config_json['train_config'][key] = val
        elif key == "eval_input_fp":
            config_json['evaluate_config'][key] = val
        elif key == "learning_rate" or key == 'warmup_ratio' or key == 'weight_decay_ratio':
            config_json['train_config']['optimizer_config'][key] = val
        elif key == 'pretrain_model_name_or_path':
            config_json['model_config'][key] = val
        elif key == 'num_gpus':
            config_json[key] = int(val)
        elif key == 'task_name':
            if val == "TNEWS":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "115,114,108,109,116,110,113,112,102,103,100,101,106,107,104"
                config_json['model_config']['num_labels'] = 15

            elif val == "AFQMC":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "IFLYTEK":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = ",".join(
                        [str(idx) for idx in range(119)])
                config_json['model_config']['num_labels'] = 119

            elif val == "CMNLI":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "entailment,neutral,contradiction"
                config_json['model_config']['num_labels'] = 3

            elif val == "CSL":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "QQP":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,xx1:str:1,xx2:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "SST-2":
                config_json['preprocess_config'][
                    'input_schema'] = "sent1:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "CoLA":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,label:str:1,xx:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "MRPC":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,xx:str:1,xx2:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "RTE":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "not_entailment,entailment"
                config_json['model_config']['num_labels'] = 2

            elif val == "BoolQ":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "True,False"
                config_json['model_config']['num_labels'] = 2

            elif val == "WiC":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "True,False"
                config_json['model_config']['num_labels'] = 2

            elif val == "WSC" or val == "CLUEWSC":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "True,False"
                config_json['model_config']['num_labels'] = 2

            elif val == "COPA":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "CB":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "neutral,entailment,contradiction"
                config_json['model_config']['num_labels'] = 3

    config = Config(mode="train_and_evaluate_on_the_fly",
                    config_json=config_json)
    app = Application(user_defined_config=config)
    train_reader = CSVReader(input_glob=app.train_input_fp,
                             is_training=True,
                             input_schema=app.input_schema,
                             batch_size=app.train_batch_size)

    app.run_train(reader=train_reader)

    eval_reader = CSVReader(input_glob=app.eval_input_fp,
                            is_training=False,
                            input_schema=app.input_schema,
                            batch_size=app.eval_batch_size)
    ckpts = set()
    with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"),
                        mode='r') as reader:
        for line in reader:
            line = line.strip()
            line = line.replace("oss://", "")
            ckpts.add(
                int(
                    line.split(":")[1].strip().replace(
                        "\"", "").split("/")[-1].replace("model.ckpt-", "")))

    if _APP_FLAGS.task_name != "CoLA":
        # early stopping
        best_acc = 0
        best_ckpt = None
        for ckpt in sorted(ckpts):
            checkpoint_path = os.path.join(app.config.model_dir,
                                           "model.ckpt-" + str(ckpt))
            tf.logging.info("checkpoint_path is {}".format(checkpoint_path))
            eval_results = app.run_evaluate(reader=eval_reader,
                                            checkpoint_path=checkpoint_path)
            acc = eval_results['py_accuracy']
            if acc > best_acc:
                best_ckpt = ckpt
                best_acc = acc
        tf.logging.info("best ckpt {}, best acc {}".format(
            best_ckpt, best_acc))

    else:
        # early stopping
        best_matthew_corr = 0
        best_ckpt = None
        for ckpt in sorted(ckpts):
            checkpoint_path = os.path.join(app.config.model_dir,
                                           "model.ckpt-" + str(ckpt))
            tf.logging.info("checkpoint_path is {}".format(checkpoint_path))
            eval_results = app.run_evaluate(reader=eval_reader,
                                            checkpoint_path=checkpoint_path)
            matthew_corr = eval_results['matthew_corr']
            if matthew_corr > best_matthew_corr:
                best_ckpt = ckpt
                best_matthew_corr = matthew_corr
        tf.logging.info("best ckpt {}, best matthew_corr {}".format(
            best_ckpt, best_matthew_corr))
Exemple #16
0
def predict_on_the_fly():
    config_json = {
        "worker_hosts": "localhost",
        "task_index": 1,
        "job_name": "chief",
        "num_gpus": 1,
        "num_workers": 1,
        "preprocess_config": {
            "input_schema": None,
            "output_schema": None,
            "sequence_length": 128,
            "first_sequence": None,
            "second_sequence": None,
            "label_enumerate_values": None,
        },
        "model_config": {
            "pretrain_model_name_or_path": None,
            "num_labels": None
        },
        "train_config": {
            "keep_checkpoint_max": 11,
            "save_steps": None,
            "optimizer_config": {
                "optimizer": "adam",
                "weight_decay_ratio": 0.01,
                "warmup_ratio": 0.1,
            },
            "distribution_config": {
                "distribution_strategy": None,
            }
        },
        "evaluate_config": {
            "eval_batch_size": 8
        },
        "predict_config": {
            "predict_checkpoint_path": None,
            "predict_input_fp": None,
            "predict_output_fp": None,
            "predict_batch_size": 1
        }
    }

    for arg in sys.argv[1:]:
        key = arg.split("=")[0].replace("--", "")
        val = arg.split("=")[1]
        if key == 'train_input_fp' or key == "train_batch_size" \
                or key == "model_dir" or key == 'num_epochs':
            config_json['train_config'][key] = val
        elif key == "eval_input_fp":
            config_json['evaluate_config'][key] = val
        elif key == "learning_rate" or key == 'warmup_ratio' or key == 'weight_decay_ratio':
            config_json['train_config']['optimizer_config'][key] = val
        elif key == 'pretrain_model_name_or_path':
            config_json['model_config'][key] = val
        elif key == 'predict_input_fp' or key == 'predict_checkpoint_path':
            config_json['predict_config'][key] = val
        elif key == 'num_gpus':
            config_json[key] = int(val)
        elif key == 'task_name':
            if val == "TNEWS":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "115,114,108,109,116,110,113,112,102,103,100,101,106,107,104"
                config_json['model_config']['num_labels'] = 15

            elif val == "AFQMC":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "IFLYTEK":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = ",".join(
                        [str(idx) for idx in range(119)])
                config_json['model_config']['num_labels'] = 119

            elif val == "CMNLI":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "entailment,neutral,contradiction"
                config_json['model_config']['num_labels'] = 3

            elif val == "CSL":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "QQP":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,xx1:str:1,xx2:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "SST-2":
                config_json['preprocess_config'][
                    'input_schema'] = "sent1:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "CoLA":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,label:str:1,xx:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "MRPC":
                config_json['preprocess_config'][
                    'input_schema'] = "label:str:1,xx:str:1,xx2:str:1,sent1:str:1,sent2:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "RTE":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "not_entailment,entailment"
                config_json['model_config']['num_labels'] = 2

            elif val == "BoolQ":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "True,False"
                config_json['model_config']['num_labels'] = 2

            elif val == "WiC":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "True,False"
                config_json['model_config']['num_labels'] = 2

            elif val == "WSC" or val == "CLUEWSC":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['model_config']['num_labels'] = 2

            elif val == "COPA":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "0,1"
                config_json['model_config']['num_labels'] = 2

            elif val == "CB":
                config_json['preprocess_config'][
                    'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1"
                config_json['preprocess_config']['first_sequence'] = "sent1"
                config_json['preprocess_config']['second_sequence'] = "sent2"
                config_json['preprocess_config'][
                    'label_enumerate_values'] = "neutral,entailment,contradiction"
                config_json['model_config']['num_labels'] = 3

    config = Config(mode="predict_on_the_fly", config_json=config_json)
    app = Application(user_defined_config=config)

    pred_reader = CSVReader(input_glob=app.predict_input_fp,
                            is_training=False,
                            input_schema=app.input_schema,
                            batch_size=1)

    id = 0
    with open("wsc_predict.json", "w") as f:
        for x in app.run_predict(reader=pred_reader,
                                 checkpoint_path=app.predict_checkpoint_path,
                                 yield_single_examples=True):

            label = None
            if x['predictions'] == 0:
                label = "true"
            else:
                label = "false"
            idx = str(x['predictions'])
            f.write("{\"id\": " + str(id) + ", \"label\": " + "\"" + label +
                    "\"}" + "\n")
            id += 1