Ejemplo n.º 1
0
def run_test(args):
    print("Loading test trees from {}...".format(args.test_path))
    test_treebank = treebanks.load_trees(args.test_path, args.test_path_text,
                                         args.text_processing)
    print("Loaded {:,} test examples.".format(len(test_treebank)))

    if len(args.model_path) != 1:
        raise NotImplementedError("Ensembling multiple parsers is not "
                                  "implemented in this version of the code.")

    model_path = args.model_path[0]
    print("Loading model from {}...".format(model_path))
    parser = parse_chart.ChartParser.from_trained(model_path)
    if args.no_predict_tags and parser.f_tag is not None:
        print("Removing part-of-speech tagging head...")
        parser.f_tag = None
    if args.parallelize:
        parser.parallelize()
    elif torch.cuda.is_available():
        parser.cuda()

    print("Parsing test sentences...")
    start_time = time.time()

    test_predicted = parser.parse(
        test_treebank.without_gold_annotations(),
        subbatch_max_tokens=args.subbatch_max_tokens,
    )

    if args.output_path == "-":
        for tree in test_predicted:
            print(tree.pformat(margin=1e100))
    elif args.output_path:
        with open(args.output_path, "w") as outfile:
            for tree in test_predicted:
                outfile.write("{}\n".format(tree.pformat(margin=1e100)))

    # The tree loader does some preprocessing to the trees (e.g. stripping TOP
    # symbols or SPMRL morphological features). We compare with the input file
    # directly to be extra careful about not corrupting the evaluation. We also
    # allow specifying a separate "raw" file for the gold trees: the inputs to
    # our parser have traces removed and may have predicted tags substituted,
    # and we may wish to compare against the raw gold trees to make sure we
    # haven't made a mistake. As far as we can tell all of these variations give
    # equivalent results.
    ref_gold_path = args.test_path
    if args.test_path_raw is not None:
        print("Comparing with raw trees from", args.test_path_raw)
        ref_gold_path = args.test_path_raw

    test_fscore = evaluate.evalb(args.evalb_dir,
                                 test_treebank.trees,
                                 test_predicted,
                                 ref_gold_path=ref_gold_path)

    print("test-fscore {} "
          "test-elapsed {}".format(
              test_fscore,
              format_elapsed(start_time),
          ))
def run_test(args):
    print("Loading test trees from {}...".format(args.test_path))
    test_treebank = treebanks.load_trees(
        args.test_path, args.test_path_text, args.text_processing
    )
    print("Loaded {:,} test examples.".format(len(test_treebank)))

    print("Loading model from {}...".format(args.model_path))
    parser = Parser(args.model_path, batch_size=args.batch_size)

    print("Parsing test sentences...")
    start_time = time.time()

    if args.output_path == "-":
        output_file = sys.stdout
    elif args.output_path:
        output_file = open(args.output_path, "w")
    else:
        output_file = None

    test_predicted = []
    for predicted_tree in parser.parse_sents(
        inputs_from_treebank(test_treebank, predict_tags=args.predict_tags)
    ):
        test_predicted.append(predicted_tree)
        if output_file is not None:
            print(tree.pformat(margin=1e100), file=output_file)

    test_fscore = evaluate.evalb(args.evalb_dir, test_treebank.trees, test_predicted)

    print(
        "test-fscore {} "
        "test-elapsed {}".format(
            test_fscore,
            format_elapsed(start_time),
        )
    )
