def test_unordered_twolevel(self): a = lisp_to_tree( "(and (or (wife BO) (spouse BO)) (or (wife BR) (spouse BR)) ) ") b = lisp_to_tree( "(and (or (spouse BR) (wife BR)) (or (spouse BO) (wife BO)) )") print(are_equal_trees(a, b)) self.assertTrue(are_equal_trees(a, b))
def test_it_unk(self): a = lisp_to_tree("(and (wife BO) (spouse BO))") b = lisp_to_tree("(and (wife BO) (spouse @UNK@))") print(are_equal_trees(a, b)) print(are_equal_trees(b, b)) self.assertFalse(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(b, b))
def test_forward(self, x: torch.Tensor, gold: torch.Tensor = None ): # --> implement how decoder operates end-to-end preds, stepsused = self.get_prediction(x) def tensor_to_trees(x, vocab: Vocab): xstrs = [ vocab.tostr(x[i]).replace("@BOS@", "").replace("@EOS@", "") for i in range(len(x)) ] xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs] trees = [] for xstr in xstrs: # drop everything after @END@, if present xstr = xstr.split("@END@") xstr = xstr[0] # add an opening parentheses if not there xstr = xstr.strip() if len(xstr) == 0 or xstr[0] != "(": xstr = "(" + xstr # balance closing parentheses parenthese_imbalance = xstr.count("(") - xstr.count(")") xstr = xstr + ")" * max(0, parenthese_imbalance ) # append missing closing parentheses xstr = "(" * -min( 0, parenthese_imbalance ) + xstr # prepend missing opening parentheses try: tree = lisp_to_tree(xstr) if isinstance( tree, tuple) and len(tree) == 2 and tree[0] is None: tree = None except Exception as e: tree = None trees.append(tree) return trees # compute loss and metrics gold_trees = tensor_to_trees(gold, vocab=self.vocab) pred_trees = tensor_to_trees(preds, vocab=self.vocab) treeaccs = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(gold_trees, pred_trees) ] ret = { "treeacc": torch.tensor(treeaccs).to(x.device), "stepsused": stepsused } return ret, pred_trees
def are_equal_trees(self, x: Tree, y: Tree, use_terminator=False): if x is None or y is None: return False _x = self.normalize_entities_and_variables(x, generic=True) _y = self.normalize_entities_and_variables(y, generic=True) # print(_x) # print(_y) ret = are_equal_trees(_x, _y, orderless=self.orderless, unktoken=self.outD[self.outD.unktoken], use_terminator=use_terminator) if ret is False: return False _x = x _y = y _x = self.normalize_tree(x)[1] _y = self.normalize_tree(y)[1] _x = self.normalize_entities_and_variables(_x, generic=False) _y = self.normalize_entities_and_variables(_y, generic=False) # print(_x) # print(_y) ret = are_equal_trees(_x, _y, orderless=self.orderless, unktoken=self.outD[self.outD.unktoken], use_terminator=use_terminator) if ret is False: # return False # print(_x) # print(_y) # iterate over possible reassignments of one tree for reassigned_y in self.reassignments(_y): # reassigned_y = self.normalize_entities_and_variables(reassigned_y, generic=False) if are_equal_trees(reassigned_y, _x, orderless=self.orderless, unktoken=self.outD[self.outD.unktoken], use_terminator=use_terminator): return True return False return ret
def compare(_gold_trees, _predactions): pred_trees = [ self.tensor2tree(predactionse) for predactionse in _predactions ] ret = [ float( are_equal_trees(gold_tree, pred_tree, orderless=self.orderless, unktoken=self.unktoken)) for gold_tree, pred_tree in zip(_gold_trees, pred_trees) ] return ret
def try_tree_permutations(): tree = Tree("x", [Tree("a", [Tree("1", []), Tree("2", []), Tree("3", [])]), Tree("b", [Tree("1", []), Tree("2", [])])]) print(tree) print("") perms = [] unique_perms = set() for tree_perm in get_tree_permutations(tree, orderless={"a", "x"}): print(tree_perm) assert(are_equal_trees(tree, tree_perm, orderless={"a", "x"})) unique_perms.add(str(tree_perm)) perms.append(str(tree_perm)) print(len(unique_perms), len(perms))
def test_it(self): a = lisp_to_tree("( and ( wife BO ) ( spouse BO ) )") b = lisp_to_tree("(and (spouse BO) (wife BO))") c = lisp_to_tree("(and (wife BO) (wife BO))") print(are_equal_trees(a, a)) # should be True print(are_equal_trees(a, b)) # should be True print(are_equal_trees(a, c)) # should be False self.assertTrue(are_equal_trees(a, a)) self.assertTrue(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(a, c))
def build_data(self, examples: Iterable[dict], splits: Iterable[str]): maxlen_in, maxlen_out = 0, 0 for example, split in zip(examples, splits): inp, out = " ".join(example["sentence"]), " ".join(example["gold"]) inp_tensor, inp_tokens = self.sentence_encoder.convert( inp, return_what="tensor,tokens") gold_tree = lisp_to_tree(" ".join(example["gold"][:-1])) if not isinstance(gold_tree, Tree): assert (gold_tree is not None) gold_tensor, gold_tokens = self.query_encoder.convert( out, return_what="tensor,tokens") candidate_tensors, candidate_tokens, candidate_align_tensors = [], [], [] candidate_align_entropies = [] candidate_trees = [] candidate_same = [] for cand in example["candidates"]: cand_tree, _ = lisp_to_tree(" ".join(cand["tokens"][:-1]), None) if cand_tree is None: cand_tree = Tree("@UNK@", []) assert (cand_tree is not None) cand_tensor, cand_tokens = self.query_encoder.convert( " ".join(cand["tokens"]), return_what="tensor,tokens") candidate_tensors.append(cand_tensor) candidate_tokens.append(cand_tokens) candidate_align_tensors.append(torch.tensor( cand["alignments"])) candidate_align_entropies.append( torch.tensor(cand["align_entropies"])) candidate_trees.append(cand_tree) candidate_same.append( are_equal_trees(cand_tree, gold_tree, orderless={"and", "or"}, unktoken="@NOUNKTOKENHERE@")) candidate_tensor = torch.stack(q.pad_tensors(candidate_tensors, 0), 0) candidate_align_tensor = torch.stack( q.pad_tensors(candidate_align_tensors, 0), 0) candidate_align_entropy = torch.stack( q.pad_tensors(candidate_align_entropies, 0), 0) candidate_same = torch.tensor(candidate_same) state = RankState( inp_tensor[None, :], gold_tensor[None, :], candidate_tensor[None, :, :], candidate_same[None, :], candidate_align_tensor[None, :], candidate_align_entropy[None, :], self.sentence_encoder.vocab, self.query_encoder.vocab, ) if split not in self.data: self.data[split] = [] self.data[split].append(state) maxlen_in = max(maxlen_in, len(inp_tokens)) maxlen_out = max(maxlen_out, candidate_tensor.size(-1), gold_tensor.size(-1)) self.maxlen_input = maxlen_in self.maxlen_output = maxlen_out
def test_forward(self, x: torch.Tensor, gold: torch.Tensor = None ): # --> implement how decoder operates end-to-end preds, prednll, maxmaxnll, entropy, total, avgconf, sumnll, stepsused = self.get_prediction( x) def tensor_to_trees(x, vocab: Vocab): xstrs = [ vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x)) ] xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs] trees = [] for xstr in xstrs: # drop everything after @END@, if present xstr = xstr.split("@END@") xstr = xstr[0] # add an opening parentheses if not there xstr = xstr.strip() if len(xstr) == 0 or xstr[0] != "(": xstr = "(" + xstr # balance closing parentheses parenthese_imbalance = xstr.count("(") - xstr.count(")") xstr = xstr + ")" * max(0, parenthese_imbalance ) # append missing closing parentheses xstr = "(" * -min( 0, parenthese_imbalance ) + xstr # prepend missing opening parentheses try: tree = taglisp_to_tree(xstr) if isinstance( tree, tuple) and len(tree) == 2 and tree[0] is None: tree = None except Exception as e: tree = None trees.append(tree) return trees # compute loss and metrics gold_trees = tensor_to_trees(gold, vocab=self.vocab) pred_trees = tensor_to_trees(preds, vocab=self.vocab) treeaccs = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(gold_trees, pred_trees) ] ret = { "treeacc": torch.tensor(treeaccs).to(x.device), "stepsused": stepsused } if self.mcdropout > 0: probses = [] preds = preds[:, 1:] self.train() for i in range(self.mcdropout): d, logits = self.train_forward(x, preds) probses.append(torch.softmax(logits, -1)) self.eval() probses = sum(probses) / len(probses) probses = probses[:, :-1] probs = probses mask = preds > 0 confs = torch.gather(probs, 2, preds[:, :, None])[:, :, 0] nlls = -torch.log(confs) avgconf = (confs + (1 - mask.float())).prod(-1) avgnll = (nlls * mask).sum(-1) / mask.float().sum(-1).clamp(1e-6) sumnll = (nlls * mask).sum(-1) maxnll, _ = (nlls + (1 - mask.float()) * -1e6).max(-1) entropy = (-torch.log(probs.clamp_min(1e-7)) * probs).sum(-1) entropy = (entropy * mask).sum(-1) / mask.float().sum(-1).clamp(1e-6) ret["decnll"] = avgnll ret["sumnll"] = sumnll ret["maxmaxnll"] = maxnll ret["entropy"] = entropy ret["avgconf"] = avgconf else: ret["decnll"] = prednll ret["sumnll"] = sumnll ret["maxmaxnll"] = maxmaxnll ret["entropy"] = entropy ret["avgconf"] = avgconf return ret, pred_trees
def forward(self, inpseqs: torch.Tensor = None, gold: torch.Tensor = None, mode: str = None, maxsteps: int = None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [ add_descendants_ancestors(goldtree) for goldtree in goldtrees ] for i in range(len(trees)): trees[i].align = goldtrees[i] i = 0 if self.training: trees = [assign_gold_actions(tree, mode=mode) for tree in trees] # choices = [deepcopy(trees)] numsteps = [0 for _ in range(len(trees))] treesizes = [0 for _ in range(len(trees))] seqlens = [0 for _ in range(len(trees))] allmetrics = {} allmetrics["loss"] = [] allmetrics["ce"] = [] allmetrics["elemrecall"] = [] allmetrics["allrecall"] = [] allmetrics["anyrecall"] = [] allmetrics["lowestentropyrecall"] = [] examplemasks = [] while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, seq = [] openmask = [] stepgold = [] for j, tree in enumerate(trees): if not all_terminated(tree): numsteps[j] += 1 treesizes[j] = tree_size(tree) fltoks, openmask_e, stepgold_e = extract_info(tree) seqlens[j] = len(fltoks) seq_e = self.seqenc.convert(fltoks, return_what="tensor") seq.append(seq_e) openmask.append(torch.tensor(openmask_e)) if self.training: stepgold_e_tensor = torch.zeros( seq_e.size(0), self.seqenc.vocab.number_of_ids()) for j, stepgold_e_i in enumerate(stepgold_e): for golde in stepgold_e_i: stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1 stepgold.append(stepgold_e_tensor) seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(openmask, 0, False), 0).to(device) if self.training: stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0), 0).to(device) if self.training: examplemask = (stepgold != 0).any(-1).any(-1) examplemasks.append(examplemask) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) if self.training: # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions) ce = self.ce(probs, stepgold, mask=openmask) elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall( probs, stepgold, mask=openmask) allmetrics["loss"].append(ce) allmetrics["ce"].append(ce) allmetrics["elemrecall"].append(elemrecall) allmetrics["allrecall"].append(allrecall) allmetrics["anyrecall"].append(anyrecall) allmetrics["lowestentropyrecall"].append(lowestentropyrecall) # get best predictions, _, best_actions = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = -(entropies * torch.log(entropies)).sum(-1) if self.training: newprobs = torch.softmax(probs, -1) * stepgold # mask using gold newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min( 1e-6) # renormalize uniform = stepgold uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6) newprobs = newprobs * (1 - q.v( self.uniformfactor)) + uniform * q.v(self.uniformfactor) noprobsmask = (newprobs != 0).any(-1, keepdim=True).float() zeroprobs = torch.zeros_like(newprobs) zeroprobs[:, :, 0] = 1 newprobs = newprobs * noprobsmask + (1 - noprobsmask) * zeroprobs sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\ .sample().view(newprobs.size(0), newprobs.size(1)) taken_actions = sampled_gold_actions else: taken_actions = best_actions # attach chosen actions and entropies to existing trees, for tree, actions_e, entropies_e in zip(trees, taken_actions, entropies): actions_e = list(actions_e.detach().cpu().numpy()) actions_e = [self.seqenc.vocab(xe) for xe in actions_e] entropies_e = list(entropies_e.detach().cpu().numpy()) self.attach_info_to_tree(tree, _chosen_action=actions_e, _entropy=entropies_e) # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e) # for seqe, openmaske, actione, entropies_e # in zip(list(seq), list(openmask), list(taken_actions), list(entropies))] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: markmode = mode if not self.training else "train-" + mode tree = mark_for_execution(tree, mode=markmode, entropylimit=self.entropylimit) budget = [self.max_tree_size - tree_size(tree)] if self.training: tree, _ = convert_chosen_actions_from_str_to_node(tree) tree = execute_chosen_actions(tree, _budget=budget, mode=mode) if self.training: tree = assign_gold_actions(tree, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} ret["seqlens"] = torch.tensor(seqlens).float() ret["treesizes"] = torch.tensor(treesizes).float() ret["numsteps"] = torch.tensor(numsteps).float() if self.training: assert (len(examplemasks) > 0) assert (len(allmetrics["loss"]) > 0) allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()} examplemasks = torch.stack(examplemasks, 1).float() _allmetrics = {} for k, v in allmetrics.items(): _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum( 1).clamp_min(1e-6) allmetrics = _allmetrics ret.update(allmetrics) if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] predtrees = [simplify_tree_for_eval(x) for x in trees] ret["treeacc"] = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees) ] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def test_unordered_duplicate(self): a = lisp_to_tree("(and (wife BO) (wife BR) (wife BO))") b = lisp_to_tree("(and (wife BO) (wife BR) (wife BR))") print(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(a, b))
def test_ordered(self): a = lisp_to_tree("(nand (wife BO) (spouse BO))") b = lisp_to_tree("(nand (wife BO) (spouse BO) (child BO))") print(are_equal_trees(a, b)) self.assertFalse(are_equal_trees(a, b))
def forward(self, inpseqs: torch.Tensor = None, y_in: torch.Tensor = None, gold: torch.Tensor = None, mode: str = None, maxsteps: int = None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=y_in) i = 0 while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, tensors = [] masks = [] for tree in trees: fltoks, openmask = extract_info(tree, nogold=True) seq = self.seqenc.convert(fltoks, return_what="tensor") tensors.append(seq) masks.append(torch.tensor(openmask)) seq = torch.stack(q.pad_tensors(tensors, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(masks, 0, False), 0).to(device) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) # get best predictions, _, best = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = -(entropies * torch.log(entropies)).sum(-1) # convert to trees, trees = [ tensors_to_tree(seqe, openmask=openmaske, actions=beste, D=self.seqenc.vocab, entropies=entropies_e) for seqe, openmaske, beste, entropies_e in zip( list(seq), list(openmask), list(best), list(entropies)) ] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: tree = mark_for_execution(tree, mode=mode) budget = [self.max_tree_size - tree_size(tree)] tree = execute_chosen_actions(tree, _budget=budget, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} if gold is not None: goldtrees = [ tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold) ] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] predtrees = [simplify_tree_for_eval(x) for x in trees] ret["treeacc"] = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees) ] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def test_forward(self, x: torch.Tensor, gold: torch.Tensor = None ): # --> implement how decoder operates end-to-end preds, prednll, maxmaxnll, stepsused = self.get_prediction(x) def tensor_to_trees(x, vocab: Vocab): xstrs = [ vocab.tostr(x[i]).replace("@START@", "") for i in range(len(x)) ] xstrs = [re.sub("::\d+", "", xstr) for xstr in xstrs] trees = [] for xstr in xstrs: # drop everything after @END@, if present xstr = xstr.split("@END@") xstr = xstr[0] # add an opening parentheses if not there xstr = xstr.strip() if len(xstr) == 0 or xstr[0] != "(": xstr = "(" + xstr # balance closing parentheses parenthese_imbalance = xstr.count("(") - xstr.count(")") xstr = xstr + ")" * max(0, parenthese_imbalance ) # append missing closing parentheses xstr = "(" * -min( 0, parenthese_imbalance ) + xstr # prepend missing opening parentheses try: tree = taglisp_to_tree(xstr) if isinstance( tree, tuple) and len(tree) == 2 and tree[0] is None: tree = None except Exception as e: tree = None trees.append(tree) return trees # compute loss and metrics gold_trees = tensor_to_trees(gold, vocab=self.vocab) pred_trees = tensor_to_trees(preds, vocab=self.vocab) treeaccs = [ float( are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(gold_trees, pred_trees) ] ret = { "treeacc": torch.tensor(treeaccs).to(x.device), "stepsused": stepsused } # compute bleu scores bleus = [] lcsf1s = [] for gold_tree, pred_tree in zip(gold_trees, pred_trees): if pred_tree is None or gold_tree is None: bleuscore = 0 lcsf1 = 0 else: gold_str = tree_to_lisp(gold_tree) pred_str = tree_to_lisp(pred_tree) bleuscore = sentence_bleu([gold_str.split(" ")], pred_str.split(" ")) lcsn = lcs(gold_str, pred_str) lcsrec = lcsn / len(gold_str) lcsprec = lcsn / len(pred_str) lcsf1 = 2 * lcsrec * lcsprec / (lcsrec + lcsprec) bleus.append(bleuscore) lcsf1s.append(lcsf1) bleus = torch.tensor(bleus).to(x.device) ret["bleu"] = bleus ret["lcsf1"] = torch.tensor(lcsf1s).to(x.device) d, logits = self.train_forward(x, gold) nll, acc, elemacc = d["loss"], d["acc"], d["elemacc"] ret["nll"] = nll ret["acc"] = acc ret["elemacc"] = elemacc # d, logits = self.train_forward(x, preds[:, 1:]) # decnll = d["loss"] # ret["decnll"] = decnll ret["decnll"] = prednll ret["maxmaxnll"] = maxmaxnll return ret, pred_trees