Exemplo n.º 1
0
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Create a dataset iterator to use for learning/updating the agent."""
     dataset = reverb.TrajectoryDataset.from_table_signature(
         server_address=replay_client.server_address,
         table=self._config.replay_table_name,
         max_in_flight_samples_per_worker=1)
     return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
Exemplo n.º 2
0
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Creates a dataset."""
     dataset = datasets.make_reverb_dataset(
         table=self._config.replay_table_name,
         server_address=replay_client.server_address,
         batch_size=self._config.batch_size,
         num_parallel_calls=None)
     return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
Exemplo n.º 3
0
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Create a dataset iterator to use for learning/updating the agent."""
     dataset = datasets.make_reverb_dataset(
         table=self._config.replay_table_name,
         server_address=replay_client.server_address,
         batch_size=self._config.batch_size *
         self._config.num_sgd_steps_per_step,
         prefetch_size=self._config.prefetch_size)
     return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
Exemplo n.º 4
0
    def make_dataset_iterator(
            self,
            replay_client: reverb.Client) -> Iterator[learning.AILSample]:
        batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size(
            self._config)

        iterator_demonstration = self._make_demonstrations(
            batch_size_per_learner_step)

        direct_iterator = self._rl_agent.make_dataset_iterator(replay_client)

        if self._config.share_iterator:
            # In order to reuse the iterator return values and not lose a 2x factor on
            # sample efficiency, we need to use itertools.tee().
            discriminator_iterator, direct_iterator = itertools.tee(
                direct_iterator)
        else:
            discriminator_iterator = datasets.make_reverb_dataset(
                table=self._config.replay_table_name,
                server_address=replay_client.server_address,
                batch_size=ail_config.get_per_learner_step_batch_size(
                    self._config),
                prefetch_size=self._config.prefetch_size).as_numpy_iterator()

        if self._config.policy_to_expert_data_ratio is not None:
            iterator_demonstration, iterator_demonstration2 = itertools.tee(
                iterator_demonstration)
            direct_iterator = _generate_samples_with_demonstrations(
                iterator_demonstration2, direct_iterator,
                self._config.policy_to_expert_data_ratio,
                self._config.direct_rl_batch_size)

        is_sequence_based = self._config.is_sequence_based

        # Don't flatten the discriminator batch if the iterator is not shared.
        process_discriminator_sample = functools.partial(
            reverb_utils.replay_sample_to_sars_transition,
            is_sequence=is_sequence_based and self._config.share_iterator,
            flatten_batch=is_sequence_based and self._config.share_iterator,
            strip_last_transition=is_sequence_based
            and self._config.share_iterator)

        discriminator_iterator = (
            # Remove the extras to have the same nested structure as demonstrations.
            process_discriminator_sample(sample)._replace(extras=())
            for sample in discriminator_iterator)

        return utils.device_put((learning.AILSample(*sample) for sample in zip(
            discriminator_iterator, direct_iterator, iterator_demonstration)),
                                jax.devices()[0])
Exemplo n.º 5
0
    def make_dataset_iterator(
        self, replay_client: reverb.Client
    ) -> Optional[Iterator[reverb.ReplaySample]]:
        """The returned iterator returns batches with both expert and policy data.

    Batch items will alternate between expert data and policy data.

    Args:
      replay_client: Reverb client.

    Returns:
      The Replay sample iterator.
    """
        # TODO(eorsini): Make sure we have the exact same format as the rl_agent's
        # adder writes in.
        demonstration_iterator = self._make_demonstrations(
            self._rl_agent_batch_size)

        rb_iterator = self._rl_agent.make_dataset_iterator(replay_client)

        return utils.device_put(
            _generate_sqil_samples(demonstration_iterator, rb_iterator),
            jax.devices()[0])