def tensor2tree(x, D:Vocab=None): # x: 1D int tensor x = list(x.detach().cpu().numpy()) x = [D(xe) for xe in x] x = [xe for xe in x if xe != D.padtoken] # find first @END@ and cut off parentheses_balance = 0 for i in range(len(x)): if x[i] == D.endtoken: x = x[:i] break elif x[i] == "(" or x[i][-1] == "(": parentheses_balance += 1 elif x[i] == ")": parentheses_balance -= 1 else: pass # balance parentheses while parentheses_balance > 0: x.append(")") parentheses_balance -= 1 i = len(x) - 1 while parentheses_balance < 0 and i > 0: if x[i] == ")": x.pop(i) parentheses_balance += 1 i -= 1 # convert to nltk.Tree try: tree, parsestate = lisp_to_tree(" ".join(x), None) except Exception as e: tree = None return tree
def build_token_specs(self, outputs: Iterable[str]): token_specs = dict() def walk_the_tree(t, _ts): l = t.label() if l not in _ts: _ts[l] = [np.infty, -np.infty] minc, maxc = _ts[l] _ts[l] = [min(minc, len(t)), max(maxc, len(t))] for c in t: walk_the_tree(c, _ts) for out in outputs: out_tokens = self.query_encoder.convert(out, return_what="tokens")[0] assert (out_tokens[-1] == "@END@") out_tokens = out_tokens[:-1] out_str = " ".join(out_tokens) tree = lisp_to_tree(out_str) walk_the_tree(tree, token_specs) token_specs["and"][1] = np.infty return token_specs
def build_data(self, examples: Iterable[dict], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for example, split in zip(examples, splits): inp, out = " ".join(example["sentence"]), " ".join(example["gold"]) inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(" ".join(example["gold"][:-1])) if not isinstance(gold_tree, Tree): assert (gold_tree is not None) gold_tensor, gold_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], [] candidate_align_entropies = [] candidate_trees = [] candidate_same = [] for cand in example["candidates"]: cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]), None) if cand_tree is None: cand_tree = Tree("@UNK@", []) assert (cand_tree is not None) cand_tensor, cand_tokens = self.query_encoder.convert( " ".join(cand["tokens"]), return_what="tensor,tokens") candidate_tensors.append(cand_tensor) candidate_tokens.append(cand_tokens) candidate_align_tensors.append(torch.tensor( cand["alignments"])) candidate_align_entropies.append( torch.tensor(cand["align_entropies"])) candidate_trees.append(cand_tree) candidate_same.append( are_equal_trees(cand_tree, gold_tree, orderless={"and", "or"}, unktoken="@NOUNKTOKENHERE@")) candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0), 0) candidate_align_tensor = torch.stack( q.pad_tensors(candidate_align_tensors, 0), 0) candidate_align_entropy = torch.stack( q.pad_tensors(candidate_align_entropies, 0), 0) candidate_same = torch.tensor(candidate_same) state = RankState( inp_tensor[None, :], gold_tensor[None, :], candidate_tensor[None, :, :], candidate_same[None, :], candidate_align_tensor[None, :], candidate_align_entropy[None, :], self.sentence_encoder.vocab, self.query_encoder.vocab, ) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, candidate_tensor.size(-1), gold_tensor.size(-1)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out
def test_unordered_duplicate(self): a = lisp_to_tree("(and (wife BO) (wife BR) (wife BO))") b = lisp_to_tree("(and (wife BO) (wife BR) (wife BR))") print(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(a, b))
def test_ordered(self): a = lisp_to_tree("(nand (wife BO) (spouse BO))") b = lisp_to_tree("(nand (wife BO) (spouse BO) (child BO))") print(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(a, b))