def get_init_state(self, inpseqs=None, y_in=None) -> Tuple[ATree, Dict]:
        """ Encodes inpseqs and creates new states """
        assert (y_in is None
                )  # starting decoding from non-vanilla is not supported yet
        # encode inpseqs
        encmask = (inpseqs != 0)
        encs = self.bert_model(inpseqs)[0]
        if self.adapter is not None:
            encs = self.adapter(encs)

        # create trees
        batsize = inpseqs.size(0)
        trees = [ATree("@START@", []) for _ in range(batsize)]
        return trees, {"enc": encs, "encmask": encmask}
def convert_chosen_actions_from_str_to_node(x: ATree, c=0):
    """ a tree with chosen actions as strings and gold is transformed to chosen actions from gold """
    if x.is_open and x._chosen_action is not None:
        assert (isinstance(x._chosen_action, str))
        for gold_action in x.gold_actions:
            if isinstance(gold_action, ATree):
                if gold_action.label() == x._chosen_action:
                    x._chosen_action = gold_action
                    break
            elif isinstance(gold_action, str):
                if gold_action == x._chosen_action:
                    x._chosen_action = gold_action
                    break
        # assert(isinstance(x._chosen_action, ATree))
        c = c + 1
    children = []
    for child in x:
        child, c = convert_chosen_actions_from_str_to_node(child, c)
        children.append(child)
    x[:] = children
    if c > 1:
        assert (c <= 1)
    return x, c
def assign_gold_actions(x: ATree, mode="default"):
    """
    :param x:
    :param mode:    "default" (all) or "ltr" (only first one)
    :return:
    """
    """ assigns actions that can be taken at every node of the given tree """
    for xe in x:
        assign_gold_actions(xe, mode=mode)
    if not x.is_open:
        x.gold_actions = []
    else:
        if x.label() == ")" or x.label() == "(":
            x.is_open = False
            x.gold_actions = []
        elif x.label() == "@SLOT@":
            if len(x.parent) == 1:
                raise Exception()
            # get this slots's siblings
            x.gold_actions = []
            xpos = child_number_of(x)
            if xpos == 0:
                leftsibling = None
                leftsibling_nr = None
            else:
                leftsibling = x.parent[xpos - 1]
                leftsibling_nr = child_number_of(leftsibling.align)
            if xpos == len(x.parent) - 1:
                rightsibling = None
                rightsibling_nr = None
            else:
                rightsibling = x.parent[xpos + 1]
                rightsibling_nr = child_number_of(rightsibling.align)

            if leftsibling is None and rightsibling is None:
                # slot is only child, can use any descendant
                x.gold_actions = x.parent.align.descendants
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
                assert (False
                        )  # should not happen if deletion actions are not used
            else:
                p = leftsibling.align.parent if leftsibling is not None else rightsibling.align.parent
                slicefrom = leftsibling_nr + 1 if leftsibling_nr is not None else None
                slicer = slice(slicefrom, rightsibling_nr)
                x.gold_actions = p[slicer]
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
            if len(x.gold_actions) == 0:
                x.gold_actions = ["@CLOSE@"]
        else:  # not a sibling slot ("@SLOT@"), not a "(" or ")"
            x.gold_actions = []
            if len(x) == 0:
                x.gold_actions = list(x.align._descendants)
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
            else:
                realchildren = [xe for xe in x if xe.label() != "@SLOT@"]
                childancestors = realchildren[0].align._ancestors[::-1]
                for child in realchildren:
                    assert (childancestors == child.align._ancestors[::-1])
                for ancestor in childancestors:
                    if ancestor is x.align:
                        break
                    else:
                        x.gold_actions.append(ancestor)
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
        if len(x.gold_actions) == 0 and x.is_open:
            x.gold_actions = ["@CLOSE@"]

    if len(x.gold_actions) > 0:
        # x._chosen_action = x.gold_actions[0]
        x._chosen_action = random.choice(x.gold_actions)
    else:
        x._chosen_action = None
    return x
