コード例 #1
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_invalid_action_behaviors():

    problem = "4x + 2x"
    env = MathyEnv(invalid_action_response="raise")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    rule_actions = env.get_valid_moves(env_state)
    rule_indices = [
        i for i, value in enumerate(rule_actions) if 1 not in value
    ]
    random.shuffle(rule_indices)
    rule_nodes = rule_actions[rule_indices[0]]
    node_indices = [i for i, value in enumerate(rule_nodes) if value == 0]
    action = (rule_indices[0], node_indices[0])

    # Raise an error when selecting an invalid action
    env_state = MathyEnvState(problem=problem, max_moves=35)
    with pytest.raises(ValueError):
        env.get_next_state(env_state, action)

    # Penalize the agent for choosing an invalid action
    env = MathyEnv(invalid_action_response="penalize")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    _, transition, _ = env.get_next_state(env_state, action)
    assert transition.reward == EnvRewards.INVALID_MOVE
    assert is_terminal_transition(transition) is False

    # End the episode when choosing an invalid action
    env = MathyEnv(invalid_action_response="terminal")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    _, transition, _ = env.get_next_state(env_state, action)
    # a transition is returned with error_invalid=False
    assert is_terminal_transition(transition) is True
コード例 #2
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_print_history(pretty: bool):
    env = PolySimplify()
    env_state = MathyEnvState(problem="4+2")
    for i in range(10):
        env_state = env_state.get_out_state(
            problem="4+2" if i % 2 == 0 else "2+4",
            moves_remaining=10 - i,
            action=(1, 1),
        )
    env.print_history(env_state, pretty=pretty)
コード例 #3
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_random_actions():
    env = MathyEnv(invalid_action_response="raise")
    state = MathyEnvState(problem="4x + 2x + 7 + y")
    expression = env.parser.parse(state.agent.problem)
    # Can select random actions of the given type
    action = env.random_action(expression, AssociativeSwapRule)
    env.get_next_state(state, action)

    # Can select random actions from all types
    state = MathyEnvState(problem="4x + 2x + 7 + y")
    expression = env.parser.parse(state.agent.problem)
    action = env.random_action(expression)
    env.get_next_state(state, action)
コード例 #4
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_terminal_conditions():

    expectations = [
        ("70656 * (x^2 * z^6)", True),
        ("b * (44b^2)", False),
        ("z * (1274z^2)", False),
        ("4x^2", True),
        ("100y * x + 2", True),
        ("10y * 10x + 2", False),
        ("10y + 1000y * (y * z)", False),
        ("4 * (5y + 2)", False),
        ("2", True),
        ("4x * 2", False),
        ("4x * 2x", False),
        ("4x + 2x", False),
        ("4 + 2", False),
        ("3x + 2y + 7", True),
        ("3x^2 + 2x + 7", True),
        ("3x^2 + 2x^2 + 7", False),
    ]

    # Valid solutions but out of scope so they aren't counted as wins.
    #
    # This works because the problem sets exclude this type of > 2 term
    # polynomial expressions
    out_of_scope_valid = []

    env = PolySimplify()
    for text, is_win in expectations + out_of_scope_valid:
        env_state = MathyEnvState(problem=text)
        reward = env.get_state_transition(env_state)
        assert text == text and env.is_terminal_state(env_state) == bool(
            is_win)
        assert text == text and is_terminal_transition(reward) == bool(is_win)
コード例 #5
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_action_masks():
    problem = "4x + 2x"
    env = MathyEnv(invalid_action_response="raise")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    valid_mask = env.get_valid_moves(env_state)
    assert len(valid_mask) == len(env.rules)
    assert len(valid_mask[0]) == len(env.parser.parse(problem).to_list())
コード例 #6
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_mathy_env_invalid_action_behaviors():
    env = MathyEnv()
    assert env is not None
    problem = "5y * 9x + 8z + 8x + 3z * 10y * 11x + 10y"
    env_state = MathyEnvState(problem=problem, max_moves=35)
    for i in range(3):
        rule_actions = env.get_valid_moves(env_state)
        rule_indices = [
            i for i, value in enumerate(rule_actions) if 1 in value
        ]
        random.shuffle(rule_indices)
        rule_nodes = rule_actions[rule_indices[0]]
        node_indices = [i for i, value in enumerate(rule_nodes) if value == 1]
        env_state, value, changed = env.get_next_state(
            env_state, (rule_indices[0], node_indices[0]))
    assert env_state.to_observation([]) is not None
コード例 #7
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_env_finalize_state():
    env = PolySimplify()

    env_state = MathyEnvState(problem="4x + 2x").get_out_state(
        problem="1337", action=(1, 1), moves_remaining=0)
    with pytest.raises(ValueError):
        env.finalize_state(env_state)

    env_state = MathyEnvState(problem="4x + 2x").get_out_state(
        problem="4x + 2", action=(1, 1), moves_remaining=0)
    with pytest.raises(ValueError):
        env.finalize_state(env_state)

    env_state = MathyEnvState(problem="4x + 2x").get_out_state(
        problem="4x + 2y", action=(1, 1), moves_remaining=0)
    with pytest.raises(ValueError):
        env.finalize_state(env_state)
