Beispiel #1
0
def main(unused_argv):
  examples = tsv_utils.read_tsv(FLAGS.input)
  new_examples = []
  for source, target in examples:
    new_examples.append((nqg_tokenization.process_source(source),
                         nqg_tokenization.process_target(target)))
  tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv):
  examples = tsv_utils.read_tsv(FLAGS.input)
  rules = qcfg_file.read_rules(FLAGS.rules)
  print("Rules: %s" % rules)

  num_examples = 0
  num_covered = 0

  for idx, example in enumerate(examples):
    if FLAGS.offset and idx < FLAGS.offset:
      continue
    if FLAGS.limit and idx >= FLAGS.limit:
      break
    print("Processing example %s." % idx)
    print("Source: %s" % example[0])
    print("Target: %s" % example[1])

    source = example[0]
    gold_target = example[1]

    can_parse = qcfg_parser.can_parse(source, gold_target, rules, verbose=False)

    num_examples += 1

    if can_parse:
      num_covered += 1
    else:
      print("Output set does not contain gold target.")

  print("%s covered out of %s" % (num_covered, num_examples))
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv):
    examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
    examples_2 = tsv_utils.read_tsv(FLAGS.input_2)
    if examples_1 == examples_2:
        print("Examples are the same.")
    else:
        print("Examples are different.")
        if len(examples_1) != len(examples_2):
            print("Number of examples is different.")
        else:
            for idx, (example_1,
                      example_2) in enumerate(zip(examples_1, examples_2)):
                if example_1 != example_2:
                    print("First different example pair at idx %s:" % idx)
                    print(example_1)
                    print(example_2)
                    break
Beispiel #5
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    examples_1, examples_2 = template_utils.split_by_template(
        examples,
        template_fn=spider_template_fn,
        max_num_examples_1=FLAGS.max_num_examples_1,
        seed=FLAGS.seed)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
Beispiel #6
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    if FLAGS.use_target:
        sorted_examples = sorted(examples, key=lambda x: len(x[1].split(" ")))
    else:
        sorted_examples = sorted(examples, key=lambda x: len(x[0].split(" ")))
    examples_1 = sorted_examples[:FLAGS.num_examples]
    examples_2 = sorted_examples[FLAGS.num_examples:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv):
    examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
    examples_2 = tsv_utils.read_tsv(FLAGS.input_2)

    atoms_1 = mcd_utils.get_all_atoms(
        examples_1, get_atoms_fn=tmcd_utils.get_example_atoms)

    num_examples = 0
    num_examples_with_unseen_atom = 0
    for example in examples_2:
        atoms = tmcd_utils.get_example_atoms(example)
        num_examples += 1
        for atom in atoms:
            if atom not in atoms_1:
                print("New atom: %s" % atom)
                num_examples_with_unseen_atom += 1
                break

    print("num_examples: %s" % num_examples)
    print("num_examples_with_unseen_atom: %s" % num_examples_with_unseen_atom)
    print("pct: %s" % (float(num_examples_with_unseen_atom) / num_examples))
Beispiel #8
0
def main(unused_argv):
    formatted_db_id_to_db_id = {}
    for db_id in database_constants.DATABASES:
        formatted_db_id_to_db_id[db_id.lower()] = db_id
        formatted_db_id_to_db_id[db_id] = db_id

    examples = tsv_utils.read_tsv(FLAGS.input)
    with gfile.GFile(FLAGS.output, "w") as txt_file:
        for example in examples:
            db_id = example[0].split()[0].rstrip(":")
            db_id = formatted_db_id_to_db_id[db_id]
            txt_file.write("%s\t%s\n" % (example[1], db_id))
def main(unused_argv):
    splits = load_splits()
    examples = tsv_utils.read_tsv(FLAGS.input)
    example_id_to_example = {
        example_id: example
        for example_id, example in enumerate(examples)
    }

    for split, split_ids in splits.items():
        examples = []
        for split_id in split_ids:
            examples.append(example_id_to_example[split_id])
        filename = os.path.join(FLAGS.output_dir, "%s.tsv" % split)
        tsv_utils.write_tsv(examples, filename)
def main(unused_argv):
  examples = tsv_utils.read_tsv(FLAGS.input)
  for idx, (_, target) in enumerate(examples):
    if FLAGS.offset and idx < FLAGS.offset:
      continue
    if FLAGS.limit and idx >= FLAGS.limit:
      break
    print("Processing example %s." % idx)

    try:
      _ = sql_parser.parse_sql(target)
    except ValueError as e:
      print(e)
      # Retry parsing with verbose debugging.
      _ = sql_parser.parse_sql(target, verbose=True)