def load_ds(domain="restaurants",
            nl_mode="bert-base-uncased",
            trainonvalid=False,
            noreorder=False):
    """
    Creates a dataset of examples which have
    * NL question and tensor
    * original FL tree
    * reduced FL tree with slots (this is randomly generated)
    * tensor corresponding to reduced FL tree with slots
    * mask specifying which elements in reduced FL tree are terminated
    * 2D gold that specifies whether a token/action is in gold for every position (compatibility with MML!)
    """
    orderless = {"op:and", "SW:concat"}  # only use in eval!!

    ds = OvernightDatasetLoader().load(domain=domain,
                                       trainonvalid=trainonvalid)
    ds = ds.map(lambda x: (x[0], ATree("@START@", [x[1]]), x[2]))

    if not noreorder:
        ds = ds.map(lambda x:
                    (x[0], reorder_tree(x[1], orderless=orderless), x[2]))

    vocab = Vocab(padid=0, startid=2, endid=3, unkid=1)
    vocab.add_token("@START@", seen=np.infty)
    vocab.add_token(
        "@CLOSE@", seen=np.infty
    )  # only here for the action of closing an open position, will not be seen at input
    vocab.add_token(
        "@OPEN@", seen=np.infty
    )  # only here for the action of opening a closed position, will not be seen at input
    vocab.add_token(
        "@REMOVE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token(
        "@REMOVESUBTREE@", seen=np.infty
    )  # only here for deletion operations, won't be seen at input
    vocab.add_token("@SLOT@",
                    seen=np.infty)  # will be seen at input, can't be produced!

    nl_tokenizer = BertTokenizer.from_pretrained(nl_mode)
    # for tok, idd in nl_tokenizer.vocab.items():
    #     vocab.add_token(tok, seen=np.infty)          # all wordpieces are added for possible later generation

    tds, vds, xds = ds[lambda x: x[2] == "train"], \
                    ds[lambda x: x[2] == "valid"], \
                    ds[lambda x: x[2] == "test"]

    seqenc = SequenceEncoder(
        vocab=vocab,
        tokenizer=lambda x: extract_info(x, onlytokens=True),
        add_start_token=False,
        add_end_token=False)
    for example in tds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=True)
    for example in vds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    for example in xds.examples:
        query = example[1]
        seqenc.inc_build_vocab(query, seen=False)
    seqenc.finalize_vocab(min_freq=0)

    def mapper(x):
        nl = x[0]
        fl = x[1]
        fltoks = extract_info(fl, onlytokens=True)
        seq = seqenc.convert(fltoks, return_what="tensor")
        ret = (nl_tokenizer.encode(nl, return_tensors="pt")[0], seq)
        return ret

    tds_seq = tds.map(mapper)
    vds_seq = vds.map(mapper)
    xds_seq = xds.map(mapper)
    return tds_seq, vds_seq, xds_seq, nl_tokenizer, seqenc, orderless
def execute_chosen_actions(x: ATree, _budget=[np.infty], mode="full"):
    if x._chosen_action is None or not x.is_open:
        iterr = list(x)
        for xe in iterr:
            execute_chosen_actions(xe, _budget=_budget, mode=mode)
        return x
    if x.label() == "(":  # insert a parent before current parent
        pass
    elif x.label() == ")":
        pass
    elif x.label() == "@SLOT@":
        if x._chosen_action == "@CLOSE@":
            del x.parent[child_number_of(x)]
            # if parentheses became empty, remove
            _budget[0] += 1
            if len(x.parent) == 2 and x.parent[0].label(
            ) == "(" and x.parent[1].label() == ")":
                x.parent[:] = []
        else:
            if _budget[0] <= 0:
                return x
            if isinstance(x._chosen_action, Tree):
                x.set_label(x._chosen_action.label())
            else:
                x.set_label(x._chosen_action)

            if isinstance(x._chosen_action, Tree):
                x.align = x._chosen_action
            x.is_open = True

            leftslot = ATree("@SLOT@", [], is_open=True)
            leftslot.parent = x.parent
            rightslot = ATree("@SLOT@", [], is_open=True)
            rightslot.parent = x.parent

            if mode != "ltr":
                x.parent.insert(child_number_of(x), leftslot)
                _budget[0] -= 1

            x.parent.insert(child_number_of(x) + 1, rightslot)
            _budget[0] -= 1

    else:
        iterr = list(x)
        for xe in iterr:
            execute_chosen_actions(xe, _budget=_budget, mode=mode)
        if _budget[0] <= 0:
            return x
        if x._chosen_action == "@CLOSE@":
            x.is_open = False  # this node can't generate children anymore
        else:
            # X(A, B, C) -> X _( _@SLOT _Y (A, B, C) [_@SLOT] _)
            # add child, with "(" and ")" and "@SLOT@" nodes

            if isinstance(x._chosen_action, Tree):
                newnode = ATree(x._chosen_action.label(), [])
            else:
                newnode = ATree(x._chosen_action, [])

            newnode.is_open = True
            if mode == "ltr":
                x.is_open = False

            if isinstance(x._chosen_action, Tree):
                newnode.align = x._chosen_action
            newnode.parent = x

            newnode[:] = x[:]
            for xe in newnode:
                xe.parent = newnode

            leftslot = ATree("@SLOT@", [], is_open=True)
            leftslot.parent = newnode.parent

            rightslot = ATree("@SLOT@", [], is_open=True)
            rightslot.parent = newnode.parent

            if mode != "ltr":
                x[:] = [leftslot, newnode, rightslot]
                _budget[0] -= 3
            else:
                x[:] = [newnode, rightslot]
                _budget[0] -= 2
    return x