예제 #1
0
def load_documents(path):
  """Loads Documents from a GZIP-ed TFRecords file into a Python list."""
  gzip_option = tf.python_io.TFRecordOptions(
      tf.python_io.TFRecordCompressionType.GZIP)

  def get_bytes_feature(ex, name):
    return list(ex.features.feature[name].bytes_list.value)

  def get_ints_feature(ex, name):
    # 32-bit Numpy arrays are more memory-efficient than Python lists.
    return np.array(ex.features.feature[name].int64_list.value, dtype=np.int32)

  docs = []
  for val in tf.python_io.tf_record_iterator(path, gzip_option):
    ex = tf.train.Example.FromString(val)
    title = get_bytes_feature(ex, 'title')[0]
    body = get_bytes_feature(ex, 'body')[0]

    doc_uid = featurization.get_document_uid(title, body)
    title_token_ids = get_ints_feature(ex, 'title_token_ids')
    body_token_ids = get_ints_feature(ex, 'body_token_ids')

    doc = featurization.Document(
        uid=doc_uid,
        title_token_ids=title_token_ids,
        body_token_ids=body_token_ids)
    docs.append(doc)

  return docs
예제 #2
0
def text_features_to_query(ex, featurizer):
    """Converts a dict of text features to a Query.

  Args:
    ex: a TF Example containing the features described below.
    featurizer: an instance of featurization.Featurizer

  Returns:
    a Query

  Each Example has the following features:
  - title: title of the document (just a bytes string).
  - text: raw text of the document (just a bytes string).
  - sentence_byte_start: byte offset for the start of each sentence (inclusive).
  - sentence_byte_limit: byte offset for the end of each sentence (exclusive).
  - span_byte_start: byte offset for the start of each salient span (inclusive).
  - span_byte_limit: byte offset for the end of each salient span (exclusive).
  """
    title = get_bytes_feature(ex, 'title')[0]
    body_text = get_bytes_feature(ex, 'text')[0]
    sentence_starts = get_ints_feature(ex, 'sentence_byte_start')
    sentence_limits = get_ints_feature(ex, 'sentence_byte_limit')
    span_starts = get_ints_feature(ex, 'span_byte_start')
    span_limits = get_ints_feature(ex, 'span_byte_limit')

    # List of (start, stop) byte offsets for each sentence (right-exclusive).
    sentence_boundaries = list(zip(sentence_starts, sentence_limits))

    # List of (start, stop) byte offsets for each salient span (right-exclusive).
    spans = list(zip(span_starts, span_limits))

    # Map spans to sentences.
    # Spans that do not strictly fall within a single sentence are omitted.
    span_to_sentence_boundaries = {}
    for span_start, span_stop in spans:
        for sent_start, sent_stop in sentence_boundaries:
            if span_start >= sent_start and span_stop <= sent_stop:
                span_to_sentence_boundaries[(span_start,
                                             span_stop)] = (sent_start,
                                                            sent_stop)
                break

    if not span_to_sentence_boundaries:
        # If there are no valid spans, skip this example.
        STATS['no_valid_spans'] += 1
        return None

    # Randomly sample a span.
    selected_span, selected_sentence_boundaries = random.choice(
        list(span_to_sentence_boundaries.items()))

    # Shift the span offsets to be relative to the sentence.
    selected_span = [
        offset - selected_sentence_boundaries[0] for offset in selected_span
    ]

    # Extract the sentence from the passage.
    sentence_text = body_text[
        selected_sentence_boundaries[0]:selected_sentence_boundaries[1]]

    try:
        sentence_tokens = featurizer.tokenizer.tokenize(sentence_text)
    except featurization.TokenizationError:
        # Tokenization errors can occur if we are unable to recover the byte offset
        # of a token in the original string. If so, skip this query.
        STATS['tokenization_error'] += 1
        return None

    doc_uid = featurization.get_document_uid(title, body_text)

    query = featurization.Query(text=sentence_text,
                                tokens=sentence_tokens,
                                mask_spans=[selected_span],
                                orig_doc_uid=doc_uid)

    try:
        featurizer.mask_query(query)
    except featurization.MaskingError:
        # If the masks cannot be appropriately applied, skip this query.
        STATS['masking_error'] += 1
        return None

    return query
예제 #3
0
        # If the masks cannot be appropriately applied, skip this query.
        STATS['masking_error'] += 1
        return None

    return query


def get_bytes_feature(ex, name):
    return list(ex.features.feature[name].bytes_list.value)


def get_ints_feature(ex, name):
    return list(ex.features.feature[name].int64_list.value)


NULL_DOCUMENT_UID = featurization.get_document_uid(b'', b'')
NULL_DOCUMENT = featurization.Document(NULL_DOCUMENT_UID, [], [])


def postprocess_candidates(candidates, query):
    """Perform additional processing on the candidates retrieved for a query.

  Args:
    candidates (list[Document]): a list of retrieved documents.
    query (Query): the query used to retrieve the documents.

  Returns:
    new_candidates (list[Document]): a list of the same size as candidates.
  """
    # If the query's originating document appears among candidates, remove it.
    candidates = [c for c in candidates if c.uid != query.orig_doc_uid]