def mapper(x):
     nl = x[0]
     fl = x[1]
     fltoks = extract_info(fl, onlytokens=True)
     seq = seqenc.convert(fltoks, return_what="tensor")
     ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq)
     return ret
    def forward(self,
                inpseqs: torch.Tensor = None,
                gold: torch.Tensor = None,
                mode: str = None,
                maxsteps: int = None,
                **kw):
        """

        """
        maxsteps = maxsteps if maxsteps is not None else self.maxsteps
        mode = mode if mode is not None else self.mode
        device = next(self.parameters()).device

        trees, context = self.tagger.get_init_state(inpseqs=inpseqs, y_in=None)
        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [
                add_descendants_ancestors(goldtree) for goldtree in goldtrees
            ]
            for i in range(len(trees)):
                trees[i].align = goldtrees[i]

        i = 0

        if self.training:
            trees = [assign_gold_actions(tree, mode=mode) for tree in trees]
        # choices = [deepcopy(trees)]
        numsteps = [0 for _ in range(len(trees))]
        treesizes = [0 for _ in range(len(trees))]
        seqlens = [0 for _ in range(len(trees))]
        allmetrics = {}
        allmetrics["loss"] = []
        allmetrics["ce"] = []
        allmetrics["elemrecall"] = []
        allmetrics["allrecall"] = []
        allmetrics["anyrecall"] = []
        allmetrics["lowestentropyrecall"] = []
        examplemasks = []
        while not all([all_terminated(tree)
                       for tree in trees]) and i < maxsteps:
            # go from tree to tensors,
            seq = []
            openmask = []
            stepgold = []
            for j, tree in enumerate(trees):
                if not all_terminated(tree):
                    numsteps[j] += 1
                treesizes[j] = tree_size(tree)

                fltoks, openmask_e, stepgold_e = extract_info(tree)
                seqlens[j] = len(fltoks)
                seq_e = self.seqenc.convert(fltoks, return_what="tensor")
                seq.append(seq_e)
                openmask.append(torch.tensor(openmask_e))
                if self.training:
                    stepgold_e_tensor = torch.zeros(
                        seq_e.size(0), self.seqenc.vocab.number_of_ids())
                    for j, stepgold_e_i in enumerate(stepgold_e):
                        for golde in stepgold_e_i:
                            stepgold_e_tensor[j, self.seqenc.vocab[golde]] = 1
                    stepgold.append(stepgold_e_tensor)
            seq = torch.stack(q.pad_tensors(seq, 0, 0), 0).to(device)
            openmask = torch.stack(q.pad_tensors(openmask, 0, False),
                                   0).to(device)
            if self.training:
                stepgold = torch.stack(q.pad_tensors(stepgold, 0, 0),
                                       0).to(device)

            if self.training:
                examplemask = (stepgold != 0).any(-1).any(-1)
                examplemasks.append(examplemask)

            #  feed to tagger,
            probs = self.tagger(seq, openmask=openmask, **context)

            if self.training:
                # stepgold is (batsize, seqlen, vocsize) with zeros and ones (ones for good actions)
                ce = self.ce(probs, stepgold, mask=openmask)
                elemrecall, allrecall, anyrecall, lowestentropyrecall = self.recall(
                    probs, stepgold, mask=openmask)
                allmetrics["loss"].append(ce)
                allmetrics["ce"].append(ce)
                allmetrics["elemrecall"].append(elemrecall)
                allmetrics["allrecall"].append(allrecall)
                allmetrics["anyrecall"].append(anyrecall)
                allmetrics["lowestentropyrecall"].append(lowestentropyrecall)

            #  get best predictions,
            _, best_actions = probs.max(-1)
            entropies = torch.softmax(probs, -1).clamp_min(1e-6)
            entropies = -(entropies * torch.log(entropies)).sum(-1)

            if self.training:
                newprobs = torch.softmax(probs,
                                         -1) * stepgold  # mask using gold
                newprobs = newprobs / newprobs.sum(-1)[:, :, None].clamp_min(
                    1e-6)  # renormalize
                uniform = stepgold
                uniform = uniform / uniform.sum(-1)[:, :, None].clamp_min(1e-6)
                newprobs = newprobs * (1 - q.v(
                    self.uniformfactor)) + uniform * q.v(self.uniformfactor)

                noprobsmask = (newprobs != 0).any(-1, keepdim=True).float()
                zeroprobs = torch.zeros_like(newprobs)
                zeroprobs[:, :, 0] = 1
                newprobs = newprobs * noprobsmask + (1 -
                                                     noprobsmask) * zeroprobs

                sampled_gold_actions = torch.distributions.categorical.Categorical(probs=newprobs.view(-1, newprobs.size(-1)))\
                    .sample().view(newprobs.size(0), newprobs.size(1))
                taken_actions = sampled_gold_actions
            else:
                taken_actions = best_actions

            #  attach chosen actions and entropies to existing trees,
            for tree, actions_e, entropies_e in zip(trees, taken_actions,
                                                    entropies):
                actions_e = list(actions_e.detach().cpu().numpy())
                actions_e = [self.seqenc.vocab(xe) for xe in actions_e]
                entropies_e = list(entropies_e.detach().cpu().numpy())
                self.attach_info_to_tree(tree,
                                         _chosen_action=actions_e,
                                         _entropy=entropies_e)
            # trees = [tensors_to_tree(seqe, openmask=openmaske, actions=actione, D=self.seqenc.vocab, entropies=entropies_e)
            #          for seqe, openmaske, actione, entropies_e
            #          in zip(list(seq), list(openmask), list(taken_actions), list(entropies))]

            #  and execute,
            trees_ = []
            for tree in trees:
                if tree_size(tree) < self.max_tree_size:
                    markmode = mode if not self.training else "train-" + mode
                    tree = mark_for_execution(tree,
                                              mode=markmode,
                                              entropylimit=self.entropylimit)
                    budget = [self.max_tree_size - tree_size(tree)]
                    if self.training:
                        tree, _ = convert_chosen_actions_from_str_to_node(tree)
                    tree = execute_chosen_actions(tree,
                                                  _budget=budget,
                                                  mode=mode)
                    if self.training:
                        tree = assign_gold_actions(tree, mode=mode)
                trees_.append(tree)

            trees = trees_
            i += 1
            #  then repeat until all terminated

        # after done decoding, if gold is given, run losses, else return just predictions

        ret = {}

        ret["seqlens"] = torch.tensor(seqlens).float()
        ret["treesizes"] = torch.tensor(treesizes).float()
        ret["numsteps"] = torch.tensor(numsteps).float()

        if self.training:
            assert (len(examplemasks) > 0)
            assert (len(allmetrics["loss"]) > 0)
            allmetrics = {k: torch.stack(v, 1) for k, v in allmetrics.items()}
            examplemasks = torch.stack(examplemasks, 1).float()
            _allmetrics = {}
            for k, v in allmetrics.items():
                _allmetrics[k] = (v * examplemasks).sum(1) / examplemasks.sum(
                    1).clamp_min(1e-6)
            allmetrics = _allmetrics
            ret.update(allmetrics)

        if gold is not None:
            goldtrees = [
                tensors_to_tree(seqe, D=self.seqenc.vocab)
                for seqe in list(gold)
            ]
            goldtrees = [simplify_tree_for_eval(x) for x in goldtrees]
            predtrees = [simplify_tree_for_eval(x) for x in trees]
            ret["treeacc"] = [
                float(
                    are_equal_trees(gold_tree,
                                    pred_tree,
                                    orderless=ORDERLESS,
                                    unktoken="@UNK@"))
                for gold_tree, pred_tree in zip(goldtrees, predtrees)
            ]
            ret["treeacc"] = torch.tensor(ret["treeacc"]).to(device)

        return ret, trees
