Example #1
0
  def test_batch(self):
    """Test that batch layer is indeed ignored.

    Code taken from: https://github.com/google/flax/issues/932
    """
    key = jax.random.PRNGKey(0)
    x = jnp.ones((5, 4, 4, 3))
    y = jax.random.uniform(key, (5, 4, 4, 7))

    foo_vars = flax.core.unfreeze(Foo(filters=7, train=True).init(key, x))
    tx = optax.masked(optax.adam(1e-7), create_weight_decay_mask())

    @self.variant
    def train_step(params, x, y):
      y1, new_batch_stats = Foo(
          filters=7, train=True).apply(
              params, x, mutable=['batch_stats'])

      return jnp.abs(y - y1).sum(), new_batch_stats

    state = self.variant(tx.init)(foo_vars['params'])
    grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y)
    updates, state = self.variant(tx.update)(dict(grads['params']), state)

    chex.assert_trees_all_close(updates['BatchNorm_0'],
                                grads['params']['BatchNorm_0'])
Example #2
0
def train_with_bc(make_demonstrations: Callable[[int],
                                                Iterator[types.Transition]],
                  networks: networks_lib.FeedForwardNetwork,
                  loss: losses.Loss,
                  num_steps: int = 100000) -> networks_lib.Params:
    """Trains the given network with BC and returns the params.

  Args:
    make_demonstrations: A function (batch_size) -> iterator with demonstrations
      to be imitated.
    networks: Network taking (params, obs, is_training, key) as input
    loss: BC loss to use.
    num_steps: number of training steps

  Returns:
    The trained network params.
  """
    demonstration_iterator = make_demonstrations(256)

    learner = learning.BCLearner(network=networks,
                                 random_key=jax.random.PRNGKey(0),
                                 loss_fn=loss,
                                 demonstrations=demonstration_iterator,
                                 optimizer=optax.adam(1e-4),
                                 num_sgd_steps_per_step=1)

    # Train the agent
    for _ in range(num_steps):
        learner.step()

    return learner.get_variables(['policy'])[0]
Example #3
0
    def __init__(self, name, param_store=None, tensorboard_dir=None):
        env = make_env(name, tensorboard_dir)

        # function approximator
        self.q = coax.Q(forward_pass, env)
        self.q_targ = self.q.copy()

        # tracer and updater
        self.q_updater = coax.td_learning.QLearning(self.q,
                                                    q_targ=self.q_targ,
                                                    optimizer=optax.adam(3e-4))

        # schedule for beta parameter used in PrioritizedReplayBuffer
        self.buffer_beta = coax.utils.StepwiseLinearFunction((0, 0.4),
                                                             (1000000, 1))

        super().__init__(
            env=env,
            param_store=param_store,
            pi=coax.BoltzmannPolicy(self.q, temperature=0.015),
            tracer=coax.reward_tracing.NStep(n=1, gamma=0.99),
            buffer=(coax.experience_replay.PrioritizedReplayBuffer(
                capacity=1000000, alpha=0.6) if param_store is None else None),
            buffer_warmup=50000,
            name=name)
Example #4
0
  def make_learner(
      self,
      random_key: networks_lib.PRNGKey,
      networks: impala_networks.IMPALANetworks,
      dataset: Iterator[reverb.ReplaySample],
      logger_fn: loggers.LoggerFactory,
      environment_spec: specs.EnvironmentSpec,
      replay_client: Optional[reverb.Client] = None,
      counter: Optional[counting.Counter] = None,
  ) -> core.Learner:
    del environment_spec, replay_client

    optimizer = optax.chain(
        optax.clip_by_global_norm(self._config.max_gradient_norm),
        optax.adam(
            self._config.learning_rate,
            b1=self._config.adam_momentum_decay,
            b2=self._config.adam_variance_decay,
            eps=self._config.adam_eps,
            eps_root=self._config.adam_eps_root))

    return learning.IMPALALearner(
        networks=networks,
        iterator=dataset,
        optimizer=optimizer,
        random_key=random_key,
        discount=self._config.discount,
        entropy_cost=self._config.entropy_cost,
        baseline_cost=self._config.baseline_cost,
        max_abs_reward=self._config.max_abs_reward,
        counter=counter,
        logger=logger_fn('learner'),
    )
