Exemplo n.º 1
0
 def write_outputs(self, tasks, trial, split):
     """Write model prediction to disk."""
     utils.log("Writing out predictions for", tasks, split)
     distill_input_fn, _, _ = self._preprocessor.prepare_predict(
         tasks, split)
     results = self._estimator.predict(input_fn=distill_input_fn,
                                       yield_single_examples=True)
     # task name -> eid -> model-logits
     logits = collections.defaultdict(dict)
     for r in results:
         if r["task_id"] != len(self._tasks):
             r = utils.nest_dict(r, self._config.task_names)
             task_name = self._config.task_names[r["task_id"]]
             logits[task_name][r[task_name]["eid"]] = (
                 r[task_name]["logits"] if "logits" in r[task_name] else
                 r[task_name]["predictions"])
     for task_name in logits:
         utils.log("Pickling predictions for {:} {:} examples ({:})".format(
             len(logits[task_name]), task_name, split))
         if split == "train":
             if trial <= self._config.n_writes_distill:
                 utils.write_pickle(
                     logits[task_name],
                     self._config.distill_outputs(task_name, trial))
         else:
             if trial <= self._config.n_writes_test:
                 utils.write_pickle(
                     logits[task_name],
                     self._config.test_outputs(task_name, split, trial))
Exemplo n.º 2
0
 def _evaluate_task(self, task):
     """Evaluate the current model on the dev set."""
     utils.log("Evaluating", task.name)
     eval_input_fn, _, _ = self._preprocessor.prepare_eval(task)
     results = self._estimator.predict(input_fn=eval_input_fn,
                                       yield_single_examples=True)
     scorer = task.get_scorer()
     for r in results:
         if r["task_id"] != len(self._tasks):
             r = utils.nest_dict(r, self._config.task_names)
             scorer.update(r[task.name])
     utils.log(task.name + ": " + scorer.results_str())
     utils.log()
     return dict(scorer.get_results())