import numpy as np import index from pympler import muppy, summary import gc def print_mem(): all_objects = muppy.get_objects() sum1 = summary.summarize(all_objects) summary.print_(sum1) print(psutil.Process().memory_full_info().rss / 1e9) prefix = index.get_index_prefix( index_base_path = "../index", full_model_path = "models/en_multibert_empty_l2_dev=94.94.pt", nn_prefix = "nl2-multi", ) span_index = index.FaissIndex(num_labels = 112, metric="l2") span_index.load(prefix) shape = span_index.keys.shape print("initialized index") print_mem() span_index.to(1) print("to gpu") print_mem() span_index.reset() print("reset index")
def run_index(args): print("Saving span representations") print() print("Loading train trees from {}...".format(args.train_path)) train_treebank = trees.load_trees(args.train_path) print("Loaded {:,} train examples.".format(len(train_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict']) parser.no_mlp = args.no_mlp parser.no_relu = args.no_relu if args.no_relu: parser.remove_relu() print("Getting labelled span representations") start_time = time.time() if args.redo_vocab: parser.label_vocab = gen_label_vocab( [tree.convert() for tree in train_treebank]) num_labels = len(parser.label_vocab.values) """ span_index = index.SpanIndex( num_indices = num_labels, library = args.library, ) """ span_index = (index.FaissIndex(num_labels=num_labels, metric=parser.metric) if args.library == "faiss" else index.AnnoyIndex( num_indices=num_labels, metric=parser.metric)) rep_time = time.time() span_reps, span_infos = index.get_span_reps_infos(parser, train_treebank, args.batch_size) print(f"rep-time: {format_elapsed(rep_time)}") # clean up later, refactor back into index.py build_time = time.time() #use_gpu = True use_gpu = False print(f"Using gpu: {use_gpu}") if args.library == "faiss": if use_gpu: span_index.to(0) span_index.add(span_reps, span_infos) span_index.build() else: for rep, info in zip(span_reps, span_infos): span_index.add_item(rep, info) span_index.build() #span_index.build() print(f"build-time {format_elapsed(build_time)}") if use_gpu: span_index.to(-1) save_time = time.time() prefix = index.get_index_prefix( index_base_path=args.index_path, full_model_path=args.model_path_base, nn_prefix=args.nn_prefix, ) print(f"Saving index to {prefix}") span_index.save(prefix) print(f"save-time {format_elapsed(save_time)}") print(f"index-elapsed {format_elapsed(start_time)}")
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() print("Loading training trees from {}...".format(args.train_path)) if hparams.predict_tags and args.train_path.endswith('10way.clean'): print("WARNING: The data distributed with this repository contains " "predicted part-of-speech tags only (not gold tags!) We do not " "recommend enabling predict_tags in this configuration.") train_treebank = trees.load_trees(args.train_path) if hparams.max_len_train > 0: train_treebank = [ tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train ] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(args.dev_path)) dev_treebank = trees.load_trees(args.dev_path) if hparams.max_len_dev > 0: dev_treebank = [ tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev ] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for idx, tree in enumerate(train_parse): tree.idx = idx # augment each node with index? nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print("Initializing model...") load_path = args.model_path_base if args.model_path_base.endswith( ".pt") else None if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict']) else: parser = parse_jc.NKChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, hparams, ) parser.no_relu = args.no_relu if args.no_relu: parser.remove_relu() print("Removing ReLU from chart MLP") if args.override_use_label_weights: # override loaded model parser.use_label_weights = args.override_use_label_weights print( f"Overriding use_label_weights: {args.override_use_label_weights}") span_index, K = None, None if args.use_neighbours: index_const = (index.FaissIndex if args.library == "faiss" else index.AnnoyIndex) # assert index loaded has the same metric span_index = index_const( num_labels=len(parser.label_vocab.values), metric=parser.metric, ) prefix = index.get_index_prefix( index_base_path=args.index_path, full_model_path=args.model_path_base, nn_prefix=args.nn_prefix, ) span_index.load(prefix) K = args.k assert K > 0 if parse_jc.use_cuda: # hack! # use CUDA_VISIBLE_DEVICES={0},{1} print(f"Using gpu {args.index_devid} for index") span_index.to(args.index_devid) #pass if args.label_weights_only: # freeze everything except "label_weights" for name, param in parser.named_parameters(): if name != "label_weights": param.requires_grad = False else: parser.label_weights.requires_grad = False print("Initializing optimizer...") trainable_parameters = [ param for param in parser.parameters() if param.requires_grad ] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: try: trainer.load_state_dict(info['trainer']) except: print("Couldn't load optim state.") def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 current_index_processed = 0 check_every = len(train_parse) / args.checks_per_epoch reindex_every = len(train_parse) / args.reindexes_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None best_dev_processed = 0 start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal best_dev_processed dev_start_time = time.time() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch( subbatch_sentences, span_index=span_index, k=K, zero_empty=parser.zero_empty, train_nn=args.train_through_nn, ) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print("dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), )) if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print( "Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_dev_model_path = "{}_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_dev_model_path)) torch.save( { 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer': trainer.state_dict(), }, best_dev_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] batch_num_tokens = sum( len(sentence) for sentence in batch_sentences) for subbatch_sentences, subbatch_trees in parser.split_batch( batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch( subbatch_sentences, subbatch_trees, span_index=span_index, k=K, zero_empty=parser.zero_empty, ) if hparams.predict_tags: loss = loss[0] / len( batch_trees) + loss[1] / batch_num_tokens else: loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) current_index_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() print("epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), )) if current_processed >= check_every: current_processed -= check_every check_dev() if current_index_processed >= reindex_every: current_index_processed -= reindex_every if span_index is not None: # recompute span_index reindex_time = time.time() span_index.reset() span_index.to(args.index_devid) span_reps, span_infos = index.get_span_reps_infos( parser, train_treebank, 128, ) span_index.add(span_reps, span_infos) span_index.build() print(f"reindex-elapsed: {format_elapsed(reindex_time)}") save_time = time.time() prefix = index.get_index_prefix( index_base_path=args.index_path, full_model_path=args.model_path_base, nn_prefix=args.save_nn_prefix, ) span_index.to(-1) print(f"Saving recomputed index") span_index.save(prefix) span_index.to(args.index_devid) print(f"save-elapsed: {format_elapsed(save_time)}") # adjust learning rate at the end of an epoch if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore) if (total_processed - best_dev_processed) > ( (hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)): print("Terminating due to lack of improvement in dev fscore.") break
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict']) if args.redo_vocab: print( "Loading memory bank trees from {} for generating label vocab...". format(args.train_path)) train_treebank = trees.load_trees(args.train_path) parser.label_vocab = gen_label_vocab( [tree.convert() for tree in train_treebank]) print("Parsing test sentences...") start_time = time.time() if args.use_neighbours: index_const = index.FaissIndex if args.library == "faiss" else index.AnnoyIndex span_index = index_const( num_labels=len(parser.label_vocab.values), metric=parser.metric, ) prefix = index.get_index_prefix( index_base_path=args.index_path, full_model_path=args.model_path_base, nn_prefix=args.nn_prefix, ) span_index.load(prefix) # also remove relu parser.no_relu = args.no_relu if args.no_relu: parser.remove_relu() test_predicted = [] for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch( subbatch_sentences, span_index=span_index if args.use_neighbours else None, k=args.k, zero_empty=args.zero_empty, ) del _ test_predicted.extend([p.convert() for p in predicted]) # The tree loader does some preprocessing to the trees (e.g. stripping TOP # symbols or SPMRL morphological features). We compare with the input file # directly to be extra careful about not corrupting the evaluation. We also # allow specifying a separate "raw" file for the gold trees: the inputs to # our parser have traces removed and may have predicted tags substituted, # and we may wish to compare against the raw gold trees to make sure we # haven't made a mistake. As far as we can tell all of these variations give # equivalent results. ref_gold_path = args.test_path if args.test_path_raw is not None: print("Comparing with raw trees from", args.test_path_raw) ref_gold_path = args.test_path_raw test_fscore = evaluate.evalb( args.evalb_dir, test_treebank, test_predicted, ref_gold_path=ref_gold_path, ) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
import time import index import torch import scatter import numpy as np index_path = "index" model_base_path = "models/en_bert_empty_nl2_dev=95.36.pt" prefix = index.get_index_prefix( index_base_path=index_path, full_model_path=model_base_path, nn_prefix="all_spans_empty_nl2", ) annoy_index = index.AnnoyIndex(metric="l2") faiss_index = index.FaissIndex(metric="l2") faiss_index_gpu = index.FaissIndex(metric="l2") t = time.time() annoy_index.load(prefix) print(f"loaded annoy {time.time() - t}") t = time.time() prefix = index.get_index_prefix( index_base_path=index_path, full_model_path=model_base_path, nn_prefix="all_spans_empty_nl2", ) faiss_index.load(prefix) print(f"loaded faiss {time.time() - t}")