コード例 #1
0
def create_gym_environment(env_config):
  """Returns a RecSimGymEnv with specified environment parameters.

  Args:
    env_config: an `EnvConfig` object.

  Returns:
    A RecSimGymEnv object.
  """

  user_ctor = functools.partial(User, **attr.asdict(env_config.user_config))

  initial_embeddings = utils.load_embeddings(env_config)
  genre_history = utils.load_genre_history(env_config)

  dataset = utils.Dataset(
      env_config.data_dir,
      user_ctor=user_ctor,
      movie_ctor=Movie,
      embeddings=initial_embeddings,
      genre_history=genre_history,
      genre_shift=env_config.genre_shift,
      bias_against_unseen=env_config.bias_against_unseen)

  document_sampler = recsim_samplers.SingletonSampler(dataset.get_movies(),
                                                      Movie)

  user_sampler = recsim_samplers.UserPoolSampler(
      seed=env_config.seeds.user_sampler,
      users=dataset.get_users(),
      user_ctor=user_ctor,
      partitions=env_config.train_eval_test,
      partition_seed=env_config.seeds.train_eval_test)

  user_model = UserModel(
      user_sampler=user_sampler,
      seed=env_config.seeds.user_model,
      slate_size=env_config.slate_size,
  )

  env = MovieLensEnvironment(
      user_model,
      document_sampler,
      num_candidates=document_sampler.size(),
      slate_size=env_config.slate_size,
      resample_documents=False,
  )

  reward_aggregator = average_ratings_reward

  return recsim_gym.RecSimGymEnv(env, reward_aggregator)
コード例 #2
0
def create_gym_environment(env_config):
    """Returns a RecSimGymEnv with specified environment parameters.

  Args:
    env_config: an `EnvConfig` object.
  Returns:
    A RecSimGymEnv object.
  """

    user_ctor = functools.partial(User, **attr.asdict(env_config.user_config))

    initial_embeddings = movie_lens_utils.load_embeddings(env_config)

    dataset = movie_lens_utils.Dataset(env_config.data_dir,
                                       user_ctor=user_ctor,
                                       movie_ctor=Movie,
                                       embeddings=initial_embeddings)

    document_sampler = recsim_samplers.SingletonSampler(
        dataset.get_movies(), Movie)

    user_sampler = recsim_samplers.UserPoolSampler(
        seed=env_config.seeds.user_sampler,
        users=dataset.get_users(),
        user_ctor=user_ctor,
        partitions=env_config.train_eval_test,
        partition_seed=env_config.seeds.train_eval_test)

    user_model = UserModel(
        user_sampler=user_sampler,
        seed=env_config.seeds.user_model,
    )

    env = MovieLensEnvironment(user_model,
                               document_sampler,
                               num_candidates=document_sampler.size(),
                               slate_size=1,
                               resample_documents=False)

    reward_aggregator = functools.partial(
        multiobjective_reward,
        lambda_non_violent=env_config.lambda_non_violent)

    return recsim_gym.RecSimGymEnv(env, reward_aggregator)
コード例 #3
0
  def _initialize_from_config(self, env_config):
    self.working_dir = tempfile.mkdtemp(dir='/tmp')

    self.initial_embeddings = movie_lens_utils.load_embeddings(env_config)
    self.genre_history = movie_lens_utils.load_genre_history(env_config)

    user_ctor = functools.partial(movie_lens.User,
                                  **attr.asdict(env_config.user_config))
    self.dataset = movie_lens_utils.Dataset(
        env_config.data_dir,
        user_ctor=user_ctor,
        movie_ctor=movie_lens.Movie,
        genre_history=self.genre_history,
        embeddings=self.initial_embeddings,
        genre_shift=env_config.genre_shift,
        bias_against_unseen=env_config.bias_against_unseen)

    self.document_sampler = recsim_samplers.SingletonSampler(
        self.dataset.get_movies(), movie_lens.Movie)

    self.user_sampler = recsim_samplers.UserPoolSampler(
        seed=env_config.seeds.user_sampler,
        users=self.dataset.get_users(),
        user_ctor=user_ctor)

    self.user_model = movie_lens.UserModel(
        user_sampler=self.user_sampler,
        seed=env_config.seeds.user_model,
        slate_size=env_config.slate_size,
    )

    env = movie_lens.MovieLensEnvironment(
        self.user_model,
        self.document_sampler,
        num_candidates=self.document_sampler.size(),
        slate_size=env_config.slate_size,
        resample_documents=False)
    env.reset()

    reward_aggregator = movie_lens.average_ratings_reward

    self.env = recsim_gym.RecSimGymEnv(env, reward_aggregator)
コード例 #4
0
    def _initialize_from_config(self, env_config):
        self.working_dir = tempfile.mkdtemp(dir=FLAGS.test_tmpdir)

        self.initial_embeddings = movie_lens_utils.load_embeddings(env_config)

        user_ctor = functools.partial(movie_lens.User,
                                      **attr.asdict(env_config.user_config))
        self.dataset = movie_lens_utils.Dataset(
            env_config.data_dir,
            user_ctor=user_ctor,
            movie_ctor=movie_lens.Movie,
            embeddings=self.initial_embeddings,
        )

        self.document_sampler = recsim_samplers.SingletonSampler(
            self.dataset.get_movies(), movie_lens.Movie)

        self.user_sampler = recsim_samplers.UserPoolSampler(
            seed=env_config.seeds.user_sampler,
            users=self.dataset.get_users(),
            user_ctor=user_ctor,
        )

        self.user_model = movie_lens.UserModel(
            user_sampler=self.user_sampler,
            seed=env_config.seeds.user_model,
        )

        env = movie_lens.MovieLensEnvironment(
            self.user_model,
            self.document_sampler,
            num_candidates=self.document_sampler.size(),
            slate_size=1,
            resample_documents=False,
        )
        env.reset()

        reward_aggregator = functools.partial(
            movie_lens.multiobjective_reward,
            lambda_non_violent=env_config.lambda_non_violent,
        )
        self.env = recsim_gym.RecSimGymEnv(env, reward_aggregator)