Exemplo n.º 1
0
    def test_bundle_and_unbundle(self):
        # Initialize agent
        slate_size = 1
        num_candidates = 3
        action_space = spaces.MultiDiscrete(num_candidates * np.ones(
            (slate_size, )))

        user_model = ie.IEUserModel(slate_size,
                                    user_state_ctor=ie.IEUserState,
                                    response_model_ctor=ie.IEResponse)
        agent = random_agent.RandomAgent(action_space, random_seed=0)

        # Create a set of documents
        document_sampler = ie.IETopicDocumentSampler()
        documents = {}
        for i in range(num_candidates):
            video = document_sampler.sample_document()
            documents[i] = video.create_observation()

        # Test that slate indices in correct range and length is correct
        observation = dict(user=user_model.create_observation(), doc=documents)
        agent.step(1, observation)

        bundle_dict = agent.bundle_and_checkpoint('', 0)
        self.assertTrue(agent.unbundle('', 0, bundle_dict))
        self.assertEqual(bundle_dict, agent.bundle_and_checkpoint('', 0))
Exemplo n.º 2
0
  def test_step(self):
    # Create a simple user
    slate_size = 2
    num_candidates = 5
    action_space = spaces.MultiDiscrete(num_candidates * np.ones((slate_size,)))
    user_model = ie.IEUserModel(
        slate_size,
        user_state_ctor=ie.IEUserState,
        response_model_ctor=ie.IEResponse)

    # Create a set of documents
    document_sampler = ie.IETopicDocumentSampler(seed=1)
    ieenv = environment.Environment(
        user_model,
        document_sampler,
        num_candidates,
        slate_size,
        resample_documents=True)

    # Create agent
    agent = greedy_pctr_agent.GreedyPCTRAgent(action_space,
                                              user_model.avg_user_state)

    # This agent doesn't use the previous user response
    observation, documents = ieenv.reset()
    slate = agent.step(1, dict(user=observation, doc=documents))
    scores = [
        user_model.avg_user_state.score_document(doc_obs)
        for doc_obs in list(documents.values())
    ]
    expected_slate = sorted(np.argsort(scores)[-2:])
    self.assertAllEqual(sorted(slate), expected_slate)
Exemplo n.º 3
0
def iex_user_model_creator(env_ctx):
    return iex.IEUserModel(
        env_ctx["slate_size"],
        user_state_ctor=iex.IEUserState,
        response_model_ctor=iex.IEResponse,
        seed=env_ctx["seed"],
    )
Exemplo n.º 4
0
 def setUp(self):
     super(EnvironmentTest, self).setUp()
     self._slate_size = 2
     self._num_candidates = 20
     user_model = ie.IEUserModel(self._slate_size,
                                 user_state_ctor=ie.IEUserState,
                                 response_model_ctor=ie.IEResponse)
     document_sampler = ie.IETopicDocumentSampler()
     self._environment = environment.Environment(user_model,
                                                 document_sampler,
                                                 self._num_candidates,
                                                 self._slate_size)
Exemplo n.º 5
0
 def setUp(self):
     super(MultiUserEnvironmentTest, self).setUp()
     self._slate_size = 2
     self._num_candidates = 20
     self._num_users = 100
     user_models = []
     for _ in range(self._num_users):
         user_models.append(
             ie.IEUserModel(self._slate_size,
                            user_state_ctor=ie.IEUserState,
                            response_model_ctor=ie.IEResponse))
     document_sampler = ie.IETopicDocumentSampler()
     self._environment = environment.MultiUserEnvironment(
         user_models, document_sampler, self._num_candidates,
         self._slate_size)