def main(): parser = build_argparse() args = parser.parse_args() paths = default_paths.get_default_paths() for treebank in args.treebanks: process_treebank(treebank, paths, args.output_dir)
def main(): paths = default_paths.get_default_paths() dataset_name = sys.argv[1] random.seed(1234) if dataset_name == 'fi_turku': process_turku(paths) elif dataset_name in ('uk_languk', 'Ukranian_languk', 'Ukranian-languk'): process_languk(paths) elif dataset_name == 'hi_ijc': process_ijc(paths, dataset_name) elif dataset_name.endswith("FIRE2013"): process_fire_2013(paths, dataset_name) elif dataset_name.endswith('WikiNER'): process_wikiner(paths, dataset_name) elif dataset_name.startswith('hu_rgai'): process_rgai(paths, dataset_name) elif dataset_name == 'hu_nytk': process_nytk(paths) elif dataset_name == 'hu_combined': process_hu_combined(paths) elif dataset_name.endswith("_bsnlp19"): process_bsnlp(paths, dataset_name) else: raise ValueError(f"dataset {dataset_name} currently not handled")
def main(run_treebank, model_dir, model_name, add_specific_args=None): logger.info("Training program called with:\n" + " ".join(sys.argv)) paths = default_paths.get_default_paths() parser = build_argparse() if add_specific_args is not None: add_specific_args(parser) if '--extra_args' in sys.argv: idx = sys.argv.index('--extra_args') extra_args = sys.argv[idx + 1:] command_args = parser.parse_args(sys.argv[:idx]) else: command_args, extra_args = parser.parse_known_args() mode = command_args.mode treebanks = [] for treebank in command_args.treebanks: # this is a really annoying typo to make if you copy/paste a # UD directory name on the cluster and your job dies 30s after # being queued for an hour if treebank.endswith("/"): treebank = treebank[:-1] if treebank.lower() in ('ud_all', 'all_ud'): ud_treebanks = common.get_ud_treebanks(paths["UDBASE"]) treebanks.extend(ud_treebanks) else: treebanks.append(treebank) for treebank in treebanks: if SHORTNAME_RE.match(treebank): short_name = treebank else: short_name = treebank_to_short_name(treebank) logger.debug("%s: %s" % (treebank, short_name)) if mode == Mode.TRAIN and not command_args.force and model_name != 'ete': model_path = "saved_models/%s/%s_%s.pt" % (model_dir, short_name, model_name) if os.path.exists(model_path): logger.info("%s: %s exists, skipping!" % (treebank, model_path)) continue else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) if command_args.temp_output and model_name != 'ete': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args) else: run_treebank(mode, paths, treebank, short_name, None, command_args, extra_args)
def main(process_treebank, add_specific_args=None): logger.info("Datasets program called with:\n" + " ".join(sys.argv)) parser = build_argparse() if add_specific_args is not None: add_specific_args(parser) args = parser.parse_args() paths = default_paths.get_default_paths() treebanks = [] for treebank in args.treebanks: if treebank.lower() in ('ud_all', 'all_ud'): ud_treebanks = get_ud_treebanks(paths["UDBASE"]) treebanks.extend(ud_treebanks) else: treebanks.append(treebank) for treebank in treebanks: process_treebank(treebank, paths, args)
def main(run_treebank, model_dir, model_name, add_specific_args=None): """ A main program for each of the run_xyz scripts It collects the arguments and runs the main method for each dataset provided. It also tries to look for an existing model and not overwrite it unless --force is provided """ logger.info("Training program called with:\n" + " ".join(sys.argv)) paths = default_paths.get_default_paths() parser = build_argparse() if add_specific_args is not None: add_specific_args(parser) if '--extra_args' in sys.argv: idx = sys.argv.index('--extra_args') extra_args = sys.argv[idx + 1:] command_args = parser.parse_args(sys.argv[1:idx]) else: command_args, extra_args = parser.parse_known_args() # Pass this through to the underlying model as well as use it here if command_args.save_dir: extra_args.extend(["--save_dir", command_args.save_dir]) mode = command_args.mode treebanks = [] for treebank in command_args.treebanks: # this is a really annoying typo to make if you copy/paste a # UD directory name on the cluster and your job dies 30s after # being queued for an hour if treebank.endswith("/"): treebank = treebank[:-1] if treebank.lower() in ('ud_all', 'all_ud'): ud_treebanks = common.get_ud_treebanks(paths["UDBASE"]) treebanks.extend(ud_treebanks) else: treebanks.append(treebank) for treebank_idx, treebank in enumerate(treebanks): if treebank_idx > 0: logger.info("=========================================") if SHORTNAME_RE.match(treebank): short_name = treebank else: short_name = treebank_to_short_name(treebank) logger.debug("%s: %s" % (treebank, short_name)) if mode == Mode.TRAIN and not command_args.force and model_name != 'ete': if command_args.save_dir: model_path = "%s/%s_%s.pt" % (command_args.save_dir, short_name, model_name) else: model_path = "saved_models/%s/%s_%s.pt" % ( model_dir, short_name, model_name) if os.path.exists(model_path): logger.info("%s: %s exists, skipping!" % (treebank, model_path)) continue else: logger.info("%s: %s does not exist, training new model" % (treebank, model_path)) if command_args.temp_output and model_name != 'ete': with tempfile.NamedTemporaryFile() as temp_output_file: run_treebank(mode, paths, treebank, short_name, temp_output_file.name, command_args, extra_args) else: run_treebank(mode, paths, treebank, short_name, None, command_args, extra_args)
import glob import os from stanza.models.common.constant import treebank_to_short_name from stanza.utils import default_paths paths = default_paths.get_default_paths() udbase = paths["UDBASE"] directories = glob.glob(udbase + "/UD_*") directories.sort() output_name = os.path.join( os.path.split(__file__)[0], "short_name_to_treebank.py") with open(output_name, "w") as fout: fout.write( "# This module is autogenerated by build_short_name_to_treebank.py\n") fout.write("# Please do not edit\n") fout.write("\n") fout.write("SHORT_NAMES = {\n") for ud_path in directories: ud_name = os.path.split(ud_path)[1] short_name = treebank_to_short_name(ud_name) fout.write(" '%s': '%s',\n" % (short_name, ud_name)) if short_name.startswith("zh_"): short_name = "zh-hans_" + short_name[3:] fout.write(" '%s': '%s',\n" % (short_name, ud_name)) fout.write("}\n")