Example #1
0
def creation_wrapper(process_dataset_fn):
  """Wrapper for creating the TFRecords files."""
  # Create the tf examples directory.
  if not tf.gfile.IsDirectory(FLAGS.tf_examples_dir):
    print('Creating TFExamples directory at ' + FLAGS.tf_examples_dir)
    tf.gfile.MkDir(FLAGS.tf_examples_dir)

  # Get the model config.
  model_config = load_config(FLAGS.config)

  for filename in FLAGS.filenames:
    if not filename:
      continue

    input_path = os.path.join(FLAGS.examples_dir, filename)
    output_path = os.path.join(
        FLAGS.tf_examples_dir,
        filename.split('/')[-1].split('.')[0] + '.tfrecords')

    permute = 'spider_train' in output_path and FLAGS.permute
    num_repeats = FLAGS.num_spider_repeats if permute else 1

    print('Processing %s. Permute: %r with %d repetitions' %
          (filename, permute, num_repeats))
    print('Writing to ' + output_path)

    process_dataset_fn(input_path, model_config, permute, num_repeats,
                       output_path)
def creation_wrapper(process_dataset_fn) -> None:
    """Wrapper for creating the TFRecords files."""
    # Create the tf examples directory.
    if not tf.gfile.IsDirectory(FLAGS.tf_examples_dir):
        print("Creating TFExamples directory at " + FLAGS.tf_examples_dir)
        tf.gfile.MkDir(FLAGS.tf_examples_dir)

    # Get the model config.
    model_config = load_config(FLAGS.config)

    for filename in FLAGS.filenames:
        if not filename:
            continue

        input_path = os.path.join(FLAGS.examples_dir, filename)
        output_path = os.path.join(
            FLAGS.tf_examples_dir,
            filename.split("/")[-1].split(".")[0] + ".tfrecords")

        permute = FLAGS.permute
        num_repeats = FLAGS.num_spider_repeats if permute else 1

        print("Processing %s. Permute: %r with %d repetition(s)" %
              (filename, permute, num_repeats))
        print("Writing to " + output_path)

        process_dataset_fn(input_path, model_config, permute, num_repeats,
                           output_path)
Example #3
0
def setup_graph():
    """Sets up the Tenorflow graph for inference."""
    # Set up the model for inference
    model_config = load_config(os.path.join(FLAGS.config_filepath))
    placeholder, features, labels = input_pipeline.create_placeholder_inputs(
        model_config.model_parameters.use_segment_ids,
        model_config.model_parameters.use_foreign_key_features,
        model_config.model_parameters.use_alignment_features)

    model_fn = model_builder.build_model_fn(model_config,
                                            FLAGS.output_vocab_filepath,
                                            beam_size=FLAGS.beam_size)
    mode = tf.estimator.ModeKeys.PREDICT
    predictions = model_fn(features, labels, mode).predictions
    saver = tf.train.Saver()

    return saver, placeholder, predictions
Example #4
0
def inference_wrapper(sharded: bool = False):
    """
    Wrapper for running inference.
    """

    assert (
        FLAGS.output_filepath
    ), "You must provide a filepath to write evaluation instructions to."
    assert FLAGS.output_filepath.endswith(".json"), "Output file must be .json"

    assert isinstance(FLAGS.splits,
                      list), f"Expected a list; got {FLAGS.splits}"
    assert len(
        FLAGS.splits) > 0, f"Expected a non-empty list; got {FLAGS.splits}"

    config = inference.Config(
        dataset_name=FLAGS.dataset_name,
        splits=FLAGS.splits,
        output_vocab_filepath=FLAGS.output_vocab_filepath,
        clean_output_vocab_filepath=FLAGS.clean_output_vocab_filepath,
        beam_size=FLAGS.beam_size,
        using_abstract_sql=FLAGS.using_abstract_sql,
        database_directory=FLAGS.database_directory,
        empty_database_directory=FLAGS.empty_database_directory,
        original_data_directory=FLAGS.original_data_directory,
        model_config=model_config.load_config(FLAGS.config_filepath),
    )

    examples = inference.load_tf_examples(FLAGS.input_tfrecords)

    print(f"Performing inference on {len(examples)} examples.")
    predictions = cached_fn_call(
        lambda: inference.inference(examples, FLAGS.checkpoint_filepath, config
                                    ),
        FLAGS.inference_cache_filepath,
    )

    # If using Abstract SQL, need to restore under-specified FROM clauses output above.
    if FLAGS.using_abstract_sql:
        is_spider = FLAGS.dataset_name.lower() == "spider"

        michigan_schema = None
        if not is_spider:
            michigan_schema = inference.load_schema_obj(
                FLAGS.dataset_name, FLAGS.original_data_directory)
        print(f"Restoring FROM clauses for {len(predictions)} predictions")
        predictions = cached_fn_call(
            lambda: restore_from_asql.restore_from_clauses(
                predictions,
                spider_examples_json=FLAGS.spider_examples_json
                if is_spider else "",
                spider_tables_json=FLAGS.spider_tables_json
                if is_spider else "",
                michigan_schema=michigan_schema,
                dataset_name=FLAGS.dataset_name,
                use_oracle_foreign_keys=FLAGS.use_oracle_foreign_keys,
            ),
            FLAGS.restored_sql_cache_filepath,
        )

    # Load the database tables.
    schema_obj = inference.load_schema_obj(FLAGS.dataset_name,
                                           FLAGS.original_data_directory)

    # Now match with the original data and save
    cached_fn_call(
        lambda: inference.match_with_dataset(config, predictions, schema_obj),
        FLAGS.output_filepath,
    )
