Пример #1
0
import numpy as np

import index

from pympler import muppy, summary
import gc

def print_mem():
    all_objects = muppy.get_objects()
    sum1 = summary.summarize(all_objects)
    summary.print_(sum1)
    print(psutil.Process().memory_full_info().rss / 1e9)

prefix = index.get_index_prefix(
    index_base_path = "../index",
    full_model_path = "models/en_multibert_empty_l2_dev=94.94.pt",
    nn_prefix = "nl2-multi",
)
span_index = index.FaissIndex(num_labels = 112, metric="l2")
span_index.load(prefix)

shape = span_index.keys.shape

print("initialized index")
print_mem()
span_index.to(1)
print("to gpu")
print_mem()

span_index.reset()
print("reset index")
Пример #2
0
def run_index(args):
    print("Saving span representations")
    print()

    print("Loading train trees from {}...".format(args.train_path))
    train_treebank = trees.load_trees(args.train_path)
    print("Loaded {:,} train examples.".format(len(train_treebank)))

    print("Loading model from {}...".format(args.model_path_base))
    assert args.model_path_base.endswith(
        ".pt"), "Only pytorch savefiles supported"

    info = torch_load(args.model_path_base)
    assert 'hparams' in info['spec'], "Older savefiles not supported"
    parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict'])
    parser.no_mlp = args.no_mlp
    parser.no_relu = args.no_relu
    if args.no_relu:
        parser.remove_relu()

    print("Getting labelled span representations")
    start_time = time.time()

    if args.redo_vocab:
        parser.label_vocab = gen_label_vocab(
            [tree.convert() for tree in train_treebank])

    num_labels = len(parser.label_vocab.values)
    """
    span_index = index.SpanIndex(
        num_indices = num_labels,
        library = args.library,
    )
    """
    span_index = (index.FaissIndex(num_labels=num_labels, metric=parser.metric)
                  if args.library == "faiss" else index.AnnoyIndex(
                      num_indices=num_labels, metric=parser.metric))

    rep_time = time.time()
    span_reps, span_infos = index.get_span_reps_infos(parser, train_treebank,
                                                      args.batch_size)
    print(f"rep-time: {format_elapsed(rep_time)}")
    # clean up later, refactor back into index.py
    build_time = time.time()
    #use_gpu = True
    use_gpu = False
    print(f"Using gpu: {use_gpu}")
    if args.library == "faiss":
        if use_gpu:
            span_index.to(0)
        span_index.add(span_reps, span_infos)
        span_index.build()
    else:
        for rep, info in zip(span_reps, span_infos):
            span_index.add_item(rep, info)
        span_index.build()

    #span_index.build()
    print(f"build-time {format_elapsed(build_time)}")
    if use_gpu:
        span_index.to(-1)

    save_time = time.time()
    prefix = index.get_index_prefix(
        index_base_path=args.index_path,
        full_model_path=args.model_path_base,
        nn_prefix=args.nn_prefix,
    )
    print(f"Saving index to {prefix}")
    span_index.save(prefix)
    print(f"save-time {format_elapsed(save_time)}")

    print(f"index-elapsed {format_elapsed(start_time)}")
