def train():
    print('Preprocessing raw data')
    preprocessor = Preprocessor()
    preprocessor.preprocess()

    dataset = Dataset(preprocessor)

    print('Training MF')
    mf = MF(preprocessor, dataset)
    mf.train_or_load_if_exists()

    print('Building I2I')
    i2i = Item2Item(dataset)

    print('Generating candidates')
    candidate_generator = CandidateGenerator(preprocessor, dataset, mf, i2i)
    X_train, y_train, q_train, q_train_reader = candidate_generator.generate_train()
    X_val, y_val, q_val, q_val_reader = candidate_generator.generate_val()

    import pickle
    try:
        with open('puke.pkl', 'wb') as f:
            pickle.dump((X_train, y_train, q_train, q_train_reader,
                         X_val, y_val, q_val, q_val_reader), f)
    except:
        print("Couldn't save puke")

    print('Training ranker')
    ranker = Ranker()
    ranker.train(X_train, y_train, q_train, X_val, y_val, q_val)
    ranker.save()

    print('Validating ranker')
    rank_scores = ranker.rank(X_val)
    print('ndcg', dataset.validate_ndcg(y_val, q_val, q_val_reader, rank_scores))
예제 #2
0
  def add_all_matching(self, hits, query, flt_tuple, max_hits):
    """
    hits is the dictionary to put results in
    query is the query string originally entered by user, used by ranking
    flt_tuple is [filter_regex, case_sensitive_bool]
    max_hits is largest hits should grow before matching terminates.
    """
    flt, case_sensitive = flt_tuple

    regex = re.compile(flt)
    base = 0
    ranker = Ranker()
    if not case_sensitive:
      index = self.lower_basenames_unsplit
    else:
      index = self.basenames_unsplit
    while True:
      m = regex.search(index, base)
      if m:
        hit = m.group(0)[1:-1]
        if hit.find('\n') != -1:
          raise Exception("Somethign is messed up with flt=[%s] query=[%s] hit=[%s]" % (flt,query,hit))
        rank = ranker.rank(query, hit)
        if case_sensitive:
          hit = hit.lower()
        if hit in hits:
          hits[hit] = max(hits[hit],rank)
        else:
          hits[hit] = rank
        base = m.end() - 1
        if len(hits) >= max_hits:
          truncated = True
          break
      else:
        break
예제 #3
0
 def add_all_wordstarts_matching(self, hits, query, max_hits):
     lower_query = query.lower()
     if lower_query in self.basenames_by_wordstarts:
         ranker = Ranker()
         for basename in self.basenames_by_wordstarts[lower_query]:
             rank = ranker.rank(query, basename)
             hits[basename] = rank
             if len(hits) >= max_hits:
                 return
예제 #4
0
 def add_all_wordstarts_matching( self, hits, query, max_hits ):
   lower_query = query.lower()
   if lower_query in self.basenames_by_wordstarts:
     ranker = Ranker()
     for basename in self.basenames_by_wordstarts[lower_query]:
       rank = ranker.rank(query, basename)
       hits[basename] = rank
       if len(hits) >= max_hits:
         return
def inference():
    preprocessor = Preprocessor(first_time=False)
    preprocessor.preprocess()
    dataset = Dataset(preprocessor)
    mf = MF(preprocessor, dataset)
    mf.load()
    i2i = Item2Item(dataset)
    candidate_generator = CandidateGenerator(preprocessor, dataset, mf, i2i)
    ranker = Ranker()
    ranker.load()

    X_submit, X_article_nums, q_submit, q_reader = candidate_generator.generate_submit()
    try:
        with open('submit_puke.pkl', 'wb') as f:
            pickle.dump((X_submit, X_article_nums, q_submit, q_reader), f)
    except:
        print("Couldn't save submit_puke")

    # X_submit, X_article_nums, q_submit, q_reader = pickle.load(open('submit_puke.pkl', 'rb'))

    rank_scores = ranker.rank(X_submit)
    base = 0
    entire_articles = []
    not_heavy_items = set(range(1, article_count+1)) - set(preprocessor.heavy_items)
    not_heavy_items = sorted(not_heavy_items)
    cut = 50

    random.seed(0)
    with result_path.open('w') as fout:
        for group_size, reader in tqdm(zip(q_submit, q_reader), total=len(q_submit)):
            articles = X_article_nums[base:base+group_size]
            scores = rank_scores[base:base+group_size]

            articles = [a for _, a in sorted(zip(scores, articles), key=lambda x: x[0], reverse=True)]
            articles = articles[:cut]
            from_followable = candidate_generator.get_readers_followable_articles(reader)
            # from_keywords = candidate_generator.get_readers_keyword_articles(reader)
            for item in from_followable:
                if len(articles) >= cut + 15:
                    break
                if item in articles:
                    continue
                articles.append(item)
            while len(articles) < 100:
                item = random.choice(not_heavy_items)
                if item not in articles:
                    articles.append(item)
            entire_articles.extend(articles)

            reader_str = preprocessor.num2reader[reader]
            article_strs = map(preprocessor.num2article.get, articles)

            fout.write('%s %s\n' % (reader_str, ' '.join(article_strs)))

            base += group_size
    print('Entropy of candidates = ', entropy(entire_articles))
