Exemplo n.º 1
0
 def build(self, strategy, task, model):
     """ Initializes. """
     self._strategy = strategy
     self._criterion: Criterion = build_criterion(
         self.args["eval_criterion.class"],
         **self.args["eval_criterion.params"])
     self._criterion.set_model(model)
     if self._criterion is None:
         logging.info(
             "WARNING: no criterion is provided in CriterionValidator "
             "for validation process.")
         self._validate_criterion = False
         return self
     self._custom_dataset = build_dataset(
         self.args["eval_dataset.class"],
         **self.args["eval_dataset.params"])
     if self._custom_dataset is None:
         logging.info("WARNING: no validation dataset is provided "
                      "in CriterionValidator for validation process.")
         self._validate_criterion = False
         return self
     from neurst.exps.evaluator import Evaluator
     with training_utils.get_strategy_scope(strategy):
         self._criterion_model = Evaluator.build_evaluation_model(
             task, model, self._criterion)
         self._eval_tfds = training_utils.build_datasets(
             compat.ModeKeys.EVAL, strategy, self._custom_dataset, task,
             True, self._eval_task_args)
     self._criterion_metric = self._criterion.as_metric()
     if isinstance(self._custom_dataset, MultipleDataset):
         self._criterion_recorder = {
             name: training_utils.TrainingStatusRecorder(
                 model=model, task=task, metric=self._criterion_metric)
             for name in self._custom_dataset.datasets
         }
         self._avg_criterion_recorder = training_utils.TrainingStatusRecorder(
             model=model, task=task, metric=self._criterion_metric)
         self._mixed_criterion_recorder = training_utils.TrainingStatusRecorder(
             model=model, task=task, metric=self._criterion_metric)
     else:
         self._criterion_recorder = training_utils.TrainingStatusRecorder(
             model=model, task=task, metric=self._criterion_metric)
     self._criterion_start_time = time.time()
     return self
Exemplo n.º 2
0
 def build(self, strategy, task, model):
     super(SeqGenerationValidator, self).build(strategy, task, model)
     if self._custom_dataset is None:
         logging.info("WARNING: no validation dataset is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     self._gen_metric = task.get_eval_metric(self.args,
                                             name="eval_metric",
                                             ds=self._custom_dataset)
     if self._gen_metric is None:
         logging.info("WARNING: no metric is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     self._gen_metric.flag = self.args["eval_metric.class"]
     search_layer = build_search_layer(
         self.args["eval_search_method.class"],
         **self.args["eval_search_method.params"])
     if search_layer is None:
         logging.info("WARNING: no search method is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     from neurst.exps.sequence_generator import SequenceGenerator
     with training_utils.get_strategy_scope(strategy):
         self._gen_model = SequenceGenerator.build_generation_model(
             task, model, search_layer)
         self._gen_tfds = training_utils.build_datasets(
             compat.ModeKeys.INFER, strategy, self._custom_dataset, task,
             True, self._eval_task_args)
         if isinstance(self._custom_dataset, MultipleDataset):
             for name in list(self._gen_tfds.keys()):
                 if self._custom_dataset.datasets[name].targets is None:
                     logging.info(
                         f"WARNING: no ground truth found for validation dataset {name}."
                     )
                     self._gen_tfds.pop(name)
             if len(self._gen_tfds) == 0:
                 logging.info(
                     "WARNING: no ground truth found for all validation datasets and "
                     "no validation will be applied.")
                 self._validate_gen = False
                 return self
         else:
             if self._custom_dataset.targets is None:
                 logging.info(
                     "WARNING: no ground truth found for validation dataset and "
                     "no validation will be applied.")
                 self._validate_gen = False
                 return self
     if isinstance(self._custom_dataset, MultipleDataset):
         self._gen_recorder = {
             name:
             training_utils.TrainingStatusRecorder(model=model,
                                                   task=task,
                                                   metric=self._gen_metric)
             for name in self._gen_tfds
         }
         self._mixed_gen_recorder = training_utils.TrainingStatusRecorder(
             model=model, task=task, metric=self._gen_metric)
         self._avg_gen_recorder = training_utils.TrainingStatusRecorder(
             model=model,
             task=task,
             metric=self._gen_metric,
             estop_patience=self.args["eval_estop_patience"],
             best_checkpoint_path=self.args["eval_best_checkpoint_path"],
             auto_average_checkpoints=self.
             args["eval_auto_average_checkpoints"],
             best_avg_checkpoint_path=self.
             args["eval_best_avg_checkpoint_path"],
             top_checkpoints_to_keep=self.
             args["eval_top_checkpoints_to_keep"])
     else:
         self._gen_recorder = training_utils.TrainingStatusRecorder(
             model=model,
             task=task,
             metric=self._gen_metric,
             estop_patience=self.args["eval_estop_patience"],
             best_checkpoint_path=self.args["eval_best_checkpoint_path"],
             auto_average_checkpoints=self.
             args["eval_auto_average_checkpoints"],
             best_avg_checkpoint_path=self.
             args["eval_best_avg_checkpoint_path"],
             top_checkpoints_to_keep=self.
             args["eval_top_checkpoints_to_keep"])
     from neurst.exps.sequence_generator import SequenceGenerator
     self._postprocess_fn = lambda y: SequenceGenerator.postprocess_generation(
         task, y)
     self._gen_start_time = time.time()
     return self