コード例 #1
0
    def train_and_evaluate(self):
        shuffle_buffer_size = self.config.shuffle_buffer_size if \
            hasattr(self.config, "shuffle_buffer_size") else None
        train_reader = get_reader_fn()(input_glob=self.config.train_input_fp,
                                       is_training=True,
                                       input_schema=self.config.input_schema,
                                       batch_size=self.config.train_batch_size,
                                       distribution_strategy=self.config.distribution_strategy,
                                       shuffle_buffer_size=shuffle_buffer_size)

        eval_reader = get_reader_fn()(input_glob=self.config.eval_input_fp,
                                      is_training=False,
                                      input_schema=self.config.input_schema,
                                      batch_size=self.config.eval_batch_size)

        self.run_train_and_evaluate(train_reader=train_reader, eval_reader=eval_reader)

        # Export the last checkpoints to saved_model
        try:
            checkpoint_path = tf.train.latest_checkpoint(self.config.model_dir, latest_filename=None)
            self.config.export_dir_base = self.config.model_dir
            self.config.checkpoint_path = checkpoint_path
            self.config.input_tensors_schema = self.get_input_tensor_schema()
            self.config.input_schema = self.config.input_tensors_schema
            self.config.receiver_tensors_schema = self.get_received_tensor_schema()
            self.config.mode = "export"
            tf.reset_default_graph()
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())
            self.export()
        except Exception as e:
            tf.logging.info(str(e))
コード例 #2
0
    def evaluate(self):
        eval_reader = get_reader_fn()(input_glob=self.config.eval_input_fp,
                                      input_schema=self.config.input_schema,
                                      is_training=False,
                                      batch_size=self.config.eval_batch_size)

        self.run_evaluate(reader=eval_reader, checkpoint_path=self.config.eval_ckpt_path)
コード例 #3
0
 def train(self):
     train_reader = get_reader_fn()(input_glob=self.config.train_input_fp,
                                    is_training=True,
                                    input_schema=self.config.input_schema,
                                    batch_size=self.config.train_batch_size,
                                    distribution_strategy=self.config.distribution_strategy)
     self.run_train(reader=train_reader)
コード例 #4
0
    def run(self):
        self.proc_executor = distribution.ProcessExecutor(self.queue_size)
        worker_id = self.config.task_index
        num_workers = len(self.config.worker_hosts.split(","))
        proc_executor = distribution.ProcessExecutor(self.queue_size)

        reader = get_reader_fn(self.config.preprocess_input_fp)(
            input_glob=self.config.preprocess_input_fp,
            input_schema=self.config.input_schema,
            is_training=False,
            batch_size=self.config.preprocess_batch_size,
            slice_id=worker_id,
            slice_count=num_workers,
            output_queue=proc_executor.get_output_queue())

        proc_executor.add(reader)
        preprocessor = preprocessors.get_preprocessor(
            self.config.tokenizer_name_or_path,
            thread_num=self.thread_num,
            input_queue=proc_executor.get_input_queue(),
            output_queue=proc_executor.get_output_queue(),
            preprocess_batch_size=self.config.preprocess_batch_size,
            user_defined_config=self.config,
            app_model_name=self.config.app_model_name)
        proc_executor.add(preprocessor)
        writer = get_writer_fn(self.config.preprocess_output_fp)(
            output_glob=self.config.preprocess_output_fp,
            output_schema=self.config.output_schema,
            slice_id=worker_id,
            input_queue=proc_executor.get_input_queue())

        proc_executor.add(writer)
        proc_executor.run()
        proc_executor.wait()
コード例 #5
0
    def _build_vocab(self):
        if hasattr(self.config, "vocab_path") and tf.gfile.Exists(self.config.vocab_path):
            return
        else:
            import os
            from easytransfer.preprocessors.deeptext_preprocessor import DeepTextVocab

            vocab = DeepTextVocab()
            reader = get_reader_fn()(input_glob=self.config.train_input_fp,
                                     is_training=False,
                                     input_schema=self.config.input_schema,
                                     batch_size=self.config.train_batch_size,
                                     distribution_strategy=self.config.distribution_strategy)
            for batch_idx, outputs in enumerate(self.estimator.predict(input_fn=reader.get_input_fn(),
                                                                       yield_single_examples=False,
                                                                       checkpoint_path=None)):
                if self.config.first_sequence in outputs:
                    for line in outputs[self.config.first_sequence]:
                        vocab.add_line(line)
                if self.config.second_sequence in outputs:
                    for line in outputs[self.config.second_sequence]:
                        vocab.add_line(line)
            vocab.filter_vocab_to_fix_length(self.config.max_vocab_size)
            self.config.vocab_path = os.path.join(self.config.model_dir, "train_vocab.txt")
            vocab.export_to_file(self.config.vocab_path)
コード例 #6
0
def train_and_evaluate():
    app = StudentNetwork()

    train_reader = get_reader_fn()(input_glob=app.config.train_input_fp,
                                   input_schema=app.config.input_schema,
                                   is_training=True,
                                   batch_size=app.config.train_batch_size)

    eval_reader = get_reader_fn()(input_glob=app.config.eval_input_fp,
                                  input_schema=app.config.input_schema,
                                  is_training=False,
                                  batch_size=app.config.eval_batch_size)

    app.run_train_and_evaluate(train_reader=train_reader,
                               eval_reader=eval_reader)

    tf.logging.info("Finished training")
コード例 #7
0
 def get_default_reader(self):
     return get_reader_fn(self.config.predict_input_fp)(input_glob=self.config.predict_input_fp,
                                                        input_schema=self.config.input_schema,
                                                        is_training=False,
                                                        batch_size=self.config.predict_batch_size,
                                                        output_queue=queue.Queue(),
                                                        slice_id=self.worker_id,
                                                        slice_count=self.num_workers)
