def run_test(args): const_test_path = args.consttest_ptb_path dep_test_path = args.deptest_ptb_path if args.dataset == 'ctb': const_test_path = args.consttest_ctb_path dep_test_path = args.deptest_ctb_path print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith(".pt"), "Only pytorch savefiles supported" info = torch_load(args.model_path_base) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) parser.eval() dep_test_reader = CoNLLXReader(dep_test_path, parser.type_vocab) print('Reading dependency parsing data from %s' % dep_test_path) dep_test_data = [] test_inst = dep_test_reader.getNext() dep_test_headid = np.zeros([40000, 300], dtype=int) dep_test_type = [] dep_test_word = [] dep_test_pos = [] dep_test_lengs = np.zeros(40000, dtype=int) cun = 0 while test_inst is not None: inst_size = test_inst.length() dep_test_lengs[cun] = inst_size sent = test_inst.sentence dep_test_data.append((sent.words, test_inst.postags, test_inst.heads, test_inst.types)) for i in range(inst_size): dep_test_headid[cun][i] = test_inst.heads[i] dep_test_type.append(test_inst.types) dep_test_word.append(sent.words) dep_test_pos.append(sent.postags) # dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) test_inst = dep_test_reader.getNext() cun = cun + 1 dep_test_reader.close() print("Loading test trees from {}...".format(const_test_path)) test_treebank = trees.load_trees(const_test_path, dep_test_headid, dep_test_type, dep_test_word) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Parsing test sentences...") start_time = time.time() punct_set = '.' '``' "''" ':' ',' parser.eval() test_predicted = [] for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_trees = test_treebank[start_index:start_index + args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _, = parser.parse_batch(subbatch_sentences) del _ test_predicted.extend([p.convert() for p in predicted]) test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, test_predicted) print( "test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ) ) test_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in test_predicted] test_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in test_predicted] assert len(test_pred_head) == len(test_pred_type) assert len(test_pred_type) == len(dep_test_type) stats, stats_nopunc, stats_root, test_total_inst = dep_eval.eval(len(test_pred_head), dep_test_word, dep_test_pos, test_pred_head, test_pred_type, dep_test_headid, dep_test_type, dep_test_lengs, punct_set=punct_set, symbolic_root=False) test_ucorrect, test_lcorrect, test_total, test_ucomlpete_match, test_lcomplete_match = stats test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucomlpete_match_nopunc, test_lcomplete_match_nopunc = stats_nopunc test_root_correct, test_total_root = stats_root print( 'best test W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( test_ucorrect, test_lcorrect, test_total, test_ucorrect * 100 / test_total, test_lcorrect * 100 / test_total, test_ucomlpete_match * 100 / test_total_inst, test_lcomplete_match * 100 / test_total_inst )) print( 'best test Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%% ' % ( test_ucorrect_nopunc, test_lcorrect_nopunc, test_total_nopunc, test_ucorrect_nopunc * 100 / test_total_nopunc, test_lcorrect_nopunc * 100 / test_total_nopunc, test_ucomlpete_match_nopunc * 100 / test_total_inst, test_lcomplete_match_nopunc * 100 / test_total_inst)) print('best test Root: corr: %d, total: %d, acc: %.2f%%' % ( test_root_correct, test_total_root, test_root_correct * 100 / test_total_root)) print( '============================================================================================================================')
def run_train(args, hparams): if args.numpy_seed is not None: print("Setting numpy random seed to {}...".format(args.numpy_seed)) np.random.seed(args.numpy_seed) # Make sure that pytorch is actually being initialized randomly. # On my cluster I was getting highly correlated results from multiple # runs, but calling reset_parameters() changed that. A brief look at the # pytorch source code revealed that pytorch initializes its RNG by # calling std::random_device, which according to the C++ spec is allowed # to be deterministic. seed_from_numpy = np.random.randint(2147483648) print("Manual seed for pytorch:", seed_from_numpy) torch.manual_seed(seed_from_numpy) hparams.set_from_args(args) print("Hyperparameters:") hparams.print() train_path = args.train_ptb_path dev_path = args.dev_ptb_path dep_train_path = args.dep_train_ptb_path dep_dev_path = args.dep_dev_ptb_path if hparams.dataset == 'ctb': train_path = args.train_ctb_path dev_path = args.dev_ctb_path dep_train_path = args.dep_train_ctb_path dep_dev_path = args.dep_dev_ctb_path dep_reader = CoNLLXReader(dep_train_path) print('Reading dependency parsing data from %s' % dep_train_path) dep_dev_reader = CoNLLXReader(dep_dev_path) print('Reading dependency parsing data from %s' % dep_dev_path) counter = 0 dep_sentences = [] dep_data = [] dep_heads = [] dep_types = [] inst = dep_reader.getNext() while inst is not None: inst_size = inst.length() if hparams.max_len_train > 0 and inst_size - 1 > hparams.max_len_train: inst = dep_reader.getNext() continue counter += 1 if counter % 10000 == 0: print("reading data: %d" % counter) sent = inst.sentence dep_data.append((sent.words, inst.postags, inst.heads, inst.types)) #dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) dep_sentences.append(sent.words) dep_heads.append(inst.heads) dep_types.append(inst.types) inst = dep_reader.getNext() dep_reader.close() print("Total number of data: %d" % counter) dep_dev_data = [] dev_inst = dep_dev_reader.getNext() dep_dev_headid = np.zeros([3000,300],dtype=int) dep_dev_type = [] dep_dev_word = [] dep_dev_pos = [] dep_dev_lengs = np.zeros(3000, dtype=int) cun = 0 while dev_inst is not None: inst_size = dev_inst.length() if hparams.max_len_dev > 0 and inst_size - 1> hparams.max_len_dev: dev_inst = dep_dev_reader.getNext() continue dep_dev_lengs[cun] = inst_size sent = dev_inst.sentence dep_dev_data.append((sent.words, dev_inst.postags, dev_inst.heads, dev_inst.types)) for i in range(inst_size): dep_dev_headid[cun][i] = dev_inst.heads[i] dep_dev_type.append(dev_inst.types) dep_dev_word.append(sent.words) dep_dev_pos.append(sent.postags) #dep_sentences.append([(tag, word) for i, (word, tag) in enumerate(zip(sent.words, sent.postags))]) dev_inst = dep_dev_reader.getNext() cun = cun + 1 dep_dev_reader.close() print("Loading training trees from {}...".format(train_path)) train_treebank = trees.load_trees(train_path, dep_heads, dep_types, dep_sentences) if hparams.max_len_train > 0: train_treebank = [tree for tree in train_treebank if len(list(tree.leaves())) <= hparams.max_len_train] print("Loaded {:,} training examples.".format(len(train_treebank))) print("Loading development trees from {}...".format(dev_path)) dev_treebank = trees.load_trees(dev_path, dep_dev_headid, dep_dev_type, dep_dev_word) if hparams.max_len_dev > 0: dev_treebank = [tree for tree in dev_treebank if len(list(tree.leaves())) <= hparams.max_len_dev] print("Loaded {:,} development examples.".format(len(dev_treebank))) print("Processing trees for training...") train_parse = [tree.convert() for tree in train_treebank] dev_parse = [tree.convert() for tree in dev_treebank] count_wh("train data:", train_parse, dep_heads, dep_types) count_wh("dev data:", dev_parse, dep_dev_headid, dep_dev_type) print("Constructing vocabularies...") tag_vocab = vocabulary.Vocabulary() tag_vocab.index(Zparser.START) tag_vocab.index(Zparser.STOP) tag_vocab.index(Zparser.TAG_UNK) word_vocab = vocabulary.Vocabulary() word_vocab.index(Zparser.START) word_vocab.index(Zparser.STOP) word_vocab.index(Zparser.UNK) label_vocab = vocabulary.Vocabulary() label_vocab.index(()) sublabels = [Zparser.Sub_Head] label_vocab.index(tuple(sublabels)) type_vocab = vocabulary.Vocabulary() char_set = set() for i, tree in enumerate(train_parse): const_sentences = [leaf.word for leaf in tree.leaves()] assert len(const_sentences) == len(dep_sentences[i]) nodes = [tree] while nodes: node = nodes.pop() if isinstance(node, trees.InternalParseNode): label_vocab.index(node.label) if node.type is not Zparser.ROOT:#not include root type type_vocab.index(node.type) nodes.extend(reversed(node.children)) else: tag_vocab.index(node.tag) word_vocab.index(node.word) type_vocab.index(node.type) char_set |= set(node.word) char_vocab = vocabulary.Vocabulary() #char_vocab.index(tokens.CHAR_PAD) # If codepoints are small (e.g. Latin alphabet), index by codepoint directly highest_codepoint = max(ord(char) for char in char_set) if highest_codepoint < 512: if highest_codepoint < 256: highest_codepoint = 256 else: highest_codepoint = 512 # This also takes care of constants like tokens.CHAR_PAD for codepoint in range(highest_codepoint): char_index = char_vocab.index(chr(codepoint)) assert char_index == codepoint else: char_vocab.index(tokens.CHAR_UNK) char_vocab.index(tokens.CHAR_START_SENTENCE) char_vocab.index(tokens.CHAR_START_WORD) char_vocab.index(tokens.CHAR_STOP_WORD) char_vocab.index(tokens.CHAR_STOP_SENTENCE) for char in sorted(char_set): char_vocab.index(char) tag_vocab.freeze() word_vocab.freeze() label_vocab.freeze() char_vocab.freeze() type_vocab.freeze() punctuation = hparams.punctuation punct_set = punctuation def print_vocabulary(name, vocab): special = {tokens.START, tokens.STOP, tokens.UNK} print("{} ({:,}): {}".format( name, vocab.size, sorted(value for value in vocab.values if value in special) + sorted(value for value in vocab.values if value not in special))) if args.print_vocabs: print_vocabulary("Tag", tag_vocab) print_vocabulary("Word", word_vocab) print_vocabulary("Label", label_vocab) print_vocabulary("Char", char_vocab) print_vocabulary("Type", type_vocab) print("Initializing model...") load_path = None if load_path is not None: print(f"Loading parameters from {load_path}") info = torch_load(load_path) parser = Zparser.ChartParser.from_spec(info['spec'], info['state_dict']) else: parser = Zparser.ChartParser( tag_vocab, word_vocab, label_vocab, char_vocab, type_vocab, hparams, ) print("Initializing optimizer...") trainable_parameters = [param for param in parser.parameters() if param.requires_grad] trainer = torch.optim.Adam(trainable_parameters, lr=1., betas=(0.9, 0.98), eps=1e-9) if load_path is not None: trainer.load_state_dict(info['trainer']) def set_lr(new_lr): for param_group in trainer.param_groups: param_group['lr'] = new_lr assert hparams.step_decay, "Only step_decay schedule is supported" warmup_coeff = hparams.learning_rate / hparams.learning_rate_warmup_steps scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( trainer, 'max', factor=hparams.step_decay_factor, patience=hparams.step_decay_patience, verbose=True, ) def schedule_lr(iteration): iteration = iteration + 1 if iteration <= hparams.learning_rate_warmup_steps: set_lr(iteration * warmup_coeff) clippable_parameters = trainable_parameters grad_clip_threshold = np.inf if hparams.clip_grad_norm == 0 else hparams.clip_grad_norm print("Training...") total_processed = 0 current_processed = 0 check_every = len(train_parse) / args.checks_per_epoch best_dev_score = -np.inf best_model_path = None model_name = hparams.model_name print("This is ", model_name) start_time = time.time() def check_dev(epoch_num): nonlocal best_dev_score nonlocal best_model_path dev_start_time = time.time() parser.eval() dev_predicted = [] for dev_start_index in range(0, len(dev_treebank), args.eval_batch_size): subbatch_trees = dev_treebank[dev_start_index:dev_start_index+args.eval_batch_size] subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in subbatch_trees] predicted, _,= parser.parse_batch(subbatch_sentences) del _ dev_predicted.extend([p.convert() for p in predicted]) dev_fscore = evaluate.evalb(args.evalb_dir, dev_treebank, dev_predicted) print( "dev-fscore {} " "dev-elapsed {} " "total-elapsed {}".format( dev_fscore, format_elapsed(dev_start_time), format_elapsed(start_time), ) ) dev_pred_head = [[leaf.father for leaf in tree.leaves()] for tree in dev_predicted] dev_pred_type = [[leaf.type for leaf in tree.leaves()] for tree in dev_predicted] assert len(dev_pred_head) == len(dev_pred_type) assert len(dev_pred_type) == len(dep_dev_type) stats, stats_nopunc, stats_root, num_inst = dep_eval.eval(len(dev_pred_head), dep_dev_word, dep_dev_pos, dev_pred_head, dev_pred_type, dep_dev_headid, dep_dev_type, dep_dev_lengs, punct_set=punct_set, symbolic_root=False) dev_ucorr, dev_lcorr, dev_total, dev_ucomlpete, dev_lcomplete = stats dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucomlpete_nopunc, dev_lcomplete_nopunc = stats_nopunc dev_root_corr, dev_total_root = stats_root dev_total_inst = num_inst print( 'W. Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr, dev_lcorr, dev_total, dev_ucorr * 100 / dev_total, dev_lcorr * 100 / dev_total, dev_ucomlpete * 100 / dev_total_inst, dev_lcomplete * 100 / dev_total_inst)) print( 'Wo Punct: ucorr: %d, lcorr: %d, total: %d, uas: %.2f%%, las: %.2f%%, ucm: %.2f%%, lcm: %.2f%%' % ( dev_ucorr_nopunc, dev_lcorr_nopunc, dev_total_nopunc, dev_ucorr_nopunc * 100 / dev_total_nopunc, dev_lcorr_nopunc * 100 / dev_total_nopunc, dev_ucomlpete_nopunc * 100 / dev_total_inst, dev_lcomplete_nopunc * 100 / dev_total_inst)) print('Root: corr: %d, total: %d, acc: %.2f%%' % ( dev_root_corr, dev_total_root, dev_root_corr * 100 / dev_total_root)) dev_uas = dev_ucorr_nopunc * 100 / dev_total_nopunc dev_las = dev_lcorr_nopunc * 100 / dev_total_nopunc if dev_fscore.fscore + dev_las > best_dev_score : if best_model_path is not None: extensions = [".pt"] for ext in extensions: path = best_model_path + ext if os.path.exists(path): print("Removing previous model file {}...".format(path)) os.remove(path) best_dev_score = dev_fscore.fscore + dev_las best_model_path = "{}_best_dev={:.2f}_devuas={:.2f}_devlas={:.2f}".format( args.model_path_base, dev_fscore.fscore, dev_uas,dev_las) print("Saving new best model to {}...".format(best_model_path)) torch.save({ 'spec': parser.spec, 'state_dict': parser.state_dict(), 'trainer' : trainer.state_dict(), }, besthh_model_path + ".pt") for epoch in itertools.count(start=1): if args.epochs is not None and epoch > args.epochs: break #check_dev(epoch) np.random.shuffle(train_parse) epoch_start_time = time.time() for start_index in range(0, len(train_parse), args.batch_size): trainer.zero_grad() schedule_lr(total_processed // args.batch_size) parser.train() batch_loss_value = 0.0 batch_trees = train_parse[start_index:start_index + args.batch_size] batch_sentences = [[(leaf.tag, leaf.word) for leaf in tree.leaves()] for tree in batch_trees] for subbatch_sentences, subbatch_trees in parser.split_batch(batch_sentences, batch_trees, args.subbatch_max_tokens): _, loss = parser.parse_batch(subbatch_sentences, subbatch_trees) loss = loss / len(batch_trees) loss_value = float(loss.data.cpu().numpy()) batch_loss_value += loss_value if loss_value > 0: loss.backward() del loss total_processed += len(subbatch_trees) current_processed += len(subbatch_trees) grad_norm = torch.nn.utils.clip_grad_norm_(clippable_parameters, grad_clip_threshold) trainer.step() print( "epoch {:,} " "batch {:,}/{:,} " "processed {:,} " "batch-loss {:.4f} " "grad-norm {:.4f} " "epoch-elapsed {} " "total-elapsed {}".format( epoch, start_index // args.batch_size + 1, int(np.ceil(len(train_parse) / args.batch_size)), total_processed, batch_loss_value, grad_norm, format_elapsed(epoch_start_time), format_elapsed(start_time), ) ) if current_processed >= check_every: current_processed -= check_every check_dev(epoch) # adjust learning rate at the end of an epoch if hparams.step_decay: if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps: scheduler.step(best_dev_score)