def update_sums(metrics, metrics_sums, predicted_sequence, flat_sequence, gold_query, original_gold_query, gold_forcing=False, loss=None, token_accuracy=0., database_username="", database_password="", database_timeout=0, gold_table=None): """" Updates summing for metrics in an aggregator. TODO: don't use sums, just keep the raw value. """ if Metrics.LOSS in metrics: metrics_sums[Metrics.LOSS] += loss if Metrics.TOKEN_ACCURACY in metrics: if gold_forcing: metrics_sums[Metrics.TOKEN_ACCURACY] += token_accuracy else: num_tokens_correct = 0. for j, token in enumerate(gold_query): if len(predicted_sequence ) > j and predicted_sequence[j] == token: num_tokens_correct += 1 metrics_sums[Metrics.TOKEN_ACCURACY] += num_tokens_correct / \ len(gold_query) if Metrics.STRING_ACCURACY in metrics: metrics_sums[Metrics.STRING_ACCURACY] += int( flat_sequence == original_gold_query) if Metrics.CORRECT_TABLES in metrics: assert database_username, "You did not provide a database username" assert database_password, "You did not provide a database password" assert database_timeout > 0, "Database timeout is 0 seconds" # Evaluate SQL if flat_sequence != original_gold_query: syntactic, semantic, table = sql_util.execution_results( " ".join(flat_sequence), database_username, database_password, database_timeout) else: syntactic = True semantic = True table = gold_table metrics_sums[Metrics.CORRECT_TABLES] += int(table == gold_table) if Metrics.SYNTACTIC_QUERIES in metrics: metrics_sums[Metrics.SYNTACTIC_QUERIES] += int(syntactic) if Metrics.SEMANTIC_QUERIES in metrics: metrics_sums[Metrics.SEMANTIC_QUERIES] += int(semantic) if Metrics.STRICT_CORRECT_TABLES in metrics: metrics_sums[Metrics.STRICT_CORRECT_TABLES] += int( table == gold_table and syntactic)
def write_prediction(fileptr, identifier, input_seq, probability, prediction, flat_prediction, gold_query, flat_gold_queries, gold_tables, index_in_interaction, database_username, database_password, database_timeout, compute_metrics=True): pred_obj = {} pred_obj["identifier"] = identifier if len(identifier.split('/')) == 2: database_id, interaction_id = identifier.split('/') else: database_id = 'atis' interaction_id = identifier pred_obj["database_id"] = database_id pred_obj["interaction_id"] = interaction_id pred_obj["input_seq"] = input_seq pred_obj["probability"] = probability pred_obj["prediction"] = prediction pred_obj["flat_prediction"] = flat_prediction pred_obj["gold_query"] = gold_query pred_obj["flat_gold_queries"] = flat_gold_queries pred_obj["index_in_interaction"] = index_in_interaction pred_obj["gold_tables"] = str(gold_tables) # Now compute the metrics we want. if compute_metrics: # First metric: whether flat predicted query is in the gold query set. correct_string = " ".join(flat_prediction) in [ " ".join(q) for q in flat_gold_queries ] pred_obj["correct_string"] = correct_string # Database metrics if not correct_string: syntactic, semantic, pred_table = sql_util.execution_results( " ".join(flat_prediction), database_username, database_password, database_timeout) pred_table = sorted(pred_table) best_prec = 0. best_rec = 0. best_f1 = 0. for gold_table in gold_tables: num_overlap = float(len(set(pred_table) & set(gold_table))) if len(set(gold_table)) > 0: prec = num_overlap / len(set(gold_table)) else: prec = 1. if len(set(pred_table)) > 0: rec = num_overlap / len(set(pred_table)) else: rec = 1. if prec > 0. and rec > 0.: f1 = (2 * (prec * rec)) / (prec + rec) else: f1 = 1. best_prec = max(best_prec, prec) best_rec = max(best_rec, rec) best_f1 = max(best_f1, f1) else: syntactic = True semantic = True pred_table = [] best_prec = 1. best_rec = 1. best_f1 = 1. assert best_prec <= 1. assert best_rec <= 1. assert best_f1 <= 1. pred_obj["syntactic"] = syntactic pred_obj["semantic"] = semantic correct_table = (pred_table in gold_tables) or correct_string pred_obj["correct_table"] = correct_table pred_obj["strict_correct_table"] = correct_table and syntactic pred_obj["pred_table"] = str(pred_table) pred_obj["table_prec"] = best_prec pred_obj["table_rec"] = best_rec pred_obj["table_f1"] = best_f1 fileptr.write(json.dumps(pred_obj) + "\n")
def write_prediction(fileptr, identifier, input_seq, probability, prediction, flat_prediction, gold_query, flat_gold_queries, gold_tables, index_in_interaction, database_username, database_password, database_timeout, compute_metrics=True,input_schema=None,step=1): pred_obj = {} pred_obj["identifier"] = identifier if len(identifier.split('/')) == 2: database_id, interaction_id = identifier.split('/') else: database_id = 'atis' interaction_id = identifier pred_obj["database_id"] = database_id pred_obj["interaction_id"] = interaction_id pred_obj["input_seq"] = input_seq pred_obj["probability"] = probability pred_obj["prediction"] = prediction new_flat_prediction=[] for x in flat_prediction: t=x.replace('table ','') t=t.replace('column ','') new_flat_prediction.append(t) if step==-1: pred_obj["flat_prediction"] = new_flat_prediction else: try: pred_obj["flat_prediction"] =new_flat_prediction pred_obj["flat_prediction"] = t1_t2_generate(pred_obj["flat_prediction"] ,input_schema) pred_obj["flat_prediction"] =revise(pred_obj["flat_prediction"] ) except: pred_obj["flat_prediction"] = new_flat_prediction pred_obj["gold_query"] = gold_query pred_obj["flat_gold_queries"] = flat_gold_queries pred_obj["index_in_interaction"] = index_in_interaction pred_obj["gold_tables"] = str(gold_tables) if compute_metrics: correct_string = " ".join(flat_prediction) in [ " ".join(q) for q in flat_gold_queries] pred_obj["correct_string"] = correct_string if not correct_string: syntactic, semantic, pred_table = sql_util.execution_results( " ".join(flat_prediction), database_username, database_password, database_timeout) pred_table = sorted(pred_table) best_prec = 0. best_rec = 0. best_f1 = 0. for gold_table in gold_tables: num_overlap = float(len(set(pred_table) & set(gold_table))) if len(set(gold_table)) > 0: prec = num_overlap / len(set(gold_table)) else: prec = 1. if len(set(pred_table)) > 0: rec = num_overlap / len(set(pred_table)) else: rec = 1. if prec > 0. and rec > 0.: f1 = (2 * (prec * rec)) / (prec + rec) else: f1 = 1. best_prec = max(best_prec, prec) best_rec = max(best_rec, rec) best_f1 = max(best_f1, f1) else: syntactic = True semantic = True pred_table = [] best_prec = 1. best_rec = 1. best_f1 = 1. assert best_prec <= 1. assert best_rec <= 1. assert best_f1 <= 1. pred_obj["syntactic"] = syntactic pred_obj["semantic"] = semantic correct_table = (pred_table in gold_tables) or correct_string pred_obj["correct_table"] = correct_table pred_obj["strict_correct_table"] = correct_table and syntactic pred_obj["pred_table"] = str(pred_table) pred_obj["table_prec"] = best_prec pred_obj["table_rec"] = best_rec pred_obj["table_f1"] = best_f1 fileptr.write(json.dumps(pred_obj) + "\n") return pred_obj["flat_prediction"]