Example #1
0
def train(
    config_file: pathlib.Path,
    dev_file: Optional[pathlib.Path],
    device: str,
    fasttext: Optional[pathlib.Path],
    max_tree_length: Optional[int],
    output_dir: pathlib.Path,
    overwrite: bool,
    rand_seed: int,
    test_file: Optional[pathlib.Path],
    train_file: pathlib.Path,
):
    model_path = output_dir / "model"
    graph_parser.train(
        config_file=config_file,
        dev_file=dev_file,
        train_file=train_file,
        fasttext=fasttext,
        max_tree_length=max_tree_length,
        model_path=model_path,
        overrides={"device": device},
        overwrite=overwrite,
        rand_seed=rand_seed,
    )
    output_metrics = dict()
    if dev_file is not None:
        parsed_devset_path = output_dir / f"{dev_file.stem}.parsed.conllu"
        graph_parser.parse(model_path,
                           dev_file,
                           parsed_devset_path,
                           overrides={"device": device})
        gold_devset = evaluator.load_conllu_file(dev_file)
        syst_devset = evaluator.load_conllu_file(parsed_devset_path)
        dev_metrics = evaluator.evaluate(gold_devset, syst_devset)
        for m in ("UPOS", "LAS"):
            output_metrics[f"{m} (dev)"] = dev_metrics[m].f1

    if test_file is not None:
        parsed_testset_path = output_dir / f"{test_file.stem}.parsed.conllu"
        graph_parser.parse(model_path,
                           test_file,
                           parsed_testset_path,
                           overrides={"device": device})
        gold_testset = evaluator.load_conllu_file(test_file)
        syst_testset = evaluator.load_conllu_file(parsed_testset_path)
        test_metrics = evaluator.evaluate(gold_testset, syst_testset)
        for m in ("UPOS", "LAS"):
            output_metrics[f"{m} (test)"] = test_metrics[m].f1

    if output_metrics:
        click.echo(make_metrics_table(output_metrics))
Example #2
0
def make_csv_summary(
    syst_files: Iterable[pathlib.Path],
    gold_file: pathlib.Path,
    out_file: TextIO,
    onlyf: bool,
    metrics: List[str],
):
    gold_conllu = evaluator.load_conllu_file(gold_file)

    if onlyf:
        header = ["name", *metrics]
    else:
        header = ["name", *(f"{m}_{p}" for m in CONLL_METRICS for p in ("P", "R", "F"))]
    print(",".join(header), file=out_file)
    for syst_file in syst_files:
        syst_conllu = evaluator.load_conllu_file(syst_file)
        eval_metrics = evaluator.evaluate(gold_conllu, syst_conllu)
        row: List[str] = [syst_file.stem]
        for m in metrics:
            mres = eval_metrics[m]
            if onlyf:
                row.append(str(mres.f1))
            else:
                row.extend((str(mres.precision), str(mres.recall), str(mres.f1)))
        print(",".join(row), file=out_file)
Example #3
0
def evaluate(
    device: str,
    intermediary_dir: str,
    model_path: pathlib.Path,
    out_format: Literal["md", "json"],
    treebank_path: str,
):
    input_file: pathlib.Path
    with dir_manager(intermediary_dir) as intermediary_path:
        if treebank_path == "-":
            input_file = intermediary_path / "input.conllu"
            input_file.write_text(sys.stdin.read())
        else:
            input_file = pathlib.Path(treebank_path)

        output_file = intermediary_path / "parsed.conllu"
        graph_parser.parse(model_path,
                           input_file,
                           output_file,
                           overrides={"device": device})
        gold_set = evaluator.load_conllu_file(str(input_file))
        syst_set = evaluator.load_conllu_file(str(output_file))
    metrics = evaluator.evaluate(gold_set, syst_set)
    if out_format == "md":
        output_metrics = {n: metrics[n].f1 for n in ("UPOS", "UAS", "LAS")}
        click.echo(make_metrics_table(output_metrics))
    elif out_format == "json":
        json.dump({m: metrics[m].f1
                   for m in ("UPOS", "UAS", "LAS")}, sys.stdout)
    else:
        raise ValueError(f"Unkown format {out_format!r}.")
