Beispiel #1
0
    def test_create_from_recsim_interest_evolution(self):
        env = RecSim(num_candidates=20, slate_size=3, resample_documents=False)
        replay_buffer = ReplayBuffer(replay_capacity=100, batch_size=10)
        obs = env.reset()
        observation = obs["user"]
        action = env.action_space.sample()
        log_prob = -1.0
        doc_features = np.stack(list(obs["doc"].values()), axis=0)

        next_obs, reward, terminal, _env = env.step(action)

        response = next_obs["response"]
        click = np.array([r["click"] for r in response])
        response_quality = np.stack([r["quality"] for r in response], axis=0)
        repsonse_cluster_id = np.array([r["cluster_id"] for r in response])
        response_watch_time = np.stack([r["watch_time"] for r in response], axis=0)
        response_liked = np.array([r["liked"] for r in response])
        replay_buffer.add(
            observation=observation,
            action=action,
            reward=reward,
            terminal=terminal,
            mdp_id=0,
            sequence_number=0,
            doc=doc_features,
            response_click=click,
            response_cluster_id=repsonse_cluster_id,
            response_quality=response_quality,
            response_liked=response_liked,
            response_watch_time=response_watch_time,
            log_prob=log_prob,
        )
Beispiel #2
0
def setup_buffer(buffer_size, trajectory_lengths, stack_size=None, multi_steps=None):
    """ We will insert one trajectory into the RB. """
    stack_size = stack_size if stack_size is not None else 1
    update_horizon = multi_steps if multi_steps is not None else 1
    memory = ReplayBuffer(
        stack_size=stack_size,
        replay_capacity=buffer_size,
        batch_size=1,
        update_horizon=update_horizon,
        return_everything_as_stack=stack_size is not None,
        return_as_timeline_format=multi_steps is not None,
    )

    i = 0
    for traj_len in trajectory_lengths:
        for j in range(traj_len):
            trans = get_add_transition(i)
            terminal = bool(j == traj_len - 1)
            memory.add(
                observation=trans["state"],
                action=trans["action"],
                reward=trans["reward"],
                terminal=terminal,
                extra1=trans["extra1"],
            )
            i += 1
    return memory.sample_all_valid_transitions()
 def __call__(
     self,
     replay_buffer: ReplayBuffer,
     obs: Any,
     action: Any,
     reward: float,
     terminal: bool,
     log_prob: float,
 ):
     replay_buffer.add(obs, action, reward, terminal, log_prob=log_prob)
Beispiel #4
0
    def __call__(self, replay_buffer: ReplayBuffer, transition: Transition):
        transition_dict = transition.asdict()
        obs = transition_dict.pop("observation")
        user = obs["user"]

        kwargs = {}

        if self.box_keys or self.discrete_keys:
            doc_obs = obs["doc"]
            for k in self.box_keys:
                kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()])
            for k in self.discrete_keys:
                kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()])
        else:
            kwargs["doc"] = np.stack(list(obs["doc"].values()))

        # Augmentation

        if self.augmentation_box_keys or self.augmentation_discrete_keys:
            aug_obs = obs["augmentation"]
            for k in self.augmentation_box_keys:
                kwargs[f"augmentation_{k}"] = np.stack(
                    [v[k] for v in aug_obs.values()])
            for k in self.augmentation_discrete_keys:
                kwargs[f"augmentation_{k}"] = np.array(
                    [v[k] for v in aug_obs.values()])

        # Responses

        response = obs["response"]
        # We need to handle None below because the first state won't have response
        for k, d in self.response_box_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.stack([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d),
                                                   dtype=np.float32)
        for k, _n in self.response_discrete_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.array([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, ),
                                                   dtype=np.int64)

        transition_dict.update(kwargs)
        replay_buffer.add(observation=user, **transition_dict)