コード例 #8
0
def render_features_from_text(input_text: str):
    global parser
    try:
        expression: MathExpression = parser.parse(input_text)
        state = MathyEnvState(problem=input_text)
        observation: MathyObservation = state.to_observation(
            hash_type=[13, 37])

        length = len(observation.nodes)
        types = observation.nodes
        values = observation.values
        nodes = expression.to_list()
        chars = [n.name for n in nodes]
        assert len(types) == len(values) == len(chars)

        view_x = 0
        view_y = 0
        view_w = BOX_SIZE * length
        view_h = BOX_SIZE * 3 + BORDER_WIDTH * 2

        tree = svgwrite.Drawing(size=(view_w, view_h))
        tree.viewbox(minx=view_x, miny=view_y, width=view_w, height=view_h)

        curr_x = BORDER_WIDTH
        for n, t, v, c in zip(nodes, types, values, chars):

            color = svgwrite.rgb(180, 200, 255)
            if isinstance(n, BinaryExpression):
                color = svgwrite.rgb(230, 230, 230)
            elif isinstance(n, VariableExpression):
                color = svgwrite.rgb(150, 250, 150)
            if n == n.get_root():
                color = svgwrite.rgb(250, 220, 200)

            box_with_char(tree, c, x=curr_x, y=BORDER_WIDTH, fill=color)
            box_with_char(tree, v, x=curr_x, y=BOX_SIZE + BORDER_WIDTH)
            box_with_char(tree, t, x=curr_x, y=BOX_SIZE * 2 + BORDER_WIDTH)
            curr_x += BOX_SIZE - (BORDER_WIDTH)

        return svgwrite.utils.pretty_xml(tree.tostring(), indent=2)
    except BaseException as error:
        return f"Failed to parse: '{input_text}' with error: {error}"
コード例 #9
0
ファイル: solver.py プロジェクト: justindujardin/mathy
def swarm_solve(
    problems: Union[List[str], str],
    config: SwarmConfig,
    max_steps: Union[List[int], int] = 128,
    silent: bool = False,
) -> Swarm:
    single_problem: bool = isinstance(problems, str)
    if single_problem:
        problems = [problems]
    if isinstance(max_steps, int):
        max_steps = [max_steps
                     ] if single_problem else [max_steps] * len(problems)
    assert len(problems) > 0, "no problems to solve"
    assert len(problems) == len(max_steps)
    assert isinstance(problems, list)
    current_problem: str = problems.pop(0)
    current_max_moves: int = max_steps.pop(0)

    def env_callable():
        nonlocal current_problem, current_max_moves
        return FragileMathyEnv(
            name="mathy_v0",
            problem=current_problem,
            repeat_problem=True,
            max_steps=current_max_moves,
        )

    mathy_env: MathyEnv = env_callable()._env._env.mathy
    swarm: Swarm = mathy_swarm(config, env_callable)
    while True:
        if not silent:
            with msg.loading(f"Solving {current_problem} ..."):
                swarm.run()
        else:
            swarm.run()

        if not silent:
            if swarm.walkers.best_reward > EnvRewards.WIN:
                last_state = MathyEnvState.from_np(
                    swarm.walkers.states.best_state)
                msg.good(
                    f"Solved! {current_problem} = {last_state.agent.problem}")
                mathy_env.print_history(last_state)
            else:
                msg.fail(f"Failed to find a solution :(")

        if len(max_steps) > 0:
            current_max_moves = max_steps.pop(0)
            current_problem = problems.pop(0)
        else:
            break
    return swarm
コード例 #10
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_mathy_env_preferred_term_commute():
    rule_idx = 1
    problem = "5y"
    env_state = MathyEnvState(problem=problem, max_moves=1)

    env = MathyEnv(preferred_term_commute=False)
    assert isinstance(env.rules[rule_idx],
                      CommutativeSwapRule), "update rule_idx"
    commute_nodes = env.get_valid_moves(env_state)[rule_idx]
    assert 1 not in commute_nodes, "shouldn't be able to commute preferred order terms"

    env = MathyEnv(preferred_term_commute=True)
    commute_nodes = env.get_valid_moves(env_state)[rule_idx]
    assert 1 in commute_nodes, "should be able to commute preferred order terms"
