示例#1
0
def test_env_action_masks():
    problem = "4x + 2x"
    env = MathyEnv(invalid_action_response="raise")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    valid_mask = env.get_valid_moves(env_state)
    assert len(valid_mask) == len(env.rules)
    assert len(valid_mask[0]) == len(env.parser.parse(problem).to_list())
示例#2
0
def test_mathy_env_invalid_action_behaviors():
    env = MathyEnv()
    assert env is not None
    problem = "5y * 9x + 8z + 8x + 3z * 10y * 11x + 10y"
    env_state = MathyEnvState(problem=problem, max_moves=35)
    for i in range(3):
        rule_actions = env.get_valid_moves(env_state)
        rule_indices = [
            i for i, value in enumerate(rule_actions) if 1 in value
        ]
        random.shuffle(rule_indices)
        rule_nodes = rule_actions[rule_indices[0]]
        node_indices = [i for i, value in enumerate(rule_nodes) if value == 1]
        env_state, value, changed = env.get_next_state(
            env_state, (rule_indices[0], node_indices[0]))
    assert env_state.to_observation([]) is not None
示例#3
0
def test_mathy_env_previous_state_penalty():
    """When previous_state_penalty=True, a negative reward is given when
    revisiting already seen problem states. If an agent revisits the
    state too many times, the game ends."""

    # We define the input problem with 3 nodes for simplicity
    # "x * y" == ["x","*","y"]
    # Because the tree is small and balanced, we can commute the
    # same node over and over to flip back-and-forth between x * y
    # and y * x.
    problem = "x * y"
    env = MathyEnv(previous_state_penalty=True)
    rule_idx = 1
    node_idx = 1
    assert isinstance(env.rules[rule_idx],
                      CommutativeSwapRule), "update rule_idx"
    action = (rule_idx, node_idx)
    env_state = MathyEnvState(problem=problem, max_moves=10)
    # Commute the first time so we are revisit the initial state
    # as we apply the same action again.
    env_state, _, _ = env.get_next_state(env_state, action)

    # After three visits to the same state, the game ends.
    last_penalty = 0.0
    found_terminal = False
    for i in range(3):
        env_state, transition, changed = env.get_next_state(env_state, action)
        assert transition.reward < 0.0
        # The penalty scales up based on the number of visits to the state
        assert transition.reward < last_penalty
        last_penalty = transition.reward

        if i < 2:
            # Visit the opposite state and ignore it (we only care about revisiting
            # the initial state)
            env_state, _, _ = env.get_next_state(env_state, action)
        else:
            # After the third time, we should receive a terminal transition
            assert is_terminal_transition(transition) is True
            found_terminal = True

    assert found_terminal is True, "did not receive expected terminal transition"
示例#4
0
def test_env_random_actions():
    env = MathyEnv(invalid_action_response="raise")
    state = MathyEnvState(problem="4x + 2x + 7 + y")
    expression = env.parser.parse(state.agent.problem)
    # Can select random actions of the given type
    action = env.random_action(expression, AssociativeSwapRule)
    env.get_next_state(state, action)

    # Can select random actions from all types
    state = MathyEnvState(problem="4x + 2x + 7 + y")
    expression = env.parser.parse(state.agent.problem)
    action = env.random_action(expression)
    env.get_next_state(state, action)
示例#5
0
def test_env_init():
    env = MathyEnv()
    assert env is not None
    # Default env is abstract and cannot be directly used for problem solving
    with pytest.raises(NotImplementedError):
        env.get_initial_state()
    with pytest.raises(NotImplementedError):
        env.get_env_namespace()
示例#6
0
def test_mathy_env_preferred_term_commute():
    rule_idx = 1
    problem = "5y"
    env_state = MathyEnvState(problem=problem, max_moves=1)

    env = MathyEnv(preferred_term_commute=False)
    assert isinstance(env.rules[rule_idx],
                      CommutativeSwapRule), "update rule_idx"
    commute_nodes = env.get_valid_moves(env_state)[rule_idx]
    assert 1 not in commute_nodes, "shouldn't be able to commute preferred order terms"

    env = MathyEnv(preferred_term_commute=True)
    commute_nodes = env.get_valid_moves(env_state)[rule_idx]
    assert 1 in commute_nodes, "should be able to commute preferred order terms"
示例#7
0
def test_env_invalid_action_behaviors():

    problem = "4x + 2x"
    env = MathyEnv(invalid_action_response="raise")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    rule_actions = env.get_valid_moves(env_state)
    rule_indices = [
        i for i, value in enumerate(rule_actions) if 1 not in value
    ]
    random.shuffle(rule_indices)
    rule_nodes = rule_actions[rule_indices[0]]
    node_indices = [i for i, value in enumerate(rule_nodes) if value == 0]
    action = (rule_indices[0], node_indices[0])

    # Raise an error when selecting an invalid action
    env_state = MathyEnvState(problem=problem, max_moves=35)
    with pytest.raises(ValueError):
        env.get_next_state(env_state, action)

    # Penalize the agent for choosing an invalid action
    env = MathyEnv(invalid_action_response="penalize")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    _, transition, _ = env.get_next_state(env_state, action)
    assert transition.reward == EnvRewards.INVALID_MOVE
    assert is_terminal_transition(transition) is False

    # End the episode when choosing an invalid action
    env = MathyEnv(invalid_action_response="terminal")
    env_state = MathyEnvState(problem=problem, max_moves=35)
    _, transition, _ = env.get_next_state(env_state, action)
    # a transition is returned with error_invalid=False
    assert is_terminal_transition(transition) is True
示例#8
0
def test_env_init_check_invalid_action_response():
    with pytest.raises(ValueError):
        MathyEnv(invalid_action_response="something_wrong")  # type:ignore
    option: Any
    for option in INVALID_ACTION_RESPONSES:
        assert MathyEnv(invalid_action_response=option) is not None
示例#9
0
 def __init__(self, **kwargs):
     super().__init__(**kwargs)
     self.rules = MathyEnv.core_rules() + [PlusNegationRule()]