Beispiel #5
0
    def __call__(
        self,
        replay_buffer: ReplayBuffer,
        obs: Any,
        action: Any,
        reward: float,
        terminal: bool,
        log_prob: float,
    ):
        user = obs["user"]

        kwargs = {}

        if self.box_keys or self.discrete_keys:
            doc_obs = obs["doc"]
            for k in self.box_keys:
                kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()])
            for k in self.discrete_keys:
                kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()])
        else:
            kwargs["doc"] = np.stack(list(obs["doc"].values()))

        # Responses

        response = obs["response"]
        # We need to handle None below because the first state won't have response
        for k, d in self.response_box_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.stack([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d))
        for k, _n in self.response_discrete_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.array([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, ))

        replay_buffer.add(
            observation=user,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=log_prob,
            **kwargs,
        )
    def __call__(
        self,
        replay_buffer: ReplayBuffer,
        obs: Any,
        action: Any,
        reward: float,
        terminal: bool,
        log_prob: float,
    ):
        user = obs["user"]

        kwargs = {}

        if self.box_keys or self.discrete_keys:
            doc_obs = obs["doc"]
            for k in self.box_keys:
                kwargs["doc_{k}"] = np.vstack([v[k] for v in doc_obs.values()])
            for k in self.discrete_keys:
                kwargs["doc_{k}"] = np.array([v[k] for v in doc_obs.values()])
        else:
            kwargs["doc"] = obs["doc"]

        # Responses

        for k in self.response_box_keys:
            kwargs["response_{k}"] = np.vstack([v[k] for v in obs["response"]])
        for k in self.response_discrete_keys:
            kwargs["response_{k}"] = np.arrray([v[k] for v in obs["response"]])

        replay_buffer.add(
            observation=user,
            action=action,
            reward=reward,
            terminal=terminal,
            log_prob=log_prob,
            **kwargs,
        )
Beispiel #7
0
    def test_sparse_input(self):
        replay_capacity = 100
        num_transitions = replay_capacity // 2
        memory = ReplayBuffer(
            stack_size=1, replay_capacity=replay_capacity, update_horizon=1
        )

        def trans(i):
            sparse_feat1 = list(range(0, i % 4))
            sparse_feat2 = list(range(i % 4, 4))
            id_list = {"sparse_feat1": sparse_feat1, "sparse_feat2": sparse_feat2}
            sparse_feat3 = (list(range(0, i % 7)), [k + 0.5 for k in range(0, i % 7)])
            sparse_feat4 = (list(range(i % 7, 7)), [k + 0.5 for k in range(i % 7, 7)])
            id_score_list = {"sparse_feat3": sparse_feat3, "sparse_feat4": sparse_feat4}
            return {
                "observation": np.ones(OBS_SHAPE, dtype=OBS_TYPE),
                "action": int(2 * i),
                "reward": float(3 * i),
                "terminal": i % 4,
                "id_list": id_list,
                "id_score_list": id_score_list,
            }

        for i in range(num_transitions):
            memory.add(**trans(i))

        indices = list(range(num_transitions - 1))
        batch = memory.sample_transition_batch(len(indices), torch.tensor(indices))

        # calculate expected
        res = {
            "id_list": {"sparse_feat1": ([], []), "sparse_feat2": ([], [])},
            "id_score_list": {
                "sparse_feat3": ([], [], []),
                "sparse_feat4": ([], [], []),
            },
            "next_id_list": {"sparse_feat1": ([], []), "sparse_feat2": ([], [])},
            "next_id_score_list": {
                "sparse_feat3": ([], [], []),
                "sparse_feat4": ([], [], []),
            },
        }
        for i in range(num_transitions - 1):
            feats_i = trans(i)
            feats_next = trans(i + 1)
            for k in ["id_list", "id_score_list"]:
                for feat_id in res[k]:
                    res[k][feat_id][0].append(len(res[k][feat_id][1]))
                    if k == "id_list":
                        res[k][feat_id][1].extend(feats_i[k][feat_id])
                    else:
                        res[k][feat_id][1].extend(feats_i[k][feat_id][0])
                        res[k][feat_id][2].extend(feats_i[k][feat_id][1])

            for k in ["next_id_list", "next_id_score_list"]:
                for feat_id in res[k]:
                    res[k][feat_id][0].append(len(res[k][feat_id][1]))
                    orig_k = k[len("next_") :]
                    if k == "next_id_list":
                        res[k][feat_id][1].extend(feats_next[orig_k][feat_id])
                    else:
                        res[k][feat_id][1].extend(feats_next[orig_k][feat_id][0])
                        res[k][feat_id][2].extend(feats_next[orig_k][feat_id][1])

        for k in ["id_list", "id_score_list", "next_id_list", "next_id_score_list"]:
            for feat_id in res[k]:
                if k in ["id_list", "next_id_list"]:
                    npt.assert_array_equal(
                        res[k][feat_id][0], getattr(batch, k)[feat_id][0]
                    )
                    npt.assert_array_equal(
                        res[k][feat_id][1], getattr(batch, k)[feat_id][1]
                    )
                else:
                    npt.assert_array_equal(
                        res[k][feat_id][0], getattr(batch, k)[feat_id][0]
                    )
                    npt.assert_array_equal(
                        res[k][feat_id][1], getattr(batch, k)[feat_id][1]
                    )
                    npt.assert_array_equal(
                        res[k][feat_id][2], getattr(batch, k)[feat_id][2]
                    )

        # sample random
        _ = memory.sample_transition_batch(10)
