def _build_networks_and_optimizer(self): self._rng, init_key = jax.random.split(self._rng) # We can reuse init_key safely for the action selection key # since it is only used for shape inference during initialization. self.network_params = self.network_def.init(init_key, self.state, init_key) self.network_optimizer = dqn_agent.create_optimizer(self._optimizer_name) self.optimizer_state = self.network_optimizer.init(self.network_params) # TODO(joshgreaves): Find a way to just copy the critic params self.target_params = self.network_params # \alpha network self.log_alpha = jnp.zeros(1) self.alpha_optimizer = dqn_agent.create_optimizer(self._optimizer_name) self.alpha_optimizer_state = self.alpha_optimizer.init(self.log_alpha)
def _build_networks_and_optimizer(self): self._rng, rng = jax.random.split(self._rng) self.online_params = self.network_def.init(rng, x=self.state, support=self._support) self.optimizer = dqn_agent.create_optimizer(self._optimizer_name) self.optimizer_state = self.optimizer.init(self.online_params) self.target_network_params = self.online_params
def _build_networks_and_optimizer(self): self._rng, rng = jax.random.split(self._rng) online_network_params = self.network_def.init(rng, x=self.state, rng=self._rng) optimizer_def = dqn_agent.create_optimizer(self._optimizer_name) self.optimizer = optimizer_def.create(online_network_params) self.target_network_params = copy.deepcopy(online_network_params)
def unbundle(self, checkpoint_dir, iteration_number, bundle_dictionary): """Restores the agent from a checkpoint. Restores the agent's Python objects to those specified in bundle_dictionary, and restores the TensorFlow objects to those specified in the checkpoint_dir. If the checkpoint_dir does not exist, will not reset the agent's state. Args: checkpoint_dir: str, path to the checkpoint saved. iteration_number: int, checkpoint version, used when restoring the replay buffer. bundle_dictionary: dict, containing additional Python objects owned by the agent. Returns: bool, True if unbundling was successful. """ try: # self._replay.load() will throw a NotFoundError if it does not find all # the necessary files. self._replay.load(checkpoint_dir, iteration_number) except tf.errors.NotFoundError: if not self.allow_partial_reload: # If we don't allow partial reloads, we will return False. return False logging.warning('Unable to reload replay buffer!') if bundle_dictionary is not None: self.state = bundle_dictionary['state'] self.training_steps = bundle_dictionary['training_steps'] self.network_params = bundle_dictionary['network_params'] self.network_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.optimizer_state = bundle_dictionary['optimizer_state'] self.target_params = bundle_dictionary['target_params'] self.log_alpha = bundle_dictionary['log_alpha'] self.alpha_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.alpha_optimizer_state = bundle_dictionary[ 'alpha_optimizer_state'] elif not self.allow_partial_reload: return False else: logging.warning("Unable to reload the agent's parameters!") return True
def _build_networks_and_optimizer(self): self._rng, active_rng, passive_rng = jax.random.split(self._rng, 3) # Initialize active networks. self.active_online_params = self.network_def.init(active_rng, x=self.state) self.active_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.active_optimizer_state = self.active_optimizer.init( self.active_online_params) self.active_target_params = self.active_online_params # Initialize passive network with the regular network. self.passive_online_params = self.network_def.init(passive_rng, x=self.state) self.passive_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.passive_optimizer_state = self.passive_optimizer.init( self.passive_online_params) self.passive_target_params = self.passive_online_params
def _build_networks_and_optimizer(self): rngs = jax.random.split(self._rng, num=5) self._rng, encoder_key, network_key, reward_key, dynamics_key = rngs self.network_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.encoder_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.reward_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.dynamics_optimizer = dqn_agent.create_optimizer( self._optimizer_name) self.alpha_optimizer = dqn_agent.create_optimizer(self._optimizer_name) # Initialize encoder network. self.encoder_params = self.encoder_network_def.init( encoder_key, self.state) self.encoder_optimizer_state = self.encoder_optimizer.init( self.encoder_params) # Create a sample latent state for initializing the SAC network. sample_z = jnp.zeros_like( self.encoder_network_def.apply(self.encoder_params, self.state).critic_z) # since it is only used for shape inference during initialization. self.network_params = self.network_def.init(network_key, sample_z, network_key) self.optimizer_state = self.network_optimizer.init(self.network_params) # Initialize reward and dynamics models. self.reward_params = self.reward_model_def.init(reward_key, sample_z) self.reward_optimizer_state = self.reward_optimizer.init( self.reward_params) # Sending a dummy action and key for initialization. self.dynamics_params = self.dynamics_model_def.init( dynamics_key, sample_z, jnp.zeros(self.action_shape), dynamics_key) self.dynamics_optimizer_state = self.dynamics_optimizer.init( self.dynamics_params) self.encoder_target_params = self.encoder_params self.target_params = self.network_params # \alpha network self.log_alpha = jnp.zeros(1) self.alpha_optimizer_state = self.alpha_optimizer.init(self.log_alpha)
def _build_networks_and_optimizer(self): self._rng, rng = jax.random.split(self._rng) self.online_params = self.network_def.init( rng, x=self.state, num_quantiles=self.num_tau_samples, rng=self._rng) self.optimizer = dqn_agent.create_optimizer(self._optimizer_name) self.optimizer_state = self.optimizer.init(self.online_params) self.target_network_params = self.online_params
def reload_jax_checkpoint(agent, bundle_dictionary): """Reload variables from a fully specified checkpoint.""" if bundle_dictionary is not None: agent.state = bundle_dictionary['state'] if isinstance(bundle_dictionary['online_params'], core.FrozenDict): agent.online_params = bundle_dictionary['online_params'] else: # Load pre-linen checkpoint. agent.online_params = core.FrozenDict({ 'params': flax_checkpoints.convert_pre_linen( bundle_dictionary['online_params']).unfreeze() }) # We recreate the optimizer with the new online weights. # pylint: disable=protected-access agent.optimizer = dqn_agent.create_optimizer(agent._optimizer_name) # pylint: enable=protected-access if 'optimizer_state' in bundle_dictionary: agent.optimizer_state = bundle_dictionary['optimizer_state'] else: agent.optimizer_state = agent.optimizer.init(agent.online_params) logging.info('Done restoring!')