def flatten_tree(x: Tree): assert (x.label() == "@START@") assert (len(x) == 1) xstr = tree_to_lisp_tokens(x[0]) nodes = [Tree(xe if xe not in "()" else "|" + xe, []) for xe in xstr] y = Tree(x.label(), nodes) return y
def tokenize_and_add_start(t, _domain, lexical=False): tokens = tree_to_lisp_tokens(t) if not lexical: starttok = f"@START/{_domain}@" if add_domain_start else "@START@" tokens = [starttok] + tokens else: starttok = f"@LEX/{_domain}@" if add_domain_start else "@LEX@" tokens = [starttok] + tokens return tokens
def tokenize_and_add_start(t): tokens = tree_to_lisp_tokens(t) starttok = "@START@" tokens = [starttok] + tokens return tokens
def forward(self, inpseqs:torch.Tensor=None, gold:torch.Tensor=None, mode:str=None, maxsteps:int=None, **kw): """ """ maxsteps = maxsteps if maxsteps is not None else self.maxsteps mode = mode if mode is not None else self.mode device = next(self.parameters()).device trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None) if gold is not None: goldtrees = [tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold)] goldtrees = [add_descendants_ancestors(goldtree) for goldtree in goldtrees] for i in range(len(trees)): trees[i].align = goldtrees[i] i = 0 if self.training: trees = [assign_gold_actions(tree, mode=mode) for tree in trees] # choices = [deepcopy(trees)] numsteps = [0 for _ in range(len(trees))] treesizes = [0 for _ in range(len(trees))] seqlens = [0 for _ in range(len(trees))] allmetrics = {} allmetrics["loss"] = [] allmetrics["ce"] = [] allmetrics["elemrecall"] = [] allmetrics["allrecall"] = [] allmetrics["anyrecall"] = [] allmetrics["lowestentropyrecall"] = [] examplemasks = [] while not all([all_terminated(tree) for tree in trees]) and i < maxsteps: # go from tree to tensors, seq = [] openmask = [] stepgold = [] for j, tree in enumerate(trees): if not all_terminated(tree): numsteps[j] += 1 treesizes[j] = tree_size(tree) fltoks, openmask_e, stepgold_e = extract_info(tree) seqlens[j] = len(fltoks) seq_e = self.seqenc.convert(fltoks, return_what="tensor") seq.append(seq_e) openmask.append(torch.tensor(openmask_e)) if self.training: stepgold_e_tensor = torch.zeros(seq_e.size(0), self.seqenc.vocab.number_of_ids()) for j, stepgold_e_i in enumerate(stepgold_e): for golde in stepgold_e_i: stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1 stepgold.append(stepgold_e_tensor) seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device) openmask = torch.stack(q.pad_tensors(openmask, 0, False), 0).to(device) if self.training: stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0), 0).to(device) if self.training: examplemask = (stepgold != 0).any(-1).any(-1) examplemasks.append(examplemask) # feed to tagger, probs = self.tagger(seq, openmask=openmask, **context) if self.training: # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions) ce = self.ce(probs, stepgold, mask=openmask) elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(probs, stepgold, mask=openmask) allmetrics["loss"].append(ce) allmetrics["ce"].append(ce) allmetrics["elemrecall"].append(elemrecall) allmetrics["allrecall"].append(allrecall) allmetrics["anyrecall"].append(anyrecall) allmetrics["lowestentropyrecall"].append(lowestentropyrecall) # get best predictions, _, best_actions = probs.max(-1) entropies = torch.softmax(probs, -1).clamp_min(1e-6) entropies = - (entropies * torch.log(entropies)).sum(-1) if self.training: newprobs = torch.softmax(probs, -1) * stepgold # mask using gold newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(1e-6) # renormalize uniform = stepgold uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6) newprobs = newprobs * (1 - q.v(self.uniformfactor)) + uniform * q.v(self.uniformfactor) noprobsmask = (newprobs != 0).any(-1, keepdim=True).float() zeroprobs = torch.zeros_like(newprobs) zeroprobs[:, :, 0] = 1 newprobs = newprobs * noprobsmask + (1-noprobsmask) * zeroprobs sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\ .sample().view(newprobs.size(0), newprobs.size(1)) taken_actions = sampled_gold_actions else: taken_actions = best_actions # attach chosen actions and entropies to existing trees, for tree, actions_e, entropies_e in zip(trees, taken_actions, entropies): actions_e = list(actions_e.detach().cpu().numpy()) actions_e = [self.seqenc.vocab(xe) for xe in actions_e] entropies_e = list(entropies_e.detach().cpu().numpy()) self.attach_info_to_tree(tree, _chosen_action=actions_e, _entropy=entropies_e) # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e) # for seqe, openmaske, actione, entropies_e # in zip(list(seq), list(openmask), list(taken_actions), list(entropies))] # and execute, trees_ = [] for tree in trees: if tree_size(tree) < self.max_tree_size: markmode = mode if not self.training else "train-"+mode tree = mark_for_execution(tree, mode=markmode, entropylimit=self.entropylimit) budget = [self.max_tree_size - tree_size(tree)] if self.training: tree, _ = convert_chosen_actions_from_str_to_node(tree) tree = execute_chosen_actions(tree, _budget=budget, mode=mode) if self.training: tree = assign_gold_actions(tree, mode=mode) trees_.append(tree) trees = trees_ i += 1 # then repeat until all terminated # after done decoding, if gold is given, run losses, else return just predictions ret = {} ret["seqlens"] = torch.tensor(seqlens).float() ret["treesizes"] = torch.tensor(treesizes).float() ret["numsteps"] = torch.tensor(numsteps).float() if self.training: assert(len(examplemasks) > 0) assert(len(allmetrics["loss"]) > 0) allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()} examplemasks = torch.stack(examplemasks, 1).float() _allmetrics = {} for k, v in allmetrics.items(): _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(1).clamp_min(1e-6) allmetrics = _allmetrics ret.update(allmetrics) if gold is not None: goldtrees = [tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold)] goldtrees = [simplify_tree_for_eval(x) for x in goldtrees] goldtree_tokens = [tree_to_lisp_tokens(xe) for xe in goldtrees] goldtree_tokens = [[xei.replace("|", "") for xei in xe] for xe in goldtree_tokens] goldtrees = [build_atree(xe) for xe in goldtree_tokens] predtrees = [simplify_tree_for_eval(x) for x in trees] predtree_tokens = [tree_to_lisp_tokens(xe) for xe in predtrees] predtree_tokens = [[xei.replace("|", "") for xei in xe] for xe in predtree_tokens] predtrees = [build_atree(xe) for xe in predtree_tokens] ret["treeacc"] = [float(are_equal_trees(gold_tree, pred_tree, orderless=ORDERLESS, unktoken="@UNK@")) for gold_tree, pred_tree in zip(goldtrees, predtrees)] ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device) return ret, trees
def tokenize_and_add_start(t, _domain): tokens = tree_to_lisp_tokens(t) starttok = f"@START/{_domain}@" if add_domain_start else "@START@" tokens = [starttok] + tokens return tokens
def tree_to_seq(x: Tree): xstr = tree_to_lisp_tokens(x) # xstr = ["@BOS@"] + xstr + ["@EOS@"] return xstr
def tree_to_str(x:Tree): toks = tree_to_lisp_tokens(x, brackets="[]") toks = " ".join(toks) toks = toks.replace("[ ", "<[").replace("]", "]>") return toks