예제 #6
0
    def add_all_matching(self, hits, query, flt_tuple, max_hits):
        """
    hits is the dictionary to put results in
    query is the query string originally entered by user, used by ranking
    flt_tuple is [filter_regex, case_sensitive_bool]
    max_hits is largest hits should grow before matching terminates.
    """
        flt, case_sensitive = flt_tuple

        regex = re.compile(flt)
        base = 0
        ranker = Ranker()
        if not case_sensitive:
            index = self.lower_basenames_unsplit
        else:
            index = self.basenames_unsplit
        while True:
            m = regex.search(index, base)
            if m:
                hit = m.group(0)[1:-1]
                if hit.find('\n') != -1:
                    raise Exception(
                        "Somethign is messed up with flt=[%s] query=[%s] hit=[%s]"
                        % (flt, query, hit))
                rank = ranker.rank(query, hit)
                if case_sensitive:
                    hit = hit.lower()
                if hit in hits:
                    hits[hit] = max(hits[hit], rank)
                else:
                    hits[hit] = rank
                base = m.end() - 1
                if len(hits) >= max_hits:
                    truncated = True
                    break
            else:
                break
예제 #7
0
argprse.add_argument("-c", "--hsv", required = True,
	help = "File Path where the computed hsv index is saved")
argprse.add_argument("-t", "--texture", required = True,
	help = "File Path where the computed texture index is saved")
argprse.add_argument("-b", "--btree", required = True,
	help = "File Path where the computed tree index is saved")
argprse.add_argument("-q", "--query", required = True,
	help = "File Path to the query image")
argmnts = vars(argprse.parse_args())

# loading the query image and describing its color, texture and tree features
query_img = cv2.imread(argmnts["query"])
cfeats = cdes.describe_color(copy.copy(query_img))
texture = txdes.describe_texture(copy.copy(query_img))
tree = tdes.color_tree(copy.copy(query_img))
 
# ranking the images in our dataset based on the query image
ranker = Ranker(argmnts["hsv"], argmnts["texture"], argmnts["btree"])
final_results = ranker.rank(cfeats, texture, tree)

current_path = os.path.dirname(os.path.abspath(__file__))

# iterating over the final results
for (score, resID) in final_results:
	# printing the image names in the order of increasing score
	print resID + "    "+ str(score)
	source_path = argmnts["dataset"]+"/"+ resID
	dest_path = current_path+"/result/"+resID
	shutil.copy2(source_path,dest_path)

예제 #8
0
    help="File Path where the computed texture index is saved")
argprse.add_argument("-b",
                     "--btree",
                     required=True,
                     help="File Path where the computed tree index is saved")
argprse.add_argument("-q",
                     "--query",
                     required=True,
                     help="File Path to the query image")
argmnts = vars(argprse.parse_args())

# loading the query image and describing its color, texture and tree features
query_img = cv2.imread(argmnts["query"])
cfeats = cdes.describe_color(copy.copy(query_img))
texture = txdes.describe_texture(copy.copy(query_img))
tree = tdes.color_tree(copy.copy(query_img))

# ranking the images in our dataset based on the query image
ranker = Ranker(argmnts["hsv"], argmnts["texture"], argmnts["btree"])
final_results = ranker.rank(cfeats, texture, tree)

current_path = os.path.dirname(os.path.abspath(__file__))

# iterating over the final results
for (score, resID) in final_results:
    # printing the image names in the order of increasing score
    print resID + "    " + str(score)
    source_path = argmnts["dataset"] + "/" + resID
    dest_path = current_path + "/result/" + resID
    shutil.copy2(source_path, dest_path)
