def recall(e: trectools.TrecEval, per_query=False): rel_ret = e.get_relevant_retrieved_documents(per_query=per_query) rel = e.get_relevant_documents(per_query=per_query) if per_query: return (rel_ret / rel).fillna(0) else: return rel_ret / rel
def plot_rp_curve(qrels, topics, runs_file, results, model): runs = TrecRun(runs_file) ev = TrecEval(runs, qrels) # Get the relevant documents for each one of the topics new_qrels = ev.qrels.qrels_data.copy() relevant_docs = {topic: [] for topic in topics} for i, row in new_qrels.iterrows(): # If the returned document is relevant, add it to the list of relevant docs of the respective topic if row["rel"] > 0: relevant_docs[row["query"]].append(row["docid"]) num_relevant_docs = { doc_id: num for doc_id, num in ev.get_relevant_documents( per_query=True).iteritems() } # TrecTools' precision calculations are very slow, so they are calculated "directly" # Obtain the recall and precision @k values for every k up to p for each topic and plot them for i, topic in enumerate(topics): precisions_aux = [0] recalls_aux = [0] # Get the number of true positives for the given topic for j in range(min(p + 1, len(results[i]))): # Check if the docid is in the list of relevant documents for that topic if results[i][j][0] in relevant_docs[topic]: recalls_aux.append(recalls_aux[j] + 1) precisions_aux.append(precisions_aux[j] + 1) else: recalls_aux.append(recalls_aux[j]) precisions_aux.append(precisions_aux[j]) # Calculate precision and recall values based on the previous values recalls = [x / num_relevant_docs[topic] for x in recalls_aux] precisions = [(x / i if i > 0 else 1) for i, x in enumerate(precisions_aux)] # Interpolate the precisions calculated before (needed to plot the recall-precision curve) interpolated_precisions = precisions.copy() j = len(interpolated_precisions) - 2 while j >= 0: if interpolated_precisions[j + 1] > interpolated_precisions[j]: interpolated_precisions[j] = interpolated_precisions[j + 1] j -= 1 # Reduce the number of points to plot to avoid excessive memory usage recalls = [ value for j, value in enumerate(recalls) if not ((100 < j < 1000 and j % 10 != 0) or (j > 1000 and j % 100 != 0)) ] precisions = [ value for j, value in enumerate(precisions) if not ((100 < j < 1000 and j % 10 != 0) or (j > 1000 and j % 100 != 0)) ] interpolated_precisions = [ value for j, value in enumerate(interpolated_precisions) if not ((100 < j < 1000 and j % 10 != 0) or (j > 1000 and j % 100 != 0)) ] # Plot the precision-recall curve of the topic fig, ax = plt.subplots() for j in range(len(recalls) - 2): ax.plot( (recalls[j], recalls[j]), (interpolated_precisions[j], interpolated_precisions[j + 1]), 'k-', label='', color='red') ax.plot((recalls[j], recalls[j + 1]), (interpolated_precisions[j + 1], interpolated_precisions[j + 1]), 'k-', label='', color='red') ax.plot(recalls, precisions, 'k--', color='blue') ax.title.set_text("R" + str(topic)) ax.set_xlabel("recall") ax.set_ylabel("precision") # Save plot in eval folder fig.savefig(os.path.join("eval", model, f"R{topic}.png")) plt.close()
def recall(e: trectools.TrecEval) -> float: return e.get_relevant_retrieved_documents(per_query=False) / e.get_relevant_documents(per_query=False)