예제 #1
0
 def testParseProblemName(self):
     problem_name = "base"
     self.assertEqual(problem_hparams.parse_problem_name(problem_name),
                      ("base", False, False))
     problem_name = "base_rev"
     self.assertEqual(problem_hparams.parse_problem_name(problem_name),
                      ("base", True, False))
     problem_name = "base_copy"
     self.assertEqual(problem_hparams.parse_problem_name(problem_name),
                      ("base", False, True))
     problem_name = "base_copy_rev"
     self.assertEqual(problem_hparams.parse_problem_name(problem_name),
                      ("base", True, True))
     problem_name = "base_rev_copy"
     self.assertEqual(problem_hparams.parse_problem_name(problem_name),
                      ("base", True, True))
예제 #2
0
def get_datasets(problems, data_dir, mode):
  """Return the location of a dataset for a given mode."""
  datasets = []
  for problem in problems.split("-"):
    problem, _, _ = problem_hparams.parse_problem_name(problem)
    path = os.path.join(data_dir, problem)
    if mode == tf.contrib.learn.ModeKeys.TRAIN:
      datasets.append("%s-train*" % path)
    else:
      datasets.append("%s-dev*" % path)
  return datasets
예제 #3
0
def get_data_filepatterns(problems, data_dir, mode):
  """Return the location of a dataset for a given mode."""
  datasets = []
  for problem in problems.split("-"):
    try:
      problem = registry.problem(problem).dataset_filename()
    except ValueError:
      problem, _, _ = problem_hparams.parse_problem_name(problem)
    path = os.path.join(data_dir, problem)
    if mode == tf.estimator.ModeKeys.TRAIN:
      datasets.append("%s-train*" % path)
    else:
      datasets.append("%s-dev*" % path)
  return datasets