Ejemplo n.º 1
0
def _get_restored_predictions(predictions_dict,
                              utterance_to_db_map=None,
                              schema=None,
                              dataset_name=None,
                              use_oracle_foreign_keys=False):
    """Returns new predictions dict with FROM clauses restored."""
    utterance = predictions_dict['utterance']
    if utterance_to_db_map:
        db = utterance_to_db_map[utterance]
        foreign_keys = abstract_sql_converters.spider_db_to_foreign_key_tuples(
            db)
        table_schemas = abstract_sql_converters.spider_db_to_table_tuples(db)

    else:
        if use_oracle_foreign_keys:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
                dataset_name)
        else:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
                schema)
        table_schemas = abstract_sql_converters.michigan_db_to_table_tuples(
            schema)

    restored_predictions = []
    restored_scores = []
    for prediction, score in zip(predictions_dict['predictions'],
                                 predictions_dict['scores']):
        # Some predictions have repeated single quotes around values.
        prediction = prediction.replace("''", "'")

        try:
            restored_prediction = abstract_sql_converters.restore_predicted_sql(
                prediction, table_schemas, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            # Remove predictions that fail conversion.
            print('For query %s' % prediction)
            print('Unsupport Error: ' + str(e))
        except abstract_sql.ParseError as e:
            print('Parse Error!!!')
        else:
            restored_predictions.append(restored_prediction)
            restored_scores.append(score)
    restored_predictions_dict = {
        'utterance': utterance,
        'predictions': restored_predictions,
        'scores': restored_scores,
    }
    return restored_predictions_dict
Ejemplo n.º 2
0
def _get_restored_predictions(
    predictions_dict: Prediction,
    utterance_to_db_map=None,
    schema=None,
    dataset_name=None,
    use_oracle_foreign_keys: bool = False,
) -> Tuple[Prediction, int]:
    """
    Returns new predictions dict with FROM clauses restored and a count of errors.
    """
    utterance = predictions_dict["utterance"]
    if utterance_to_db_map:
        db = utterance_to_db_map[utterance]
        foreign_keys = abstract_sql_converters.spider_db_to_foreign_key_tuples(db)
        table_schemas = abstract_sql_converters.spider_db_to_table_tuples(db)

    else:
        if use_oracle_foreign_keys:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
                dataset_name
            )
        else:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
                schema
            )
        table_schemas = abstract_sql_converters.michigan_db_to_table_tuples(schema)

    restored_predictions = []
    restored_scores = []
    error_count = 0

    for prediction, score in zip(
        predictions_dict["predictions"], predictions_dict["scores"]
    ):
        # Some predictions have repeated single quotes around values.
        prediction = prediction.replace("''", "'")

        try:
            restored_prediction = abstract_sql_converters.restore_predicted_sql(
                prediction, table_schemas, foreign_keys
            )
            restored_predictions.append(restored_prediction)
            restored_scores.append(score)
        except abstract_sql.UnsupportedSqlError as e:
            # Remove predictions that fail conversion.
            print("For query %s" % prediction)
            print("Unsupport Error: " + str(e))
            error_count += 1
        except abstract_sql.ParseError as e:
            print("For query %s" % prediction)
            print("Parse Error: " + str(e))
            error_count += 1

    return (
        {
            "utterance": utterance,
            "predictions": restored_predictions,
            "scores": restored_scores,
        },
        error_count,
    )