Пример #1
0
  def __init__(self,
               *,
               validation_batched,
               fever_dev_path,
               max_evidence,
               max_recall_n = 15,
               checkpoint_dir,
               debug = False):
    """Compute the fever metrics from the batched dataset.

    Args:
      validation_batched: Batched dataset to compute metrics from.
      fever_dev_path: Path to fever dev data
      max_evidence: Max evidence to use
      max_recall_n: Stop computing recall after this position
      checkpoint_dir: If not none, then write validation predictions to disk
        here every epoch
      debug: Whether to enable debug handling of metrics
        since (some fail/error without full data)
    """
    super().__init__()
    self._validation_batched = validation_batched
    self._fever_dev_path = fever_dev_path
    self._max_recall_n = max_recall_n
    self._max_evidence = max_evidence
    self._checkpoint_dir = checkpoint_dir
    self._debug = debug
    self._validation_flat = list(self._validation_batched.unbatch())
    self._dev = util.read_jsonlines(fever_dev_path)
    self._verifiable_dev = [
        claim for claim in self._dev
        if claim['label'] != constants.NOT_ENOUGH_INFO]
    self._verifiable_dev_lookup = {
        claim['id']: claim for claim in self._verifiable_dev}
Пример #2
0
  def setup(self):
    self._claim_to_fold = {}
    train = util.read_jsonlines(self._fever_train_path)
    for claim in train:
      self._claim_to_fold[claim['id']] = 'train'
    dev = util.read_jsonlines(self._fever_dev_path)
    for claim in dev:
      self._claim_to_fold[claim['id']] = 'dev'

    test = util.read_jsonlines(self._fever_test_path)
    for claim in test:
      self._claim_to_fold[claim['id']] = 'test'

    self._wiki_db = wiki_db.WikiDatabase.from_local(self._wiki_db_path)
    self._wiki_titles = set(self._wiki_db.get_wikipedia_urls())
    self._matcher = text_matcher.TextMatcher()
    self._matcher.load(self._text_matcher_params_path)

    drqa_scrape_table = scrape_db.ScrapeDatabase.from_local(self._drqa_db_path)  # pylint: disable=unused-variable
    self._drqa_scrape_table = drqa_scrape_table

    lucene_scrape_table = scrape_db.ScrapeDatabase.from_local(
        self._drqa_db_path)  # pylint: disable=unused-variable
    self._lucene_scrape_table = lucene_scrape_table
    self._name_to_scrape = {
        constants.DRQA: self._drqa_scrape_table,
        constants.LUCENE: self._lucene_scrape_table
    }

    ukp_claim_docs = (
        util.read_jsonlines(self._ukp_docs_train_path) +
        util.read_jsonlines(self._ukp_docs_dev_path) +
        util.read_jsonlines(self._ukp_docs_test_path))
    self._ukp_docs = {claim['id']: claim for claim in ukp_claim_docs}
Пример #3
0
def main(_):
    tf.enable_v2_behavior()
    flags.mark_flag_as_required('model_root')
    flags.mark_flag_as_required('report_dir')

    root = pathlib.Path(FLAGS.model_root)
    report_dir = pathlib.Path(FLAGS.report_dir)
    logging.info('Reading predictions from model_root: %s', root)
    logging.info('Will write analysis to: %s', report_dir)

    # Config() contains non-model specific configuration, which is why its
    # fine to use this instead of the model's configuration.
    conf = config.Config()
    dev = {c['id']: c for c in util.read_jsonlines(conf.fever_dev)}
    logging.info('Reading fever TFDS examples')
    builder = fever_tfds.FeverEvidence(
        data_dir=util.readahead(conf.fever_evidence_tfds_data),
        n_similar_negatives=FLAGS.n_similar_negatives,
        n_background_negatives=FLAGS.n_background_negatives,
        train_scrape_type=FLAGS.train_scrape_type,
        include_not_enough_info=True,
        title_in_scoring=True)
    val = builder.as_dataset(split='validation')
    val_tfds_examples = [x for x in tqdm.tqdm(val, mininterval=10)]

    logging.info('Reading model predictions')
    model_predictions = read_model_predictions(root / 'val_predictions.json')
    val_df = parse_fold(fold_name='val',
                        model_predictions=model_predictions,
                        tfds_examples=val_tfds_examples)
    df = pd.concat([val_df])

    logging.info('Writing analysis to disk')
    write_summary(report_dir, df)
    write_per_claim_analysis(output_path=report_dir /
                             'claim_evidence_predictions.pickle',
                             df=df,
                             claim_lookup=dev)
