示例#1
0
def test_mathy_env_invalid_action_behaviors():
    env = MathyEnv()
    assert env is not None
    problem = "5y * 9x + 8z + 8x + 3z * 10y * 11x + 10y"
    env_state = MathyEnvState(problem=problem, max_moves=35)
    for i in range(3):
        rule_actions = env.get_valid_moves(env_state)
        rule_indices = [
            i for i, value in enumerate(rule_actions) if 1 in value
        ]
        random.shuffle(rule_indices)
        rule_nodes = rule_actions[rule_indices[0]]
        node_indices = [i for i, value in enumerate(rule_nodes) if value == 1]
        env_state, value, changed = env.get_next_state(
            env_state, (rule_indices[0], node_indices[0]))
    assert env_state.to_observation([]) is not None
示例#2
0
def render_features_from_text(input_text: str):
    global parser
    try:
        expression: MathExpression = parser.parse(input_text)
        state = MathyEnvState(problem=input_text)
        observation: MathyObservation = state.to_observation(
            hash_type=[13, 37])

        length = len(observation.nodes)
        types = observation.nodes
        values = observation.values
        nodes = expression.to_list()
        chars = [n.name for n in nodes]
        assert len(types) == len(values) == len(chars)

        view_x = 0
        view_y = 0
        view_w = BOX_SIZE * length
        view_h = BOX_SIZE * 3 + BORDER_WIDTH * 2

        tree = svgwrite.Drawing(size=(view_w, view_h))
        tree.viewbox(minx=view_x, miny=view_y, width=view_w, height=view_h)

        curr_x = BORDER_WIDTH
        for n, t, v, c in zip(nodes, types, values, chars):

            color = svgwrite.rgb(180, 200, 255)
            if isinstance(n, BinaryExpression):
                color = svgwrite.rgb(230, 230, 230)
            elif isinstance(n, VariableExpression):
                color = svgwrite.rgb(150, 250, 150)
            if n == n.get_root():
                color = svgwrite.rgb(250, 220, 200)

            box_with_char(tree, c, x=curr_x, y=BORDER_WIDTH, fill=color)
            box_with_char(tree, v, x=curr_x, y=BOX_SIZE + BORDER_WIDTH)
            box_with_char(tree, t, x=curr_x, y=BOX_SIZE * 2 + BORDER_WIDTH)
            curr_x += BOX_SIZE - (BORDER_WIDTH)

        return svgwrite.utils.pretty_xml(tree.tostring(), indent=2)
    except BaseException as error:
        return f"Failed to parse: '{input_text}' with error: {error}"