def test_set_default_value(self): reward = rewards.RewardBase() # Default None. self.assertIsNone(reward._default_value) # The default value can be changed. reward.set_default_value(42) self.assertAlmostEqual(reward._default_value, 42.) # The default value can be changed multiple times. reward.set_default_value(-1.5) self.assertAlmostEqual(reward._default_value, -1.5)
def test_evaluate_not_terminal_without_default_value(self): not_terminal_state = states.ProductionRulesState( production_rules_sequence=[]) not_terminal_state.is_terminal = mock.MagicMock(return_value=False) reward = rewards.RewardBase(allow_nonterminal=False, default_value=None) with self.assertRaisesRegexp(ValueError, 'allow_nonterminal is False and ' 'default_value is None, but state is not ' 'terminal'): reward.evaluate(not_terminal_state) # ValueError will not be raised if default value is set. reward.set_default_value(42) self.assertAlmostEqual(reward.evaluate(not_terminal_state), 42.)
def test_evaluate_not_terminal_with_default_value(self): not_terminal_state = states.ProductionRulesState( production_rules_sequence=[]) not_terminal_state.is_terminal = mock.MagicMock(return_value=False) reward = rewards.RewardBase(allow_nonterminal=False, default_value=42) self.assertAlmostEqual(reward.evaluate(not_terminal_state), 42)
def test_evaluate_not_implemented(self): state = states.ProductionRulesState(production_rules_sequence=[]) reward = rewards.RewardBase() with self.assertRaisesRegexp(NotImplementedError, 'Must be implemented by subclass'): reward.evaluate(state)
def test_set_post_transformer_not_callable(self): with self.assertRaisesRegexp(TypeError, 'post_transformer is not callable'): reward = rewards.RewardBase() reward.set_post_transformer(post_transformer=42)