예제 #1
0
def main(unused_argv):
    examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
    examples_2 = tsv_utils.read_tsv(FLAGS.input_2)
    divergence = mcd_utils.measure_example_divergence(
        examples_1,
        examples_2,
        get_compounds_fn=tmcd_utils.get_example_compounds)
    print("Compound divergence: %s" % divergence)
예제 #2
0
def main(unused_argv):
    config = json_utils.json_file_to_dict(FLAGS.config)
    examples = tsv_utils.read_tsv(FLAGS.input)
    rules = qcfg_file.read_rules(FLAGS.rules)

    slice_start = FLAGS.offset
    slice_end = FLAGS.limit if FLAGS.limit else None
    examples = examples[slice(slice_start, slice_end)]

    def _convert_examples(pipeline):
        _ = (pipeline
             | "ImportExamples" >> beam.Create(examples)
             | "ConvertExamples" >> beam.ParDo(ConvertExampleFn(rules, config))
             | "WriteExamples" >> beam.io.tfrecordio.WriteToTFRecord(
                 FLAGS.output, coder=beam.coders.ProtoCoder(tf.train.Example)))

    pipeline_options = beam.options.pipeline_options.PipelineOptions(
        FLAGS.pipeline_options)
    with beam.Pipeline(pipeline_options) as pipeline:
        _convert_examples(pipeline)

    metrics = pipeline.result.metrics().query()
    for distribution in metrics["distributions"]:
        logging.info("max %s: %s", distribution.key.metric.name,
                     distribution.committed.max)
    for counter in metrics["counters"]:
        logging.info("count %s: %s", counter.key.metric.name,
                     counter.committed)
예제 #3
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        new_examples.append((nqg_tokenization.process_source(source),
                             nqg_tokenization.process_target(target)))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
예제 #4
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    rules = qcfg_file.read_rules(FLAGS.rules)
    print("Rules: %s" % rules)

    num_examples = 0
    num_covered = 0

    for idx, example in enumerate(examples):
        if FLAGS.offset and idx < FLAGS.offset:
            continue
        if FLAGS.limit and idx >= FLAGS.limit:
            break
        print("Processing example %s." % idx)
        print("Source: %s" % example[0])
        print("Target: %s" % example[1])

        source = example[0]
        gold_target = example[1]

        can_parse = qcfg_parser.can_parse(source,
                                          gold_target,
                                          rules,
                                          verbose=False)

        num_examples += 1

        if can_parse:
            num_covered += 1
        else:
            print("Output set does not contain gold target.")

    print("%s covered out of %s" % (num_covered, num_examples))
예제 #5
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        new_source = string_utils.format_source(source)
        new_examples.append((new_source, target))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
예제 #6
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    rules = exact_match_utils.get_exact_match_rules(examples)
    # Sort by target.
    rules = list(rules)
    rules.sort(key=lambda x: x.target)
    qcfg_file.write_rules(rules, FLAGS.output)
예제 #7
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
예제 #8
0
def main(unused_argv):
    examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
    examples_2 = tsv_utils.read_tsv(FLAGS.input_2)
    if examples_1 == examples_2:
        print("Examples are the same.")
    else:
        print("Examples are different.")
        if len(examples_1) != len(examples_2):
            print("Number of examples is different.")
        else:
            for idx, (example_1,
                      example_2) in enumerate(zip(examples_1, examples_2)):
                if example_1 != example_2:
                    print("First different example pair at idx %s:" % idx)
                    print(example_1)
                    print(example_2)
                    break
예제 #9
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    examples_1, examples_2 = template_utils.split_by_template(
        examples,
        template_fn=funql_template_fn,
        max_num_examples_1=FLAGS.max_num_examples_1,
        seed=FLAGS.seed)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
