def reduce_eval_results(criterion: Criterion, custom_dataset: Dataset, eval_results) -> Tuple[dict, dict, dict]: """ Args: criterion: A criterion instance. custom_dataset: The custom dataset object. eval_results: The prediction results from `make_predictions`. Returns: A tuple of dicts of evaluation results: - result dict of each dataset - averaged result - mixed result """ if isinstance(custom_dataset, MultipleDataset): res = {} avg_res = {} mixed_data = [] for k, v in eval_results.items(): res[k] = criterion.reduce_metrics(v) mixed_data.extend(v) for kk, vv in res[k].items(): if kk not in avg_res: avg_res[kk] = 0. avg_res[kk] += vv * custom_dataset.sample_weights[k] mixed_res = criterion.reduce_metrics(mixed_data) else: res = avg_res = mixed_res = criterion.reduce_metrics(eval_results) return (to_numpy_or_python_type(res), to_numpy_or_python_type(avg_res), to_numpy_or_python_type(mixed_res))
def reduce_sample_metrics(self, eval_res): """ Reduces the metrics at sample level. Args: eval_res: A tuple of numpy.ndarray or tensors generated by `self.__call__`. Returns: A list of dict of reduced metrics (scalar) for evaluation. """ nll_sum, _, nll_tokens = eval_res nll_sum = to_numpy_or_python_type(nll_sum) nll_tokens = to_numpy_or_python_type(nll_tokens) return [{ "nll": _nll, "ppl": 2.**(_nll / _tokens), "nll_per_token": _nll / _tokens } for _nll, _tokens in zip(nll_sum, nll_tokens)]
def gen(): ds = load_tfrecords(self._data_path, shuffle=False, auto_shard=False, name_to_features=self.fields, sharding_index=shard_id, num_shards=total_shards) for x in ds: data = to_numpy_or_python_type(x, bytes_as_str=True) if map_func is not None: data = map_func(data) yield data
def gen(): for data in load_tfrecords(self._data_path, shuffle=self._shuffle_dataset and shuffle, deterministic=(not shuffle), auto_shard=auto_shard, name_to_features=self.fields, feature_name_mapping={ self._feature_key: "audio", self._transcript_key: "transcript" }): data = to_numpy_or_python_type(data, bytes_as_str=True) if map_func is not None: data = map_func(data) yield data
def gen(): ds = load_tfrecords(self._data_path, shuffle=False, auto_shard=False, name_to_features=self.fields, sharding_index=shard_id, num_shards=total_shards, feature_name_mapping={ self._feature_key: "audio", self._transcript_key: "transcript", self._translation_key: "translation" }) for x in ds: data = to_numpy_or_python_type(x, bytes_as_str=True) if map_func is not None: data = map_func(data) yield data
def record(self, step, metric_result): """ Records the metrics and keep the best. """ metric_result = to_numpy_or_python_type(metric_result) if (self._best_metric_result is None or self._metric.greater_or_eq( metric_result, self._best_metric_result)): self._bad_count = 0 self._best_metric_result = metric_result else: self._bad_count += 1 # re-save the best checkpoint if self._keep_best_ckpt_saver is not None: start_time = time.time() stat = self._keep_best_ckpt_saver.save(step, metric_result) logging.info( "Checking the best checkpoints kept and %s. Elapsed %.2fs", "a new checkpoint was saved" if stat else "no checkpoint was saved.", time.time() - start_time) if self._average_ckpt_saver is not None: start_time = time.time() stat = self._average_ckpt_saver.save(step, metric_result) if stat: logging.info("An averaged checkpoint was saved. Elapsed %.2fs", time.time() - start_time) if self._estop_patience is not None: logging.info( f"Evaluating {self._metric.flag} at step={step} with bad count={self._bad_count} " f"(early_stop_patience={self._estop_patience}).") if self._estop_patience and self._bad_count >= self._estop_patience > 0: logging.info("Hit maximum patience! Early Stop!!!") # kill self and exit with code=0 def handler(*args): sys.exit(0) # register for signal signal.signal(signal.SIGUSR1, handler) os.kill(os.getpid(), signal.SIGUSR1)
def postprocess_generation(task, generations): generations = numpy.concatenate(to_numpy_or_python_type(generations), 0) postprocess_fn = task.get_data_postprocess_fn(compat.ModeKeys.INFER) generations = [postprocess_fn(x) for x in generations] return generations