def test_dist_preprocess(self): app = Serialization() queue_size = 1 proc_executor = ProcessExecutor(queue_size) reader = CSVReader(input_glob=app.preprocess_input_fp, input_schema=app.input_schema, is_training=False, batch_size=app.preprocess_batch_size, output_queue=proc_executor.get_output_queue()) proc_executor.add(reader) feature_process = preprocessors.get_preprocessor( 'google-bert-base-zh', thread_num=7, input_queue=proc_executor.get_input_queue(), output_queue=proc_executor.get_output_queue()) proc_executor.add(feature_process) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema, input_queue=proc_executor.get_input_queue()) proc_executor.add(writer) proc_executor.run() proc_executor.wait() writer.close()
def test_preprocess_for_finetune(self): app = SerializationFinetuneModel() reader = CSVReader(input_glob=app.preprocess_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.preprocess_batch_size) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema) app.run_preprocess(reader=reader, writer=writer) self.assertTrue( os.path.exists('output/preprocess_output_for_finetune.csv')) lines = open('output/preprocess_output_for_finetune.csv', 'r').readlines() pd = lines[0].strip() gt = "101,1352,1282,671,5709,1446,2990,7583,102,7027,1377,809,2990,5709,1446,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0" self.assertTrue(pd == gt) pd = lines[1].strip() gt = "101,5709,1446,3118,2898,7770,7188,4873,102,711,784,720,1351,802,2140,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0" self.assertTrue(pd == gt) pd = lines[2].strip() gt = "101,2769,4638,6010,6009,5709,1446,3118,102,2769,1168,3118,802,2140,2141,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t1" self.assertTrue(pd == gt)
def test_dist_preprocess(self): app = PretrainSerialization() reader = CSVReader(input_glob=app.preprocess_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.preprocess_batch_size) writer = CSVWriter(output_glob=app.preprocess_output_fp, output_schema=app.output_schema) app.run_preprocess(reader=reader, writer=writer)
def test_predict(self): app = Application() predict_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) predict_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=predict_reader, writer=predict_writer, checkpoint_path=app.predict_checkpoint_path)
def main(_): app = Application() if FLAGS.mode == "train_and_evaluate_on_the_fly": if "PAI" in tf.__version__: train_reader = OdpsTableReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = OdpsTableReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) else: train_reader = CSVReader(input_glob=app.train_input_fp, is_training=True, input_schema=app.input_schema, batch_size=app.train_batch_size) eval_reader = CSVReader(input_glob=app.eval_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.eval_batch_size) app.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader) if FLAGS.mode == "predict_on_the_fly": if "PAI" in tf.__version__: pred_reader = OdpsTableReader(input_glob=app.predict_input_fp, input_schema=app.input_schema, is_training=False, batch_size=app.predict_batch_size) pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema, slice_id=0, input_queue=None) else: pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=pred_reader, writer=pred_writer, checkpoint_path=app.predict_checkpoint_path)
def do_predict(): config = Config(mode="predict_on_the_fly", config_json=config_json) app = TextClassification(user_defined_config=config) pred_reader = CSVReader(input_glob=app.predict_input_fp, is_training=False, input_schema=app.input_schema, batch_size=app.predict_batch_size) pred_writer = CSVWriter(output_glob=app.predict_output_fp, output_schema=app.output_schema) app.run_predict(reader=pred_reader, writer=pred_writer, checkpoint_path=app.predict_checkpoint_path) pred = pd.read_csv('./data/predict.csv', header=None, delimiter='\t', encoding='utf8') pred.columns = ['true_label', 'pred_label_id'] pred.head(10)