Пример #3
0
def run_train(args, hparams):
    if args.numpy_seed is not None:
        print("Setting numpy random seed to {}...".format(args.numpy_seed))
        np.random.seed(args.numpy_seed)

    # Make sure that pytorch is actually being initialized randomly.
    # On my cluster I was getting highly correlated results from multiple
    # runs, but calling reset_parameters() changed that. A brief look at the
    # pytorch source code revealed that pytorch initializes its RNG by
    # calling std::random_device, which according to the C++ spec is allowed
    # to be deterministic.
    seed_from_numpy = np.random.randint(2147483648)
    print("Manual seed for pytorch:", seed_from_numpy)
    torch.manual_seed(seed_from_numpy)

    hparams.set_from_args(args)
    print("Hyperparameters:")
    hparams.print()

    print("Loading training trees from {}...".format(args.train_path))
    if hparams.predict_tags and args.train_path.endswith('10way.clean'):
        print("WARNING: The data distributed with this repository contains "
              "predicted part-of-speech tags only (not gold tags!) We do not "
              "recommend enabling predict_tags in this configuration.")
    train_treebank = trees.load_trees(args.train_path)
    if hparams.max_len_train > 0:
        train_treebank = [
            tree for tree in train_treebank
            if len(list(tree.leaves())) <= hparams.max_len_train
        ]
    print("Loaded {:,} training examples.".format(len(train_treebank)))

    print("Loading development trees from {}...".format(args.dev_path))
    dev_treebank = trees.load_trees(args.dev_path)
    if hparams.max_len_dev > 0:
        dev_treebank = [
            tree for tree in dev_treebank
            if len(list(tree.leaves())) <= hparams.max_len_dev
        ]
    print("Loaded {:,} development examples.".format(len(dev_treebank)))

    print("Processing trees for training...")
    train_parse = [tree.convert() for tree in train_treebank]

    print("Constructing vocabularies...")

    tag_vocab = vocabulary.Vocabulary()
    tag_vocab.index(tokens.START)
    tag_vocab.index(tokens.STOP)
    tag_vocab.index(tokens.TAG_UNK)

    word_vocab = vocabulary.Vocabulary()
    word_vocab.index(tokens.START)
    word_vocab.index(tokens.STOP)
    word_vocab.index(tokens.UNK)

    label_vocab = vocabulary.Vocabulary()
    label_vocab.index(())

    char_set = set()

    for idx, tree in enumerate(train_parse):
        tree.idx = idx
        # augment each node with index?
        nodes = [tree]
        while nodes:
            node = nodes.pop()
            if isinstance(node, trees.InternalParseNode):
                label_vocab.index(node.label)
                nodes.extend(reversed(node.children))
            else:
                tag_vocab.index(node.tag)
                word_vocab.index(node.word)
                char_set |= set(node.word)
    char_vocab = vocabulary.Vocabulary()

    # If codepoints are small (e.g. Latin alphabet), index by codepoint directly
    highest_codepoint = max(ord(char) for char in char_set)
    if highest_codepoint < 512:
        if highest_codepoint < 256:
            highest_codepoint = 256
        else:
            highest_codepoint = 512

        # This also takes care of constants like tokens.CHAR_PAD
        for codepoint in range(highest_codepoint):
            char_index = char_vocab.index(chr(codepoint))
            assert char_index == codepoint
    else:
        char_vocab.index(tokens.CHAR_UNK)
        char_vocab.index(tokens.CHAR_START_SENTENCE)
        char_vocab.index(tokens.CHAR_START_WORD)
        char_vocab.index(tokens.CHAR_STOP_WORD)
        char_vocab.index(tokens.CHAR_STOP_SENTENCE)
        for char in sorted(char_set):
            char_vocab.index(char)

    tag_vocab.freeze()
    word_vocab.freeze()
    label_vocab.freeze()
    char_vocab.freeze()

    def print_vocabulary(name, vocab):
        special = {tokens.START, tokens.STOP, tokens.UNK}
        print("{} ({:,}): {}".format(
            name, vocab.size,
            sorted(value for value in vocab.values if value in special) +
            sorted(value for value in vocab.values if value not in special)))

    if args.print_vocabs:
        print_vocabulary("Tag", tag_vocab)
        print_vocabulary("Word", word_vocab)
        print_vocabulary("Label", label_vocab)

    print("Initializing model...")
    load_path = args.model_path_base if args.model_path_base.endswith(
        ".pt") else None
    if load_path is not None:
        print(f"Loading parameters from {load_path}")
        info = torch_load(load_path)
        parser = parse_jc.NKChartParser.from_spec(info['spec'],
                                                  info['state_dict'])
    else:
        parser = parse_jc.NKChartParser(
            tag_vocab,
            word_vocab,
            label_vocab,
            char_vocab,
            hparams,
        )
    parser.no_relu = args.no_relu
    if args.no_relu:
        parser.remove_relu()
        print("Removing ReLU from chart MLP")
    if args.override_use_label_weights:
        # override loaded model
        parser.use_label_weights = args.override_use_label_weights
        print(
            f"Overriding use_label_weights: {args.override_use_label_weights}")

    span_index, K = None, None
    if args.use_neighbours:
        index_const = (index.FaissIndex
                       if args.library == "faiss" else index.AnnoyIndex)
        # assert index loaded has the same metric
        span_index = index_const(
            num_labels=len(parser.label_vocab.values),
            metric=parser.metric,
        )
        prefix = index.get_index_prefix(
            index_base_path=args.index_path,
            full_model_path=args.model_path_base,
            nn_prefix=args.nn_prefix,
        )
        span_index.load(prefix)
        K = args.k
        assert K > 0
        if parse_jc.use_cuda:
            # hack!
            # use CUDA_VISIBLE_DEVICES={0},{1}
            print(f"Using gpu {args.index_devid} for index")
            span_index.to(args.index_devid)
            #pass

    if args.label_weights_only:
        # freeze everything except "label_weights"
        for name, param in parser.named_parameters():
            if name != "label_weights":
                param.requires_grad = False
    else:
        parser.label_weights.requires_grad = False

    print("Initializing optimizer...")
    trainable_parameters = [
        param for param in parser.parameters() if param.requires_grad
    ]
    trainer = torch.optim.Adam(trainable_parameters,
                               lr=1.,
                               betas=(0.9, 0.98),
                               eps=1e-9)
    if load_path is not None:
        try:
            trainer.load_state_dict(info['trainer'])
        except:
            print("Couldn't load optim state.")

    def set_lr(new_lr):
        for param_group in trainer.param_groups:
            param_group['lr'] = new_lr

    assert hparams.step_decay, "Only step_decay schedule is supported"

    warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        trainer,
        'max',
        factor=hparams.step_decay_factor,
        patience=hparams.step_decay_patience,
        verbose=True,
    )

    def schedule_lr(iteration):
        iteration = iteration + 1
        if iteration <= hparams.learning_rate_warmup_steps:
            set_lr(iteration * warmup_coeff)

    clippable_parameters = trainable_parameters
    grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm

    print("Training...")
    total_processed = 0
    current_processed = 0
    current_index_processed = 0
    check_every = len(train_parse) / args.checks_per_epoch
    reindex_every = len(train_parse) / args.reindexes_per_epoch
    best_dev_fscore = -np.inf
    best_dev_model_path = None
    best_dev_processed = 0

    start_time = time.time()

    def check_dev():
        nonlocal best_dev_fscore
        nonlocal best_dev_model_path
        nonlocal best_dev_processed

        dev_start_time = time.time()

        dev_predicted = []
        for dev_start_index in range(0, len(dev_treebank),
                                     args.eval_batch_size):
            subbatch_trees = dev_treebank[dev_start_index:dev_start_index +
                                          args.eval_batch_size]
            subbatch_sentences = [[(leaf.tag, leaf.word)
                                   for leaf in tree.leaves()]
                                  for tree in subbatch_trees]
            predicted, _ = parser.parse_batch(
                subbatch_sentences,
                span_index=span_index,
                k=K,
                zero_empty=parser.zero_empty,
                train_nn=args.train_through_nn,
            )
            del _
            dev_predicted.extend([p.convert() for p in predicted])

        dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank,
                                    dev_predicted)

        print("dev-fscore {} "
              "dev-elapsed {} "
              "total-elapsed {}".format(
                  dev_fscore,
                  format_elapsed(dev_start_time),
                  format_elapsed(start_time),
              ))

        if dev_fscore.fscore > best_dev_fscore:
            if best_dev_model_path is not None:
                extensions = [".pt"]
                for ext in extensions:
                    path = best_dev_model_path + ext
                    if os.path.exists(path):
                        print(
                            "Removing previous model file {}...".format(path))
                        os.remove(path)

            best_dev_fscore = dev_fscore.fscore
            best_dev_model_path = "{}_dev={:.2f}".format(
                args.model_path_base, dev_fscore.fscore)
            best_dev_processed = total_processed
            print("Saving new best model to {}...".format(best_dev_model_path))
            torch.save(
                {
                    'spec': parser.spec,
                    'state_dict': parser.state_dict(),
                    'trainer': trainer.state_dict(),
                }, best_dev_model_path + ".pt")

    for epoch in itertools.count(start=1):
        if args.epochs is not None and epoch > args.epochs:
            break

        np.random.shuffle(train_parse)
        epoch_start_time = time.time()

        for start_index in range(0, len(train_parse), args.batch_size):
            trainer.zero_grad()
            schedule_lr(total_processed // args.batch_size)

            batch_loss_value = 0.0
            batch_trees = train_parse[start_index:start_index +
                                      args.batch_size]
            batch_sentences = [[(leaf.tag, leaf.word)
                                for leaf in tree.leaves()]
                               for tree in batch_trees]
            batch_num_tokens = sum(
                len(sentence) for sentence in batch_sentences)

            for subbatch_sentences, subbatch_trees in parser.split_batch(
                    batch_sentences, batch_trees, args.subbatch_max_tokens):
                _, loss = parser.parse_batch(
                    subbatch_sentences,
                    subbatch_trees,
                    span_index=span_index,
                    k=K,
                    zero_empty=parser.zero_empty,
                )

                if hparams.predict_tags:
                    loss = loss[0] / len(
                        batch_trees) + loss[1] / batch_num_tokens
                else:
                    loss = loss / len(batch_trees)
                loss_value = float(loss.data.cpu().numpy())
                batch_loss_value += loss_value
                if loss_value > 0:
                    loss.backward()
                del loss
                total_processed += len(subbatch_trees)
                current_processed += len(subbatch_trees)
                current_index_processed += len(subbatch_trees)

            grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters,
                                                       grad_clip_threshold)

            trainer.step()

            print("epoch {:,} "
                  "batch {:,}/{:,} "
                  "processed {:,} "
                  "batch-loss {:.4f} "
                  "grad-norm {:.4f} "
                  "epoch-elapsed {} "
                  "total-elapsed {}".format(
                      epoch,
                      start_index // args.batch_size + 1,
                      int(np.ceil(len(train_parse) / args.batch_size)),
                      total_processed,
                      batch_loss_value,
                      grad_norm,
                      format_elapsed(epoch_start_time),
                      format_elapsed(start_time),
                  ))

            if current_processed >= check_every:
                current_processed -= check_every
                check_dev()
            if current_index_processed >= reindex_every:
                current_index_processed -= reindex_every
                if span_index is not None:
                    # recompute span_index
                    reindex_time = time.time()
                    span_index.reset()
                    span_index.to(args.index_devid)
                    span_reps, span_infos = index.get_span_reps_infos(
                        parser,
                        train_treebank,
                        128,
                    )
                    span_index.add(span_reps, span_infos)
                    span_index.build()

                    print(f"reindex-elapsed: {format_elapsed(reindex_time)}")
                    save_time = time.time()
                    prefix = index.get_index_prefix(
                        index_base_path=args.index_path,
                        full_model_path=args.model_path_base,
                        nn_prefix=args.save_nn_prefix,
                    )
                    span_index.to(-1)
                    print(f"Saving recomputed index")
                    span_index.save(prefix)
                    span_index.to(args.index_devid)
                    print(f"save-elapsed: {format_elapsed(save_time)}")

        # adjust learning rate at the end of an epoch
        if (total_processed // args.batch_size +
                1) > hparams.learning_rate_warmup_steps:
            scheduler.step(best_dev_fscore)
            if (total_processed - best_dev_processed) > (
                (hparams.step_decay_patience + 1) *
                    hparams.max_consecutive_decays * len(train_parse)):
                print("Terminating due to lack of improvement in dev fscore.")
                break
