def testParseProblemName(self): problem_name = "base" self.assertEqual(problem_hparams.parse_problem_name(problem_name), ("base", False, False)) problem_name = "base_rev" self.assertEqual(problem_hparams.parse_problem_name(problem_name), ("base", True, False)) problem_name = "base_copy" self.assertEqual(problem_hparams.parse_problem_name(problem_name), ("base", False, True)) problem_name = "base_copy_rev" self.assertEqual(problem_hparams.parse_problem_name(problem_name), ("base", True, True)) problem_name = "base_rev_copy" self.assertEqual(problem_hparams.parse_problem_name(problem_name), ("base", True, True))
def get_datasets(problems, data_dir, mode): """Return the location of a dataset for a given mode.""" datasets = [] for problem in problems.split("-"): problem, _, _ = problem_hparams.parse_problem_name(problem) path = os.path.join(data_dir, problem) if mode == tf.contrib.learn.ModeKeys.TRAIN: datasets.append("%s-train*" % path) else: datasets.append("%s-dev*" % path) return datasets
def get_data_filepatterns(problems, data_dir, mode): """Return the location of a dataset for a given mode.""" datasets = [] for problem in problems.split("-"): try: problem = registry.problem(problem).dataset_filename() except ValueError: problem, _, _ = problem_hparams.parse_problem_name(problem) path = os.path.join(data_dir, problem) if mode == tf.estimator.ModeKeys.TRAIN: datasets.append("%s-train*" % path) else: datasets.append("%s-dev*" % path) return datasets