def __call__(self, replay_buffer: ReplayBuffer, transition: Transition): transition_dict = transition.asdict() obs = transition_dict.pop("observation") user = obs["user"] kwargs = {} if self.box_keys or self.discrete_keys: doc_obs = obs["doc"] for k in self.box_keys: kwargs[f"doc_{k}"] = np.stack([v[k] for v in doc_obs.values()]) for k in self.discrete_keys: kwargs[f"doc_{k}"] = np.array([v[k] for v in doc_obs.values()]) else: kwargs["doc"] = np.stack(list(obs["doc"].values())) # Augmentation if self.augmentation_box_keys or self.augmentation_discrete_keys: aug_obs = obs["augmentation"] for k in self.augmentation_box_keys: kwargs[f"augmentation_{k}"] = np.stack( [v[k] for v in aug_obs.values()]) for k in self.augmentation_discrete_keys: kwargs[f"augmentation_{k}"] = np.array( [v[k] for v in aug_obs.values()]) # Responses response = obs["response"] # We need to handle None below because the first state won't have response for k, d in self.response_box_keys: if response is not None: kwargs[f"response_{k}"] = np.stack([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, *d), dtype=np.float32) for k, _n in self.response_discrete_keys: if response is not None: kwargs[f"response_{k}"] = np.array([v[k] for v in response]) else: kwargs[f"response_{k}"] = np.zeros((self.num_responses, ), dtype=np.int64) transition_dict.update(kwargs) replay_buffer.add(observation=user, **transition_dict)
def __call__(self, replay_buffer: ReplayBuffer, transition: Transition): replay_buffer.add(**transition.asdict())