def test_get_new_states_probs(self): state = states.ProductionRulesState( self._strings_to_production_rules([ 'S -> S "+" T', 'S -> T', ])) # The above production rules sequence are parsed as # S # | # S '+' T # | # T # # Since the order of the production rules sequence is the preorder traversal # of the parsing tree, the next symbol to parse is the 'T' on the left side # of the above parsing tree. Only production rule with left hand side symbol # T are valid production rule. # Thus, for grammar with production rules: # 'S -> S "+" T' # 'S -> T' # 'T -> "(" S ")"' # 'T -> "x"' # Appending the first two production rules will create invalid state, with # prior probabilities nan. The last two production rules can be appended # and will create new states, with equal prior probabilities. expected_new_states = [ None, None, states.ProductionRulesState( self._strings_to_production_rules([ 'S -> S "+" T', 'S -> T', 'T -> "(" S ")"', ])), states.ProductionRulesState( self._strings_to_production_rules([ 'S -> S "+" T', 'S -> T', 'T -> "x"', ])), ] policy = policies.ProductionRuleAppendPolicy(grammar=self.grammar) new_states, action_probs = policy.get_new_states_probs(state) np.testing.assert_allclose(action_probs, [np.nan, np.nan, 0.5, 0.5]) self.assertEqual(len(new_states), len(expected_new_states)) for new_state, expected_new_state in zip(new_states, expected_new_states): self.assertEqual(new_state, expected_new_state)
def test_get_new_states_probs_type_error(self): policy = policies.ProductionRuleAppendPolicy(grammar=self.grammar) with self.assertRaisesRegexp( TypeError, r'Input state shoud be an instance of ' r'states\.ProductionRulesState'): policy.get_new_states_probs(states.StateBase())