def learner( self, random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: counting.Counter, ): """The Learning part of the agent.""" iterator = self._builder.make_dataset_iterator(replay) dummy_seed = 1 environment_spec = (self._environment_spec or specs.make_environment_spec( self._environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = self._network_factory(environment_spec) if self._prefetch_size > 1: # When working with single GPU we should prefetch to device for # efficiency. If running on TPU this isn't necessary as the computation # and input placement can be done automatically. For multi-gpu currently # the best solution is to pre-fetch to host although this may change in # the future. device = jax.devices()[0] if self._device_prefetch else None iterator = utils.prefetch(iterator, buffer_size=self._prefetch_size, device=device) else: logging.info('Not prefetching the iterator.') counter = counting.Counter(counter, 'learner') learner = self._builder.make_learner(random_key, networks, iterator, replay, counter) return savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=self._checkpointing_config.directory, add_uid=self._checkpointing_config.add_uid, max_to_keep=self._checkpointing_config.max_to_keep)
def build_learner( random_key: networks_lib.PRNGKey, replay: reverb.Client, counter: Optional[counting.Counter] = None, primary_learner: Optional[core.Learner] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) iterator = experiment.builder.make_dataset_iterator(replay) # make_dataset_iterator is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner(random_key, networks, iterator, experiment.logger_factory, spec, replay, counter) if primary_learner is None: learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) else: learner.restore(primary_learner.save()) # NOTE: This initially synchronizes secondary learner states with the # primary one. Further synchronization should be handled by the learner # properly doing a pmap/pmean on the loss/gradients, respectively. return learner
def build_learner( random_key: networks_lib.PRNGKey, counter: Optional[counting.Counter] = None, ): """The Learning part of the agent.""" dummy_seed = 1 spec = (experiment.environment_spec or specs.make_environment_spec( experiment.environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = experiment.network_factory(spec) dataset_key, random_key = jax.random.split(random_key) iterator = experiment.demonstration_dataset_factory(dataset_key) # make_demonstrations is responsible for putting data onto appropriate # training devices, so here we apply prefetch, so that data is copied over # in the background. iterator = utils.prefetch(iterable=iterator, buffer_size=1) counter = counting.Counter(counter, 'learner') learner = experiment.builder.make_learner( random_key=random_key, networks=networks, dataset=iterator, logger_fn=experiment.logger_factory, environment_spec=spec, counter=counter) learner = savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, directory=checkpointing_config.directory, add_uid=checkpointing_config.add_uid, max_to_keep=checkpointing_config.max_to_keep) return learner
def _initialize_train(self): """Initialize train. This includes initializing the input pipeline and Byol's state. """ self._train_input = acme_utils.prefetch(self._build_train_input()) # Check we haven't already restored params if self._byol_state is None: logging.info( 'Initializing parameters rather than restoring from checkpoint.' ) # initialize Byol and setup optimizer state inputs = next(self._train_input) init_byol = jax.pmap(self._make_initial_state, axis_name='i') # Init uses the same RNG key on all hosts+devices to ensure everyone # computes the same initial state and parameters. init_rng = jax.random.PRNGKey(self._random_seed) init_rng = helpers.bcast_local_devices(init_rng) self._byol_state = init_byol(rng=init_rng, dummy_input=inputs)
def learner( self, random_key, replay, counter, ): """The Learning part of the agent.""" if self._builder._config.env_name.startswith('offline_ant'): # pytype: disable=attribute-error, pylint: disable=protected-access adder = self._builder.make_adder(replay) env = self._environment_factory(0) dataset = env.get_dataset() # pytype: disable=attribute-error for t in tqdm.trange(dataset['observations'].shape[0]): discount = 1.0 if t == 0 or dataset['timeouts'][t - 1]: step_type = dm_env.StepType.FIRST elif dataset['timeouts'][t]: step_type = dm_env.StepType.LAST discount = 0.0 else: step_type = dm_env.StepType.MID ts = dm_env.TimeStep( step_type=step_type, reward=dataset['rewards'][t], discount=discount, observation=np.concatenate([dataset['observations'][t], dataset['infos/goal'][t]]), ) if t == 0 or dataset['timeouts'][t - 1]: adder.add_first(ts) # pytype: disable=attribute-error else: adder.add(action=dataset['actions'][t-1], next_timestep=ts) # pytype: disable=attribute-error if self._builder._config.local and t > 10_000: # pytype: disable=attribute-error, pylint: disable=protected-access break iterator = self._builder.make_dataset_iterator(replay) dummy_seed = 1 environment_spec = ( self._environment_spec or specs.make_environment_spec(self._environment_factory(dummy_seed))) # Creates the networks to optimize (online) and target networks. networks = self._network_factory(environment_spec) if self._prefetch_size > 1: # When working with single GPU we should prefetch to device for # efficiency. If running on TPU this isn't necessary as the computation # and input placement can be done automatically. For multi-gpu currently # the best solution is to pre-fetch to host although this may change in # the future. device = jax.devices()[0] if self._device_prefetch else None iterator = utils.prefetch( iterator, buffer_size=self._prefetch_size, device=device) else: logging.info('Not prefetching the iterator.') counter = counting.Counter(counter, 'learner') learner = self._builder.make_learner(random_key, networks, iterator, replay, counter) kwargs = {} if self._checkpointing_config: kwargs = vars(self._checkpointing_config) # Return the learning agent. return savers.CheckpointingRunner( learner, key='learner', subdirectory='learner', time_delta_minutes=5, **kwargs)
def __init__(self, network: networks.QNetwork, obs_spec: specs.Array, discount: float, importance_sampling_exponent: float, target_update_period: int, iterator: Iterator[reverb.ReplaySample], optimizer: optix.InitUpdate, rng: hk.PRNGSequence, max_abs_reward: float = 1., huber_loss_parameter: float = 1., replay_client: reverb.Client = None, counter: counting.Counter = None, logger: loggers.Logger = None): """Initializes the learner.""" # Transform network into a pure function. network = hk.transform(network) def loss(params: hk.Params, target_params: hk.Params, sample: reverb.ReplaySample): o_tm1, a_tm1, r_t, d_t, o_t = sample.data keys, probs = sample.info[:2] # Forward pass. q_tm1 = network.apply(params, o_tm1) q_t_value = network.apply(target_params, o_t) q_t_selector = network.apply(params, o_t) # Cast and clip rewards. d_t = (d_t * discount).astype(jnp.float32) r_t = jnp.clip(r_t, -max_abs_reward, max_abs_reward).astype(jnp.float32) # Compute double Q-learning n-step TD-error. batch_error = jax.vmap(rlax.double_q_learning) td_error = batch_error(q_tm1, a_tm1, r_t, d_t, q_t_value, q_t_selector) batch_loss = rlax.huber_loss(td_error, huber_loss_parameter) # Importance weighting. importance_weights = (1. / probs).astype(jnp.float32) importance_weights **= importance_sampling_exponent importance_weights /= jnp.max(importance_weights) # Reweight. mean_loss = jnp.mean(importance_weights * batch_loss) # [] priorities = jnp.abs(td_error).astype(jnp.float64) return mean_loss, (keys, priorities) def sgd_step( state: TrainingState, samples: reverb.ReplaySample ) -> Tuple[TrainingState, LearnerOutputs]: grad_fn = jax.grad(loss, has_aux=True) gradients, (keys, priorities) = grad_fn(state.params, state.target_params, samples) updates, new_opt_state = optimizer.update(gradients, state.opt_state) new_params = optix.apply_updates(state.params, updates) new_state = TrainingState(params=new_params, target_params=state.target_params, opt_state=new_opt_state, step=state.step + 1) outputs = LearnerOutputs(keys=keys, priorities=priorities) return new_state, outputs def update_priorities(outputs: LearnerOutputs): for key, priority in zip(outputs.keys, outputs.priorities): replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates={key: priority}) # Internalise agent components (replay buffer, networks, optimizer). self._replay_client = replay_client self._iterator = utils.prefetch(iterator) # Internalise the hyperparameters. self._target_update_period = target_update_period # Internalise logging/counting objects. self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialise parameters and optimiser state. initial_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_target_params = network.init( next(rng), utils.add_batch_dim(utils.zeros_like(obs_spec))) initial_opt_state = optimizer.init(initial_params) self._state = TrainingState(params=initial_params, target_params=initial_target_params, opt_state=initial_opt_state, step=0) self._forward = jax.jit(network.apply) self._sgd_step = jax.jit(sgd_step) self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__(self, network: networks_lib.FeedForwardNetwork, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, replay_table_name: str = adders.DEFAULT_PRIORITY_TABLE, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None, num_sgd_steps_per_step: int = 1): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra def postprocess_aux(extra: LossExtra) -> LossExtra: reverb_update = jax.tree_map( lambda a: jnp.reshape(a, (-1, *a.shape[2:])), extra.reverb_update) return extra._replace(metrics=jax.tree_map(jnp.mean, extra.metrics), reverb_update=reverb_update) self._num_sgd_steps_per_step = num_sgd_steps_per_step sgd_step = utils.process_multiple_batches(sgd_step, num_sgd_steps_per_step, postprocess_aux) self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Do not record timestamps until after the first learning step is done. # This is to avoid including the time it takes for actors to come online and # fill the replay buffer. self._timestamp = None # Initialize the network parameters key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params) initial_target_params = self.network.init(key_target) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: ReverbUpdate) -> None: if replay_client is None: return keys, priorities = tree.map_structure( utils.fetch_devicearray, (reverb_update.keys, reverb_update.priorities)) replay_client.mutate_priorities(table=replay_table_name, updates=dict(zip(keys, priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def run_experiment(experiment: config.ExperimentConfig, eval_every: int = 100, num_eval_episodes: int = 1): """Runs a simple, single-threaded training loop using the default evaluators. It targets simplicity of the code and so only the basic features of the ExperimentConfig are supported. Arguments: experiment: Definition and configuration of the agent to run. eval_every: After how many actor steps to perform evaluation. num_eval_episodes: How many evaluation episodes to execute at each evaluation step. """ key = jax.random.PRNGKey(experiment.seed) # Create the environment and get its spec. environment = experiment.environment_factory(experiment.seed) environment_spec = experiment.environment_spec or specs.make_environment_spec( environment) # Create the networks and policy. networks = experiment.network_factory(environment_spec) policy = config.make_policy( experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=False) # Create the replay server and grab its address. replay_tables = experiment.builder.make_replay_tables(environment_spec, policy) # Disable blocking of inserts by tables' rate limiters, as this function # executes learning (sampling from the table) and data generation # (inserting into the table) sequentially from the same thread # which could result in blocked insert making the algorithm hang. replay_tables, rate_limiters_max_diff = _disable_insert_blocking( replay_tables) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Parent counter allows to share step counts between train and eval loops and # the learner, so that it is possible to plot for example evaluator's return # value as a function of the number of training episodes. parent_counter = counting.Counter(time_delta=0.) # Create actor, and learner for generating, storing, and consuming # data respectively. dataset = experiment.builder.make_dataset_iterator(replay_client) # We always use prefetch, as it provides an iterator with additional # 'ready' method. dataset = utils.prefetch(dataset, buffer_size=1) learner_key, key = jax.random.split(key) learner = experiment.builder.make_learner( random_key=learner_key, networks=networks, dataset=dataset, logger_fn=experiment.logger_factory, environment_spec=environment_spec, replay_client=replay_client, counter=counting.Counter(parent_counter, prefix='learner', time_delta=0.)) adder = experiment.builder.make_adder(replay_client, environment_spec, policy) actor_key, key = jax.random.split(key) actor = experiment.builder.make_actor( actor_key, policy, environment_spec, variable_source=learner, adder=adder) # Create the environment loop used for training. train_counter = counting.Counter( parent_counter, prefix='train', time_delta=0.) train_logger = experiment.logger_factory('train', train_counter.get_steps_key(), 0) # Replace the actor with a LearningActor. This makes sure that every time # that `update` is called on the actor it checks to see whether there is # any new data to learn from and if so it runs a learner step. The rate # at which new data is released is controlled by the replay table's # rate_limiter which is created by the builder.make_replay_tables call above. actor = _LearningActor(actor, learner, dataset, replay_tables, rate_limiters_max_diff) train_loop = acme.EnvironmentLoop( environment, actor, counter=train_counter, logger=train_logger, observers=experiment.observers) if num_eval_episodes == 0: # No evaluation. Just run the training loop. train_loop.run(num_steps=experiment.max_num_actor_steps) return # Create the evaluation actor and loop. eval_counter = counting.Counter(parent_counter, prefix='eval', time_delta=0.) eval_logger = experiment.logger_factory('eval', eval_counter.get_steps_key(), 0) eval_policy = config.make_policy( experiment=experiment, networks=networks, environment_spec=environment_spec, evaluation=True) eval_actor = experiment.builder.make_actor( random_key=jax.random.PRNGKey(experiment.seed), policy=eval_policy, environment_spec=environment_spec, variable_source=learner) eval_loop = acme.EnvironmentLoop( environment, eval_actor, counter=eval_counter, logger=eval_logger, observers=experiment.observers) steps = 0 while steps < experiment.max_num_actor_steps: eval_loop.run(num_episodes=num_eval_episodes) steps += train_loop.run(num_steps=eval_every) eval_loop.run(num_episodes=num_eval_episodes)
def __init__(self, network: networks_lib.FeedForwardNetwork, obs_spec: specs.Array, loss_fn: LossFn, optimizer: optax.GradientTransformation, data_iterator: Iterator[reverb.ReplaySample], target_update_period: int, random_key: networks_lib.PRNGKey, replay_client: Optional[reverb.Client] = None, counter: Optional[counting.Counter] = None, logger: Optional[loggers.Logger] = None): """Initialize the SGD learner.""" self.network = network # Internalize the loss_fn with network. self._loss = jax.jit(functools.partial(loss_fn, self.network)) # SGD performs the loss, optimizer update and periodic target net update. def sgd_step( state: TrainingState, batch: reverb.ReplaySample) -> Tuple[TrainingState, LossExtra]: next_rng_key, rng_key = jax.random.split(state.rng_key) # Implements one SGD step of the loss and updates training state (loss, extra), grads = jax.value_and_grad( self._loss, has_aux=True)(state.params, state.target_params, batch, rng_key) extra.metrics.update({'total_loss': loss}) # Apply the optimizer updates updates, new_opt_state = optimizer.update(grads, state.opt_state) new_params = optax.apply_updates(state.params, updates) # Periodically update target networks. steps = state.steps + 1 target_params = rlax.periodic_update(new_params, state.target_params, steps, target_update_period) new_training_state = TrainingState(new_params, target_params, new_opt_state, steps, next_rng_key) return new_training_state, extra self._sgd_step = jax.jit(sgd_step) # Internalise agent components self._data_iterator = utils.prefetch(data_iterator) self._target_update_period = target_update_period self._counter = counter or counting.Counter() self._logger = logger or loggers.TerminalLogger('learner', time_delta=1.) # Initialize the network parameters dummy_obs = utils.add_batch_dim(utils.zeros_like(obs_spec)) key_params, key_target, key_state = jax.random.split(random_key, 3) initial_params = self.network.init(key_params, dummy_obs) initial_target_params = self.network.init(key_target, dummy_obs) self._state = TrainingState( params=initial_params, target_params=initial_target_params, opt_state=optimizer.init(initial_params), steps=0, rng_key=key_state, ) # Update replay priorities def update_priorities(reverb_update: Optional[ReverbUpdate]) -> None: if reverb_update is None or replay_client is None: return else: replay_client.mutate_priorities( table=adders.DEFAULT_PRIORITY_TABLE, updates=dict( zip(reverb_update.keys, reverb_update.priorities))) self._replay_client = replay_client self._async_priority_updater = async_utils.AsyncExecutor( update_priorities)
def __init__( self, seed: int, environment_spec: specs.EnvironmentSpec, builder: builders.GenericActorLearnerBuilder, networks: Any, policy_network: Any, workdir: Optional[str] = '~/acme', min_replay_size: int = 1000, samples_per_insert: float = 256.0, batch_size: int = 256, num_sgd_steps_per_step: int = 1, prefetch_size: int = 1, device_prefetch: bool = True, counter: Optional[counting.Counter] = None, checkpoint: bool = True, ): """Initialize the agent. Args: seed: A random seed to use for this layout instance. environment_spec: description of the actions, observations, etc. builder: builder defining an RL algorithm to train. networks: network objects to be passed to the learner. policy_network: function that given an observation returns actions. workdir: if provided saves the state of the learner and the counter (if the counter is not None) into workdir. min_replay_size: minimum replay size before updating. samples_per_insert: number of samples to take from replay for every insert that is made. batch_size: batch size for updates. num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. For performance reasons (especially to reduce TPU host-device transfer times) it is performance-beneficial to do multiple sgd updates at once, provided that it does not hurt the training, which needs to be verified empirically for each environment. prefetch_size: whether to prefetch iterator. device_prefetch: whether prefetching should happen to a device. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner and the counter (if the counter is not None). """ if prefetch_size < 0: raise ValueError(f'Prefetch size={prefetch_size} should be non negative') key = jax.random.PRNGKey(seed) # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec) replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client) def _is_reverb_queue(reverb_table: reverb.Table, reverb_client: reverb.Client) -> bool: """Returns True iff the Reverb Table is actually a queue.""" # TODO(sinopalnikov): make it more generic and check for a table that # needs special handling on update. info = reverb_client.server_info() table_info = info[reverb_table.name] is_queue = ( table_info.max_times_sampled == 1 and table_info.sampler_options.fifo and table_info.remover_options.fifo) return is_queue is_reverb_queue = any(_is_reverb_queue(table, replay_client) for table in replay_tables) dataset = builder.make_dataset_iterator(replay_client) if prefetch_size > 1: device = jax.devices()[0] if device_prefetch else None dataset = utils.prefetch(dataset, buffer_size=prefetch_size, device=device) learner_key, key = jax.random.split(key) learner = builder.make_learner( random_key=learner_key, networks=networks, dataset=dataset, replay_client=replay_client, counter=counter) if not checkpoint or workdir is None: self._checkpointer = None else: objects_to_save = {'learner': learner} if counter is not None: objects_to_save.update({'counter': counter}) self._checkpointer = savers.Checkpointer( objects_to_save, time_delta_minutes=30, subdirectory='learner', directory=workdir, add_uid=(workdir == '~/acme')) actor_key, key = jax.random.split(key) actor = builder.make_actor( actor_key, policy_network, adder, variable_source=learner) self._custom_update_fn = None if is_reverb_queue: # Reverb queue requires special handling on update: custom logic to # decide when it is safe to make a learner step. This is only needed for # the local agent, where the actor and the learner are running # synchronously and the learner will deadlock if it makes a step with # no data available. def custom_update(): should_update_actor = False # Run a number of learner steps (usually gradient steps). # TODO(raveman): This is wrong. When running multi-level learners, # different levels might have different batch sizes. Find a solution. while all(table.can_sample(batch_size) for table in replay_tables): learner.step() should_update_actor = True if should_update_actor: # "wait=True" to make it more onpolicy actor.update(wait=True) self._custom_update_fn = custom_update effective_batch_size = batch_size * num_sgd_steps_per_step super().__init__( actor=actor, learner=learner, min_observations=max(effective_batch_size, min_replay_size), observations_per_step=float(effective_batch_size) / samples_per_insert) # Save the replay so we don't garbage collect it. self._replay_server = replay_server
def _initialize_train(self, rng): """BYOL's _ExperimentState initialization. Args: rng: random number generator used to initialize parameters. If working in a multi device setup, this need to be a ShardedArray. dummy_input: a dummy image, used to compute intermediate outputs shapes. Returns: Initial EvalExperiment state. Raises: RuntimeError: invalid or empty checkpoint. """ self._train_input = acme_utils.prefetch(self._build_train_input()) # Check we haven't already restored params if self._experiment_state is None: inputs = next(self._train_input) if self._checkpoint_to_evaluate is not None: # Load params from checkpoint checkpoint_data = checkpointing.load_checkpoint( self._checkpoint_to_evaluate) if checkpoint_data is None: raise RuntimeError('Invalid checkpoint.') backbone_params = checkpoint_data[ 'experiment_state'].online_params backbone_state = checkpoint_data[ 'experiment_state'].online_state backbone_params = helpers.bcast_local_devices(backbone_params) backbone_state = helpers.bcast_local_devices(backbone_state) else: if not self._allow_train_from_scratch: raise ValueError( 'No checkpoint specified, but `allow_train_from_scratch` ' 'set to False') # Initialize with random parameters logging.info( 'No checkpoint specified, initializing the networks from scratch ' '(dry run mode)') backbone_params, backbone_state = jax.pmap( functools.partial(self.forward_backbone.init, is_training=True), axis_name='i')(rng=rng, inputs=inputs) init_experiment = jax.pmap(self._make_initial_state, axis_name='i') # Init uses the same RNG key on all hosts+devices to ensure everyone # computes the same initial state and parameters. init_rng = jax.random.PRNGKey(self._random_seed) init_rng = helpers.bcast_local_devices(init_rng) self._experiment_state = init_experiment( rng=init_rng, dummy_input=inputs, backbone_params=backbone_params, backbone_state=backbone_state) # Clear the backbone optimizer's state when the backbone is frozen. if self._freeze_backbone: self._experiment_state = _EvalExperimentState( backbone_params=self._experiment_state.backbone_params, classif_params=self._experiment_state.classif_params, backbone_state=self._experiment_state.backbone_state, backbone_opt_state=None, classif_opt_state=self._experiment_state.classif_opt_state, )
def __init__( self, seed: int, environment_spec: specs.EnvironmentSpec, builder: builders.ActorLearnerBuilder, networks: Any, policy_network: Any, learner_logger: Optional[loggers.Logger] = None, workdir: Optional[str] = '~/acme', batch_size: int = 256, num_sgd_steps_per_step: int = 1, prefetch_size: int = 1, counter: Optional[counting.Counter] = None, checkpoint: bool = True, ): """Initialize the agent. Args: seed: A random seed to use for this layout instance. environment_spec: description of the actions, observations, etc. builder: builder defining an RL algorithm to train. networks: network objects to be passed to the learner. policy_network: function that given an observation returns actions. learner_logger: logger used by the learner. workdir: if provided saves the state of the learner and the counter (if the counter is not None) into workdir. batch_size: batch size for updates. num_sgd_steps_per_step: how many sgd steps a learner does per 'step' call. For performance reasons (especially to reduce TPU host-device transfer times) it is performance-beneficial to do multiple sgd updates at once, provided that it does not hurt the training, which needs to be verified empirically for each environment. prefetch_size: whether to prefetch iterator. counter: counter object used to keep track of steps. checkpoint: boolean indicating whether to checkpoint the learner and the counter (if the counter is not None). """ if prefetch_size < 0: raise ValueError( f'Prefetch size={prefetch_size} should be non negative') key = jax.random.PRNGKey(seed) # Create the replay server and grab its address. replay_tables = builder.make_replay_tables(environment_spec, policy_network) # Disable blocking of inserts by tables' rate limiters, as LocalLayout # agents run inserts and sampling from the same thread and blocked insert # would result in a hang. new_tables = [] for table in replay_tables: rl_info = table.info.rate_limiter_info rate_limiter = reverb.rate_limiters.RateLimiter( samples_per_insert=rl_info.samples_per_insert, min_size_to_sample=rl_info.min_size_to_sample, min_diff=rl_info.min_diff, max_diff=sys.float_info.max) new_tables.append(table.replace(rate_limiter=rate_limiter)) replay_tables = new_tables replay_server = reverb.Server(replay_tables, port=None) replay_client = reverb.Client(f'localhost:{replay_server.port}') # Create actor, dataset, and learner for generating, storing, and consuming # data respectively. adder = builder.make_adder(replay_client, environment_spec, policy_network) dataset = builder.make_dataset_iterator(replay_client) # We always use prefetch, as it provides an iterator with additional # 'ready' method. dataset = utils.prefetch(dataset, buffer_size=prefetch_size) learner_key, key = jax.random.split(key) learner = builder.make_learner( random_key=learner_key, networks=networks, dataset=dataset, logger_fn=(lambda label, steps_key=None, task_instance=None: learner_logger), environment_spec=environment_spec, replay_client=replay_client, counter=counter) if not checkpoint or workdir is None: self._checkpointer = None else: objects_to_save = {'learner': learner} if counter is not None: objects_to_save.update({'counter': counter}) self._checkpointer = savers.Checkpointer( objects_to_save, time_delta_minutes=30, subdirectory='learner', directory=workdir, add_uid=(workdir == '~/acme')) actor_key, key = jax.random.split(key) actor = builder.make_actor(actor_key, policy_network, environment_spec, variable_source=learner, adder=adder) super().__init__(actor=actor, learner=learner, iterator=dataset, replay_tables=replay_tables) # Save the replay so we don't garbage collect it. self._replay_server = replay_server