Beispiel #8
0
    def test_replay_overflow(self):
        """
        hard to make a stress test for this, since tracking which indices
        gets replaced would be effectively building a second RB
        so instead opt for simple test...
        stack_size = 2 so there's 1 padding.
        """
        multi_steps = 2
        stack_size = 2
        memory = ReplayBuffer(
            stack_size=stack_size,
            replay_capacity=6,
            batch_size=1,
            update_horizon=multi_steps,
            return_everything_as_stack=None,
            return_as_timeline_format=True,
        )

        def trans(i):
            return {
                "observation": np.ones(OBS_SHAPE, dtype=OBS_TYPE),
                "action": int(2 * i),
                "reward": float(3 * i),
            }

        # Contents of RB
        # start: [X, X, X, X, X, X]
        npt.assert_array_equal(
            memory._is_index_valid, [False, False, False, False, False, False]
        )

        # t0: [X, s0, X, X, X, X]
        memory.add(**trans(0), terminal=False)
        npt.assert_array_equal(
            memory._is_index_valid, [False, False, False, False, False, False]
        )

        # t1: [X, s0, s1, X, X, X]
        memory.add(**trans(1), terminal=False)
        npt.assert_array_equal(
            memory._is_index_valid, [False, False, False, False, False, False]
        )

        # t2: [X, s0, s1, s2, X, X]
        # s0 finally becomes valid as its next state was added
        memory.add(**trans(2), terminal=False)
        npt.assert_array_equal(
            memory._is_index_valid, [False, True, False, False, False, False]
        )
        batch = memory.sample_all_valid_transitions()
        npt.assert_array_equal(batch.action, [[0, 0]])
        npt.assert_array_equal(batch.next_action[0], [[0, 2], [2, 4]])

        # t3: [X, s0, s1, s2, s3, X]
        # episode termination validates whole episode
        memory.add(**trans(3), terminal=True)
        npt.assert_array_equal(
            memory._is_index_valid, [False, True, True, True, True, False]
        )
        batch = memory.sample_all_valid_transitions()
        npt.assert_array_equal(batch.action, [[0, 0], [0, 2], [2, 4], [4, 6]])
        npt.assert_array_equal(batch.next_action[0], [[0, 2], [2, 4]])
        npt.assert_array_equal(batch.next_action[1], [[2, 4], [4, 6]])
        # batch.next_action[2][1] is garbage
        npt.assert_array_equal(batch.next_action[2][0], [4, 6])
        # batch.next_action[3] is [garbage]

        # t4: [s4, s0, s1, s2, s3, X]
        # s0 invalidated as its previous frame is corrupted
        memory.add(**trans(4), terminal=False)
        npt.assert_array_equal(
            memory._is_index_valid, [False, False, True, True, True, False]
        )
        batch = memory.sample_all_valid_transitions()
        npt.assert_array_equal(batch.action, [[0, 2], [2, 4], [4, 6]])
        npt.assert_array_equal(batch.next_action[0], [[2, 4], [4, 6]])
        npt.assert_array_equal(batch.next_action[1][0], [4, 6])

        # t5: [s4, s5, s1, s2, s3, X]
        memory.add(**trans(5), terminal=False)
        npt.assert_array_equal(
            memory._is_index_valid, [False, False, False, True, True, False]
        )
        batch = memory.sample_all_valid_transitions()
        npt.assert_array_equal(batch.action, [[2, 4], [4, 6]])
        npt.assert_array_equal(batch.next_action[0][0], [4, 6])

        # t6: [s4, s5, s6, s2, s3, X]
        memory.add(**trans(6), terminal=True)
        npt.assert_array_equal(
            memory._is_index_valid, [True, True, True, False, True, False]
        )
        batch = memory.sample_all_valid_transitions()
        npt.assert_array_equal(batch.action, [[0, 8], [8, 10], [10, 12], [4, 6]])
        npt.assert_array_equal(batch.next_action[0], [[8, 10], [10, 12]])
        npt.assert_array_equal(batch.next_action[1][0], [10, 12])
        # batch.next_action[2] is [garbage]
        # batch.next_action[3] is [garbage]

        logger.info("Overflow test passes!")
Beispiel #9
0
 def __call__(self, replay_buffer: ReplayBuffer, transition: Transition):
     replay_buffer.add(**transition.asdict())