예제 #10
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    if FLAGS.use_target:
        sorted_examples = sorted(examples, key=lambda x: len(x[1].split(" ")))
    else:
        sorted_examples = sorted(examples, key=lambda x: len(x[0].split(" ")))
    examples_1 = sorted_examples[:FLAGS.num_examples]
    examples_2 = sorted_examples[FLAGS.num_examples:]
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(_):
  examples = tsv_utils.read_tsv(FLAGS.input, expected_num_columns=3)
  new_examples = []
  for source, target, category in examples:
    if category == "primitive":
      if len(source.split()) != 1:
        raise ValueError(f"Invalid primitive: {source}")
      new_target = source
    else:
      new_target = cogs_converter.cogs_lf_to_funcall(target)
    new_examples.append((source, new_target))
  tsv_utils.write_tsv(new_examples, FLAGS.output)
예제 #12
0
def main(unused_argv):
    formatted_db_id_to_db_id = {}
    for db_id in database_constants.DATABASES:
        formatted_db_id_to_db_id[db_id.lower()] = db_id
        formatted_db_id_to_db_id[db_id] = db_id

    examples = tsv_utils.read_tsv(FLAGS.input)
    with gfile.GFile(FLAGS.output, "w") as txt_file:
        for example in examples:
            db_id = example[0].split()[0].rstrip(":")
            db_id = formatted_db_id_to_db_id[db_id]
            txt_file.write("%s\t%s\n" % (example[1], db_id))
def main(unused_argv):
    examples_1 = tsv_utils.read_tsv(FLAGS.input_1)
    examples_2 = tsv_utils.read_tsv(FLAGS.input_2)

    atoms_1 = mcd_utils.get_all_atoms(
        examples_1, get_atoms_fn=tmcd_utils.get_example_atoms)

    num_examples = 0
    num_examples_with_unseen_atom = 0
    for example in examples_2:
        atoms = tmcd_utils.get_example_atoms(example)
        num_examples += 1
        for atom in atoms:
            if atom not in atoms_1:
                print("New atom: %s" % atom)
                num_examples_with_unseen_atom += 1
                break

    print("num_examples: %s" % num_examples)
    print("num_examples_with_unseen_atom: %s" % num_examples_with_unseen_atom)
    print("pct: %s" % (float(num_examples_with_unseen_atom) / num_examples))
예제 #14
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    new_examples = []
    for source, target in examples:
        if check_person_name(source, target) and check_year(source, target):
            new_examples.append((source, target))
    tsv_utils.write_tsv(new_examples, FLAGS.output)
    num_examples = len(examples)
    num_new_examples = len(new_examples)
    print("original examples: %d." % num_examples)
    print("new examples: %d." % num_new_examples)
    print("filtered examples: %d." % (num_examples - num_new_examples))
예제 #15
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    single_domain_examples = []
    cross_domain_examples = []
    for source, target in examples:
        if is_cross_domain(target):
            cross_domain_examples.append((source, target))
        else:
            single_domain_examples.append((source, target))
    print("len(cross_domain_examples): %s" % len(cross_domain_examples))
    print("len(single_domain_examples): %s" % len(single_domain_examples))
    tsv_utils.write_tsv(cross_domain_examples, FLAGS.output_cross)
    tsv_utils.write_tsv(single_domain_examples, FLAGS.output_single)
예제 #16
0
def merge_predictions(examples, filename):
    """Merge multiple predcition files into one."""
    source_to_prediction = {}
    output_files = gfile.glob("%s-*-of-*" % filename)
    for output_file in output_files:
        predictions = tsv_utils.read_tsv(output_file)
        for prediction in predictions:
            source, predicted_target = prediction
            source_to_prediction[source] = predicted_target
    new_predictions = []
    for example in examples:
        new_predictions.append((source_to_prediction[example[0]]))
    txt_utils.write_txt(new_predictions, filename)
def main(unused_argv):
  examples = tsv_utils.read_tsv(FLAGS.input)
  rules = set()

  for source, target in examples:
    rules |= get_number_rules(source, target)
    rules |= get_string_rules(source, target)
    rules |= get_datetime_exact_match(source, target)

  # Sort by target.
  rules = list(rules)
  rules.sort(key=lambda x: x.target)
  qcfg_file.write_rules(rules, FLAGS.output)
예제 #18
0
def main(unused_argv):
    splits = load_splits()
    examples = tsv_utils.read_tsv(FLAGS.input)
    example_id_to_example = {
        example_id: example
        for example_id, example in enumerate(examples)
    }

    for split, split_ids in splits.items():
        examples = []
        for split_id in split_ids:
            examples.append(example_id_to_example[split_id])
        filename = os.path.join(FLAGS.output_dir, "%s.tsv" % split)
        tsv_utils.write_tsv(examples, filename)
