Esempio n. 1
0
    def test_bundle_and_unbundle(self):
        # Initialize agent
        slate_size = 2
        num_candidates = 5
        action_space = spaces.MultiDiscrete(num_candidates * np.ones(
            (slate_size, )))

        agent = cluster_bandit_agent.ClusterBanditAgent(
            self.dummy_observation_space(), action_space)

        # Create a set of documents
        document_sampler = ie.IETopicDocumentSampler()
        documents = {}
        for i in range(num_candidates):
            video = document_sampler.sample_document()
            documents[i] = video.create_observation()

        # Test that slate indices in correct range and length is correct
        sufficient_stats_observation = self.doc_user_to_sufficient_stats(
            documents, np.array([0, 0, 0, 0]))

        agent.step(1, sufficient_stats_observation)

        bundle_dict = agent.bundle_and_checkpoint('', 0)
        self.assertTrue(agent.unbundle('', 0, bundle_dict))
        self.assertEqual(bundle_dict, agent.bundle_and_checkpoint('', 0))
Esempio n. 2
0
 def test_bundle_and_unbundle_trivial(self):
     action_space = spaces.MultiDiscrete(2 * np.ones((2, )))
     agent = cluster_bandit_agent.ClusterBanditAgent(
         self.dummy_observation_space(), action_space)
     self.assertFalse(agent.unbundle('', 0, {}))
     self.assertEqual(
         {
             'base_agent_bundle_0': {
                 'episode_num': 0
             },
             'base_agent_bundle_1': {
                 'episode_num': 0
             }
         }, agent.bundle_and_checkpoint('', 0))
Esempio n. 3
0
    def test_step_with_bigger_slate(self):
        # Initialize agent.
        slate_size = 5
        num_candidates = 5
        action_space = spaces.MultiDiscrete(num_candidates * np.ones(
            (slate_size, )))
        agent = cluster_bandit_agent.ClusterBanditAgent(
            self.dummy_observation_space(), action_space)

        # Create a set of documents
        document_sampler = ie.IETopicDocumentSampler(seed=1)
        documents = {}
        for i in range(num_candidates):
            video = document_sampler.sample_document()
            documents[i] = video.create_observation()

        # Past observation shows Topic 1 is better.
        user_obs = np.array([1, 1, 0, 1])
        sufficient_stats_observation = self.doc_user_to_sufficient_stats(
            documents, user_obs)
        slate = agent.step(0, sufficient_stats_observation)
        # Documents in Topic 0 sorted by quality: 1, 2.
        # Documents in Topic 1 sorted by quality: 0, 4, 3.
        self.assertAllEqual(slate, [0, 4, 3, 1, 2])