def __init__(self, observation_spec, action_spec, epsilon_cfg, target_period, learning_rate): self._observation_spec = observation_spec self._action_spec = action_spec self._target_period = target_period # Neural net and optimiser. self._network = build_network(action_spec.num_values) self._optimizer = optax.adam(learning_rate) self._epsilon_by_frame = optax.polynomial_schedule(**epsilon_cfg) # Jitting for speed. self.actor_step = jax.jit(self.actor_step) self.learner_step = jax.jit(self.learner_step)
def get_default_schedules(pretraining=False): """Get schedules for learning rate, entropy, TDlambda.""" if pretraining: return dict( learning_rate=optax.constant_schedule(5e-3), entropy=optax.constant_schedule(1e-3), td_lambda=optax.constant_schedule(0.2), ) return dict( learning_rate=optax.exponential_decay(1e-3, 60_000, decay_rate=0.2), entropy=( lambda count: 1e-3 * 0.1 ** (count / 80_000) if count < 80_000 else -1e-2 ), td_lambda=optax.polynomial_schedule(0.2, 0.8, power=1, transition_steps=60_000), )
def get_optimizer(config): warm_up_poly = optax.polynomial_schedule( init_value=1 / config['warmup_iter'], end_value=1, power=1, transition_steps=config['warmup_iter']) exp_decay = optax.exponential_decay( init_value=config['adam_lr'], transition_steps=config['decay_steps'], decay_rate=config['lr_decay_rate'], transition_begin=0) #config['warmup_iter']) opt = optax.chain( # clip_by_global_norm(max_norm), optax.scale_by_adam(b1=config['adam_beta_1'], b2=config['adam_beta_2'], eps=config['adam_eps']), optax.scale_by_schedule(warm_up_poly), optax.scale_by_schedule(exp_decay), optax.scale(-1)) return opt
def build_optimizer(lr, momentum, steps_per_epoch, n_epochs, nesterov, warmup_epochs=5): cosine_schedule = optax.cosine_decay_schedule(1, decay_steps=n_epochs * steps_per_epoch, alpha=1e-10) warmup_schedule = optax.polynomial_schedule( init_value=0.0, end_value=1.0, power=1, transition_steps=warmup_epochs * steps_per_epoch, ) schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x) ) optimizer = optax.sgd(lr, momentum, nesterov=nesterov) optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule)) return optimizer
def warm_up_polynomial_schedule( base_learning_rate: float, end_learning_rate: float, decay_steps: int, warmup_steps: int, decay_power: float, ) -> Callable: """Please see uncertainty_baselines.schedules.WarmUpPolynomialSchedule. """ poly_schedule = optax.polynomial_schedule( init_value=base_learning_rate, end_value=end_learning_rate, power=decay_power, transition_steps=decay_steps, ) def schedule(step): lr = poly_schedule(step) indicator = jnp.maximum(0.0, jnp.sign(warmup_steps - step)) warmup_lr = base_learning_rate * step / warmup_steps lr = warmup_lr * indicator + (1 - indicator) * lr return lr return schedule
def __init__( self, env: types.Env, gamma: float = 0.99, batch_size: int = 64, chunk_size: int = 8, online_update_interval: int = 1, target_update_interval: int = 512, learning_rate: float = 0.001, epsilon_init: float = 1.0, epsilon_end: float = 0.05, epsilon_steps: int = 5000, max_replay_size: int = 100000, eval_interval: Optional[int] = None, max_episode_length: Optional[int] = None, lambda_: Optional[float] = None, net_constructor: Optional[Callable[..., hk.Module]] = None, net_kwargs: Optional[Mapping[str, Any]] = None, max_gradient_norm: float = 100.0, **kwargs ): AgentWithSimplePolicy.__init__(self, env, **kwargs) env = self.env self.rng_key = jax.random.PRNGKey(self.rng.integers(2**32).item()) # checks if not isinstance(self.env.observation_space, spaces.Box): raise ValueError("DQN only implemented for Box observation spaces.") if not isinstance(self.env.action_space, spaces.Discrete): raise ValueError("DQN only implemented for Discrete action spaces.") # params self._gamma = gamma self._batch_size = batch_size self._chunk_size = chunk_size self._online_update_interval = online_update_interval self._target_update_interval = target_update_interval self._max_replay_size = max_replay_size self._eval_interval = eval_interval self._max_episode_length = max_episode_length or np.inf self._lambda = lambda_ self._max_gradient_norm = max_gradient_norm # # Setup replay buffer # # define specs # TODO: generalize. Observation is taken from reset() because gym is # mixing things up (returning double instead of float) sample_obs = env.reset() try: obs_shape, obs_dtype = sample_obs.shape, sample_obs.dtype except AttributeError: # in case sample_obs has no .shape attribute obs_shape, obs_dtype = ( env.observation_space.shape, env.observation_space.dtype, ) action_shape, action_dtype = env.action_space.shape, env.action_space.dtype self._replay_buffer = ReplayBuffer( self._batch_size, self._chunk_size, self._max_replay_size, ) self._replay_buffer.setup_entry("actions", action_shape, action_dtype) self._replay_buffer.setup_entry("observations", obs_shape, obs_dtype) self._replay_buffer.setup_entry("next_observations", obs_shape, obs_dtype) self._replay_buffer.setup_entry("rewards", (), np.float32) self._replay_buffer.setup_entry("discounts", (), np.float32) self._replay_buffer.build() # initialize network and params net_constructor = net_constructor or nets.MLPQNetwork net_kwargs = net_kwargs or dict( num_actions=self.env.action_space.n, hidden_sizes=(64, 64) ) net_ctor = functools.partial(net_constructor, **net_kwargs) self._q_net = hk.without_apply_rng(hk.transform(lambda x: net_ctor()(x))) self._dummy_obs = jnp.ones(self.env.observation_space.shape) self.rng_key, subkey1 = jax.random.split(self.rng_key) self.rng_key, subkey2 = jax.random.split(self.rng_key) self._all_params = AllParams( online=self._q_net.init(subkey1, self._dummy_obs), target=self._q_net.init(subkey2, self._dummy_obs), ) # initialize optimizer and states self._optimizer = optax.chain( optax.clip_by_global_norm(self._max_gradient_norm), optax.adam(learning_rate), ) self._all_states = AllStates( optimizer=self._optimizer.init(self._all_params.online), learner_steps=jnp.array(0), actor_steps=jnp.array(0), ) # epsilon decay self._epsilon_schedule = optax.polynomial_schedule( init_value=epsilon_init, end_value=epsilon_end, transition_steps=epsilon_steps, transition_begin=0, power=1.0, ) # update functions (jit) self.actor_step = jax.jit(self._actor_step) self.learner_step = jax.jit(self._learner_step)
def _create_jax_schedule(self): import optax return optax.polynomial_schedule(init_value=self.initial_rate, end_value=self.final_rate, power=self.power, transition_steps=self.decay_steps)