Example #5
0
def main(unused_argv: Any) -> None:
    tf.logging.info("Saving model saves and results to " + FLAGS.model_dir)

    global_seed(42)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError("At least one of `do_train`, `do_eval` must be True.")

    config = model_config.load_config(FLAGS.config)

    if FLAGS.do_train:
        tf.logging.info("Training with train filenames: " +
                        str(FLAGS.training_filename))

    # Training allows noisy examples so do not use clean output vocab
    model_fn = model_builder.build_model_fn(config,
                                            FLAGS.output_vocab_filepath,
                                            clean_output_vocab_path="")

    # region training
    if FLAGS.do_train:
        # for keepsake CLI (helps track experiment results)
        experiment = keepsake.init(params={
            "learning_rate": config.training_options.optimizer_learning_rate,
            "batch_size": config.training_options.batch_size,
            "training_steps": config.training_options.training_steps,
            "eval_batch_size": FLAGS.eval_batch_size,
            "training_data": FLAGS.training_filename,
            "eval_data": FLAGS.eval_filename,
        }, )

        train_input_fn = input_pipeline.create_training_input_fn(
            config,
            FLAGS.tf_examples_dir,
            [name for name in FLAGS.training_filename if name],
        )

        train_features, train_labels = train_input_fn()
        train_model = model_fn(train_features, train_labels,
                               tf.estimator.ModeKeys.TRAIN)

        tf.get_variable_scope().reuse_variables()

        inference_config = inference.Config(
            FLAGS.eval_dataset_name,
            FLAGS.eval_splits.split(","),
            FLAGS.output_vocab_filepath,
            FLAGS.clean_output_vocab_filepath,
            FLAGS.eval_beam_size,
            FLAGS.using_abstract_sql,
            FLAGS.database_directory,
            FLAGS.empty_database_directory,
            FLAGS.original_data_directory,
            model_config.load_config(FLAGS.config),
        )

        saver = tf.train.Saver(max_to_keep=None)

        global_step = 0
        checkpoint = checkpoint_path(FLAGS.model_dir, global_step)

        validation_query_cache: Dict[str, Any] = {}

        with tf.Session() as init_sess:
            init_sess.run(tf.global_variables_initializer())
            saver.save(init_sess, checkpoint)

        while global_step < config.training_options.training_steps:
            # region training loop
            with tf.Session() as train_sess:
                tf.logging.info(
                    "Training from step %s to step %s",
                    global_step,
                    global_step + FLAGS.steps_between_saves,
                )
                saver.restore(train_sess, checkpoint)

                train_losses = []

                for step in range(FLAGS.steps_between_saves):
                    _, train_loss = train_sess.run(
                        [train_model.train_op, train_model.loss])

                    train_losses.append(train_loss)

                    if step % 100 == 0:
                        tf.logging.info(
                            "Step %s's training loss: %s",
                            global_step + step,
                            train_loss,
                        )

                train_loss = statistics.mean(train_losses)

                global_step += FLAGS.steps_between_saves
                checkpoint = checkpoint_path(FLAGS.model_dir, global_step)
                saver.save(train_sess, checkpoint)
            # endregion

            # region eval loop
            tf.logging.info("Evaluating checkpoint %s", checkpoint)

            examples = inference.load_tf_examples(
                os.path.join(FLAGS.tf_examples_dir, FLAGS.eval_filename))
            random.shuffle(examples)

            tf.logging.info("Running inference on %s", FLAGS.eval_filename)
            predictions = inference.inference(
                examples,
                checkpoint,
                inference_config,
            )

            examples_to_execute = get_examples_to_execute(
                predictions, inference_config)

            # Only update cache when it's empty
            should_update_cache = len(validation_query_cache) == 0

            # only scholar is case sensitive
            case_sensitive = "scholar" not in FLAGS.eval_dataset_name.lower()

            results, validation_query_cache = official_evaluation.execute_predictions(
                instructions=examples_to_execute,
                cache_dict=validation_query_cache,
                case_sensitive=case_sensitive,
                verbose=False,
                update_cache=should_update_cache,
            )

            metrics = official_evaluation.aggregate_metrics(
                results, FLAGS.use_empty_tables)
            tf.logging.info("Validation Results:\n\tExecution F1: %s",
                            metrics.execution_f1)
            # endregion

            experiment.checkpoint(
                step=global_step,
                metrics={
                    "train_loss": train_loss,
                    "eval_execution_f1": metrics.execution_f1,
                    "eval_string_match": metrics.string_same,
                },
                primary_metric=("eval_execution_f1", "maximize"),
            )

            # region disk management

            for step in checkpoints_to_delete(experiment):
                assert (
                    step != global_step
                ), f"Can't delete step {step}; need it for next training epoch starting at step {global_step}"
                print(f"Deleting checkpoint {step}")
                delete_checkpoint(FLAGS.model_dir, step)
