Exemplo n.º 1
0
    def transition_fn(
        self,
        env_state: MathyEnvState,
        expression: MathExpression,
        features: MathyObservation,
    ) -> Optional[time_step.TimeStep]:
        """If all like terms are siblings."""
        term_nodes = get_terms(expression)
        already_seen: set = set()
        current_term = ""
        # Iterate over each term in order and build a unique key to identify its
        # term likeness. For this we drop the coefficient from the term and use
        # only its variable/exponent to build keys.
        for term in term_nodes:
            ex: Optional[TermEx] = get_term_ex(term)
            if ex is None:
                continue
            key = f"{ex.variable}{ex.exponent}"
            # If the key is in the "already seen and moved on" list then we've failed
            # to meet the completion criteria. e.g. the final x in "4x + 2y + x"
            if key in already_seen:
                return None
            if key != current_term:
                already_seen.add(current_term)
                current_term = key

        return time_step.termination(features, self.get_win_signal(env_state))
Exemplo n.º 2
0
def test_problems_variable_sharing_unlike_terms() -> None:
    """Verify that the polynomial generation functions return matches that include
    shared variables for terms that are not like, e.g. "4x + x^3 + 2x"
    """
    parser = ExpressionParser()
    problem, _ = gen_simplify_multiple_terms(2,
                                             share_var_probability=1.0,
                                             noise_probability=0.0,
                                             op="+")
    expression: MathExpression = parser.parse(problem)
    term_nodes: List[MathExpression] = get_terms(expression)
    found_var: Optional[str] = None
    found_exp: bool = False
    for term_node in term_nodes:
        ex: Optional[TermEx] = get_term_ex(term_node)
        assert ex is not None, f"invalid expression {term_node}"
        if found_var is None:
            found_var = ex.variable
        assert found_var == ex.variable, "expected only one variable"
        if ex.exponent is not None:
            found_exp = True

    # Assert there are terms with and without exponents for this var
    assert found_var is not None
    assert found_exp is True
Exemplo n.º 3
0
def test_util_get_term_ex():
    examples = [
        ["-y", TermEx(-1, "y", None)],
        ["-x^3", TermEx(-1, "x", 3)],
        ["-2x^3", TermEx(-2, "x", 3)],
        ["4x^2", TermEx(4, "x", 2)],
        ["4x", TermEx(4, "x", None)],
        ["x", TermEx(None, "x", None)],
        # TODO: non-natural term forms? If this is supported we can drop the other
        #       get_term impl maybe?
        # ["x * 2", TermEx(2, "x", None)],
    ]
    parser = ExpressionParser()
    for input, expected in examples:
        expr = parser.parse(input)
        assert input == input and get_term_ex(expr) == expected
Exemplo n.º 4
0
    def transition_fn(
        self,
        env_state: MathyEnvState,
        expression: MathExpression,
        features: MathyObservation,
    ) -> Optional[time_step.TimeStep]:
        """If all like terms are siblings."""
        agent = env_state.agent
        if len(agent.history) == 0:
            return None
        # History gets pushed before this fn, so history[-1] is the current state,
        # and history[-2] is the previous state. Find the previous state node we
        # acted on, and compare to that.
        curr_timestep: MathyEnvStateStep = agent.history[-1]
        last_timestep: MathyEnvStateStep = agent.history[-2]
        expression = self.parser.parse(last_timestep.raw)
        action_node = self.get_token_at_index(expression, curr_timestep.action[1])
        touched_term = get_term_ex(action_node)

        term_nodes = get_terms(expression)
        # We have the token_index of the term that was acted on, now we have to see
        # if that term has any like siblings (not itself). We do this by ignoring the
        # term with a matching r_index to the node the agent acted on.
        #
        # find_nodes updates the `r_index` value on each node which is the token index
        BaseRule().find_nodes(expression)

        like_counts: Dict[str, int] = {}
        all_indices: Dict[str, List[int]] = {}
        max_index = 0
        for term_node in term_nodes:
            assert term_node is not None and term_node.r_index is not None
            max_index = max(max_index, term_node.r_index)
            ex: Optional[TermEx] = get_term_ex(term_node)
            if ex is None:
                continue

            key = mathy_term_string(variable=ex.variable, exponent=ex.exponent)
            if key == "":
                key = "const"
            if key not in like_counts:
                like_counts[key] = 1
            else:
                like_counts[key] += 1
            if key not in all_indices:
                all_indices[key] = [term_node.r_index]
            else:
                all_indices[key].append(term_node.r_index)

        like_indices: Optional[List[int]] = None
        for key in all_indices.keys():
            if len(all_indices[key]) > 1:
                like_indices = all_indices[key]
        if action_node is not None and touched_term is not None:
            touched_key = mathy_term_string(
                variable=touched_term.variable, exponent=touched_term.exponent
            )
            if touched_key in like_counts and like_counts[touched_key] > 1:
                action_node.all_changed()
                return time_step.termination(features, self.get_win_signal(env_state))

        if env_state.agent.moves_remaining <= 0:
            distances = []
            if like_indices is not None:
                assert action_node is not None and action_node.r_index is not None
                for index in like_indices:
                    distances.append(abs(index - action_node.r_index))
                loss_magnitude = min(distances) / max_index
            else:
                loss_magnitude = 1.0
            lose_signal = EnvRewards.LOSE - loss_magnitude
            return time_step.termination(features, lose_signal)
        return None