def load_documents(path): """Loads Documents from a GZIP-ed TFRecords file into a Python list.""" gzip_option = tf.python_io.TFRecordOptions( tf.python_io.TFRecordCompressionType.GZIP) def get_bytes_feature(ex, name): return list(ex.features.feature[name].bytes_list.value) def get_ints_feature(ex, name): # 32-bit Numpy arrays are more memory-efficient than Python lists. return np.array(ex.features.feature[name].int64_list.value, dtype=np.int32) docs = [] for val in tf.python_io.tf_record_iterator(path, gzip_option): ex = tf.train.Example.FromString(val) title = get_bytes_feature(ex, 'title')[0] body = get_bytes_feature(ex, 'body')[0] doc_uid = featurization.get_document_uid(title, body) title_token_ids = get_ints_feature(ex, 'title_token_ids') body_token_ids = get_ints_feature(ex, 'body_token_ids') doc = featurization.Document( uid=doc_uid, title_token_ids=title_token_ids, body_token_ids=body_token_ids) docs.append(doc) return docs
def text_features_to_query(ex, featurizer): """Converts a dict of text features to a Query. Args: ex: a TF Example containing the features described below. featurizer: an instance of featurization.Featurizer Returns: a Query Each Example has the following features: - title: title of the document (just a bytes string). - text: raw text of the document (just a bytes string). - sentence_byte_start: byte offset for the start of each sentence (inclusive). - sentence_byte_limit: byte offset for the end of each sentence (exclusive). - span_byte_start: byte offset for the start of each salient span (inclusive). - span_byte_limit: byte offset for the end of each salient span (exclusive). """ title = get_bytes_feature(ex, 'title')[0] body_text = get_bytes_feature(ex, 'text')[0] sentence_starts = get_ints_feature(ex, 'sentence_byte_start') sentence_limits = get_ints_feature(ex, 'sentence_byte_limit') span_starts = get_ints_feature(ex, 'span_byte_start') span_limits = get_ints_feature(ex, 'span_byte_limit') # List of (start, stop) byte offsets for each sentence (right-exclusive). sentence_boundaries = list(zip(sentence_starts, sentence_limits)) # List of (start, stop) byte offsets for each salient span (right-exclusive). spans = list(zip(span_starts, span_limits)) # Map spans to sentences. # Spans that do not strictly fall within a single sentence are omitted. span_to_sentence_boundaries = {} for span_start, span_stop in spans: for sent_start, sent_stop in sentence_boundaries: if span_start >= sent_start and span_stop <= sent_stop: span_to_sentence_boundaries[(span_start, span_stop)] = (sent_start, sent_stop) break if not span_to_sentence_boundaries: # If there are no valid spans, skip this example. STATS['no_valid_spans'] += 1 return None # Randomly sample a span. selected_span, selected_sentence_boundaries = random.choice( list(span_to_sentence_boundaries.items())) # Shift the span offsets to be relative to the sentence. selected_span = [ offset - selected_sentence_boundaries[0] for offset in selected_span ] # Extract the sentence from the passage. sentence_text = body_text[ selected_sentence_boundaries[0]:selected_sentence_boundaries[1]] try: sentence_tokens = featurizer.tokenizer.tokenize(sentence_text) except featurization.TokenizationError: # Tokenization errors can occur if we are unable to recover the byte offset # of a token in the original string. If so, skip this query. STATS['tokenization_error'] += 1 return None doc_uid = featurization.get_document_uid(title, body_text) query = featurization.Query(text=sentence_text, tokens=sentence_tokens, mask_spans=[selected_span], orig_doc_uid=doc_uid) try: featurizer.mask_query(query) except featurization.MaskingError: # If the masks cannot be appropriately applied, skip this query. STATS['masking_error'] += 1 return None return query
# If the masks cannot be appropriately applied, skip this query. STATS['masking_error'] += 1 return None return query def get_bytes_feature(ex, name): return list(ex.features.feature[name].bytes_list.value) def get_ints_feature(ex, name): return list(ex.features.feature[name].int64_list.value) NULL_DOCUMENT_UID = featurization.get_document_uid(b'', b'') NULL_DOCUMENT = featurization.Document(NULL_DOCUMENT_UID, [], []) def postprocess_candidates(candidates, query): """Perform additional processing on the candidates retrieved for a query. Args: candidates (list[Document]): a list of retrieved documents. query (Query): the query used to retrieve the documents. Returns: new_candidates (list[Document]): a list of the same size as candidates. """ # If the query's originating document appears among candidates, remove it. candidates = [c for c in candidates if c.uid != query.orig_doc_uid]