示例#1
0
文件: builder.py 项目: zerocurve/acme
 def make_dataset_iterator(
     self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
   """Creates a dataset iterator to use for learning."""
   dataset = datasets.make_reverb_dataset(
       table=self._config.replay_table_name,
       server_address=replay_client.server_address,
       batch_size=(
           self._config.batch_size * self._config.num_sgd_steps_per_step),
       prefetch_size=self._config.prefetch_size)
   return dataset.as_numpy_iterator()
示例#2
0
文件: builder.py 项目: deepmind/acme
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Creates a dataset."""
     dataset = datasets.make_reverb_dataset(
         table=self._config.replay_table_name,
         server_address=replay_client.server_address,
         batch_size=self._config.batch_size,
         num_parallel_calls=None)
     return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
示例#3
0
文件: builder.py 项目: deepmind/acme
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Create a dataset iterator to use for learning/updating the agent."""
     dataset = datasets.make_reverb_dataset(
         table=self._config.replay_table_name,
         server_address=replay_client.server_address,
         batch_size=self._config.batch_size *
         self._config.num_sgd_steps_per_step,
         prefetch_size=self._config.prefetch_size)
     return utils.device_put(dataset.as_numpy_iterator(), jax.devices()[0])
示例#4
0
 def make_dataset_iterator(
         self,
         replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
     """Create a dataset iterator to use for learning/updating the agent."""
     dataset = datasets.make_reverb_dataset(
         table=self._config.replay_table_name,
         server_address=replay_client.server_address,
         batch_size=self._config.batch_size,
         prefetch_size=self._config.prefetch_size,
         num_parallel_calls=self._config.num_parallel_calls)
     return dataset.as_numpy_iterator()
示例#5
0
    def make_dataset_iterator(
            self,
            replay_client: reverb.Client) -> Iterator[learning.AILSample]:
        batch_size_per_learner_step = ail_config.get_per_learner_step_batch_size(
            self._config)

        iterator_demonstration = self._make_demonstrations(
            batch_size_per_learner_step)

        direct_iterator = self._rl_agent.make_dataset_iterator(replay_client)

        if self._config.share_iterator:
            # In order to reuse the iterator return values and not lose a 2x factor on
            # sample efficiency, we need to use itertools.tee().
            discriminator_iterator, direct_iterator = itertools.tee(
                direct_iterator)
        else:
            discriminator_iterator = datasets.make_reverb_dataset(
                table=self._config.replay_table_name,
                server_address=replay_client.server_address,
                batch_size=ail_config.get_per_learner_step_batch_size(
                    self._config),
                prefetch_size=self._config.prefetch_size).as_numpy_iterator()

        if self._config.policy_to_expert_data_ratio is not None:
            iterator_demonstration, iterator_demonstration2 = itertools.tee(
                iterator_demonstration)
            direct_iterator = _generate_samples_with_demonstrations(
                iterator_demonstration2, direct_iterator,
                self._config.policy_to_expert_data_ratio,
                self._config.direct_rl_batch_size)

        is_sequence_based = self._config.is_sequence_based

        # Don't flatten the discriminator batch if the iterator is not shared.
        process_discriminator_sample = functools.partial(
            reverb_utils.replay_sample_to_sars_transition,
            is_sequence=is_sequence_based and self._config.share_iterator,
            flatten_batch=is_sequence_based and self._config.share_iterator,
            strip_last_transition=is_sequence_based
            and self._config.share_iterator)

        discriminator_iterator = (
            # Remove the extras to have the same nested structure as demonstrations.
            process_discriminator_sample(sample)._replace(extras=())
            for sample in discriminator_iterator)

        return utils.device_put((learning.AILSample(*sample) for sample in zip(
            discriminator_iterator, direct_iterator, iterator_demonstration)),
                                jax.devices()[0])
示例#6
0
  def make_dataset_iterator(
      self, replay_client: reverb.Client) -> Iterator[reverb.ReplaySample]:
    """Creates a dataset."""
    batch_size_per_learner = self._config.batch_size // jax.process_count()
    batch_size_per_device, ragged = divmod(self._config.batch_size,
                                           jax.device_count())
    if ragged:
      raise ValueError(
          'Learner batch size must be divisible by total number of devices!')

    dataset = datasets.make_reverb_dataset(
        table=self._config.replay_table_name,
        server_address=replay_client.server_address,
        batch_size=batch_size_per_device,
        num_parallel_calls=None,
        max_in_flight_samples_per_worker=2 * batch_size_per_learner)

    return utils.multi_device_put(dataset.as_numpy_iterator(),
                                  jax.local_devices())
示例#7
0
文件: builder.py 项目: deepmind/acme
    def make_dataset_iterator(
        self, replay_client: reverb.Client
    ) -> Iterator[r2d2_learning.R2D2ReplaySample]:
        """Create a dataset iterator to use for learning/updating the agent."""
        dataset = datasets.make_reverb_dataset(
            table=self._config.replay_table_name,
            server_address=replay_client.server_address,
            batch_size=self._batch_size_per_device,
            prefetch_size=self._config.prefetch_size,
            num_parallel_calls=self._config.num_parallel_calls)

        # We split samples in two outputs, the keys which need to be kept on-host
        # since int64 arrays are not supported in TPUs, and the entire sample
        # separately so it can be sent to the sgd_step method.
        def split_sample(
                sample: reverb.ReplaySample) -> utils.PrefetchingSplit:
            return utils.PrefetchingSplit(host=sample.info.key, device=sample)

        return utils.multi_device_put(dataset.as_numpy_iterator(),
                                      devices=jax.local_devices(),
                                      split_fn=split_sample)
示例#8
0
def main(_):
  environment = fakes.ContinuousEnvironment(action_dim=8,
                                            observation_dim=87,
                                            episode_length=10000000)
  spec = specs.make_environment_spec(environment)
  replay_tables = make_replay_tables(spec)
  replay_server = reverb.Server(replay_tables, port=None)
  replay_client = reverb.Client(f'localhost:{replay_server.port}')
  adder = make_adder(replay_client)

  timestep = environment.reset()
  adder.add_first(timestep)
  # TODO(raveman): Consider also filling the table to say 1M (too slow).
  for steps in range(10000):
    if steps % 1000 == 0:
      logging.info('Processed %s steps', steps)
    action = np.asarray(np.random.uniform(-1, 1, (8,)), dtype=np.float32)
    next_timestep = environment.step(action)
    adder.add(action, next_timestep, extras=())

  for batch_size in [256, 256 * 8, 256 * 64]:
    for prefetch_size in [0, 1, 4]:
      print(f'Processing batch_size={batch_size} prefetch_size={prefetch_size}')
      ds = datasets.make_reverb_dataset(
          table='default',
          server_address=replay_client.server_address,
          batch_size=batch_size,
          prefetch_size=prefetch_size,
      )
      it = ds.as_numpy_iterator()

      for iteration in range(3):
        t = time.time()
        for _ in range(1000):
          _ = next(it)
        print(f'Iteration {iteration} finished in {time.time() - t}s')
示例#9
0
replay_server_address = 'localhost:%d' % replay_server.port

# Create a 5-step transition adder where in between those steps a discount of
# 0.99 is used (which should be the same discount used for learning).
adder = adders.NStepTransitionAdder(
    priority_fns={replay_table_name: lambda x: 1.},
    client=reverb.Client(replay_server_address),
    n_step=5,
    discount=0.99)

# This connects to the created reverb server; also note that we use a transition
# adder above so we'll tell the dataset function that so that it knows the type
# of data that's coming out.
dataset = datasets.make_reverb_dataset(
    table=replay_table_name,
    server_address=replay_server_address,
    batch_size=256,
    prefetch_size=True)

# Make sure observation network is a Sonnet Module.
observation_network = tf2_utils.batch_concat
observation_network = tf2_utils.to_sonnet_module(observation_network)

# Create the target networks
target_policy_network = copy.deepcopy(policy_network)
target_critic_network = copy.deepcopy(critic_network)
target_observation_network = copy.deepcopy(observation_network)

# Get observation and action specs.
act_spec = environment_spec.actions
obs_spec = environment_spec.observations