示例#1
0
    def test_get_new_states_probs(self):
        state = states.ProductionRulesState(
            self._strings_to_production_rules([
                'S -> S "+" T',
                'S -> T',
            ]))
        # The above production rules sequence are parsed as
        #                 S
        #                 |
        #              S '+' T
        #              |
        #              T
        #
        # Since the order of the production rules sequence is the preorder traversal
        # of the parsing tree, the next symbol to parse is the 'T' on the left side
        # of the above parsing tree. Only production rule with left hand side symbol
        # T are valid production rule.
        # Thus, for grammar with production rules:
        # 'S -> S "+" T'
        # 'S -> T'
        # 'T -> "(" S ")"'
        # 'T -> "x"'
        # Appending the first two production rules will create invalid state, with
        # prior probabilities nan. The last two production rules can be appended
        # and will create new states, with equal prior probabilities.
        expected_new_states = [
            None,
            None,
            states.ProductionRulesState(
                self._strings_to_production_rules([
                    'S -> S "+" T',
                    'S -> T',
                    'T -> "(" S ")"',
                ])),
            states.ProductionRulesState(
                self._strings_to_production_rules([
                    'S -> S "+" T',
                    'S -> T',
                    'T -> "x"',
                ])),
        ]

        policy = policies.ProductionRuleAppendPolicy(grammar=self.grammar)
        new_states, action_probs = policy.get_new_states_probs(state)

        np.testing.assert_allclose(action_probs, [np.nan, np.nan, 0.5, 0.5])
        self.assertEqual(len(new_states), len(expected_new_states))
        for new_state, expected_new_state in zip(new_states,
                                                 expected_new_states):
            self.assertEqual(new_state, expected_new_state)
示例#2
0
 def test_get_new_states_probs_type_error(self):
     policy = policies.ProductionRuleAppendPolicy(grammar=self.grammar)
     with self.assertRaisesRegexp(
             TypeError, r'Input state shoud be an instance of '
             r'states\.ProductionRulesState'):
         policy.get_new_states_probs(states.StateBase())