示例#1
0
文件: env.py 项目: mathy/mathy_envs
    def __init__(
        self,
        *,
        rules: List[BaseRule] = None,
        max_moves: int = 20,
        verbose: bool = False,
        invalid_action_response: InvalidActionResponses = "raise",
        reward_discount: float = 0.99,
        max_seq_len: int = 128,
        previous_state_penalty: bool = True,
        preferred_term_commute: bool = False,
    ):
        self.discount = reward_discount
        self.previous_state_penalty = previous_state_penalty
        self.verbose = verbose
        self.max_moves = max_moves
        self.max_seq_len = max_seq_len
        self.invalid_action_response = invalid_action_response
        self.parser = ExpressionParser()
        if rules is None:
            self.rules = MathyEnv.core_rules(
                preferred_term_commute=preferred_term_commute)
        else:
            self.rules = rules
        self.valid_actions_mask_cache = dict()
        self.valid_rules_cache = dict()

        if self.invalid_action_response not in INVALID_ACTION_RESPONSES:
            raise ValueError(
                f"Unknown invalid action behavior: {self.invalid_action_response}\n"
                f"Expected one of: {', '.join(INVALID_ACTION_RESPONSES)}")
示例#2
0
def test_problems_variable_sharing_unlike_terms() -> None:
    """Verify that the polynomial generation functions return matches that include
    shared variables for terms that are not like, e.g. "4x + x^3 + 2x"
    """
    parser = ExpressionParser()
    problem, _ = gen_simplify_multiple_terms(2,
                                             share_var_probability=1.0,
                                             noise_probability=0.0,
                                             op="+")
    expression: MathExpression = parser.parse(problem)
    term_nodes: List[MathExpression] = get_terms(expression)
    found_var: Optional[str] = None
    found_exp: bool = False
    for term_node in term_nodes:
        ex: Optional[TermEx] = get_term_ex(term_node)
        assert ex is not None, f"invalid expression {term_node}"
        if found_var is None:
            found_var = ex.variable
        assert found_var == ex.variable, "expected only one variable"
        if ex.exponent is not None:
            found_exp = True

    # Assert there are terms with and without exponents for this var
    assert found_var is not None
    assert found_exp is True
示例#3
0
def test_parser_mult_exp_precedence() -> None:
    """should respect order of operations with factor parsing"""
    parser = ExpressionParser()
    expression = parser.parse("4x^2")
    val = expression.evaluate({"x": 2})
    # 4x^2 should evaluate to 16 with x=2
    assert val == 16

    expression = parser.parse("7 * 10 * 6x * 3x + 5x")
    assert expression is not None
示例#4
0
def test_util_has_like_terms():
    examples = [
        ["14 + 6y + 7x + x * (3y)", False],
        ["b * (44b^2)", False],
        ["z * (1274z^2)", False],
        ["100y * x + 2", False],
    ]
    parser = ExpressionParser()
    for input, expected in examples:
        expr = parser.parse(input)
        assert input == input and has_like_terms(expr) == expected
示例#5
0
def test_rules_rule_can_apply_to():
    parser = ExpressionParser()
    expression = parser.parse("7 + 4x - 2")

    available_actions = [
        CommutativeSwapRule(),
        DistributiveFactorOutRule(),
        DistributiveMultiplyRule(),
        AssociativeSwapRule(),
    ]
    for action in available_actions:
        assert type(action.can_apply_to(expression)) == bool
