def add(self, stacked_datapoints): """Adds datapoints to the buffer. Args: stacked_datapoints (pytree): Transition object containing the datapoints, stacked along axis 0. """ for datapoint in data.nested_unstack(stacked_datapoints): self.memory.add(datapoint)
def _log_top_priority_transitions(self): if (self._log_n_top_transitions is not None and self._model_trainer.replay_sample_mode == 'priority'): samples = self._model_trainer.top_priority_samples( self._log_n_top_transitions ) for bucket_name, (batch, priorities) in samples.items(): inputs, targets = zip(*batch) inputs_batch = np.stack(inputs) preds = self._model_network.predict(inputs_batch) preds_unstacked = nested_unstack(preds) batch_with_preds = zip(inputs, targets, preds_unstacked) visualizations = self._logging_env.visualize_replay_buffer( batch_with_preds, priorities ) for i, transition_viz in enumerate(visualizations): metric_logging.log_image( f'model/top_priority_{bucket_name}_epoch_' f'{self._epoch}', i, transition_viz )
def add_episode(self, episode): episode = [(trans.observation, trans.action, trans.reward, trans.next_observation, trans.done) for trans in data.nested_unstack(episode.transition_batch)] self._replay_buffer.add(episode)
def unbatch_responses(self, x): def unflatten_first_2_dims(x): return np.reshape(x, (self._n_agents, -1) + x.shape[1:]) # (n_agents * n_requests, ...) -> (n_agents, n_requests, ...) return data.nested_unstack(data.nested_map(unflatten_first_2_dims, x))
def unbatch_responses(self, response): # (n_agents * n_requests, ...) -> (n_agents, n_requests, ...) responses = data.nested_map(self._unflatten_first_2_dims, response) return data.nested_unstack(responses)