예제 #9
0
def evaluate_ranking(seed_file,
                     candidate_file,
                     negative_file,
                     data_dir,
                     rankings,
                     max_cand,
                     representation,
                     test_ratio,
                     online,
                     selection=None,
                     max_pages=1,
                     prf=False,
                     seednumbs=None):
    """
    test_ratio: percentage of test urls splitted from seed urls
    """
    t = time.time()
    seed_urls = URLUtility.load_urls(seed_file)
    cand_urls = URLUtility.load_urls(candidate_file)
    neg_urls = URLUtility.load_urls(negative_file)

    # Split train and test urls
    split = int((1 - test_ratio) * len(seed_urls))
    test_urls = seed_urls[split:]
    train_urls = seed_urls[:split]

    # Fetch the train, test and candidate sites
    print "Loading the cache"
    fetcher = Fetcher(data_dir)
    if selection == "mix":
        # This is to prove the yet ineffectiveness of multipages representation
        train_selection = test_selection = "search"
        cand_selection = "random"
    else:
        train_selection = test_selection = cand_selection = selection

    print "\nFetching train sites"
    train_sites = fetcher.fetch_sites(train_urls, max_pages, train_selection,
                                      online)

    print "Time to fetch train sites: ", time.time() - t
    t = time.time()

    if seednumbs:
        seednumbs = get_seednumbs(seednumbs[0], len(train_sites), seednumbs[1])
    else:
        seednumbs = [len(train_sites)]
    print "seednumbs", seednumbs
    for seednumb in seednumbs:
        train_sites = train_sites[:seednumb + 1]
        #for s in train_sites:
        #    for p in s:
        #        print p.get_url()
        print "\nFetching cand sites"
        cand_sites = fetcher.fetch_sites(cand_urls, max_pages, cand_selection,
                                         online)
        print "\nFetching test sites"
        test_sites = fetcher.fetch_sites(test_urls, max_pages, test_selection,
                                         online)
        print "\nFetching negative sites"
        neg_sites = fetcher.fetch_sites(neg_urls, 1, None, online)
        print "Time to fetch cand, test, neg sites: ", time.time() - t

        cand_sites = cand_sites[:max_cand]
        max_cand -= len(test_sites)
        cand_sites.extend(test_sites)
        print "Number of seed sites: ", len(train_sites)
        print "Number of test sites: ", len(test_sites)
        print "Number of candidate sites: ", len(cand_sites)
        print "Ranking methods: ", rankings
        if online:
            print "Running online mode"
        else:
            print "Running offline mode"

        # Initialize the ranking models
        for ranking in rankings:
            # Train
            print "Ranking..."
            t = time.time()
            ranker = Ranker(
                copy.deepcopy(train_sites), representation, ranking, neg_sites
            )  # train_sites might be changed in the object initialization
            print "Time to initialize ranker: ", time.time() - t
            t = time.time()
            top_sites = ranker.rank(cand_sites, prf)
            print "Time to rank: ", time.time() - t

            # Evaluate
            print "Evaluating ranking results"
            site2rank = {}
            site2website = {}
            for i, site_score in enumerate(top_sites):
                site = site_score[0].get_host()
                if site not in site2rank:
                    site2rank[site] = i
                    site2website[site] = site_score[0]
            test_scores = []
            #test_count = 0
            for url in test_urls:
                site = URLUtility.get_host(url)
                if site in site2rank:
                    #test_count += 1
                    print site, site2rank[site]
                    print[p.get_url() for p in site2website[site]]
                    test_scores.append(site2rank[site])
            test_scores = sorted(test_scores)
            mean = sum(test_scores) / float(len(test_scores))
            mean = round(mean, 2)
            median = test_scores[(len(test_scores) - 1) / 2]
            #prec_at_k = round(len([s for s in test_scores if s<=len(test_urls)])/float(test_count), 4)*100
            prec_at_k = round(
                len([s for s in test_scores if s < len(test_scores)]) /
                float(len(test_scores)), 4) * 100
            precs = compute_prec(test_scores)
            print "RESULTS_SEEDNUMB", len(train_sites)
            print "RESULTS_RAW," + ranking + ',' + ','.join(
                [str(s) for s in test_scores])
            print "RESULTS_AGGREGATION," + ranking + ',' + str(
                mean) + ',' + str(median) + ',' + str(prec_at_k)
            print "RESULTS_PRECS", ranking + ',' + ','.join(
                [str(p) for p in precs])

            # Debug: print top 10 urls
            print "Top 10 urls: "
            for item in top_sites[:20]:
                print item[0].get_host(), item[1]
                print[p.get_url() for p in item[0]]

            # Clear the pre-computed vectorization from previous runs
            clear(train_sites)
            clear(cand_sites)
            clear(test_sites)
            clear(neg_sites)
