Exemplo n.º 1
0
    def test_create_from_recsim_interest_evolution(self):
        env = RecSim(num_candidates=20, slate_size=3, resample_documents=False)
        replay_buffer = ReplayBuffer(replay_capacity=100, batch_size=10)
        obs = env.reset()
        observation = obs["user"]
        action = env.action_space.sample()
        log_prob = -1.0
        doc_features = np.stack(list(obs["doc"].values()), axis=0)

        next_obs, reward, terminal, _env = env.step(action)

        response = next_obs["response"]
        click = np.array([r["click"] for r in response])
        response_quality = np.stack([r["quality"] for r in response], axis=0)
        repsonse_cluster_id = np.array([r["cluster_id"] for r in response])
        response_watch_time = np.stack([r["watch_time"] for r in response],
                                       axis=0)
        response_liked = np.array([r["liked"] for r in response])
        replay_buffer.add(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            mdp_id=0,
            sequence_number=0,
            doc=doc_features,
            response_click=click,
            response_cluster_id=repsonse_cluster_id,
            response_quality=response_quality,
            response_liked=response_liked,
            response_watch_time=response_watch_time,
            log_prob=log_prob,
        )
    def test_recsim_interest_exploration(self):
        num_candidate = 10
        env = RecSim(
            num_candidates=num_candidate,
            slate_size=3,
            resample_documents=False,
            is_interest_exploration=True,
        )
        obs_preprocessor = env.get_obs_preprocessor()
        obs = env.reset()
        state = obs_preprocessor(obs)
        self.assertFalse(state.has_float_features_only)
        self.assertEqual(state.float_features.shape, (1, obs["user"].shape[0]))
        self.assertEqual(state.float_features.dtype, torch.float32)
        self.assertEqual(state.float_features.device, torch.device("cpu"))
        npt.assert_array_almost_equal(obs["user"], state.float_features.squeeze(0))
        doc_float_features = state.candidate_docs.float_features
        self.assertIsNotNone(doc_float_features)

        quality_len = 1
        expected_doc_feature_length = (
            env.observation_space["doc"]["0"]["cluster_id"].n + quality_len
        )

        self.assertEqual(
            doc_float_features.shape, (1, num_candidate, expected_doc_feature_length)
        )
        self.assertEqual(doc_float_features.dtype, torch.float32)
        self.assertEqual(doc_float_features.device, torch.device("cpu"))
        for i, v in enumerate(obs["doc"].values()):
            expected_doc_feature = torch.cat(
                [
                    F.one_hot(torch.tensor(v["cluster_id"]), 2).float(),
                    # This needs unsqueeze because it's a scalar
                    torch.tensor(v["quality"]).unsqueeze(0).float(),
                ],
                dim=0,
            )
            npt.assert_array_almost_equal(
                expected_doc_feature, doc_float_features[0, i]
            )
 def test_recsim_interest_evolution(self):
     num_candidate = 10
     env = RecSim(
         num_candidates=num_candidate, slate_size=3, resample_documents=False
     )
     obs_preprocessor = env.get_obs_preprocessor()
     obs = env.reset()
     state = obs_preprocessor(obs)
     self.assertFalse(state.has_float_features_only)
     self.assertEqual(state.float_features.shape, (1, obs["user"].shape[0]))
     self.assertEqual(state.float_features.dtype, torch.float32)
     self.assertEqual(state.float_features.device, torch.device("cpu"))
     npt.assert_array_almost_equal(obs["user"], state.float_features.squeeze(0))
     doc_float_features = state.candidate_docs.float_features
     self.assertIsNotNone(doc_float_features)
     self.assertEqual(
         doc_float_features.shape, (1, num_candidate, obs["doc"]["0"].shape[0])
     )
     self.assertEqual(doc_float_features.dtype, torch.float32)
     self.assertEqual(doc_float_features.device, torch.device("cpu"))
     for i, v in enumerate(obs["doc"].values()):
         npt.assert_array_almost_equal(v, doc_float_features[0, i])
Exemplo n.º 4
0
 def test_recsim_interest_evolution(self):
     num_candidate = 10
     slate_size = 3
     env = RecSim(
         num_candidates=num_candidate,
         slate_size=slate_size,
         resample_documents=False,
     )
     replay_buffer, inserted = _create_replay_buffer_and_insert(env)
     batch = replay_buffer.sample_transition_batch(indices=torch.tensor([0]))
     npt.assert_array_almost_equal(
         inserted[0]["observation"]["user"], batch.state.squeeze(0)
     )
     npt.assert_array_almost_equal(
         inserted[1]["observation"]["user"], batch.next_state.squeeze(0)
     )
     docs = list(inserted[0]["observation"]["doc"].values())
     next_docs = list(inserted[1]["observation"]["doc"].values())
     for i in range(num_candidate):
         npt.assert_array_equal(docs[i], batch.doc.squeeze(0)[i])
         npt.assert_array_equal(next_docs[i], batch.next_doc.squeeze(0)[i])
     npt.assert_array_equal(inserted[0]["action"], batch.action.squeeze(0))
     npt.assert_array_equal(inserted[1]["action"], batch.next_action.squeeze(0))
     npt.assert_array_equal([0, 0, 0], batch.response_click.squeeze(0))
     npt.assert_array_equal([0, 0, 0], batch.response_cluster_id.squeeze(0))
     npt.assert_array_equal([0, 0, 0], batch.response_liked.squeeze(0))
     npt.assert_array_equal([0.0, 0.0, 0.0], batch.response_quality.squeeze(0))
     npt.assert_array_equal([0.0, 0.0, 0.0], batch.response_watch_time.squeeze(0))
     resp = inserted[1]["observation"]["response"]
     for i in range(slate_size):
         npt.assert_array_equal(
             resp[i]["click"], batch.next_response_click.squeeze(0)[i]
         )
         npt.assert_array_equal(
             resp[i]["cluster_id"], batch.next_response_cluster_id.squeeze(0)[i]
         )
         npt.assert_array_equal(
             resp[i]["liked"], batch.next_response_liked.squeeze(0)[i]
         )
         npt.assert_array_almost_equal(
             resp[i]["quality"], batch.next_response_quality.squeeze(0)[i]
         )
         npt.assert_array_almost_equal(
             resp[i]["watch_time"], batch.next_response_watch_time.squeeze(0)[i]
         )