def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        new_source = string_utils.format_source(source)
        new_examples.append((new_source, target))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
示例#2
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        new_examples.append((nqg_tokenization.process_source(source),
                             nqg_tokenization.process_target(target)))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
示例#3
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
示例#4
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    examples_1, examples_2 = template_utils.split_by_template(
        examples,
        template_fn=funql_template_fn,
        max_num_examples_1=FLAGS.max_num_examples_1,
        seed=FLAGS.seed)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    if FLAGS.use_target:
        sorted_examples = sorted(examples, key=lambda x: len(x[1].split(" ")))
    else:
        sorted_examples = sorted(examples, key=lambda x: len(x[0].split(" ")))
    examples_1 = sorted_examples[:FLAGS.num_examples]
    examples_2 = sorted_examples[FLAGS.num_examples:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        if check_person_name(source, target) and check_year(source, target):
            new_examples.append((source, target))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
    num_examples = len(examples)
    num_new_examples = len(new_examples)
    print("original examples: %d." % num_examples)
    print("new examples: %d." % num_new_examples)
    print("filtered examples: %d." % (num_examples - num_new_examples))
def main(_):
  examples = tsv_utils.read_tsv(FLAGS.input, expected_num_columns=3)
  new_examples = []
  for source, target, category in examples:
    if category == "primitive":
      if len(source.split()) != 1:
        raise ValueError(f"Invalid primitive: {source}")
      new_target = source
    else:
      new_target = cogs_converter.cogs_lf_to_funcall(target)
    new_examples.append((source, new_target))
  tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    single_domain_examples = []
    cross_domain_examples = []
    for source, target in examples:
        if is_cross_domain(target):
            cross_domain_examples.append((source, target))
        else:
            single_domain_examples.append((source, target))
    print("len(cross_domain_examples): %s" % len(cross_domain_examples))
    print("len(single_domain_examples): %s" % len(single_domain_examples))
    tsv_utils.write_tsv(cross_domain_examples, FLAGS.output_cross)
    tsv_utils.write_tsv(single_domain_examples, FLAGS.output_single)
示例#9
0
def main(unused_argv):
    splits = load_splits()
    examples = tsv_utils.read_tsv(FLAGS.input)
    example_id_to_example = {
        example_id: example
        for example_id, example in enumerate(examples)
    }

    for split, split_ids in splits.items():
        examples = []
        for split_id in split_ids:
            examples.append(example_id_to_example[split_id])
        filename = os.path.join(FLAGS.output_dir, "%s.tsv" % split)
        tsv_utils.write_tsv(examples, filename)
示例#10
0
def main(unused_argv):
  tables_json = load_json(FLAGS.tables)
  db_id_to_schema_string = {}
  for table_json in tables_json:
    db_id = table_json["db_id"].lower()
    db_id_to_schema_string[db_id] = _get_schema_string(table_json)

  examples = tsv_utils.read_tsv(FLAGS.input)
  new_examples = []
  for source, target in examples:
    db_id = source.split()[0].rstrip(":")
    schema_string = db_id_to_schema_string[db_id]
    new_source = "%s%s" % (source, schema_string)
    new_examples.append((new_source.lower(), target.lower()))
  tsv_utils.write_tsv(new_examples, FLAGS.output)
示例#11
0
def main(unused_argv):
    examples_json = load_json(FLAGS.examples)
    examples = []
    for example_json in examples_json:
        database = example_json["db_id"]
        source = example_json["question"]
        target = example_json["query"]

        # Optionally skip if database not in set of databases with >= 50 examples.
        if (FLAGS.filter_by_database
                and database not in database_constants.DATABASES):
            continue

        # Prepend database.
        source = "%s: %s" % (database, source)

        target = normalize_whitespace(target)
        examples.append((source.lower(), target.lower()))

    tsv_utils.write_tsv(examples, FLAGS.output)
示例#12
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)

    # First, randomly split examples.
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]

    # Initialize cache.
    cache = AtomAndCompoundCache()

    # Swap examples to meet atom constraint and maximize compound divergence.
    examples_1, examples_2 = mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn=cache.get_compounds,
        get_atoms_fn=cache.get_atoms,
        max_iterations=10000,
        max_divergence=None)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
示例#13
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)

    # First, randomly split examples.
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]

    # Swap examples to meet atom constraint and maximize compound divergence.
    get_atoms_fn = (tmcd_utils.get_example_atoms_with_num_arguments
                    if FLAGS.get_atoms_with_num_arguments else
                    tmcd_utils.get_example_atoms)
    examples_1, examples_2 = mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn=tmcd_utils.get_example_compounds,
        get_atoms_fn=get_atoms_fn,
        max_iterations=1000,
        max_divergence=None,
        min_atom_count=FLAGS.min_atom_count)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
示例#14
0
def main(unused_argv):
    sampler = sampler_utils.get_sampler_wrapper(
        augment_config=FLAGS.augment_config,
        model_dir=FLAGS.model_dir,
        model_config=FLAGS.model_config,
        rules=FLAGS.rules,
        target_grammar_file=FLAGS.target_grammar,
        checkpoint=FLAGS.checkpoint,
        verbose=FLAGS.verbose)

    examples = []
    if FLAGS.allow_duplicates:
        while len(examples) < FLAGS.num_examples:
            source, target = sampler.sample_example(len(examples))
            examples.append((source, target))
    else:
        examples_set = set()
        while len(examples_set) < FLAGS.num_examples:
            source, target = sampler.sample_example(len(examples_set))
            examples_set.add((source, target))
        examples = list(examples_set)
    tsv_utils.write_tsv(examples, FLAGS.output)
    if FLAGS.save_sampler:
        sampler.save()
示例#15
0
def main(unused_argv):
    examples = load_examples(FLAGS.input)
    tsv_utils.write_tsv(examples, FLAGS.output)
示例#16
0
def main(unused_argv):
    input_1 = tsv_utils.read_tsv(FLAGS.input_1)
    input_2 = tsv_utils.read_tsv(FLAGS.input_2)
    outputs = input_1 * FLAGS.duplicate_input_1 + input_2
    random.shuffle(outputs)
    tsv_utils.write_tsv(outputs, FLAGS.output)
示例#17
0
def main(unused_argv):
    source = read_txt(FLAGS.source)
    target = read_txt(FLAGS.target)
    examples = list(zip(source, target))
    tsv_utils.write_tsv(examples, FLAGS.output)
示例#18
0
def main(unused_argv):
    examples = read_examples(FLAGS.source, FLAGS.target)
    tsv_utils.write_tsv(examples, FLAGS.output)
示例#19
0
def main(unused_argv):
    examples = get_examples()
    tsv_utils.write_tsv(examples, FLAGS.output)