예제 #10
0
    relevances = dict()
    for queryid, pmid, relevance in lines:
        relevance = float(relevance)
        query = relevances.setdefault(queryid, {})
        query.setdefault(pmid, relevance + 1 - int(relevance == 2) * 2)

    tokenizer.make_rule(rules.stopping, 'stopwords.txt')
    index = Index.segment_on_load('output', tokenizer=tokenizer)
    ranker = Ranker(index)

    with open('queries.txt', 'r') as fin:
        queries = [tuple(line.strip().split('\t')) for line in fin]

    length = max(len(query) for qid, query in queries)
    queries = {query: ranker.rank(query[1]) for query in queries}

    formatter_header = '{:2}   {:>14}   {:>14}   {:>14}   {:>14}   {:>14}'
    formatter_body = '{:02d}   {:>14.3f}   {:>14.3f}   {:>14.3f}   {:>14.3f}   {:>14.3f}'

    avg_formatter_header = '{:>14}   {:>14}   {:>14}   {:>14}   {:>14}'
    avg_formatter_body = '{:>14.3f}   {:>14.3f}   {:>14.3f}   {:>14.3f}   {:>14.3f}'

    for i in [10, 20, 50]:
        print('Top-{}'.format(i))
        print(
            formatter_header.format('', 'Precision', 'Recall', 'F-Measure',
                                    'Avg. Precision', 'NDCG'))
        ap = ar = af = aa = an = 0
        for query, ranking in sorted(queries.items(),
                                     key=lambda item: int(item[0][0])):
예제 #11
0
class DialogManager:
    def __init__(self):
        print('Loading dialog manager...')

        self.retriever = Retrival('data/qa_processed')
        self.ranker = Ranker('localhost')

        # self.ANSWER_TEMPLATE = "您可能想问:%s\n最佳答案:%s \n distance: %f"  # debug
        self.ANSWER_TEMPLATE = "您是说:“%s” 吗?  Friday想了想:%s"

        # simple_state_tracker
        self.tracker = Tracker()

        self.threshold = 0.02  # 问题是否不确定的阈值

        self.spider = None  # TODO
        print('DialogManager established.')

    def get_answer(self, query):
        # 输入query,进行文本处理和布尔搜索
        query = simple_process(query)
        bool_state = self.retriever.input_query(query)
        query = self.retriever.query  # 可能进行了文本纠错

        # 检测最近一次对话,处理信息缺失情况. e.g. “那xx呢?”
        if self.tracker.check(query):
            tmp = self.tracker.fill_query()
            if len(tmp) > 0:
                query = tmp
                bool_state = self.retriever.input_query(query)

        # 对召回的结果进行更进一步的排序
        candi_idx, candi_q = self.recall_candidates(bool_state)
        if len(candi_idx) == 0:
            if len(self.allowed_words) == 0:
                return '非常抱歉,我不明白您的意思'
            else:
                return '非常抱歉,我不明白您的意思。你可以问问其他关于“{}”的问题'.format(
                    ' '.join(self.allowed_words))
        results = self.ranker.rank(query, candi_idx, candi_q, top=1) # (q_id, distance)

        # 读取结果
        best_match = self.retriever.data.iloc[results[0][0]].question
        answer = self.retriever.data.iloc[results[0][0]].answer
        distance = results[0][1]
        # 保存最近一次对话
        filted_words = self.retriever.allowed_words
        self.tracker.previous_cache = (query, filted_words)

        if distance < self.threshold:
            return str(answer)
        else:
            return self.ANSWER_TEMPLATE % (best_match, answer)

        # for debug
        # return self.ANSWER_TEMPLATE % (best_match, answer, distance)


    def recall_candidates(self, bool_state):
        result_candidates = set()
        # 处理当布尔搜索没有结果的情况。在input_query时,已经通过文本纠错和放宽约束
        # 只进行词向量搜索。
        if bool_state == True:
            result_candidates.update(self.retriever.search_tfidf(top_k=10))
            result_candidates.update(self.retriever.search_bm25(top_k=10))
            result_candidates.update(self.retriever.search_editdist(top_k=10))
            result_candidates.update(self.retriever.search_fasttext(top_k=10))
        else:
            result_candidates.update(self.retriever.search_fasttext(top_k=20))

        candi_idx = list(result_candidates)
        candi_q = []
        for idx in result_candidates:
            candi_q.append(self.retriever.data.iloc[idx].question)

        return candi_idx, candi_q

    def eval(self, query, topn=10):
        # 输入query,进行文本处理和布尔搜索
        query = simple_process(query)
        bool_state = self.retriever.input_query(query)

        # 检测最近一次对话,处理信息缺失情况. e.g. “那xx呢?”
        cur_words = self.retriever.words
        if self.tracker.check(cur_words):
            query = self.tracker.fill_query(cur_words)
            bool_state = self.retriever.input_query(query)

        # 对召回的结果进行更进一步的排序
        candi_idx, candi_q = self.recall_candidates(bool_state)
        if len(candi_idx) == 0:
            return ['非常抱歉,我不明白您的意思' for i in range(topn)]
        results = self.ranker.rank(query, candi_idx, candi_q, top=topn) # (q_id, distance)
        q_list = []
        for idx, _ in results:
            q_list.append(self.retriever.data.iloc[idx].question)

        # 保存最近一次对话
        filted_words = self.retriever.allowed_words
        self.tracker.previous_cache = (query, filted_words)

        return q_list