Example #6
0
def main(unused_argv):
    tf.logging.info("Saving model saves and results to " + FLAGS.model_dir)

    if not FLAGS.do_train and not FLAGS.do_eval:
        raise ValueError("At least one of `do_train`, `do_eval` must be True.")

    config = model_config.load_config(FLAGS.config)

    if FLAGS.do_train:
        tf.logging.info("Training with train filenames: " +
                        str(FLAGS.training_filename))

    training_options = config.training_options
    use_tpu = FLAGS.use_tpu
    run_config = tf.contrib.tpu.RunConfig(
        master=FLAGS.master,
        model_dir=FLAGS.model_dir,
        save_summary_steps=1,
        save_checkpoints_steps=FLAGS.steps_between_saves,
        keep_checkpoint_max=KEEP_CHECKPOINTS_MAX,
        tpu_config=tf.contrib.tpu.TPUConfig(
            iterations_per_loop=training_options.tpu_iterations_per_loop,
            num_shards=FLAGS.num_tpu_shards))

    # Set up estimator
    model_fn = model_builder.build_model_fn(config, FLAGS.output_vocab,
                                            use_tpu)

    estimator = tf.contrib.tpu.TPUEstimator(
        model_fn=model_fn,
        use_tpu=use_tpu,
        config=run_config,
        train_batch_size=config.training_options.batch_size,
        eval_batch_size=FLAGS.eval_batch_size)

    if FLAGS.do_train:
        train_input_fn = input_pipeline.create_training_input_fn(
            config, FLAGS.tf_examples_dir,
            [name for name in FLAGS.training_filename if name], use_tpu)

        estimator.train(input_fn=train_input_fn,
                        max_steps=config.training_options.training_steps)

    if FLAGS.do_eval:
        max_acc = 0.

        eval_input_fn = input_pipeline.create_eval_input_fn(
            config, FLAGS.tf_examples_dir, [FLAGS.eval_filename], use_tpu)

        # When FLAGS.init_checkpoint = None, the latest checkpoint will be evaluated
        num_train_steps = int(config.training_options.training_steps)

        for ckpt in tf.contrib.training.checkpoints_iterator(FLAGS.model_dir):
            acc = evaluate(estimator, eval_input_fn, ckpt)
            if acc > max_acc:
                copy_checkpoint(
                    ckpt,
                    os.path.join(
                        FLAGS.model_dir,
                        str(get_ckpt_number(ckpt)) + "model_max_" +
                        FLAGS.eval_filename.split(".")[0] + ".ckpt"))
            if get_ckpt_number(ckpt) == num_train_steps:
                break