Exemplo n.º 1
0
  def _evaluate_squad(self):
    """Runs BERT SQuAD evaluation."""
    input_meta_data = self._read_input_meta_data_from_file()
    strategy = self._get_distribution_strategy()

    run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)

    dataset = self._read_predictions_dataset_from_file()
    predictions = self._read_predictions_from_file()

    eval_metrics = squad_evaluate_v1_1.evaluate(dataset, predictions)
    # Use F1 score as reported evaluation metric.
    self.eval_metrics = eval_metrics['f1']
  def _evaluate_squad(self, use_ds=True):
    """Runs BERT SQuAD evaluation."""
    assert tf.version.VERSION.startswith('2.')
    input_meta_data = self._read_input_meta_data_from_file()
    strategy = self._get_distribution_strategy(use_ds)

    run_squad.predict_squad(strategy=strategy, input_meta_data=input_meta_data)

    dataset = self._read_predictions_dataset_from_file()
    predictions = self._read_predictions_from_file()

    eval_metrics = squad_evaluate_v1_1.evaluate(dataset, predictions)
    # Use F1 score as reported evaluation metric.
    self.eval_metrics = eval_metrics['f1']
Exemplo n.º 3
0
    def _run_bert_squad(self):
        """Starts BERT SQuAD training and evaluation tasks."""
        with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
            input_meta_data = json.loads(reader.read().decode('utf-8'))

        strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy='mirrored', num_gpus=self.num_gpus)

        run_squad.train_squad(strategy=strategy,
                              input_meta_data=input_meta_data,
                              custom_callbacks=[self.timer_callback])
        run_squad.predict_squad(strategy=strategy,
                                input_meta_data=input_meta_data)
        predictions_file = os.path.join(FLAGS.model_dir, 'predictions.json')
        eval_metrics = self._evaluate_squad(predictions_file)
        # Use F1 score as reported evaluation metric.
        self.eval_metrics = eval_metrics['f1']