def load_data(data_dir): train_data = os.path.join(data_dir, '02-21.10way.clean') # train_data = os.path.join(data_dir, '22.auto.clean') valid_data = os.path.join(data_dir, '22.auto.clean') test_data = os.path.join(data_dir, '23.auto.clean') print("Reading trees...") train_trees = trees.load_trees(train_data) valid_trees = trees.load_trees(valid_data) test_trees = trees.load_trees(test_data) print("Converting trees...") train_parse = [tree.convert() for tree in train_trees] valid_parse = [tree.convert() for tree in valid_trees] test_parse = [tree.convert() for tree in test_trees] tag_vocab = vocabulary.Vocabulary() tag_vocab.index(vocabulary.PAD) tag_vocab.index(vocabulary.START) tag_vocab.index(vocabulary.STOP) tag_vocab.index(vocabulary.UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(vocabulary.PAD) word_vocab.index(vocabulary.START) word_vocab.index(vocabulary.STOP) word_vocab.index(vocabulary.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(vocabulary.PAD) label_vocab.index(()) print("Getting vocabulary...") for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) label_vocab.freeze() word_vocab.freeze() tag_vocab.freeze() print("Tag vocab: ", tag_vocab.size) print("Label vocab: ", label_vocab.size) print("Word vocab: ", word_vocab.size) return (word_vocab, tag_vocab, label_vocab, train_parse, valid_parse, test_parse)
def dict_to_data(self, dict): synconst = dict['synconst'] syndep_head = dict['syndep_head'] syndep_type = dict['syndep_type'] srlspan_str = dict['srlspan'] srlspan = {} for pred_id, args in srlspan_str.items(): srlspan[int(pred_id)] = [(int(a[0]), int(a[1]), a[2]) for a in args] srldep_str = dict['srldep'] srldep = {} for pred_id, args in srldep_str.items(): srldep[int(pred_id)] = [(int(a[0]), a[1]) for a in args] syntree = trees.load_trees(synconst, [[int(head) for head in syndep_head]], [syndep_type], strip_top = False)[0] sent = [(leaf.tag, leaf.word) for leaf in syntree.leaves()] synparse = syntree.convert() dict_new = {} dict_new['synconst'] = synconst dict_new['syndep_head'] = json.dumps(syndep_head) dict_new['syndep_type'] = json.dumps(syndep_type) dict_new['srlspan'] = json.dumps(srlspan_str) dict_new['srldep'] = json.dumps(srldep_str) return (sent, syntree, synparse, srlspan, srldep), dict_new
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) # model = dy.ParameterCollection() # [parser] = dy.load(args.model_path_base, model) parser = torch.load(args.model_path_base) print("Parsing test sentences...") start_time = time.time() test_predicted = [] for tree in test_treebank: # dy.renew_cg() parser.eval() sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] predicted, _ = parser.parse(sentence) test_predicted.append(predicted.convert()) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_parse_extra(args): if args.output_path != '-' and os.path.exists(args.output_path): print("Error: output file already exists:", args.output_path) return print("Loading parse trees from {}...".format(args.input_path)) treebank = trees.load_trees(args.input_path) if args.max_len_eval > 0: treebank = [ tree for tree in treebank if len(list(tree.leaves())) <= args.max_len_eval ] print("Loaded {:,} parse tree examples.".format(len(treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) print("Parsing test sentences...") start_time = time.time() new_treebank = [] for start_index in range(0, len(treebank), args.eval_batch_size): subbatch_trees = treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ new_treebank.extend([p.convert() for p in predicted]) assert len(treebank) == len(new_treebank), (len(treebank), len(new_treebank)) if args.write_parse is not None: print('writing to {}'.format(args.write_parse)) f = open(args.write_parse, 'w') for x, y in zip(new_treebank, treebank): gold = '(ROOT {})'.format(y.linearize()) pred = '(ROOT {})'.format(x.linearize()) ex = dict(gold=gold, pred=pred) f.write(json.dumps(ex) + '\n') f.close() test_fscore = evaluate.evalb(args.evalb_dir, treebank, new_treebank, ref_gold_path=None) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = SAPar_model.SAChartParser.from_spec(info['spec'], info['state_dict']) print("Parsing test sentences...") start_time = time.time() test_predicted = [] for start_index in tqdm(range(0, len(test_treebank), args.eval_batch_size)): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ test_predicted.extend([p.convert() for p in predicted]) # The tree loader does some preprocessing to the trees (e.g. stripping TOP # symbols or SPMRL morphological features). We compare with the input file # directly to be extra careful about not corrupting the evaluation. We also # allow specifying a separate "raw" file for the gold trees: the inputs to # our parser have traces removed and may have predicted tags substituted, # and we may wish to compare against the raw gold trees to make sure we # haven't made a mistake. As far as we can tell all of these variations give # equivalent results. ref_gold_path = args.test_path if args.test_path_raw is not None: print("Comparing with raw trees from", args.test_path_raw) ref_gold_path = args.test_path_raw test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted, ref_gold_path=ref_gold_path) model_name = args.model_path_base[args.model_path_base.rfind('/') + 1:args.model_path_base.rfind('.')] output_file = './results/' + model_name + '.txt' with open(output_file, "w") as outfile: for tree in test_predicted: outfile.write("{}\n".format(tree.linearize())) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def produce_elmo_for_treebank(args): treebank = trees.load_trees(args.input_file, strip_top=True, filter_none=True) for tree in treebank: parse = tree.convert() sentence1 = [leaf.word for leaf in tree.leaves] sentence2 = [leaf.word for leaf in parse.leaves] assert sentence1 == sentence2, (sentence1, sentence2) sentences = [[leaf.word for leaf in tree.leaves] for tree in treebank] tokenized_lines = [' '.join(sentence) for sentence in sentences] compute_elmo_embeddings(tokenized_lines, args.experiment_directory)
def run_ensemble(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) parsers = [] for model_path_base in args.model_path_base: print("Loading model from {}...".format(model_path_base)) assert model_path_base.endswith(".pt"), "Only pytorch savefiles supported" info = torch_load(model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) parsers.append(parser) # Ensure that label scores charts produced by the models can be combined # using simple averaging ref_label_vocab = parsers[0].label_vocab for parser in parsers: assert parser.label_vocab.indices == ref_label_vocab.indices print("Parsing test sentences...") start_time = time.time() test_predicted = [] # Ensemble by averaging label score charts from different models # We did not observe any benefits to doing weighted averaging, probably # because all our parsers output label scores of around the same magnitude for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] chart_lists = [] for parser in parsers: charts = parser.parse_batch(subbatch_sentences, return_label_scores_charts=True) chart_lists.append(charts) subbatch_charts = [np.mean(list(sentence_charts), 0) for sentence_charts in zip(*chart_lists)] predicted, _ = parsers[0].decode_from_chart_batch(subbatch_sentences, subbatch_charts) del _ test_predicted.extend([p.convert() for p in predicted]) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted, ref_gold_path=args.test_path) print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) )
def run_viz(args): assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported" print("Loading test trees from {}...".format(args.viz_path)) viz_treebank = trees.load_trees(args.viz_path) print("Loaded {:,} test examples.".format(len(viz_treebank))) print("Loading model from {}...".format(args.model_path_base)) info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Only self-attentive models are supported" parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) from viz import viz_attention stowed_values = {} orig_multihead_forward = parse_nk.MultiHeadAttention.forward def wrapped_multihead_forward(self, inp, batch_idxs, **kwargs): res, attns = orig_multihead_forward(self, inp, batch_idxs, **kwargs) stowed_values[f'attns{stowed_values["stack"]}'] = attns.cpu().data.numpy() stowed_values['stack'] += 1 return res, attns parse_nk.MultiHeadAttention.forward = wrapped_multihead_forward # Select the sentences we will actually be visualizing max_len_viz = 15 if max_len_viz > 0: viz_treebank = [tree for tree in viz_treebank if len(list(tree.leaves())) <= max_len_viz] viz_treebank = viz_treebank[:1] print("Parsing viz sentences...") for start_index in range(0, len(viz_treebank), args.eval_batch_size): subbatch_trees = viz_treebank[start_index:start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] stowed_values = dict(stack=0) predicted, _ = parser.parse_batch(subbatch_sentences) del _ predicted = [p.convert() for p in predicted] stowed_values['predicted'] = predicted for snum, sentence in enumerate(subbatch_sentences): sentence_words = [tokens.START] + [x[1] for x in sentence] + [tokens.STOP] for stacknum in range(stowed_values['stack']): attns_padded = stowed_values[f'attns{stacknum}'] attns = attns_padded[snum::len(subbatch_sentences), :len(sentence_words), :len(sentence_words)] viz_attention(sentence_words, attns)
def run_test(args): test_path = args.test_ptb_path if args.dataset == "ctb": test_path = args.test_ctb_path print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert "hparams" in info["spec"], "Older savefiles not supported" parser = Lparser.ChartParser.from_spec(info["spec"], info["state_dict"]) parser.eval() print("Loading test trees from {}...".format(test_path)) test_treebank = trees.load_trees(test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Parsing test sentences...") start_time = time.time() punct_set = "." "``" "''" ":" "," parser.eval() test_predicted = [] for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] ( predicted, _, ) = parser.parse_batch(subbatch_sentences) del _ test_predicted.extend([p.convert() for p in predicted]) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_test(args): #args.test_path = args.test_path.replace('[*]', args.treetype) print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path, args.normal) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) model = dy.ParameterCollection() [parser] = dy.load(args.model_path_base, model) label_vocab = vocabulary.Vocabulary() label_list = util.load_label_list('../data/labels.txt') for item in label_list: label_vocab.index((item, )) label_vocab.index((parse.EMPTY, )) for item in label_list: label_vocab.index((item + "'", )) label_vocab.freeze() latent_tree = latent.latent_tree_builder(label_vocab, args.RBTlabel) print("Parsing test sentences...") start_time = time.time() test_predicted = [] test_gold = latent_tree.build_latent_trees(test_treebank) for x, chunks in test_treebank: dy.renew_cg() #sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] sentence = [(parse.XX, ch) for ch in x] predicted, _ = parser.parse(sentence) test_predicted.append(predicted.convert()) #test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted, args.expname + '.test.') test_fscore = evaluate.eval_chunks(args.evalb_dir, test_gold, test_predicted, output_filename=args.expname + '.finaltest.txt') # evalb print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_test(args): if not os.path.exists(args.experiment_directory): os.mkdir(args.experiment_directory) print("Loading test trees from {}...".format(args.input_file)) test_treebank = trees.load_trees(args.input_file) test_tokenized_lines = parse_trees_to_string_lines(test_treebank) test_embeddings_file = compute_elmo_embeddings(test_tokenized_lines, os.path.join( args.experiment_directory, 'test_embeddings')) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path)) model = dy.ParameterCollection() [parser] = dy.load(args.model_path, model) print("Parsing test sentences...") check_performance(parser, test_treebank, test_embeddings_file, args)
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) # model = dy.ParameterCollection() # [parser] = dy.load(args.model_path_base, model) parser = torch.load(args.model_path_base) if torch.cuda.is_available(): parser = parser.cuda() print("Parsing test sentences...") start_time = time.time() test_predicted = run_eval(parser, test_treebank) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() train_path = args.train_ptb_path dev_path = args.dev_ptb_path if hparams.dataset == "ctb": train_path = args.train_ctb_path dev_path = args.dev_ctb_path print("Loading training trees from {}...".format(train_path)) train_treebank = trees.load_trees(train_path) if hparams.max_len_train > 0: train_treebank = [ tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train ] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(dev_path)) dev_treebank = trees.load_trees(dev_path) if hparams.max_len_dev > 0: dev_treebank = [ tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev ] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] dev_parse = [tree.convert() for tree in dev_treebank] print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(Lparser.START) tag_vocab.index(Lparser.STOP) tag_vocab.index(Lparser.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(Lparser.START) word_vocab.index(Lparser.STOP) word_vocab.index(Lparser.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special), )) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print_vocabulary("Char", char_vocab) print("Initializing model...") load_path = None if load_path is not None: print("Loading parameters from {}".format(load_path)) info = torch_load(load_path) parser = Lparser.ChartParser.from_spec(info["spec"], info["state_dict"]) else: parser = Lparser.ChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [ param for param in parser.parameters() if param.requires_grad ] trainer = torch.optim.Adam(trainable_parameters, lr=1.0, betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info["trainer"]) def set_lr(new_lr): for param_group in trainer.param_groups: param_group["lr"] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, "max", factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = (np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm) print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_model_path = None model_name = hparams.model_name best_dev_processed = 0 print("This is ", model_name) start_time = time.time() def check_dev(epoch_num): nonlocal best_dev_fscore nonlocal best_model_path nonlocal best_dev_processed dev_start_time = time.time() parser.eval() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] ( predicted, _, ) = parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print("\n" "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format(dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time))) if dev_fscore.fscore > best_dev_fscore: if best_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_model_path + ext if os.path.exists(path): print( "Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_model_path = "{}_best_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_model_path)) torch.save( { "spec": parser.spec, "state_dict": parser.state_dict(), "trainer": trainer.state_dict(), }, best_model_path + ".pt", ) for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) parser.train() batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] for subbatch_sentences, subbatch_trees in parser.split_batch( batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() print( "\r" "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), ), end="", ) sys.stdout.flush() if current_processed >= check_every: current_processed -= check_every check_dev(epoch) # adjust learning rate at the end of an epoch if hparams.step_decay: if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore)
def run_train(args, hparams): # if args.numpy_seed is not None: # print("Setting numpy random seed to {}...".format(args.numpy_seed)) # np.random.seed(args.numpy_seed) # # # Make sure that pytorch is actually being initialized randomly. # # On my cluster I was getting highly correlated results from multiple # # runs, but calling reset_parameters() changed that. A brief look at the # # pytorch source code revealed that pytorch initializes its RNG by # # calling std::random_device, which according to the C++ spec is allowed # # to be deterministic. # seed_from_numpy = np.random.randint(2147483648) # print("Manual seed for pytorch:", seed_from_numpy) # torch.manual_seed(seed_from_numpy) now_time = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S') log_file_name = os.path.join(args.log_dir, 'log-' + now_time) logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', filename=log_file_name, filemode='w', level=logging.INFO) logger = logging.getLogger(__name__) console_handler = logging.StreamHandler() logger.addHandler(console_handler) logger = logging.getLogger(__name__) random.seed(args.seed) np.random.seed(args.seed) torch.manual_seed(args.seed) hparams.set_from_args(args) logger.info("Hyperparameters:") logger.info(hparams.print()) logger.info("Loading training trees from {}...".format(args.train_path)) if hparams.predict_tags and args.train_path.endswith('10way.clean'): logger.info( "WARNING: The data distributed with this repository contains " "predicted part-of-speech tags only (not gold tags!) We do not " "recommend enabling predict_tags in this configuration.") train_treebank = trees.load_trees(args.train_path) if hparams.max_len_train > 0: train_treebank = [ tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train ] logger.info("Loaded {:,} training examples.".format(len(train_treebank))) logger.info("Loading development trees from {}...".format(args.dev_path)) dev_treebank = trees.load_trees(args.dev_path) if hparams.max_len_dev > 0: dev_treebank = [ tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev ] logger.info("Loaded {:,} development examples.".format(len(dev_treebank))) logger.info("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) if hparams.max_len_dev > 0: test_treebank = [ tree for tree in test_treebank if len(list(tree.leaves())) <= hparams.max_len_dev ] logger.info("Loaded {:,} test examples.".format(len(test_treebank))) logger.info("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] dev_parse = [tree.convert() for tree in dev_treebank] test_parse = [tree.convert() for tree in test_treebank] logger.info("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() # -------- ngram vocab ------------ ngram_vocab = vocabulary.Vocabulary() ngram_vocab.index(()) ngram_finder = FindNgrams(min_count=hparams.ngram_threshold) def get_sentence(parse): sentences = [] for tree in parse: sentence = [] for leaf in tree.leaves(): sentence.append(leaf.word) sentences.append(sentence) return sentences sentence_list = get_sentence(train_parse) if not args.cross_domain: sentence_list.extend(get_sentence(dev_parse)) # sentence_list.extend(get_sentence(test_parse)) if hparams.ngram_type == 'freq': logger.info('ngram type: freq') ngram_finder.count_ngram(sentence_list, hparams.ngram) elif hparams.ngram_type == 'pmi': logger.info('ngram type: pmi') ngram_finder.find_ngrams_pmi(sentence_list, hparams.ngram, hparams.ngram_freq_threshold) else: raise ValueError() ngram_type_count = [0 for _ in range(hparams.ngram)] for w, c in ngram_finder.ngrams.items(): ngram_type_count[len(list(w)) - 1] += 1 for _ in range(c): ngram_vocab.index(w) logger.info(str(ngram_type_count)) ngram_vocab.freeze() ngram_count = [0 for _ in range(hparams.ngram)] for sentence in sentence_list: for n in range(len(ngram_count)): length = n + 1 for i in range(len(sentence)): gram = tuple(sentence[i:i + length]) if gram in ngram_finder.ngrams: ngram_count[n] += 1 logger.info(str(ngram_count)) # -------- ngram vocab ------------ def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} logger.info("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print_vocabulary("Ngram", ngram_vocab) logger.info("Initializing model...") load_path = None if load_path is not None: logger.info(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = SAPar_model.SAChartParser.from_spec(info['spec'], info['state_dict']) else: parser = SAPar_model.SAChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, ngram_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [ param for param in parser.parameters() if param.requires_grad ] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) pytorch_total_params = sum(p.numel() for p in parser.parameters() if p.requires_grad) logger.info('# of trainable parameters: %d' % pytorch_total_params) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm logger.info("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_eval_fscore = -np.inf test_fscore_on_dev = -np.inf best_eval_scores = None best_eval_model_path = None best_eval_processed = 0 start_time = time.time() def check_eval(eval_treebank, ep, flag='dev'): # nonlocal best_eval_fscore # nonlocal best_eval_model_path # nonlocal best_eval_processed dev_start_time = time.time() eval_predicted = [] for dev_start_index in range(0, len(eval_treebank), args.eval_batch_size): subbatch_trees = eval_treebank[dev_start_index:dev_start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ eval_predicted.extend([p.convert() for p in predicted]) eval_fscore = evaluate.evalb(args.evalb_dir, eval_treebank, eval_predicted) logger.info(flag + ' eval ' 'epoch {} ' "fscore {} " "elapsed {} " "total-elapsed {}".format( ep, eval_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), )) return eval_fscore def save_model(eval_fscore, remove_model): nonlocal best_eval_fscore nonlocal best_eval_model_path nonlocal best_eval_processed nonlocal best_eval_scores if best_eval_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_eval_model_path + ext if os.path.exists(path) and remove_model: logger.info( "Removing previous model file {}...".format(path)) os.remove(path) best_eval_fscore = eval_fscore.fscore best_eval_scores = eval_fscore best_eval_model_path = "{}_eval={:.2f}_{}".format( args.model_path_base, eval_fscore.fscore, now_time) best_eval_processed = total_processed logger.info( "Saving new best model to {}...".format(best_eval_model_path)) torch.save( { 'spec': parser.spec, 'state_dict': parser.state_dict(), # 'trainer' : trainer.state_dict(), }, best_eval_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] batch_num_tokens = sum( len(sentence) for sentence in batch_sentences) for subbatch_sentences, subbatch_trees in parser.split_batch( batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) if hparams.predict_tags: loss = loss[0] / len( batch_trees) + loss[1] / batch_num_tokens else: loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() print("epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), )) if current_processed >= check_every: current_processed -= check_every dev_fscore = check_eval(dev_treebank, epoch, flag='dev') test_fscore = check_eval(test_treebank, epoch, flag='test') if dev_fscore.fscore > best_eval_fscore: save_model(dev_fscore, remove_model=True) test_fscore_on_dev = test_fscore # adjust learning rate at the end of an epoch if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_eval_fscore) if (total_processed - best_eval_processed) > args.patients \ + ((hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)): logger.info( "Terminating due to lack of improvement in eval fscore.") logger.info("best dev {} test {}".format( best_eval_scores, test_fscore_on_dev, )) break
def evaluate_on_brown_corpus(args): if not os.path.exists(args.experiment_directory): os.mkdir(args.experiment_directory) model = dy.ParameterCollection() [parser] = dy.load(args.model_path_base, model) assert parser.use_elmo == args.use_elmo, (parser.use_elmo, args.use_elmo) directories = ['cf', 'cg', 'ck', 'cl', 'cm', 'cn', 'cp', 'cr'] for directory in directories: print('-' * 100) print(directory) input_file = '../brown/' + directory + '/' + directory + '.all.mrg' expt_name = args.experiment_directory + '/' + directory if not os.path.exists(expt_name): os.mkdir(expt_name) cleaned_corpus_path = trees.cleanup_text(input_file) treebank = trees.load_trees(cleaned_corpus_path, strip_top=True, filter_none=True) sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves] for tree in treebank] tokenized_lines = [' '.join([word for pos, word in sentence]) for sentence in sentences] if args.use_elmo: embedding_file = compute_elmo_embeddings(tokenized_lines, expt_name) else: embedding_file = None dev_predicted = [] num_correct = 0 total = 0 for tree_index, tree in enumerate(treebank): if tree_index % 100 == 0: print(tree_index) dy.renew_cg() sentence = sentences[tree_index] if args.use_elmo: embeddings_np = embedding_file[str(tree_index)][:, :, :] assert embeddings_np.shape[1] == len(sentence), ( embeddings_np.shape[1], len(sentence)) embeddings = dy.inputTensor(embeddings_np) else: embeddings = None predicted, (additional_info, c, t) = parser.span_parser(sentence, is_train=False, elmo_embeddings=embeddings) num_correct += c total += t dev_predicted.append(predicted.convert()) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, dev_predicted, args=args, erase_labels=True, name="without-labels", expt_name=expt_name) print("dev-fscore without labels", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, dev_predicted, args=args, erase_labels=True, flatten=True, name="without-label-flattened", expt_name=expt_name) print("dev-fscore without labels and flattened", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, dev_predicted, args=args, erase_labels=False, flatten=True, name="flattened", expt_name=expt_name) print("dev-fscore with labels and flattened", dev_fscore_without_labels) test_fscore = evaluate.evalb('EVALB/', treebank, dev_predicted, args=args, name="regular", expt_name=expt_name) print("regular", test_fscore) pos_fraction = num_correct / total print('pos fraction', pos_fraction) with open(expt_name + '/pos_accuracy.txt', 'w') as f: f.write(str(pos_fraction))
import trees train_treebank = trees.load_trees("data/02-21.10way.clean") print("Loaded {:,} training examples.".format(len(train_treebank))) s = train_treebank[11] print(s) t = s.sentencify() print(t)
def run_test(args): const_test_path = args.consttest_ptb_path dep_test_path = args.deptest_ptb_path if args.dataset == 'ctb': const_test_path = args.consttest_ctb_path dep_test_path = args.deptest_ctb_path print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) parser.eval() dep_test_reader = CoNLLXReader(dep_test_path, parser.type_vocab) print('Reading dependency parsing data from %s' % dep_test_path) dep_test_data = [] test_inst = dep_test_reader.getNext() dep_test_headid = np.zeros([40000, 300], dtype=int) dep_test_type = [] dep_test_word = [] dep_test_pos = [] dep_test_lengs = np.zeros(40000, dtype=int) cun = 0 while test_inst is not None: inst_size = test_inst.length() dep_test_lengs[cun] = inst_size sent = test_inst.sentence dep_test_data.append((sent.words, test_inst.postags, test_inst.heads, test_inst.types)) for i in range(inst_size): dep_test_headid[cun][i] = test_inst.heads[i] dep_test_type.append(test_inst.types) dep_test_word.append(sent.words) dep_test_pos.append(sent.postags) # dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) test_inst = dep_test_reader.getNext() cun = cun + 1 dep_test_reader.close() print("Loading test trees from {}...".format(const_test_path)) test_treebank = trees.load_trees(const_test_path, dep_test_headid, dep_test_type, dep_test_word) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Parsing test sentences...") start_time = time.time() punct_set = '.' '``' "''" ':' ',' parser.eval() test_predicted = [] for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _, = parser.parse_batch(subbatch_sentences) del _ test_predicted.extend([p.convert() for p in predicted]) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) ) test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in test_predicted] test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in test_predicted] assert len(test_pred_head) == len(test_pred_type) assert len(test_pred_type) == len(dep_test_type) stats, stats_nopunc, stats_root, test_total_inst = dep_eval.eval(len(test_pred_head), dep_test_word, dep_test_pos, test_pred_head, test_pred_type, dep_test_headid, dep_test_type, dep_test_lengs, punct_set=punct_set, symbolic_root=False) test_ucorrect, test_lcorrect, test_total, test_ucomlpete_match, test_lcomplete_match = stats test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucomlpete_match_nopunc, test_lcomplete_match_nopunc = stats_nopunc test_root_correct, test_total_root = stats_root print( 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst )) print( 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% ' % ( test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst)) print('best test Root: corr: %d, total: %d, acc: %.2f%%' % ( test_root_correct, test_total_root, test_root_correct * 100 / test_total_root)) print( '============================================================================================================================')
def run_train_question_bank(args): if not os.path.exists(args.experiment_directory): os.mkdir(args.experiment_directory) all_trees = trees.load_trees(args.question_bank_trees_path) all_parses = [tree.convert() for tree in all_trees] print("Loaded {:,} trees.".format(len(all_parses))) tentative_stanford_train_indices = list(range(0, 1000)) + list(range(2000, 3000)) stanford_dev_indices = list(range(1000, 1500)) + list(range(3000, 3500)) stanford_test_indices = list(range(1500, 2000)) + list(range(3500, 4000)) dev_and_test_sentences = set() for index in stanford_dev_indices: parse = all_parses[index] sentence = [(leaf.tag, leaf.word) for leaf in parse.leaves] dev_and_test_sentences.add(tuple(sentence)) for index in stanford_test_indices: parse = all_parses[index] sentence = [(leaf.tag, leaf.word) for leaf in parse.leaves] dev_and_test_sentences.add(tuple(sentence)) stanford_train_indices = [] for index in tentative_stanford_train_indices: parse = all_parses[index] sentence = [(leaf.tag, leaf.word) for leaf in parse.leaves] if tuple(sentence) not in dev_and_test_sentences: stanford_train_indices.append(index) qb_embeddings_file = h5py.File(args.question_bank_elmo_embeddings_path, 'r') print("We have {:,} train trees.".format(len(stanford_train_indices))) wsj_train = load_parses(args.wsj_train_trees_path) qb_train_parses = [all_parses[index] for index in stanford_train_indices] qb_dev_treebank = [all_trees[index] for index in stanford_dev_indices] parser, model = load_or_create_model(args, qb_train_parses + wsj_train) trainer = dy.AdamTrainer(model) total_processed = 0 current_processed = 0 best_dev_fscore = -np.inf best_dev_model_path = None start_time = time.time() if args.train_on_wsj == 'true': print('training on wsj') wsj_embeddings_file = h5py.File(args.wsj_train_elmo_embeddings_path, 'r') wsj_indices = list(range(39832)) else: print('not training on wsj') indices_file_path = args.experiment_directory + '/train_tree_indices.txt' if os.path.exists(indices_file_path): with open(indices_file_path, 'r') as f: tree_indices = [int(x) for x in f.read().splitlines()] print('loaded', len(tree_indices), 'indices from file', indices_file_path) elif args.num_samples != 'false': print('restricting to', args.num_samples, 'samples') random.shuffle(stanford_train_indices) tree_indices = stanford_train_indices[:int(args.num_samples)] else: print('training on original data') tree_indices = stanford_train_indices if args.num_samples != 'false': assert int(args.num_samples) == len(tree_indices), (args.num_samples, len(tree_indices)) with open(indices_file_path, 'w') as f: f.write('\n'.join([str(x) for x in tree_indices])) for epoch in itertools.count(start=1): np.random.shuffle(tree_indices) epoch_start_time = time.time() for start_index in range(0, len(tree_indices), args.batch_size): dy.renew_cg() batch_losses = [] for tree_index in tree_indices[start_index: start_index + args.batch_size]: tree = all_parses[tree_index] sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] embeddings_np = qb_embeddings_file[str(tree_index)][:, :, :] assert embeddings_np.shape[1] == len(sentence), ( embeddings_np.shape, len(sentence), sentence) embeddings = dy.inputTensor(embeddings_np) loss = parser.span_parser(sentence, is_train=True, gold=tree, elmo_embeddings=embeddings) batch_losses.append(loss) total_processed += 1 current_processed += 1 if args.train_on_wsj == 'true': random.shuffle(wsj_indices) for tree_index in wsj_indices[:100]: tree = wsj_train[tree_index] sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] embeddings = dy.inputTensor(wsj_embeddings_file[str(tree_index)][:, :, :]) loss = parser.span_parser(sentence, is_train=True, gold=tree, elmo_embeddings=embeddings) batch_losses.append(loss) batch_loss = dy.average(batch_losses) batch_loss_value = batch_loss.scalar_value() batch_loss.backward() trainer.update() print( "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size, int(np.ceil(len(tree_indices) / args.batch_size)), total_processed, batch_loss_value, format_elapsed(epoch_start_time), format_elapsed(start_time), ) ) if epoch % 10 == 0: check_performance_and_save(parser, best_dev_fscore, best_dev_model_path, qb_dev_treebank, qb_embeddings_file, args)
def train_on_parses(args): args.model_path_base = os.path.join(args.experiment_directory, 'model') train_parses = load_parses(os.path.join(args.experiment_directory, 'train_trees.txt')) train_indices = list(range(len(train_parses))) if not args.no_elmo: train_tokenized_lines = parse_trees_to_string_lines(train_parses) train_embeddings = compute_elmo_embeddings(train_tokenized_lines, os.path.join(args.experiment_directory, 'train_embeddings')) else: train_embeddings = None dev_trees = trees.load_trees(os.path.join(args.experiment_directory, 'dev_trees.txt')) if not args.no_elmo: dev_tokenized_lines = parse_trees_to_string_lines(dev_trees) dev_embeddings = compute_elmo_embeddings(dev_tokenized_lines, os.path.join(args.experiment_directory, 'dev_embeddings')) else: dev_embeddings = None additional_trees_path = os.path.join(args.experiment_directory, 'additional_trees.txt') if os.path.exists(additional_trees_path): print('Training on', additional_trees_path) additional_train_trees = load_parses(additional_trees_path) additional_trees_indices = list(range(len(additional_train_trees))) if not args.no_elmo: additional_tokenized_lines = parse_trees_to_string_lines(additional_train_trees) additional_embeddings_file = compute_elmo_embeddings(additional_tokenized_lines, os.path.join( args.experiment_directory, 'additional_embeddings')) else: additional_embeddings_file = None else: print('No additional training trees.') additional_train_trees = [] additional_trees_indices = [] parser, model = load_or_create_model(args, train_parses + additional_train_trees) trainer = dy.AdamTrainer(model) total_processed = 0 current_processed = 0 best_dev_fscore = -np.inf best_dev_model_path = None start_time = time.time() for epoch in itertools.count(start=1): np.random.shuffle(train_indices) epoch_start_time = time.time() for start_index in range(0, len(train_indices), args.batch_size): dy.renew_cg() batch_losses = [] for tree_index in train_indices[start_index: start_index + args.batch_size]: tree = train_parses[tree_index] sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] if train_embeddings is not None: embeddings_np = train_embeddings[str(tree_index)][:, :, :] assert embeddings_np.shape[1] == len(sentence), ( embeddings_np.shape, len(sentence), sentence) embeddings = dy.inputTensor(embeddings_np) else: embeddings = None loss = parser.span_parser(sentence, is_train=True, gold=tree, elmo_embeddings=embeddings) batch_losses.append(loss) total_processed += 1 current_processed += 1 random.shuffle(additional_trees_indices) for tree_index in additional_trees_indices[:100]: tree = additional_train_trees[tree_index] sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] if additional_embeddings_file is not None: embeddings = dy.inputTensor( additional_embeddings_file[str(tree_index)][:, :, :]) else: embeddings = None loss = parser.span_parser(sentence, is_train=True, gold=tree, elmo_embeddings=embeddings) batch_losses.append(loss) batch_loss = dy.average(batch_losses) batch_loss_value = batch_loss.scalar_value() batch_loss.backward() trainer.update() print( "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size, int(np.ceil(len(train_indices) / args.batch_size)), total_processed, batch_loss_value, format_elapsed(epoch_start_time), format_elapsed(start_time), ) ) if epoch % int(args.num_epochs_per_check) == 0: check_performance_and_save(parser, best_dev_fscore, best_dev_model_path, dev_trees, dev_embeddings, args)
def run_test_qbank(args): if not os.path.exists(args.experiment_directory): os.mkdir(args.experiment_directory) print("Loading model from {}...".format(args.model_path_base)) model = dy.ParameterCollection() [parser] = dy.load(args.model_path_base, model) all_trees = trees.load_trees(args.question_bank_trees_path) if args.stanford_split == 'true': print('using stanford split') split_to_indices = { 'train': list(range(0, 1000)) + list(range(2000, 3000)), 'dev': list(range(1000, 1500)) + list(range(3000, 3500)), 'test': list(range(1500, 2000)) + list(range(3500, 4000)) } else: print('not using stanford split') split_to_indices = { 'train': range(0, 2000), 'dev': range(2000, 3000), 'test': range(3000, 4000) } test_indices = split_to_indices[args.split] qb_embeddings_file = h5py.File('../question-bank.hdf5', 'r') dev_predicted = [] for test_index in test_indices: if len(dev_predicted) % 100 == 0: dy.renew_cg() tree = all_trees[test_index] sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] test_embeddings_np = qb_embeddings_file[str(test_index)][:, :, :] assert test_embeddings_np.shape[1] == len(sentence) test_embeddings = dy.inputTensor(test_embeddings_np) predicted, _ = parser.span_parser(sentence, is_train=False, elmo_embeddings=test_embeddings) dev_predicted.append(predicted.convert()) test_treebank = [all_trees[index] for index in test_indices] dev_fscore_without_labels = evaluate.evalb(args.evalb_dir, test_treebank, dev_predicted, args=args, erase_labels=True, name="without-labels") print("dev-fscore without labels", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb(args.evalb_dir, test_treebank, dev_predicted, args=args, erase_labels=True, flatten=True, name="without-label-flattened") print("dev-fscore without labels and flattened", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb(args.evalb_dir, test_treebank, dev_predicted, args=args, erase_labels=False, flatten=True, name="flattened") print("dev-fscore with labels and flattened", dev_fscore_without_labels) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, dev_predicted, args=args, name="regular") print("regular", test_fscore)
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() print("Loading training trees from {}...".format(args.train_path)) if hparams.predict_tags and args.train_path.endswith('10way.clean'): print("WARNING: The data distributed with this repository contains " "predicted part-of-speech tags only (not gold tags!) We do not " "recommend enabling predict_tags in this configuration.") train_treebank = trees.load_trees(args.train_path) if hparams.max_len_train > 0: train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(args.dev_path)) dev_treebank = trees.load_trees(args.dev_path) if hparams.max_len_dev > 0: dev_treebank = [tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(tokens.START) tag_vocab.index(tokens.STOP) tag_vocab.index(tokens.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(tokens.START) word_vocab.index(tokens.STOP) word_vocab.index(tokens.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) char_set = set() for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print("Initializing model...") load_path = args.load_path if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) else: parser = parse_nk.NKChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [param for param in parser.parameters() if param.requires_grad] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None best_dev_processed = 0 start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path nonlocal best_dev_processed dev_start_time = time.time() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _ = parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_dev_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore best_dev_model_path = "{}_dev={:.2f}".format( args.model_path_base, dev_fscore.fscore) best_dev_processed = total_processed print("Saving new best model to {}...".format(best_dev_model_path)) torch.save({ 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer' : trainer.state_dict(), }, best_dev_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in tqdm(range(0, len(train_parse), args.batch_size)): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] batch_num_tokens = sum(len(sentence) for sentence in batch_sentences) for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) if hparams.predict_tags: loss = loss[0] / len(batch_trees) + loss[1] / batch_num_tokens else: loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() # if start_index // args.batch_size + 1 == int(np.ceil(len(train_parse) / args.batch_size)): # print( # "epoch {:,} " # "batch {:,}/{:,} " # "processed {:,} " # "batch-loss {:.4f} " # "grad-norm {:.4f} " # "epoch-elapsed {} " # "total-elapsed {}".format( # epoch, # start_index // args.batch_size + 1, # int(np.ceil(len(train_parse) / args.batch_size)), # total_processed, # batch_loss_value, # grad_norm, # format_elapsed(epoch_start_time), # format_elapsed(start_time), # ) # ) if current_processed >= check_every: current_processed -= check_every # print('\nEpoch {}, weights {}'.format(epoch, parser.weighted_layer.weight.data)) check_dev() # adjust learning rate at the end of an epoch if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_fscore) if (total_processed - best_dev_processed) > ((hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)): print("Terminating due to lack of improvement in dev fscore.") print("The layer weights are: ", parser.weighted_layer.weight) break
def run_index(args): print("Saving span representations") print() print("Loading train trees from {}...".format(args.train_path)) train_treebank = trees.load_trees(args.train_path) print("Loaded {:,} train examples.".format(len(train_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith( ".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_jc.NKChartParser.from_spec(info['spec'], info['state_dict']) parser.no_mlp = args.no_mlp parser.no_relu = args.no_relu if args.no_relu: parser.remove_relu() print("Getting labelled span representations") start_time = time.time() if args.redo_vocab: parser.label_vocab = gen_label_vocab( [tree.convert() for tree in train_treebank]) num_labels = len(parser.label_vocab.values) """ span_index = index.SpanIndex( num_indices = num_labels, library = args.library, ) """ span_index = (index.FaissIndex(num_labels=num_labels, metric=parser.metric) if args.library == "faiss" else index.AnnoyIndex( num_indices=num_labels, metric=parser.metric)) rep_time = time.time() span_reps, span_infos = index.get_span_reps_infos(parser, train_treebank, args.batch_size) print(f"rep-time: {format_elapsed(rep_time)}") # clean up later, refactor back into index.py build_time = time.time() #use_gpu = True use_gpu = False print(f"Using gpu: {use_gpu}") if args.library == "faiss": if use_gpu: span_index.to(0) span_index.add(span_reps, span_infos) span_index.build() else: for rep, info in zip(span_reps, span_infos): span_index.add_item(rep, info) span_index.build() #span_index.build() print(f"build-time {format_elapsed(build_time)}") if use_gpu: span_index.to(-1) save_time = time.time() prefix = index.get_index_prefix( index_base_path=args.index_path, full_model_path=args.model_path_base, nn_prefix=args.nn_prefix, ) print(f"Saving index to {prefix}") span_index.save(prefix) print(f"save-time {format_elapsed(save_time)}") print(f"index-elapsed {format_elapsed(start_time)}")
import trees basePath = '../self-attentive-parser/models/de_elmo/' #de_char_1_1/' gold_treebank = trees.load_trees(basePath + 'gold.txt') predicted_treebank = trees.load_trees(basePath + 'predicted.txt') print("Loaded {:,} examples.".format(len(gold_treebank))) print("Loaded {:,} examples.".format(len(predicted_treebank))) for idx in range(len(predicted_treebank)): gold_sentences = [node.word for node in gold_treebank[idx].leaves()] predict_sentences = [ node.word for node in predicted_treebank[idx].leaves() ] if (len(gold_sentences) == len(predict_sentences)): for idx1 in range(len(gold_sentences)): if (gold_sentences[idx1] != predict_sentences[idx1]): print('Sentence ' + str(idx + 1) + ' fails! Not same') break else: print('Sentence ' + str(idx + 1) + ' fails! Not even same length') # If nothing printed then the sentence ordering is perfect failed_noun_count = 0 failed_verb_count = 0 total_noun_count = 0 total_verb_count = 0 for idx in range(len(predicted_treebank)): gold_tree_nodes = [gold_treebank[idx]]
def load_parses(file_path): print("Loading trees from {}...".format(file_path)) treebank = trees.load_trees(file_path) parses = [tree.convert() for tree in treebank] return parses
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() train_path = args.train_ptb_path dev_path = args.dev_ptb_path dep_train_path = args.dep_train_ptb_path dep_dev_path = args.dep_dev_ptb_path if hparams.dataset == 'ctb': train_path = args.train_ctb_path dev_path = args.dev_ctb_path dep_train_path = args.dep_train_ctb_path dep_dev_path = args.dep_dev_ctb_path dep_reader = CoNLLXReader(dep_train_path) print('Reading dependency parsing data from %s' % dep_train_path) dep_dev_reader = CoNLLXReader(dep_dev_path) print('Reading dependency parsing data from %s' % dep_dev_path) counter = 0 dep_sentences = [] dep_data = [] dep_heads = [] dep_types = [] inst = dep_reader.getNext() while inst is not None: inst_size = inst.length() if hparams.max_len_train > 0 and inst_size - 1 > hparams.max_len_train: inst = dep_reader.getNext() continue counter += 1 if counter % 10000 == 0: print("reading data: %d" % counter) sent = inst.sentence dep_data.append((sent.words, inst.postags, inst.heads, inst.types)) #dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) dep_sentences.append(sent.words) dep_heads.append(inst.heads) dep_types.append(inst.types) inst = dep_reader.getNext() dep_reader.close() print("Total number of data: %d" % counter) dep_dev_data = [] dev_inst = dep_dev_reader.getNext() dep_dev_headid = np.zeros([3000,300],dtype=int) dep_dev_type = [] dep_dev_word = [] dep_dev_pos = [] dep_dev_lengs = np.zeros(3000, dtype=int) cun = 0 while dev_inst is not None: inst_size = dev_inst.length() if hparams.max_len_dev > 0 and inst_size - 1> hparams.max_len_dev: dev_inst = dep_dev_reader.getNext() continue dep_dev_lengs[cun] = inst_size sent = dev_inst.sentence dep_dev_data.append((sent.words, dev_inst.postags, dev_inst.heads, dev_inst.types)) for i in range(inst_size): dep_dev_headid[cun][i] = dev_inst.heads[i] dep_dev_type.append(dev_inst.types) dep_dev_word.append(sent.words) dep_dev_pos.append(sent.postags) #dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) dev_inst = dep_dev_reader.getNext() cun = cun + 1 dep_dev_reader.close() print("Loading training trees from {}...".format(train_path)) train_treebank = trees.load_trees(train_path, dep_heads, dep_types, dep_sentences) if hparams.max_len_train > 0: train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(dev_path)) dev_treebank = trees.load_trees(dev_path, dep_dev_headid, dep_dev_type, dep_dev_word) if hparams.max_len_dev > 0: dev_treebank = [tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] dev_parse = [tree.convert() for tree in dev_treebank] count_wh("train data:", train_parse, dep_heads, dep_types) count_wh("dev data:", dev_parse, dep_dev_headid, dep_dev_type) print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(Zparser.START) tag_vocab.index(Zparser.STOP) tag_vocab.index(Zparser.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(Zparser.START) word_vocab.index(Zparser.STOP) word_vocab.index(Zparser.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) sublabels = [Zparser.Sub_Head] label_vocab.index(tuple(sublabels)) type_vocab = vocabulary.Vocabulary() char_set = set() for i, tree in enumerate(train_parse): const_sentences = [leaf.word for leaf in tree.leaves()] assert len(const_sentences) == len(dep_sentences[i]) nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) if node.type is not Zparser.ROOT:#not include root type type_vocab.index(node.type) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) type_vocab.index(node.type) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() #char_vocab.index(tokens.CHAR_PAD) # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() type_vocab.freeze() punctuation = hparams.punctuation punct_set = punctuation def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print_vocabulary("Char", char_vocab) print_vocabulary("Type", type_vocab) print("Initializing model...") load_path = None if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) else: parser = Zparser.ChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, type_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [param for param in parser.parameters() if param.requires_grad] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_score = -np.inf best_model_path = None model_name = hparams.model_name print("This is ", model_name) start_time = time.time() def check_dev(epoch_num): nonlocal best_dev_score nonlocal best_model_path dev_start_time = time.time() parser.eval() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _,= parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in dev_predicted] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in dev_predicted] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(dep_dev_type) stats, stats_nopunc, stats_root, num_inst = dep_eval.eval(len(dev_pred_head), dep_dev_word, dep_dev_pos, dev_pred_head, dev_pred_type, dep_dev_headid, dep_dev_type, dep_dev_lengs, punct_set=punct_set, symbolic_root=False) dev_ucorr, dev_lcorr, dev_total, dev_ucomlpete, dev_lcomplete = stats dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc = stats_nopunc dev_root_corr, dev_total_root = stats_root dev_total_inst = num_inst print( 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst)) print( 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst)) print('Root: corr: %d, total: %d, acc: %.2f%%' % ( dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root)) dev_uas = dev_ucorr_nopunc * 100 / dev_total_nopunc dev_las = dev_lcorr_nopunc * 100 / dev_total_nopunc if dev_fscore.fscore + dev_las > best_dev_score : if best_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_score = dev_fscore.fscore + dev_las best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}".format( args.model_path_base, dev_fscore.fscore, dev_uas,dev_las) print("Saving new best model to {}...".format(best_model_path)) torch.save({ 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer' : trainer.state_dict(), }, besthh_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break #check_dev(epoch) np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) parser.train() batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() print( "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), ) ) if current_processed >= check_every: current_processed -= check_every check_dev(epoch) # adjust learning rate at the end of an epoch if hparams.step_decay: if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_score)
def test_on_parses(args): if not os.path.exists(args.experiment_directory): os.mkdir(args.experiment_directory) model = dy.ParameterCollection() [parser] = dy.load(args.model_path_base, model) treebank = trees.load_trees(args.input_file, strip_top=True, filter_none=True) output = [tree.linearize() for tree in treebank] with open(os.path.join(args.experiment_directory, 'parses.txt'), 'w') as f: f.write('\n'.join(output)) sentence_embeddings = h5py.File(args.elmo_embeddings_file_path, 'r') test_predicted = [] start_time = time.time() total_log_likelihood = 0 total_confusion_matrix = {} total_turned_off = 0 ranks = [] num_correct = 0 total = 0 for tree_index, tree in enumerate(treebank): if tree_index % 100 == 0: print(tree_index) dy.renew_cg() sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves] elmo_embeddings_np = sentence_embeddings[str(tree_index)][:, :, :] assert elmo_embeddings_np.shape[1] == len(sentence), ( elmo_embeddings_np.shape[1], len(sentence), [word for pos, word in sentence]) elmo_embeddings = dy.inputTensor(elmo_embeddings_np) predicted, (additional_info, c, t) = parser.span_parser(sentence, is_train=False, elmo_embeddings=elmo_embeddings) num_correct += c total += t rank = additional_info[3] ranks.append(rank) total_log_likelihood += additional_info[-1] test_predicted.append(predicted.convert()) print('pos accuracy', num_correct / total) print("total time", time.time() - start_time) print("total loglikelihood", total_log_likelihood) print("total turned off", total_turned_off) print(total_confusion_matrix) print(ranks) print("avg", np.mean(ranks), "median", np.median(ranks)) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, test_predicted, args=args, erase_labels=True, name="without-labels") print("dev-fscore without labels", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, test_predicted, args=args, erase_labels=True, flatten=True, name="without-label-flattened") print("dev-fscore without labels and flattened", dev_fscore_without_labels) dev_fscore_without_labels = evaluate.evalb('EVALB/', treebank, test_predicted, args=args, erase_labels=False, flatten=True, name="flattened") print("dev-fscore with labels and flattened", dev_fscore_without_labels) test_fscore = evaluate.evalb('EVALB/', treebank, test_predicted, args=args, name="regular") print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) ) with open(os.path.join(args.experiment_directory, "confusion_matrix.pickle"), "wb") as f: pickle.dump(total_confusion_matrix, f)
def run_train(args, hparams): if args.seed is not None: print("Setting numpy random seed to {}...".format(args.seed)) np.random.seed(args.seed) seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) if args.local_rank == -1 or args.no_cuda: device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") n_gpu = torch.cuda.device_count() else: torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) n_gpu = 1 # Initializes the distributed backend which will take care of sychronizing nodes/GPUs torch.distributed.init_process_group(backend='nccl') logger.info( "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}". format(device, n_gpu, bool(args.local_rank != -1), args.fp16)) if args.gradient_accumulation_steps < 1: raise ValueError( "Invalid gradient_accumulation_steps parameter: {}, should be >= 1" .format(args.gradient_accumulation_steps)) args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps) if n_gpu > 0: torch.cuda.manual_seed_all(seed_from_numpy) # if os.path.exists(args.output_dir) and os.listdir(args.output_dir): # raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir)) # os.makedirs(args.output_dir, exist_ok=True) print("Initializing model...") load_path = args.load_path if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) model = Zmodel.Jointmodel.from_spec(info['spec'], info['state_dict']) hparams = model.hparams Ptb_dataset = PTBDataset(hparams) Ptb_dataset.process_PTB(args) else: hparams.set_from_args(args) Ptb_dataset = PTBDataset(hparams) Ptb_dataset.process_PTB(args) model = Zmodel.Jointmodel( Ptb_dataset.tag_vocab, Ptb_dataset.word_vocab, Ptb_dataset.label_vocab, Ptb_dataset.char_vocab, Ptb_dataset.type_vocab, Ptb_dataset.srl_vocab, hparams, ) print("Hyperparameters:") hparams.print() # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case) #train_examples = None # num_train_steps = None print("Loading Train Dataset", args.train_file) Ptb_dataset.rand_dataset() # print(model.tokenizer.tokenize("Federal Paper Board sells paper and wood products .")) #max_seq_length = model.bert_max_len train_dataset = BERTDataset(args.pre_wiki_line, hparams, Ptb_dataset, args.train_file, model.tokenizer, seq_len=args.max_seq_length, corpus_lines=None, on_memory=args.on_memory) task_list = [ 'dev_synconst', 'dev_srlspan', 'dev_srldep', 'test_synconst', 'test_srlspan', 'test_srldep', 'brown_srlspan', 'brown_srldep' ] evaluator = EvalManyTask(device=1, hparams=hparams, ptb_dataset=Ptb_dataset, task_list=task_list, bert_tokenizer=model.tokenizer, seq_len=args.eval_seq_length, eval_batch_size=args.eval_batch_size, evalb_dir=args.evalb_dir, model_path_base=args.save_model_path_base, log_path="{}_log".format("models_log/" + hparams.model_name)) num_train_steps = int( len(train_dataset) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs) # Prepare model # model = BertForPreTraining.from_pretrained(args.bert_model) if args.fp16: model.half() model.to(device) if args.local_rank != -1: try: from apex.parallel import DistributedDataParallel as DDP except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) model = DDP(model) elif n_gpu > 1: model = torch.nn.DataParallel(model) # Prepare optimizer param_optimizer = list(model.named_parameters()) no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] optimizer_grouped_parameters = [{ 'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01 }, { 'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0 }] if args.fp16: try: from apex.optimizers import FP16_Optimizer from apex.optimizers import FusedAdam except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training." ) optimizer = FusedAdam(optimizer_grouped_parameters, lr=args.learning_rate, bias_correction=False, max_grad_norm=1.0) if args.loss_scale == 0: optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True) else: optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale) else: optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup_proportion, t_total=num_train_steps) if load_path is not None: optimizer.load_state_dict(info['optimizer']) global_step = args.pre_step pre_step = args.pre_step # wiki_line = 0 # while train_dataset.wiki_id < wiki_line: # train_dataset.file.__next__().strip() # train_dataset.wiki_id+=1 logger.info("***** Running training *****") logger.info(" Num examples = %d", len(train_dataset)) logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Num steps = %d", num_train_steps) if args.local_rank == -1: train_sampler = RandomSampler(train_dataset) else: #TODO: check if this works with current data generator from disk that relies on file.__next__ # (it doesn't return item back by index) train_sampler = DistributedSampler(train_dataset) train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size) hparams.model_name = args.model_name print("This is ", hparams.model_name) start_time = time.time() def save_args(hparams): arg_path = "{}_log".format("models_log/" + hparams.model_name) + '.arg.json' kwargs = hparams.to_dict() json.dump({'kwargs': kwargs}, open(arg_path, 'w'), indent=4) save_args(hparams) # test_save_path = args.save_model_path_base + "_fortest" # torch.save({ # 'spec': model_to_save.spec, # 'state_dict': model_to_save.state_dict(), # 'optimizer': optimizer.state_dict(), # }, test_save_path + ".pt") # evaluator.test_model_path = test_save_path cur_ptb_epoch = 0 for epoch in trange(int(args.num_train_epochs), desc="Epoch"): tr_loss = 0 nb_tr_examples, nb_tr_steps = 0, 0 #save_model_path, is_save = evaluator.eval_multitask(start_time, cur_ptb_epoch) epoch_start_time = time.time() for step, batch in enumerate(train_dataloader): model.train() input_ids, origin_ids, input_mask, word_start_mask, word_end_mask, segment_ids, perm_mask, target_mapping, lm_label_ids, lm_label_mask, is_next, \ synconst_list, syndep_head_list, syndep_type_list, srlspan_str_list, srldep_str_list, is_ptb = batch # synconst_list, syndep_head_list, syndep_type_list , srlspan_str_list, srldep_str_list = gold_list dis_idx = [i for i in range(len(input_ids))] dis_idx = torch.tensor(dis_idx) batch = dis_idx, input_ids, origin_ids, input_mask, word_start_mask, word_end_mask, segment_ids, perm_mask, target_mapping, lm_label_ids, lm_label_mask, is_next bert_data = tuple(t.to(device) for t in batch) sentences = [] gold_syntree = [] gold_srlspans = [] gold_srldeps = [] # for data_dict1 in dict1: for synconst, syndep_head_str, syndep_type_str, srlspan_str, srldep_str in zip( synconst_list, syndep_head_list, syndep_type_list, srlspan_str_list, srldep_str_list): syndep_head = json.loads(syndep_head_str) syndep_type = json.loads(syndep_type_str) syntree = trees.load_trees( synconst, [[int(head) for head in syndep_head]], [syndep_type], strip_top=False)[0] sentences.append([(leaf.tag, leaf.word) for leaf in syntree.leaves()]) gold_syntree.append(syntree.convert()) srlspan = {} srlspan_dict = json.loads(srlspan_str) for pred_id, argus in srlspan_dict.items(): srlspan[int(pred_id)] = [(int(a[0]), int(a[1]), a[2]) for a in argus] srldep_dict = json.loads(srldep_str) srldep = {} if str(-1) in srldep_dict: srldep = None else: for pred_id, argus in srldep_dict.items(): srldep[int(pred_id)] = [(int(a[0]), a[1]) for a in argus] gold_srlspans.append(srlspan) gold_srldeps.append(srldep) if global_step < pre_step: if global_step % 1000 == 0: print("global_step:", global_step) print("pre_step:", pre_step) print("Wiki line:", train_dataset.wiki_line) print("total-elapsed {} ".format( format_elapsed(start_time))) global_step += 1 cur_ptb_epoch = train_dataset.ptb_epoch continue bert_loss, task_loss = model(sentences=sentences, gold_trees=gold_syntree, gold_srlspans=gold_srlspans, gold_srldeps=gold_srldeps, bert_data=bert_data) if n_gpu > 1: bert_loss = bert_loss.sum() task_loss = task_loss.sum() loss = bert_loss + task_loss #* 0.1 loss = loss / len(synconst_list) bert_loss = bert_loss / len(synconst_list) task_loss = task_loss / len(synconst_list) tatal_loss = float(loss.data.cpu().numpy()) if bert_loss > 0: bert_loss = float(bert_loss.data.cpu().numpy()) if task_loss > 0: task_loss = float(task_loss.data.cpu().numpy()) # grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) lr_this_step = args.learning_rate * warmup_linear( global_step / num_train_steps, args.warmup_proportion) print("epoch {:,} " "ptb-epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "PTB line {:,} " "Wiki line {:,} " "total-loss {:.4f} " "bert-loss {:.4f} " "task-loss {:.4f} " "lr_this_step {:.12f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, cur_ptb_epoch, global_step, int(np.ceil(len(train_dataset) / args.train_batch_size)), (global_step + 1) * args.train_batch_size, train_dataset.ptb_cur_line, train_dataset.wiki_line, tatal_loss, bert_loss, task_loss, lr_this_step, format_elapsed(epoch_start_time), format_elapsed(start_time), )) # if n_gpu > 1: # loss = loss.mean() # mean() to average on multi-gpu. if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: optimizer.backward(loss) else: loss.backward() tr_loss += loss.item() nb_tr_examples += input_ids.size(0) nb_tr_steps += 1 if (step + 1) % args.gradient_accumulation_steps == 0: # modify learning rate with special warm up BERT uses lr_this_step = args.learning_rate * warmup_linear( global_step / num_train_steps, args.warmup_proportion) for param_group in optimizer.param_groups: param_group['lr'] = lr_this_step optimizer.step() optimizer.zero_grad() global_step += 1 #if train_dataset.ptb_epoch > cur_ptb_epoch: if global_step % args.pre_step_tosave == 0: cur_ptb_epoch = train_dataset.ptb_epoch save_path = "{}_gstep{}_wiki{}_loss={:.4f}.pt".\ format(args.save_model_path_base, global_step, train_dataset.wiki_line, tatal_loss) model_to_save = model.module if hasattr( model, 'module') else model torch.save( { 'spec': model_to_save.spec, 'state_dict': model_to_save.state_dict(), 'optimizer': optimizer.state_dict(), }, save_path) # evaluator.test_model_path = test_save_path # # save_model_path, is_save = evaluator.eval_multitask(start_time, cur_ptb_epoch) # if is_save: # print("Saving new best model to {}...".format(save_model_path)) # torch.save({ # 'spec': model_to_save.spec, # 'state_dict': model_to_save.state_dict(), # 'optimizer': optimizer.state_dict(), # }, save_model_path + ".pt") # Save a trained model logger.info("** ** * Saving fine - tuned model ** ** * ") torch.save( { 'spec': model_to_save.spec, 'state_dict': model_to_save.state_dict(), 'optimizer': optimizer.state_dict(), }, args.save_model_path_base + ".pt")
# %% if True: if parse_nk.use_cuda: info = torch.load(args.model_path_base) else: info = torch.load(args.model_path_base, map_location=lambda storage, location: storage) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_nk.NKChartParser.from_spec(info['spec'], info['state_dict']) bert_model = info['spec']['hparams']['bert_model'] bert_do_lower_case = info['spec']['hparams']['bert_do_lower_case'] #%% print("Loading test trees from {}...".format(args.test_path)) test_treebank = trees.load_trees(args.test_path) print("Loaded {:,} test examples.".format(len(test_treebank))) #%% import tensorflow as tf sess = tf.InteractiveSession() sd = parser.state_dict() LABEL_VOCAB = [x[0] for x in sorted(parser.label_vocab.indices.items(), key=lambda x: x[1])] TAG_VOCAB = [x[0] for x in sorted(parser.tag_vocab.indices.items(), key=lambda x: x[1])] # %%
def run_train(args): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) torch.manual_seed(args.numpy_seed) print("Loading training trees from {}...".format(args.train_path)) train_treebank = trees.load_trees(args.train_path) print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(args.dev_path)) dev_treebank = trees.load_trees(args.dev_path) print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(parse.START) tag_vocab.index(parse.STOP) word_vocab = vocabulary.Vocabulary() word_vocab.index(parse.START) word_vocab.index(parse.STOP) word_vocab.index(parse.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) for tree in train_parse: nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() def print_vocabulary(name, vocab): special = {parse.START, parse.STOP, parse.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print("Initializing model...") # model = dy.ParameterCollection() if args.parser_type == "top-down": parser = parse.TopDownParser( # model, tag_vocab, word_vocab, label_vocab, args.tag_embedding_dim, args.word_embedding_dim, args.lstm_layers, args.lstm_dim, args.label_hidden_dim, args.split_hidden_dim, args.dropout, ) # else: # parser = parse.ChartParser( # model, # tag_vocab, # word_vocab, # label_vocab, # args.tag_embedding_dim, # args.word_embedding_dim, # args.lstm_layers, # args.lstm_dim, # args.label_hidden_dim, # args.dropout, # ) # trainer = dy.AdamTrainer(model) optimizer = torch.optim.Adam(parser.parameters()) total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_fscore = -np.inf best_dev_model_path = None start_time = time.time() def check_dev(): nonlocal best_dev_fscore nonlocal best_dev_model_path dev_start_time = time.time() dev_predicted = [] for tree in dev_treebank: # dy.renew_cg() parser.eval() sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] predicted, _ = parser.parse(sentence) dev_predicted.append(predicted.convert()) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print("dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), )) if dev_fscore.fscore > best_dev_fscore: if best_dev_model_path is not None: # for ext in [".data", ".meta"]: # path = best_dev_model_path + ext # if os.path.exists(path): # print("Removing previous model file {}...".format(path)) # os.remove(path) path = best_dev_model_path if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_fscore = dev_fscore.fscore # best_dev_model_path = "{}_dev={:.2f}".format( best_dev_model_path = "{}_dev={:.2f}.pth".format( args.model_path_base, dev_fscore.fscore) print("Saving new best model to {}...".format(best_dev_model_path)) # dy.save(best_dev_model_path, [parser]) torch.save(parser, best_dev_model_path) for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): # dy.renew_cg() optimizer.zero_grad() parser.train() batch_losses = [] for tree in train_parse[start_index:start_index + args.batch_size]: sentence = [(leaf.tag, leaf.word) for leaf in tree.leaves()] if args.parser_type == "top-down": _, loss = parser.parse(sentence, tree, args.explore) # else: # _, loss = parser.parse(sentence, tree) batch_losses.append(loss) total_processed += 1 current_processed += 1 # batch_loss = dy.average(batch_losses) # batch_loss_value = batch_loss.scalar_value() batch_loss = torch.stack(batch_losses).mean() assert batch_loss.data.numel() == 1 batch_loss_value = batch_loss.data[0] batch_loss.backward() # trainer.update() optimizer.step() print("epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, format_elapsed(epoch_start_time), format_elapsed(start_time), )) if current_processed >= check_every: current_processed -= check_every check_dev()
'svm': SVM, 'random_forest': RandomForest, 'multilabel': MultiLabel, 'nn_multilabel': NeuralMultiLabel, 'neural': Neural, 'rnn': RNN, 'vote': VoteModel } if __name__ == '__main__': parser = argparse.ArgumentParser(prog=__package__) parser.add_argument('--model', choices=MODELS.keys(), default='multilabel') args = parser.parse_args() print('preprocessing') trees, max_edus = load_trees(TRAINING_DIR) vocab, samples = Vocabulary(trees), get_samples(trees) x_train, y_train, sents_idx = get_features(trees, samples, vocab, max_edus) print('training') model = MODELS[args.model](trees=trees, samples=samples, sents_idx=sents_idx, n_features=len(x_train[0]), models=[SGD, MultiLabel, RandomForest], num_classes=len(ACTIONS), hidden_size=256, batch_size=1024, epochs=100, lr=1e-4, w_decay=1e-5)