Example #5
0
    def test_optimizer_epoch(self):
        optax_op = optax.adam(1e-3)
        lr_schedule = lambda step, epoch: epoch

        optimizer = elegy.Optimizer(optax_op,
                                    lr_schedule=lr_schedule,
                                    steps_per_epoch=2)

        params = np.random.uniform((3, 4))
        grads = np.random.uniform((3, 4))
        rng = elegy.RNGSeq(42)

        optimizer_states = optimizer.init(
            rng=rng,
            net_params=params,
        )

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 0)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 0)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 1)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)

        assert jnp.allclose(optimizer.current_lr(optimizer_states), 1)
        params, optimizer_states = optimizer.apply(params, grads,
                                                   optimizer_states, rng)
Example #6
0
def sparsify_basis(Q,lr=1e-2): #(n,r)
    """ Convenience function to attempt to sparsify a given basis by applying an orthogonal transformation
        W, Q' = QW where Q' has only 1s, 0s and -1s. Notably this method does not have the same convergence
        gauruntees of krylov_constraint_solve and can fail (even silently). Intended to be used only for
        visualization purposes, use at your own risk. """
    W = np.random.randn(Q.shape[-1],Q.shape[-1])
    W,_ = np.linalg.qr(W)
    W = device_put(W.astype(jnp.float32))
    opt_init,opt_update = optax.adam(lr)#optax.sgd(1e2,.9)#optax.adam(lr)#optax.sgd(3e-3,.9)#optax.adam(lr)
    opt_update = jit(opt_update)
    opt_state = opt_init(W)  # init stats

    def loss(W):
        return jnp.abs([email protected]).mean() + .1*(jnp.abs([email protected](W.shape[0]))).mean()+.01*jax.numpy.linalg.slogdet(W)[1]**2

    loss_and_grad = jit(jax.value_and_grad(loss))

    for i in tqdm(range(3000),desc=f'sparsifying basis'):
        lossval, grad = loss_and_grad(W)
        updates, opt_state = opt_update(grad, opt_state, W)
        W = optax.apply_updates(W, updates)
        #W,_ = np.linalg.qr(W)
        if lossval>1e2 and i>100: # Solve diverged due to too high learning rate
            logging.warning(f"basis sparsification diverged, trying lower learning rate {lr/3:.2e}")
            return sparsify_basis(Q,lr=lr/3)
    Q = np.copy([email protected])
    Q[np.abs(Q)<1e-2]=0
    Q[np.abs(Q)>1e-2] /= np.abs(Q[np.abs(Q)>1e-2])
    A = Q@(1+np.arange(Q.shape[-1]))
    if len(np.unique(np.abs(A)))!=Q.shape[-1]+1 and len(np.unique(np.abs(A)))!=Q.shape[-1]:
        logging.error(f"Basis elems did not separate: found only {len(np.unique(np.abs(A)))}/{Q.shape[-1]}")
        #raise ConvergenceError(f"Basis elems did not separate: found only {len(np.unique(A))}/{Q.shape[-1]}")
    return Q