Пример #4
0
 def _generate_examples(self, filepath, **kwargs):
   fever_claims = util.read_jsonlines(filepath)
   for claim in fever_claims:
     claim_id = claim["id"]
     claim_text = claim["claim"]
     claim_label = claim["label"]
     example_id = f"{claim_id}"
     yield claim_id, {
         "example_id": example_id,
         "claim_text": claim_text,
         "evidence_text": "",
         "wikipedia_url": "",
         # Ordinarily, this would (possibly) be concatenated to the evidence
         # but since this is claim only, I'm using a null integer value
         "sentence_id": "-1",
         # This label doesn't matter here since its claim only
         "evidence_label": constants.NOT_MATCHING,
         "claim_label": claim_label,
         "scrape_type": "",
         "metadata": json.dumps({
             "claim_id": claim_id,
         })
     }
Пример #5
0
 def _generate_examples(self, boolq_filepath, fold):
     boolq_claims = util.read_jsonlines(boolq_filepath)
     for idx, claim in enumerate(boolq_claims):
         example_id = f'{fold}-{idx}'
         example = {
             'example_id':
             example_id,
             'claim_text':
             claim['question'],
             'evidence_text':
             claim['passage'],
             'wikipedia_url':
             claim['title'],
             'sentence_id':
             '0',
             # This is effectively gold evidence
             'evidence_label':
             constants.MATCHING,
             'claim_label':
             constants.SUPPORTS if claim['answer'] else constants.REFUTES,
             'metadata':
             json.dumps({})
         }
         yield example_id, example
Пример #6
0
def main(_):
    flags.mark_flag_as_required('out_path')
    flags.mark_flag_as_required('wiki_embedding_dir')
    flags.mark_flag_as_required('claim_id_path')
    flags.mark_flag_as_required('claim_embedding_path')
    flags.mark_flag_as_required('n_shards')

    tf.enable_v2_behavior()

    conf = config.Config()
    logging.info('wiki_embedding_dir: %s', FLAGS.wiki_embedding_dir)
    logging.info('n_shards: %s', FLAGS.n_shards)
    logging.info('l2_norm: %s', FLAGS.l2_norm)
    logging.info('claim_id_path: %s', FLAGS.claim_id_path)
    logging.info('claim_embedding_path: %s', FLAGS.claim_embedding_path)
    logging.info('copy_to_tmp: %s', FLAGS.copy_to_tmp)
    logging.info('batch_size: %s', FLAGS.batch_size)

    with util.log_time('Building index'):
        index = Index(
            wiki_embedding_dir=FLAGS.wiki_embedding_dir,
            n_shards=FLAGS.n_shards,
            l2_norm=FLAGS.l2_norm,
            claim_id_path=FLAGS.claim_id_path,
            claim_embedding_path=FLAGS.claim_embedding_path,
            copy_to_tmp=FLAGS.copy_to_tmp,
            batch_size=FLAGS.batch_size,
            device=FLAGS.device,
        )
        index.build()

    logging.info('Reading claims from: %s', conf.fever_dev)
    dev = [
        c for c in util.read_jsonlines(conf.fever_dev)
        if c['label'] != constants.NOT_ENOUGH_INFO
    ]

    logging.info('Making predictions')
    claim_id_to_scored_keys = index.score_claim_to_wiki(n=5)

    formatted_predictions = []
    actual = []
    for claim in tqdm.tqdm(dev):
        claim_id = claim['id']
        predicted_evidence = []
        scored_keys = claim_id_to_scored_keys[claim_id]
        for index_key in scored_keys['wiki_keys']:
            # sentence_id is a numpy int, and fever scoring script only
            # accepts python int.
            predicted_evidence.append(
                [index_key.wikipedia_url,
                 int(index_key.sentence_id)])

        formatted_predictions.append({
            'id': claim_id,
            'predicted_label': constants.SUPPORTS,
            'predicted_evidence': predicted_evidence,
        })
        actual.append({'evidence': claim['evidence'], 'label': claim['label']})

    logging.info('FEVER Metrics')
    strict_score, accuracy_score, precision, recall, f1 = fever_score(
        formatted_predictions, actual)
    logging.info('Strict Score: %s', strict_score)
    logging.info('Accuracy Score: %s', accuracy_score)
    logging.info('Precision: %s', precision)
    logging.info('Recall: %s', recall)
    logging.info('F1: %s', f1)

    logging.info('Saving predictions and metrics to: %s', FLAGS.out_path)
    util.write_json(
        {
            'predictions': formatted_predictions,
            'metrics': {
                'strict_score': strict_score,
                'accuracy_score': accuracy_score,
                'precision': precision,
                'recall': recall,
                'f1': f1,
            }
        }, FLAGS.out_path)