예제 #12
0
    def run_mix_search(self,
                       ranking,
                       selection=None,
                       online=True,
                       max_results=50,
                       seed_keyword="gun",
                       search="kw",
                       iters=5,
                       representation='body',
                       negative_file=None):
        """
        seed_sites: urls that are used for search
        selected_urls: urls that were used for search

        Only top-ranked urls will become seed urls

        Important Args:
            ranking: a ranking method
            max_results: Maximum number of results to return in related and keyword search
        """
        max_pages = 1  # Always use single page to represent a website
        train_sites = self.fetcher.fetch_sites(self.train_urls, max_pages,
                                               selection, online)

        if negative_file:  # (random) reliably negative examples
            neg_urls = URLUtility.load_urls(negative_file)
            neg_urls = neg_urls[:200]
        else:
            neg_urls = []
        print "neg_urls: ", len(neg_urls)
        neg_sites = self.fetcher.fetch_sites(neg_urls, 1, None, online)
        ranker = Ranker(train_sites, representation, ranking, neg_sites)

        # Data
        scores = []  # Avoid exception when iters=0
        #seed_sites = self.train_urls # topk urls from each search batch
        seed_sites = train_sites  # topk urls from each search batch
        selected_urls = {}  # avoid searching with these urls again
        selected_urls['kw'] = set()
        selected_urls['bl'] = set()
        selected_urls['rl'] = set()
        selected_urls['fw'] = set()
        results = []  # Search results for ranking
        urls = set()  # Avoid fetch and rank these urls again
        sites = set()  # used to compute reward

        # Hyperparameters
        #max_numb_pages = 12000 # stop condition
        max_numb_pages = 51000  # stop condition
        #iters = 500
        iters = 2000
        k = 20  #  number of pages from the newly discovered pages to be added to the seed list
        max_kw = 20  # maximum number of keywords to select from the seed pages
        self.searcher.set_max_keywords(max_kw)

        # Initialize Search Operator Selection Strategy
        count = {}  # Count number of results yeilded by each search operator
        count['bl'] = count['kw'] = count['rl'] = count['fw'] = 0
        count['bl'] = 20000  # never choose this
        #ucb = UCB1(['rl', 'bl', 'kw'])
        ucb = UCB1(['rl', 'bl', 'kw', 'fw'])

        site_mode = False  # used in get_top_ranked_urls function

        for i in xrange(iters):
            t = time.time()

            print "Searching... ", len(seed_sites), "  seed urls"
            searchop = self.select_searchop(count, search, ucb)

            if searchop == 'rl' or searchop == 'bl':
                site_mode = True
            else:
                site_mode = False

            print "\n Iteration ", i, searchop
            new_urls = self.searcher.search(seed_sites, \
                                            searchop, seed_keyword=seed_keyword, \
                                            max_results=max_results)
            new_urls = [url for url in new_urls if url not in urls]

            if len(new_urls) == 0:
                print "Searcher found 0 url"
                seed_sites = self.get_top_ranked_urls(
                    scores, k, selected_urls[searchop], site_mode
                )  # Backlink search and related search only use host name to form the query. searchop!='kw' <-> searchop=='bl' or searchop=='rl'
                if len(seed_sites) == 0:
                    print "Stop. Running out of seeds"
                    break
                else:
                    continue

            urls.update(new_urls)

            print "Time to search ", i, ": ", time.time() - t
            t = time.time()

            new_sites = self.fetcher.fetch_sites(new_urls, max_pages,
                                                 selection, online)

            print "Time to fetch ", i, ": ", time.time() - t
            t = time.time()

            temp = len(results)
            results.extend(new_sites)
            print "Size of candidates (after): ", len(results)
            print "Number of new candidates (after): ", len(results) - temp
            scores = ranker.rank(results)
            if len(scores) >= max_numb_pages:
                print "Stop. Retrieved ", max_numb_pages, " pages"
                break
            #seed_sites = self.get_top_ranked_urls(scores, k, selected_urls[searchop])
            seed_sites = self.get_top_ranked_urls(
                scores, k, selected_urls[searchop], site_mode
            )  # Backlink search and related search only use host name to form the query. searchop!='kw' <-> searchop=='bl' or searchop=='rl'
            if len(seed_sites) == 0:
                print "Stop. Running out of seeds"
                break
            self.save_urls(new_sites, i)

            # Update information from the search results to the operation selector
            count[searchop] += len(new_urls)
            if (search == 'bandit') and new_sites:
                reward = self.get_reward(scores, new_sites, sites)
                print "UCB Rewards", searchop, reward
                ucb.update(searchop, reward, len(new_sites))
                sites.update([s.get_host() for s in new_sites])
            print "Time to rank ", i, ": ", time.time() - t

        self.save_scores(scores)
