コード例 #1
0
    def __call__(self, replay_buffer: ReplayBuffer, transition: Transition):
        transition_dict = transition.asdict()
        obs = transition_dict.pop("observation")
        user = obs["user"]

        kwargs = {}

        if self.box_keys or self.discrete_keys:
            doc_obs = obs["doc"]
            for k in self.box_keys:
                kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()])
            for k in self.discrete_keys:
                kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()])
        else:
            kwargs["doc"] = np.stack(list(obs["doc"].values()))

        # Augmentation

        if self.augmentation_box_keys or self.augmentation_discrete_keys:
            aug_obs = obs["augmentation"]
            for k in self.augmentation_box_keys:
                kwargs[f"augmentation_{k}"] = np.stack(
                    [v[k] for v in aug_obs.values()])
            for k in self.augmentation_discrete_keys:
                kwargs[f"augmentation_{k}"] = np.array(
                    [v[k] for v in aug_obs.values()])

        # Responses

        response = obs["response"]
        # We need to handle None below because the first state won't have response
        for k, d in self.response_box_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.stack([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d),
                                                   dtype=np.float32)
        for k, _n in self.response_discrete_keys:
            if response is not None:
                kwargs[f"response_{k}"] = np.array([v[k] for v in response])
            else:
                kwargs[f"response_{k}"] = np.zeros((self.num_responses, ),
                                                   dtype=np.int64)

        transition_dict.update(kwargs)
        replay_buffer.add(observation=user, **transition_dict)
コード例 #2
0
 def __call__(self, replay_buffer: ReplayBuffer, transition: Transition):
     replay_buffer.add(**transition.asdict())