Example #4
0
def train_single_model(
    train_file: pathlib.Path,
    dev_file: pathlib.Path,
    pred_file: pathlib.Path,
    out_dir: pathlib.Path,
    config_path: pathlib.Path,
    device: str,
    additional_args: Dict[str, str],
) -> TrainResults:
    subprocess.run(
        [
            "graph_parser",
            "--train_file",
            str(train_file),
            "--dev_file",
            str(dev_file),
            "--pred_file",
            str(pred_file),
            "--out_dir",
            str(out_dir),
            "--device",
            device,
            *(
                a
                for key, value in additional_args.items()
                if value
                for a in (f"--{key}", value)
            ),
            str(config_path),
        ],
        check=True,
    )

    gold_devset = evaluator.load_conllu_file(dev_file)
    syst_devset = evaluator.load_conllu_file(out_dir / f"{dev_file.name}.parsed")
    dev_metrics = evaluator.evaluate(gold_devset, syst_devset)

    gold_testset = evaluator.load_conllu_file(pred_file)
    syst_testset = evaluator.load_conllu_file(out_dir / f"{pred_file.name}.parsed")
    test_metrics = evaluator.evaluate(gold_testset, syst_testset)

    return TrainResults(
        dev_upos=dev_metrics["UPOS"].f1,
        dev_las=dev_metrics["LAS"].f1,
        test_upos=test_metrics["UPOS"].f1,
        test_las=test_metrics["LAS"].f1,
    )