예제 #13
0
    def run(self,
            ranking,
            selection=None,
            online=True,
            max_results=50,
            seed_keyword="gun",
            searchop="kw",
            iters=5,
            representation='body',
            negative_file=None):
        """
        seed_sites: urls that are used for search
        selected_urls: urls that were used for search

        Only top-ranked urls will become seed urls

        Important Args:
            ranking: a ranking method
            max_results: Maximum number of results to return in related and keyword search
        """
        max_pages = 1  # Always use single page to represent a website
        train_sites = self.fetcher.fetch_sites(self.train_urls, max_pages,
                                               selection, online)

        if negative_file:  # (random) reliably negative examples
            neg_urls = URLUtility.load_urls(negative_file)
            neg_urls = neg_urls[:200]
        else:
            neg_urls = []
        print "neg_urls: ", len(neg_urls)
        neg_sites = self.fetcher.fetch_sites(neg_urls, 1, None, online)
        ranker = Ranker(train_sites, representation, ranking, neg_sites)

        # Data
        scores = []  # Avoid exception when iters=0
        #seed_sites = self.train_urls # topk urls from each search batch
        seed_sites = train_sites  # topk urls from each search batch
        selected_urls = set()  # avoid searching with these urls again
        results = []  # Search results for ranking
        urls = set()  # Avoid fetch and rank these urls again

        # Hyperparameters
        #max_numb_pages = 12000 # stop condition
        max_numb_pages = 51000  # stop condition
        #iters = 500
        iters = 2000

        k = 20  #  number of pages from the newly discovered pages to be added to the seed list
        max_kw = 20  # maximum number of keywords to select from the seed pages
        self.searcher.set_max_keywords(max_kw)
        """
        # Search Strategy
        blsearch = kwsearch = rlsearch = fwsearch = False
        if search == 'bl':
            blsearch = True
            print "Backlink search enable"
        elif search == 'rl':
            rlsearch = True
            print "Related search enable"
        elif search == 'kw':
            kwsearch =  True
            print "Keyword search enable"
        """
        site_mode = False  # used in get_top_ranked_urls function
        if searchop == 'rl' or searchop == 'bl':
            site_mode = True

        for i in xrange(iters):
            t = time.time()

            print "Searching... ", len(seed_sites), "  seed urls"
            print "\n Iteration ", i, searchop
            new_urls = self.searcher.search(seed_sites, searchop, \
                                            seed_keyword=seed_keyword, \
                                            max_results=max_results)
            new_urls = [url for url in new_urls if url not in urls]
            if len(new_urls) == 0:
                print "Searcher found 0 url"
                seed_sites = self.get_top_ranked_urls(scores, k, selected_urls,
                                                      site_mode)
                if len(seed_sites) == 0:
                    print "Stop. Running out of seeds"
                    break
                else:
                    continue

            urls.update(new_urls)

            print "Time to search ", i, ": ", time.time() - t
            t = time.time()

            new_sites = self.fetcher.fetch_sites(new_urls, max_pages,
                                                 selection, online)

            print "Time to fetch ", i, ": ", time.time() - t
            t = time.time()

            print "Size of candidates (before): ", len(results)
            results.extend(new_sites)
            print "Size of candidates (after): ", len(results)
            scores = ranker.rank(results)
            if len(scores) >= max_numb_pages:
                print "Stop. Retrieved ", max_numb_pages, " pages"
                break
            seed_sites = self.get_top_ranked_urls(scores, k, selected_urls,
                                                  site_mode)
            if len(seed_sites) == 0:
                print "Stop. Running out of seeds"
                break
            self.save_urls(new_sites, i)

            print "Time to rank ", i, ": ", time.time() - t

        self.save_scores(scores)
예제 #14
0
class RankerTest(unittest.TestCase):
  def setUp(self):
