コード例 #1
0
    def test_dist_preprocess(self):
        app = Serialization()
        queue_size = 1

        proc_executor = ProcessExecutor(queue_size)

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           input_schema=app.input_schema,
                           is_training=False,
                           batch_size=app.preprocess_batch_size,
                           output_queue=proc_executor.get_output_queue())

        proc_executor.add(reader)

        feature_process = preprocessors.get_preprocessor(
            'google-bert-base-zh',
            thread_num=7,
            input_queue=proc_executor.get_input_queue(),
            output_queue=proc_executor.get_output_queue())
        proc_executor.add(feature_process)
        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema,
                           input_queue=proc_executor.get_input_queue())

        proc_executor.add(writer)
        proc_executor.run()
        proc_executor.wait()
        writer.close()
    def test_preprocess_for_finetune(self):
        app = SerializationFinetuneModel()

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           is_training=False,
                           input_schema=app.input_schema,
                           batch_size=app.preprocess_batch_size)

        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema)

        app.run_preprocess(reader=reader, writer=writer)

        self.assertTrue(
            os.path.exists('output/preprocess_output_for_finetune.csv'))
        lines = open('output/preprocess_output_for_finetune.csv',
                     'r').readlines()
        pd = lines[0].strip()
        gt = "101,1352,1282,671,5709,1446,2990,7583,102,7027,1377,809,2990,5709,1446,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0"
        self.assertTrue(pd == gt)
        pd = lines[1].strip()
        gt = "101,5709,1446,3118,2898,7770,7188,4873,102,711,784,720,1351,802,2140,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t0"
        self.assertTrue(pd == gt)
        pd = lines[2].strip()
        gt = "101,2769,4638,6010,6009,5709,1446,3118,102,2769,1168,3118,802,2140,2141,102\t1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1\t0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1\t1"
        self.assertTrue(pd == gt)
コード例 #3
0
    def test_dist_preprocess(self):
        app = PretrainSerialization()

        reader = CSVReader(input_glob=app.preprocess_input_fp,
                           is_training=False,
                           input_schema=app.input_schema,
                           batch_size=app.preprocess_batch_size)

        writer = CSVWriter(output_glob=app.preprocess_output_fp,
                           output_schema=app.output_schema)

        app.run_preprocess(reader=reader, writer=writer)
    def test_predict(self):
        app = Application()
        predict_reader = CSVReader(input_glob=app.predict_input_fp,
                                   is_training=False,
                                   input_schema=app.input_schema,
                                   batch_size=app.predict_batch_size)

        predict_writer = CSVWriter(output_glob=app.predict_output_fp,
                                   output_schema=app.output_schema)

        app.run_predict(reader=predict_reader,
                        writer=predict_writer,
                        checkpoint_path=app.predict_checkpoint_path)
コード例 #5
0
def main(_):
    app = Application()
    if FLAGS.mode == "train_and_evaluate_on_the_fly":
        if "PAI" in tf.__version__:
            train_reader = OdpsTableReader(input_glob=app.train_input_fp,
                                           is_training=True,
                                           input_schema=app.input_schema,
                                           batch_size=app.train_batch_size)
            eval_reader = OdpsTableReader(input_glob=app.eval_input_fp,
                                          is_training=False,
                                          input_schema=app.input_schema,
                                          batch_size=app.eval_batch_size)
        else:
            train_reader = CSVReader(input_glob=app.train_input_fp,
                                     is_training=True,
                                     input_schema=app.input_schema,
                                     batch_size=app.train_batch_size)
            eval_reader = CSVReader(input_glob=app.eval_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.eval_batch_size)
        app.run_train_and_evaluate(train_reader=train_reader,
                                   eval_reader=eval_reader)

    if FLAGS.mode == "predict_on_the_fly":
        if "PAI" in tf.__version__:
            pred_reader = OdpsTableReader(input_glob=app.predict_input_fp,
                                          input_schema=app.input_schema,
                                          is_training=False,
                                          batch_size=app.predict_batch_size)
            pred_writer = OdpsTableWriter(output_glob=app.predict_output_fp,
                                          output_schema=app.output_schema,
                                          slice_id=0,
                                          input_queue=None)
        else:
            pred_reader = CSVReader(input_glob=app.predict_input_fp,
                                    is_training=False,
                                    input_schema=app.input_schema,
                                    batch_size=app.predict_batch_size)
            pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                                    output_schema=app.output_schema)
        app.run_predict(reader=pred_reader,
                        writer=pred_writer,
                        checkpoint_path=app.predict_checkpoint_path)
コード例 #6
0
def do_predict():
    config = Config(mode="predict_on_the_fly", config_json=config_json)
    app = TextClassification(user_defined_config=config)
    pred_reader = CSVReader(input_glob=app.predict_input_fp,
                            is_training=False,
                            input_schema=app.input_schema,
                            batch_size=app.predict_batch_size)

    pred_writer = CSVWriter(output_glob=app.predict_output_fp,
                            output_schema=app.output_schema)

    app.run_predict(reader=pred_reader, writer=pred_writer,
                    checkpoint_path=app.predict_checkpoint_path)

    pred = pd.read_csv('./data/predict.csv', header=None, delimiter='\t', encoding='utf8')

    pred.columns = ['true_label', 'pred_label_id']

    pred.head(10)