コード例 #11
0
ファイル: test_env.py プロジェクト: mathy/mathy_envs
def test_mathy_env_previous_state_penalty():
    """When previous_state_penalty=True, a negative reward is given when
    revisiting already seen problem states. If an agent revisits the
    state too many times, the game ends."""

    # We define the input problem with 3 nodes for simplicity
    # "x * y" == ["x","*","y"]
    # Because the tree is small and balanced, we can commute the
    # same node over and over to flip back-and-forth between x * y
    # and y * x.
    problem = "x * y"
    env = MathyEnv(previous_state_penalty=True)
    rule_idx = 1
    node_idx = 1
    assert isinstance(env.rules[rule_idx],
                      CommutativeSwapRule), "update rule_idx"
    action = (rule_idx, node_idx)
    env_state = MathyEnvState(problem=problem, max_moves=10)
    # Commute the first time so we are revisit the initial state
    # as we apply the same action again.
    env_state, _, _ = env.get_next_state(env_state, action)

    # After three visits to the same state, the game ends.
    last_penalty = 0.0
    found_terminal = False
    for i in range(3):
        env_state, transition, changed = env.get_next_state(env_state, action)
        assert transition.reward < 0.0
        # The penalty scales up based on the number of visits to the state
        assert transition.reward < last_penalty
        last_penalty = transition.reward

        if i < 2:
            # Visit the opposite state and ignore it (we only care about revisiting
            # the initial state)
            env_state, _, _ = env.get_next_state(env_state, action)
        else:
            # After the third time, we should receive a terminal transition
            assert is_terminal_transition(transition) is True
            found_terminal = True

    assert found_terminal is True, "did not receive expected terminal transition"
コード例 #12
0

class CustomTimestepRewards(MathyEnv):
    def get_rewarding_actions(self,
                              state: MathyEnvState) -> List[Type[BaseRule]]:
        return [rules.AssociativeSwapRule]

    def get_penalizing_actions(self,
                               state: MathyEnvState) -> List[Type[BaseRule]]:
        return [rules.CommutativeSwapRule]


env = CustomTimestepRewards()
problem = "4x + y + 2x"
expression = env.parser.parse(problem)
state = MathyEnvState(problem=problem)

action = env.random_action(expression, rules.AssociativeSwapRule)
_, transition, _ = env.get_next_state(
    state,
    action,
)
# Expect positive reward
assert transition.reward > 0.0

_, transition, _ = env.get_next_state(
    state,
    env.random_action(expression, rules.CommutativeSwapRule),
)
# Expect neagative reward
assert transition.reward < 0.0
コード例 #13
0
from mathy_core.rules import ConstantsSimplifyRule
from mathy_envs import MathyEnvState, envs, is_terminal_transition


class CustomEpisodeRewards(envs.PolySimplify):
    def get_win_signal(self, env_state: MathyEnvState) -> float:
        return 20.0

    def get_lose_signal(self, env_state: MathyEnvState) -> float:
        return -20.0


env = CustomEpisodeRewards()

# Win by simplifying constants and yielding a single simple term form
state = MathyEnvState(problem="(4 + 2) * x")
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, ConstantsSimplifyRule)
out_state, transition, _ = env.get_next_state(state, action)
assert is_terminal_transition(transition) is True
assert transition.reward == 20.0
assert out_state.agent.problem == "6x"

# Lose by applying a rule with only 1 move remaining
state = MathyEnvState(problem="2x + (4 + 2) + 4x", max_moves=1)
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, ConstantsSimplifyRule)
out_state, transition, _ = env.get_next_state(state, action)
assert is_terminal_transition(transition) is True
assert transition.reward == -20.0
assert out_state.agent.problem == "2x + 6 + 4x"
コード例 #14
0
ファイル: solver.py プロジェクト: justindujardin/mathy
 def set_state(self, state: np.ndarray):
     assert self._env is not None, "env required to set_state"
     self._env.state = MathyEnvState.from_np(state)
     return state
コード例 #15
0
    rule = rules.DistributiveFactorOutRule()

    def transition_fn(
        self,
        env_state: MathyEnvState,
        expression: MathExpression,
        features: MathyObservation,
    ) -> Optional[time_step.TimeStep]:
        # If the rule can find any applicable nodes
        if self.rule.find_node(expression) is not None:
            # Return a terminal transition with reward
            return time_step.termination(features,
                                         self.get_win_signal(env_state))
        # None does nothing
        return None


env = CustomWinConditions()

# This state is not terminal because none of the nodes can have the distributive
# factoring rule applied to them.
state_one = MathyEnvState(problem="4x + y + 2x")
transition = env.get_state_transition(state_one)
assert is_terminal_transition(transition) is False

# This is a terminal state because the nodes representing "4x + 2x" can
# have the distributive factoring rule applied to them.
state_two = MathyEnvState(problem="4x + 2x + y")
transition = env.get_state_transition(state_two)
assert is_terminal_transition(transition) is True
コード例 #16
0
    @property
    def code(self) -> str:
        return "PN"

    def can_apply_to(self, node) -> bool:
        is_sub = isinstance(node, SubtractExpression)
        is_parent_add = isinstance(node.parent, AddExpression)
        return is_sub and (node.parent is None or is_parent_add)

    def apply_to(self, node):
        change = super().apply_to(node)
        change.save_parent()  # connect result to node.parent
        result = AddExpression(node.left, NegateExpression(node.right))
        result.set_changed()  # mark this node as changed for visualization
        return change.done(result)


class CustomActionEnv(envs.PolySimplify):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.rules = MathyEnv.core_rules() + [PlusNegationRule()]


env = CustomActionEnv()

state = MathyEnvState(problem="4x - 2x")
expression = env.parser.parse(state.agent.problem)
action = env.random_action(expression, PlusNegationRule)
out_state, transition, _ = env.get_next_state(state, action)
assert out_state.agent.problem == "4x + -2x"