def test_skip_retraining_fn(self): env = test_util.DummyEnv() burnin = 10 def _skip_retraining(action, observation): """Always skip retraining.""" del action, observation return True params = classifier_agents.ScoringAgentParams( burnin=burnin, freeze_classifier_after_burnin=False, default_action_fn=env.action_space.sample, feature_keys=["x"], skip_retraining_fn=_skip_retraining, ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) for _ in range(burnin + 1): self.assertFalse(agent.frozen) _ = agent.act(env.observation_space.sample(), False) self.assertFalse(agent.frozen) # Agent is not frozen. self.assertFalse(agent.global_threshold) # Agent has not learned.
def test_threshold_history_is_recorded(self): observation_space = gym.spaces.Dict( { "x": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32), "group": gym.spaces.MultiDiscrete([1]), } ) observation_space.seed(100) params = classifier_agents.ScoringAgentParams( default_action_fn=lambda: 0, feature_keys=["x"], group_key="group", burnin=0, threshold_policy=threshold_policies.ThresholdPolicy.EQUALIZE_OPPORTUNITY, ) agent = classifier_agents.ThresholdAgent( observation_space=observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) for _ in range(10): agent.act(observation_space.sample(), False) self.assertLen(agent.global_threshold_history, 10) self.assertTrue(agent.group_specific_threshold_history) for _, history in agent.group_specific_threshold_history.items(): # Takes 2 extra steps (one to observe features and one to observe label) # before any learned group-specific threshold is available. self.assertLen(history, 8)
def test_agent_can_learn_different_thresholds(self): observation_space = gym.spaces.Dict( { "x": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32), "group": gym.spaces.Discrete(2), } ) params = classifier_agents.ScoringAgentParams( default_action_fn=lambda: 0, feature_keys=["x"], group_key="group", threshold_policy=threshold_policies.ThresholdPolicy.EQUALIZE_OPPORTUNITY, ) rng = np.random.RandomState(100) agent = classifier_agents.ThresholdAgent( observation_space=observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, rng=rng, ) # Train over the whole range of observations. Expect slightly different # thresholds to be learned. for observation in rng.rand(100): for group in [0, 1]: agent._act_impl( {"x": np.array([observation]), "group": np.array([group])}, reward=observation > 0.5 + 0.1 * group, done=False, ) agent.frozen = True actions = {} for group in [0, 1]: actions[group] = [] for observation in np.linspace(0, 1, 1000): actions[group].append( agent.act( {"x": np.array([observation]), "group": np.array([group])}, done=False ) ) # The two groups are classified with different policies so they are not # exactly equal. self.assertNotEqual(actions[0], actions[1]) self.assertLen(agent.group_specific_thresholds, 2)
def test_one_hot_conversion(self): observation_space = gym.spaces.Dict({"x": multinomial.Multinomial(10, 1)}) params = classifier_agents.ScoringAgentParams( default_action_fn=lambda: 0, feature_keys=["x"], convert_one_hot_to_integer=True, threshold_policy=threshold_policies.ThresholdPolicy.SINGLE_THRESHOLD, ) agent = classifier_agents.ThresholdAgent( observation_space=observation_space, reward_fn=rewards.NullReward(), params=params ) self.assertEqual(agent._get_features({"x": _one_hot(5)}), [5])
def test_interact_with_env_replicable(self): env = test_util.DummyEnv() params = classifier_agents.ScoringAgentParams( burnin=10, freeze_classifier_after_burnin=False, default_action_fn=env.action_space.sample, feature_keys=["x"], ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) test_util.run_test_simulation(env=env, agent=agent)
def test_frozen_classifier_never_trains(self): env = test_util.DummyEnv() params = classifier_agents.ScoringAgentParams( burnin=0, default_action_fn=env.action_space.sample, feature_keys=["x"] ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, frozen=True, ) # Initialize global_threshold with a distinctive value. agent.global_threshold = 0.123 # Run for some number of steps, global_threshold should not change. for _ in range(10): agent.act(env.observation_space.sample(), False) self.assertEqual(agent.global_threshold, 0.123)
def test_agent_seed(self): env = test_util.DummyEnv() params = classifier_agents.ScoringAgentParams( burnin=10, freeze_classifier_after_burnin=False, default_action_fn=env.action_space.sample, feature_keys=["x"], ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) agent.seed(100) a = agent.rng.randint(0, 1000) agent.seed(100) b = agent.rng.randint(0, 1000) self.assertEqual(a, b)
def test_freeze_after_burnin(self): env = test_util.DummyEnv() burnin = 10 params = classifier_agents.ScoringAgentParams( burnin=burnin, freeze_classifier_after_burnin=True, default_action_fn=env.action_space.sample, feature_keys=["x"], ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) for _ in range(burnin + 1): self.assertFalse(agent.frozen) _ = agent.act(env.observation_space.sample(), False) self.assertTrue(agent.frozen) self.assertTrue(agent.global_threshold) # Agent has learned something.
def test_agent_on_one_hot_vectors(self): # Space of 1-hot vectors of length 10. observation_space = gym.spaces.Dict({"x": multinomial.Multinomial(10, 1)}) params = classifier_agents.ScoringAgentParams( default_action_fn=lambda: 0, feature_keys=["x"], convert_one_hot_to_integer=True, burnin=999, threshold_policy=threshold_policies.ThresholdPolicy.SINGLE_THRESHOLD, ) agent = classifier_agents.ThresholdAgent( observation_space=observation_space, reward_fn=rewards.NullReward(), params=params ) observation_space.seed(100) # Train a boundary at 3 using 1-hot vectors. observation = observation_space.sample() agent._act_impl(observation, reward=None, done=False) for _ in range(1000): last_observation = observation observation = observation_space.sample() agent._act_impl( observation, reward=int(np.argmax(last_observation["x"]) >= 3), done=False ) if agent._training_corpus.examples: assert ( int(agent._training_corpus.examples[-1].features[0] >= 3) == agent._training_corpus.examples[-1].label ) agent.frozen = True self.assertTrue(agent.act({"x": _one_hot(3)}, done=False)) self.assertFalse(agent.act({"x": _one_hot(2)}, done=False))
def test_agent_trains(self): env = test_util.DummyEnv() params = classifier_agents.ScoringAgentParams( burnin=200, default_action_fn=env.action_space.sample, feature_keys=["x"] ) agent = classifier_agents.ThresholdAgent( action_space=env.action_space, observation_space=env.observation_space, reward_fn=rewards.BinarizedScalarDeltaReward("x"), params=params, ) # Train with points that are nearly separable but have some overlap between # 0.3 and 0.4. for observation in np.linspace(0, 0.4, 100): agent._act_impl({"x": np.array([observation])}, reward=0, done=False) for observation in np.linspace(0.3, 0.8, 100): agent._act_impl({"x": np.array([observation])}, reward=1, done=False) # Add a negative point at the top of the range so that the training labels # are not fit perfectly by a threshold. agent._act_impl({"x": np.array([0.9])}, reward=0, done=False) agent.frozen = True actions = [ agent.act({"x": np.array([obs])}, done=False) for obs in np.linspace(0, 0.95, 100) ] # Assert some actions are 0 and some are 1. self.assertSameElements(actions, {0, 1}) # Assert actions are sorted - i.e., 0s followed by 1s. self.assertSequenceEqual(actions, sorted(actions)) self.assertGreater(agent.global_threshold, 0) self.assertFalse(agent.group_specific_thresholds)