def learner(
        self,
        random_key: networks_lib.PRNGKey,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""
        # Counter and logger.
        counter = counting.Counter(counter, 'learner')
        logger = loggers.make_default_logger(
            'learner',
            self._save_logs,
            time_delta=self._log_every,
            asynchronous=True,
            serialize_fn=utils.fetch_devicearray,
            steps_key='learner_steps')

        # Create the learner.
        networks = self._network_factory()
        learner = self._make_learner(random_key, networks, counter, logger)

        kwargs = {
            'directory': self._workdir,
            'add_uid': self._workdir == '~/acme'
        }
        # Return the learning agent.
        return savers.CheckpointingRunner(learner,
                                          subdirectory='learner',
                                          time_delta_minutes=5,
                                          **kwargs)
 def counter(self):
     kwargs = {
         'directory': self._workdir,
         'add_uid': self._workdir == '~/acme'
     }
     return savers.CheckpointingRunner(counting.Counter(),
                                       subdirectory='counter',
                                       time_delta_minutes=5,
                                       **kwargs)
示例#3
0
 def counter(self):
     return savers.CheckpointingRunner(
         counting.Counter(),
         key='counter',
         subdirectory='counter',
         time_delta_minutes=5,
         directory=self._checkpointing_config.directory,
         add_uid=self._checkpointing_config.add_uid,
         max_to_keep=self._checkpointing_config.max_to_keep)
 def counter(self):
     kwargs = {}
     if self._checkpointing_config:
         kwargs = vars(self._checkpointing_config)
     return savers.CheckpointingRunner(counting.Counter(),
                                       key='counter',
                                       subdirectory='counter',
                                       time_delta_minutes=5,
                                       **kwargs)
示例#5
0
    def learner(
        self,
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: counting.Counter,
    ):
        """The Learning part of the agent."""

        iterator = self._builder.make_dataset_iterator(replay)

        dummy_seed = 1
        environment_spec = (self._environment_spec
                            or specs.make_environment_spec(
                                self._environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = self._network_factory(environment_spec)

        if self._prefetch_size > 1:
            # When working with single GPU we should prefetch to device for
            # efficiency. If running on TPU this isn't necessary as the computation
            # and input placement can be done automatically. For multi-gpu currently
            # the best solution is to pre-fetch to host although this may change in
            # the future.
            device = jax.devices()[0] if self._device_prefetch else None
            iterator = utils.prefetch(iterator,
                                      buffer_size=self._prefetch_size,
                                      device=device)
        else:
            logging.info('Not prefetching the iterator.')

        counter = counting.Counter(counter, 'learner')
        learner = self._builder.make_learner(random_key, networks, iterator,
                                             replay, counter)

        return savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=self._checkpointing_config.directory,
            add_uid=self._checkpointing_config.add_uid,
            max_to_keep=self._checkpointing_config.max_to_keep)
    def build_learner(
        random_key: networks_lib.PRNGKey,
        replay: reverb.Client,
        counter: Optional[counting.Counter] = None,
        primary_learner: Optional[core.Learner] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        iterator = experiment.builder.make_dataset_iterator(replay)
        # make_dataset_iterator is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(random_key, networks,
                                                  iterator,
                                                  experiment.logger_factory,
                                                  spec, replay, counter)

        if primary_learner is None:
            learner = savers.CheckpointingRunner(
                learner,
                key='learner',
                subdirectory='learner',
                time_delta_minutes=5,
                directory=checkpointing_config.directory,
                add_uid=checkpointing_config.add_uid,
                max_to_keep=checkpointing_config.max_to_keep)
        else:
            learner.restore(primary_learner.save())
            # NOTE: This initially synchronizes secondary learner states with the
            # primary one. Further synchronization should be handled by the learner
            # properly doing a pmap/pmean on the loss/gradients, respectively.

        return learner
    def build_learner(
        random_key: networks_lib.PRNGKey,
        counter: Optional[counting.Counter] = None,
    ):
        """The Learning part of the agent."""

        dummy_seed = 1
        spec = (experiment.environment_spec or specs.make_environment_spec(
            experiment.environment_factory(dummy_seed)))

        # Creates the networks to optimize (online) and target networks.
        networks = experiment.network_factory(spec)

        dataset_key, random_key = jax.random.split(random_key)
        iterator = experiment.demonstration_dataset_factory(dataset_key)
        # make_demonstrations is responsible for putting data onto appropriate
        # training devices, so here we apply prefetch, so that data is copied over
        # in the background.
        iterator = utils.prefetch(iterable=iterator, buffer_size=1)
        counter = counting.Counter(counter, 'learner')
        learner = experiment.builder.make_learner(
            random_key=random_key,
            networks=networks,
            dataset=iterator,
            logger_fn=experiment.logger_factory,
            environment_spec=spec,
            counter=counter)

        learner = savers.CheckpointingRunner(
            learner,
            key='learner',
            subdirectory='learner',
            time_delta_minutes=5,
            directory=checkpointing_config.directory,
            add_uid=checkpointing_config.add_uid,
            max_to_keep=checkpointing_config.max_to_keep)

        return learner
示例#8
0
  def learner(
      self,
      random_key,
      replay,
      counter,
  ):
    """The Learning part of the agent."""

    if self._builder._config.env_name.startswith('offline_ant'):  # pytype: disable=attribute-error, pylint: disable=protected-access
      adder = self._builder.make_adder(replay)
      env = self._environment_factory(0)
      dataset = env.get_dataset()  # pytype: disable=attribute-error
      for t in tqdm.trange(dataset['observations'].shape[0]):
        discount = 1.0
        if t == 0 or dataset['timeouts'][t - 1]:
          step_type = dm_env.StepType.FIRST
        elif dataset['timeouts'][t]:
          step_type = dm_env.StepType.LAST
          discount = 0.0
        else:
          step_type = dm_env.StepType.MID

        ts = dm_env.TimeStep(
            step_type=step_type,
            reward=dataset['rewards'][t],
            discount=discount,
            observation=np.concatenate([dataset['observations'][t],
                                        dataset['infos/goal'][t]]),
        )
        if t == 0 or dataset['timeouts'][t - 1]:
          adder.add_first(ts)  # pytype: disable=attribute-error
        else:
          adder.add(action=dataset['actions'][t-1], next_timestep=ts)  # pytype: disable=attribute-error

        if self._builder._config.local and t > 10_000:  # pytype: disable=attribute-error, pylint: disable=protected-access
          break

    iterator = self._builder.make_dataset_iterator(replay)

    dummy_seed = 1
    environment_spec = (
        self._environment_spec or
        specs.make_environment_spec(self._environment_factory(dummy_seed)))

    # Creates the networks to optimize (online) and target networks.
    networks = self._network_factory(environment_spec)

    if self._prefetch_size > 1:
      # When working with single GPU we should prefetch to device for
      # efficiency. If running on TPU this isn't necessary as the computation
      # and input placement can be done automatically. For multi-gpu currently
      # the best solution is to pre-fetch to host although this may change in
      # the future.
      device = jax.devices()[0] if self._device_prefetch else None
      iterator = utils.prefetch(
          iterator, buffer_size=self._prefetch_size, device=device)
    else:
      logging.info('Not prefetching the iterator.')

    counter = counting.Counter(counter, 'learner')

    learner = self._builder.make_learner(random_key, networks, iterator, replay,
                                         counter)
    kwargs = {}
    if self._checkpointing_config:
      kwargs = vars(self._checkpointing_config)
    # Return the learning agent.
    return savers.CheckpointingRunner(
        learner,
        key='learner',
        subdirectory='learner',
        time_delta_minutes=5,
        **kwargs)