예제 #19
0
def main(unused_argv):
    config = json_utils.json_file_to_dict(FLAGS.config)
    wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules,
                                                    FLAGS.target_grammar,
                                                    FLAGS.verbose)
    writer = None
    if FLAGS.write:
        write_dir = os.path.join(FLAGS.model_dir, FLAGS.subdir)
        writer = writer_utils.get_summary_writer(write_dir)

    examples = tsv_utils.read_tsv(FLAGS.input)
    fallback_predictions = [None] * len(examples)
    if FLAGS.fallback_predictions:
        fallback_predictions = txt_utils.read_txt(FLAGS.fallback_predictions)
    if len(examples) != len(fallback_predictions):
        raise ValueError("len(examples) != len(fallback_predictions).")

    slice_start = FLAGS.offset
    slice_end = FLAGS.limit if FLAGS.limit else None
    examples = examples[slice(slice_start, slice_end)]
    fallback_predictions = fallback_predictions[slice(slice_start, slice_end)]

    if FLAGS.poll:
        last_checkpoint = None
        while True:
            checkpoint, step = inference_utils.get_checkpoint(
                wrapper, FLAGS.model_dir, FLAGS.checkpoint)
            if checkpoint == last_checkpoint:
                logging.info(
                    "Waiting for new checkpoint...\nLast checkpoint: %s",
                    last_checkpoint)
            else:
                run_inference(writer,
                              wrapper,
                              examples,
                              fallback_predictions,
                              step=step)
                last_checkpoint = checkpoint
            if step and step >= config["training_steps"]:
                # Stop eval job after completing eval for last training step.
                break
            time.sleep(10)
    else:
        checkpoint, step = inference_utils.get_checkpoint(
            wrapper, FLAGS.model_dir, FLAGS.checkpoint)
        run_inference(writer,
                      wrapper,
                      examples,
                      fallback_predictions,
                      step=step)
예제 #20
0
def main(unused_argv):
  tables_json = load_json(FLAGS.tables)
  db_id_to_schema_string = {}
  for table_json in tables_json:
    db_id = table_json["db_id"].lower()
    db_id_to_schema_string[db_id] = _get_schema_string(table_json)

  examples = tsv_utils.read_tsv(FLAGS.input)
  new_examples = []
  for source, target in examples:
    db_id = source.split()[0].rstrip(":")
    schema_string = db_id_to_schema_string[db_id]
    new_source = "%s%s" % (source, schema_string)
    new_examples.append((new_source.lower(), target.lower()))
  tsv_utils.write_tsv(new_examples, FLAGS.output)
예제 #21
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)
    for idx, (_, target) in enumerate(examples):
        if FLAGS.offset and idx < FLAGS.offset:
            continue
        if FLAGS.limit and idx >= FLAGS.limit:
            break
        print("Processing example %s." % idx)

        try:
            _ = sql_parser.parse_sql(target)
        except ValueError as e:
            print(e)
            # Retry parsing with verbose debugging.
            _ = sql_parser.parse_sql(target, verbose=True)
예제 #22
0
def induce_and_write_rules():
  """Induce and write set of rules."""
  examples = tsv_utils.read_tsv(FLAGS.input)
  config = induction_utils.InductionConfig(
      sample_size=FLAGS.sample_size,
      max_iterations=FLAGS.max_iterations,
      min_delta=FLAGS.min_delta,
      terminal_codelength=FLAGS.terminal_codelength,
      non_terminal_codelength=FLAGS.non_terminal_codelength,
      parse_sample=FLAGS.parse_sample,
      allow_repeated_target_nts=FLAGS.allow_repeated_target_nts,
      seed_exact_match=FLAGS.seed_exact_match,
      balance_parens=FLAGS.balance_parens,
  )
  induced_rules = induction_utils.induce_rules(examples, config)
  qcfg_file.write_rules(induced_rules, FLAGS.output)
예제 #23
0
def main(unused_argv):
    """Induce and write set of rules."""
    examples = tsv_utils.read_tsv(FLAGS.input)
    config = json_utils.json_file_to_dict(FLAGS.config)
    if not config.get("allow_duplicate_examples", True):
        examples = set([tuple(ex) for ex in examples])
    examples = sorted(examples, key=lambda e: (len(e[0]), e))

    seed_rules = set()
    # Add mannual seed rules.
    if FLAGS.seed_rules_file:
        for seed_rules_file in FLAGS.seed_rules_file:
            seed_rules |= set(qcfg_file.read_rules(seed_rules_file))

    target_grammar_rules = (target_grammar.load_rules_from_file(
        FLAGS.target_grammar) if FLAGS.target_grammar else None)

    num_partitions = config.get("num_partitions", 1)
    partition_to_examples = np.array_split(examples, num_partitions)
    induction_state = induction_utils.InductionState(FLAGS.output, config)
    if FLAGS.restore_partition or FLAGS.restore_step:
        # Restore from an existing induction state.
        induction_state.restore_state(FLAGS.restore_partition,
                                      FLAGS.restore_step)
    else:
        # Initialize the induction state with manual seed rules.
        induction_state.current_rules = seed_rules.copy()

    while induction_state.current_partition < num_partitions:
        current_examples = induction_utils.get_examples_up_to_partition(
            partition_to_examples, induction_state.current_partition)
        logging.info("Partition: %s, number of examples: %s.",
                     induction_state.current_partition, len(current_examples))
        # At the first step of each partition, we add a rule corresponding to
        # each example in the partition.
        if induction_state.current_step_in_partition == 0:
            induction_state.current_rules |= induction_utils.get_example_rules(
                partition_to_examples[induction_state.current_partition])
        policy = greedy_policy.GreedyPolicy(config,
                                            current_examples,
                                            seed_rules,
                                            target_grammar_rules,
                                            verbose=FLAGS.verbose)
        induce_rules_for_partition(policy, induction_state)
        induction_state.current_partition += 1
        induction_state.current_step_in_partition = 0
    qcfg_file.write_rules(induction_state.current_rules, FLAGS.output)
예제 #24
0
def main(unused_argv):
    config = json_utils.json_file_to_dict(FLAGS.config)
    wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules,
                                                    FLAGS.target_grammar,
                                                    FLAGS.verbose)
    _ = inference_utils.get_checkpoint(wrapper, FLAGS.model_dir,
                                       FLAGS.checkpoint)
    examples = tsv_utils.read_tsv(FLAGS.input)

    num_predictions_match = 0
    predictions = []
    for idx, example in enumerate(examples):
        if FLAGS.offset and idx < FLAGS.offset:
            continue
        if FLAGS.limit and idx >= FLAGS.limit:
            break

        if FLAGS.verbose:
            print("Processing example %s: (%s, %s)" %
                  (idx, example[0], example[1]))

        source = example[0]
        original_target = example[1]

        predicted_target = inference_parser.get_top_output(source, wrapper)
        if FLAGS.verbose:
            print("predicted_target: %s" % predicted_target)

        if predicted_target == original_target:
            num_predictions_match += 1
        else:
            if FLAGS.verbose:
                print("predictions do not match.")

        predictions.append(predicted_target)

    print("num_predictions_match: %s" % num_predictions_match)

    with gfile.GFile(FLAGS.output, "w") as txt_file:
        for prediction in predictions:
            txt_file.write("%s\n" % prediction)
예제 #25
0
def main(unused_argv):
    gold_examples = tsv_utils.read_tsv(FLAGS.gold)

    preds = []
    with gfile.GFile(FLAGS.predictions, "r") as f:
        for line in f:
            preds.append(line.rstrip())

    correct = 0
    incorrect = 0
    for pred, gold_example in zip(preds, gold_examples):
        if pred == gold_example[1]:
            correct += 1
        else:
            incorrect += 1
            print("Incorrect for example %s.\nTarget: %s\nPrediction: %s" %
                  (gold_example[0], gold_example[1], pred))

    print("correct: %s" % correct)
    print("incorrect: %s" % incorrect)
    print("pct: %s" % str(float(correct) / float(correct + incorrect)))
