Ejemplo n.º 1
0
 def __init__(self, args, **kwargs):
     """ Initializes a util class for sequence generation. """
     super(SequenceGenerator, self).__init__(**kwargs)
     self._output_file = args["output_file"]
     self._save_metric = args["save_metric"]
     self._metric = self.task.get_eval_metric(args, ds=self.custom_dataset)
     self._search_layer = build_search_layer(args)
Ejemplo n.º 2
0
 def __init__(self, args, **kwargs):
     """ Initializes a util class for sequence generation. """
     super(MaskSequenceGenerator, self).__init__(**kwargs)
     self._output_file = args["output_file"]
     self._save_metric = args["save_metric"]
     self._metric = self.task.get_eval_metric(args, ds=self.custom_dataset)
     self._search_layer = build_search_layer(args)
     self._apply_mask = args["apply_mask"]
     if args["mask_dir"]:
             self.mask_dir = args["mask_dir"][0]
             # self.load_mask = np.load(self.mask_dir, allow_pickle=True)
             with open(self.mask_dir, 'rb') as f:
                 self.load_mask = pickle.load(f)
Ejemplo n.º 3
0
 def build(self, strategy, task, model):
     super(SeqGenerationValidator, self).build(strategy, task, model)
     if self._custom_dataset is None:
         logging.info("WARNING: no validation dataset is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     self._gen_metric = task.get_eval_metric(self.args,
                                             name="eval_metric",
                                             ds=self._custom_dataset)
     if self._gen_metric is None:
         logging.info("WARNING: no metric is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     self._gen_metric.flag = self.args["eval_metric.class"]
     search_layer = build_search_layer(
         self.args["eval_search_method.class"],
         **self.args["eval_search_method.params"])
     if search_layer is None:
         logging.info("WARNING: no search method is provided "
                      "in SeqGenerationValidator for validation process.")
         self._validate_gen = False
         return self
     from neurst.exps.sequence_generator import SequenceGenerator
     with training_utils.get_strategy_scope(strategy):
         self._gen_model = SequenceGenerator.build_generation_model(
             task, model, search_layer)
         self._gen_tfds = training_utils.build_datasets(
             compat.ModeKeys.INFER, strategy, self._custom_dataset, task,
             True, self._eval_task_args)
         if isinstance(self._custom_dataset, MultipleDataset):
             for name in list(self._gen_tfds.keys()):
                 if self._custom_dataset.datasets[name].targets is None:
                     logging.info(
                         f"WARNING: no ground truth found for validation dataset {name}."
                     )
                     self._gen_tfds.pop(name)
             if len(self._gen_tfds) == 0:
                 logging.info(
                     "WARNING: no ground truth found for all validation datasets and "
                     "no validation will be applied.")
                 self._validate_gen = False
                 return self
         else:
             if self._custom_dataset.targets is None:
                 logging.info(
                     "WARNING: no ground truth found for validation dataset and "
                     "no validation will be applied.")
                 self._validate_gen = False
                 return self
     if isinstance(self._custom_dataset, MultipleDataset):
         self._gen_recorder = {
             name:
             training_utils.TrainingStatusRecorder(model=model,
                                                   task=task,
                                                   metric=self._gen_metric)
             for name in self._gen_tfds
         }
         self._mixed_gen_recorder = training_utils.TrainingStatusRecorder(
             model=model, task=task, metric=self._gen_metric)
         self._avg_gen_recorder = training_utils.TrainingStatusRecorder(
             model=model,
             task=task,
             metric=self._gen_metric,
             estop_patience=self.args["eval_estop_patience"],
             best_checkpoint_path=self.args["eval_best_checkpoint_path"],
             auto_average_checkpoints=self.
             args["eval_auto_average_checkpoints"],
             best_avg_checkpoint_path=self.
             args["eval_best_avg_checkpoint_path"],
             top_checkpoints_to_keep=self.
             args["eval_top_checkpoints_to_keep"])
     else:
         self._gen_recorder = training_utils.TrainingStatusRecorder(
             model=model,
             task=task,
             metric=self._gen_metric,
             estop_patience=self.args["eval_estop_patience"],
             best_checkpoint_path=self.args["eval_best_checkpoint_path"],
             auto_average_checkpoints=self.
             args["eval_auto_average_checkpoints"],
             best_avg_checkpoint_path=self.
             args["eval_best_avg_checkpoint_path"],
             top_checkpoints_to_keep=self.
             args["eval_top_checkpoints_to_keep"])
     from neurst.exps.sequence_generator import SequenceGenerator
     self._postprocess_fn = lambda y: SequenceGenerator.postprocess_generation(
         task, y)
     self._gen_start_time = time.time()
     return self