Exemplo n.º 1
0
def main(_):
    usr_dir.import_usr_dir(FLAGS.t2t_usr_dir)

    # Calculate the list of problems to generate.
    problems = sorted(
        list(_SUPPORTED_PROBLEM_GENERATORS) + registry.list_base_problems())
    for exclude in FLAGS.exclude_problems.split(","):
        if exclude:
            problems = [p for p in problems if exclude not in p]
    if FLAGS.problem and FLAGS.problem[-1] == "*":
        problems = [p for p in problems if p.startswith(FLAGS.problem[:-1])]
    elif FLAGS.problem and "," in FLAGS.problem:
        problems = [p for p in problems if p in FLAGS.problem.split(",")]
    elif FLAGS.problem:
        problems = [p for p in problems if p == FLAGS.problem]
    else:
        problems = []

    # Remove TIMIT if paths are not given.
    if getattr(FLAGS, "timit_paths", None):
        problems = [p for p in problems if "timit" not in p]
    # Remove parsing if paths are not given.
    if getattr(FLAGS, "parsing_path", None):
        problems = [p for p in problems if "parsing_english_ptb" not in p]

    if not problems:
        problems_str = "\n  * ".join(
            sorted(
                list(_SUPPORTED_PROBLEM_GENERATORS) +
                registry.list_base_problems()))
        error_msg = ("You must specify one of the supported problems to "
                     "generate data for:\n  * " + problems_str + "\n")
        error_msg += ("TIMIT and parsing need data_sets specified with "
                      "--timit_paths and --parsing_path.")
        raise ValueError(error_msg)

    if not FLAGS.data_dir:
        FLAGS.data_dir = tempfile.gettempdir()
        tf.logging.warning(
            "It is strongly recommended to specify --data_dir. "
            "Data will be written to default data_dir=%s.", FLAGS.data_dir)
    FLAGS.data_dir = os.path.expanduser(FLAGS.data_dir)
    tf.gfile.MakeDirs(FLAGS.data_dir)

    tf.logging.info(
        "Generating problems:\n%s" %
        registry.display_list_by_prefix(problems, starting_spaces=4))
    if FLAGS.only_list:
        return
    for problem in problems:
        set_random_seed()

        if problem in _SUPPORTED_PROBLEM_GENERATORS:
            generate_data_for_problem(problem)
        else:
            generate_data_for_registered_problem(problem)
Exemplo n.º 2
0
def available():
    return registry.list_base_problems()