Esempio n. 1
0
def create_gym_environment(env_config):
  """Returns a RecSimGymEnv with specified environment parameters.

  Args:
    env_config: an `EnvConfig` object.

  Returns:
    A RecSimGymEnv object.
  """

  user_ctor = functools.partial(User, **attr.asdict(env_config.user_config))

  initial_embeddings = utils.load_embeddings(env_config)
  genre_history = utils.load_genre_history(env_config)

  dataset = utils.Dataset(
      env_config.data_dir,
      user_ctor=user_ctor,
      movie_ctor=Movie,
      embeddings=initial_embeddings,
      genre_history=genre_history,
      genre_shift=env_config.genre_shift,
      bias_against_unseen=env_config.bias_against_unseen)

  document_sampler = recsim_samplers.SingletonSampler(dataset.get_movies(),
                                                      Movie)

  user_sampler = recsim_samplers.UserPoolSampler(
      seed=env_config.seeds.user_sampler,
      users=dataset.get_users(),
      user_ctor=user_ctor,
      partitions=env_config.train_eval_test,
      partition_seed=env_config.seeds.train_eval_test)

  user_model = UserModel(
      user_sampler=user_sampler,
      seed=env_config.seeds.user_model,
      slate_size=env_config.slate_size,
  )

  env = MovieLensEnvironment(
      user_model,
      document_sampler,
      num_candidates=document_sampler.size(),
      slate_size=env_config.slate_size,
      resample_documents=False,
  )

  reward_aggregator = average_ratings_reward

  return recsim_gym.RecSimGymEnv(env, reward_aggregator)
Esempio n. 2
0
  def _initialize_from_config(self, env_config):
    self.working_dir = tempfile.mkdtemp(dir='/tmp')

    self.initial_embeddings = movie_lens_utils.load_embeddings(env_config)
    self.genre_history = movie_lens_utils.load_genre_history(env_config)

    user_ctor = functools.partial(movie_lens.User,
                                  **attr.asdict(env_config.user_config))
    self.dataset = movie_lens_utils.Dataset(
        env_config.data_dir,
        user_ctor=user_ctor,
        movie_ctor=movie_lens.Movie,
        genre_history=self.genre_history,
        embeddings=self.initial_embeddings,
        genre_shift=env_config.genre_shift,
        bias_against_unseen=env_config.bias_against_unseen)

    self.document_sampler = recsim_samplers.SingletonSampler(
        self.dataset.get_movies(), movie_lens.Movie)

    self.user_sampler = recsim_samplers.UserPoolSampler(
        seed=env_config.seeds.user_sampler,
        users=self.dataset.get_users(),
        user_ctor=user_ctor)

    self.user_model = movie_lens.UserModel(
        user_sampler=self.user_sampler,
        seed=env_config.seeds.user_model,
        slate_size=env_config.slate_size,
    )

    env = movie_lens.MovieLensEnvironment(
        self.user_model,
        self.document_sampler,
        num_candidates=self.document_sampler.size(),
        slate_size=env_config.slate_size,
        resample_documents=False)
    env.reset()

    reward_aggregator = movie_lens.average_ratings_reward

    self.env = recsim_gym.RecSimGymEnv(env, reward_aggregator)