Example #7
0
def main():
    # Create the dataset.
    train_dataset, vocab_size = load(batch_size, sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = GradientUpdater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
Example #8
0
def train(  # pylint: disable=invalid-name
    Phi,
    Psi,
    num_epochs,
    learning_rate,
    key,
    estimator,
    alpha,
    optimizer,
    use_l2_reg,
    reg_coeff,
    use_penalty,
    j,
    num_rows,
    skipsize=1):
  """Training function."""
  Phis = [Phi]  # pylint: disable=invalid-name
  grads = []
  if optimizer == 'sgd':
    optim = optax.sgd(learning_rate)
  elif optimizer == 'adam':
    optim = optax.adam(learning_rate)
  opt_state = optim.init(Phi)
  for i in tqdm(range(num_epochs)):
    key, subkey = jax.random.split(key)
    Phi, opt_state, grad = estimates.nabla_phi_analytical(
        Phi, Psi, subkey, optim, opt_state, estimator, alpha, use_l2_reg,
        reg_coeff, use_penalty, j, num_rows)
    Phis.append(Phi)
    grads.append(grad)
    if i % skipsize == 0:
      Phis.append(Phi)
      grads.append(grad)
  return jnp.stack(Phis), jnp.stack(grads)
Example #9
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: networks_lib.FeedForwardNetwork,
        dataset: Iterator[reverb.ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: Optional[specs.EnvironmentSpec],
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec

        return learning_lib.SGDLearner(
            network=networks,
            random_key=random_key,
            optimizer=optax.adam(self._config.learning_rate,
                                 eps=self._config.adam_eps),
            target_update_period=self._config.target_update_period,
            data_iterator=dataset,
            loss_fn=self._loss_fn,
            replay_client=replay_client,
            replay_table_name=self._config.replay_table_name,
            counter=counter,
            num_sgd_steps_per_step=self._config.num_sgd_steps_per_step,
            logger=logger_fn('learner'))
Example #10
0
def test_beta_bernoulli(elbo):
    data = jnp.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = numpyro.sample("beta", dist.Beta(1.0, 1.0))
        numpyro.sample("obs", dist.Bernoulli(f), obs=data)

    def guide(data):
        alpha_q = numpyro.param("alpha_q",
                                1.0,
                                constraint=constraints.positive)
        beta_q = numpyro.param("beta_q", 1.0, constraint=constraints.positive)
        numpyro.sample("beta", dist.Beta(alpha_q, beta_q))

    adam = optax.adam(0.05)
    svi = SVI(model, guide, adam, elbo)
    svi_state = svi.init(random.PRNGKey(1), data)
    assert_allclose(
        svi.optim.get_params(svi_state.optim_state)["alpha_q"], 0.0)

    def body_fn(i, val):
        svi_state, _ = svi.update(val, data)
        return svi_state

    svi_state = fori_loop(0, 2000, body_fn, svi_state)
    params = svi.get_params(svi_state)
    assert_allclose(
        params["alpha_q"] / (params["alpha_q"] + params["beta_q"]),
        0.8,
        atol=0.05,
        rtol=0.05,
    )
    def test_fitting_surrogate_posterior_stateless(self):
        if not JAX_MODE:
            self.skipTest('Requires optax.')
        import optax  # pylint: disable=g-import-not-at-top

        prior_dist = self.make_prior_dist()
        observations = self.get_observations(prior_dist)
        init_fn, build_surrogate_posterior_fn = (
            tfp.experimental.vi.build_asvi_surrogate_posterior_stateless(
                prior=prior_dist))
        target_log_prob = self.get_target_log_prob(observations, prior_dist)

        def loss_fn(*params, seed=None):
            surrogate_posterior = build_surrogate_posterior_fn(*params)
            zs, q_lp = surrogate_posterior.experimental_sample_and_log_prob(
                10, seed=seed)
            return tf.reduce_mean(q_lp - target_log_prob(*zs), axis=0)

        # Test vi fit surrogate posterior works
        optimized_params, _ = tfp.math.minimize_stateless(
            loss_fn,
            init=init_fn(seed=test_util.test_seed()),
            num_steps=5,  # Don't optimize to completion.
            optimizer=optax.adam(0.1),
            seed=test_util.test_seed(sampler_type='stateless'))
        surrogate_posterior = build_surrogate_posterior_fn(optimized_params)
        surrogate_posterior.sample(
            100, seed=test_util.test_seed(sampler_type='stateless'))
Example #12
0
def test_batch_overfit(train_dataset):
    vocab_size, d_model, num_heads, num_layers = 100, 32, 8, 1
    dropout_rate, grad_clip_value, learning_rate = 0.01, 0.25, 2e-2
    max_iter = 100

    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, d_model, num_heads, num_layers,
                                  dropout_rate, max_iter)

    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(grad_clip_value),
                            optax.adam(learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)

    # Initialize parameters.
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    for step in range(100):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)

    assert metrics['loss'] < 0.1
    assert metrics['step'] == 99
Example #13
0
    def make_learner(
        self,
        random_key: networks_lib.PRNGKey,
        networks: r2d2_networks.R2D2Networks,
        dataset: Iterator[r2d2_learning.R2D2ReplaySample],
        logger_fn: loggers.LoggerFactory,
        environment_spec: specs.EnvironmentSpec,
        replay_client: Optional[reverb.Client] = None,
        counter: Optional[counting.Counter] = None,
    ) -> core.Learner:
        del environment_spec

        # The learner updates the parameters (and initializes them).
        return r2d2_learning.R2D2Learner(
            unroll=networks.unroll,
            initial_state=networks.initial_state,
            batch_size=self._batch_size_per_device,
            random_key=random_key,
            burn_in_length=self._config.burn_in_length,
            discount=self._config.discount,
            importance_sampling_exponent=(
                self._config.importance_sampling_exponent),
            max_priority_weight=self._config.max_priority_weight,
            target_update_period=self._config.target_update_period,
            iterator=dataset,
            optimizer=optax.adam(self._config.learning_rate),
            bootstrap_n=self._config.bootstrap_n,
            tx_pair=self._config.tx_pair,
            clip_rewards=self._config.clip_rewards,
            replay_client=replay_client,
            counter=counter,
            logger=logger_fn('learner'))
