def learner( self, random_key: networks_lib.PRNGKey, counter: counting.Counter, ): """The Learning part of the agent.""" # Counter and logger. counter = counting.Counter(counter, 'learner') logger = loggers.make_default_logger( 'learner', self._save_logs, time_delta=self._log_every, asynchronous=True, serialize_fn=utils.fetch_devicearray, steps_key='learner_steps') # Create the learner. networks = self._network_factory() learner = self._make_learner(random_key, networks, counter, logger) kwargs = { 'directory': self._workdir, 'add_uid': self._workdir == '~/acme' } # Return the learning agent. return savers.CheckpointingRunner(learner, subdirectory='learner', time_delta_minutes=5, **kwargs)
def counter(self): kwargs = { 'directory': self._workdir, 'add_uid': self._workdir == '~/acme' } return savers.CheckpointingRunner(counting.Counter(), subdirectory='counter', time_delta_minutes=5, **kwargs)
def counter(self): return savers.CheckpointingRunner( counting.Counter(), key='counter', subdirectory='counter', time_delta_minutes=5, directory=self._checkpointing_config.directory, add_uid=self._checkpointing_config.add_uid, max_to_keep=self._checkpointing_config.max_to_keep)
def counter(self): kwargs = {} if self._checkpointing_config: kwargs = vars(self._checkpointing_config) return savers.CheckpointingRunner(counting.Counter(), key='counter', subdirectory='counter', time_delta_minutes=5, **kwargs)
def learner( self, random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" iterator = self._builder.make_dataset_iterator(replay) dummy_seed = 1 environment_spec = (self._environment_spec or specs.make_environment_spec( self._environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = self._network_factory(environment_spec) if self._prefetch_size > 1: # When working with single GPU we should prefetch to device for # efficiency. If running on TPU this isn't necessary as the computation # and input placement can be done automatically. For multi-gpu currently # the best solution is to pre-fetch to host although this may change in # the future. device = jax.devices()[0] if self._device_prefetch else None iterator = utils.prefetch(iterator, buffer_size=self._prefetch_size, device=device) else: logging.info('Not prefetching the iterator.') counter = counting.Counter(counter, 'learner') learner = self._builder.make_learner(random_key, networks, iterator, replay, counter) return savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=self._checkpointing_config.directory, add_uid=self._checkpointing_config.add_uid, max_to_keep=self._checkpointing_config.max_to_keep)
def build_learner( random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: Optional[counting.Counter] = None, primary_learner: Optional[core.Learner] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) iterator = experiment.builder.make_dataset_iterator(replay) # make_dataset_iterator is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner(random_key, networks, iterator, experiment.logger_factory, spec, replay, counter) if primary_learner is None: learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) else: learner.restore(primary_learner.save()) # NOTE: This initially synchronizes secondary learner states with the # primary one. Further synchronization should be handled by the learner # properly doing a pmap/pmean on the loss/gradients, respectively. return learner
def build_learner( random_key: networks_lib.PRNGKey, counter: Optional[counting.Counter] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) dataset_key, random_key = jax.random.split(random_key) iterator = experiment.demonstration_dataset_factory(dataset_key) # make_demonstrations is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner( random_key=random_key, networks=networks, dataset=iterator, logger_fn=experiment.logger_factory, environment_spec=spec, counter=counter) learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) return learner
def learner( self, random_key, replay, counter, ): """The Learning part of the agent.""" if self._builder._config.env_name.startswith('offline_ant'): # pytype: disable=attribute-error, pylint: disable=protected-access adder = self._builder.make_adder(replay) env = self._environment_factory(0) dataset = env.get_dataset() # pytype: disable=attribute-error for t in tqdm.trange(dataset['observations'].shape[0]): discount = 1.0 if t == 0 or dataset['timeouts'][t - 1]: step_type = dm_env.StepType.FIRST elif dataset['timeouts'][t]: step_type = dm_env.StepType.LAST discount = 0.0 else: step_type = dm_env.StepType.MID ts = dm_env.TimeStep( step_type=step_type, reward=dataset['rewards'][t], discount=discount, observation=np.concatenate([dataset['observations'][t], dataset['infos/goal'][t]]), ) if t == 0 or dataset['timeouts'][t - 1]: adder.add_first(ts) # pytype: disable=attribute-error else: adder.add(action=dataset['actions'][t-1], next_timestep=ts) # pytype: disable=attribute-error if self._builder._config.local and t > 10_000: # pytype: disable=attribute-error, pylint: disable=protected-access break iterator = self._builder.make_dataset_iterator(replay) dummy_seed = 1 environment_spec = ( self._environment_spec or specs.make_environment_spec(self._environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = self._network_factory(environment_spec) if self._prefetch_size > 1: # When working with single GPU we should prefetch to device for # efficiency. If running on TPU this isn't necessary as the computation # and input placement can be done automatically. For multi-gpu currently # the best solution is to pre-fetch to host although this may change in # the future. device = jax.devices()[0] if self._device_prefetch else None iterator = utils.prefetch( iterator, buffer_size=self._prefetch_size, device=device) else: logging.info('Not prefetching the iterator.') counter = counting.Counter(counter, 'learner') learner = self._builder.make_learner(random_key, networks, iterator, replay, counter) kwargs = {} if self._checkpointing_config: kwargs = vars(self._checkpointing_config) # Return the learning agent. return savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, **kwargs)