Exemple #1
0
 def test_no_user_id(self):
     env = movie_lens_dynamic.create_gym_environment(self.env_config)
     agent = batched_movielens_rnn_agent.MovieLensRNNAgent(
         env.observation_space,
         env.action_space,
         stateful=True,
         batch_size=1,
         user_id_input=False,
         user_embedding_size=0,
         max_episode_length=None)
     for _ in range(3):
         for _ in range(7):
             reward = 0
             observation = env.reset()
             for _ in range(2):
                 slate = agent.step(reward, observation)
                 observation, reward, _, _ = env.step(slate)
             agent.end_episode(reward, observation, eval_mode=True)
         # There are 7 episodes in every batch used to update the model.
         agent.set_batch_size(7)
         agent.model_update(learning_rate=0.1,
                            lambda_learning_rate=0.1,
                            var_learning_rate=0.1)
         agent.empty_buffer()
         # The agent and environment simulate one episode at a time.
         agent.set_batch_size(1)
    def test_gym_environment_builder(self):
        env = movie_lens.create_gym_environment(self.env_config)
        env.seed(100)
        env.reset()

        # Recommend some manual slates and check that the observations are as
        # expected.
        for slate in [[0], [0], [2]]:
            observation, _, _, _ = env.step(slate)
            for field in ["doc", "response", "user"]:
                self.assertIn(observation[field],
                              env.observation_space.spaces[field])
Exemple #3
0
def _envs_builder(config, num_envs):
    """Returns a list of environments."""
    # Make the first environment.
    envs = [movie_lens_dynamic.create_gym_environment(config['env_config'])]

    # All subsequent environments are copies with different user sampler seeds.
    for _ in range(1, num_envs):
        logging.info('Build env')
        envs.append(copy.deepcopy(envs[0]))
        # Unseed the envirnment user samplers. Go crazy!
        envs[-1]._environment._user_model._user_sampler._seed = None  # pylint: disable=protected-access
        envs[-1]._environment._user_model.reset_sampler()  # pylint: disable=protected-access
    return envs
Exemple #4
0
 def test_interaction(self):
     env = movie_lens_dynamic.create_gym_environment(self.env_config)
     agent = batched_movielens_rnn_agent.MovieLensRNNAgent(
         env.observation_space, env.action_space, max_episode_length=None)
     for _ in range(3):
         for _ in range(2):
             reward = 0
             observation = env.reset()
             for _ in range(2):
                 slate = agent.step(reward, observation)
                 observation, reward, _, _ = env.step(slate)
             agent.end_episode(reward, observation, eval_mode=True)
         agent.model_update(learning_rate=0.1,
                            lambda_learning_rate=0.1,
                            var_learning_rate=0.1)
         agent.empty_buffer()
Exemple #5
0
 def test_batch_interaction(self):
     envs = [
         movie_lens_dynamic.create_gym_environment(self.env_config)
         for _ in range(5)
     ]
     agent = batched_movielens_rnn_agent.MovieLensRNNAgent(
         envs[0].observation_space,
         envs[0].action_space,
         max_episode_length=None)
     for _ in range(3):
         rewards = [0 for _ in envs]
         observations = [env.reset() for env in envs]
         for _ in range(2):
             slates = agent.step(rewards, observations)
             observations, rewards, _, _ = zip(
                 *[env.step(slate) for env, slate in zip(envs, slates)])
         agent.end_episode(rewards, observations, eval_mode=True)
         agent.model_update(learning_rate=0.1,
                            lambda_learning_rate=0.1,
                            var_learning_rate=0.1)
         agent.empty_buffer()
    def test_user_order_is_consistent(self):
        self.env.reset_sampler()
        first_list = []
        for _ in range(100):
            observation = self.env.reset()
            first_list.append(observation["user"]["user_id"])

        self.env.reset_sampler()
        other_list = []
        for _ in range(100):
            observation = self.env.reset()
            other_list.append(observation["user"]["user_id"])

        self.assertEqual(first_list, other_list)

        # Also check that changing the seed creates a new ordering.
        config = copy.deepcopy(self.env_config)
        config.seeds.user_sampler += 1
        env = movie_lens.create_gym_environment(config)
        other_list = []
        for _ in range(100):
            observation = env.reset()
            other_list.append(observation["user"]["user_id"])
        self.assertNotEqual(first_list, other_list)