def test_create_from_recsim_interest_evolution(self): env = RecSim(num_candidates=20, slate_size=3, resample_documents=False) replay_buffer = ReplayBuffer(replay_capacity=100, batch_size=10) obs = env.reset() observation = obs["user"] action = env.action_space.sample() log_prob = -1.0 doc_features = np.stack(list(obs["doc"].values()), axis=0) next_obs, reward, terminal, _env = env.step(action) response = next_obs["response"] click = np.array([r["click"] for r in response]) response_quality = np.stack([r["quality"] for r in response], axis=0) repsonse_cluster_id = np.array([r["cluster_id"] for r in response]) response_watch_time = np.stack([r["watch_time"] for r in response], axis=0) response_liked = np.array([r["liked"] for r in response]) replay_buffer.add( observation=observation, action=action, reward=reward, terminal=terminal, mdp_id=0, sequence_number=0, doc=doc_features, response_click=click, response_cluster_id=repsonse_cluster_id, response_quality=response_quality, response_liked=response_liked, response_watch_time=response_watch_time, log_prob=log_prob, )
def test_recsim_interest_exploration(self): num_candidate = 10 env = RecSim( num_candidates=num_candidate, slate_size=3, resample_documents=False, is_interest_exploration=True, ) obs_preprocessor = env.get_obs_preprocessor() obs = env.reset() state = obs_preprocessor(obs) self.assertFalse(state.has_float_features_only) self.assertEqual(state.float_features.shape, (1, obs["user"].shape[0])) self.assertEqual(state.float_features.dtype, torch.float32) self.assertEqual(state.float_features.device, torch.device("cpu")) npt.assert_array_almost_equal(obs["user"], state.float_features.squeeze(0)) doc_float_features = state.candidate_docs.float_features self.assertIsNotNone(doc_float_features) quality_len = 1 expected_doc_feature_length = ( env.observation_space["doc"]["0"]["cluster_id"].n + quality_len ) self.assertEqual( doc_float_features.shape, (1, num_candidate, expected_doc_feature_length) ) self.assertEqual(doc_float_features.dtype, torch.float32) self.assertEqual(doc_float_features.device, torch.device("cpu")) for i, v in enumerate(obs["doc"].values()): expected_doc_feature = torch.cat( [ F.one_hot(torch.tensor(v["cluster_id"]), 2).float(), # This needs unsqueeze because it's a scalar torch.tensor(v["quality"]).unsqueeze(0).float(), ], dim=0, ) npt.assert_array_almost_equal( expected_doc_feature, doc_float_features[0, i] )
def test_recsim_interest_evolution(self): num_candidate = 10 env = RecSim( num_candidates=num_candidate, slate_size=3, resample_documents=False ) obs_preprocessor = env.get_obs_preprocessor() obs = env.reset() state = obs_preprocessor(obs) self.assertFalse(state.has_float_features_only) self.assertEqual(state.float_features.shape, (1, obs["user"].shape[0])) self.assertEqual(state.float_features.dtype, torch.float32) self.assertEqual(state.float_features.device, torch.device("cpu")) npt.assert_array_almost_equal(obs["user"], state.float_features.squeeze(0)) doc_float_features = state.candidate_docs.float_features self.assertIsNotNone(doc_float_features) self.assertEqual( doc_float_features.shape, (1, num_candidate, obs["doc"]["0"].shape[0]) ) self.assertEqual(doc_float_features.dtype, torch.float32) self.assertEqual(doc_float_features.device, torch.device("cpu")) for i, v in enumerate(obs["doc"].values()): npt.assert_array_almost_equal(v, doc_float_features[0, i])
def test_recsim_interest_evolution(self): num_candidate = 10 slate_size = 3 env = RecSim( num_candidates=num_candidate, slate_size=slate_size, resample_documents=False, ) replay_buffer, inserted = _create_replay_buffer_and_insert(env) batch = replay_buffer.sample_transition_batch(indices=torch.tensor([0])) npt.assert_array_almost_equal( inserted[0]["observation"]["user"], batch.state.squeeze(0) ) npt.assert_array_almost_equal( inserted[1]["observation"]["user"], batch.next_state.squeeze(0) ) docs = list(inserted[0]["observation"]["doc"].values()) next_docs = list(inserted[1]["observation"]["doc"].values()) for i in range(num_candidate): npt.assert_array_equal(docs[i], batch.doc.squeeze(0)[i]) npt.assert_array_equal(next_docs[i], batch.next_doc.squeeze(0)[i]) npt.assert_array_equal(inserted[0]["action"], batch.action.squeeze(0)) npt.assert_array_equal(inserted[1]["action"], batch.next_action.squeeze(0)) npt.assert_array_equal([0, 0, 0], batch.response_click.squeeze(0)) npt.assert_array_equal([0, 0, 0], batch.response_cluster_id.squeeze(0)) npt.assert_array_equal([0, 0, 0], batch.response_liked.squeeze(0)) npt.assert_array_equal([0.0, 0.0, 0.0], batch.response_quality.squeeze(0)) npt.assert_array_equal([0.0, 0.0, 0.0], batch.response_watch_time.squeeze(0)) resp = inserted[1]["observation"]["response"] for i in range(slate_size): npt.assert_array_equal( resp[i]["click"], batch.next_response_click.squeeze(0)[i] ) npt.assert_array_equal( resp[i]["cluster_id"], batch.next_response_cluster_id.squeeze(0)[i] ) npt.assert_array_equal( resp[i]["liked"], batch.next_response_liked.squeeze(0)[i] ) npt.assert_array_almost_equal( resp[i]["quality"], batch.next_response_quality.squeeze(0)[i] ) npt.assert_array_almost_equal( resp[i]["watch_time"], batch.next_response_watch_time.squeeze(0)[i] )