示例#6
0
def test_parser_to_string() -> None:
    parser = ExpressionParser()
    expects = [
        {
            "input": "(-2.257893300159429e+16h^2 * v) * j^4",
            "output": "(-2.257893300159429e + 16h^2 * v) * j^4",
        },
        {
            "input": "1f + 98i + 3f + 14t",
            "output": "1f + 98i + 3f + 14t"
        },
        {
            "input": "4x * p^(1 + 3) * 12x^2",
            "output": "4x * p^(1 + 3) * 12x^2"
        },
        {
            "input": "(5 * 3) * (32 / 7)",
            "output": "(5 * 3) * 32 / 7"
        },
        {
            "input": "7 - 5 * 3 * (2^7)",
            "output": "7 - 5 * 3 * 2^7"
        },
        {
            "input": "(8x^2 * 9b) * 7",
            "output": "(8x^2 * 9b) * 7"
        },
        {
            "input": "(8 * 9b) * 7",
            "output": "(8 * 9b) * 7"
        },
        {
            "input": "7 - (5 * 3) * (32 / 7)",
            "output": "7 - (5 * 3) * 32 / 7"
        },
        {
            "input": "7 - (5 - 3) * (32 - 7)",
            "output": "7 - (5 - 3) * (32 - 7)"
        },
        {
            "input": "(7 - (5 * 3)) * (32 - 7)",
            "output": "(7 - 5 * 3) * (32 - 7)"
        },
    ]
    # Test to make sure parens are preserved in output when they are meaningful
    for expect in expects:
        expression = parser.parse(expect["input"])
        out_str = str(expression)
        assert out_str == expect["output"]
示例#7
0
def test_util_is_preferred_term_form():
    examples = [
        ["b * (44b^2)", False],
        ["z * (1274z^2)", False],
        ["4x * z", True],
        ["z * 4x", True],
        ["2x * x", False],
        ["29y", True],
        ["z", True],
        ["z * 10", False],
        ["4x^2", True],
    ]
    parser = ExpressionParser()
    for input, expected in examples:
        expr = parser.parse(input)
        assert input == input and is_preferred_term_form(expr) == expected
示例#8
0
def test_util_get_term_ex():
    examples = [
        ["-y", TermEx(-1, "y", None)],
        ["-x^3", TermEx(-1, "x", 3)],
        ["-2x^3", TermEx(-2, "x", 3)],
        ["4x^2", TermEx(4, "x", 2)],
        ["4x", TermEx(4, "x", None)],
        ["x", TermEx(None, "x", None)],
        # TODO: non-natural term forms? If this is supported we can drop the other
        #       get_term impl maybe?
        # ["x * 2", TermEx(2, "x", None)],
    ]
    parser = ExpressionParser()
    for input, expected in examples:
        expr = parser.parse(input)
        assert input == input and get_term_ex(expr) == expected
示例#9
0
def test_util_terms_are_like():
    parser = ExpressionParser()
    expr = parser.parse("10 + (7x + 6x)")
    terms = get_terms(expr)
    assert len(terms) == 3
    assert not terms_are_like(terms[0], terms[1])
    assert terms_are_like(terms[1], terms[2])

    expr = parser.parse("10 + 7x + 6")
    terms = get_terms(expr)
    assert len(terms) == 3
    assert not terms_are_like(terms[0], terms[1])
    assert terms_are_like(terms[0], terms[2])

    expr = parser.parse("6x + 6 * 5")
    terms = get_terms(expr)
    assert len(terms) == 2
    assert not terms_are_like(terms[0], terms[1])

    expr = parser.parse("360y^1")
    terms = get_terms(expr)
    assert len(terms) == 1

    expr = parser.parse("4z")
    terms = get_terms(expr)
    assert len(terms) == 1
