def get_features(inp):
    features = []
    labels = []
    queryfeats = []

    for _ in inp:
        try:
            qid = random.choice(list(QUERYDICT.keys()))
            query = [token.text for token in PREPROCESS(QUERYDICT[qid])]
            qf = get_query_features(query)

            # positive sample
            docid = random.choice(QREL[qid])
            docfeats = get_doc_feats(docid, query, FILE)
            features.append(docfeats)
            if not args.graded:
                labels.append(1)
            else:
                labels.append(4 if random.random() < 0.5 else 3)
            queryfeats.append(qf)

            # negative sample
            docid = random.choice(list(set(TOP100[qid]) - set(QREL[qid])))
            docfeats = get_doc_feats(docid, query, FILE)
            features.append(docfeats)
            if not args.graded:
                labels.append(0)
            else:
                labels.append(1 if random.random() > 0.5 else 2)
            queryfeats.append(qf)

        except Exception as e:
            print("ERROR:", e)

    return np.array(queryfeats), np.array(features), np.array(labels)
def predict(inp):
    qid, query = inp
    ret = []

    if args.model == "okapi" or args.model == "bm25":
        results = SEARCHER.search(qp.parse(query), limit=args.limit)
        for rank, hit in enumerate(results):
            ret.append([qid, hit["docid"], rank + 1, results.score(rank), run_id])

    elif args.model == "clusvm":
        query = [token.text for token in PREPROCESS(query)]
        queryfeats = get_query_features(query)
        queryfeats = np.concatenate([queryfeats[None, :]] * len(top100[qid]))

        docids = []
        features = []
        for docid in top100[qid]:
            features.append(get_doc_feats(docid, query, FILE))
            docids.append(docid)
        features = np.array(features)

        relevance = clusvm.predict(queryfeats, features)
        ordering = np.argsort(-relevance)

        for rank, idx in enumerate(ordering):
            if relevance[idx] != 0:
                ret.append([qid, docids[idx], rank + 1, relevance[idx], run_id])

    elif "sv" in args.model or args.model == "adarank":
        query = [token.text for token in PREPROCESS(query)]

        docids = []
        features = []
        for docid in top100[qid]:
            features.append(get_doc_feats(docid, query, FILE))
            docids.append(docid)
        features = np.array(features)

        if args.model == "adarank":
            relevance = np.dot(features, alpha)
        else:
            if args.binary or args.model == "svr":
                relevance = svm.predict(features)
            else:
                relevance = svm.decision_function(features)
            if args.add_bm25:
                relevance += features[:, -1] / 100
        ordering = np.argsort(-relevance)

        for rank, idx in enumerate(ordering):
            if relevance[idx] != 0:
                ret.append([qid, docids[idx], rank + 1, relevance[idx], run_id])

    else:
        print("ERROR: unsupported model")

    return ret
def process_graded(inp):
    qid, _, docid, relevance = inp
    query = [token.text for token in PREPROCESS(QUERYDICT[qid])]
    queryfeats = get_query_features(query)
    docfeats = get_doc_feats(docid, query, FILE)
    return queryfeats, docfeats, relevance