def rerank(fm, args, network=None): if network is None: network = torch.load(args.model) print('Loaded model from: {}'.format(args.model)) gold = PhraseTree.load_trees(args.gold) kbest = PhraseTree.load_kbests(args.kbest, fm) res = [] print('reranking') for onebest, g in zip(kbest, gold): # gold_score = network.force_decoding(fm.gold_data(g), fm) # print(gold_score) scores = np.zeros(len(onebest), dtype='float32') for i, data in enumerate(onebest): scores[i] = force_decoding(network, data, fm) maxid = np.argmax(scores) # print(scores) # print(maxid) res.append(onebest[maxid]) accuracy = FScore() baseline = FScore() for p, g in zip(kbest, gold): local_accuracy = p[0]['tree'].compare(g) baseline += local_accuracy for p, g in zip(res, gold): local_accuracy = p['tree'].compare(g) accuracy += local_accuracy print(accuracy) print(baseline) return accuracy
def test(fm, args): test_trees = PhraseTree.load_trees(args.test) print('Loaded test trees from {}'.format(args.test)) network = torch.load(args.model) print('Loaded model from: {}'.format(args.model)) accuracy = Parser.evaluate_corpus(test_trees, fm, network) print('Accuracy: {}'.format(accuracy))
def vocab_init(tree_file, verbose=True): """ Learn vocabulary from file of strings. """ word_freq = defaultdict(int) tag_freq = defaultdict(int) label_freq = defaultdict(int) trees = PhraseTree.load_trees(tree_file) for i, tree in enumerate(trees): for (word, tag) in tree.sentence: word_freq[word] += 1 tag_freq[tag] += 1 for action in Parser.gold_actions(tree): if action.startswith('label-'): label = action[6:] label_freq[label] += 1 if verbose: print('\rTree {}'.format(i), end='') sys.stdout.flush() if verbose: print('\r', end='') words = [ Vocab.UNK, Vocab.START, Vocab.STOP, ] + sorted(word_freq) wdict = OrderedDict((w, i) for (i, w) in enumerate(words)) tags = [ Vocab.UNK, Vocab.START, Vocab.STOP, ] + sorted(tag_freq) tdict = OrderedDict((t, i) for (i, t) in enumerate(tags)) labels = sorted(label_freq) ldict = OrderedDict((l, i) for (i, l) in enumerate(labels)) if verbose: print('Loading features from {}'.format(tree_file)) print('({} words, {} tags, {} nonterminal-chains)'.format( len(wdict), len(tdict), len(ldict), )) return { 'wdict': wdict, 'word_freq': word_freq, 'tdict': tdict, 'ldict': ldict, }
def gold_data_from_file(self, fname): """ Static oracle for file. """ trees = PhraseTree.load_trees(fname) result = [] for tree in trees: sentence_data = self.gold_data(tree) result.append(sentence_data) return result
def write_raw_predicted(fname, sentences, fm, network): f = open(fname, 'w') for sentence in sentences: predicted = Parser.parse(sentence, fm, network) topped = PhraseTree( symbol='TOP', children=[predicted], sentence=predicted.sentence, ) f.write(str(topped)) f.write('\n') f.close()
def write_predicted(fname, trees, fm, network): """ Input trees being used only to carry sentences. """ f = open(fname, 'w') accuracy = FScore() for tree in trees: predicted = Parser.parse(tree.sentence, fm, network) local_accuracy = predicted.compare(tree) accuracy += local_accuracy topped = PhraseTree( symbol='TOP', children=[predicted], sentence=predicted.sentence, ) f.write(str(topped)) f.write('\n') f.close() return accuracy
def train(fm, args): train_data_file = args.train dev_data_file = args.dev epochs = args.epochs batch_size = args.batch_size unk_param = args.unk_param alpha = args.alpha beta = args.beta model_save_file = args.model print("this is train mode") start_time = time.time() network = Network(fm, args) optimizer = optimize.Adadelta(network.parameters(), eps=1e-7, rho=0.99) if GlobalNames.use_gpu: network.cuda() training_data = fm.gold_data_from_file(train_data_file) num_batches = -(-len(training_data) // batch_size) print('Loaded {} training sentences ({} batches of size {})!'.format( len(training_data), num_batches, batch_size, )) parse_every = -(-num_batches // 4) dev_trees = PhraseTree.load_trees(dev_data_file) print('Loaded {} validation trees!'.format(len(dev_trees))) best_acc = FScore() for epoch in range(1, epochs + 1): print('........... epoch {} ...........'.format(epoch)) total_cost = 0.0 total_states = 0 training_acc = FScore() np.random.shuffle(training_data) for b in range(num_batches): network.zero_grad() batch = training_data[(b * batch_size):((b + 1) * batch_size)] batch_loss = None for example in batch: example_Loss, example_states, acc = Parser.exploration( example, fm, network, alpha=alpha, beta=beta, unk_param=unk_param) total_states += example_states if batch_loss is not None: batch_loss += example_Loss else: batch_loss = example_Loss training_acc += acc if GlobalNames.use_gpu: total_cost += batch_loss.cpu().data.numpy()[0] else: total_cost += batch_loss.data.numpy()[0] batch_loss.backward() optimizer.step() mean_cost = total_cost / total_states print( '\rBatch {} Mean Cost {:.4f} [Train: {}]'.format( b, mean_cost, training_acc, ), end='', ) sys.stdout.flush() if ((b + 1) % parse_every) == 0 or b == (num_batches - 1): dev_acc = Parser.evaluate_corpus( dev_trees, fm, network, ) print(' [Dev: {}]'.format(dev_acc)) if dev_acc > best_acc: best_acc = dev_acc s = round(dev_acc.fscore(), 2) temp_save_file = model_save_file.replace( '.model', '{}.model'.format(s)) torch.save(network, temp_save_file) print(' [saved model: {}]'.format(temp_save_file)) # rerank(fm,args) current_time = time.time() runmins = (current_time - start_time) / 60. print(' Elapsed time: {:.2f}m'.format(runmins))
def label(self, nonterminals=[]): for nt in nonterminals: (left, right, trees) = self.stack.pop() tree = PhraseTree(symbol=nt, children=trees) self.stack.append((left, right, [tree]))
def shift(self): j = self.i # (index of shifted word) treelet = PhraseTree(leaf=j) self.stack.append((j, j, [treelet])) self.i += 1