示例#1
0
def has_answer(answer, doc_id, match):
    """Check if a document contains an answer string.

    If `match` is string, token matching is done between the text and answer.
    If `match` is regex, we search the whole text with the regex.
    If 'match' is find, string matching is done between the text and answer.
    """
    global PROCESS_DB, PROCESS_TOK
    text = PROCESS_DB.get_doc_text(doc_id)
    text = utils.normalize(text)
    if match == 'string':
        # Answer is a list of possible strings
        text = PROCESS_TOK.tokenize(text).words(uncased=True)
        for single_answer in answer:
            single_answer = utils.normalize(single_answer)
            single_answer = PROCESS_TOK.tokenize(single_answer)
            single_answer = single_answer.words(uncased=True)
            for i in range(0, len(text) - len(single_answer) + 1):
                if single_answer == text[i:i + len(single_answer)]:
                    return True
    elif match == 'regex':
        # Answer is a regex
        single_answer = utils.normalize(answer[0])
        if regex_match(text, single_answer):
            return True
    elif match == 'find':
        single_answer = utils.normalize(re.sub('\s+', '', answer[0]))
        if re.sub('\s+', '', text).find(single_answer) != -1:
            return True
    return False
示例#2
0
    def text2spvec(self, query):
        """Create a sparse tfidf-weighted word vector from query.

        tfidf = log(tf + 1) * log((N - Nt + 0.5) / (Nt + 0.5))
        """
        # Get hashed ngrams
        words = self.parse(utils.normalize(query))
        wids = [utils.hash(w, self.hash_size) for w in words]

        if len(wids) == 0:
            if self.strict:
                raise RuntimeError('No valid word in: %s' % query)
            else:
                logger.warning('No valid word in: %s' % query)
                return sp.csr_matrix((1, self.hash_size))

        # Count TF
        wids_unique, wids_counts = np.unique(wids, return_counts=True)
        tfs = np.log1p(wids_counts)

        # Count IDF
        Ns = self.doc_freqs[wids_unique]
        idfs = np.log((self.num_docs - Ns + 0.5) / (Ns + 0.5))
        idfs[idfs < 0] = 0

        # TF-IDF
        data = np.multiply(tfs, idfs)

        # One row, sparse csr matrix
        indptr = np.array([0, len(wids_unique)])
        spvec = sp.csr_matrix(
            (data, wids_unique, indptr), shape=(1, self.hash_size)
        )

        return spvec
示例#3
0
def get_contents(filename):
    """Parse the contents of a file. Each line is a JSON encoded document."""
    global PREPROCESS_FN
    documents = []
    with open(filename) as f:
        for line in f:
            # Parse document
            doc = json.loads(line)
            # Maybe preprocess the document with custom function
            if PREPROCESS_FN:
                doc = PREPROCESS_FN(doc)
            # Skip if it is empty or None
            if not doc:
                continue
            # Add the document
            documents.append((utils.normalize(doc['id']),
                              utils.normalize(doc['text'])))
    return documents
示例#4
0
 def get_doc_text(self, doc_id):
     """Fetch the raw text of the doc for 'doc_id'."""
     cursor = self.connection.cursor()
     cursor.execute(
         "SELECT text FROM documents WHERE id = ?",
         (utils.normalize(doc_id),)
     )
     result = cursor.fetchone()
     cursor.close()
     return result if result is None else result[0]
示例#5
0
    def get_title_scores_by_sim(self, query):
        """Compute all title scores based on similarity between title tokens and query tokens,
        stop words filtered.
           sim = 2 * len(common words) / (len(title words) + len(query_words))
        """
        # get unique word hash ids
        words = self.parse(utils.normalize(query))
        wids = [utils.hash(w, self.hash_size) for w in words]

        if len(wids) == 0:
            if self.strict:
                raise RuntimeError('No valid word in: %s' % query)
            else:
                logger.warning('No valid word in: %s' % query)
                return sp.csr_matrix((0, self.num_docs))

        wids_unique, wids_counts = np.unique(wids, return_counts=True)

        # get query sparse vector
        query_spvec = sp.csr_matrix(
            ([1] * len(wids_unique), wids_unique, [0, len(wids_unique)]), shape=(1, self.hash_size)
        )

        # get all titles' length, and get title csc_matrix for similarity computing
        if self.title_csc_matrix is None:
            self.get_title_csc_matrix()
        if self.titles_lens is None:
            self.get_titles_lens()

        self.title_tfidf.data = np.array([1] * len(self.title_tfidf.data))

        titles_scores = query_spvec * self.title_tfidf

        query_len_spvec = sp.csr_matrix(
            ([len(wids_unique)] * self.num_docs,
             list(range(self.num_docs)),
             [0, self.num_docs]), shape=(1, self.num_docs)
        )

        denominator = self.titles_lens + query_len_spvec
        titles_scores = 2 * titles_scores / denominator

        titles_scores = sp.csr_matrix(titles_scores)

        return titles_scores
示例#6
0
args = parser.parse_args()
t0 = time.time()

args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
    torch.cuda.set_device(args.gpu)
    logger.info('CUDA enabled (GPU %d)' % args.gpu)
else:
    logger.info('Running on CPU only.')

if args.candidate_file:
    logger.info('Loading candidates from %s' % args.candidate_file)
    candidates = set()
    with open(args.candidate_file) as f:
        for line in f:
            line = utils.normalize(line.strip()).lower()
            candidates.add(line)
    logger.info('Loaded %d candidates.' % len(candidates))
else:
    candidates = None

logger.info('Initializing pipeline...')
DrQA = pipeline.DrQA(
    reader_model=args.reader_model,
    fixed_candidates=candidates,
    embedding_file=args.embedding_file,
    tokenizer=args.tokenizer,
    batch_size=args.batch_size,
    cuda=args.cuda,
    data_parallel=args.parallel,
    ranker_config={'options': {'tfidf_path': args.retriever_model,