Example #14
0
def main(batch_size: int = 64, k: int = 5, debug: bool = False):

    noise = np.float32(np.random.normal(size=(3000, 1)))  # random noise
    y_train = np.float32(np.random.uniform(-10.5, 10.5, (1, 3000))).T
    X_train = np.float32(
        np.sin(0.75 * y_train) * 7.0 + y_train * 0.5 + noise * 1.0)

    X_train = X_train / np.abs(X_train.max())
    y_train = y_train / np.abs(y_train.max())

    visualize_data(X_train, y_train)

    model = elegy.Model(module=MixtureModel(k=k),
                        loss=MixtureNLL(),
                        optimizer=optax.adam(3e-4))

    model.summary(X_train[:batch_size], depth=1)

    model.fit(
        x=X_train,
        y=y_train,
        epochs=500,
        batch_size=batch_size,
        shuffle=True,
    )

    visualize_model(X_train, y_train, model, k)
Example #15
0
def optimize_club(num_steps: int):
    """Solves the karte club problem by optimizing the assignments of students."""
    network = hk.without_apply_rng(hk.transform(network_definition))
    zacharys_karate_club = get_zacharys_karate_club()
    labels = get_ground_truth_assignments_for_zacharys_karate_club()
    params = network.init(jax.random.PRNGKey(42), zacharys_karate_club)

    @jax.jit
    def prediction_loss(params):
        decoded_nodes = network.apply(params, zacharys_karate_club)
        # We interpret the decoded nodes as a pair of logits for each node.
        log_prob = jax.nn.log_softmax(decoded_nodes)
        # The only two assignments we know a-priori are those of Mr. Hi (Node 0)
        # and John A (Node 33).
        return -(log_prob[0, 0] + log_prob[33, 1])

    opt_init, opt_update = optax.adam(1e-2)
    opt_state = opt_init(params)

    @jax.jit
    def update(params, opt_state):
        g = jax.grad(prediction_loss)(params)
        updates, opt_state = opt_update(g, opt_state)
        return optax.apply_updates(params, updates), opt_state

    @jax.jit
    def accuracy(params):
        decoded_nodes = network.apply(params, zacharys_karate_club)
        return jnp.mean(jnp.argmax(decoded_nodes, axis=1) == labels)

    for step in range(num_steps):
        logging.info("step %r accuracy %r", step, accuracy(params).item())
        params, opt_state = update(params, opt_state)
Example #16
0
def subspace_sampler(key, loglikelihood, logprior, params_init_tree, build_sampler, data, batch_size,
                     subspace_dim, nsamples, opt=optax.adam(learning_rate=0.1),
                     nsteps_full=0, nsteps_sub=0, projection_matrix=None, use_cv=True, pbar=True):
    subspace_key, sample_key = split(key)

    if nsteps_full > 0 or nsteps_sub > 0:
        # Find good control variate / starting point in subspace
        params_tree, params_sub, log_post_trace, subspace_fns = subspace_optimizer(
            subspace_key, loglikelihood, logprior, params_init_tree, data, batch_size,
            subspace_dim, nsteps_full, nsteps_sub, opt, pbar=pbar)
    else:
        params_sub = jax.random.normal(subspace_key, (subspace_dim,))
        params_init_flat, _ = jax.flatten_util.ravel_pytree(params_init_tree)
        full_dim = len(params_init_flat)
        if projection_matrix is None:
            projection_matrix = generate_random_basis(subspace_key, full_dim, subspace_dim)
        subspace_fns = make_subspace_fns(loglikelihood, logprior, params_init_tree, projection_matrix)

    loglik_sub, logprior_sub, subspace_to_pytree_fn = subspace_fns
    
    if use_cv:
        sampler_sub = build_sampler(loglikelihood=loglik_sub, logprior=logprior_sub, data=data, batch_size=batch_size,
                                    centering_value=params_sub, pbar=pbar)
    else:
        sampler_sub = build_sampler(loglikelihood=loglik_sub, logprior=logprior_sub, data=data,
                                    batch_size=batch_size, pbar=pbar)

    params_sub_samples = sampler_sub(sample_key, nsamples, params_sub)
    params_tree_samples = vmap(subspace_to_pytree_fn)(params_sub_samples)

    return params_tree_samples, params_sub_samples, subspace_fns
Example #17
0
    def __init__(self,
                 f,
                 f_targ=None,
                 optimizer=None,
                 loss_function=None,
                 policy_regularizer=None):

        self._f = f
        self._f_targ = f if f_targ is None else f_targ
        self.loss_function = huber if loss_function is None else loss_function

        if not isinstance(policy_regularizer, (Regularizer, type(None))):
            raise TypeError(
                f"policy_regularizer must be a Regularizer, got: {type(policy_regularizer)}"
            )
        self.policy_regularizer = policy_regularizer

        # optimizer
        self._optimizer = optax.adam(1e-3) if optimizer is None else optimizer
        self._optimizer_state = self.optimizer.init(self._f.params)

        def apply_grads_func(opt, opt_state, params, grads):
            updates, new_opt_state = opt.update(grads, opt_state, params)
            new_params = optax.apply_updates(params, updates)
            return new_opt_state, new_params

        self._apply_grads_func = jit(apply_grads_func, static_argnums=0)