#    self.basenames = json.load(open('test_data/cr_files_basenames.json'))
    self.ranker = Ranker()

  def test_is_wordstart(self):
    def check(s, expectations):
      assert len(s) == len(expectations)
      for i in range(len(s)):
        self.assertEquals(expectations[i], self.ranker._is_wordstart(s, i), "disagreement on index %i" % i)

    check("foo", [True, False, False])
    check("fooBar", [True, False, False, True, False, False])
    check("o", [True])
    check("_", [True])
    check("F", [True])
    check("FooBar", [True, False, False, True, False, False])
    check("Foo_Bar", [True, False, False, False, True, False, False])
    check("_Bar", [True, True, False, False])
    check("_bar", [True, True, False, False])
    check("foo_bar", [True, False, False, False, True, False, False])

    check(".h", [True, False])
    check("a.h", [True, False, False])
    check("__b", [True, False, True])
    check("foo__bar", [True, False, False, False, False, True, False, False])

    check("Foo3D", [True, False, False, True, True])
    check("Foo33", [True, False, False, True, False])

    check("x3d", [True, True,  False]) # I could be convinced that 'd' is a wordstart.

    check("AAb", [True, True, False])
    check("CCFra", [True, True, True, False, False])

  def test_get_word_starts(self):
    data = {
      # This comment simply helps map indice to values
      # 1234567
      '' : [],
      'abc' : [0],
      'abd_def' : [0, 4],
      'ab_cd_ef' : [0, 3, 6],
      'ab_' : [0],
      'AA': [0, 1],
      'AAbA': [0,1,3],
      'Abc': [0],
      'AbcDef': [0,3],
      'Abc_Def': [0,4],
      }
    for word, expected_starts in data.items():
      starts = self.ranker.get_starts(word)
      self.assertEquals(expected_starts, starts, "for %s, expect %s" % (word, expected_starts))

  def assertBasicRankAndWordHitCountIs(self, expected_rank, expected_word_count, query, candidate):
    res = self.ranker._get_basic_rank(query, candidate)
    self.assertEquals(expected_rank, res[0])
    self.assertEquals(expected_word_count, res[1])

  def test_query_hits_on_word_starts(self):
    self.assertBasicRankAndWordHitCountIs(8, 4, 'rwhv', 'render_widget_host_view.cc') # test +1 for hitting all words
    self.assertBasicRankAndWordHitCountIs(6, 3, 'rwh', 'render_widget_host_view.cc')
    self.assertBasicRankAndWordHitCountIs(5.5, 2, 'wvi', 'render_widget_host_view_win.cc') # eew
    self.assertBasicRankAndWordHitCountIs(2, 1, 'w', 'WebViewImpl.cc')
    self.assertBasicRankAndWordHitCountIs(2, 1, 'v', 'WebViewImpl.cc')
    self.assertBasicRankAndWordHitCountIs(4, 2, 'wv', 'WebViewImpl.cc')
    self.assertBasicRankAndWordHitCountIs(5, 2, 'evi', 'WebViewImpl.cc')
    self.assertBasicRankAndWordHitCountIs(4, 2, 'wv', 'eWbViewImpl.cc')
    self.assertBasicRankAndWordHitCountIs(6, 0, 'ebewp', 'WebViewImpl.cc')


  def test_basic_rank_pays_attention_to_case(self):
    # these test that we aren't losing catching case transpitions
    self.assertBasicRankAndWordHitCountIs(4.5, 1, "rw", "rwf")
    self.assertBasicRankAndWordHitCountIs(4, 2, "rw", "rWf")

  def test_basic_rank_works_at_all(self):
    # these are generic tests
    self.assertBasicRankAndWordHitCountIs(8, 4, "rwhv", "render_widget_host_view.h")
    self.assertBasicRankAndWordHitCountIs(10, 5, "rwhvm", "render_widget_host_view_mac.h")
    self.assertBasicRankAndWordHitCountIs(10, 5, "rwhvm", "render_widget_host_view_mac.mm")

    self.assertBasicRankAndWordHitCountIs(29, 4, 'ccframerate', 'CCFrameRateController.cpp')


  def test_basic_rank_query_case_doesnt_influence_rank(self):
    a = self.ranker._get_basic_rank("Rwhvm", "render_widget_host_view_mac.h")
    b = self.ranker._get_basic_rank("rwhvm", "Render_widget_host_view_mac.h")
    self.assertEquals(a, b)

  def test_basic_rank_isnt_only_greedy(self):
    # this checks that we consider _mac and as a wordstart rather than macmm
    self.assertBasicRankAndWordHitCountIs(10, 5, "rwhvm", "render_widget_host_view_macmm")

  def test_basic_rank_on_corner_cases(self):
    self.assertBasicRankAndWordHitCountIs(0, 0, "", "")
    self.assertBasicRankAndWordHitCountIs(0, 0, "", "x")
    self.assertBasicRankAndWordHitCountIs(0, 0, "x", "")
    self.assertBasicRankAndWordHitCountIs(2, 1, "x", "x")
    self.assertBasicRankAndWordHitCountIs(1, 0, "x", "yx")
    self.assertBasicRankAndWordHitCountIs(0, 0, "x", "abcd")

  def test_basic_rank_on_mixed_wordstarts_and_full_words(self):
    self.assertBasicRankAndWordHitCountIs(17, 3, "enderwhv", "render_widget_host_view.h")
    self.assertBasicRankAndWordHitCountIs(15, 2, "idgethv", "render_widget_host_view.h")

    self.assertBasicRankAndWordHitCountIs(8, 4, "rwhv", "render_widget_host_view_mac.h")
    self.assertBasicRankAndWordHitCountIs(14, 5, "rwhvmac", "render_widget_host_view_mac.h")

    self.assertBasicRankAndWordHitCountIs(10, 5, "rwhvm", "render_widget_host_view_mac.h")

  def test_basic_rank_overconditioned_query(self):
    self.assertBasicRankAndWordHitCountIs(2, 1, 'test_thread_tab.py', 'tw')

  def test_basic_rank_on_suffixes_of_same_base(self):
    # render_widget.cpp should be ranked higher than render_widget.h
    # unless the query explicitly matches the .h or .cpp
    pass

  def test_rank_corner_cases(self):
    # empty
    self.assertEquals(0, self.ranker.rank('foo', ''))
    self.assertEquals(0, self.ranker.rank('', 'foo'))

    # undersized
    self.assertEquals(0, self.ranker.rank('foo', 'm'))
    self.assertEquals(0, self.ranker.rank('f', 'oom'))

    # overconditioned
    self.assertEquals(2, self.ranker.rank('test_thread_tab.py', 'tw'))

  def test_rank_subclasses_lower_ranked_than_base(self):
    # this tests that hitting all words counts higher than hitting some of the words
    base_rank = self.ranker.rank("rwhvm", "render_widget_host_view.h")
    subclass_rank = self.ranker.rank("rwhvm", "render_widget_host_view_subclass.h")
    self.assertTrue(base_rank > subclass_rank)

  def test_rank_order_for_hierarchy_puts_bases_first(self):
    names = ['render_widget_host_view_mac.h',
             'render_widget_host_view_mac.mm',
             'render_widget_host_view_mac_delegate.h',
             'render_widget_host_view_mac_unittest.mm',
             'render_widget_host_view_mac_editcommand_helper.mm',
             'render_widget_host_view_mac_editcommand_helper.h'
             'render_widget_host_view_mac_editcommand_helper_unittest.mm',
             ]
    self._assertRankDecreasesOrStaysTheSame("rwhvm", names)

  def _assertRankDecreasesOrStaysTheSame(self, query, names):
    """
    Makes suer that the first element in the array has highest rank
    and subsequent items have decreasing or equal rank.
    """
    ranks = [self.ranker.rank(query, n) for n in names]
    nw = [self.ranker.get_num_words(n) for n in names]
    basic_ranks = [self.ranker._get_basic_rank(query, n) for n in names]
    for i in range(1, len(ranks)):
      changeInRank = ranks[i] - ranks[i-1]
      self.assertTrue(changeInRank <= 0)

  def test_rank_order_prefers_capitals(self):
    # Ensure we still prefer capitals for simple queries The heuristics that
    # deal with order_puts_tests_second tends to break this.
    self.assertBasicRankAndWordHitCountIs(6, 3, 'wvi', 'WebViewImpl.cc')

  def test_rank_order_puts_tests_second(self):
    q = "ccframerate"
    a1 = self.ranker.rank(q, 'CCFrameRateController.cpp')
    a2 = self.ranker.rank(q, 'CCFrameRateController.h')
    b = self.ranker.rank(q, 'CCFrameRateControllerTest.cpp')

    # This is a hard test to pass because ccframera(te) ties to (Te)st
    # if you weight non-word matches equally.
    self.assertTrue(a1 > b);
    self.assertTrue(a2 > b);

    q = "chrome_switches"
    a1 = self.ranker.rank(q, 'chrome_switches.cc')
    a2 = self.ranker.rank(q, 'chrome_switches.h')
    b = self.ranker.rank(q, 'chrome_switches_uitest.cc')
    self.assertTrue(a1 > b);
    self.assertTrue(a2 > b);

  def test_rank_order_for_hierarchy_puts_prefixed_second(self):
    q = "ccframerate"
    a = self.ranker.rank(q, 'CCFrameRateController.cpp')
    b1 = self.ranker.rank(q, 'webcore_platform.CCFrameRateController.o.d')
    b2 = self.ranker.rank(q, 'webkit_unit_tests.CCFrameRateControllerTest.o.d')
    self.assertTrue(a > b1);
    # FAILS because ccframera(te) ties to (Te)st
    # self.assertTrue(a > b2);

  def test_rank_order_puts_tests_second_2(self):
    q = "ccdelaybassedti"
    a1 = self.ranker.rank(q, 'CCDelayBasedTimeSource.cpp')
    a2 = self.ranker.rank(q, 'CCDelayBasedTimeSource.h')
    b = self.ranker.rank(q, 'CCDelayBasedTimeSourceTest.cpp')
    self.assertTrue(a1 > b);
    self.assertTrue(a2 > b);

    q = "LayerTexture"
    a = self.ranker.rank(q, 'LayerTexture.cpp')
    b = self.ranker.rank(q, 'LayerTextureSubImage.cpp')
    self.assertTrue(a > b)

  def test_refinement_improves_rank(self):
    a = self.ranker.rank('render_', 'render_widget.cc')
    b = self.ranker.rank('render_widget', 'render_widget.cc')
    self.assertTrue(b > a)