def test_create_from_recsim_interest_evolution(self): env = RecSim(num_candidates=20, slate_size=3, resample_documents=False) replay_buffer = ReplayBuffer(replay_capacity=100, batch_size=10) obs = env.reset() observation = obs["user"] action = env.action_space.sample() log_prob = -1.0 doc_features = np.stack(list(obs["doc"].values()), axis=0) next_obs, reward, terminal, _env = env.step(action) response = next_obs["response"] click = np.array([r["click"] for r in response]) response_quality = np.stack([r["quality"] for r in response], axis=0) repsonse_cluster_id = np.array([r["cluster_id"] for r in response]) response_watch_time = np.stack([r["watch_time"] for r in response], axis=0) response_liked = np.array([r["liked"] for r in response]) replay_buffer.add( observation=observation, action=action, reward=reward, terminal=terminal, mdp_id=0, sequence_number=0, doc=doc_features, response_click=click, response_cluster_id=repsonse_cluster_id, response_quality=response_quality, response_liked=response_liked, response_watch_time=response_watch_time, log_prob=log_prob, )
def setup_buffer(buffer_size, trajectory_lengths, stack_size=None, multi_steps=None): """ We will insert one trajectory into the RB. """ stack_size = stack_size if stack_size is not None else 1 update_horizon = multi_steps if multi_steps is not None else 1 memory = ReplayBuffer( stack_size=stack_size, replay_capacity=buffer_size, batch_size=1, update_horizon=update_horizon, return_everything_as_stack=stack_size is not None, return_as_timeline_format=multi_steps is not None, ) i = 0 for traj_len in trajectory_lengths: for j in range(traj_len): trans = get_add_transition(i) terminal = bool(j == traj_len - 1) memory.add( observation=trans["state"], action=trans["action"], reward=trans["reward"], terminal=terminal, extra1=trans["extra1"], ) i += 1 return memory.sample_all_valid_transitions()
def __call__( self, replay_buffer: ReplayBuffer, obs: Any, action: Any, reward: float, terminal: bool, log_prob: float, ): replay_buffer.add(obs, action, reward, terminal, log_prob=log_prob)
def __call__(self, replay_buffer: ReplayBuffer, transition: Transition): transition_dict = transition.asdict() obs = transition_dict.pop("observation") user = obs["user"] kwargs = {} if self.box_keys or self.discrete_keys: doc_obs = obs["doc"] for k in self.box_keys: kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()]) for k in self.discrete_keys: kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()]) else: kwargs["doc"] = np.stack(list(obs["doc"].values())) # Augmentation if self.augmentation_box_keys or self.augmentation_discrete_keys: aug_obs = obs["augmentation"] for k in self.augmentation_box_keys: kwargs[f"augmentation_{k}"] = np.stack( [v[k] for v in aug_obs.values()]) for k in self.augmentation_discrete_keys: kwargs[f"augmentation_{k}"] = np.array( [v[k] for v in aug_obs.values()]) # Responses response = obs["response"] # We need to handle None below because the first state won't have response for k, d in self.response_box_keys: if response is not None: kwargs[f"response_{k}"] = np.stack([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d), dtype=np.float32) for k, _n in self.response_discrete_keys: if response is not None: kwargs[f"response_{k}"] = np.array([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, ), dtype=np.int64) transition_dict.update(kwargs) replay_buffer.add(observation=user, **transition_dict)
def __call__( self, replay_buffer: ReplayBuffer, obs: Any, action: Any, reward: float, terminal: bool, log_prob: float, ): user = obs["user"] kwargs = {} if self.box_keys or self.discrete_keys: doc_obs = obs["doc"] for k in self.box_keys: kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()]) for k in self.discrete_keys: kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()]) else: kwargs["doc"] = np.stack(list(obs["doc"].values())) # Responses response = obs["response"] # We need to handle None below because the first state won't have response for k, d in self.response_box_keys: if response is not None: kwargs[f"response_{k}"] = np.stack([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d)) for k, _n in self.response_discrete_keys: if response is not None: kwargs[f"response_{k}"] = np.array([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, )) replay_buffer.add( observation=user, action=action, reward=reward, terminal=terminal, log_prob=log_prob, **kwargs, )
def __call__( self, replay_buffer: ReplayBuffer, obs: Any, action: Any, reward: float, terminal: bool, log_prob: float, ): user = obs["user"] kwargs = {} if self.box_keys or self.discrete_keys: doc_obs = obs["doc"] for k in self.box_keys: kwargs["doc_{k}"] = np.vstack([v[k] for v in doc_obs.values()]) for k in self.discrete_keys: kwargs["doc_{k}"] = np.array([v[k] for v in doc_obs.values()]) else: kwargs["doc"] = obs["doc"] # Responses for k in self.response_box_keys: kwargs["response_{k}"] = np.vstack([v[k] for v in obs["response"]]) for k in self.response_discrete_keys: kwargs["response_{k}"] = np.arrray([v[k] for v in obs["response"]]) replay_buffer.add( observation=user, action=action, reward=reward, terminal=terminal, log_prob=log_prob, **kwargs, )
def test_sparse_input(self): replay_capacity = 100 num_transitions = replay_capacity // 2 memory = ReplayBuffer( stack_size=1, replay_capacity=replay_capacity, update_horizon=1 ) def trans(i): sparse_feat1 = list(range(0, i % 4)) sparse_feat2 = list(range(i % 4, 4)) id_list = {"sparse_feat1": sparse_feat1, "sparse_feat2": sparse_feat2} sparse_feat3 = (list(range(0, i % 7)), [k + 0.5 for k in range(0, i % 7)]) sparse_feat4 = (list(range(i % 7, 7)), [k + 0.5 for k in range(i % 7, 7)]) id_score_list = {"sparse_feat3": sparse_feat3, "sparse_feat4": sparse_feat4} return { "observation": np.ones(OBS_SHAPE, dtype=OBS_TYPE), "action": int(2 * i), "reward": float(3 * i), "terminal": i % 4, "id_list": id_list, "id_score_list": id_score_list, } for i in range(num_transitions): memory.add(**trans(i)) indices = list(range(num_transitions - 1)) batch = memory.sample_transition_batch(len(indices), torch.tensor(indices)) # calculate expected res = { "id_list": {"sparse_feat1": ([], []), "sparse_feat2": ([], [])}, "id_score_list": { "sparse_feat3": ([], [], []), "sparse_feat4": ([], [], []), }, "next_id_list": {"sparse_feat1": ([], []), "sparse_feat2": ([], [])}, "next_id_score_list": { "sparse_feat3": ([], [], []), "sparse_feat4": ([], [], []), }, } for i in range(num_transitions - 1): feats_i = trans(i) feats_next = trans(i + 1) for k in ["id_list", "id_score_list"]: for feat_id in res[k]: res[k][feat_id][0].append(len(res[k][feat_id][1])) if k == "id_list": res[k][feat_id][1].extend(feats_i[k][feat_id]) else: res[k][feat_id][1].extend(feats_i[k][feat_id][0]) res[k][feat_id][2].extend(feats_i[k][feat_id][1]) for k in ["next_id_list", "next_id_score_list"]: for feat_id in res[k]: res[k][feat_id][0].append(len(res[k][feat_id][1])) orig_k = k[len("next_") :] if k == "next_id_list": res[k][feat_id][1].extend(feats_next[orig_k][feat_id]) else: res[k][feat_id][1].extend(feats_next[orig_k][feat_id][0]) res[k][feat_id][2].extend(feats_next[orig_k][feat_id][1]) for k in ["id_list", "id_score_list", "next_id_list", "next_id_score_list"]: for feat_id in res[k]: if k in ["id_list", "next_id_list"]: npt.assert_array_equal( res[k][feat_id][0], getattr(batch, k)[feat_id][0] ) npt.assert_array_equal( res[k][feat_id][1], getattr(batch, k)[feat_id][1] ) else: npt.assert_array_equal( res[k][feat_id][0], getattr(batch, k)[feat_id][0] ) npt.assert_array_equal( res[k][feat_id][1], getattr(batch, k)[feat_id][1] ) npt.assert_array_equal( res[k][feat_id][2], getattr(batch, k)[feat_id][2] ) # sample random _ = memory.sample_transition_batch(10)
def test_replay_overflow(self): """ hard to make a stress test for this, since tracking which indices gets replaced would be effectively building a second RB so instead opt for simple test... stack_size = 2 so there's 1 padding. """ multi_steps = 2 stack_size = 2 memory = ReplayBuffer( stack_size=stack_size, replay_capacity=6, batch_size=1, update_horizon=multi_steps, return_everything_as_stack=None, return_as_timeline_format=True, ) def trans(i): return { "observation": np.ones(OBS_SHAPE, dtype=OBS_TYPE), "action": int(2 * i), "reward": float(3 * i), } # Contents of RB # start: [X, X, X, X, X, X] npt.assert_array_equal( memory._is_index_valid, [False, False, False, False, False, False] ) # t0: [X, s0, X, X, X, X] memory.add(**trans(0), terminal=False) npt.assert_array_equal( memory._is_index_valid, [False, False, False, False, False, False] ) # t1: [X, s0, s1, X, X, X] memory.add(**trans(1), terminal=False) npt.assert_array_equal( memory._is_index_valid, [False, False, False, False, False, False] ) # t2: [X, s0, s1, s2, X, X] # s0 finally becomes valid as its next state was added memory.add(**trans(2), terminal=False) npt.assert_array_equal( memory._is_index_valid, [False, True, False, False, False, False] ) batch = memory.sample_all_valid_transitions() npt.assert_array_equal(batch.action, [[0, 0]]) npt.assert_array_equal(batch.next_action[0], [[0, 2], [2, 4]]) # t3: [X, s0, s1, s2, s3, X] # episode termination validates whole episode memory.add(**trans(3), terminal=True) npt.assert_array_equal( memory._is_index_valid, [False, True, True, True, True, False] ) batch = memory.sample_all_valid_transitions() npt.assert_array_equal(batch.action, [[0, 0], [0, 2], [2, 4], [4, 6]]) npt.assert_array_equal(batch.next_action[0], [[0, 2], [2, 4]]) npt.assert_array_equal(batch.next_action[1], [[2, 4], [4, 6]]) # batch.next_action[2][1] is garbage npt.assert_array_equal(batch.next_action[2][0], [4, 6]) # batch.next_action[3] is [garbage] # t4: [s4, s0, s1, s2, s3, X] # s0 invalidated as its previous frame is corrupted memory.add(**trans(4), terminal=False) npt.assert_array_equal( memory._is_index_valid, [False, False, True, True, True, False] ) batch = memory.sample_all_valid_transitions() npt.assert_array_equal(batch.action, [[0, 2], [2, 4], [4, 6]]) npt.assert_array_equal(batch.next_action[0], [[2, 4], [4, 6]]) npt.assert_array_equal(batch.next_action[1][0], [4, 6]) # t5: [s4, s5, s1, s2, s3, X] memory.add(**trans(5), terminal=False) npt.assert_array_equal( memory._is_index_valid, [False, False, False, True, True, False] ) batch = memory.sample_all_valid_transitions() npt.assert_array_equal(batch.action, [[2, 4], [4, 6]]) npt.assert_array_equal(batch.next_action[0][0], [4, 6]) # t6: [s4, s5, s6, s2, s3, X] memory.add(**trans(6), terminal=True) npt.assert_array_equal( memory._is_index_valid, [True, True, True, False, True, False] ) batch = memory.sample_all_valid_transitions() npt.assert_array_equal(batch.action, [[0, 8], [8, 10], [10, 12], [4, 6]]) npt.assert_array_equal(batch.next_action[0], [[8, 10], [10, 12]]) npt.assert_array_equal(batch.next_action[1][0], [10, 12]) # batch.next_action[2] is [garbage] # batch.next_action[3] is [garbage] logger.info("Overflow test passes!")
def __call__(self, replay_buffer: ReplayBuffer, transition: Transition): replay_buffer.add(**transition.asdict())