def test_skip_retraining_fn(self):
        env = test_util.DummyEnv()
        burnin = 10

        def _skip_retraining(action, observation):
            """Always skip retraining."""
            del action, observation
            return True

        params = classifier_agents.ScoringAgentParams(
            burnin=burnin,
            freeze_classifier_after_burnin=False,
            default_action_fn=env.action_space.sample,
            feature_keys=["x"],
            skip_retraining_fn=_skip_retraining,
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )

        for _ in range(burnin + 1):
            self.assertFalse(agent.frozen)
            _ = agent.act(env.observation_space.sample(), False)

        self.assertFalse(agent.frozen)  # Agent is not frozen.
        self.assertFalse(agent.global_threshold)  # Agent has not learned.
    def test_threshold_history_is_recorded(self):
        observation_space = gym.spaces.Dict(
            {
                "x": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
                "group": gym.spaces.MultiDiscrete([1]),
            }
        )
        observation_space.seed(100)

        params = classifier_agents.ScoringAgentParams(
            default_action_fn=lambda: 0,
            feature_keys=["x"],
            group_key="group",
            burnin=0,
            threshold_policy=threshold_policies.ThresholdPolicy.EQUALIZE_OPPORTUNITY,
        )

        agent = classifier_agents.ThresholdAgent(
            observation_space=observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )

        for _ in range(10):
            agent.act(observation_space.sample(), False)

        self.assertLen(agent.global_threshold_history, 10)
        self.assertTrue(agent.group_specific_threshold_history)
        for _, history in agent.group_specific_threshold_history.items():
            # Takes 2 extra steps (one to observe features and one to observe label)
            # before any learned group-specific threshold is available.
            self.assertLen(history, 8)
    def test_agent_can_learn_different_thresholds(self):

        observation_space = gym.spaces.Dict(
            {
                "x": gym.spaces.Box(low=0, high=1, shape=(1,), dtype=np.float32),
                "group": gym.spaces.Discrete(2),
            }
        )

        params = classifier_agents.ScoringAgentParams(
            default_action_fn=lambda: 0,
            feature_keys=["x"],
            group_key="group",
            threshold_policy=threshold_policies.ThresholdPolicy.EQUALIZE_OPPORTUNITY,
        )

        rng = np.random.RandomState(100)

        agent = classifier_agents.ThresholdAgent(
            observation_space=observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
            rng=rng,
        )

        # Train over the whole range of observations. Expect slightly different
        # thresholds to be learned.
        for observation in rng.rand(100):
            for group in [0, 1]:
                agent._act_impl(
                    {"x": np.array([observation]), "group": np.array([group])},
                    reward=observation > 0.5 + 0.1 * group,
                    done=False,
                )

        agent.frozen = True

        actions = {}
        for group in [0, 1]:
            actions[group] = []
            for observation in np.linspace(0, 1, 1000):
                actions[group].append(
                    agent.act(
                        {"x": np.array([observation]), "group": np.array([group])}, done=False
                    )
                )

        # The two groups are classified with different policies so they are not
        # exactly equal.
        self.assertNotEqual(actions[0], actions[1])
        self.assertLen(agent.group_specific_thresholds, 2)
    def test_one_hot_conversion(self):
        observation_space = gym.spaces.Dict({"x": multinomial.Multinomial(10, 1)})

        params = classifier_agents.ScoringAgentParams(
            default_action_fn=lambda: 0,
            feature_keys=["x"],
            convert_one_hot_to_integer=True,
            threshold_policy=threshold_policies.ThresholdPolicy.SINGLE_THRESHOLD,
        )

        agent = classifier_agents.ThresholdAgent(
            observation_space=observation_space, reward_fn=rewards.NullReward(), params=params
        )

        self.assertEqual(agent._get_features({"x": _one_hot(5)}), [5])
    def test_interact_with_env_replicable(self):
        env = test_util.DummyEnv()
        params = classifier_agents.ScoringAgentParams(
            burnin=10,
            freeze_classifier_after_burnin=False,
            default_action_fn=env.action_space.sample,
            feature_keys=["x"],
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )
        test_util.run_test_simulation(env=env, agent=agent)
    def test_frozen_classifier_never_trains(self):
        env = test_util.DummyEnv()
        params = classifier_agents.ScoringAgentParams(
            burnin=0, default_action_fn=env.action_space.sample, feature_keys=["x"]
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
            frozen=True,
        )
        # Initialize global_threshold with a distinctive value.
        agent.global_threshold = 0.123

        # Run for some number of steps, global_threshold should not change.
        for _ in range(10):
            agent.act(env.observation_space.sample(), False)
        self.assertEqual(agent.global_threshold, 0.123)
    def test_agent_seed(self):
        env = test_util.DummyEnv()

        params = classifier_agents.ScoringAgentParams(
            burnin=10,
            freeze_classifier_after_burnin=False,
            default_action_fn=env.action_space.sample,
            feature_keys=["x"],
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )

        agent.seed(100)
        a = agent.rng.randint(0, 1000)
        agent.seed(100)
        b = agent.rng.randint(0, 1000)
        self.assertEqual(a, b)
    def test_freeze_after_burnin(self):
        env = test_util.DummyEnv()
        burnin = 10
        params = classifier_agents.ScoringAgentParams(
            burnin=burnin,
            freeze_classifier_after_burnin=True,
            default_action_fn=env.action_space.sample,
            feature_keys=["x"],
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )

        for _ in range(burnin + 1):
            self.assertFalse(agent.frozen)
            _ = agent.act(env.observation_space.sample(), False)

        self.assertTrue(agent.frozen)
        self.assertTrue(agent.global_threshold)  # Agent has learned something.
    def test_agent_on_one_hot_vectors(self):

        # Space of 1-hot vectors of length 10.
        observation_space = gym.spaces.Dict({"x": multinomial.Multinomial(10, 1)})

        params = classifier_agents.ScoringAgentParams(
            default_action_fn=lambda: 0,
            feature_keys=["x"],
            convert_one_hot_to_integer=True,
            burnin=999,
            threshold_policy=threshold_policies.ThresholdPolicy.SINGLE_THRESHOLD,
        )

        agent = classifier_agents.ThresholdAgent(
            observation_space=observation_space, reward_fn=rewards.NullReward(), params=params
        )

        observation_space.seed(100)
        # Train a boundary at 3 using 1-hot vectors.
        observation = observation_space.sample()
        agent._act_impl(observation, reward=None, done=False)
        for _ in range(1000):
            last_observation = observation
            observation = observation_space.sample()
            agent._act_impl(
                observation, reward=int(np.argmax(last_observation["x"]) >= 3), done=False
            )
            if agent._training_corpus.examples:
                assert (
                    int(agent._training_corpus.examples[-1].features[0] >= 3)
                    == agent._training_corpus.examples[-1].label
                )

        agent.frozen = True

        self.assertTrue(agent.act({"x": _one_hot(3)}, done=False))
        self.assertFalse(agent.act({"x": _one_hot(2)}, done=False))
    def test_agent_trains(self):
        env = test_util.DummyEnv()
        params = classifier_agents.ScoringAgentParams(
            burnin=200, default_action_fn=env.action_space.sample, feature_keys=["x"]
        )

        agent = classifier_agents.ThresholdAgent(
            action_space=env.action_space,
            observation_space=env.observation_space,
            reward_fn=rewards.BinarizedScalarDeltaReward("x"),
            params=params,
        )

        # Train with points that are nearly separable but have some overlap between
        # 0.3 and 0.4.
        for observation in np.linspace(0, 0.4, 100):
            agent._act_impl({"x": np.array([observation])}, reward=0, done=False)

        for observation in np.linspace(0.3, 0.8, 100):
            agent._act_impl({"x": np.array([observation])}, reward=1, done=False)

        # Add a negative point at the top of the range so that the training labels
        # are not fit perfectly by a threshold.
        agent._act_impl({"x": np.array([0.9])}, reward=0, done=False)

        agent.frozen = True
        actions = [
            agent.act({"x": np.array([obs])}, done=False) for obs in np.linspace(0, 0.95, 100)
        ]

        # Assert some actions are 0 and some are 1.
        self.assertSameElements(actions, {0, 1})
        # Assert actions are sorted - i.e., 0s followed by 1s.
        self.assertSequenceEqual(actions, sorted(actions))

        self.assertGreater(agent.global_threshold, 0)
        self.assertFalse(agent.group_specific_thresholds)