def convert_chosen_actions_from_str_to_node(x: ATree, c=0):
    """ a tree with chosen actions as strings and gold is transformed to chosen actions from gold """
    if x.is_open and x._chosen_action is not None:
        assert (isinstance(x._chosen_action, str))
        for gold_action in x.gold_actions:
            if isinstance(gold_action, ATree):
                if gold_action.label() == x._chosen_action:
                    x._chosen_action = gold_action
                    break
            elif isinstance(gold_action, str):
                if gold_action == x._chosen_action:
                    x._chosen_action = gold_action
                    break
        # assert(isinstance(x._chosen_action, ATree))
        c = c + 1
    children = []
    for child in x:
        child, c = convert_chosen_actions_from_str_to_node(child, c)
        children.append(child)
    x[:] = children
    if c > 1:
        assert (c <= 1)
    return x, c
Esempio n. 2
0
def build_atree(x:Iterable[str], open:Iterable[bool]=None, chosen_actions:Iterable[str]=None, entropies=None):
    try:
        open = [False for _ in x] if open is None else open
        chosen_actions = [None for _ in x] if chosen_actions is None else chosen_actions
        entropies = [0 for _ in x] if entropies is None else entropies
        nodes = []
        for xe, opene, chosen_action, entropy in zip(x, open, chosen_actions, entropies):
            if xe == "(" or xe == ")":
                nodes.append(xe)
                assert(opene == False)
                # assert(chosen_action is None)
            else:
                a = ATree(xe, [], is_open=opene)
                a._chosen_action = chosen_action
                a._entropy = entropy
                nodes.append(a)

        buffer = list(nodes)
        stack = []
        keepgoing = len(buffer) > 0
        while keepgoing:
            if len(stack) > 0 and stack[-1] == ")":
                    stack.pop(-1)
                    acc = []
                    while len(acc) == 0 or not stack[-1] == "(":
                        acc.append(stack.pop(-1))
                    stack.pop(-1)
                    node = acc.pop(-1)
                    node[:] = reversed(acc)
                    for nodechild in node:
                        nodechild.parent = node
                    stack.append(node)
            else:
                if len(buffer) == 0:
                    keepgoing = False
                else:
                    stack.append(buffer.pop(0))
        if len(stack) == 1:
            # assert(len(stack) == 1)
            return stack[0]
        else:
            return None
    except Exception as e:
        return None
def assign_gold_actions(x: ATree, mode="default"):
    """
    :param x:
    :param mode:    "default" (all) or "ltr" (only first one)
    :return:
    """
    """ assigns actions that can be taken at every node of the given tree """
    for xe in x:
        assign_gold_actions(xe, mode=mode)
    if not x.is_open:
        x.gold_actions = []
    else:
        if x.label() == ")" or x.label() == "(":
            x.is_open = False
            x.gold_actions = []
        elif x.label() == "@SLOT@":
            if len(x.parent) == 1:
                raise Exception()
            # get this slots's siblings
            x.gold_actions = []
            xpos = child_number_of(x)
            if xpos == 0:
                leftsibling = None
                leftsibling_nr = None
            else:
                leftsibling = x.parent[xpos - 1]
                leftsibling_nr = child_number_of(leftsibling.align)
            if xpos == len(x.parent) - 1:
                rightsibling = None
                rightsibling_nr = None
            else:
                rightsibling = x.parent[xpos + 1]
                rightsibling_nr = child_number_of(rightsibling.align)

            if leftsibling is None and rightsibling is None:
                # slot is only child, can use any descendant
                x.gold_actions = x.parent.align.descendants
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
                assert (False
                        )  # should not happen if deletion actions are not used
            else:
                p = leftsibling.align.parent if leftsibling is not None else rightsibling.align.parent
                slicefrom = leftsibling_nr + 1 if leftsibling_nr is not None else None
                slicer = slice(slicefrom, rightsibling_nr)
                x.gold_actions = p[slicer]
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
            if len(x.gold_actions) == 0:
                x.gold_actions = ["@CLOSE@"]
        else:  # not a sibling slot ("@SLOT@"), not a "(" or ")"
            x.gold_actions = []
            if len(x) == 0:
                x.gold_actions = list(x.align._descendants)
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
            else:
                realchildren = [xe for xe in x if xe.label() != "@SLOT@"]
                childancestors = realchildren[0].align._ancestors[::-1]
                for child in realchildren:
                    assert (childancestors == child.align._ancestors[::-1])
                for ancestor in childancestors:
                    if ancestor is x.align:
                        break
                    else:
                        x.gold_actions.append(ancestor)
                if mode == "ltr" and len(x.gold_actions) > 0:
                    x.gold_actions = [x.gold_actions[0]]
        if len(x.gold_actions) == 0 and x.is_open:
            x.gold_actions = ["@CLOSE@"]

    if len(x.gold_actions) > 0:
        # x._chosen_action = x.gold_actions[0]
        x._chosen_action = random.choice(x.gold_actions)
    else:
        x._chosen_action = None
    return x