Ejemplo n.º 3
0
def run_train(args, hparams):
    if args.numpy_seed is not None:
        print("Setting numpy random seed to {}...".format(args.numpy_seed))
        np.random.seed(args.numpy_seed)

    # Make sure that pytorch is actually being initialized randomly.
    # On my cluster I was getting highly correlated results from multiple
    # runs, but calling reset_parameters() changed that. A brief look at the
    # pytorch source code revealed that pytorch initializes its RNG by
    # calling std::random_device, which according to the C++ spec is allowed
    # to be deterministic.
    seed_from_numpy = np.random.randint(2147483648)
    print("Manual seed for pytorch:", seed_from_numpy)
    torch.manual_seed(seed_from_numpy)

    hparams.set_from_args(args)
    print("Hyperparameters:")
    hparams.print()
    print()
    pprint(vars(args))
    print()

    print("Loading training trees from {}...".format(args.train_path))
    train_treebank = treebanks.load_trees(args.train_path,
                                          args.train_path_text,
                                          args.text_processing)
    print("Loaded {:,} training examples.".format(len(train_treebank)))
    if hparams.max_len_train > 0:
        train_treebank = train_treebank.filter_by_length(hparams.max_len_train)
        print("len after filtering {:,}".format(len(train_treebank)))

    print("Loading development trees from {}...".format(args.dev_path))
    dev_treebank = treebanks.load_trees(args.dev_path, args.dev_path_text,
                                        args.text_processing)
    print("Loaded {:,} development examples.".format(len(dev_treebank)))
    if hparams.max_len_dev > 0:
        dev_treebank = dev_treebank.filter_by_length(hparams.max_len_dev)
        print("len after filtering {:,}".format(len(dev_treebank)))

    print("Constructing vocabularies...")
    label_vocab = decode_chart.ChartDecoder.build_vocab(train_treebank.trees)
    if hparams.use_chars_lstm:
        char_vocab = char_lstm.RetokenizerForCharLSTM.build_vocab(
            train_treebank.sents)
    else:
        char_vocab = None

    tag_vocab = set()
    for tree in train_treebank.trees:
        for _, tag in tree.pos():
            tag_vocab.add(tag)
    tag_vocab = ["UNK"] + sorted(tag_vocab)
    tag_vocab = {label: i for i, label in enumerate(tag_vocab)}

    if hparams.force_root_constituent.lower() in ("true", "yes", "1"):
        hparams.force_root_constituent = True
    elif hparams.force_root_constituent.lower() in ("false", "no", "0"):
        hparams.force_root_constituent = False
    elif hparams.force_root_constituent.lower() == "auto":
        hparams.force_root_constituent = (
            decode_chart.ChartDecoder.infer_force_root_constituent(
                train_treebank.trees))
        print("Set hparams.force_root_constituent to",
              hparams.force_root_constituent)

    print("Initializing model...")
    parser = parse_chart.ChartParser(
        tag_vocab=tag_vocab,
        label_vocab=label_vocab,
        char_vocab=char_vocab,
        hparams=hparams,
    )
    if args.parallelize:
        parser.parallelize()
    elif torch.cuda.is_available():
        parser.cuda()
    else:
        print("Not using CUDA!")

    print("Initializing optimizer...")
    trainable_parameters = [
        param for param in parser.parameters() if param.requires_grad
    ]
    optimizer = torch.optim.Adam(trainable_parameters,
                                 lr=hparams.learning_rate,
                                 betas=(0.9, 0.98),
                                 eps=1e-9)

    scheduler = learning_rates.WarmupThenReduceLROnPlateau(
        optimizer,
        hparams.learning_rate_warmup_steps,
        mode="max",
        factor=hparams.step_decay_factor,
        patience=hparams.step_decay_patience * hparams.checks_per_epoch,
        verbose=True,
    )

    clippable_parameters = trainable_parameters
    grad_clip_threshold = (np.inf if hparams.clip_grad_norm == 0 else
                           hparams.clip_grad_norm)

    print("Training...")
    total_processed = 0
    current_processed = 0
    check_every = len(train_treebank) / hparams.checks_per_epoch
    best_dev_fscore = -np.inf
    best_dev_model_path = None
    best_dev_processed = 0

    start_time = time.time()

    def check_dev():
        nonlocal best_dev_fscore
        nonlocal best_dev_model_path
        nonlocal best_dev_processed

        dev_start_time = time.time()

        dev_predicted = parser.parse(
            dev_treebank.without_gold_annotations(),
            subbatch_max_tokens=args.subbatch_max_tokens,
        )
        dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank.trees,
                                    dev_predicted)

        print("dev-fscore {} "
              "dev-elapsed {} "
              "total-elapsed {}".format(
                  dev_fscore,
                  format_elapsed(dev_start_time),
                  format_elapsed(start_time),
              ))

        if dev_fscore.fscore > best_dev_fscore:
            if best_dev_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_dev_model_path + ext
                    if os.path.exists(path):
                        print(
                            "Removing previous model file {}...".format(path))
                        os.remove(path)

            best_dev_fscore = dev_fscore.fscore
            best_dev_model_path = "{}_dev={:.2f}".format(
                args.model_path_base, dev_fscore.fscore)
            best_dev_processed = total_processed
            print("Saving new best model to {}...".format(best_dev_model_path))
            torch.save(
                {
                    "config": parser.config,
                    "state_dict": parser.state_dict(),
                    "optimizer": optimizer.state_dict(),
                },
                best_dev_model_path + ".pt",
            )

    data_loader = torch.utils.data.DataLoader(
        train_treebank,
        batch_size=hparams.batch_size,
        shuffle=True,
        collate_fn=functools.partial(
            parser.encode_and_collate_subbatches,
            subbatch_max_tokens=args.subbatch_max_tokens,
        ),
    )
    for epoch in itertools.count(start=1):
        epoch_start_time = time.time()

        for batch_num, batch in enumerate(data_loader, start=1):
            optimizer.zero_grad()
            parser.train()

            batch_loss_value = 0.0
            for subbatch_size, subbatch in batch:
                loss = parser.compute_loss(subbatch)
                loss_value = float(loss.data.cpu().numpy())
                batch_loss_value += loss_value
                if loss_value > 0:
                    loss.backward()
                del loss
                total_processed += subbatch_size
                current_processed += subbatch_size

            grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters,
                                                       grad_clip_threshold)

            optimizer.step()

            print("epoch {:,} "
                  "batch {:,}/{:,} "
                  "processed {:,} "
                  "batch-loss {:.4f} "
                  "grad-norm {:.4f} "
                  "epoch-elapsed {} "
                  "total-elapsed {}".format(
                      epoch,
                      batch_num,
                      int(np.ceil(len(train_treebank) / hparams.batch_size)),
                      total_processed,
                      batch_loss_value,
                      grad_norm,
                      format_elapsed(epoch_start_time),
                      format_elapsed(start_time),
                  ))

            if current_processed >= check_every:
                current_processed -= check_every
                check_dev()
                scheduler.step(metrics=best_dev_fscore)
            else:
                scheduler.step()

        if (total_processed - best_dev_processed) > (
            (hparams.step_decay_patience + 1) *
                hparams.max_consecutive_decays * len(train_treebank)):
            print("Terminating due to lack of improvement in dev fscore.")
            break