示例#10
0
    def to_observation(
        self,
        move_mask: Optional[NodeMaskIntList] = None,
        hash_type: Optional[ProblemTypeIntList] = None,
        parser: Optional[ExpressionParser] = None,
        normalize: bool = True,
        max_seq_len: Optional[int] = None,
    ) -> MathyObservation:
        """Convert a state into an observation"""
        if parser is None:
            parser = ExpressionParser()
        if hash_type is None:
            hash_type = self.get_problem_hash()
        expression = parser.parse(self.agent.problem)
        nodes: List[MathExpression] = expression.to_list()
        vectors: NodeIntList = []
        values: NodeValuesFloatList = []
        if move_mask is None:
            move_mask = np.zeros(len(nodes))
        assert move_mask is not None
        for node in nodes:
            vectors.append(node.type_id)
            if isinstance(node, ConstantExpression):
                assert node.value is not None
                values.append(float(node.value))
            else:
                values.append(0.0)

        # The "types" and "values" can be normalized 0-1
        if normalize is True:
            # https://bit.ly/3irAalH
            x = np.asfarray(values)
            if x.sum() != 0.0:
                x = (x - min(x)) / (max(x) - min(x) + 1e-32)
            values = x.tolist()
            x = np.asfarray(vectors)
            if x.sum() != 0.0:
                x = (x - min(x)) / (max(x) - min(x) + 1e-32)
            vectors = x.tolist()

        # Pass a 0-1 value indicating the relative episode time where 0.0 is
        # the episode start, and 1.0 is the episode end as indicated by the
        # maximum allowed number of actions.
        step = int(self.max_moves - self.agent.moves_remaining)
        time = int(step / self.max_moves * 10)

        # Pad observations to max_seq_len if specified
        if max_seq_len is not None:
            values = pad_array(values, max_seq_len, 0.0)
            vectors = pad_array(vectors, max_seq_len, 0.0)
            move_mask = [pad_array(m, max_seq_len, 0) for m in move_mask]

        return MathyObservation(nodes=vectors,
                                mask=move_mask,
                                values=values,
                                type=hash_type,
                                time=[time])
示例#11
0
def test_parser_exceptions() -> None:
    parser = ExpressionParser()
    expectations = [
        ["1=5+-", InvalidSyntax, "parse_unary not expected"],
        ["x+4^-", InvalidSyntax, "parse_unary not expected"],
        ["4^4/.", ValueError, "parse_unary coerce_to_number"],
        ["4*/", InvalidSyntax, "parse_mult not expected"],
        ["4+3+3     3", TrailingTokens, "_parse trailing tokens check"],
        ["4^+", InvalidSyntax, "parse_exponent check unary"],
        ["4+", UnexpectedBehavior, "parse_add not expected and not right"],
        ["", InvalidExpression, "_parse initial next check"],
        ["4=+", UnexpectedBehavior, "parse_equal not expected"],
        ["+!", InvalidSyntax, "parse_equal first check"],
        ["=+", InvalidSyntax, "parse_equal first check"],
    ]
    for in_str, out_err, meta in expectations:
        with pytest.raises(out_err):
            parser.parse(in_str)
        assert meta != "", "add note about which parser fn throws for this case"
示例#12
0
def test_util_get_sub_terms():
    expectations = [
        ["-f", 1],
        ["70656 * (x^2 * z^6)", 2],
        ["4x^2 * z^6 * y", 3],
        ["2x^2", 1],
        ["x^2", 1],
        ["2", 1],
    ]
    invalid_expectations = [
        # can't have more than one term
        ["4 + 4", False]
    ]
    parser = ExpressionParser()
    for text, output in expectations + invalid_expectations:
        exp = parser.parse(text)
        sub_terms = get_sub_terms(exp)
        if output is False:
            assert text == text and sub_terms == output
        else:
            assert isinstance(sub_terms, list)
            assert text == text and len(sub_terms) == output
