def parse_from_annotations(self, fencepost_annotations_start, fencepost_annotations_end, sentence, gold=None): is_train = gold is not None label_scores_chart = self.label_scores_from_annotations(fencepost_annotations_start, fencepost_annotations_end) label_scores_chart_np = label_scores_chart.cpu().data.numpy() if is_train: decoder_args = dict( sentence_len=len(sentence), label_scores_chart=label_scores_chart_np, gold=gold, label_vocab=self.label_vocab, is_train=is_train) p_score, p_i, p_j, p_label, p_augment = chart_helper.decode(False, **decoder_args) g_score, g_i, g_j, g_label, g_augment = chart_helper.decode(True, **decoder_args) return p_i, p_j, p_label, p_augment, g_i, g_j, g_label else: return self.decode_from_chart(sentence, label_scores_chart_np)
def decode_from_chart(self, sentence, chart_np, gold=None): decoder_args = dict( sentence_len=len(sentence), label_scores_chart=chart_np, gold=gold, label_vocab=self.label_vocab, is_train=False) force_gold = (gold is not None) # The optimized cython decoder implementation doesn't actually # generate trees, only scores and span indices. When converting to a # tree, we assume that the indices follow a preorder traversal. score, p_i, p_j, p_label, _ = chart_helper.decode(force_gold, **decoder_args) last_splits = [] idx = -1 def make_tree(): nonlocal idx idx += 1 i, j, label_idx = p_i[idx], p_j[idx], p_label[idx] label = self.label_vocab.value(label_idx) if (i + 1) >= j: tag, word = sentence[i] tree = trees.LeafParseNode(int(i), tag, word) if label: tree = trees.InternalParseNode(label, [tree]) return [tree] else: left_trees = make_tree() right_trees = make_tree() children = left_trees + right_trees if label: return [trees.InternalParseNode(label, children)] else: return children tree = make_tree()[0] return tree, score
def run_test(args): print("Loading test trees from {}...".format(args.test_path)) if args.test_lbls: test_txt = [ l.strip().split() for l in open(args.test_path, 'r').readlines() ] test_lbls = [ l.strip().split() for l in open(args.test_lbls, 'r').readlines() ] test_sent_ids = [ l.strip() for l in open(args.test_sent_id_path, 'r').readlines() ] test_treebank = [(txt, lbl) for txt, lbl in zip(test_txt, test_lbls)] else: test_treebank, test_sent_ids = trees.load_trees_with_idx(args.test_path, \ args.test_sent_id_path, strip_top=False) if not args.new_set: test_pause_path = os.path.join(args.feature_path, args.test_prefix + \ '_pause.pickle') with open(test_pause_path, 'rb') as f: test_pause_data = pickle.load(f, encoding='latin1') # to_remove = set(test_sent_ids).difference(set(test_pause_data.keys())) # to_remove = sorted([test_sent_ids.index(i) for i in to_remove]) # for x in to_remove[::-1]: # test_treebank.pop(x) # test_sent_ids.pop(x) print("Loaded {:,} test examples.".format(len(test_treebank))) print("Loading model from {}...".format(args.model_path_base)) assert args.model_path_base.endswith(".pt"), "Only pytorch files supported" info = torch_load(args.model_path_base) print(info.keys()) assert 'hparams' in info['spec'], "Older savefiles not supported" parser = parse_model.SpeechParser.from_spec(info['spec'], \ info['state_dict']) from prettytable import PrettyTable total_params = 0 table = PrettyTable(["Modules", "Parameters"]) for name, parameter in parser.named_parameters(): if not parameter.requires_grad: continue param = parameter.numel() table.add_row([name, param]) total_params += param parser.eval() # turn off dropout at evaluation time label_vocab = parser.label_vocab #print("{} ({:,}): {}".format("label", label_vocab.size, \ # sorted(value for value in label_vocab.values))) test_feat_dict = {} if info['spec']['speech_features'] is not None: speech_features = info['spec']['speech_features'] print("Loading speech features for test set...") for feat_type in speech_features: print("\t", feat_type) feat_path = os.path.join(args.feature_path, \ args.test_prefix + '_' + feat_type + '.pickle') with open(feat_path, 'rb') as f: feat_data = pickle.load(f, encoding='latin1') test_feat_dict[feat_type] = feat_data print("Parsing test sentences...") start_time = time.time() test_predicted = [] test_scores = [] pscores = [] gscores = [] with torch.no_grad(): for start_index in range(0, len(test_treebank), args.eval_batch_size): subbatch_treebank = test_treebank[start_index:start_index \ + args.eval_batch_size] subbatch_sent_ids = test_sent_ids[start_index:start_index \ + args.eval_batch_size] if args.test_lbls: # EKN using this instead of the seg flag bc it's an hparam subbatch_txt = [turn[0] for turn in subbatch_treebank] subbatch_lbl = [turn[1] for turn in subbatch_treebank] subbatch_sentences = [[(lbl,txt) for lbl,txt in zip(sent_lbl,sent_txt)] for \ sent_lbl,sent_txt in zip(subbatch_lbl,subbatch_txt)] else: subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in \ tree.leaves()] for tree in subbatch_treebank] subbatch_trees = [t.convert() for t in subbatch_treebank] subbatch_features = load_features(subbatch_sent_ids, test_feat_dict\ , args.sp_off) predicted, scores = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features) if not args.get_scores: del scores else: charts = parser.parse_batch(subbatch_sentences, \ subbatch_sent_ids, subbatch_features, subbatch_trees, True) for i in range(len(charts)): decoder_args = dict(sentence_len=len(subbatch_sentences[i]),\ label_scores_chart=charts[i],\ gold=subbatch_trees[i],\ label_vocab=parser.label_vocab, \ is_train=False, \ backoff=True) p_score, _, _, _, _ = chart_helper.decode( False, **decoder_args) g_score, _, _, _, _ = chart_helper.decode( True, **decoder_args) pscores.append(p_score) gscores.append(g_score) test_scores += scores if args.test_lbls: test_predicted.extend(predicted) else: test_predicted.extend([p.convert() for p in predicted]) # DEBUG # print(test_scores) #print(test_score_offsets) with open(args.output_path, 'w') as output_file: for tree in test_predicted: if args.test_lbls: #import pdb;pdb.set_trace() lbls = ' '.join(tree) output_file.write("{}\n".format(lbls)) else: output_file.write("{}\n".format(tree.linearize())) print("Output written to:", args.output_path) if args.get_scores: with open(args.output_path + '.scores', 'w') as output_file: for score1, score2, score3 in zip(test_scores, pscores, gscores): output_file.write("{}\t{}\t{}\n".format( score1, score2, score3)) print("Output scores written to:", args.output_path + '.scores') if args.write_gold: with open(args.test_prefix + '_sent_ids.txt', 'w') as sid_file: for sent_id in test_sent_ids: sid_file.write("{}\n".format(sent_id)) print("Sent ids written to:", args.test_prefix + '_sent_ids.txt') with open(args.test_prefix + '_gold.txt', 'w') as gold_file: for tree in test_treebank: gold_file.write("{}\n".format(tree.linearize())) print("Gold trees written to:", args.test_prefix + '_gold.txt') # The tree loader does some preprocessing to the trees (e.g. stripping TOP # symbols or SPMRL morphological features). We compare with the input file # directly to be extra careful about not corrupting the evaluation. We also # allow specifying a separate "raw" file for the gold trees: the inputs to # our parser have traces removed and may have predicted tags substituted, # and we may wish to compare against the raw gold trees to make sure we # haven't made a mistake. As far as we can tell all of these variations give # equivalent results. ref_gold_path = args.test_path if args.test_path_raw is not None: print("Comparing with raw trees from", args.test_path_raw) ref_gold_path = args.test_path_raw else: # Need this since I'm evaluating on subset ref_gold_path = None if args.test_lbls: test_fscore = evaluate.seg_fscore(test_treebank, test_predicted, is_train=False) else: test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, \ test_predicted, ref_gold_path=ref_gold_path, is_train=False) print("test-fscore {} " "test-elapsed {}".format( test_fscore, format_elapsed(start_time), ))
N = 16 T = 5 num_labels = 3 chart = torch.randn(N, T + 1, T + 1, num_labels) * 5 chart_np = chart.cpu().numpy() spans0 = [] for n in range(N): decoder_args = dict( sentence_len=T, label_scores_chart=chart_np[n], gold=None, label_vocab=None, is_train=False, ) score, p_i, p_j, p_label, _ = chart_helper.decode(False, **decoder_args) spans0.append(sorted(list(zip(p_i, p_j, p_label)))) model = ts.CKY_CRF max_struct = model(ts.MaxSemiring) def gs(chart, struct): # don't allow root symbol to be empty chart[:, 0, -1, 0].fill_(-1e8) spans = struct.marginals(chart).nonzero() spans[:, 2] += 1 return spans def cat(chart, dim):