예제 #1
0
파일: simple_dqn.py 프로젝트: deepmind/rlax
 def __init__(self, observation_spec, action_spec, epsilon_cfg,
              target_period, learning_rate):
     self._observation_spec = observation_spec
     self._action_spec = action_spec
     self._target_period = target_period
     # Neural net and optimiser.
     self._network = build_network(action_spec.num_values)
     self._optimizer = optax.adam(learning_rate)
     self._epsilon_by_frame = optax.polynomial_schedule(**epsilon_cfg)
     # Jitting for speed.
     self.actor_step = jax.jit(self.actor_step)
     self.learner_step = jax.jit(self.learner_step)
예제 #2
0
def get_default_schedules(pretraining=False):
    """Get schedules for learning rate, entropy, TDlambda."""
    if pretraining:
        return dict(
            learning_rate=optax.constant_schedule(5e-3),
            entropy=optax.constant_schedule(1e-3),
            td_lambda=optax.constant_schedule(0.2),
        )

    return dict(
        learning_rate=optax.exponential_decay(1e-3, 60_000, decay_rate=0.2),
        entropy=(
            lambda count: 1e-3 * 0.1 ** (count / 80_000) if count < 80_000 else -1e-2
        ),
        td_lambda=optax.polynomial_schedule(0.2, 0.8, power=1, transition_steps=60_000),
    )
예제 #3
0
def get_optimizer(config):
    warm_up_poly = optax.polynomial_schedule(
        init_value=1 / config['warmup_iter'],
        end_value=1,
        power=1,
        transition_steps=config['warmup_iter'])
    exp_decay = optax.exponential_decay(
        init_value=config['adam_lr'],
        transition_steps=config['decay_steps'],
        decay_rate=config['lr_decay_rate'],
        transition_begin=0)  #config['warmup_iter'])
    opt = optax.chain(
        # clip_by_global_norm(max_norm),
        optax.scale_by_adam(b1=config['adam_beta_1'],
                            b2=config['adam_beta_2'],
                            eps=config['adam_eps']),
        optax.scale_by_schedule(warm_up_poly),
        optax.scale_by_schedule(exp_decay),
        optax.scale(-1))
    return opt
예제 #4
0
 def build_optimizer(lr,
                     momentum,
                     steps_per_epoch,
                     n_epochs,
                     nesterov,
                     warmup_epochs=5):
     cosine_schedule = optax.cosine_decay_schedule(1,
                                                   decay_steps=n_epochs *
                                                   steps_per_epoch,
                                                   alpha=1e-10)
     warmup_schedule = optax.polynomial_schedule(
         init_value=0.0,
         end_value=1.0,
         power=1,
         transition_steps=warmup_epochs * steps_per_epoch,
     )
     schedule = lambda x: jnp.minimum(cosine_schedule(x), warmup_schedule(x)
                                      )
     optimizer = optax.sgd(lr, momentum, nesterov=nesterov)
     optimizer = optax.chain(optimizer, optax.scale_by_schedule(schedule))
     return optimizer
예제 #5
0
def warm_up_polynomial_schedule(
    base_learning_rate: float,
    end_learning_rate: float,
    decay_steps: int,
    warmup_steps: int,
    decay_power: float,
) -> Callable:
    """Please see uncertainty_baselines.schedules.WarmUpPolynomialSchedule.
  """
    poly_schedule = optax.polynomial_schedule(
        init_value=base_learning_rate,
        end_value=end_learning_rate,
        power=decay_power,
        transition_steps=decay_steps,
    )

    def schedule(step):
        lr = poly_schedule(step)
        indicator = jnp.maximum(0.0, jnp.sign(warmup_steps - step))
        warmup_lr = base_learning_rate * step / warmup_steps
        lr = warmup_lr * indicator + (1 - indicator) * lr
        return lr

    return schedule