Example #18
0
def subspace_optimizer(key, loglikelihood, logprior, params_init_tree, data, batch_size, subspace_dim, nwarmup,
                       nsteps, opt=optax.adam(learning_rate=0.1), projection_matrix=None, pbar=True):
    opt_key, subspace_key, sub_init_key, sub_opt_key = split(key, 4)

    # Find good anchor in full space during warmup phase
    if nwarmup > 0:
        optimizer = build_optax_optimizer(opt, loglikelihood, logprior, data, batch_size, pbar)
        params_init_tree, _ = optimizer(opt_key, nwarmup, params_init_tree)

    # Make Random subspace
    if projection_matrix is None:
        params_init_flat, _ = jax.flatten_util.ravel_pytree(params_init_tree)
        full_dim = len(params_init_flat)
        projection_matrix = generate_random_basis(subspace_key, full_dim, subspace_dim)
    # TODO: add SVD

    loglik_sub, logprior_sub, subspace_to_pytree_fn = make_subspace_fns(
        loglikelihood, logprior, params_init_tree, projection_matrix)
    subspace_fns = (loglik_sub, logprior_sub, subspace_to_pytree_fn)

    # Do subspace optimization starting from rnd location
    params_subspace = jax.random.normal(sub_init_key, (subspace_dim,))
    optimizer_sub = build_optax_optimizer(opt, loglik_sub, logprior_sub, data, batch_size, pbar)

    params_subspace, log_post_trace = optimizer_sub(sub_opt_key, nsteps, params_subspace)
    params_tree = subspace_to_pytree_fn(params_subspace)

    return params_tree, params_subspace, log_post_trace, subspace_fns
Example #19
0
def main(_):
    # Create the dataset.
    train_dataset, vocab_size = dataset.load(FLAGS.batch_size,
                                             FLAGS.sequence_length)
    # Set up the model, loss, and updater.
    forward_fn = build_forward_fn(vocab_size, FLAGS.d_model, FLAGS.num_heads,
                                  FLAGS.num_layers, FLAGS.dropout_rate)
    forward_fn = hk.transform(forward_fn)
    loss_fn = functools.partial(lm_loss_fn, forward_fn.apply, vocab_size)

    optimizer = optax.chain(optax.clip_by_global_norm(FLAGS.grad_clip_value),
                            optax.adam(FLAGS.learning_rate, b1=0.9, b2=0.99))

    updater = Updater(forward_fn.init, loss_fn, optimizer)
    updater = CheckpointingUpdater(updater, FLAGS.checkpoint_dir)

    # Initialize parameters.
    logging.info('Initializing parameters...')
    rng = jax.random.PRNGKey(428)
    data = next(train_dataset)
    state = updater.init(rng, data)

    logging.info('Starting train loop...')
    prev_time = time.time()
    for step in range(MAX_STEPS):
        data = next(train_dataset)
        state, metrics = updater.update(state, data)
        # We use JAX runahead to mask data preprocessing and JAX dispatch overheads.
        # Using values from state/metrics too often will block the runahead and can
        # cause these overheads to become more prominent.
        if step % LOG_EVERY == 0:
            steps_per_sec = LOG_EVERY / (time.time() - prev_time)
            prev_time = time.time()
            metrics.update({'steps_per_sec': steps_per_sec})
            logging.info({k: float(v) for k, v in metrics.items()})
Example #20
0
def train_model(workdir):
    """Train for a fixed number of steps and decode during training."""

    key = jax.random.PRNGKey(0)

    key, init_key = jax.random.split(key)
    model = Seq2seq(teacher_force=False, hidden_size=FLAGS.hidden_size)
    params = get_initial_params(model, init_key)
    tx = optax.adam(FLAGS.learning_rate)
    state = train_state.TrainState.create(apply_fn=model.apply,
                                          params=params,
                                          tx=tx)

    writer = metric_writers.create_default_writer(workdir)
    for step in range(FLAGS.num_train_steps):
        key, lstm_key = jax.random.split(key)
        batch = get_batch(FLAGS.batch_size)
        state, metrics = train_step(state, batch, lstm_key)
        if step % FLAGS.decode_frequency == 0:
            writer.write_scalars(step, metrics)
            key, lstm_key = jax.random.split(key)
            batch = get_batch(5)
            decode_batch(state.params, batch, lstm_key)

    return state