示例#13
0
文件: env.py 项目: mathy/mathy_envs
class MathyEnv:
    """Implement a math solving game where a player wins by executing the
    right sequence of actions to reduce a math expression to an agreeable
    basic representation in as few moves as possible."""

    rules: List[BaseRule]
    max_moves: int
    max_seq_len: int
    verbose: bool
    reward_discount: float
    parser: ExpressionParser
    valid_actions_mask_cache: Dict[str, List[List[int]]]
    valid_rules_cache: Dict[str, List[int]]
    invalid_action_response: InvalidActionResponses
    previous_state_penalty: bool
    preferred_term_commute: bool

    def __init__(
        self,
        *,
        rules: List[BaseRule] = None,
        max_moves: int = 20,
        verbose: bool = False,
        invalid_action_response: InvalidActionResponses = "raise",
        reward_discount: float = 0.99,
        max_seq_len: int = 128,
        previous_state_penalty: bool = True,
        preferred_term_commute: bool = False,
    ):
        self.discount = reward_discount
        self.previous_state_penalty = previous_state_penalty
        self.verbose = verbose
        self.max_moves = max_moves
        self.max_seq_len = max_seq_len
        self.invalid_action_response = invalid_action_response
        self.parser = ExpressionParser()
        if rules is None:
            self.rules = MathyEnv.core_rules(
                preferred_term_commute=preferred_term_commute)
        else:
            self.rules = rules
        self.valid_actions_mask_cache = dict()
        self.valid_rules_cache = dict()

        if self.invalid_action_response not in INVALID_ACTION_RESPONSES:
            raise ValueError(
                f"Unknown invalid action behavior: {self.invalid_action_response}\n"
                f"Expected one of: {', '.join(INVALID_ACTION_RESPONSES)}")

    @classmethod
    def core_rules(cls,
                   preferred_term_commute: bool = False) -> List[BaseRule]:
        """Return the mathy core agent actions"""
        return [
            ConstantsSimplifyRule(),
            CommutativeSwapRule(preferred=preferred_term_commute),
            DistributiveMultiplyRule(),
            DistributiveFactorOutRule(),
            AssociativeSwapRule(),
            VariableMultiplyRule(),
        ]

    @property
    def action_size(self) -> int:
        """Return the number of available actions"""
        return len(self.rules) * self.max_seq_len

    def finalize_state(self, state: MathyEnvState) -> None:
        """Perform final checks on a problem state, to ensure the episode yielded
        results that were uncorrupted by transformation errors."""
        from_timestep: MathyEnvStateStep = state.agent.history[0]
        to_timestep: MathyEnvStateStep = state.agent.history[-1]
        compare_expression_string_values(str(from_timestep.raw),
                                         str(to_timestep.raw),
                                         state.agent.history)

    def get_env_namespace(self) -> str:
        """Return a unique dot namespaced string representing the current
        environment. e.g. mycompany.envs.differentiate"""
        raise NotImplementedError("subclass must implement this")

    def get_rewarding_actions(self,
                              state: MathyEnvState) -> List[Type[BaseRule]]:
        """Get the list of rewarding action types. When these actions
        are selected, the agent gets a positive reward."""
        # NOTE: by default we give a positive reward for most actions taken. Reward
        #       values are only applied AFTER penalties, so things like reentrant
        #       states become negative reward even if their action is otherwise
        #       rewarding.
        return [
            ConstantsSimplifyRule,
            DistributiveFactorOutRule,
            VariableMultiplyRule,
        ]

    def get_penalizing_actions(self,
                               state: MathyEnvState) -> List[Type[BaseRule]]:
        """Get the list of penalizing action types. When these actions
        are selected, the agent gets a negative reward."""
        return [
            AssociativeSwapRule,
            DistributiveMultiplyRule,
        ]

    def max_moves_fn(self, problem: MathyEnvProblem,
                     config: MathyEnvProblemArgs) -> int:
        """Return the environment specific maximum move count for a given prolem."""
        return problem.complexity * 3

    def transition_fn(
        self,
        env_state: MathyEnvState,
        expression: MathExpression,
        features: MathyObservation,
    ) -> Optional[time_step.TimeStep]:
        """Provide environment-specific transitions per timestep."""
        return None

    def problem_fn(self, params: MathyEnvProblemArgs) -> MathyEnvProblem:
        """Return a problem for the environment given a set of parameters
        to control problem generation.

        This is implemented per environment so each environment can
        generate its own dataset with no required configuration."""
        raise NotImplementedError("This must be implemented in a subclass")

    def state_to_observation(
            self,
            state: MathyEnvState,
            max_seq_len: Optional[int] = None) -> MathyObservation:
        """Convert an environment state into an observation that can be used
        by a training agent."""

        action_mask = self.get_valid_moves(state)
        observation = state.to_observation(move_mask=action_mask,
                                           parser=self.parser,
                                           max_seq_len=max_seq_len)
        return observation

    def get_win_signal(self, env_state: MathyEnvState) -> float:
        """Calculate the reward value for completing the episode. This is done
        so that the reward signal can be scaled based on the time it took to
        complete the episode."""
        tiny = 3e-10
        total_moves = max(tiny, env_state.max_moves)
        # guard against divide by zero with max and a small value
        current_move = max(tiny, total_moves - env_state.agent.moves_remaining)
        bonus = (total_moves / current_move) / total_moves
        # If the episode is not very short, and the agent completes in half
        # the number of allowed steps, double the bonus signal
        if total_moves > 10 and current_move < total_moves / 2:
            bonus *= 2
        return min(2.0, EnvRewards.WIN + bonus)

    def get_lose_signal(self, env_state: MathyEnvState) -> float:
        """Calculate the reward value for failing to complete the episode. This is done
        so that the reward signal can be problem-type dependent."""
        return EnvRewards.LOSE

    def get_state_transition(self,
                             env_state: MathyEnvState) -> time_step.TimeStep:
        """Given an input state calculate the transition value of the timestep.

        # Parameters
            env_state: current env_state

        # Returns
            transition: the current state value transition
        """
        agent = env_state.agent
        expression = self.parser.parse(agent.problem)
        features = env_state.to_observation(self.get_valid_moves(env_state),
                                            parser=self.parser)
        root = expression.get_root()
        assert isinstance(root, MathExpression)

        # Subclass specific win conditions happen here. Custom win-conditions
        # outside of that can override this method entirely.
        result = self.transition_fn(env_state, root, features)
        if result is not None:
            return result

        # Check the turn count last because if the previous move that incremented
        # the turn over the count resulted in a win-condition, we want it to be honored.
        if env_state.agent.moves_remaining <= 0:
            return time_step.termination(features,
                                         self.get_lose_signal(env_state))

        # The agent is penalized for returning to a previous state.
        if self.previous_state_penalty is True:
            for key, group in groupby(
                    sorted([f"{h.raw}" for h in env_state.agent.history])):
                list_count = len(list(group))
                if list_count <= 1 or key != expression.raw:
                    continue

                # After more than (n) visits to the same state, you lose.
                if list_count > 3:
                    return time_step.termination(
                        features, self.get_lose_signal(env_state))

                # NOTE: the reward is scaled by # of times this state has been visited
                return time_step.transition(
                    features,
                    reward=EnvRewards.PREVIOUS_LOCATION * list_count,
                    discount=self.discount,
                )

        if len(agent.history) > 0:
            last_timestep = agent.history[-1]
            rule = self.get_rule_from_timestep(last_timestep)
            reward_actions = self.get_rewarding_actions(env_state)
            # The rewarding_actions can be user specified
            for rewarding_class in reward_actions:
                if isinstance(rule, rewarding_class):
                    return time_step.transition(
                        features,
                        reward=EnvRewards.HELPFUL_MOVE,
                        discount=self.discount,
                    )

            penalty_actions = self.get_penalizing_actions(env_state)
            # The rewarding_actions can be user specified
            for penalty_class in penalty_actions:
                if isinstance(rule, penalty_class):
                    return time_step.transition(
                        features,
                        reward=EnvRewards.UNHELPFUL_MOVE,
                        discount=self.discount,
                    )

        # We're in a new state, and the agent is a little older.
        return time_step.transition(features,
                                    reward=EnvRewards.TIMESTEP,
                                    discount=self.discount)

    def get_next_state(
        self, env_state: MathyEnvState, action: Union[int, ActionType]
    ) -> Tuple[MathyEnvState, time_step.TimeStep, ExpressionChangeRule]:
        """
        # Parameters
        env_state: current env_state
        action:    a tuple of two integers representing the rule and node to act on

        # Returns
        next_state: env_state after applying action

        transition: the timestep that represents the state transition

        change: the change descriptor describing the change that happened
        """
        action = self.to_action(action)
        agent = env_state.agent
        expression = self.parser.parse(agent.problem)
        assert isinstance(
            action, (tuple, list)
        ), f"Expected tuple action, but received: {type(action)} {action}"
        action_index, token_index = action
        token = self.get_token_at_index(expression, token_index)
        operation = self.rules[action_index]

        op_not_rule = not isinstance(operation, BaseRule)
        op_cannot_apply = token is None or operation.can_apply_to(
            token) is False
        if token is None or op_not_rule or op_cannot_apply:
            if self.invalid_action_response == "raise":
                steps = int(env_state.max_moves - agent.moves_remaining)
                msg = "Step: {} - Invalid action({}) '{}' for expression '{}'.".format(
                    steps, action, type(operation), expression)
                raise_with_history("Invalid Action", msg, agent.history)
            elif self.invalid_action_response == "penalize":
                #
                out_env = MathyEnvState.copy(env_state)
                obs = out_env.to_observation(self.get_valid_moves(out_env),
                                             parser=self.parser)
                transition = time_step.transition(obs, EnvRewards.INVALID_MOVE)
                return out_env, transition, ExpressionChangeRule(BaseRule())
            elif self.invalid_action_response == "terminal":
                out_env = MathyEnvState.copy(env_state)
                obs = out_env.to_observation(self.get_valid_moves(out_env),
                                             parser=self.parser)
                transition = time_step.termination(
                    obs, self.get_lose_signal(env_state))
                return out_env, transition, ExpressionChangeRule(BaseRule())

        assert token is not None
        change = operation.apply_to(token.clone_from_root())
        assert change.result is not None
        root = change.result.get_root()
        change_name = operation.name
        out_problem = str(root)
        out_env = env_state.get_out_state(
            problem=out_problem,
            action=action,
            moves_remaining=agent.moves_remaining - 1,
        )

        transition = self.get_state_transition(out_env)
        if self.verbose:
            token_idx = int("{}".format(token_index).zfill(3))
            self.print_state(out_env, change_name[:25].lower(), token_idx,
                             change, transition.reward)
        return out_env, transition, change

    def print_state(
        self,
        env_state: MathyEnvState,
        action_name: str,
        token_index: int = -1,
        change: ExpressionChangeRule = None,
        change_reward: float = 0.0,
        pretty: bool = False,
    ) -> None:
        """Render the given state to stdout for visualization"""
        print(
            self.render_state(env_state, action_name, token_index, change,
                              change_reward, pretty))

    def is_terminal_state(self, env_state: MathyEnvState) -> bool:
        """Determine if a given state is terminal or not.

        # Arguments
        env_state (MathyEnvState): The state to inspect

        # Returns
        (bool): A boolean indicating if the state is terminal or not.
        """
        return is_terminal_transition(self.get_state_transition(env_state))

    def print_history(self,
                      env_state: MathyEnvState,
                      pretty: bool = True) -> None:
        """Render the history of an episode from a given state.

        # Arguments
        env_state (MathyEnvState): The state to render the history of.
        """
        history: List[MathyEnvStateStep] = env_state.agent.history[:]
        initial_step: MathyEnvStateStep = history.pop(0)
        curr_state: MathyEnvState = MathyEnvState(
            problem=initial_step.raw,
            max_moves=env_state.max_moves,
        )
        self.print_state(curr_state, "initial-state", pretty=pretty)
        while len(history) > 0:
            step: MathyEnvStateStep = history.pop(0)
            curr_state, transition, change = self.get_next_state(
                curr_state, step.action)
            rule_idx, token_idx = step.action
            rule: BaseRule = self.rules[rule_idx]
            rule_name: str = rule.name[:25].lower()
            self.print_state(
                pretty=pretty,
                env_state=curr_state,
                action_name=rule_name,
                token_index=int(f"{token_idx}".zfill(3)),
                change=change,
                change_reward=transition.reward,
            )

    def render_state(
        self,
        env_state: MathyEnvState,
        action_name: str,
        token_index: int = -1,
        change: ExpressionChangeRule = None,
        change_reward: float = 0.0,
        pretty: bool = False,
    ) -> str:
        """Render the given state to a string suitable for printing to a log"""
        changed_problem = env_state.agent.problem
        if change is not None and change.result is not None:
            root = change.result.get_root()
            assert isinstance(root, MathExpression)
            changed_problem = root.terminal_text

        action_name = f"{action_name.lower()}({token_index})"
        output = """{:<25} | {}""".format(action_name.lower(), changed_problem)

        def get_move_shortname(index: int, move: int) -> str:
            if move == 0:
                return "--"
            if move >= len(self.rules):
                return "xx"
            return self.rules[index].code.lower()

        moves_left = str(env_state.agent.moves_remaining).zfill(2)
        valid_rules = self.get_valid_rules(env_state)
        valid_moves = self.get_valid_moves(env_state)
        num_moves = "{}".format(len(np.nonzero(valid_moves)[0])).zfill(3)
        move_codes = [
            get_move_shortname(i, m) for i, m in enumerate(valid_rules)
        ]
        moves = " ".join(move_codes)
        reward = f"{change_reward:.2}"
        reward = f"{reward:<5}"
        if pretty:
            return output
        return f"{num_moves} | {moves} | {moves_left} | {reward} | {output}"

    def random_action(
        self,
        expression: MathExpression,
        rule: Type[BaseRule] = None,
    ) -> Tuple[int, int]:
        """Get a random action index that represents a particular rule"""

        if rule is not None:
            found = -1
            for rule_idx, r in enumerate(self.rules):
                if isinstance(r, rule):  # type:ignore
                    found = rule_idx
                    break
            if found == -1:
                raise ValueError(
                    "The action {rule} does not exist in the environment rule list"
                )
            all_actions = self.get_actions_for_node(expression, [rule])
            valid_actions = np.nonzero(all_actions[found])
            action = random.choice(valid_actions[0])
            return (found, int(action))

        all_actions = self.get_actions_for_node(expression)
        valid_rules = [i for i, r in enumerate(all_actions) if 1 in r]
        if len(valid_rules) == 0:
            raise ValueError(f"no valid actions for expression: {expression}")
        chosen_rule = random.choice(valid_rules)
        valid_actions = [
            i for i, r in enumerate(all_actions[chosen_rule]) if r == 1
        ]
        action = random.choice(valid_actions)
        return chosen_rule, action

    def get_initial_state(
            self,
            params: Optional[MathyEnvProblemArgs] = None,
            print_problem: bool = True
    ) -> Tuple[MathyEnvState, MathyEnvProblem]:
        """Generate an initial MathyEnvState for an episode"""
        config = params if params is not None else MathyEnvProblemArgs()
        prob: MathyEnvProblem = self.problem_fn(config)
        self.valid_actions_mask_cache = dict()
        self.valid_rules_cache = dict()
        self.parser.clear_cache()
        self.max_moves = self.max_moves_fn(prob, config)

        # Build and return the initial state
        env_state = MathyEnvState(
            problem=prob.text,
            problem_type=self.get_env_namespace(),
            max_moves=self.max_moves,
            num_rules=len(self.rules),
        )
        if print_problem and self.verbose:
            self.print_state(env_state, "initial-state")
        return env_state, prob

    def get_agent_actions_count(self, env_state: MathyEnvState) -> int:
        """Return number of all possible actions"""
        node_count = len(self.parser.parse(env_state.agent.problem).to_list())
        return self.action_size * node_count

    def get_token_at_index(self, expression: MathExpression,
                           index: int) -> Optional[MathExpression]:
        """Get the token that is `index` from the left of the expression"""
        count = 0
        result: Optional[MathExpression] = None

        def visit_fn(node: BinaryTreeNode, depth: int,
                     data: Any) -> Optional[VisitStop]:
            nonlocal result, count
            result = node  # type:ignore
            if count == index:
                return STOP
            count = count + 1
            return None

        expression.visit_inorder(visit_fn)
        return result

    def get_valid_moves(self, env_state: MathyEnvState) -> List[List[int]]:
        """Get a 2d list describing the valid moves for the current state.

        The first dimension contains the list of known rules in the order that
        they're registered, and the second dimension contains a list of the max
        sequence length size that is 1/0 representing that the node at that index
        for the given rule is valid.
        """
        agent = env_state.agent
        expression: Optional[MathExpression] = None
        try:
            expression = self.parser.parse(agent.problem)
        except InvalidSyntax as err:
            raise_with_history(self.get_env_namespace(), err.message,
                               env_state.agent.history)
        assert expression is not None
        return self.get_actions_for_node(expression)

    def get_valid_rules(self, env_state: MathyEnvState) -> List[int]:
        """Get a vector the length of the number of valid rules that is
        filled with 0/1 based on whether the rule has any nodes in the
        expression that it can be applied to.

        !!! note

            If you want to get a list of which nodes each rule can be
            applied to, prefer to use the `get_valid_moves` method.
        """
        key = self.to_hash_key(env_state)
        if key in self.valid_rules_cache:
            return self.valid_rules_cache[key]
        expression = self.parser.parse(env_state.agent.problem)
        actions = [0] * len(self.rules)
        for rule_index, rule in enumerate(self.rules):
            nodes = rule.find_nodes(expression)
            actions[rule_index] = 0 if len(nodes) == 0 else 1
        self.valid_rules_cache[key] = actions[:]
        return actions

    def get_rule_from_timestep(self, time_step: MathyEnvStateStep) -> BaseRule:
        return self.rules[time_step.action[0]]

    def get_actions_for_node(
        self,
        expression: MathExpression,
        rule_list: List[Type[BaseRule]] = None,
    ) -> List[List[int]]:
        """Return a valid actions mask for the given expression and rule list.

        Action masks are 2d lists of length (num_rules, max_seq_len) where a 0 indicates
        the action is not valid in the current state, and a 1 indicates that it is
        a valid action to take."""
        key = str(expression)
        if rule_list is None and key in self.valid_actions_mask_cache:
            return self.valid_actions_mask_cache[key][:]
        node_count = len(expression.to_list())
        rule_count = len(self.rules)
        actions = [[0] * node_count for _ in range(rule_count)]
        for rule_index, rule in enumerate(self.rules):
            if rule_list is not None:
                if not isinstance(rule, tuple(rule_list)):  # type:ignore
                    continue
            nodes = rule.find_nodes(expression)
            for node in nodes:
                assert node.r_index is not None
                actions[rule_index][node.r_index] = 1
        if rule_list is None:
            self.valid_actions_mask_cache[key] = actions[:]
        return actions

    def to_hash_key(self, env_state: MathyEnvState) -> str:
        """Convert env_state to a string for MCTS cache"""
        return env_state.agent.problem

    def to_action(self, action: Union[int, ActionType]) -> ActionType:
        """Resolve a given action input to a tuple of (rule_index, node_index).

        When given an int, it is treated as an index into the flattened 2d action
        space. When given a tuple, it is assumed to be (rule, node)"""
        if isinstance(action, (tuple, list)):
            return action
        token_index = action % self.max_seq_len
        action_index = int((action - token_index) / self.max_seq_len)
        return action_index, token_index
示例#14
0
def test_parser_factorials() -> None:
    """should parse factorials"""
    parser = ExpressionParser()
    expression = parser.parse("5!")
    # 5! = 5 * 4 * 3 * 2 * 1 = 120
    assert expression.evaluate() == 120