def standardize_oia_repo(output_file_path, old_standard): """ @param output_file_path: @type output_file_path: @return: @rtype: """ source_oia_repo = OIARepo() target_oia_repo = OIARepo(output_file_path) standardizer = OIAStandardizer() context = UD2OIAContext() options = WorkFlowOptions() for source in ["train", "dev", "test"]: standardized_graphs = [] for oia in tqdm.tqdm(source_oia_repo.all(source), "Standardize {}:".format(source)): origin_oia_graph = OIAGraph.parse(oia) uri = origin_oia_graph.meta["uri"] oia_graph = copy.deepcopy(origin_oia_graph) logger.info("Standardizing {}:{}".format(source, uri)) if old_standard: upgrade_to_new_standard(oia_graph) logger.info("Update to new standard {}:{}".format(source, uri)) try: standardizer.standardize(None, oia_graph, context, options) except Exception as e: logger.error("Sentence {0} standardize error".format(uri)) logger.error("Sentence = " + " ".join(oia_graph.words)) raise e standardized_graphs.append(oia_graph) target_oia_repo.insert(source, standardized_graphs)
def evaluate(lang, input_path, source, output_path, debug, use_ud): """ :param source: :return: """ if not debug: logger.remove() logger.add(sys.stderr, level="INFO") ud_repo = get_udrepo(lang) oia_repo = get_oiarepo(lang, input_path) total_node_pred_num = 0 total_node_true_num = 0 total_node_match_num = 0 total_edge_pred_num = 0 total_edge_true_num = 0 total_edge_match_num = 0 total_exact_same = 0 total_graph_num = 0 total_empty_num = 0 item_measure = [] samples = getattr(ud_repo, source) logger.info("evaluation sample num = {0}".format(len(samples))) oia_visualizer = OIAGraphVisualizer() results = dict() for oia_data in oia_repo.all(source): ground_oia_graph = OIAGraph.parse(oia_data) uri = ground_oia_graph.meta['uri'] # try: # ground_oia_graph = OIAGraph.parse(oia_repo.get(source, index)) # except Exception as e: # logger.error("error in get labeled oia for {0}_{1}".format(source, index)) # continue result = ud_repo.get(int(uri), source) if len(result) == 0: logger.error("cannot find the original ud data ") continue index, sentence, ud_data = result[0] logger.info("evaluating uri {0}".format(index)) try: if use_ud: pred_oia_graph, _ = ud2oia(lang, index, ud_data, debug=debug) else: pred_oia_graph, _ = sentence2oia(lang, index, sentence, debug=debug) except Exception as e: logger.error("error in generating oia for {0}_{1}".format( source, index)) traceback.print_exc() continue results[index] = (pred_oia_graph, ground_oia_graph) try: eval_result = graph_match_metric(pred_oia_graph, ground_oia_graph) except Exception as e: logger.error("error in evaluating {0}_{1}".format(source, index)) traceback.print_exc() continue ((node_pred_num, node_true_num, node_match_num), (edge_pred_num, edge_true_num, edge_match_num), exact_same) = eval_result node_recall = float( node_match_num) / node_true_num if node_true_num > 0 else 1 node_prec = float( node_match_num) / node_pred_num if node_pred_num > 0 else 1 if node_true_num == 0 and node_pred_num == 0: node_f1 = 1 elif node_match_num == 0: node_f1 = 0 else: node_f1 = 2 * node_recall * node_prec / (node_recall + node_prec) edge_recall = float( edge_match_num) / edge_true_num if edge_true_num > 0 else 1 edge_prec = float( edge_match_num) / edge_pred_num if edge_pred_num > 0 else 1 if edge_true_num == 0 and edge_pred_num == 0: edge_f1 = 1 elif edge_match_num == 0: edge_f1 = 0 else: edge_f1 = 2 * edge_recall * edge_prec / (edge_recall + edge_prec) item_measure.append((index, node_f1, edge_f1)) total_node_pred_num += node_pred_num total_node_true_num += node_true_num total_node_match_num += node_match_num total_edge_pred_num += edge_pred_num total_edge_true_num += edge_true_num total_edge_match_num += edge_match_num total_exact_same += exact_same total_graph_num += 1 logger.info("Evaluation Results:") logger.info("Error Graph Num = {0}".format(total_empty_num)) logger.info("Type\tTrueNum\tPredNum\tMatchNum\tRecall\tPrecision") logger.info("Node\t{0}\t{1}\t{2}\t{3}\t{4}".format( total_node_true_num, total_node_pred_num, total_node_match_num, float(total_node_match_num) / total_node_true_num, float(total_node_match_num) / total_node_pred_num, )) logger.info("Edge\t{0}\t{1}\t{2}\t{3}\t{4}".format( total_edge_true_num, total_edge_pred_num, total_edge_match_num, float(total_edge_match_num) / total_edge_true_num, float(total_edge_match_num) / total_edge_pred_num, )) logger.info("Graph\t{0}\t{1}\t{2}\t{3}\t{4}".format( total_graph_num, total_graph_num, total_exact_same, float(total_exact_same) / total_graph_num, float(total_exact_same) / total_graph_num, )) logger.info("Worst Items by node f1") item_measure.sort(key=lambda x: x[1]) for index, node_f1, edge_f1 in item_measure: if node_f1 < 0.8: pred_oia_graph, ground_oia_graph = results[index] pred_img = oia_visualizer.visualize(pred_oia_graph, return_img=True) ground_img = oia_visualizer.visualize(ground_oia_graph, return_img=True) image = make_pair_image(pred_img, ground_img) image.save( os.path.join(output_path, "node_error_{0}_{1}.png".format(source, index))) logger.info("{0}\t{1}\t{2}".format(index, node_f1, edge_f1)) logger.info("Worst Items by edge f1") item_measure.sort(key=lambda x: x[2]) for index, node_f1, edge_f1 in item_measure: if edge_f1 < 0.6: pred_oia_graph, ground_oia_graph = results[index] pred_img = oia_visualizer.visualize(pred_oia_graph, return_img=True) ground_img = oia_visualizer.visualize(ground_oia_graph, return_img=True) image = make_pair_image(pred_img, ground_img) image.save( os.path.join(output_path, "label_error_{0}_{1}.png".format(source, index))) logger.info("{0}\t{1}\t{2}".format(index, node_f1, edge_f1))