def main():
    parser = build_argparse()
    args = parser.parse_args()

    paths = default_paths.get_default_paths()
    for treebank in args.treebanks:
        process_treebank(treebank, paths, args.output_dir)
예제 #2
0
def main():
    paths = default_paths.get_default_paths()

    dataset_name = sys.argv[1]
    random.seed(1234)

    if dataset_name == 'fi_turku':
        process_turku(paths)
    elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'):
        process_languk(paths)
    elif dataset_name == 'hi_ijc':
        process_ijc(paths, dataset_name)
    elif dataset_name.endswith("FIRE2013"):
        process_fire_2013(paths, dataset_name)
    elif dataset_name.endswith('WikiNER'):
        process_wikiner(paths, dataset_name)
    elif dataset_name.startswith('hu_rgai'):
        process_rgai(paths, dataset_name)
    elif dataset_name == 'hu_nytk':
        process_nytk(paths)
    elif dataset_name == 'hu_combined':
        process_hu_combined(paths)
    elif dataset_name.endswith("_bsnlp19"):
        process_bsnlp(paths, dataset_name)
    else:
        raise ValueError(f"dataset {dataset_name} currently not handled")
예제 #3
0
def main(run_treebank, model_dir, model_name, add_specific_args=None):
    logger.info("Training program called with:\n" + " ".join(sys.argv))

    paths = default_paths.get_default_paths()

    parser = build_argparse()
    if add_specific_args is not None:
        add_specific_args(parser)
    if '--extra_args' in sys.argv:
        idx = sys.argv.index('--extra_args')
        extra_args = sys.argv[idx + 1:]
        command_args = parser.parse_args(sys.argv[:idx])
    else:
        command_args, extra_args = parser.parse_known_args()

    mode = command_args.mode
    treebanks = []

    for treebank in command_args.treebanks:
        # this is a really annoying typo to make if you copy/paste a
        # UD directory name on the cluster and your job dies 30s after
        # being queued for an hour
        if treebank.endswith("/"):
            treebank = treebank[:-1]
        if treebank.lower() in ('ud_all', 'all_ud'):
            ud_treebanks = common.get_ud_treebanks(paths["UDBASE"])
            treebanks.extend(ud_treebanks)
        else:
            treebanks.append(treebank)

    for treebank in treebanks:
        if SHORTNAME_RE.match(treebank):
            short_name = treebank
        else:
            short_name = treebank_to_short_name(treebank)
        logger.debug("%s: %s" % (treebank, short_name))

        if mode == Mode.TRAIN and not command_args.force and model_name != 'ete':
            model_path = "saved_models/%s/%s_%s.pt" % (model_dir, short_name,
                                                       model_name)
            if os.path.exists(model_path):
                logger.info("%s: %s exists, skipping!" %
                            (treebank, model_path))
                continue
            else:
                logger.info("%s: %s does not exist, training new model" %
                            (treebank, model_path))

        if command_args.temp_output and model_name != 'ete':
            with tempfile.NamedTemporaryFile() as temp_output_file:
                run_treebank(mode, paths, treebank, short_name,
                             temp_output_file.name, command_args, extra_args)
        else:
            run_treebank(mode, paths, treebank, short_name, None, command_args,
                         extra_args)
예제 #4
0
def main(process_treebank, add_specific_args=None):
    logger.info("Datasets program called with:\n" + " ".join(sys.argv))

    parser = build_argparse()
    if add_specific_args is not None:
        add_specific_args(parser)
    args = parser.parse_args()

    paths = default_paths.get_default_paths()

    treebanks = []
    for treebank in args.treebanks:
        if treebank.lower() in ('ud_all', 'all_ud'):
            ud_treebanks = get_ud_treebanks(paths["UDBASE"])
            treebanks.extend(ud_treebanks)
        else:
            treebanks.append(treebank)

    for treebank in treebanks:
        process_treebank(treebank, paths, args)
예제 #5
0
def main(run_treebank, model_dir, model_name, add_specific_args=None):
    """
    A main program for each of the run_xyz scripts

    It collects the arguments and runs the main method for each dataset provided.
    It also tries to look for an existing model and not overwrite it unless --force is provided
    """
    logger.info("Training program called with:\n" + " ".join(sys.argv))

    paths = default_paths.get_default_paths()

    parser = build_argparse()
    if add_specific_args is not None:
        add_specific_args(parser)
    if '--extra_args' in sys.argv:
        idx = sys.argv.index('--extra_args')
        extra_args = sys.argv[idx + 1:]
        command_args = parser.parse_args(sys.argv[1:idx])
    else:
        command_args, extra_args = parser.parse_known_args()

    # Pass this through to the underlying model as well as use it here
    if command_args.save_dir:
        extra_args.extend(["--save_dir", command_args.save_dir])

    mode = command_args.mode
    treebanks = []

    for treebank in command_args.treebanks:
        # this is a really annoying typo to make if you copy/paste a
        # UD directory name on the cluster and your job dies 30s after
        # being queued for an hour
        if treebank.endswith("/"):
            treebank = treebank[:-1]
        if treebank.lower() in ('ud_all', 'all_ud'):
            ud_treebanks = common.get_ud_treebanks(paths["UDBASE"])
            treebanks.extend(ud_treebanks)
        else:
            treebanks.append(treebank)

    for treebank_idx, treebank in enumerate(treebanks):
        if treebank_idx > 0:
            logger.info("=========================================")

        if SHORTNAME_RE.match(treebank):
            short_name = treebank
        else:
            short_name = treebank_to_short_name(treebank)
        logger.debug("%s: %s" % (treebank, short_name))

        if mode == Mode.TRAIN and not command_args.force and model_name != 'ete':
            if command_args.save_dir:
                model_path = "%s/%s_%s.pt" % (command_args.save_dir,
                                              short_name, model_name)
            else:
                model_path = "saved_models/%s/%s_%s.pt" % (
                    model_dir, short_name, model_name)
            if os.path.exists(model_path):
                logger.info("%s: %s exists, skipping!" %
                            (treebank, model_path))
                continue
            else:
                logger.info("%s: %s does not exist, training new model" %
                            (treebank, model_path))

        if command_args.temp_output and model_name != 'ete':
            with tempfile.NamedTemporaryFile() as temp_output_file:
                run_treebank(mode, paths, treebank, short_name,
                             temp_output_file.name, command_args, extra_args)
        else:
            run_treebank(mode, paths, treebank, short_name, None, command_args,
                         extra_args)
import glob
import os

from stanza.models.common.constant import treebank_to_short_name
from stanza.utils import default_paths

paths = default_paths.get_default_paths()
udbase = paths["UDBASE"]

directories = glob.glob(udbase + "/UD_*")
directories.sort()

output_name = os.path.join(
    os.path.split(__file__)[0], "short_name_to_treebank.py")
with open(output_name, "w") as fout:
    fout.write(
        "# This module is autogenerated by build_short_name_to_treebank.py\n")
    fout.write("# Please do not edit\n")
    fout.write("\n")
    fout.write("SHORT_NAMES = {\n")
    for ud_path in directories:
        ud_name = os.path.split(ud_path)[1]
        short_name = treebank_to_short_name(ud_name)
        fout.write("    '%s': '%s',\n" % (short_name, ud_name))

        if short_name.startswith("zh_"):
            short_name = "zh-hans_" + short_name[3:]
            fout.write("    '%s': '%s',\n" % (short_name, ud_name))

    fout.write("}\n")