def load_ds(domain="restaurants",
            nl_mode="bert-base-uncased",
            trainonvalid=False,
            noreorder=False):
    """
    Creates a dataset of examples which have
    * NL question and tensor
    * original FL tree
    * reduced FL tree with slots (this is randomly generated)
    * tensor corresponding to reduced FL tree with slots
    * mask specifying which elements in reduced FL tree are terminated
    * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!)
    """
    orderless = {"op:and", "SW:concat"}  # only use in eval!!

    ds = OvernightDatasetLoader().load(domain=domain,
                                       trainonvalid=trainonvalid)
    ds = ds.map(lambda x: (x[0], ATree("@START@", [x[1]]), x[2]))

    if not noreorder:
        ds = ds.map(lambda x:
                    (x[0], reorder_tree(x[1], orderless=orderless), x[2]))

    vocab = Vocab(padid=0, startid=2, endid=3, unkid=1)
    vocab.add_token("@START@", seen=np.infty)
    vocab.add_token(
        "@CLOSE@", seen=np.infty
    )  # only here for the action of closing an open position, will not be seen at input
    vocab.add_token(
        "@OPEN@", seen=np.infty
    )  # only here for the action of opening a closed position, will not be seen at input
    vocab.add_token(
        "@REMOVE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token(
        "@REMOVESUBTREE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token("@SLOT@",
                    seen=np.infty)  # will be seen at input, can't be produced!

    nl_tokenizer = BertTokenizer.from_pretrained(nl_mode)
    # for tok, idd in nl_tokenizer.vocab.items():
    #     vocab.add_token(tok, seen=np.infty)          # all wordpieces are added for possible later generation

    tds, vds, xds = ds[lambda x: x[2] == "train"], \
                    ds[lambda x: x[2] == "valid"], \
                    ds[lambda x: x[2] == "test"]

    seqenc = SequenceEncoder(
        vocab=vocab,
        tokenizer=lambda x: extract_info(x, onlytokens=True),
        add_start_token=False,
        add_end_token=False)
    for example in tds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=True)
    for example in vds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    for example in xds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    seqenc.finalize_vocab(min_freq=0)

    def mapper(x):
        nl = x[0]
        fl = x[1]
        fltoks = extract_info(fl, onlytokens=True)
        seq = seqenc.convert(fltoks, return_what="tensor")
        ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq)
        return ret

    tds_seq = tds.map(mapper)
    vds_seq = vds.map(mapper)
    xds_seq = xds.map(mapper)
    return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless