def main(unused_argv): examples_1 = tsv_utils.read_tsv(FLAGS.input_1) examples_2 = tsv_utils.read_tsv(FLAGS.input_2) divergence = mcd_utils.measure_example_divergence( examples_1, examples_2, get_compounds_fn=tmcd_utils.get_example_compounds) print("Compound divergence: %s" % divergence)
def main(unused_argv): config = json_utils.json_file_to_dict(FLAGS.config) examples = tsv_utils.read_tsv(FLAGS.input) rules = qcfg_file.read_rules(FLAGS.rules) slice_start = FLAGS.offset slice_end = FLAGS.limit if FLAGS.limit else None examples = examples[slice(slice_start, slice_end)] def _convert_examples(pipeline): _ = (pipeline | "ImportExamples" >> beam.Create(examples) | "ConvertExamples" >> beam.ParDo(ConvertExampleFn(rules, config)) | "WriteExamples" >> beam.io.tfrecordio.WriteToTFRecord( FLAGS.output, coder=beam.coders.ProtoCoder(tf.train.Example))) pipeline_options = beam.options.pipeline_options.PipelineOptions( FLAGS.pipeline_options) with beam.Pipeline(pipeline_options) as pipeline: _convert_examples(pipeline) metrics = pipeline.result.metrics().query() for distribution in metrics["distributions"]: logging.info("max %s: %s", distribution.key.metric.name, distribution.committed.max) for counter in metrics["counters"]: logging.info("count %s: %s", counter.key.metric.name, counter.committed)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) new_examples = [] for source, target in examples: new_examples.append((nqg_tokenization.process_source(source), nqg_tokenization.process_target(target))) tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) rules = qcfg_file.read_rules(FLAGS.rules) print("Rules: %s" % rules) num_examples = 0 num_covered = 0 for idx, example in enumerate(examples): if FLAGS.offset and idx < FLAGS.offset: continue if FLAGS.limit and idx >= FLAGS.limit: break print("Processing example %s." % idx) print("Source: %s" % example[0]) print("Target: %s" % example[1]) source = example[0] gold_target = example[1] can_parse = qcfg_parser.can_parse(source, gold_target, rules, verbose=False) num_examples += 1 if can_parse: num_covered += 1 else: print("Output set does not contain gold target.") print("%s covered out of %s" % (num_covered, num_examples))
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) new_examples = [] for source, target in examples: new_source = string_utils.format_source(source) new_examples.append((new_source, target)) tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) rules = exact_match_utils.get_exact_match_rules(examples) # Sort by target. rules = list(rules) rules.sort(key=lambda x: x.target) qcfg_file.write_rules(rules, FLAGS.output)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) random.seed(FLAGS.seed) random.shuffle(examples) examples_1 = examples[:FLAGS.num_examples_1] examples_2 = examples[FLAGS.num_examples_1:] tsv_utils.write_tsv(examples_1, FLAGS.output_1) tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv): examples_1 = tsv_utils.read_tsv(FLAGS.input_1) examples_2 = tsv_utils.read_tsv(FLAGS.input_2) if examples_1 == examples_2: print("Examples are the same.") else: print("Examples are different.") if len(examples_1) != len(examples_2): print("Number of examples is different.") else: for idx, (example_1, example_2) in enumerate(zip(examples_1, examples_2)): if example_1 != example_2: print("First different example pair at idx %s:" % idx) print(example_1) print(example_2) break
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) examples_1, examples_2 = template_utils.split_by_template( examples, template_fn=funql_template_fn, max_num_examples_1=FLAGS.max_num_examples_1, seed=FLAGS.seed) tsv_utils.write_tsv(examples_1, FLAGS.output_1) tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) if FLAGS.use_target: sorted_examples = sorted(examples, key=lambda x: len(x[1].split(" "))) else: sorted_examples = sorted(examples, key=lambda x: len(x[0].split(" "))) examples_1 = sorted_examples[:FLAGS.num_examples] examples_2 = sorted_examples[FLAGS.num_examples:] tsv_utils.write_tsv(examples_1, FLAGS.output_1) tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(_): examples = tsv_utils.read_tsv(FLAGS.input, expected_num_columns=3) new_examples = [] for source, target, category in examples: if category == "primitive": if len(source.split()) != 1: raise ValueError(f"Invalid primitive: {source}") new_target = source else: new_target = cogs_converter.cogs_lf_to_funcall(target) new_examples.append((source, new_target)) tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv): formatted_db_id_to_db_id = {} for db_id in database_constants.DATABASES: formatted_db_id_to_db_id[db_id.lower()] = db_id formatted_db_id_to_db_id[db_id] = db_id examples = tsv_utils.read_tsv(FLAGS.input) with gfile.GFile(FLAGS.output, "w") as txt_file: for example in examples: db_id = example[0].split()[0].rstrip(":") db_id = formatted_db_id_to_db_id[db_id] txt_file.write("%s\t%s\n" % (example[1], db_id))
def main(unused_argv): examples_1 = tsv_utils.read_tsv(FLAGS.input_1) examples_2 = tsv_utils.read_tsv(FLAGS.input_2) atoms_1 = mcd_utils.get_all_atoms( examples_1, get_atoms_fn=tmcd_utils.get_example_atoms) num_examples = 0 num_examples_with_unseen_atom = 0 for example in examples_2: atoms = tmcd_utils.get_example_atoms(example) num_examples += 1 for atom in atoms: if atom not in atoms_1: print("New atom: %s" % atom) num_examples_with_unseen_atom += 1 break print("num_examples: %s" % num_examples) print("num_examples_with_unseen_atom: %s" % num_examples_with_unseen_atom) print("pct: %s" % (float(num_examples_with_unseen_atom) / num_examples))
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) new_examples = [] for source, target in examples: if check_person_name(source, target) and check_year(source, target): new_examples.append((source, target)) tsv_utils.write_tsv(new_examples, FLAGS.output) num_examples = len(examples) num_new_examples = len(new_examples) print("original examples: %d." % num_examples) print("new examples: %d." % num_new_examples) print("filtered examples: %d." % (num_examples - num_new_examples))
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) single_domain_examples = [] cross_domain_examples = [] for source, target in examples: if is_cross_domain(target): cross_domain_examples.append((source, target)) else: single_domain_examples.append((source, target)) print("len(cross_domain_examples): %s" % len(cross_domain_examples)) print("len(single_domain_examples): %s" % len(single_domain_examples)) tsv_utils.write_tsv(cross_domain_examples, FLAGS.output_cross) tsv_utils.write_tsv(single_domain_examples, FLAGS.output_single)
def merge_predictions(examples, filename): """Merge multiple predcition files into one.""" source_to_prediction = {} output_files = gfile.glob("%s-*-of-*" % filename) for output_file in output_files: predictions = tsv_utils.read_tsv(output_file) for prediction in predictions: source, predicted_target = prediction source_to_prediction[source] = predicted_target new_predictions = [] for example in examples: new_predictions.append((source_to_prediction[example[0]])) txt_utils.write_txt(new_predictions, filename)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) rules = set() for source, target in examples: rules |= get_number_rules(source, target) rules |= get_string_rules(source, target) rules |= get_datetime_exact_match(source, target) # Sort by target. rules = list(rules) rules.sort(key=lambda x: x.target) qcfg_file.write_rules(rules, FLAGS.output)
def main(unused_argv): splits = load_splits() examples = tsv_utils.read_tsv(FLAGS.input) example_id_to_example = { example_id: example for example_id, example in enumerate(examples) } for split, split_ids in splits.items(): examples = [] for split_id in split_ids: examples.append(example_id_to_example[split_id]) filename = os.path.join(FLAGS.output_dir, "%s.tsv" % split) tsv_utils.write_tsv(examples, filename)
def main(unused_argv): config = json_utils.json_file_to_dict(FLAGS.config) wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules, FLAGS.target_grammar, FLAGS.verbose) writer = None if FLAGS.write: write_dir = os.path.join(FLAGS.model_dir, FLAGS.subdir) writer = writer_utils.get_summary_writer(write_dir) examples = tsv_utils.read_tsv(FLAGS.input) fallback_predictions = [None] * len(examples) if FLAGS.fallback_predictions: fallback_predictions = txt_utils.read_txt(FLAGS.fallback_predictions) if len(examples) != len(fallback_predictions): raise ValueError("len(examples) != len(fallback_predictions).") slice_start = FLAGS.offset slice_end = FLAGS.limit if FLAGS.limit else None examples = examples[slice(slice_start, slice_end)] fallback_predictions = fallback_predictions[slice(slice_start, slice_end)] if FLAGS.poll: last_checkpoint = None while True: checkpoint, step = inference_utils.get_checkpoint( wrapper, FLAGS.model_dir, FLAGS.checkpoint) if checkpoint == last_checkpoint: logging.info( "Waiting for new checkpoint...\nLast checkpoint: %s", last_checkpoint) else: run_inference(writer, wrapper, examples, fallback_predictions, step=step) last_checkpoint = checkpoint if step and step >= config["training_steps"]: # Stop eval job after completing eval for last training step. break time.sleep(10) else: checkpoint, step = inference_utils.get_checkpoint( wrapper, FLAGS.model_dir, FLAGS.checkpoint) run_inference(writer, wrapper, examples, fallback_predictions, step=step)
def main(unused_argv): tables_json = load_json(FLAGS.tables) db_id_to_schema_string = {} for table_json in tables_json: db_id = table_json["db_id"].lower() db_id_to_schema_string[db_id] = _get_schema_string(table_json) examples = tsv_utils.read_tsv(FLAGS.input) new_examples = [] for source, target in examples: db_id = source.split()[0].rstrip(":") schema_string = db_id_to_schema_string[db_id] new_source = "%s%s" % (source, schema_string) new_examples.append((new_source.lower(), target.lower())) tsv_utils.write_tsv(new_examples, FLAGS.output)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) for idx, (_, target) in enumerate(examples): if FLAGS.offset and idx < FLAGS.offset: continue if FLAGS.limit and idx >= FLAGS.limit: break print("Processing example %s." % idx) try: _ = sql_parser.parse_sql(target) except ValueError as e: print(e) # Retry parsing with verbose debugging. _ = sql_parser.parse_sql(target, verbose=True)
def induce_and_write_rules(): """Induce and write set of rules.""" examples = tsv_utils.read_tsv(FLAGS.input) config = induction_utils.InductionConfig( sample_size=FLAGS.sample_size, max_iterations=FLAGS.max_iterations, min_delta=FLAGS.min_delta, terminal_codelength=FLAGS.terminal_codelength, non_terminal_codelength=FLAGS.non_terminal_codelength, parse_sample=FLAGS.parse_sample, allow_repeated_target_nts=FLAGS.allow_repeated_target_nts, seed_exact_match=FLAGS.seed_exact_match, balance_parens=FLAGS.balance_parens, ) induced_rules = induction_utils.induce_rules(examples, config) qcfg_file.write_rules(induced_rules, FLAGS.output)
def main(unused_argv): """Induce and write set of rules.""" examples = tsv_utils.read_tsv(FLAGS.input) config = json_utils.json_file_to_dict(FLAGS.config) if not config.get("allow_duplicate_examples", True): examples = set([tuple(ex) for ex in examples]) examples = sorted(examples, key=lambda e: (len(e[0]), e)) seed_rules = set() # Add mannual seed rules. if FLAGS.seed_rules_file: for seed_rules_file in FLAGS.seed_rules_file: seed_rules |= set(qcfg_file.read_rules(seed_rules_file)) target_grammar_rules = (target_grammar.load_rules_from_file( FLAGS.target_grammar) if FLAGS.target_grammar else None) num_partitions = config.get("num_partitions", 1) partition_to_examples = np.array_split(examples, num_partitions) induction_state = induction_utils.InductionState(FLAGS.output, config) if FLAGS.restore_partition or FLAGS.restore_step: # Restore from an existing induction state. induction_state.restore_state(FLAGS.restore_partition, FLAGS.restore_step) else: # Initialize the induction state with manual seed rules. induction_state.current_rules = seed_rules.copy() while induction_state.current_partition < num_partitions: current_examples = induction_utils.get_examples_up_to_partition( partition_to_examples, induction_state.current_partition) logging.info("Partition: %s, number of examples: %s.", induction_state.current_partition, len(current_examples)) # At the first step of each partition, we add a rule corresponding to # each example in the partition. if induction_state.current_step_in_partition == 0: induction_state.current_rules |= induction_utils.get_example_rules( partition_to_examples[induction_state.current_partition]) policy = greedy_policy.GreedyPolicy(config, current_examples, seed_rules, target_grammar_rules, verbose=FLAGS.verbose) induce_rules_for_partition(policy, induction_state) induction_state.current_partition += 1 induction_state.current_step_in_partition = 0 qcfg_file.write_rules(induction_state.current_rules, FLAGS.output)
def main(unused_argv): config = json_utils.json_file_to_dict(FLAGS.config) wrapper = inference_utils.get_inference_wrapper(config, FLAGS.rules, FLAGS.target_grammar, FLAGS.verbose) _ = inference_utils.get_checkpoint(wrapper, FLAGS.model_dir, FLAGS.checkpoint) examples = tsv_utils.read_tsv(FLAGS.input) num_predictions_match = 0 predictions = [] for idx, example in enumerate(examples): if FLAGS.offset and idx < FLAGS.offset: continue if FLAGS.limit and idx >= FLAGS.limit: break if FLAGS.verbose: print("Processing example %s: (%s, %s)" % (idx, example[0], example[1])) source = example[0] original_target = example[1] predicted_target = inference_parser.get_top_output(source, wrapper) if FLAGS.verbose: print("predicted_target: %s" % predicted_target) if predicted_target == original_target: num_predictions_match += 1 else: if FLAGS.verbose: print("predictions do not match.") predictions.append(predicted_target) print("num_predictions_match: %s" % num_predictions_match) with gfile.GFile(FLAGS.output, "w") as txt_file: for prediction in predictions: txt_file.write("%s\n" % prediction)
def main(unused_argv): gold_examples = tsv_utils.read_tsv(FLAGS.gold) preds = [] with gfile.GFile(FLAGS.predictions, "r") as f: for line in f: preds.append(line.rstrip()) correct = 0 incorrect = 0 for pred, gold_example in zip(preds, gold_examples): if pred == gold_example[1]: correct += 1 else: incorrect += 1 print("Incorrect for example %s.\nTarget: %s\nPrediction: %s" % (gold_example[0], gold_example[1], pred)) print("correct: %s" % correct) print("incorrect: %s" % incorrect) print("pct: %s" % str(float(correct) / float(correct + incorrect)))
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) # First, randomly split examples. random.seed(FLAGS.seed) random.shuffle(examples) examples_1 = examples[:FLAGS.num_examples_1] examples_2 = examples[FLAGS.num_examples_1:] # Initialize cache. cache = AtomAndCompoundCache() # Swap examples to meet atom constraint and maximize compound divergence. examples_1, examples_2 = mcd_utils.swap_examples( examples_1, examples_2, get_compounds_fn=cache.get_compounds, get_atoms_fn=cache.get_atoms, max_iterations=10000, max_divergence=None) tsv_utils.write_tsv(examples_1, FLAGS.output_1) tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv): config = json_utils.json_file_to_dict(FLAGS.config) examples = tsv_utils.read_tsv(FLAGS.input) rules = qcfg_file.read_rules(FLAGS.rules) converter = example_converter.ExampleConverter(rules, config) slice_start = FLAGS.offset slice_end = FLAGS.limit if FLAGS.limit else None examples = examples[slice(slice_start, slice_end)] num_examples = 0 writer = tf.io.TFRecordWriter(FLAGS.output) for idx, example in enumerate(examples): tf_example = converter.convert(example) writer.write(tf_example.SerializeToString()) num_examples += 1 if FLAGS.verbose: print("Processing example %s." % idx) print("(%s, %s)" % (example[0], example[1])) converter.print_max_sizes() print("Wrote %d examples." % num_examples)
def main(unused_argv): examples = tsv_utils.read_tsv(FLAGS.input) # First, randomly split examples. random.seed(FLAGS.seed) random.shuffle(examples) examples_1 = examples[:FLAGS.num_examples_1] examples_2 = examples[FLAGS.num_examples_1:] # Swap examples to meet atom constraint and maximize compound divergence. get_atoms_fn = (tmcd_utils.get_example_atoms_with_num_arguments if FLAGS.get_atoms_with_num_arguments else tmcd_utils.get_example_atoms) examples_1, examples_2 = mcd_utils.swap_examples( examples_1, examples_2, get_compounds_fn=tmcd_utils.get_example_compounds, get_atoms_fn=get_atoms_fn, max_iterations=1000, max_divergence=None, min_atom_count=FLAGS.min_atom_count) tsv_utils.write_tsv(examples_1, FLAGS.output_1) tsv_utils.write_tsv(examples_2, FLAGS.output_2)
def main(unused_argv): config = config_utils.json_file_to_dict(FLAGS.config) examples = tsv_utils.read_tsv(FLAGS.input) rules = qcfg_file.read_rules(FLAGS.rules) tokenizer = tokenization_utils.get_tokenizer( os.path.join(FLAGS.bert_dir, "vocab.txt")) converter = example_converter.ExampleConverter(rules, tokenizer, config) total_written = 0 writer = tf.io.TFRecordWriter(FLAGS.output) for idx, example in enumerate(examples): if FLAGS.offset and idx < FLAGS.offset: continue if FLAGS.limit and idx >= FLAGS.limit: break print("Processing example %s." % idx) tf_example = converter.convert(example) writer.write(tf_example.SerializeToString()) total_written += 1 converter.print_max_sizes() print("Wrote %d examples." % total_written)
def main(unused_argv): config = config_utils.json_file_to_dict(FLAGS.config) wrapper = get_inference_wrapper(config) examples = tsv_utils.read_tsv(FLAGS.input) writer = get_summary_writer() if FLAGS.poll: last_checkpoint = None while True: checkpoint, step = get_checkpoint() if checkpoint == last_checkpoint: print("Waiting for new checkpoint...\nLast checkpoint: %s" % last_checkpoint) else: run_inference(writer, wrapper, examples, checkpoint, step=step) last_checkpoint = checkpoint if step and step >= config["training_steps"]: # Stop eval job after completing eval for last training step. break time.sleep(10) else: checkpoint, _ = get_checkpoint() run_inference(writer, wrapper, examples, checkpoint)