コード例 #8
0
def predict():
    app = StudentNetwork()
    reader = get_reader_fn()(input_glob=app.config.predict_input_fp,
                             input_schema=app.config.input_schema,
                             is_training=False,
                             batch_size=app.config.predict_batch_size)
    writer = get_writer_fn()(output_glob=app.config.predict_output_fp,
                             output_schema=app.config.output_schema)
    app.run_predict(reader=reader,
                    writer=writer,
                    checkpoint_path=app.config.predict_checkpoint_path)
コード例 #9
0
def evaluate():
    app = StudentNetwork()

    eval_reader = get_reader_fn()(input_glob=app.config.eval_input_fp,
                                  input_schema=app.config.input_schema,
                                  is_training=False,
                                  batch_size=app.config.eval_batch_size)

    app.run_evaluate(reader=eval_reader,
                     checkpoint_path=app.config.eval_ckpt_path)

    tf.logging.info("Finished training")
コード例 #10
0
    def predict(self):
        if ".ckpt" in self.config.predict_checkpoint_path:
            predict_reader = get_reader_fn()(input_glob=self.config.predict_input_fp,
                                             batch_size=self.config.predict_batch_size,
                                             is_training=False,
                                             input_schema=self.config.input_schema)

            predict_writer = get_writer_fn()(output_glob=self.config.predict_output_fp,
                                             output_schema=self.config.output_schema)

            self.run_predict(predict_reader,
                            predict_writer,
                            checkpoint_path=self.config.predict_checkpoint_path)
        else:
            self.config.mode = "predict_on_the_fly"
            self.mode = "predict_on_the_fly"
            run_app_predictor(self.config)
コード例 #11
0
    def train_and_evaluate(self):
        shuffle_buffer_size = self.config.shuffle_buffer_size if \
            hasattr(self.config, "shuffle_buffer_size") else None
        train_reader = get_reader_fn(self.config.train_input_fp)(
            input_glob=self.config.train_input_fp,
            is_training=True,
            input_schema=self.config.input_schema,
            batch_size=self.config.train_batch_size,
            distribution_strategy=self.config.distribution_strategy,
            shuffle_buffer_size=shuffle_buffer_size)

        eval_reader = get_reader_fn(self.config.eval_input_fp)(
            input_glob=self.config.eval_input_fp,
            is_training=False,
            input_schema=self.config.input_schema,
            batch_size=self.config.eval_batch_size)

        if hasattr(self.config, "export_best_checkpoint"
                   ) and self.config.export_best_checkpoint:
            tf.logging.info("First train, then search for best checkpoint...")
            self.run_train(reader=train_reader)

            ckpts = set()
            with tf.gfile.GFile(os.path.join(self.config.model_dir,
                                             "checkpoint"),
                                mode='r') as reader:
                for line in reader:
                    line = line.strip()
                    line = line.replace("oss://", "")
                    ckpts.add(
                        int(
                            line.split(":")[1].strip().replace(
                                "\"",
                                "").split("/")[-1].replace("model.ckpt-", "")))

            best_score = float('-inf')
            best_ckpt = None
            best_eval_results = None
            best_metric_name = self.config.export_best_checkpoint_metric
            eval_results_list = list()
            for ckpt in sorted(ckpts):
                checkpoint_path = os.path.join(self.config.model_dir,
                                               "model.ckpt-" + str(ckpt))
                eval_results = self.run_evaluate(
                    reader=eval_reader, checkpoint_path=checkpoint_path)
                eval_results.pop('global_step')
                eval_results.pop('loss')
                score = eval_results[best_metric_name]
                _score = -1 * score if self.config.export_best_checkpoint_metric == "mse" else score
                if _score > best_score:
                    best_ckpt = ckpt
                    best_score = _score
                    best_eval_results = eval_results
                tf.logging.info(
                    "Ckpt {} 's {}: {:.4f}; Best ckpt {} 's {}: {:.4f}".format(
                        ckpt, best_metric_name, score, best_ckpt,
                        best_metric_name, best_score))
                eval_results_list.append((ckpt, eval_results))
            for ckpt, eval_results in eval_results_list:
                tf.logging.info("Checkpoint-%d: " % ckpt)
                for metric_name, score in eval_results.items():
                    tf.logging.info("\t{}: {:.4f}".format(metric_name, score))
            tf.logging.info("Best checkpoint: {}".format(best_ckpt))
            for metric_name, score in best_eval_results.items():
                tf.logging.info("\t{}: {:.4f}".format(metric_name, score))

            # Export best checkpoints to saved_model
            checkpoint_path = os.path.join(self.config.model_dir,
                                           "model.ckpt-" + str(best_ckpt))

        else:
            self.run_train_and_evaluate(train_reader=train_reader,
                                        eval_reader=eval_reader)

            # Export the last checkpoints to saved_model
            checkpoint_path = tf.train.latest_checkpoint(self.config.model_dir,
                                                         latest_filename=None)

        try:
            tf.logging.info("Export checkpoint {}".format(checkpoint_path))
            self.config.export_dir_base = self.config.model_dir
            self.config.checkpoint_path = checkpoint_path
            self.config.input_tensors_schema = self.get_input_tensor_schema()
            self.config.input_schema = self.config.input_tensors_schema
            self.config.receiver_tensors_schema = self.get_received_tensor_schema(
            )
            self.config.mode = "export"
            tf.reset_default_graph()
            self.estimator = tf.estimator.Estimator(
                model_fn=self._build_model_fn(),
                config=self._get_run_predict_config())
            self.export()
        except Exception as e:
            tf.logging.info(str(e))