예제 #1
0
def generate_data_for_problem(problem):
    """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
    training_gen, dev_gen, test_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

    num_train_shards = FLAGS.num_shards or 10
    tf.logging.info("Generating training data for %s.", problem)
    train_output_files = generator_utils.train_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
        num_train_shards)
    generator_utils.generate_files(training_gen(), train_output_files,
                                   FLAGS.max_cases)
    num_dev_shards = int(num_train_shards * 0.1)
    tf.logging.info("Generating development data for %s.", problem)
    dev_output_files = generator_utils.dev_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
        num_dev_shards)
    generator_utils.generate_files(dev_gen(), dev_output_files)
    num_test_shards = int(num_train_shards * 0.1)
    test_output_files = []
    test_gen_data = test_gen()
    if test_gen_data is not None:
        tf.logging.info("Generating test data for %s.", problem)
        test_output_files = generator_utils.test_data_filenames(
            problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
            num_test_shards)
        generator_utils.generate_files(test_gen_data, test_output_files)
    all_output_files = train_output_files + dev_output_files + test_output_files
    generator_utils.shuffle_dataset(all_output_files)
예제 #2
0
def generate_data_for_problem(problem):
  """Generate data for a problem in _SUPPORTED_PROBLEM_GENERATORS."""
  training_gen, dev_gen, test_gen = _SUPPORTED_PROBLEM_GENERATORS[problem]

  num_train_shards = FLAGS.num_shards or 10
  tf.logging.info("Generating training data for %s.", problem)
  train_output_files = generator_utils.train_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
      num_train_shards)
  generator_utils.generate_files(training_gen(), train_output_files,
                                 FLAGS.max_cases)
  num_dev_shards = int(num_train_shards * 0.1)
  tf.logging.info("Generating development data for %s.", problem)
  dev_output_files = generator_utils.dev_data_filenames(
      problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
      num_dev_shards)
  generator_utils.generate_files(dev_gen(), dev_output_files)
  num_test_shards = int(num_train_shards * 0.1)
  test_output_files = []
  test_gen_data = test_gen()
  if test_gen_data is not None:
    tf.logging.info("Generating test data for %s.", problem)
    test_output_files = generator_utils.test_data_filenames(
        problem + generator_utils.UNSHUFFLED_SUFFIX, FLAGS.data_dir,
        num_test_shards)
    generator_utils.generate_files(test_gen_data, test_output_files)
  all_output_files = train_output_files + dev_output_files + test_output_files
  generator_utils.shuffle_dataset(all_output_files)
예제 #3
0
 def test_filepaths(self, data_dir, num_shards, shuffled):
   file_basename = self.dataset_filename()
   if not shuffled:
     file_basename += generator_utils.UNSHUFFLED_SUFFIX
   return generator_utils.test_data_filenames(file_basename, data_dir,
                                              num_shards)
예제 #4
0
 def test_filepaths(self, data_dir, num_shards, shuffled):
   file_basename = self.dataset_filename()
   if not shuffled:
     file_basename += generator_utils.UNSHUFFLED_SUFFIX
   return generator_utils.test_data_filenames(file_basename, data_dir,
                                              num_shards)