Exemplo n.º 1
0
def flatten_tree(x: Tree):
    assert (x.label() == "@START@")
    assert (len(x) == 1)
    xstr = tree_to_lisp_tokens(x[0])
    nodes = [Tree(xe if xe not in "()" else "|" + xe, []) for xe in xstr]
    y = Tree(x.label(), nodes)
    return y
Exemplo n.º 2
0
 def tokenize_and_add_start(t, _domain, lexical=False):
     tokens = tree_to_lisp_tokens(t)
     if not lexical:
         starttok = f"@START/{_domain}@" if add_domain_start else "@START@"
         tokens = [starttok] + tokens
     else:
         starttok = f"@LEX/{_domain}@" if add_domain_start else "@LEX@"
         tokens = [starttok] + tokens
     return tokens
Exemplo n.º 3
0
 def tokenize_and_add_start(t):
     tokens = tree_to_lisp_tokens(t)
     starttok = "@START@"
     tokens = [starttok] + tokens
     return tokens
Exemplo n.º 4
0
    def forward(self, inpseqs:torch.Tensor=None, gold:torch.Tensor=None,
                mode:str=None, maxsteps:int=None, **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None)
        if gold is not None:
            goldtrees = [tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold)]
            goldtrees = [add_descendants_ancestors(goldtree) for goldtree in goldtrees]
            for i in range(len(trees)):
                trees[i].align = goldtrees[i]

        i = 0

        if self.training:
            trees = [assign_gold_actions(tree, mode=mode) for tree in trees]
        # choices = [deepcopy(trees)]
        numsteps = [0 for _ in range(len(trees))]
        treesizes = [0 for _ in range(len(trees))]
        seqlens = [0 for _ in range(len(trees))]
        allmetrics = {}
        allmetrics["loss"] = []
        allmetrics["ce"] = []
        allmetrics["elemrecall"] = []
        allmetrics["allrecall"] = []
        allmetrics["anyrecall"] = []
        allmetrics["lowestentropyrecall"] = []
        examplemasks = []
        while not all([all_terminated(tree) for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            seq = []
            openmask = []
            stepgold = []
            for j, tree in enumerate(trees):
                if not all_terminated(tree):
                    numsteps[j] += 1
                treesizes[j] = tree_size(tree)

                fltoks, openmask_e, stepgold_e = extract_info(tree)
                seqlens[j] = len(fltoks)
                seq_e = self.seqenc.convert(fltoks, return_what="tensor")
                seq.append(seq_e)
                openmask.append(torch.tensor(openmask_e))
                if self.training:
                    stepgold_e_tensor = torch.zeros(seq_e.size(0), self.seqenc.vocab.number_of_ids())
                    for j, stepgold_e_i in enumerate(stepgold_e):
                        for golde in stepgold_e_i:
                            stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1
                    stepgold.append(stepgold_e_tensor)
            seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(openmask, 0, False), 0).to(device)
            if self.training:
                stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0), 0).to(device)

            if self.training:
                examplemask = (stepgold != 0).any(-1).any(-1)
                examplemasks.append(examplemask)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            if self.training:
                # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions)
                ce = self.ce(probs, stepgold, mask=openmask)
                elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(probs, stepgold, mask=openmask)
                allmetrics["loss"].append(ce)
                allmetrics["ce"].append(ce)
                allmetrics["elemrecall"].append(elemrecall)
                allmetrics["allrecall"].append(allrecall)
                allmetrics["anyrecall"].append(anyrecall)
                allmetrics["lowestentropyrecall"].append(lowestentropyrecall)

            #  get best predictions,
            _, best_actions = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = - (entropies * torch.log(entropies)).sum(-1)

            if self.training:
                newprobs = torch.softmax(probs, -1) * stepgold              # mask using gold
                newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(1e-6)  # renormalize
                uniform = stepgold
                uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6)
                newprobs = newprobs * (1 - q.v(self.uniformfactor)) + uniform * q.v(self.uniformfactor)

                noprobsmask = (newprobs != 0).any(-1, keepdim=True).float()
                zeroprobs = torch.zeros_like(newprobs)
                zeroprobs[:, :, 0] = 1
                newprobs = newprobs * noprobsmask + (1-noprobsmask) * zeroprobs

                sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\
                    .sample().view(newprobs.size(0), newprobs.size(1))
                taken_actions = sampled_gold_actions
            else:
                taken_actions = best_actions

            #  attach chosen actions and entropies to existing trees,
            for tree, actions_e, entropies_e in zip(trees, taken_actions, entropies):
                actions_e = list(actions_e.detach().cpu().numpy())
                actions_e = [self.seqenc.vocab(xe) for xe in actions_e]
                entropies_e = list(entropies_e.detach().cpu().numpy())
                self.attach_info_to_tree(tree,
                    _chosen_action=actions_e,
                    _entropy=entropies_e)
            # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e)
            #          for seqe, openmaske, actione, entropies_e
            #          in zip(list(seq), list(openmask), list(taken_actions), list(entropies))]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    markmode = mode if not self.training else "train-"+mode
                    tree = mark_for_execution(tree, mode=markmode, entropylimit=self.entropylimit)
                    budget = [self.max_tree_size - tree_size(tree)]
                    if self.training:
                        tree, _ = convert_chosen_actions_from_str_to_node(tree)
                    tree = execute_chosen_actions(tree, _budget=budget, mode=mode)
                    if self.training:
                        tree = assign_gold_actions(tree, mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        ret["seqlens"] = torch.tensor(seqlens).float()
        ret["treesizes"] = torch.tensor(treesizes).float()
        ret["numsteps"] = torch.tensor(numsteps).float()

        if self.training:
            assert(len(examplemasks) > 0)
            assert(len(allmetrics["loss"]) > 0)
            allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()}
            examplemasks = torch.stack(examplemasks, 1).float()
            _allmetrics = {}
            for k, v in allmetrics.items():
                _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(1).clamp_min(1e-6)
            allmetrics = _allmetrics
            ret.update(allmetrics)

        if gold is not None:
            goldtrees = [tensors_to_tree(seqe, D=self.seqenc.vocab) for seqe in list(gold)]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            goldtree_tokens = [tree_to_lisp_tokens(xe) for xe in goldtrees]
            goldtree_tokens = [[xei.replace("|", "") for xei in xe] for xe in goldtree_tokens]
            goldtrees = [build_atree(xe) for xe in goldtree_tokens]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            predtree_tokens = [tree_to_lisp_tokens(xe) for xe in predtrees]
            predtree_tokens = [[xei.replace("|", "") for xei in xe] for xe in predtree_tokens]
            predtrees = [build_atree(xe) for xe in predtree_tokens]
            ret["treeacc"] = [float(are_equal_trees(gold_tree, pred_tree,
                            orderless=ORDERLESS, unktoken="@UNK@"))
                   for gold_tree, pred_tree in zip(goldtrees, predtrees)]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
 def tokenize_and_add_start(t, _domain):
     tokens = tree_to_lisp_tokens(t)
     starttok = f"@START/{_domain}@" if add_domain_start else "@START@"
     tokens = [starttok] + tokens
     return tokens
Exemplo n.º 6
0
def tree_to_seq(x: Tree):
    xstr = tree_to_lisp_tokens(x)
    # xstr = ["@BOS@"] + xstr + ["@EOS@"]
    return xstr
Exemplo n.º 7
0
 def tree_to_str(x:Tree):
     toks = tree_to_lisp_tokens(x, brackets="[]")
     toks = " ".join(toks)
     toks = toks.replace("[ ", "<[").replace("]", "]>")
     return toks