Пример #1
0
def load_documents(path):
  """Loads Documents from a GZIP-ed TFRecords file into a Python list."""
  gzip_option = tf.python_io.TFRecordOptions(
      tf.python_io.TFRecordCompressionType.GZIP)

  def get_bytes_feature(ex, name):
    return list(ex.features.feature[name].bytes_list.value)

  def get_ints_feature(ex, name):
    # 32-bit Numpy arrays are more memory-efficient than Python lists.
    return np.array(ex.features.feature[name].int64_list.value, dtype=np.int32)

  docs = []
  for val in tf.python_io.tf_record_iterator(path, gzip_option):
    ex = tf.train.Example.FromString(val)
    title = get_bytes_feature(ex, 'title')[0]
    body = get_bytes_feature(ex, 'body')[0]

    doc_uid = featurization.get_document_uid(title, body)
    title_token_ids = get_ints_feature(ex, 'title_token_ids')
    body_token_ids = get_ints_feature(ex, 'body_token_ids')

    doc = featurization.Document(
        uid=doc_uid,
        title_token_ids=title_token_ids,
        body_token_ids=body_token_ids)
    docs.append(doc)

  return docs
Пример #2
0
  def retrieve(self, query_batch):
    # [batch_size, embed_dim]
    query_embeds = tf.zeros((len(query_batch), self.embed_dim))

    with tf.device('/CPU:0'):
      # [batch_size, total_candidates]
      cand_scores = tf.matmul(query_embeds, self._doc_embeds, transpose_b=True)
      _, top_ids_batch = tf.math.top_k(cand_scores, k=self._num_neighbors)

    title_ids = np.zeros(10, dtype=np.int32)
    body_ids = np.zeros(280, dtype=np.int32)

    retrievals_batch = []
    for top_ids in top_ids_batch:
      retrievals = [
          featurization.Document(0, title_ids, body_ids) for i in top_ids
      ]
      retrievals_batch.append(retrievals)
    return retrievals_batch
Пример #3
0
        STATS['masking_error'] += 1
        return None

    return query


def get_bytes_feature(ex, name):
    return list(ex.features.feature[name].bytes_list.value)


def get_ints_feature(ex, name):
    return list(ex.features.feature[name].int64_list.value)


NULL_DOCUMENT_UID = featurization.get_document_uid(b'', b'')
NULL_DOCUMENT = featurization.Document(NULL_DOCUMENT_UID, [], [])


def postprocess_candidates(candidates, query):
    """Perform additional processing on the candidates retrieved for a query.

  Args:
    candidates (list[Document]): a list of retrieved documents.
    query (Query): the query used to retrieve the documents.

  Returns:
    new_candidates (list[Document]): a list of the same size as candidates.
  """
    # If the query's originating document appears among candidates, remove it.
    candidates = [c for c in candidates if c.uid != query.orig_doc_uid]