Example #21
0
def adam(learning_rate: ScalarOrSchedule,
         b1: float = 0.9,
         b2: float = 0.999,
         eps: float = 1e-8,
         eps_root: float = 0.0) -> Optimizer:
  """The classic Adam optimiser.

  Adam is an SGD variant with learning rate adaptation. The `learning_rate`
  used for each weight is computed from estimates of first- and second-order
  moments of the gradients (using suitable exponential moving averages).

  References:
    [Kingma et al, 2014](https://arxiv.org/abs/1412.6980)

  Args:
    learning_rate: This is a fixed global scaling factor.
    b1: The exponential decay rate to track the first moment of past gradients.
    b2: The exponential decay rate to track the second moment of past gradients.
    eps: A small constant applied to denominator outside of the square root (as
      in the Adam paper) to avoid dividing by zero when rescaling.
    eps_root: A small constant applied to denominator inside the square root (as
      in RMSProp), to avoid dividing by zero when rescaling. This is needed for
      example when computing (meta-)gradients through Adam.

  Returns:
    The corresponding `Optimizer`.
  """
  return create_optimizer_from_optax(
      optax.adam(
          learning_rate=learning_rate, b1=b1, b2=b2, eps=eps,
          eps_root=eps_root))
Example #22
0
def main(_):
    optimizer = optax.adam(FLAGS.learning_rate)

    @jax.jit
    def update(params: hk.Params, prng_key: PRNGKey, opt_state: OptState,
               batch: Batch) -> Tuple[hk.Params, OptState]:
        """Single SGD update step."""
        grads = jax.grad(loss_fn)(params, prng_key, batch)
        updates, new_opt_state = optimizer.update(grads, opt_state)
        new_params = optax.apply_updates(params, updates)
        return new_params, new_opt_state

    prng_seq = hk.PRNGSequence(42)
    params = log_prob.init(next(prng_seq), np.zeros((1, *MNIST_IMAGE_SHAPE)))
    opt_state = optimizer.init(params)

    train_ds = load_dataset(tfds.Split.TRAIN, FLAGS.batch_size)
    valid_ds = load_dataset(tfds.Split.TEST, FLAGS.batch_size)

    for step in range(FLAGS.training_steps):
        params, opt_state = update(params, next(prng_seq), opt_state,
                                   next(train_ds))

        if step % FLAGS.eval_frequency == 0:
            val_loss = eval_fn(params, next(valid_ds))
            logging.info("STEP: %5d; Validation loss: %.3f", step, val_loss)
Example #23
0
    def test_basic_variational_fitting_stateless(self):
        if not JAX_MODE:
            self.skipTest('Uses `optax` for stateless optimization')
        import optax  # pylint: disable=g-import-not-at-top
        batch_shape = [2, 3]
        num_timesteps = 5
        num_inits = 10
        observed_time_series = self._build_tensor(
            np.random.randn(*(batch_shape + [num_timesteps])))

        model = self._build_model(observed_time_series)
        seed = test_util.test_seed(sampler_type='stateless')
        init_seed, fit_seed = tfp.random.split_seed(seed, n=2)

        init_fn, build_surrogate_fn = (
            tfp.sts.build_factored_surrogate_posterior_stateless(
                model, batch_shape=num_inits))
        jd = model.joint_distribution(
            observed_time_series=observed_time_series)
        _, loss_curve = tfp.vi.fit_surrogate_posterior_stateless(
            jd.log_prob,
            build_surrogate_posterior_fn=build_surrogate_fn,
            initial_parameters=init_fn(init_seed),
            sample_size=3,
            num_steps=10,
            optimizer=optax.adam(1e-1),
            jit_compile=True,
            seed=fit_seed)
        self.assertLess(np.mean(loss_curve[-1]), np.mean(loss_curve[0]))
Example #24
0
 def test_sdp_dual_simple_no_crash(self, model_type):
     verif_instance = test_utils.make_toy_verif_instance(seed=0,
                                                         target_label=1,
                                                         label=2,
                                                         nn=model_type)
     kwargs = {
         'key': jax.random.PRNGKey(0),
         'opt': optax.adam(1e-3),
         'num_steps': 10,
         'eval_every': 5,
         'verbose': False,
         'use_exact_eig_eval': False,
         'use_exact_eig_train': False,
         'n_iter_lanczos': 5,
         'kappa_reg_weight': 1e-5,
         'kappa_zero_after': 8,
         'device_type': None,
     }
     verif_instance = utils.make_sdp_verif_instance(verif_instance)
     # Check all kwargs work.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    **kwargs)
     assert isinstance(dual_val, float)
     # Check code runs without kwargs.
     dual_val, _ = sdp_verify.solve_sdp_dual_simple(verif_instance,
                                                    num_steps=5)
     assert isinstance(dual_val, float)
