Beispiel #1
0
    def parse_from_annotations(self, fencepost_annotations_start, fencepost_annotations_end, sentence, gold=None):
        is_train = gold is not None
        label_scores_chart = self.label_scores_from_annotations(fencepost_annotations_start, fencepost_annotations_end)
        label_scores_chart_np = label_scores_chart.cpu().data.numpy()

        if is_train:
            decoder_args = dict(
                sentence_len=len(sentence),
                label_scores_chart=label_scores_chart_np,
                gold=gold,
                label_vocab=self.label_vocab,
                is_train=is_train)

            p_score, p_i, p_j, p_label, p_augment = chart_helper.decode(False, **decoder_args)
            g_score, g_i, g_j, g_label, g_augment = chart_helper.decode(True, **decoder_args)
            return p_i, p_j, p_label, p_augment, g_i, g_j, g_label
        else:
            return self.decode_from_chart(sentence, label_scores_chart_np)
Beispiel #2
0
    def decode_from_chart(self, sentence, chart_np, gold=None):
        decoder_args = dict(
            sentence_len=len(sentence),
            label_scores_chart=chart_np,
            gold=gold,
            label_vocab=self.label_vocab,
            is_train=False)

        force_gold = (gold is not None)

        # The optimized cython decoder implementation doesn't actually
        # generate trees, only scores and span indices. When converting to a
        # tree, we assume that the indices follow a preorder traversal.
        score, p_i, p_j, p_label, _ = chart_helper.decode(force_gold, **decoder_args)
        last_splits = []
        idx = -1

        def make_tree():
            nonlocal idx
            idx += 1
            i, j, label_idx = p_i[idx], p_j[idx], p_label[idx]
            label = self.label_vocab.value(label_idx)
            if (i + 1) >= j:
                tag, word = sentence[i]
                tree = trees.LeafParseNode(int(i), tag, word)
                if label:
                    tree = trees.InternalParseNode(label, [tree])
                return [tree]
            else:
                left_trees = make_tree()
                right_trees = make_tree()
                children = left_trees + right_trees
                if label:
                    return [trees.InternalParseNode(label, children)]
                else:
                    return children

        tree = make_tree()[0]
        return tree, score
