コード例 #1
0
ファイル: trainer.py プロジェクト: tvishrut/sling
    def evaluate(self):
        if self.evaluator is not None:
            if self.count != self.last_eval_count:
                # Use average parameters if available.
                self._swap_with_ema_parameters()

                metrics = self.evaluator(self.model)
                metrics["num_examples_seen"] = self.count
                self.checkpoint_metrics.append(metrics)
                eval_metric = metrics["eval_metric"]
                print("Eval metric after", self.count, " examples:",
                      eval_metric)

                if self.output_file_prefix is not None:
                    # Record the evaluation metric to a separate file.
                    if self.last_eval_count == 0:
                        f = open(self.output_file_prefix + ".evals", "w")
                        f.close()

                    f = open(self.output_file_prefix + ".evals", "a")
                    f.write("Slot_F1 after " + str(self.count) + " examples " +
                            str(eval_metric) + "\n")
                    f.close()

                    if self.best_metric is None or self.best_metric < eval_metric:
                        self.best_metric = eval_metric

                        best_flow_file = self.output_file_prefix + ".best.flow"
                        fl = flow.Flow()
                        self.model.to_flow(fl)
                        self.save_training_details(fl)
                        fl.save(best_flow_file)
                        print("Updating best flow at", best_flow_file)

                self.last_eval_count = self.count

                # Swap back.
                self._swap_with_ema_parameters()
コード例 #2
0
ファイル: viewmodel.py プロジェクト: yespon/sling
               default='',
               type=str,
               metavar='FLOW')
  flags.define('--training_details',
               help='Print training details or not',
               default=False,
               action='store_true')
  flags.define('--output_commons',
               help='Output file to store commons',
               default='',
               type=str,
               metavar='FILE')
  flags.parse()
  assert os.path.exists(flags.arg.flow), flags.arg.flow
 
  f = flow.Flow()
  f.load(flags.arg.flow)

  if flags.arg.training_details:
    details = f.blobs.get('training_details', None)
    if not details:
      print 'No training details in the flow file.'
    else:
      dictionary = pickle.loads(details.data)
      print 'Hyperparams:\n', dictionary['hyperparams'], '\n'
      print 'Number of examples seen:', dictionary['num_examples_seen']

      (final_loss, final_count) = dictionary['losses'][-1]['total']
      print 'Final loss', (final_loss / final_count)

      metrics = dictionary['checkpoint_metrics']