Example #25
0
def create_train_state(config, rng, init_samples):
    """Creates the training state."""
    model = create_model(config)
    params = model.init(rng, *init_samples)
    tx = optax.adam(learning_rate=config.learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply,
                                         params=params,
                                         tx=tx)
Example #26
0
def main(_):
    # Create an environment and grab the spec.
    environment = bc_utils.make_environment()
    environment_spec = specs.make_environment_spec(environment)

    # Unwrap the environment to get the demonstrations.
    dataset = bc_utils.make_demonstrations(environment.environment,
                                           FLAGS.batch_size)
    dataset = dataset.as_numpy_iterator()

    # Create the networks to optimize.
    network = bc_utils.make_network(environment_spec)

    key = jax.random.PRNGKey(FLAGS.seed)
    key, key1 = jax.random.split(key, 2)

    def logp_fn(logits, actions):
        logits_actions = jnp.sum(jax.nn.one_hot(actions, logits.shape[-1]) *
                                 logits,
                                 axis=-1)
        logits_actions = logits_actions - special.logsumexp(logits, axis=-1)
        return logits_actions

    loss_fn = bc.logp(logp_fn=logp_fn)

    learner = bc.BCLearner(network=network,
                           random_key=key1,
                           loss_fn=loss_fn,
                           optimizer=optax.adam(FLAGS.learning_rate),
                           demonstrations=dataset,
                           num_sgd_steps_per_step=1)

    def evaluator_network(params: hk.Params, key: jnp.DeviceArray,
                          observation: jnp.DeviceArray) -> jnp.DeviceArray:
        dist_params = network.apply(params, observation)
        return rlax.epsilon_greedy(FLAGS.evaluation_epsilon).sample(
            key, dist_params)

    actor_core = actor_core_lib.batched_feed_forward_to_actor_core(
        evaluator_network)
    variable_client = variable_utils.VariableClient(learner,
                                                    'policy',
                                                    device='cpu')
    evaluator = actors.GenericActor(actor_core,
                                    key,
                                    variable_client,
                                    backend='cpu')

    eval_loop = acme.EnvironmentLoop(environment=environment,
                                     actor=evaluator,
                                     logger=loggers.TerminalLogger(
                                         'evaluation', time_delta=0.))

    # Run the environment loop.
    while True:
        for _ in range(FLAGS.evaluate_every):
            learner.step()
        eval_loop.run(FLAGS.evaluation_episodes)
Example #27
0
    def test_chunking(self, relaxer):
        batch_size = 3
        input_size = 2
        hidden_size = 5
        final_size = 4

        input_shape = (batch_size, input_size)
        hidden_lay_weight_shape = (input_size, hidden_size)
        final_lay_weight_shape = (hidden_size, final_size)

        inp_lb, inp_ub = test_utils.sample_bounds(jax.random.PRNGKey(0),
                                                  input_shape,
                                                  minval=-1.,
                                                  maxval=1.)
        inp_bound = jax_verify.IntervalBound(inp_lb, inp_ub)

        hidden_lay_weight = jax.random.uniform(jax.random.PRNGKey(1),
                                               hidden_lay_weight_shape)
        final_lay_weight = jax.random.uniform(jax.random.PRNGKey(2),
                                              final_lay_weight_shape)

        def model_fun(inp):
            hidden = inp @ hidden_lay_weight
            act = jax.nn.relu(hidden)
            final = act @ final_lay_weight
            return final

        if isinstance(relaxer,
                      linear_bound_utils.ParameterizedLinearBoundsRelaxer):
            concretizing_transform = (
                backward_crown.OptimizingLinearBoundBackwardTransform(
                    relaxer,
                    backward_crown.CONCRETIZE_ARGS_PRIMITIVE,
                    optax.adam(1.e-3),
                    num_opt_steps=10))
        else:
            concretizing_transform = backward_crown.LinearBoundBackwardTransform(
                relaxer, backward_crown.CONCRETIZE_ARGS_PRIMITIVE)

        chunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=16)
        unchunked_concretizer = backward_crown.ChunkedBackwardConcretizer(
            concretizing_transform, max_chunk_size=0)

        chunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            chunked_concretizer)
        unchunked_algorithm = bound_utils.BackwardConcretizingAlgorithm(
            unchunked_concretizer)

        chunked_bound, _ = bound_propagation.bound_propagation(
            chunked_algorithm, model_fun, inp_bound)
        unchunked_bound, _ = bound_propagation.bound_propagation(
            unchunked_algorithm, model_fun, inp_bound)

        np.testing.assert_array_almost_equal(chunked_bound.lower,
                                             unchunked_bound.lower)
        np.testing.assert_array_almost_equal(chunked_bound.upper,
                                             unchunked_bound.upper)
