Beispiel #1
0
def tree_search_actions(state: State, depth: int, factor=0.22, occupation_threshold=0.0):
    colors: List[int] = []
    for deal in state.deals[1:]:
        colors.extend(deal)

    action_mask: int = 0
    for action in state.actions:
        action_mask |= 1 << state._validation_actions.index(action)

    search_args = [
        state.num_layers,
        state.has_garbage,
        action_mask,
        colors,
        depth - 1,
        factor,
    ]
    search_fun = core.bottom_tree_search
    if isinstance(state.field, TallField):
        search_args.insert(1, state.tsu_rules)
        search_args.insert(1, state.width)
        search_fun = core.tall_tree_search

    base_popcount: int = state.field.popcount
    prevent_chains: bool = (base_popcount < occupation_threshold * state.width * state.height)

    best_indices = []
    best_score: float = float("-inf")

    possible_indices = []
    possible_score = float("-inf")
    for index, (child, score) in enumerate(state.get_children(True)):
        if not child:
            continue

        args = [child.field.data] + search_args
        tree_score: float = search_fun(*args)

        child_score = score + GAMMA * tree_score

        if prevent_chains and child.field.popcount < base_popcount:
            if child_score > possible_score:
                possible_indices = [index]
                possible_score = child_score
            elif child_score == possible_score:
                possible_indices.append(index)
        else:
            if child_score > best_score:
                best_indices = [index]
                best_score = child_score
            elif child_score == best_score:
                best_indices.append(index)
    return best_indices or possible_indices or [np.random.randint(0, len(state.actions))]
Beispiel #2
0
def test_has_moves_tsu():
    state = State(13, 2, 4, 1, tsu_rules=True)
    stack = [_, _, _, _, _, _, _, _] * state.field.offset
    stack += [
        _,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        G,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
    ]
    state.field = TallField.from_list(stack,
                                      num_layers=state.num_layers,
                                      tsu_rules=state.tsu_rules)
    state.render()
    assert (state.get_children())
Beispiel #3
0
def test_no_moves():
    state = State(8, 2, 4, 1)
    stack = [
        _,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        B,
        _,
        _,
        _,
        _,
        _,
        _,
        G,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        Y,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        G,
        _,
        _,
        _,
        _,
        _,
        _,
        B,
        R,
        _,
        _,
        _,
        _,
        _,
        _,
    ]
    state.field = BottomField.from_list(stack, num_layers=state.num_layers)
    state.render()
    assert (not state.get_children())