def main(_): domain_str = aFLAGS.domains global num_domains, layer_indexes, domain_weight, do_sent_pair num_domains = len(domain_str.split(",")) layer_indexes_str = aFLAGS.layer_indexes layer_indexes = [int(x) for x in layer_indexes_str.split(",")] domain_weight = aFLAGS.domain_weight do_sent_pair = aFLAGS.do_sent_pair app = Application() if "PAI" in tf.__version__: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def test_train_and_eval(self): app = Application() train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def test_dist_preprocess(self): app = Serialization() queue_size = 1 proc_executor = ProcessExecutor(queue_size) reader = CSVReader(input_glob=app.preprocess_input_fp, input_schema=app.input_schema, is_training=False, batch_size=app.preprocess_batch_size, output_queue=proc_executor.get_output_queue()) proc_executor.add(reader) feature_process = preprocessors.get_preprocessor( 'google-bert-base-zh', thread_num=7, input_queue=proc_executor.get_input_queue(), output_queue=proc_executor.get_output_queue()) proc_executor.add(feature_process) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema, input_queue=proc_executor.get_input_queue()) proc_executor.add(writer) proc_executor.run() proc_executor.wait() writer.close()
def do_train(): config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json) app = TextClassification(user_defined_config=config) train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def test_preprocess_for_finetune(self): app = SerializationFinetuneModel() reader = CSVReader(input_glob=app.preprocess_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.preprocess_batch_size) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema) app.run_preprocess(reader=reader, writer=writer) self.assertTrue( os.path.exists('output/preprocess_output_for_finetune.csv')) lines = open('output/preprocess_output_for_finetune.csv', 'r').readlines() pd = lines[0].strip() gt = "101,1352,1282,671,5709,1446,2990,7583,102,7027,1377,809,2990,5709,1446,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0" self.assertTrue(pd == gt) pd = lines[1].strip() gt = "101,5709,1446,3118,2898,7770,7188,4873,102,711,784,720,1351,802,2140,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0" self.assertTrue(pd == gt) pd = lines[2].strip() gt = "101,2769,4638,6010,6009,5709,1446,3118,102,2769,1168,3118,802,2140,2141,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t1" self.assertTrue(pd == gt)
def test_train_and_eval(self): app = Application() eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_evaluate(eval_reader, checkpoint_path=app.eval_ckpt_path)
def main(_): app = Application() if FLAGS.mode == "train_and_evaluate_on_the_fly": if "PAI" in tf.__version__: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) if FLAGS.mode == "predict_on_the_fly": if "PAI" in tf.__version__: pred_reader = OdpsTableReader(input_glob=app.predict_input_fp, input_schema=app.input_schema, is_training=False, batch_size=app.predict_batch_size) pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema, slice_id=0, input_queue=None) else: pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=pred_reader, writer=pred_writer, checkpoint_path=app.predict_checkpoint_path)
def test_dist_preprocess(self): app = PretrainSerialization() reader = CSVReader(input_glob=app.preprocess_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.preprocess_batch_size) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema) app.run_preprocess(reader=reader, writer=writer)
def main(_): app = FinetuneSerialization() reader = CSVReader(input_glob=app.preprocess_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.preprocess_batch_size) writer = TFRecordWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema) app.run_preprocess(reader=reader, writer=writer)
def test_predict(self): app = Application() predict_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) predict_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=predict_reader, writer=predict_writer, checkpoint_path=app.predict_checkpoint_path)
def main(_): app = Application() if FLAGS.usePAI: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)
def do_predict(): config = Config(mode="predict_on_the_fly", config_json=config_json) app = TextClassification(user_defined_config=config) pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=pred_reader, writer=pred_writer, checkpoint_path=app.predict_checkpoint_path) pred = pd.read_csv('./data/predict.csv', header=None, delimiter='\t', encoding='utf8') pred.columns = ['true_label', 'pred_label_id'] pred.head(10)
def main(_): # load domain and class dict domains = aFLAGS.domains.split(",") classes = aFLAGS.classes.split(",") if aFLAGS.do_sent_pair: app = SentPairClsApplication() else: app = SentClsApplication() # create empty centroids domain_class_embeddings = dict() for domain_name in domains: for class_name in classes: key_name = domain_name + "\t" + class_name domain_class_embeddings[key_name] = list() # for training data if FLAGS.usePAI: predict_reader = OdpsTableReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) else: predict_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) # do inference for training data temp_output_data = list() for output in app.run_predict(reader=predict_reader, checkpoint_path=app.predict_checkpoint_path): current_size = len(output["pool_output"]) for i in range(current_size): if FLAGS.do_sent_pair: pool_output = output["pool_output"][i] text1 = output["text1"][i] text2 = output["text2"][i] domain = output["domain"][i] label = output["label"][i] key_name = domain.decode('utf-8') + "\t" + label.decode( 'utf-8') domain_class_embeddings[key_name].append(pool_output) temp_output_data.append( (text1, text2, domain, label, pool_output)) else: pool_output = output["pool_output"][i] text = output["text"][i] domain = output["domain"][i] label = output["label"][i] key_name = domain.decode('utf-8') + "\t" + label.decode( 'utf-8') domain_class_embeddings[key_name].append(pool_output) temp_output_data.append((text, domain, label, pool_output)) # compute centroids centroid_embeddings = dict() for key_name in domain_class_embeddings: domain_class_data_embeddings = np.array( domain_class_embeddings[key_name]) centroid_embeddings[key_name] = np.mean(domain_class_data_embeddings, axis=0) # output files for meta fine-tune if FLAGS.usePAI: #write odps tables records = [] if aFLAGS.do_sent_pair: for text1, text2, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) tup = (text1, text2, str(domains.index(domain)), label, np.around(weight, decimals=5)) records.append(tup) else: for text, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) tup = (text, str(domains.index(domain)), label, np.around(weight, decimals=5)) records.append(tup) with tf.python_io.TableWriter(FLAGS.outputs) as writer: if aFLAGS.do_sent_pair: indices = list(x for x in range(0, 5)) else: indices = list(x for x in range(0, 4)) writer.write(records, indices) else: #write to local file with open(FLAGS.outputs, 'w+') as f: if aFLAGS.do_sent_pair: for text1, text2, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) f.write(text1 + '\t' + text2 + '\t' + str(domains.index(domain)) + '\t' + label + '\t' + np.around(weight, decimals=5).astype('str') + '\n') else: for text, domain, label, embeddings in temp_output_data: weight = compute_weight(domain, label, embeddings, centroid_embeddings) f.write(text + '\t' + str(domains.index(domain)) + '\t' + label + '\t' + np.around(weight, decimals=5).astype('str') + '\n')
# # (六)启动训练 # In[13]: config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json) # In[14]: app = Application(user_defined_config=config) # In[15]: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) # In[16]: app.run_train(reader=train_reader) # In[17]: ckpts = set() with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"),
def train_and_evaluate_on_the_fly(): config_json = { "worker_hosts": "localhost", "task_index": 1, "job_name": "chief", "num_gpus": 1, "num_workers": 1, "preprocess_config": { "input_schema": None, "sequence_length": 128, "first_sequence": None, "second_sequence": None, "label_name": "label", "label_enumerate_values": None, }, "model_config": { "pretrain_model_name_or_path": None, "num_labels": None }, "train_config": { "keep_checkpoint_max": 11, "save_steps": None, "optimizer_config": { "optimizer": "adam", "weight_decay_ratio": 0.01, "warmup_ratio": 0.1, }, "distribution_config": { "distribution_strategy": None, } }, "evaluate_config": { "eval_batch_size": 8 } } for arg in sys.argv[1:]: key = arg.split("=")[0].replace("--", "") val = arg.split("=")[1] if key == 'train_input_fp' or key == "train_batch_size" \ or key == "model_dir" or key == 'num_epochs': config_json['train_config'][key] = val elif key == "eval_input_fp": config_json['evaluate_config'][key] = val elif key == "learning_rate" or key == 'warmup_ratio' or key == 'weight_decay_ratio': config_json['train_config']['optimizer_config'][key] = val elif key == 'pretrain_model_name_or_path': config_json['model_config'][key] = val elif key == 'num_gpus': config_json[key] = int(val) elif key == 'task_name': if val == "TNEWS": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "115,114,108,109,116,110,113,112,102,103,100,101,106,107,104" config_json['model_config']['num_labels'] = 15 elif val == "AFQMC": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "IFLYTEK": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = ",".join( [str(idx) for idx in range(119)]) config_json['model_config']['num_labels'] = 119 elif val == "CMNLI": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "entailment,neutral,contradiction" config_json['model_config']['num_labels'] = 3 elif val == "CSL": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "QQP": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,xx1:str:1,xx2:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "SST-2": config_json['preprocess_config'][ 'input_schema'] = "sent1:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "CoLA": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,label:str:1,xx:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "MRPC": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,xx:str:1,xx2:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "RTE": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "not_entailment,entailment" config_json['model_config']['num_labels'] = 2 elif val == "BoolQ": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "True,False" config_json['model_config']['num_labels'] = 2 elif val == "WiC": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "True,False" config_json['model_config']['num_labels'] = 2 elif val == "WSC" or val == "CLUEWSC": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "True,False" config_json['model_config']['num_labels'] = 2 elif val == "COPA": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "CB": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "neutral,entailment,contradiction" config_json['model_config']['num_labels'] = 3 config = Config(mode="train_and_evaluate_on_the_fly", config_json=config_json) app = Application(user_defined_config=config) train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) app.run_train(reader=train_reader) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) ckpts = set() with tf.gfile.GFile(os.path.join(app.config.model_dir, "checkpoint"), mode='r') as reader: for line in reader: line = line.strip() line = line.replace("oss://", "") ckpts.add( int( line.split(":")[1].strip().replace( "\"", "").split("/")[-1].replace("model.ckpt-", ""))) if _APP_FLAGS.task_name != "CoLA": # early stopping best_acc = 0 best_ckpt = None for ckpt in sorted(ckpts): checkpoint_path = os.path.join(app.config.model_dir, "model.ckpt-" + str(ckpt)) tf.logging.info("checkpoint_path is {}".format(checkpoint_path)) eval_results = app.run_evaluate(reader=eval_reader, checkpoint_path=checkpoint_path) acc = eval_results['py_accuracy'] if acc > best_acc: best_ckpt = ckpt best_acc = acc tf.logging.info("best ckpt {}, best acc {}".format( best_ckpt, best_acc)) else: # early stopping best_matthew_corr = 0 best_ckpt = None for ckpt in sorted(ckpts): checkpoint_path = os.path.join(app.config.model_dir, "model.ckpt-" + str(ckpt)) tf.logging.info("checkpoint_path is {}".format(checkpoint_path)) eval_results = app.run_evaluate(reader=eval_reader, checkpoint_path=checkpoint_path) matthew_corr = eval_results['matthew_corr'] if matthew_corr > best_matthew_corr: best_ckpt = ckpt best_matthew_corr = matthew_corr tf.logging.info("best ckpt {}, best matthew_corr {}".format( best_ckpt, best_matthew_corr))
def predict_on_the_fly(): config_json = { "worker_hosts": "localhost", "task_index": 1, "job_name": "chief", "num_gpus": 1, "num_workers": 1, "preprocess_config": { "input_schema": None, "output_schema": None, "sequence_length": 128, "first_sequence": None, "second_sequence": None, "label_enumerate_values": None, }, "model_config": { "pretrain_model_name_or_path": None, "num_labels": None }, "train_config": { "keep_checkpoint_max": 11, "save_steps": None, "optimizer_config": { "optimizer": "adam", "weight_decay_ratio": 0.01, "warmup_ratio": 0.1, }, "distribution_config": { "distribution_strategy": None, } }, "evaluate_config": { "eval_batch_size": 8 }, "predict_config": { "predict_checkpoint_path": None, "predict_input_fp": None, "predict_output_fp": None, "predict_batch_size": 1 } } for arg in sys.argv[1:]: key = arg.split("=")[0].replace("--", "") val = arg.split("=")[1] if key == 'train_input_fp' or key == "train_batch_size" \ or key == "model_dir" or key == 'num_epochs': config_json['train_config'][key] = val elif key == "eval_input_fp": config_json['evaluate_config'][key] = val elif key == "learning_rate" or key == 'warmup_ratio' or key == 'weight_decay_ratio': config_json['train_config']['optimizer_config'][key] = val elif key == 'pretrain_model_name_or_path': config_json['model_config'][key] = val elif key == 'predict_input_fp' or key == 'predict_checkpoint_path': config_json['predict_config'][key] = val elif key == 'num_gpus': config_json[key] = int(val) elif key == 'task_name': if val == "TNEWS": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "115,114,108,109,116,110,113,112,102,103,100,101,106,107,104" config_json['model_config']['num_labels'] = 15 elif val == "AFQMC": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "IFLYTEK": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = ",".join( [str(idx) for idx in range(119)]) config_json['model_config']['num_labels'] = 119 elif val == "CMNLI": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "entailment,neutral,contradiction" config_json['model_config']['num_labels'] = 3 elif val == "CSL": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "QQP": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,xx1:str:1,xx2:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "SST-2": config_json['preprocess_config'][ 'input_schema'] = "sent1:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "CoLA": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,label:str:1,xx:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "MRPC": config_json['preprocess_config'][ 'input_schema'] = "label:str:1,xx:str:1,xx2:str:1,sent1:str:1,sent2:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "RTE": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "not_entailment,entailment" config_json['model_config']['num_labels'] = 2 elif val == "BoolQ": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "True,False" config_json['model_config']['num_labels'] = 2 elif val == "WiC": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "True,False" config_json['model_config']['num_labels'] = 2 elif val == "WSC" or val == "CLUEWSC": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['model_config']['num_labels'] = 2 elif val == "COPA": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "0,1" config_json['model_config']['num_labels'] = 2 elif val == "CB": config_json['preprocess_config'][ 'input_schema'] = "idx:str:1,sent1:str:1,sent2:str:1,label:str:1" config_json['preprocess_config']['first_sequence'] = "sent1" config_json['preprocess_config']['second_sequence'] = "sent2" config_json['preprocess_config'][ 'label_enumerate_values'] = "neutral,entailment,contradiction" config_json['model_config']['num_labels'] = 3 config = Config(mode="predict_on_the_fly", config_json=config_json) app = Application(user_defined_config=config) pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=1) id = 0 with open("wsc_predict.json", "w") as f: for x in app.run_predict(reader=pred_reader, checkpoint_path=app.predict_checkpoint_path, yield_single_examples=True): label = None if x['predictions'] == 0: label = "true" else: label = "false" idx = str(x['predictions']) f.write("{\"id\": " + str(id) + ", \"label\": " + "\"" + label + "\"}" + "\n") id += 1