def creation_wrapper(process_dataset_fn): """Wrapper for creating the TFRecords files.""" # Create the tf examples directory. if not tf.gfile.IsDirectory(FLAGS.tf_examples_dir): print('Creating TFExamples directory at ' + FLAGS.tf_examples_dir) tf.gfile.MkDir(FLAGS.tf_examples_dir) # Get the model config. model_config = load_config(FLAGS.config) for filename in FLAGS.filenames: if not filename: continue input_path = os.path.join(FLAGS.examples_dir, filename) output_path = os.path.join( FLAGS.tf_examples_dir, filename.split('/')[-1].split('.')[0] + '.tfrecords') permute = 'spider_train' in output_path and FLAGS.permute num_repeats = FLAGS.num_spider_repeats if permute else 1 print('Processing %s. Permute: %r with %d repetitions' % (filename, permute, num_repeats)) print('Writing to ' + output_path) process_dataset_fn(input_path, model_config, permute, num_repeats, output_path)
def creation_wrapper(process_dataset_fn) -> None: """Wrapper for creating the TFRecords files.""" # Create the tf examples directory. if not tf.gfile.IsDirectory(FLAGS.tf_examples_dir): print("Creating TFExamples directory at " + FLAGS.tf_examples_dir) tf.gfile.MkDir(FLAGS.tf_examples_dir) # Get the model config. model_config = load_config(FLAGS.config) for filename in FLAGS.filenames: if not filename: continue input_path = os.path.join(FLAGS.examples_dir, filename) output_path = os.path.join( FLAGS.tf_examples_dir, filename.split("/")[-1].split(".")[0] + ".tfrecords") permute = FLAGS.permute num_repeats = FLAGS.num_spider_repeats if permute else 1 print("Processing %s. Permute: %r with %d repetition(s)" % (filename, permute, num_repeats)) print("Writing to " + output_path) process_dataset_fn(input_path, model_config, permute, num_repeats, output_path)
def setup_graph(): """Sets up the Tenorflow graph for inference.""" # Set up the model for inference model_config = load_config(os.path.join(FLAGS.config_filepath)) placeholder, features, labels = input_pipeline.create_placeholder_inputs( model_config.model_parameters.use_segment_ids, model_config.model_parameters.use_foreign_key_features, model_config.model_parameters.use_alignment_features) model_fn = model_builder.build_model_fn(model_config, FLAGS.output_vocab_filepath, beam_size=FLAGS.beam_size) mode = tf.estimator.ModeKeys.PREDICT predictions = model_fn(features, labels, mode).predictions saver = tf.train.Saver() return saver, placeholder, predictions
def inference_wrapper(sharded: bool = False): """ Wrapper for running inference. """ assert ( FLAGS.output_filepath ), "You must provide a filepath to write evaluation instructions to." assert FLAGS.output_filepath.endswith(".json"), "Output file must be .json" assert isinstance(FLAGS.splits, list), f"Expected a list; got {FLAGS.splits}" assert len( FLAGS.splits) > 0, f"Expected a non-empty list; got {FLAGS.splits}" config = inference.Config( dataset_name=FLAGS.dataset_name, splits=FLAGS.splits, output_vocab_filepath=FLAGS.output_vocab_filepath, clean_output_vocab_filepath=FLAGS.clean_output_vocab_filepath, beam_size=FLAGS.beam_size, using_abstract_sql=FLAGS.using_abstract_sql, database_directory=FLAGS.database_directory, empty_database_directory=FLAGS.empty_database_directory, original_data_directory=FLAGS.original_data_directory, model_config=model_config.load_config(FLAGS.config_filepath), ) examples = inference.load_tf_examples(FLAGS.input_tfrecords) print(f"Performing inference on {len(examples)} examples.") predictions = cached_fn_call( lambda: inference.inference(examples, FLAGS.checkpoint_filepath, config ), FLAGS.inference_cache_filepath, ) # If using Abstract SQL, need to restore under-specified FROM clauses output above. if FLAGS.using_abstract_sql: is_spider = FLAGS.dataset_name.lower() == "spider" michigan_schema = None if not is_spider: michigan_schema = inference.load_schema_obj( FLAGS.dataset_name, FLAGS.original_data_directory) print(f"Restoring FROM clauses for {len(predictions)} predictions") predictions = cached_fn_call( lambda: restore_from_asql.restore_from_clauses( predictions, spider_examples_json=FLAGS.spider_examples_json if is_spider else "", spider_tables_json=FLAGS.spider_tables_json if is_spider else "", michigan_schema=michigan_schema, dataset_name=FLAGS.dataset_name, use_oracle_foreign_keys=FLAGS.use_oracle_foreign_keys, ), FLAGS.restored_sql_cache_filepath, ) # Load the database tables. schema_obj = inference.load_schema_obj(FLAGS.dataset_name, FLAGS.original_data_directory) # Now match with the original data and save cached_fn_call( lambda: inference.match_with_dataset(config, predictions, schema_obj), FLAGS.output_filepath, )
def main(unused_argv: Any) -> None: tf.logging.info("Saving model saves and results to " + FLAGS.model_dir) global_seed(42) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError("At least one of `do_train`, `do_eval` must be True.") config = model_config.load_config(FLAGS.config) if FLAGS.do_train: tf.logging.info("Training with train filenames: " + str(FLAGS.training_filename)) # Training allows noisy examples so do not use clean output vocab model_fn = model_builder.build_model_fn(config, FLAGS.output_vocab_filepath, clean_output_vocab_path="") # region training if FLAGS.do_train: # for keepsake CLI (helps track experiment results) experiment = keepsake.init(params={ "learning_rate": config.training_options.optimizer_learning_rate, "batch_size": config.training_options.batch_size, "training_steps": config.training_options.training_steps, "eval_batch_size": FLAGS.eval_batch_size, "training_data": FLAGS.training_filename, "eval_data": FLAGS.eval_filename, }, ) train_input_fn = input_pipeline.create_training_input_fn( config, FLAGS.tf_examples_dir, [name for name in FLAGS.training_filename if name], ) train_features, train_labels = train_input_fn() train_model = model_fn(train_features, train_labels, tf.estimator.ModeKeys.TRAIN) tf.get_variable_scope().reuse_variables() inference_config = inference.Config( FLAGS.eval_dataset_name, FLAGS.eval_splits.split(","), FLAGS.output_vocab_filepath, FLAGS.clean_output_vocab_filepath, FLAGS.eval_beam_size, FLAGS.using_abstract_sql, FLAGS.database_directory, FLAGS.empty_database_directory, FLAGS.original_data_directory, model_config.load_config(FLAGS.config), ) saver = tf.train.Saver(max_to_keep=None) global_step = 0 checkpoint = checkpoint_path(FLAGS.model_dir, global_step) validation_query_cache: Dict[str, Any] = {} with tf.Session() as init_sess: init_sess.run(tf.global_variables_initializer()) saver.save(init_sess, checkpoint) while global_step < config.training_options.training_steps: # region training loop with tf.Session() as train_sess: tf.logging.info( "Training from step %s to step %s", global_step, global_step + FLAGS.steps_between_saves, ) saver.restore(train_sess, checkpoint) train_losses = [] for step in range(FLAGS.steps_between_saves): _, train_loss = train_sess.run( [train_model.train_op, train_model.loss]) train_losses.append(train_loss) if step % 100 == 0: tf.logging.info( "Step %s's training loss: %s", global_step + step, train_loss, ) train_loss = statistics.mean(train_losses) global_step += FLAGS.steps_between_saves checkpoint = checkpoint_path(FLAGS.model_dir, global_step) saver.save(train_sess, checkpoint) # endregion # region eval loop tf.logging.info("Evaluating checkpoint %s", checkpoint) examples = inference.load_tf_examples( os.path.join(FLAGS.tf_examples_dir, FLAGS.eval_filename)) random.shuffle(examples) tf.logging.info("Running inference on %s", FLAGS.eval_filename) predictions = inference.inference( examples, checkpoint, inference_config, ) examples_to_execute = get_examples_to_execute( predictions, inference_config) # Only update cache when it's empty should_update_cache = len(validation_query_cache) == 0 # only scholar is case sensitive case_sensitive = "scholar" not in FLAGS.eval_dataset_name.lower() results, validation_query_cache = official_evaluation.execute_predictions( instructions=examples_to_execute, cache_dict=validation_query_cache, case_sensitive=case_sensitive, verbose=False, update_cache=should_update_cache, ) metrics = official_evaluation.aggregate_metrics( results, FLAGS.use_empty_tables) tf.logging.info("Validation Results:\n\tExecution F1: %s", metrics.execution_f1) # endregion experiment.checkpoint( step=global_step, metrics={ "train_loss": train_loss, "eval_execution_f1": metrics.execution_f1, "eval_string_match": metrics.string_same, }, primary_metric=("eval_execution_f1", "maximize"), ) # region disk management for step in checkpoints_to_delete(experiment): assert ( step != global_step ), f"Can't delete step {step}; need it for next training epoch starting at step {global_step}" print(f"Deleting checkpoint {step}") delete_checkpoint(FLAGS.model_dir, step)
def main(unused_argv): tf.logging.info("Saving model saves and results to " + FLAGS.model_dir) if not FLAGS.do_train and not FLAGS.do_eval: raise ValueError("At least one of `do_train`, `do_eval` must be True.") config = model_config.load_config(FLAGS.config) if FLAGS.do_train: tf.logging.info("Training with train filenames: " + str(FLAGS.training_filename)) training_options = config.training_options use_tpu = FLAGS.use_tpu run_config = tf.contrib.tpu.RunConfig( master=FLAGS.master, model_dir=FLAGS.model_dir, save_summary_steps=1, save_checkpoints_steps=FLAGS.steps_between_saves, keep_checkpoint_max=KEEP_CHECKPOINTS_MAX, tpu_config=tf.contrib.tpu.TPUConfig( iterations_per_loop=training_options.tpu_iterations_per_loop, num_shards=FLAGS.num_tpu_shards)) # Set up estimator model_fn = model_builder.build_model_fn(config, FLAGS.output_vocab, use_tpu) estimator = tf.contrib.tpu.TPUEstimator( model_fn=model_fn, use_tpu=use_tpu, config=run_config, train_batch_size=config.training_options.batch_size, eval_batch_size=FLAGS.eval_batch_size) if FLAGS.do_train: train_input_fn = input_pipeline.create_training_input_fn( config, FLAGS.tf_examples_dir, [name for name in FLAGS.training_filename if name], use_tpu) estimator.train(input_fn=train_input_fn, max_steps=config.training_options.training_steps) if FLAGS.do_eval: max_acc = 0. eval_input_fn = input_pipeline.create_eval_input_fn( config, FLAGS.tf_examples_dir, [FLAGS.eval_filename], use_tpu) # When FLAGS.init_checkpoint = None, the latest checkpoint will be evaluated num_train_steps = int(config.training_options.training_steps) for ckpt in tf.contrib.training.checkpoints_iterator(FLAGS.model_dir): acc = evaluate(estimator, eval_input_fn, ckpt) if acc > max_acc: copy_checkpoint( ckpt, os.path.join( FLAGS.model_dir, str(get_ckpt_number(ckpt)) + "model_max_" + FLAGS.eval_filename.split(".")[0] + ".ckpt")) if get_ckpt_number(ckpt) == num_train_steps: break