def inference_wrapper(inference_fn, sharded=False):
    """Wrapper for running inference."""
    dataset_name = FLAGS.dataset_name

    if not FLAGS.predictions_path:
        raise ValueError('Predictions path must be set.')

    predictions = FLAGS.predictions_path + '*'
    # Don't run inference if predictions have already been generated.
    if not tf.gfile.Glob(FLAGS.predictions_path + '*'):
        inference_fn(FLAGS.input, FLAGS.predictions_path,
                     FLAGS.checkpoint_filepath, dataset_name)

    # If using Abstract SQL, need to restore under-specified FROM clauses
    # output above.
    if FLAGS.restore_preds_from_asql:
        spider = dataset_name.lower() == 'spider'

        if not FLAGS.restored_predictions_path:
            raise ValueError('Restored predictions path must be set '
                             'if restoring predictions from AbSQL.')

        if not tf.io.gfile.exists(FLAGS.restored_predictions_path):
            restore_from_asql.restore_from_clauses(
                predictions,
                FLAGS.restored_predictions_path,
                spider_examples_json=FLAGS.spider_examples_json
                if spider else '',
                spider_tables_json=FLAGS.spider_tables_json if spider else '',
                michigan_schema=None if spider else read_schema(
                    os.path.join(FLAGS.data_filepath, FLAGS.dataset_name +
                                 '_schema.csv')),
                dataset_name=FLAGS.dataset_name,
                use_oracle_foriegn_keys=FLAGS.use_oracle_foriegn_keys)
        predictions = FLAGS.restored_predictions_path

    if FLAGS.match_and_save:
        # Load the database tables.
        schema_obj = None
        if dataset_name.lower() == 'spider':
            schema_obj = load_spider_tables(
                os.path.join(FLAGS.data_filepath, 'tables.json'))
        elif dataset_name.lower() == 'wikisql':
            raise ValueError('WikiSQL inference is not supported yet')
        else:
            schema_csv = os.path.join(FLAGS.data_filepath,
                                      FLAGS.dataset_name + '_schema.csv')
            schema_obj = read_schema(schema_csv)

        # Now match with the original data and save
        match_and_save(predictions, FLAGS.output, dataset_name.lower(),
                       FLAGS.splits, FLAGS.data_filepath, schema_obj, sharded)
def process_michigan_datasets(output_file, debugging_file, tokenizer):
  """Loads, converts, and writes Michigan examples to the standard format."""
  # TODO(alanesuhr,petershaw): Support asql for this dataset.
  if FLAGS.generate_sql and FLAGS.abstract_sql:
    raise NotImplementedError(
        'Abstract SQL currently only supported for SPIDER.')

  schema_csv = os.path.join(FLAGS.input_filepath,
                            FLAGS.dataset_name + '_schema.csv')
  data_filepath = os.path.join(FLAGS.input_filepath,
                               FLAGS.dataset_name + '.json')

  # Don't actually provide table entities.
  num_examples_created = 0
  num_examples_failed = 0

  print('Loading from ' + data_filepath)
  paired_data = get_nl_sql_pairs(data_filepath, set(FLAGS.splits))
  print('Loaded %d examples.' % len(paired_data))

  schema = read_schema(schema_csv)

  for nl, sql in paired_data:
    example = convert_michigan(nl, sql, schema, tokenizer, FLAGS.generate_sql)
    if example is not None:
      output_file.write(json.dumps(example.to_json()) + '\n')
      num_examples_created += 1

      debugging_file.write(example.model_input.original_utterance + '\n')
      if FLAGS.generate_sql:
        debugging_file.write(example.gold_query_string() + '\n\n')
    else:
      num_examples_failed += 1

  return num_examples_created, num_examples_failed