def run_test(args):
    print("Loading test trees from {}...".format(args.test_path))
    if args.test_lbls:
        test_txt = [
            l.strip().split() for l in open(args.test_path, 'r').readlines()
        ]
        test_lbls = [
            l.strip().split() for l in open(args.test_lbls, 'r').readlines()
        ]
        test_sent_ids = [
            l.strip() for l in open(args.test_sent_id_path, 'r').readlines()
        ]
        test_treebank = [(txt, lbl) for txt, lbl in zip(test_txt, test_lbls)]
    else:
        test_treebank, test_sent_ids = trees.load_trees_with_idx(args.test_path, \
            args.test_sent_id_path, strip_top=False)

    if not args.new_set:
        test_pause_path = os.path.join(args.feature_path, args.test_prefix + \
            '_pause.pickle')
        with open(test_pause_path, 'rb') as f:
            test_pause_data = pickle.load(f, encoding='latin1')

        # to_remove = set(test_sent_ids).difference(set(test_pause_data.keys()))
        # to_remove = sorted([test_sent_ids.index(i) for i in to_remove])
        # for x in to_remove[::-1]:
        #     test_treebank.pop(x)
        #     test_sent_ids.pop(x)

    print("Loaded {:,} test examples.".format(len(test_treebank)))

    print("Loading model from {}...".format(args.model_path_base))
    assert args.model_path_base.endswith(".pt"), "Only pytorch files supported"

    info = torch_load(args.model_path_base)
    print(info.keys())
    assert 'hparams' in info['spec'], "Older savefiles not supported"

    parser = parse_model.SpeechParser.from_spec(info['spec'], \
            info['state_dict'])

    from prettytable import PrettyTable
    total_params = 0
    table = PrettyTable(["Modules", "Parameters"])
    for name, parameter in parser.named_parameters():
        if not parameter.requires_grad: continue
        param = parameter.numel()
        table.add_row([name, param])
        total_params += param

    parser.eval()  # turn off dropout at evaluation time
    label_vocab = parser.label_vocab
    #print("{} ({:,}): {}".format("label", label_vocab.size, \
    #        sorted(value for value in label_vocab.values)))

    test_feat_dict = {}
    if info['spec']['speech_features'] is not None:
        speech_features = info['spec']['speech_features']
        print("Loading speech features for test set...")
        for feat_type in speech_features:
            print("\t", feat_type)
            feat_path = os.path.join(args.feature_path, \
                    args.test_prefix + '_' + feat_type + '.pickle')
            with open(feat_path, 'rb') as f:
                feat_data = pickle.load(f, encoding='latin1')
            test_feat_dict[feat_type] = feat_data

    print("Parsing test sentences...")
    start_time = time.time()

    test_predicted = []
    test_scores = []
    pscores = []
    gscores = []
    with torch.no_grad():
        for start_index in range(0, len(test_treebank), args.eval_batch_size):
            subbatch_treebank = test_treebank[start_index:start_index \
                    + args.eval_batch_size]
            subbatch_sent_ids = test_sent_ids[start_index:start_index \
                    + args.eval_batch_size]
            if args.test_lbls:  # EKN using this instead of the seg flag bc it's an hparam
                subbatch_txt = [turn[0] for turn in subbatch_treebank]
                subbatch_lbl = [turn[1] for turn in subbatch_treebank]
                subbatch_sentences = [[(lbl,txt) for lbl,txt in zip(sent_lbl,sent_txt)] for \
                                   sent_lbl,sent_txt in zip(subbatch_lbl,subbatch_txt)]
            else:
                subbatch_sentences = [[(leaf.tag, leaf.word) for leaf in \
                    tree.leaves()] for tree in subbatch_treebank]
                subbatch_trees = [t.convert() for t in subbatch_treebank]
            subbatch_features = load_features(subbatch_sent_ids, test_feat_dict\
                    , args.sp_off)
            predicted, scores = parser.parse_batch(subbatch_sentences, \
                        subbatch_sent_ids, subbatch_features)
            if not args.get_scores:
                del scores
            else:
                charts = parser.parse_batch(subbatch_sentences, \
                        subbatch_sent_ids, subbatch_features, subbatch_trees, True)
                for i in range(len(charts)):
                    decoder_args = dict(sentence_len=len(subbatch_sentences[i]),\
                            label_scores_chart=charts[i],\
                            gold=subbatch_trees[i],\
                            label_vocab=parser.label_vocab, \
                            is_train=False, \
                            backoff=True)
                    p_score, _, _, _, _ = chart_helper.decode(
                        False, **decoder_args)
                    g_score, _, _, _, _ = chart_helper.decode(
                        True, **decoder_args)
                    pscores.append(p_score)
                    gscores.append(g_score)
                test_scores += scores
            if args.test_lbls:
                test_predicted.extend(predicted)
            else:
                test_predicted.extend([p.convert() for p in predicted])

    # DEBUG
    # print(test_scores)
    #print(test_score_offsets)

    with open(args.output_path, 'w') as output_file:
        for tree in test_predicted:
            if args.test_lbls:
                #import pdb;pdb.set_trace()
                lbls = ' '.join(tree)
                output_file.write("{}\n".format(lbls))
            else:
                output_file.write("{}\n".format(tree.linearize()))
    print("Output written to:", args.output_path)

    if args.get_scores:
        with open(args.output_path + '.scores', 'w') as output_file:
            for score1, score2, score3 in zip(test_scores, pscores, gscores):
                output_file.write("{}\t{}\t{}\n".format(
                    score1, score2, score3))
        print("Output scores written to:", args.output_path + '.scores')

    if args.write_gold:
        with open(args.test_prefix + '_sent_ids.txt', 'w') as sid_file:
            for sent_id in test_sent_ids:
                sid_file.write("{}\n".format(sent_id))
        print("Sent ids written to:", args.test_prefix + '_sent_ids.txt')

        with open(args.test_prefix + '_gold.txt', 'w') as gold_file:
            for tree in test_treebank:
                gold_file.write("{}\n".format(tree.linearize()))
        print("Gold trees written to:", args.test_prefix + '_gold.txt')

    # The tree loader does some preprocessing to the trees (e.g. stripping TOP
    # symbols or SPMRL morphological features). We compare with the input file
    # directly to be extra careful about not corrupting the evaluation. We also
    # allow specifying a separate "raw" file for the gold trees: the inputs to
    # our parser have traces removed and may have predicted tags substituted,
    # and we may wish to compare against the raw gold trees to make sure we
    # haven't made a mistake. As far as we can tell all of these variations give
    # equivalent results.
    ref_gold_path = args.test_path
    if args.test_path_raw is not None:
        print("Comparing with raw trees from", args.test_path_raw)
        ref_gold_path = args.test_path_raw
    else:
        # Need this since I'm evaluating on subset
        ref_gold_path = None

    if args.test_lbls:
        test_fscore = evaluate.seg_fscore(test_treebank,
                                          test_predicted,
                                          is_train=False)
    else:
        test_fscore = evaluate.evalb(args.evalb_dir, test_treebank, \
            test_predicted, ref_gold_path=ref_gold_path, is_train=False)

    print("test-fscore {} "
          "test-elapsed {}".format(
              test_fscore,
              format_elapsed(start_time),
          ))
N = 16
T = 5
num_labels = 3
chart = torch.randn(N, T + 1, T + 1, num_labels) * 5

chart_np = chart.cpu().numpy()
spans0 = []
for n in range(N):
    decoder_args = dict(
        sentence_len=T,
        label_scores_chart=chart_np[n],
        gold=None,
        label_vocab=None,
        is_train=False,
    )
    score, p_i, p_j, p_label, _ = chart_helper.decode(False, **decoder_args)
    spans0.append(sorted(list(zip(p_i, p_j, p_label))))

model = ts.CKY_CRF
max_struct = model(ts.MaxSemiring)


def gs(chart, struct):
    # don't allow root symbol to be empty
    chart[:, 0, -1, 0].fill_(-1e8)
    spans = struct.marginals(chart).nonzero()
    spans[:, 2] += 1
    return spans


def cat(chart, dim):