def test_env_invalid_action_behaviors(): problem = "4x + 2x" env = MathyEnv(invalid_action_response="raise") env_state = MathyEnvState(problem=problem, max_moves=35) rule_actions = env.get_valid_moves(env_state) rule_indices = [ i for i, value in enumerate(rule_actions) if 1 not in value ] random.shuffle(rule_indices) rule_nodes = rule_actions[rule_indices[0]] node_indices = [i for i, value in enumerate(rule_nodes) if value == 0] action = (rule_indices[0], node_indices[0]) # Raise an error when selecting an invalid action env_state = MathyEnvState(problem=problem, max_moves=35) with pytest.raises(ValueError): env.get_next_state(env_state, action) # Penalize the agent for choosing an invalid action env = MathyEnv(invalid_action_response="penalize") env_state = MathyEnvState(problem=problem, max_moves=35) _, transition, _ = env.get_next_state(env_state, action) assert transition.reward == EnvRewards.INVALID_MOVE assert is_terminal_transition(transition) is False # End the episode when choosing an invalid action env = MathyEnv(invalid_action_response="terminal") env_state = MathyEnvState(problem=problem, max_moves=35) _, transition, _ = env.get_next_state(env_state, action) # a transition is returned with error_invalid=False assert is_terminal_transition(transition) is True
def test_env_print_history(pretty: bool): env = PolySimplify() env_state = MathyEnvState(problem="4+2") for i in range(10): env_state = env_state.get_out_state( problem="4+2" if i % 2 == 0 else "2+4", moves_remaining=10 - i, action=(1, 1), ) env.print_history(env_state, pretty=pretty)
def test_env_random_actions(): env = MathyEnv(invalid_action_response="raise") state = MathyEnvState(problem="4x + 2x + 7 + y") expression = env.parser.parse(state.agent.problem) # Can select random actions of the given type action = env.random_action(expression, AssociativeSwapRule) env.get_next_state(state, action) # Can select random actions from all types state = MathyEnvState(problem="4x + 2x + 7 + y") expression = env.parser.parse(state.agent.problem) action = env.random_action(expression) env.get_next_state(state, action)
def test_env_terminal_conditions(): expectations = [ ("70656 * (x^2 * z^6)", True), ("b * (44b^2)", False), ("z * (1274z^2)", False), ("4x^2", True), ("100y * x + 2", True), ("10y * 10x + 2", False), ("10y + 1000y * (y * z)", False), ("4 * (5y + 2)", False), ("2", True), ("4x * 2", False), ("4x * 2x", False), ("4x + 2x", False), ("4 + 2", False), ("3x + 2y + 7", True), ("3x^2 + 2x + 7", True), ("3x^2 + 2x^2 + 7", False), ] # Valid solutions but out of scope so they aren't counted as wins. # # This works because the problem sets exclude this type of > 2 term # polynomial expressions out_of_scope_valid = [] env = PolySimplify() for text, is_win in expectations + out_of_scope_valid: env_state = MathyEnvState(problem=text) reward = env.get_state_transition(env_state) assert text == text and env.is_terminal_state(env_state) == bool( is_win) assert text == text and is_terminal_transition(reward) == bool(is_win)
def test_env_action_masks(): problem = "4x + 2x" env = MathyEnv(invalid_action_response="raise") env_state = MathyEnvState(problem=problem, max_moves=35) valid_mask = env.get_valid_moves(env_state) assert len(valid_mask) == len(env.rules) assert len(valid_mask[0]) == len(env.parser.parse(problem).to_list())
def test_mathy_env_invalid_action_behaviors(): env = MathyEnv() assert env is not None problem = "5y * 9x + 8z + 8x + 3z * 10y * 11x + 10y" env_state = MathyEnvState(problem=problem, max_moves=35) for i in range(3): rule_actions = env.get_valid_moves(env_state) rule_indices = [ i for i, value in enumerate(rule_actions) if 1 in value ] random.shuffle(rule_indices) rule_nodes = rule_actions[rule_indices[0]] node_indices = [i for i, value in enumerate(rule_nodes) if value == 1] env_state, value, changed = env.get_next_state( env_state, (rule_indices[0], node_indices[0])) assert env_state.to_observation([]) is not None
def test_env_finalize_state(): env = PolySimplify() env_state = MathyEnvState(problem="4x + 2x").get_out_state( problem="1337", action=(1, 1), moves_remaining=0) with pytest.raises(ValueError): env.finalize_state(env_state) env_state = MathyEnvState(problem="4x + 2x").get_out_state( problem="4x + 2", action=(1, 1), moves_remaining=0) with pytest.raises(ValueError): env.finalize_state(env_state) env_state = MathyEnvState(problem="4x + 2x").get_out_state( problem="4x + 2y", action=(1, 1), moves_remaining=0) with pytest.raises(ValueError): env.finalize_state(env_state)
def render_features_from_text(input_text: str): global parser try: expression: MathExpression = parser.parse(input_text) state = MathyEnvState(problem=input_text) observation: MathyObservation = state.to_observation( hash_type=[13, 37]) length = len(observation.nodes) types = observation.nodes values = observation.values nodes = expression.to_list() chars = [n.name for n in nodes] assert len(types) == len(values) == len(chars) view_x = 0 view_y = 0 view_w = BOX_SIZE * length view_h = BOX_SIZE * 3 + BORDER_WIDTH * 2 tree = svgwrite.Drawing(size=(view_w, view_h)) tree.viewbox(minx=view_x, miny=view_y, width=view_w, height=view_h) curr_x = BORDER_WIDTH for n, t, v, c in zip(nodes, types, values, chars): color = svgwrite.rgb(180, 200, 255) if isinstance(n, BinaryExpression): color = svgwrite.rgb(230, 230, 230) elif isinstance(n, VariableExpression): color = svgwrite.rgb(150, 250, 150) if n == n.get_root(): color = svgwrite.rgb(250, 220, 200) box_with_char(tree, c, x=curr_x, y=BORDER_WIDTH, fill=color) box_with_char(tree, v, x=curr_x, y=BOX_SIZE + BORDER_WIDTH) box_with_char(tree, t, x=curr_x, y=BOX_SIZE * 2 + BORDER_WIDTH) curr_x += BOX_SIZE - (BORDER_WIDTH) return svgwrite.utils.pretty_xml(tree.tostring(), indent=2) except BaseException as error: return f"Failed to parse: '{input_text}' with error: {error}"
def swarm_solve( problems: Union[List[str], str], config: SwarmConfig, max_steps: Union[List[int], int] = 128, silent: bool = False, ) -> Swarm: single_problem: bool = isinstance(problems, str) if single_problem: problems = [problems] if isinstance(max_steps, int): max_steps = [max_steps ] if single_problem else [max_steps] * len(problems) assert len(problems) > 0, "no problems to solve" assert len(problems) == len(max_steps) assert isinstance(problems, list) current_problem: str = problems.pop(0) current_max_moves: int = max_steps.pop(0) def env_callable(): nonlocal current_problem, current_max_moves return FragileMathyEnv( name="mathy_v0", problem=current_problem, repeat_problem=True, max_steps=current_max_moves, ) mathy_env: MathyEnv = env_callable()._env._env.mathy swarm: Swarm = mathy_swarm(config, env_callable) while True: if not silent: with msg.loading(f"Solving {current_problem} ..."): swarm.run() else: swarm.run() if not silent: if swarm.walkers.best_reward > EnvRewards.WIN: last_state = MathyEnvState.from_np( swarm.walkers.states.best_state) msg.good( f"Solved! {current_problem} = {last_state.agent.problem}") mathy_env.print_history(last_state) else: msg.fail(f"Failed to find a solution :(") if len(max_steps) > 0: current_max_moves = max_steps.pop(0) current_problem = problems.pop(0) else: break return swarm
def test_mathy_env_preferred_term_commute(): rule_idx = 1 problem = "5y" env_state = MathyEnvState(problem=problem, max_moves=1) env = MathyEnv(preferred_term_commute=False) assert isinstance(env.rules[rule_idx], CommutativeSwapRule), "update rule_idx" commute_nodes = env.get_valid_moves(env_state)[rule_idx] assert 1 not in commute_nodes, "shouldn't be able to commute preferred order terms" env = MathyEnv(preferred_term_commute=True) commute_nodes = env.get_valid_moves(env_state)[rule_idx] assert 1 in commute_nodes, "should be able to commute preferred order terms"
def test_mathy_env_previous_state_penalty(): """When previous_state_penalty=True, a negative reward is given when revisiting already seen problem states. If an agent revisits the state too many times, the game ends.""" # We define the input problem with 3 nodes for simplicity # "x * y" == ["x","*","y"] # Because the tree is small and balanced, we can commute the # same node over and over to flip back-and-forth between x * y # and y * x. problem = "x * y" env = MathyEnv(previous_state_penalty=True) rule_idx = 1 node_idx = 1 assert isinstance(env.rules[rule_idx], CommutativeSwapRule), "update rule_idx" action = (rule_idx, node_idx) env_state = MathyEnvState(problem=problem, max_moves=10) # Commute the first time so we are revisit the initial state # as we apply the same action again. env_state, _, _ = env.get_next_state(env_state, action) # After three visits to the same state, the game ends. last_penalty = 0.0 found_terminal = False for i in range(3): env_state, transition, changed = env.get_next_state(env_state, action) assert transition.reward < 0.0 # The penalty scales up based on the number of visits to the state assert transition.reward < last_penalty last_penalty = transition.reward if i < 2: # Visit the opposite state and ignore it (we only care about revisiting # the initial state) env_state, _, _ = env.get_next_state(env_state, action) else: # After the third time, we should receive a terminal transition assert is_terminal_transition(transition) is True found_terminal = True assert found_terminal is True, "did not receive expected terminal transition"
class CustomTimestepRewards(MathyEnv): def get_rewarding_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]: return [rules.AssociativeSwapRule] def get_penalizing_actions(self, state: MathyEnvState) -> List[Type[BaseRule]]: return [rules.CommutativeSwapRule] env = CustomTimestepRewards() problem = "4x + y + 2x" expression = env.parser.parse(problem) state = MathyEnvState(problem=problem) action = env.random_action(expression, rules.AssociativeSwapRule) _, transition, _ = env.get_next_state( state, action, ) # Expect positive reward assert transition.reward > 0.0 _, transition, _ = env.get_next_state( state, env.random_action(expression, rules.CommutativeSwapRule), ) # Expect neagative reward assert transition.reward < 0.0
from mathy_core.rules import ConstantsSimplifyRule from mathy_envs import MathyEnvState, envs, is_terminal_transition class CustomEpisodeRewards(envs.PolySimplify): def get_win_signal(self, env_state: MathyEnvState) -> float: return 20.0 def get_lose_signal(self, env_state: MathyEnvState) -> float: return -20.0 env = CustomEpisodeRewards() # Win by simplifying constants and yielding a single simple term form state = MathyEnvState(problem="(4 + 2) * x") expression = env.parser.parse(state.agent.problem) action = env.random_action(expression, ConstantsSimplifyRule) out_state, transition, _ = env.get_next_state(state, action) assert is_terminal_transition(transition) is True assert transition.reward == 20.0 assert out_state.agent.problem == "6x" # Lose by applying a rule with only 1 move remaining state = MathyEnvState(problem="2x + (4 + 2) + 4x", max_moves=1) expression = env.parser.parse(state.agent.problem) action = env.random_action(expression, ConstantsSimplifyRule) out_state, transition, _ = env.get_next_state(state, action) assert is_terminal_transition(transition) is True assert transition.reward == -20.0 assert out_state.agent.problem == "2x + 6 + 4x"
def set_state(self, state: np.ndarray): assert self._env is not None, "env required to set_state" self._env.state = MathyEnvState.from_np(state) return state
rule = rules.DistributiveFactorOutRule() def transition_fn( self, env_state: MathyEnvState, expression: MathExpression, features: MathyObservation, ) -> Optional[time_step.TimeStep]: # If the rule can find any applicable nodes if self.rule.find_node(expression) is not None: # Return a terminal transition with reward return time_step.termination(features, self.get_win_signal(env_state)) # None does nothing return None env = CustomWinConditions() # This state is not terminal because none of the nodes can have the distributive # factoring rule applied to them. state_one = MathyEnvState(problem="4x + y + 2x") transition = env.get_state_transition(state_one) assert is_terminal_transition(transition) is False # This is a terminal state because the nodes representing "4x + 2x" can # have the distributive factoring rule applied to them. state_two = MathyEnvState(problem="4x + 2x + y") transition = env.get_state_transition(state_two) assert is_terminal_transition(transition) is True
@property def code(self) -> str: return "PN" def can_apply_to(self, node) -> bool: is_sub = isinstance(node, SubtractExpression) is_parent_add = isinstance(node.parent, AddExpression) return is_sub and (node.parent is None or is_parent_add) def apply_to(self, node): change = super().apply_to(node) change.save_parent() # connect result to node.parent result = AddExpression(node.left, NegateExpression(node.right)) result.set_changed() # mark this node as changed for visualization return change.done(result) class CustomActionEnv(envs.PolySimplify): def __init__(self, **kwargs): super().__init__(**kwargs) self.rules = MathyEnv.core_rules() + [PlusNegationRule()] env = CustomActionEnv() state = MathyEnvState(problem="4x - 2x") expression = env.parser.parse(state.agent.problem) action = env.random_action(expression, PlusNegationRule) out_state, transition, _ = env.get_next_state(state, action) assert out_state.agent.problem == "4x + -2x"