def compute_michigan_coverage():
    """Prints out statistics for asql conversions."""
    # Read data files.
    schema_csv_path = os.path.join(FLAGS.michigan_data_dir,
                                   '%s-schema.csv' % FLAGS.dataset_name)
    examples_json_path = os.path.join(FLAGS.michigan_data_dir,
                                      '%s.json' % FLAGS.dataset_name)
    schema = michigan_preprocessing.read_schema(schema_csv_path)
    if FLAGS.use_oracle_foriegn_keys:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples_orcale(
            FLAGS.dataset_name)
    else:
        foreign_keys = abstract_sql_converters.michigan_db_to_foreign_key_tuples(
            schema)
    table_schema = abstract_sql_converters.michigan_db_to_table_tuples(schema)
    nl_sql_pairs = michigan_preprocessing.get_nl_sql_pairs(
        examples_json_path, FLAGS.splits)

    # Iterate through examples and generate counts.
    num_examples = 0
    num_conversion_failures = 0
    num_successes = 0
    num_parse_failures = 0
    num_reconstruction_failtures = 0
    exception_counts = collections.defaultdict(int)
    for _, gold_sql_query in nl_sql_pairs:
        num_examples += 1
        print('Parsing example number %s.' % num_examples)
        try:
            sql_spans = abstract_sql.sql_to_sql_spans(gold_sql_query,
                                                      table_schema)
            sql_spans = abstract_sql.replace_from_clause(sql_spans)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error converting:\n%s\n%s' % (gold_sql_query, e))
            num_conversion_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        except abstract_sql.ParseError as e:
            print('Error parsing:\n%s\n%s' % (gold_sql_query, e))
            num_parse_failures += 1
            exception_counts[str(e)[:100]] += 1
            continue
        try:
            sql_spans = abstract_sql.restore_from_clause(
                sql_spans, foreign_keys)
        except abstract_sql.UnsupportedSqlError as e:
            print('Error recontructing:\n%s\n%s' % (gold_sql_query, e))
            exception_counts[str(e)[:100]] += 1
            num_reconstruction_failtures += 1
            continue
        print('Success:\n%s\n%s' %
              (gold_sql_query, abstract_sql.sql_spans_to_string(sql_spans)))
        num_successes += 1
    print('exception_counts: %s' % exception_counts)
    print('Examples: %s' % num_examples)
    print('Failed conversions: %s' % num_conversion_failures)
    print('Failed parses: %s' % num_parse_failures)
    print('Failed reconstructions: %s' % num_reconstruction_failtures)
    print('Successes: %s' % num_successes)
Exemple #4
0
def load_schema_obj(dataset_name: str, data_dir: str) -> Dict[Any, Any]:
    if dataset_name.lower() == "spider":
        return load_spider_tables(os.path.join(data_dir, "tables.json"))
    elif dataset_name.lower() == "wikisql":
        raise ValueError("WikiSQL inference is not supported yet")
    else:
        schema_csv = os.path.join(
            data_dir,
            dataset_name + "_schema.csv",
        )
        return read_schema(schema_csv)
def process_michigan_datasets(output_file, debugging_file,
                              tokenizer) -> Tuple[int, int]:
    """Loads, converts, and writes Michigan examples to the standard format."""
    # TODO(alanesuhr,petershaw): Support asql for this dataset.
    # if FLAGS.generate_sql and FLAGS.abstract_sql:
    #     raise NotImplementedError("Abstract SQL currently only supported for SPIDER.")

    schema_csv: str = os.path.join(FLAGS.input_filepath,
                                   FLAGS.dataset_name + "_schema.csv")
    data_filepath: str = os.path.join(FLAGS.input_filepath,
                                      FLAGS.dataset_name + ".json")

    # Don't actually provide table entities.
    num_examples_created = 0
    num_examples_failed = 0

    print("Loading from " + data_filepath)
    paired_data = get_nl_sql_pairs(data_filepath, frozenset(FLAGS.splits))
    print("Loaded %d examples." % len(paired_data))

    schema = read_schema(schema_csv)

    for nl, sql in paired_data:
        example = convert_michigan(
            nl,
            sql,
            schema,
            tokenizer,
            FLAGS.generate_sql,
            FLAGS.anonymize_values,
            FLAGS.abstract_sql,
            abstract_sql_converters.michigan_db_to_table_tuples(schema),
            FLAGS.allow_value_generation,
        )
        if example is not None:
            output_file.write(json.dumps(example.to_json()) + "\n")
            num_examples_created += 1

            debugging_file.write(example.model_input.original_utterance + "\n")
            if FLAGS.generate_sql:
                debugging_file.write(example.gold_query_string() + "\n\n")
        else:
            num_examples_failed += 1

    return num_examples_created, num_examples_failed