def main(unused_argv):
    tables_json = load_json(FLAGS.tables)
    db_id_to_schema_string = {}
    for table_json in tables_json:
        db_id = table_json["db_id"].lower()
        db_id_to_schema_string[db_id] = _get_schema_string(table_json)

    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        db_id = source.split()[0].rstrip(":")
        schema_string = db_id_to_schema_string[db_id]
        new_source = "%s%s" % (source, schema_string)
        new_examples.append((new_source.lower(), target.lower()))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
Beispiel #12
0
def induce_and_write_rules():
    """Induce and write set of rules."""
    examples = tsv_utils.read_tsv(FLAGS.input)
    config = induction_utils.InductionConfig(
        sample_size=FLAGS.sample_size,
        max_iterations=FLAGS.max_iterations,
        min_delta=FLAGS.min_delta,
        terminal_codelength=FLAGS.terminal_codelength,
        non_terminal_codelength=FLAGS.non_terminal_codelength,
        parse_sample=FLAGS.parse_sample,
        allow_repeated_target_nts=FLAGS.allow_repeated_target_nts,
        seed_exact_match=FLAGS.seed_exact_match,
        balance_parens=FLAGS.balance_parens,
    )
    induced_rules = induction_utils.induce_rules(examples, config)
    qcfg_file.write_rules(induced_rules, FLAGS.output)
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)

    # First, randomly split examples.
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]

    # Swap examples to meet atom constraint and maximize compound divergence.
    examples_1, examples_2 = mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn=tmcd_utils.get_example_compounds,
        get_atoms_fn=tmcd_utils.get_example_atoms,
        max_iterations=1000,
        max_divergence=None)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
Beispiel #14
0
def main(unused_argv):
  gold_examples = tsv_utils.read_tsv(FLAGS.gold)

  preds = []
  with gfile.GFile(FLAGS.predictions, "r") as f:
    for line in f:
      preds.append(line.rstrip())

  correct = 0
  incorrect = 0
  for pred, gold_example in zip(preds, gold_examples):
    if pred == gold_example[1]:
      correct += 1
    else:
      incorrect += 1
      print("Incorrect for example %s.\nTarget: %s\nPrediction: %s" %
            (gold_example[0], gold_example[1], pred))

  print("correct: %s" % correct)
  print("incorrect: %s" % incorrect)
  print("pct: %s" % str(float(correct) / float(correct + incorrect)))
def main(unused_argv):
  config = config_utils.json_file_to_dict(FLAGS.config)
  examples = tsv_utils.read_tsv(FLAGS.input)
  rules = qcfg_file.read_rules(FLAGS.rules)
  tokenizer = tokenization_utils.get_tokenizer(
      os.path.join(FLAGS.bert_dir, "vocab.txt"))
  converter = example_converter.ExampleConverter(rules, tokenizer, config)

  total_written = 0
  writer = tf.io.TFRecordWriter(FLAGS.output)
  for idx, example in enumerate(examples):
    if FLAGS.offset and idx < FLAGS.offset:
      continue
    if FLAGS.limit and idx >= FLAGS.limit:
      break
    print("Processing example %s." % idx)

    tf_example = converter.convert(example)
    writer.write(tf_example.SerializeToString())
    total_written += 1

  converter.print_max_sizes()
  print("Wrote %d examples." % total_written)
Beispiel #16
0
def main(unused_argv):
    config = config_utils.json_file_to_dict(FLAGS.config)
    wrapper = get_inference_wrapper(config)
    examples = tsv_utils.read_tsv(FLAGS.input)
    writer = get_summary_writer()

    if FLAGS.poll:
        last_checkpoint = None
        while True:
            checkpoint, step = get_checkpoint()
            if checkpoint == last_checkpoint:
                print("Waiting for new checkpoint...\nLast checkpoint: %s" %
                      last_checkpoint)
            else:
                run_inference(writer, wrapper, examples, checkpoint, step=step)
                last_checkpoint = checkpoint
            if step and step >= config["training_steps"]:
                # Stop eval job after completing eval for last training step.
                break
            time.sleep(10)
    else:
        checkpoint, _ = get_checkpoint()
        run_inference(writer, wrapper, examples, checkpoint)
Beispiel #17
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    with gfile.GFile(FLAGS.output, "w") as txt_file:
        for example in examples:
            txt_file.write("%s%s\n" % (FLAGS.prefix, example[0]))
def main(unused_argv):
  examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
  examples_2 = tsv_utils.read_tsv(FLAGS.input_2)
  divergence = mcd_utils.measure_example_divergence(
      examples_1, examples_2, get_compounds_fn=tmcd_utils.get_example_compounds)
  print("Compound divergence: %s" % divergence)