Example #5
0
def main():
    parser = argparse.ArgumentParser(
        description="Graph based Attention based dependency parser/tagger")
    parser.add_argument("config_file",
                        metavar="CONFIG_FILE",
                        type=str,
                        help="the configuration file")
    parser.add_argument("--train_file",
                        metavar="TRAIN_FILE",
                        type=str,
                        help="the conll training file")
    parser.add_argument("--dev_file",
                        metavar="DEV_FILE",
                        type=str,
                        help="the conll development file")
    parser.add_argument("--pred_file",
                        metavar="PRED_FILE",
                        type=str,
                        help="the conll file to parse")
    parser.add_argument(
        "--out_dir",
        metavar="OUT_DIR",
        type=str,
        help="the path of the output directory (defaults to the config dir)",
    )
    parser.add_argument(
        "--fasttext",
        metavar="PATH",
        help=
        "The path to either an existing FastText model or a raw text file to train one. If this option is absent, a model will be trained from the parsing train set.",
    )
    parser.add_argument(
        "--device",
        metavar="DEVICE",
        type=str,
        help=
        "the (torch) device to use for the parser. Supersedes configuration if given",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=
        "If a model already exists, restart training from scratch instead of continuing.",
    )

    args = parser.parse_args()
    if args.device is not None:
        overrides = {"device": args.device}
    else:
        overrides = dict()

    # TODO: warn about unused parameters in config
    config_file = os.path.abspath(args.config_file)
    if args.train_file and args.out_dir:
        model_dir = os.path.join(args.out_dir, "model")
        os.makedirs(model_dir, exist_ok=True)
        config_file = shutil.copy(args.config_file, model_dir)
    else:
        model_dir = os.path.dirname(config_file)

    with open(config_file) as in_stream:
        hp = yaml.load(in_stream, Loader=yaml.SafeLoader)

    if args.train_file and args.dev_file:
        # TRAIN MODE
        weights_file = os.path.join(model_dir, "model.pt")
        if os.path.exists(weights_file):
            print(f"Found existing trained model in {model_dir}",
                  file=sys.stderr)
            overwrite = args.overwrite
            if args.overwrite:
                print("Erasing it since --overwrite was asked",
                      file=sys.stderr)
                # Ensure the parser won't load existing weights
                os.remove(weights_file)
                overwrite = True
            else:
                print("Continuing training", file=sys.stderr)
                overwrite = False
        else:
            overwrite = True
        traintrees = DependencyDataset.read_conll(args.train_file,
                                                  max_tree_length=150)
        devtrees = DependencyDataset.read_conll(args.dev_file)

        if overwrite:
            fasttext_model_path = os.path.join(model_dir, "fasttext_model.bin")
            if args.fasttext is None:
                if os.path.exists(fasttext_model_path) and not args.out_dir:
                    print(f"Using the FastText model at {fasttext_model_path}")
                else:
                    if os.path.exists(fasttext_model_path):
                        print(
                            f"Erasing the FastText model at {fasttext_model_path} since --overwrite was asked",
                            file=sys.stderr,
                        )
                        os.remove(fasttext_model_path)
                    print(
                        f"Generating a FastText model from {args.train_file}")
                    FastTextTorch.train_model_from_sents(
                        [tree.words[1:] for tree in traintrees],
                        fasttext_model_path)
            elif os.path.exists(args.fasttext):
                if os.path.exists(fasttext_model_path):
                    os.remove(fasttext_model_path)
                try:
                    # ugly, but we have no better way of checking if a file is a valid model
                    FastTextTorch.loadmodel(args.fasttext)
                    print(f"Using the FastText model at {args.fasttext}")
                    shutil.copy(args.fasttext, fasttext_model_path)
                except ValueError:
                    # FastText couldn't load it, so it should be raw text
                    print(f"Generating a FastText model from {args.fasttext}")
                    FastTextTorch.train_model_from_raw(args.fasttext,
                                                       fasttext_model_path)
            else:
                raise ValueError(f"{args.fasttext} not found")

            # NOTE: these include the [ROOT] token, which will thus automatically have a dedicated
            # word embeddings in layers based on this vocab
            ordered_vocab = make_vocab(
                [word for tree in traintrees for word in tree.words],
                0,
                unk_word=DependencyDataset.UNK_WORD,
                pad_token=DependencyDataset.PAD_TOKEN,
            )
            savelist(ordered_vocab, os.path.join(model_dir, "vocab.lst"))

            # FIXME: A better save that can restore special tokens is probably a good idea
            ordered_charset = CharDataSet.from_words(
                ordered_vocab,
                special_tokens=[DepGraph.ROOT_TOKEN],
            )
            savelist(ordered_charset.i2c,
                     os.path.join(model_dir, "charcodes.lst"))

            itolab = gen_labels(traintrees)
            savelist(itolab, os.path.join(model_dir, "labcodes.lst"))

            itotag = gen_tags(traintrees)
            savelist(itotag, os.path.join(model_dir, "tagcodes.lst"))

        parser = BiAffineParser.from_config(config_file, overrides)

        ft_dataset = FastTextDataSet(parser.ft_lexer,
                                     special_tokens=[DepGraph.ROOT_TOKEN])
        trainset = DependencyDataset(
            traintrees,
            parser.lexer,
            parser.charset,
            ft_dataset,
            use_labels=parser.labels,
            use_tags=parser.tagset,
        )
        devset = DependencyDataset(
            devtrees,
            parser.lexer,
            parser.charset,
            ft_dataset,
            use_labels=parser.labels,
            use_tags=parser.tagset,
        )

        parser.train_model(
            trainset,
            devset,
            hp["epochs"],
            hp["batch_size"],
            hp["lr"],
            modelpath=weights_file,
        )
        print("training done.", file=sys.stderr)
        # Load final params
        parser.load_params(weights_file)
        parser.eval()
        if args.out_dir is not None:
            parsed_devset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.dev_file)}.parsed")
        else:
            parsed_devset_path = os.path.join(
                os.path.dirname(args.dev_file),
                f"{os.path.basename(args.dev_file)}.parsed",
            )
        with open(parsed_devset_path, "w") as ostream:
            parser.predict_batch(devset,
                                 ostream,
                                 hp["batch_size"],
                                 greedy=False)
        gold_devset = evaluator.load_conllu_file(args.dev_file)
        syst_devset = evaluator.load_conllu_file(parsed_devset_path)
        dev_metrics = evaluator.evaluate(gold_devset, syst_devset)
        print(
            f"Dev-best results: {dev_metrics['UPOS'].f1} UPOS\t{dev_metrics['UAS'].f1} UAS\t{dev_metrics['LAS'].f1} LAS",
            file=sys.stderr,
        )

    if args.pred_file:
        # TEST MODE
        parser = BiAffineParser.from_config(config_file, overrides)
        parser.eval()
        testtrees = DependencyDataset.read_conll(args.pred_file)
        # FIXME: the special tokens should be saved somewhere instead of hardcoded
        ft_dataset = FastTextDataSet(parser.ft_lexer,
                                     special_tokens=[DepGraph.ROOT_TOKEN])
        testset = DependencyDataset(
            testtrees,
            parser.lexer,
            parser.charset,
            ft_dataset,
            use_labels=parser.labels,
            use_tags=parser.tagset,
        )
        if args.out_dir is not None:
            parsed_testset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.pred_file)}.parsed")
        else:
            parsed_testset_path = os.path.join(
                os.path.dirname(args.pred_file),
                f"{os.path.basename(args.pred_file)}.parsed",
            )
        with open(parsed_testset_path, "w") as ostream:
            parser.predict_batch(testset,
                                 ostream,
                                 hp["batch_size"],
                                 greedy=False)
        print("parsing done.", file=sys.stderr)
