def compute_michigan_coverage():
    """Prints out statistics for asql conversions."""
    # Read data files.
    schema_csv_path = os.path.join(FLAGS.michigan_data_dir,
                                   '%s-schema.csv' % FLAGS.dataset_name)
    examples_json_path = os.path.join(FLAGS.michigan_data_dir,
                                      '%s.json' % FLAGS.dataset_name)
    schema = michigan_preprocessing.read_schema(schema_csv_path)
    if FLAGS.use_oracle_foriegn_keys:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
            FLAGS.dataset_name)
    else:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
            schema)
    table_schema = abstract_sql_converters.michigan_db_to_table_tuples(schema)
    nl_sql_pairs = michigan_preprocessing.get_nl_sql_pairs(
        examples_json_path, FLAGS.splits)

    # Iterate through examples and generate counts.
    num_examples = 0
    num_conversion_failures = 0
    num_successes = 0
    num_parse_failures = 0
    num_reconstruction_failtures = 0
    exception_counts = collections.defaultdict(int)
    for _, gold_sql_query in nl_sql_pairs:
        num_examples += 1
        print('Parsing example number %s.' % num_examples)
        try:
            sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query,
                                                      table_schema)
            sql_spans = abstract_sql.replace_from_clause(sql_spans)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error converting:\n%s\n%s' % (gold_sql_query, e))
            num_conversion_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        except abstract_sql.ParseError as e:
            print('Error parsing:\n%s\n%s' % (gold_sql_query, e))
            num_parse_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        try:
            sql_spans = abstract_sql.restore_from_clause(
                sql_spans, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error recontructing:\n%s\n%s' % (gold_sql_query, e))
            exception_counts[str(e)[:100]] += 1
            num_reconstruction_failtures += 1
            continue
        print('Success:\n%s\n%s' %
              (gold_sql_query, abstract_sql.sql_spans_to_string(sql_spans)))
        num_successes += 1
    print('exception_counts: %s' % exception_counts)
    print('Examples: %s' % num_examples)
    print('Failed conversions: %s' % num_conversion_failures)
    print('Failed parses: %s' % num_parse_failures)
    print('Failed reconstructions: %s' % num_reconstruction_failtures)
    print('Successes: %s' % num_successes)
示例#2
0
def _get_restored_predictions(predictions_dict,
                              utterance_to_db_map=None,
                              schema=None,
                              dataset_name=None,
                              use_oracle_foreign_keys=False):
    """Returns new predictions dict with FROM clauses restored."""
    utterance = predictions_dict['utterance']
    if utterance_to_db_map:
        db = utterance_to_db_map[utterance]
        foreign_keys = abstract_sql_converters.spider_db_to_foreign_key_tuples(
            db)
        table_schemas = abstract_sql_converters.spider_db_to_table_tuples(db)

    else:
        if use_oracle_foreign_keys:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
                dataset_name)
        else:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
                schema)
        table_schemas = abstract_sql_converters.michigan_db_to_table_tuples(
            schema)

    restored_predictions = []
    restored_scores = []
    for prediction, score in zip(predictions_dict['predictions'],
                                 predictions_dict['scores']):
        # Some predictions have repeated single quotes around values.
        prediction = prediction.replace("''", "'")

        try:
            restored_prediction = abstract_sql_converters.restore_predicted_sql(
                prediction, table_schemas, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            # Remove predictions that fail conversion.
            print('For query %s' % prediction)
            print('Unsupport Error: ' + str(e))
        except abstract_sql.ParseError as e:
            print('Parse Error!!!')
        else:
            restored_predictions.append(restored_prediction)
            restored_scores.append(score)
    restored_predictions_dict = {
        'utterance': utterance,
        'predictions': restored_predictions,
        'scores': restored_scores,
    }
    return restored_predictions_dict
def process_michigan_datasets(output_file, debugging_file,
                              tokenizer) -> Tuple[int, int]:
    """Loads, converts, and writes Michigan examples to the standard format."""
    # TODO(alanesuhr,petershaw): Support asql for this dataset.
    # if FLAGS.generate_sql and FLAGS.abstract_sql:
    #     raise NotImplementedError("Abstract SQL currently only supported for SPIDER.")

    schema_csv: str = os.path.join(FLAGS.input_filepath,
                                   FLAGS.dataset_name + "_schema.csv")
    data_filepath: str = os.path.join(FLAGS.input_filepath,
                                      FLAGS.dataset_name + ".json")

    # Don't actually provide table entities.
    num_examples_created = 0
    num_examples_failed = 0

    print("Loading from " + data_filepath)
    paired_data = get_nl_sql_pairs(data_filepath, frozenset(FLAGS.splits))
    print("Loaded %d examples." % len(paired_data))

    schema = read_schema(schema_csv)

    for nl, sql in paired_data:
        example = convert_michigan(
            nl,
            sql,
            schema,
            tokenizer,
            FLAGS.generate_sql,
            FLAGS.anonymize_values,
            FLAGS.abstract_sql,
            abstract_sql_converters.michigan_db_to_table_tuples(schema),
            FLAGS.allow_value_generation,
        )
        if example is not None:
            output_file.write(json.dumps(example.to_json()) + "\n")
            num_examples_created += 1

            debugging_file.write(example.model_input.original_utterance + "\n")
            if FLAGS.generate_sql:
                debugging_file.write(example.gold_query_string() + "\n\n")
        else:
            num_examples_failed += 1

    return num_examples_created, num_examples_failed
def _get_restored_predictions(
    predictions_dict: Prediction,
    utterance_to_db_map=None,
    schema=None,
    dataset_name=None,
    use_oracle_foreign_keys: bool = False,
) -> Tuple[Prediction, int]:
    """
    Returns new predictions dict with FROM clauses restored and a count of errors.
    """
    utterance = predictions_dict["utterance"]
    if utterance_to_db_map:
        db = utterance_to_db_map[utterance]
        foreign_keys = abstract_sql_converters.spider_db_to_foreign_key_tuples(db)
        table_schemas = abstract_sql_converters.spider_db_to_table_tuples(db)

    else:
        if use_oracle_foreign_keys:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
                dataset_name
            )
        else:
            foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
                schema
            )
        table_schemas = abstract_sql_converters.michigan_db_to_table_tuples(schema)

    restored_predictions = []
    restored_scores = []
    error_count = 0

    for prediction, score in zip(
        predictions_dict["predictions"], predictions_dict["scores"]
    ):
        # Some predictions have repeated single quotes around values.
        prediction = prediction.replace("''", "'")

        try:
            restored_prediction = abstract_sql_converters.restore_predicted_sql(
                prediction, table_schemas, foreign_keys
            )
            restored_predictions.append(restored_prediction)
            restored_scores.append(score)
        except abstract_sql.UnsupportedSqlError as e:
            # Remove predictions that fail conversion.
            print("For query %s" % prediction)
            print("Unsupport Error: " + str(e))
            error_count += 1
        except abstract_sql.ParseError as e:
            print("For query %s" % prediction)
            print("Parse Error: " + str(e))
            error_count += 1

    return (
        {
            "utterance": utterance,
            "predictions": restored_predictions,
            "scores": restored_scores,
        },
        error_count,
    )