예제 #26
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)

    # First, randomly split examples.
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]

    # Initialize cache.
    cache = AtomAndCompoundCache()

    # Swap examples to meet atom constraint and maximize compound divergence.
    examples_1, examples_2 = mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn=cache.get_compounds,
        get_atoms_fn=cache.get_atoms,
        max_iterations=10000,
        max_divergence=None)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
예제 #27
0
def main(unused_argv):
    config = json_utils.json_file_to_dict(FLAGS.config)
    examples = tsv_utils.read_tsv(FLAGS.input)
    rules = qcfg_file.read_rules(FLAGS.rules)
    converter = example_converter.ExampleConverter(rules, config)

    slice_start = FLAGS.offset
    slice_end = FLAGS.limit if FLAGS.limit else None
    examples = examples[slice(slice_start, slice_end)]

    num_examples = 0
    writer = tf.io.TFRecordWriter(FLAGS.output)
    for idx, example in enumerate(examples):
        tf_example = converter.convert(example)
        writer.write(tf_example.SerializeToString())
        num_examples += 1
        if FLAGS.verbose:
            print("Processing example %s." % idx)
            print("(%s, %s)" % (example[0], example[1]))

    converter.print_max_sizes()
    print("Wrote %d examples." % num_examples)
예제 #28
0
def main(unused_argv):
    examples = tsv_utils.read_tsv(FLAGS.input)

    # First, randomly split examples.
    random.seed(FLAGS.seed)
    random.shuffle(examples)
    examples_1 = examples[:FLAGS.num_examples_1]
    examples_2 = examples[FLAGS.num_examples_1:]

    # Swap examples to meet atom constraint and maximize compound divergence.
    get_atoms_fn = (tmcd_utils.get_example_atoms_with_num_arguments
                    if FLAGS.get_atoms_with_num_arguments else
                    tmcd_utils.get_example_atoms)
    examples_1, examples_2 = mcd_utils.swap_examples(
        examples_1,
        examples_2,
        get_compounds_fn=tmcd_utils.get_example_compounds,
        get_atoms_fn=get_atoms_fn,
        max_iterations=1000,
        max_divergence=None,
        min_atom_count=FLAGS.min_atom_count)
    tsv_utils.write_tsv(examples_1, FLAGS.output_1)
    tsv_utils.write_tsv(examples_2, FLAGS.output_2)
예제 #29
0
def main(unused_argv):
    config = config_utils.json_file_to_dict(FLAGS.config)
    examples = tsv_utils.read_tsv(FLAGS.input)
    rules = qcfg_file.read_rules(FLAGS.rules)
    tokenizer = tokenization_utils.get_tokenizer(
        os.path.join(FLAGS.bert_dir, "vocab.txt"))
    converter = example_converter.ExampleConverter(rules, tokenizer, config)

    total_written = 0
    writer = tf.io.TFRecordWriter(FLAGS.output)
    for idx, example in enumerate(examples):
        if FLAGS.offset and idx < FLAGS.offset:
            continue
        if FLAGS.limit and idx >= FLAGS.limit:
            break
        print("Processing example %s." % idx)

        tf_example = converter.convert(example)
        writer.write(tf_example.SerializeToString())
        total_written += 1

    converter.print_max_sizes()
    print("Wrote %d examples." % total_written)
예제 #30
0
def main(unused_argv):
    config = config_utils.json_file_to_dict(FLAGS.config)
    wrapper = get_inference_wrapper(config)
    examples = tsv_utils.read_tsv(FLAGS.input)
    writer = get_summary_writer()

    if FLAGS.poll:
        last_checkpoint = None
        while True:
            checkpoint, step = get_checkpoint()
            if checkpoint == last_checkpoint:
                print("Waiting for new checkpoint...\nLast checkpoint: %s" %
                      last_checkpoint)
            else:
                run_inference(writer, wrapper, examples, checkpoint, step=step)
                last_checkpoint = checkpoint
            if step and step >= config["training_steps"]:
                # Stop eval job after completing eval for last training step.
                break
            time.sleep(10)
    else:
        checkpoint, _ = get_checkpoint()
        run_inference(writer, wrapper, examples, checkpoint)