Example #6
0
def main(argv=None):
    parser = argparse.ArgumentParser(
        description="Graph based Attention based dependency parser/tagger")
    parser.add_argument("config_file",
                        metavar="CONFIG_FILE",
                        type=str,
                        help="the configuration file")
    parser.add_argument("--train_file",
                        metavar="TRAIN_FILE",
                        type=str,
                        help="the conll training file")
    parser.add_argument("--dev_file",
                        metavar="DEV_FILE",
                        type=str,
                        help="the conll development file")
    parser.add_argument("--pred_file",
                        metavar="PRED_FILE",
                        type=str,
                        help="the conll file to parse")
    parser.add_argument(
        "--device",
        metavar="DEVICE",
        type=str,
        help=
        "the (torch) device to use for the parser. Supersedes configuration if given",
    )
    parser.add_argument(
        "--fasttext",
        metavar="PATH",
        help=
        "The path to either an existing FastText model or a raw text file to train one. If this option is absent, a model will be trained from the parsing train set.",
    )
    parser.add_argument(
        "--out_dir",
        metavar="OUT_DIR",
        type=str,
        help="the path of the output directory (defaults to the config dir)",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help=
        "If a model already exists, restart training from scratch instead of continuing.",
    )
    parser.add_argument(
        "--rand_seed",
        metavar="SEED",
        type=int,
        help=
        "Force the random seed fo Python and Pytorch (see <https://pytorch.org/docs/stable/notes/randomness.html> for notes on reproducibility)",
    )
    warnings.warn(
        "The `graph_parser` interface is DEPRECATED and will be removed in a future release, use `hopsparser train` instead",
        category=FutureWarning,
    )
    args = parser.parse_args(argv)

    if args.overwrite and not args.out_dir:
        print(
            "ERROR: overwriting is only supported with --out_dir",
            file=sys.stderr,
        )
        return 1

    if args.device is not None:
        overrides = {"device": args.device}
    else:
        overrides = dict()

    # TODO: warn about unused parameters in config
    if args.train_file and args.out_dir:
        model_dir = pathlib.Path(args.out_dir) / "model"
        config_file = pathlib.Path(args.config_file)
        trained_config_file = model_dir / "config.yaml"
    else:
        model_dir = pathlib.Path(args.config_file).parent
        # We need to give the temp file a name to avoid garbage collection before the method exits
        # this is not very clean but this code path will be removed soon anyway.
        _temp_config_file = tempfile.NamedTemporaryFile()
        shutil.copy(args.config_file, _temp_config_file.name)
        config_file = pathlib.Path(_temp_config_file.name)
        trained_config_file = pathlib.Path(args.config_file)

    if args.train_file and args.dev_file:
        print("Start training", file=sys.stderr)
        train(
            config_file=pathlib.Path(config_file),
            dev_file=pathlib.Path(args.dev_file),
            train_file=pathlib.Path(args.train_file),
            fasttext=(pathlib.Path(args.fasttext)
                      if args.fasttext is not None else None),
            max_tree_length=150,
            model_path=model_dir,
            overrides=overrides,
            overwrite=args.overwrite,
            rand_seed=args.rand_seed,
        )
        print("Training done.", file=sys.stderr)
        print("Parsing dev corpus", file=sys.stderr)
        if args.out_dir is not None:
            parsed_devset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.dev_file)}.parsed")
        else:
            parsed_devset_path = os.path.join(
                os.path.dirname(args.dev_file),
                f"{os.path.basename(args.dev_file)}.parsed",
            )
        parse(model_dir,
              args.dev_file,
              parsed_devset_path,
              overrides=overrides)
        gold_devset = evaluator.load_conllu_file(args.dev_file)
        syst_devset = evaluator.load_conllu_file(parsed_devset_path)
        dev_metrics = evaluator.evaluate(gold_devset, syst_devset)
        print(
            f"Dev-best results: {dev_metrics['UPOS'].f1} UPOS\t{dev_metrics['UAS'].f1} UAS\t{dev_metrics['LAS'].f1} LAS",
            file=sys.stderr,
        )

    if args.pred_file:
        # TEST MODE
        if args.out_dir is not None:
            parsed_testset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.pred_file)}.parsed")
        else:
            parsed_testset_path = os.path.join(
                os.path.dirname(args.pred_file),
                f"{os.path.basename(args.pred_file)}.parsed",
            )
        print("Parsing test corpus", file=sys.stderr)
        parse(
            trained_config_file,
            args.pred_file,
            parsed_testset_path,
            overrides=overrides,
        )
        print("Parsing done.", file=sys.stderr)
