Exemplo n.º 1
0
  def test_gym_environment_builder(self):
    env = movie_lens.create_gym_environment(self.env_config)
    env.seed(100)
    env.reset()

    # Recommend some manual slates and check that the observations are as
    # expected.
    for slate in [[0], [0], [2]]:
      observation, _, _, _ = env.step(slate)
      for field in ['doc', 'response', 'user']:
        self.assertIn(observation[field], env.observation_space.spaces[field])
Exemplo n.º 2
0
  def test_user_order_is_consistent(self):
    self.env.reset_sampler()
    first_list = []
    for _ in range(100):
      observation = self.env.reset()
      first_list.append(observation['user']['user_id'])

    self.env.reset_sampler()
    other_list = []
    for _ in range(100):
      observation = self.env.reset()
      other_list.append(observation['user']['user_id'])

    self.assertEqual(first_list, other_list)

    # Also check that changing the seed creates a new ordering.
    config = copy.deepcopy(self.env_config)
    config.seeds.user_sampler += 1
    env = movie_lens.create_gym_environment(config)
    other_list = []
    for _ in range(100):
      observation = env.reset()
      other_list.append(observation['user']['user_id'])
    self.assertNotEqual(first_list, other_list)
Exemplo n.º 3
0
def generate_data(env_config,
                  slate_type,
                  n_samples,
                  seed,
                  intercept,
                  rating_coef,
                  div_seek_coef,
                  diversity_coef,
                  recommender_agent=None,
                  shift=0.,
                  user_pool='train'):
    """Generate synthetic data.

  Args:
    env_config: Environment Configuration.
    slate_type: Type of slate, should be one of ['all', 'top1genre',
      'top2genre', 'test'].
    n_samples: Number of samples to generate.
    seed: Random seed.
    intercept: Intercept of the ground truth logReg model P(Y).
    rating_coef: Coefficient for rating in the ground truth model.
    div_seek_coef: Coefficient for diversitySeeker in the ground truth model.
    diversity_coef: Coefficient for diversity in the ground truth model.
    recommender_agent: If None, will make recommendations uniformly at random.
    shift: Increase the rating (prior to clipping) by this amount for each genre
      in [Comedy, Drama, Childrens] and decrease by this amount for each genre
      in [Sci-Fi, Fantasy, War]'))
    user_pool: Train/Eval/Test pool of users.

  Returns:
    Pandas DataFrame containing the samples.
  """
    if not env_config.embeddings_path.endswith(('.json', '.pkl')):
        raise ValueError('Embedding path should end in .json or .pkl')
    if not env_config.genre_history_path.endswith(('.json', '.pkl')):
        raise ValueError('Genre history path should end in .json or .pkl')
    if (env_config.user_config.accept_prob < 0
            or env_config.user_config.accept_prob > 1):
        raise ValueError('Accept probability should be in [0, 1]')

    if slate_type == 'all':
        top_movie_slate = np.arange(utils.NUM_MOVIES)
    elif slate_type == 'top1genre':
        # Most popular movie of each genre, with de-duplication
        top_movie_slate = utils.find_top_movies_per_genre(
            movies_per_genre=1, data_path=env_config.data_dir)
    elif slate_type == 'top2genre':
        # Top 2 most popular movies of each genre, with de-duplication
        top_movie_slate = utils.find_top_movies_per_genre(
            movies_per_genre=2, data_path=env_config.data_dir)
    elif slate_type == 'top20':
        # Top 20 most popular movies
        top_movie_slate = utils.find_top_movies_overall(
            num_movies=20, data_path=env_config.data_dir)
    elif slate_type == 'test':
        top_movie_slate = np.arange(5)
    else:
        raise ValueError('Slate type not recognized')

    env_config.slate_size = len(top_movie_slate)

    # Comedy and drama are the most commonly watched (and therefore least
    # "diverse") genres
    pos_shift = ['Comedy', 'Drama', 'Children\'s']
    # These happen to be more "diverse" genres, and in our top1 from each
    # genre tends to hit Star Wars movies in particular.
    neg_shift = ['Sci-Fi', 'Fantasy', 'War']

    genre_shift = [0.] * len(utils.GENRES)
    for genre in pos_shift:
        genre_shift[utils.GENRE_MAP[genre]] = shift
    for genre in neg_shift:
        genre_shift[utils.GENRE_MAP[genre]] = -shift
    env_config.genre_shift = genre_shift

    # TODO(moberst): All randomization should be through rngs, but currently the
    # multinomial choice model (imported from RecSim) does not accept an rng,
    # so this is a work-around to ensure consistent user choices.
    np.random.seed(seed)
    rng = np.random.default_rng(seed)
    if recommender_agent is None:
        recommender_agent = agent.RandomAgent(movies=top_movie_slate)

    ml_env = movie_lens_simulator.create_gym_environment(env_config)
    ml_env.environment.set_active_pool(user_pool)
    res = []

    for i in range(n_samples):
        if i % 100 == 0:
            logging.info('Iteration: %d / %d', i, n_samples)

        # Generate data, one row per user interaction
        row = {}

        initial_obs = ml_env.reset()
        user_id = initial_obs['user']['user_id']

        row['user'] = user_id
        row['diversity_seeker'] = (
            ml_env.environment.user_model._user_state.diversity_seeking)  # pylint: disable=protected-access

        slate = recommender_agent.make_recommendation(user_id)
        obs, _, _, _ = ml_env.step(slate)
        [response] = obs['response']
        row['rec'] = slate[0]
        row['watched'] = response['doc_id']
        row['rating'] = response['rating']
        row['diversity'] = response['diversity']

        res.append(row)

    df = pd.DataFrame.from_dict(res).sort_values('user')

    logits = intercept + df['rating'] * rating_coef + df[
        'diversity_seeker'] * div_seek_coef + df['diversity'] * diversity_coef

    # Calculate probability of the long-term reward (p_ltr) and sample it.
    df['p_ltr'] = scipy.special.expit(logits)
    df['ltr'] = rng.binomial(1, df['p_ltr'])

    return df