Пример #4
0
def run_test(args):
    print("Loading test trees from {}...".format(args.test_path))
    test_treebank = trees.load_trees(args.test_path)
    print("Loaded {:,} test examples.".format(len(test_treebank)))

    print("Loading model from {}...".format(args.model_path_base))
    assert args.model_path_base.endswith(
        ".pt"), "Only pytorch savefiles supported"

    info = torch_load(args.model_path_base)
    assert 'hparams' in info['spec'], "Older savefiles not supported"
    parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict'])

    if args.redo_vocab:
        print(
            "Loading memory bank trees from {} for generating label vocab...".
            format(args.train_path))
        train_treebank = trees.load_trees(args.train_path)
        parser.label_vocab = gen_label_vocab(
            [tree.convert() for tree in train_treebank])

    print("Parsing test sentences...")
    start_time = time.time()

    if args.use_neighbours:
        index_const = index.FaissIndex if args.library == "faiss" else index.AnnoyIndex
        span_index = index_const(
            num_labels=len(parser.label_vocab.values),
            metric=parser.metric,
        )
        prefix = index.get_index_prefix(
            index_base_path=args.index_path,
            full_model_path=args.model_path_base,
            nn_prefix=args.nn_prefix,
        )
        span_index.load(prefix)

        # also remove relu
        parser.no_relu = args.no_relu
        if args.no_relu:
            parser.remove_relu()

    test_predicted = []
    for start_index in range(0, len(test_treebank), args.eval_batch_size):
        subbatch_trees = test_treebank[start_index:start_index +
                                       args.eval_batch_size]
        subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()]
                              for tree in subbatch_trees]
        predicted, _ = parser.parse_batch(
            subbatch_sentences,
            span_index=span_index if args.use_neighbours else None,
            k=args.k,
            zero_empty=args.zero_empty,
        )
        del _
        test_predicted.extend([p.convert() for p in predicted])

    # The tree loader does some preprocessing to the trees (e.g. stripping TOP
    # symbols or SPMRL morphological features). We compare with the input file
    # directly to be extra careful about not corrupting the evaluation. We also
    # allow specifying a separate "raw" file for the gold trees: the inputs to
    # our parser have traces removed and may have predicted tags substituted,
    # and we may wish to compare against the raw gold trees to make sure we
    # haven't made a mistake. As far as we can tell all of these variations give
    # equivalent results.
    ref_gold_path = args.test_path
    if args.test_path_raw is not None:
        print("Comparing with raw trees from", args.test_path_raw)
        ref_gold_path = args.test_path_raw

    test_fscore = evaluate.evalb(
        args.evalb_dir,
        test_treebank,
        test_predicted,
        ref_gold_path=ref_gold_path,
    )

    print("test-fscore {} "
          "test-elapsed {}".format(
              test_fscore,
              format_elapsed(start_time),
          ))
Пример #5
0
import time

import index
import torch
import scatter

import numpy as np

index_path = "index"
model_base_path = "models/en_bert_empty_nl2_dev=95.36.pt"
prefix = index.get_index_prefix(
    index_base_path=index_path,
    full_model_path=model_base_path,
    nn_prefix="all_spans_empty_nl2",
)

annoy_index = index.AnnoyIndex(metric="l2")
faiss_index = index.FaissIndex(metric="l2")
faiss_index_gpu = index.FaissIndex(metric="l2")

t = time.time()
annoy_index.load(prefix)
print(f"loaded annoy {time.time() - t}")
t = time.time()
prefix = index.get_index_prefix(
    index_base_path=index_path,
    full_model_path=model_base_path,
    nn_prefix="all_spans_empty_nl2",
)
faiss_index.load(prefix)
print(f"loaded faiss {time.time() - t}")