예제 #6
0
파일: dqn.py 프로젝트: omardrwch/rlberry
    def __init__(
        self,
        env: types.Env,
        gamma: float = 0.99,
        batch_size: int = 64,
        chunk_size: int = 8,
        online_update_interval: int = 1,
        target_update_interval: int = 512,
        learning_rate: float = 0.001,
        epsilon_init: float = 1.0,
        epsilon_end: float = 0.05,
        epsilon_steps: int = 5000,
        max_replay_size: int = 100000,
        eval_interval: Optional[int] = None,
        max_episode_length: Optional[int] = None,
        lambda_: Optional[float] = None,
        net_constructor: Optional[Callable[..., hk.Module]] = None,
        net_kwargs: Optional[Mapping[str, Any]] = None,
        max_gradient_norm: float = 100.0,
        **kwargs
    ):
        AgentWithSimplePolicy.__init__(self, env, **kwargs)
        env = self.env
        self.rng_key = jax.random.PRNGKey(self.rng.integers(2**32).item())

        # checks
        if not isinstance(self.env.observation_space, spaces.Box):
            raise ValueError("DQN only implemented for Box observation spaces.")
        if not isinstance(self.env.action_space, spaces.Discrete):
            raise ValueError("DQN only implemented for Discrete action spaces.")

        # params
        self._gamma = gamma
        self._batch_size = batch_size
        self._chunk_size = chunk_size
        self._online_update_interval = online_update_interval
        self._target_update_interval = target_update_interval
        self._max_replay_size = max_replay_size
        self._eval_interval = eval_interval
        self._max_episode_length = max_episode_length or np.inf
        self._lambda = lambda_
        self._max_gradient_norm = max_gradient_norm

        #
        # Setup replay buffer
        #

        # define specs
        # TODO: generalize. Observation is taken from reset() because gym is
        # mixing things up (returning double instead of float)
        sample_obs = env.reset()
        try:
            obs_shape, obs_dtype = sample_obs.shape, sample_obs.dtype
        except AttributeError:  # in case sample_obs has no .shape attribute
            obs_shape, obs_dtype = (
                env.observation_space.shape,
                env.observation_space.dtype,
            )
        action_shape, action_dtype = env.action_space.shape, env.action_space.dtype

        self._replay_buffer = ReplayBuffer(
            self._batch_size,
            self._chunk_size,
            self._max_replay_size,
        )
        self._replay_buffer.setup_entry("actions", action_shape, action_dtype)
        self._replay_buffer.setup_entry("observations", obs_shape, obs_dtype)
        self._replay_buffer.setup_entry("next_observations", obs_shape, obs_dtype)
        self._replay_buffer.setup_entry("rewards", (), np.float32)
        self._replay_buffer.setup_entry("discounts", (), np.float32)
        self._replay_buffer.build()

        # initialize network and params
        net_constructor = net_constructor or nets.MLPQNetwork
        net_kwargs = net_kwargs or dict(
            num_actions=self.env.action_space.n, hidden_sizes=(64, 64)
        )
        net_ctor = functools.partial(net_constructor, **net_kwargs)
        self._q_net = hk.without_apply_rng(hk.transform(lambda x: net_ctor()(x)))

        self._dummy_obs = jnp.ones(self.env.observation_space.shape)

        self.rng_key, subkey1 = jax.random.split(self.rng_key)
        self.rng_key, subkey2 = jax.random.split(self.rng_key)

        self._all_params = AllParams(
            online=self._q_net.init(subkey1, self._dummy_obs),
            target=self._q_net.init(subkey2, self._dummy_obs),
        )

        # initialize optimizer and states
        self._optimizer = optax.chain(
            optax.clip_by_global_norm(self._max_gradient_norm),
            optax.adam(learning_rate),
        )
        self._all_states = AllStates(
            optimizer=self._optimizer.init(self._all_params.online),
            learner_steps=jnp.array(0),
            actor_steps=jnp.array(0),
        )

        # epsilon decay
        self._epsilon_schedule = optax.polynomial_schedule(
            init_value=epsilon_init,
            end_value=epsilon_end,
            transition_steps=epsilon_steps,
            transition_begin=0,
            power=1.0,
        )

        # update functions (jit)
        self.actor_step = jax.jit(self._actor_step)
        self.learner_step = jax.jit(self._learner_step)
예제 #7
0
 def _create_jax_schedule(self):
     import optax
     return optax.polynomial_schedule(init_value=self.initial_rate,
                                      end_value=self.final_rate,
                                      power=self.power,
                                      transition_steps=self.decay_steps)