Esempio n. 1
0
def tensor2tree(x, D:Vocab=None):
    # x: 1D int tensor
    x = list(x.detach().cpu().numpy())
    x = [D(xe) for xe in x]
    x = [xe for xe in x if xe != D.padtoken]

    # find first @END@ and cut off
    parentheses_balance = 0
    for i in range(len(x)):
        if x[i] == D.endtoken:
            x = x[:i]
            break
        elif x[i] == "(" or x[i][-1] == "(":
            parentheses_balance += 1
        elif x[i] == ")":
            parentheses_balance -= 1
        else:
            pass

    # balance parentheses
    while parentheses_balance > 0:
        x.append(")")
        parentheses_balance -= 1
    i = len(x) - 1
    while parentheses_balance < 0 and i > 0:
        if x[i] == ")":
            x.pop(i)
            parentheses_balance += 1
        i -= 1

    # convert to nltk.Tree
    try:
        tree, parsestate = lisp_to_tree(" ".join(x), None)
    except Exception as e:
        tree = None
    return tree
Esempio n. 2
0
    def build_token_specs(self, outputs: Iterable[str]):
        token_specs = dict()

        def walk_the_tree(t, _ts):
            l = t.label()
            if l not in _ts:
                _ts[l] = [np.infty, -np.infty]
            minc, maxc = _ts[l]
            _ts[l] = [min(minc, len(t)), max(maxc, len(t))]
            for c in t:
                walk_the_tree(c, _ts)

        for out in outputs:
            out_tokens = self.query_encoder.convert(out,
                                                    return_what="tokens")[0]
            assert (out_tokens[-1] == "@END@")
            out_tokens = out_tokens[:-1]
            out_str = " ".join(out_tokens)
            tree = lisp_to_tree(out_str)
            walk_the_tree(tree, token_specs)

        token_specs["and"][1] = np.infty

        return token_specs
Esempio n. 3
0
    def build_data(self, examples: Iterable[dict], splits: Iterable[str]):
        maxlen_in, maxlen_out = 0, 0
        for example, split in zip(examples, splits):
            inp, out = " ".join(example["sentence"]), " ".join(example["gold"])
            inp_tensor, inp_tokens = self.sentence_encoder.convert(
                inp, return_what="tensor,tokens")
            gold_tree = lisp_to_tree(" ".join(example["gold"][:-1]))
            if not isinstance(gold_tree, Tree):
                assert (gold_tree is not None)
            gold_tensor, gold_tokens = self.query_encoder.convert(
                out, return_what="tensor,tokens")

            candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], []
            candidate_align_entropies = []
            candidate_trees = []
            candidate_same = []
            for cand in example["candidates"]:
                cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]),
                                            None)
                if cand_tree is None:
                    cand_tree = Tree("@UNK@", [])
                assert (cand_tree is not None)
                cand_tensor, cand_tokens = self.query_encoder.convert(
                    " ".join(cand["tokens"]), return_what="tensor,tokens")
                candidate_tensors.append(cand_tensor)
                candidate_tokens.append(cand_tokens)
                candidate_align_tensors.append(torch.tensor(
                    cand["alignments"]))
                candidate_align_entropies.append(
                    torch.tensor(cand["align_entropies"]))
                candidate_trees.append(cand_tree)
                candidate_same.append(
                    are_equal_trees(cand_tree,
                                    gold_tree,
                                    orderless={"and", "or"},
                                    unktoken="@NOUNKTOKENHERE@"))

            candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0),
                                           0)
            candidate_align_tensor = torch.stack(
                q.pad_tensors(candidate_align_tensors, 0), 0)
            candidate_align_entropy = torch.stack(
                q.pad_tensors(candidate_align_entropies, 0), 0)
            candidate_same = torch.tensor(candidate_same)

            state = RankState(
                inp_tensor[None, :],
                gold_tensor[None, :],
                candidate_tensor[None, :, :],
                candidate_same[None, :],
                candidate_align_tensor[None, :],
                candidate_align_entropy[None, :],
                self.sentence_encoder.vocab,
                self.query_encoder.vocab,
            )
            if split not in self.data:
                self.data[split] = []
            self.data[split].append(state)
            maxlen_in = max(maxlen_in, len(inp_tokens))
            maxlen_out = max(maxlen_out, candidate_tensor.size(-1),
                             gold_tensor.size(-1))
        self.maxlen_input = maxlen_in
        self.maxlen_output = maxlen_out
Esempio n. 4
0
 def test_unordered_duplicate(self):
     a = lisp_to_tree("(and (wife BO) (wife BR) (wife BO))")
     b = lisp_to_tree("(and (wife BO) (wife BR) (wife BR))")
     print(are_equal_trees(a, b))
     self.assertFalse(are_equal_trees(a, b))
Esempio n. 5
0
 def test_ordered(self):
     a = lisp_to_tree("(nand (wife BO) (spouse BO))")
     b = lisp_to_tree("(nand (wife BO) (spouse BO) (child BO))")
     print(are_equal_trees(a, b))
     self.assertFalse(are_equal_trees(a, b))