Пример #1
0
    def add(self, stacked_datapoints):
        """Adds datapoints to the buffer.

        Args:
            stacked_datapoints (pytree): Transition object containing the
                datapoints, stacked along axis 0.
        """
        for datapoint in data.nested_unstack(stacked_datapoints):
            self.memory.add(datapoint)
Пример #2
0
 def _log_top_priority_transitions(self):
     if (self._log_n_top_transitions is not None and
             self._model_trainer.replay_sample_mode == 'priority'):
         samples = self._model_trainer.top_priority_samples(
             self._log_n_top_transitions
         )
         for bucket_name, (batch, priorities) in samples.items():
             inputs, targets = zip(*batch)
             inputs_batch = np.stack(inputs)
             preds = self._model_network.predict(inputs_batch)
             preds_unstacked = nested_unstack(preds)
             batch_with_preds = zip(inputs, targets, preds_unstacked)
             visualizations = self._logging_env.visualize_replay_buffer(
                 batch_with_preds, priorities
             )
             for i, transition_viz in enumerate(visualizations):
                 metric_logging.log_image(
                     f'model/top_priority_{bucket_name}_epoch_'
                     f'{self._epoch}',
                     i, transition_viz
                 )
Пример #3
0
 def add_episode(self, episode):
     episode = [(trans.observation, trans.action, trans.reward,
                 trans.next_observation, trans.done)
                for trans in data.nested_unstack(episode.transition_batch)]
     self._replay_buffer.add(episode)
Пример #4
0
    def unbatch_responses(self, x):
        def unflatten_first_2_dims(x):
            return np.reshape(x, (self._n_agents, -1) + x.shape[1:])

        # (n_agents * n_requests, ...) -> (n_agents, n_requests, ...)
        return data.nested_unstack(data.nested_map(unflatten_first_2_dims, x))
Пример #5
0
 def unbatch_responses(self, response):
     # (n_agents * n_requests, ...) -> (n_agents, n_requests, ...)
     responses = data.nested_map(self._unflatten_first_2_dims, response)
     return data.nested_unstack(responses)