Example #7
0
def main(argv=None):
    parser = argparse.ArgumentParser(
        description="Graph based Attention based dependency parser/tagger"
    )
    parser.add_argument(
        "config_file", metavar="CONFIG_FILE", type=str, help="the configuration file"
    )
    parser.add_argument(
        "--train_file", metavar="TRAIN_FILE", type=str, help="the conll training file"
    )
    parser.add_argument(
        "--dev_file", metavar="DEV_FILE", type=str, help="the conll development file"
    )
    parser.add_argument(
        "--pred_file", metavar="PRED_FILE", type=str, help="the conll file to parse"
    )
    parser.add_argument(
        "--device",
        metavar="DEVICE",
        type=str,
        help="the (torch) device to use for the parser. Supersedes configuration if given",
    )
    parser.add_argument(
        "--fasttext",
        metavar="PATH",
        help="The path to either an existing FastText model or a raw text file to train one. If this option is absent, a model will be trained from the parsing train set.",
    )
    parser.add_argument(
        "--out_dir",
        metavar="OUT_DIR",
        type=str,
        help="the path of the output directory (defaults to the config dir)",
    )
    parser.add_argument(
        "--overwrite",
        action="store_true",
        help="If a model already exists, restart training from scratch instead of continuing.",
    )
    parser.add_argument(
        "--rand_seed",
        metavar="SEED",
        type=int,
        help="Force the random seed fo Python and Pytorch (see <https://pytorch.org/docs/stable/notes/randomness.html> for notes on reproducibility)",
    )

    args = parser.parse_args(argv)
    if args.rand_seed is not None:
        random.seed(args.rand_seed)
        torch.manual_seed(args.rand_seed)

    if args.device is not None:
        overrides = {"device": args.device}
    else:
        overrides = dict()

    # TODO: warn about unused parameters in config
    if args.train_file and args.out_dir:
        model_dir = pathlib.Path(args.out_dir) / "model"
        config_file = pathlib.Path(args.config_file)
        trained_config_file = model_dir / "config.yaml"
    else:
        model_dir = pathlib.Path(args.config_file).parent
        # We need to give the temp file a name to avoid garbage collection before the method exits
        # this is not very clean but this code path will be deprecated soon anyway.
        _temp_config_file = tempfile.NamedTemporaryFile()
        shutil.copy(args.config_file, _temp_config_file.name)
        config_file = pathlib.Path(_temp_config_file.name)
        trained_config_file = pathlib.Path(args.config_file)

    with open(config_file) as in_stream:
        hp = yaml.load(in_stream, Loader=yaml.SafeLoader)
    if "device" in hp:
        warnings.warn(
            "Setting a device directly in a configuration file is deprecated and will be removed in a future version. Use --device instead."
        )

    if args.train_file and args.dev_file:
        # TRAIN MODE
        traintrees = DependencyDataset.read_conll(args.train_file, max_tree_length=150)
        devtrees = DependencyDataset.read_conll(args.dev_file)
        if os.path.exists(model_dir) and not args.overwrite:
            print(f"Continuing training from {model_dir}", file=sys.stderr)
            parser = BiAffineParser.load(model_dir, overrides)
        else:
            if args.overwrite:
                if not args.out_dir:
                    print("ERROR: overwriting is only supported with --out_dir", file=sys.stderr)
                    return 1
                print(
                    f"Erasing existing trained model in {model_dir} since --overwrite was asked",
                    file=sys.stderr,
                )
                shutil.rmtree(model_dir)
            parser = BiAffineParser.initialize(
                config_path=config_file,
                model_path=model_dir,
                overrides=overrides,
                treebank=traintrees,
                fasttext=(
                    pathlib.Path(args.fasttext) if args.fasttext is not None else None
                ),
            )

        trainset = DependencyDataset(
            traintrees,
            parser.lexer,
            parser.char_rnn,
            parser.ft_lexer,
            use_labels=parser.labels,
            use_tags=parser.tagset,
        )
        devset = DependencyDataset(
            devtrees,
            parser.lexer,
            parser.char_rnn,
            parser.ft_lexer,
            use_labels=parser.labels,
            use_tags=parser.tagset,
        )

        parser.train_model(
            train_set=trainset,
            dev_set=devset,
            epochs=hp["epochs"],
            batch_size=hp["batch_size"],
            lr=hp["lr"],
            lr_schedule=hp.get(
                "lr_schedule", {"shape": "exponential", "warmup_steps": 0}
            ),
            model_path=model_dir,
        )
        print("training done.", file=sys.stderr)
        # Load final params
        parser.load_params(model_dir / "model.pt")
        parser.eval()
        if args.out_dir is not None:
            parsed_devset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.dev_file)}.parsed"
            )
        else:
            parsed_devset_path = os.path.join(
                os.path.dirname(args.dev_file),
                f"{os.path.basename(args.dev_file)}.parsed",
            )
        with open(parsed_devset_path, "w") as ostream:
            parser.predict_batch(devset, ostream, greedy=False)
        gold_devset = evaluator.load_conllu_file(args.dev_file)
        syst_devset = evaluator.load_conllu_file(parsed_devset_path)
        dev_metrics = evaluator.evaluate(gold_devset, syst_devset)
        print(
            f"Dev-best results: {dev_metrics['UPOS'].f1} UPOS\t{dev_metrics['UAS'].f1} UAS\t{dev_metrics['LAS'].f1} LAS",
            file=sys.stderr,
        )

    if args.pred_file:
        # TEST MODE
        if args.out_dir is not None:
            parsed_testset_path = os.path.join(
                args.out_dir, f"{os.path.basename(args.pred_file)}.parsed"
            )
        else:
            parsed_testset_path = os.path.join(
                os.path.dirname(args.pred_file),
                f"{os.path.basename(args.pred_file)}.parsed",
            )
        parse(trained_config_file, args.pred_file, parsed_testset_path, overrides=overrides)
        print("Parsing done.", file=sys.stderr)