def transition_fn( self, env_state: MathyEnvState, expression: MathExpression, features: MathyObservation, ) -> Optional[time_step.TimeStep]: """If all like terms are siblings.""" term_nodes = get_terms(expression) already_seen: set = set() current_term = "" # Iterate over each term in order and build a unique key to identify its # term likeness. For this we drop the coefficient from the term and use # only its variable/exponent to build keys. for term in term_nodes: ex: Optional[TermEx] = get_term_ex(term) if ex is None: continue key = f"{ex.variable}{ex.exponent}" # If the key is in the "already seen and moved on" list then we've failed # to meet the completion criteria. e.g. the final x in "4x + 2y + x" if key in already_seen: return None if key != current_term: already_seen.add(current_term) current_term = key return time_step.termination(features, self.get_win_signal(env_state))
def test_problems_variable_sharing_unlike_terms() -> None: """Verify that the polynomial generation functions return matches that include shared variables for terms that are not like, e.g. "4x + x^3 + 2x" """ parser = ExpressionParser() problem, _ = gen_simplify_multiple_terms(2, share_var_probability=1.0, noise_probability=0.0, op="+") expression: MathExpression = parser.parse(problem) term_nodes: List[MathExpression] = get_terms(expression) found_var: Optional[str] = None found_exp: bool = False for term_node in term_nodes: ex: Optional[TermEx] = get_term_ex(term_node) assert ex is not None, f"invalid expression {term_node}" if found_var is None: found_var = ex.variable assert found_var == ex.variable, "expected only one variable" if ex.exponent is not None: found_exp = True # Assert there are terms with and without exponents for this var assert found_var is not None assert found_exp is True
def test_util_get_term_ex(): examples = [ ["-y", TermEx(-1, "y", None)], ["-x^3", TermEx(-1, "x", 3)], ["-2x^3", TermEx(-2, "x", 3)], ["4x^2", TermEx(4, "x", 2)], ["4x", TermEx(4, "x", None)], ["x", TermEx(None, "x", None)], # TODO: non-natural term forms? If this is supported we can drop the other # get_term impl maybe? # ["x * 2", TermEx(2, "x", None)], ] parser = ExpressionParser() for input, expected in examples: expr = parser.parse(input) assert input == input and get_term_ex(expr) == expected
def transition_fn( self, env_state: MathyEnvState, expression: MathExpression, features: MathyObservation, ) -> Optional[time_step.TimeStep]: """If all like terms are siblings.""" agent = env_state.agent if len(agent.history) == 0: return None # History gets pushed before this fn, so history[-1] is the current state, # and history[-2] is the previous state. Find the previous state node we # acted on, and compare to that. curr_timestep: MathyEnvStateStep = agent.history[-1] last_timestep: MathyEnvStateStep = agent.history[-2] expression = self.parser.parse(last_timestep.raw) action_node = self.get_token_at_index(expression, curr_timestep.action[1]) touched_term = get_term_ex(action_node) term_nodes = get_terms(expression) # We have the token_index of the term that was acted on, now we have to see # if that term has any like siblings (not itself). We do this by ignoring the # term with a matching r_index to the node the agent acted on. # # find_nodes updates the `r_index` value on each node which is the token index BaseRule().find_nodes(expression) like_counts: Dict[str, int] = {} all_indices: Dict[str, List[int]] = {} max_index = 0 for term_node in term_nodes: assert term_node is not None and term_node.r_index is not None max_index = max(max_index, term_node.r_index) ex: Optional[TermEx] = get_term_ex(term_node) if ex is None: continue key = mathy_term_string(variable=ex.variable, exponent=ex.exponent) if key == "": key = "const" if key not in like_counts: like_counts[key] = 1 else: like_counts[key] += 1 if key not in all_indices: all_indices[key] = [term_node.r_index] else: all_indices[key].append(term_node.r_index) like_indices: Optional[List[int]] = None for key in all_indices.keys(): if len(all_indices[key]) > 1: like_indices = all_indices[key] if action_node is not None and touched_term is not None: touched_key = mathy_term_string( variable=touched_term.variable, exponent=touched_term.exponent ) if touched_key in like_counts and like_counts[touched_key] > 1: action_node.all_changed() return time_step.termination(features, self.get_win_signal(env_state)) if env_state.agent.moves_remaining <= 0: distances = [] if like_indices is not None: assert action_node is not None and action_node.r_index is not None for index in like_indices: distances.append(abs(index - action_node.r_index)) loss_magnitude = min(distances) / max_index else: loss_magnitude = 1.0 lose_signal = EnvRewards.LOSE - loss_magnitude return time_step.termination(features, lose_signal) return None