def run(self):
        self.proc_executor = distribution.ProcessExecutor(self.queue_size)
        worker_id = self.config.task_index
        num_workers = len(self.config.worker_hosts.split(","))
        proc_executor = distribution.ProcessExecutor(self.queue_size)

        reader = get_reader_fn(self.config.preprocess_input_fp)(
            input_glob=self.config.preprocess_input_fp,
            input_schema=self.config.input_schema,
            is_training=False,
            batch_size=self.config.preprocess_batch_size,
            slice_id=worker_id,
            slice_count=num_workers,
            output_queue=proc_executor.get_output_queue())

        proc_executor.add(reader)
        preprocessor = preprocessors.get_preprocessor(
            self.config.tokenizer_name_or_path,
            thread_num=self.thread_num,
            input_queue=proc_executor.get_input_queue(),
            output_queue=proc_executor.get_output_queue(),
            preprocess_batch_size=self.config.preprocess_batch_size,
            user_defined_config=self.config,
            app_model_name=self.config.app_model_name)
        proc_executor.add(preprocessor)
        writer = get_writer_fn(self.config.preprocess_output_fp)(
            output_glob=self.config.preprocess_output_fp,
            output_schema=self.config.output_schema,
            slice_id=worker_id,
            input_queue=proc_executor.get_input_queue())

        proc_executor.add(writer)
        proc_executor.run()
        proc_executor.wait()
Ejemplo n.º 2
0
def predict():
    app = StudentNetwork()
    reader = get_reader_fn()(input_glob=app.config.predict_input_fp,
                             input_schema=app.config.input_schema,
                             is_training=False,
                             batch_size=app.config.predict_batch_size)
    writer = get_writer_fn()(output_glob=app.config.predict_output_fp,
                             output_schema=app.config.output_schema)
    app.run_predict(reader=reader,
                    writer=writer,
                    checkpoint_path=app.config.predict_checkpoint_path)
Ejemplo n.º 3
0
    def predict(self):
        if ".ckpt" in self.config.predict_checkpoint_path:
            predict_reader = get_reader_fn()(input_glob=self.config.predict_input_fp,
                                             batch_size=self.config.predict_batch_size,
                                             is_training=False,
                                             input_schema=self.config.input_schema)

            predict_writer = get_writer_fn()(output_glob=self.config.predict_output_fp,
                                             output_schema=self.config.output_schema)

            self.run_predict(predict_reader,
                            predict_writer,
                            checkpoint_path=self.config.predict_checkpoint_path)
        else:
            self.config.mode = "predict_on_the_fly"
            self.mode = "predict_on_the_fly"
            run_app_predictor(self.config)
 def get_default_writer(self):
     return get_writer_fn(self.config.predict_output_fp)(output_glob=self.config.predict_output_fp,
                                                         output_schema=self.config.output_schema,
                                                         slice_id=self.worker_id,
                                                         input_queue=queue.Queue())