Пример #1
0
    def _train_squad(self):
        """Runs BERT SQuAD training."""
        input_meta_data = self._read_input_meta_data_from_file()
        strategy = self._get_distribution_strategy()

        run_squad.train_squad(strategy=strategy,
                              input_meta_data=input_meta_data,
                              custom_callbacks=[self.timer_callback])
    def _train_squad(self, use_ds=True, run_eagerly=False):
        """Runs BERT SQuAD training."""
        input_meta_data = self._read_input_meta_data_from_file()
        strategy = self._get_distribution_strategy(use_ds)

        run_squad.train_squad(strategy=strategy,
                              input_meta_data=input_meta_data,
                              run_eagerly=run_eagerly,
                              custom_callbacks=[self.timer_callback])
  def _run_bert_squad(self):
    """Starts BERT SQuAD task."""
    with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
      input_meta_data = json.loads(reader.read().decode('utf-8'))

    strategy = distribution_utils.get_distribution_strategy(
        distribution_strategy='mirrored', num_gpus=self.num_gpus)

    run_squad.train_squad(
        strategy=strategy,
        input_meta_data=input_meta_data,
        custom_callbacks=[self.timer_callback])
Пример #4
0
    def _run_bert_squad(self):
        """Starts BERT SQuAD training and evaluation tasks."""
        with tf.io.gfile.GFile(FLAGS.input_meta_data_path, 'rb') as reader:
            input_meta_data = json.loads(reader.read().decode('utf-8'))

        strategy = distribution_utils.get_distribution_strategy(
            distribution_strategy='mirrored', num_gpus=self.num_gpus)

        run_squad.train_squad(strategy=strategy,
                              input_meta_data=input_meta_data,
                              custom_callbacks=[self.timer_callback])
        run_squad.predict_squad(strategy=strategy,
                                input_meta_data=input_meta_data)
        predictions_file = os.path.join(FLAGS.model_dir, 'predictions.json')
        eval_metrics = self._evaluate_squad(predictions_file)
        # Use F1 score as reported evaluation metric.
        self.eval_metrics = eval_metrics['f1']