Example #28
0
    def __init__(
        self,
        environment_spec: specs.EnvironmentSpec,
        network: networks_lib.FeedForwardNetwork,
        config: DQNConfig,
    ):
        """Initialize the agent."""
        # Data is communicated via reverb replay.
        reverb_replay = replay.make_reverb_prioritized_nstep_replay(
            environment_spec=environment_spec,
            n_step=config.n_step,
            batch_size=config.batch_size,
            max_replay_size=config.max_replay_size,
            min_replay_size=config.min_replay_size,
            priority_exponent=config.priority_exponent,
            discount=config.discount,
        )
        self._server = reverb_replay.server

        optimizer = optax.chain(
            optax.clip_by_global_norm(config.max_gradient_norm),
            optax.adam(config.learning_rate),
        )
        key_learner, key_actor = jax.random.split(
            jax.random.PRNGKey(config.seed))
        # The learner updates the parameters (and initializes them).
        learner = learning.DQNLearner(
            network=network,
            obs_spec=environment_spec.observations,
            random_key=key_learner,
            optimizer=optimizer,
            discount=config.discount,
            importance_sampling_exponent=config.importance_sampling_exponent,
            target_update_period=config.target_update_period,
            iterator=reverb_replay.data_iterator,
            replay_client=reverb_replay.client,
        )

        # The actor selects actions according to the policy.
        def policy(params: networks_lib.Params, key: jnp.ndarray,
                   observation: jnp.ndarray) -> jnp.ndarray:
            action_values = network.apply(params, observation)
            return rlax.epsilon_greedy(config.epsilon).sample(
                key, action_values)

        actor = actors.FeedForwardActor(
            policy=policy,
            rng=hk.PRNGSequence(key_actor),
            variable_client=variable_utils.VariableClient(learner, ''),
            adder=reverb_replay.adder)

        super().__init__(
            actor=actor,
            learner=learner,
            min_observations=max(config.batch_size, config.min_replay_size),
            observations_per_step=config.batch_size /
            config.samples_per_insert,
        )
Example #29
0
def main(logdir: str = "runs",
         steps_per_epoch: tp.Optional[int] = None,
         epochs: int = 10,
         batch_size: int = 32):

    platform = jax.local_devices()[0].platform
    ndevices = len(jax.devices())
    print('devices ', jax.devices())
    print('platform ', platform)

    current_time = datetime.now().strftime("%b%d_%H-%M-%S")
    logdir = os.path.join(logdir, current_time)

    dataset = load_dataset("mnist")
    dataset.set_format("np")
    X_train = dataset["train"]["image"][..., None]
    y_train = dataset["train"]["label"]
    X_test = dataset["test"]["image"][..., None]
    y_test = dataset["test"]["label"]

    accuracies = {}
    # we run distributed=False twice to remove any initial warmup costs
    for distributed in [False, False, True]:
        print(f'Distributed training = {distributed}')
        start_time = time.time()

        model = eg.Model(module=CNN(),
                         loss=eg.losses.Crossentropy(),
                         metrics=eg.metrics.Accuracy(),
                         optimizer=optax.adam(1e-3),
                         seed=42)

        if distributed:
            model = model.distributed()
            bs = batch_size  #int(batch_size / ndevices)
        else:
            bs = batch_size

        #model.summary(X_train[:64], depth=1)

        history = model.fit(inputs=X_train,
                            labels=y_train,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            batch_size=bs,
                            validation_data=(X_test, y_test),
                            shuffle=True,
                            verbose=3)

        ev = model.evaluate(x=X_test, y=y_test, verbose=1)
        print('eval ', ev)
        accuracies[distributed] = ev['accuracy']

        end_time = time.time()
        print(f'time taken ', {end_time - start_time})

    print(accuracies)
        def create_train_state(key, rng, batch_size, learning_rate):
            init_data = jnp.ones([batch_size, 28, 28, 1], jnp.float32)

            state = train_state.TrainState.create(
                apply_fn=self.model().apply,
                params=self.model().init(key, init_data, rng)['params'],
                tx=optax.adam(learning_rate),
            )
            return state