def run_export(args):
    if args.test_path is not None:
        print("Loading test trees from {}...".format(args.test_path))
        test_treebank = treebanks.load_trees(
            args.test_path, args.test_path_text, args.text_processing
        )
        print("Loaded {:,} test examples.".format(len(test_treebank)))
    else:
        test_treebank = None

    print("Loading model from {}...".format(args.model_path))
    parser = Parser(args.model_path, batch_size=args.batch_size)
    model = parser._parser
    if model.pretrained_model is None:
        raise ValueError(
            "Exporting is only defined when using a pre-trained transformer "
            "encoder. For CharLSTM-based model, just distribute the pytorch "
            "checkpoint directly. You may manually delete the 'optimizer' "
            "field to reduce file size by discarding the optimizer state."
        )

    if test_treebank is not None:
        print("Parsing test sentences (predicting tags)...")
        start_time = time.time()
        test_inputs = inputs_from_treebank(test_treebank, predict_tags=True)
        test_predicted = list(parser.parse_sents(test_inputs))
        test_fscore = evaluate.evalb(args.evalb_dir, test_treebank.trees, test_predicted)
        test_elapsed = format_elapsed(start_time)
        print("test-fscore {} test-elapsed {}".format(test_fscore, test_elapsed))

        print("Parsing test sentences (not predicting tags)...")
        start_time = time.time()
        test_inputs = inputs_from_treebank(test_treebank, predict_tags=False)
        notags_test_predicted = list(parser.parse_sents(test_inputs))
        notags_test_fscore = evaluate.evalb(
            args.evalb_dir, test_treebank.trees, notags_test_predicted
        )
        notags_test_elapsed = format_elapsed(start_time)
        print(
            "test-fscore {} test-elapsed {}".format(notags_test_fscore, notags_test_elapsed)
        )

    print("Exporting tokenizer...")
    model.retokenizer.tokenizer.save_pretrained(args.output_dir)

    print("Exporting config...")
    config = model.pretrained_model.config
    config.benepar = model.config
    config.save_pretrained(args.output_dir)

    if args.compress:
        print("Compressing weights...")
        state_dict = get_compressed_state_dict(model.cpu())
        print("Saving weights...")
    else:
        print("Exporting weights...")
        state_dict = model.cpu().state_dict()
    torch.save(state_dict, os.path.join(args.output_dir, "benepar_model.bin"))

    del model, parser, state_dict

    print("Loading exported model from {}...".format(args.output_dir))
    exported_parser = Parser(args.output_dir, batch_size=args.batch_size)

    if test_treebank is None:
        print()
        print("Export complete.")
        print("Did not verify model accuracy because no treebank was provided.")
        return

    print("Parsing test sentences (predicting tags)...")
    start_time = time.time()
    test_inputs = inputs_from_treebank(test_treebank, predict_tags=True)
    exported_predicted = list(exported_parser.parse_sents(test_inputs))
    exported_fscore = evaluate.evalb(
        args.evalb_dir, test_treebank.trees, exported_predicted
    )
    exported_elapsed = format_elapsed(start_time)
    print(
        "exported-fscore {} exported-elapsed {}".format(
            exported_fscore, exported_elapsed
        )
    )

    print("Parsing test sentences (not predicting tags)...")
    start_time = time.time()
    test_inputs = inputs_from_treebank(test_treebank, predict_tags=False)
    notags_exported_predicted = list(exported_parser.parse_sents(test_inputs))
    notags_exported_fscore = evaluate.evalb(
        args.evalb_dir, test_treebank.trees, notags_exported_predicted
    )
    notags_exported_elapsed = format_elapsed(start_time)
    print(
        "exported-fscore {} exported-elapsed {}".format(
            notags_exported_fscore, notags_exported_elapsed
        )
    )

    print()
    print("Export and verification complete.")
    fscore_delta = evaluate.FScore(
        recall=notags_exported_fscore.recall - notags_test_fscore.recall,
        precision=notags_exported_fscore.precision - notags_test_fscore.precision,
        fscore=notags_exported_fscore.fscore - notags_test_fscore.fscore,
        complete_match=(
            notags_exported_fscore.complete_match - notags_test_fscore.complete_match
        ),
        tagging_accuracy=(
            exported_fscore.tagging_accuracy - test_fscore.tagging_accuracy
        ),
    )
    print("delta-fscore {}".format(fscore_delta))