Exemple #1
0
class ExperimentalOptimizersEquivalenceTest(chex.TestCase):
    def setUp(self):
        super().setUp()
        self.init_params = (jnp.array([1., 2.]), jnp.array([3., 4.]))
        self.per_step_updates = (jnp.array([500., 5.]), jnp.array([300., 3.]))

    @chex.all_variants()
    @parameterized.named_parameters(
        ('sgd', alias.sgd(LR, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(
            LR, decay=.9, eps=0.1), optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum', alias.rmsprop(LR, decay=.9, eps=0.1,
                                           momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
        ('sgd', alias.sgd(LR_SCHED, 0.0), optimizers.sgd(LR), 1e-5),
        ('adam', alias.adam(LR_SCHED, 0.9, 0.999,
                            1e-8), optimizers.adam(LR, 0.9, 0.999), 1e-4),
        ('rmsprop', alias.rmsprop(LR_SCHED, decay=.9, eps=0.1),
         optimizers.rmsprop(LR, .9, 0.1), 1e-5),
        ('rmsprop_momentum',
         alias.rmsprop(LR_SCHED, decay=.9, eps=0.1, momentum=0.9),
         optimizers.rmsprop_momentum(LR, .9, 0.1, 0.9), 1e-5),
        ('adagrad', alias.adagrad(
            LR_SCHED,
            0.,
            0.,
        ), optimizers.adagrad(LR, 0.), 1e-5),
    )
    def test_jax_optimizer_equivalent(self, optax_optimizer, jax_optimizer,
                                      rtol):

        # experimental/optimizers.py
        jax_params = self.init_params
        opt_init, opt_update, get_params = jax_optimizer
        state = opt_init(jax_params)
        for i in range(STEPS):
            state = opt_update(i, self.per_step_updates, state)
            jax_params = get_params(state)

        # optax
        optax_params = self.init_params
        state = optax_optimizer.init(optax_params)

        @self.variant
        def step(updates, state):
            return optax_optimizer.update(updates, state)

        for _ in range(STEPS):
            updates, state = step(self.per_step_updates, state)
            optax_params = update.apply_updates(optax_params, updates)

        # Check equivalence.
        chex.assert_tree_all_close(jax_params, optax_params, rtol=rtol)
Exemple #2
0
def get_model_and_optimizer(model_params, hyper_params, rng):
    """ Trains the model for one batch.

       Args:
            model_params (list): Contains layer and activation weights for the forward function.
            hyper_params (dict): Hyperparameters of the model.
            rng (jax.random.PRNGKey): Random key for initializing parameter weights

        Returns:
             psi (func): forward function
             opt_update (func): Function to update the state.
             opt_state (func): Function
             get_params_from_opt,
             init_params (list): Start parameters.
             opt_init()

   """
    init_fun, psi = serial_InvNet(model_params,
                                  hyper_params)  # load model from serial
    _, init_params = init_fun(rng,
                              (hyper_params['batch_size'],
                               hyper_params['z_latent']))  # get initial params
    opt_init, opt_update, get_params_from_opt = adam(
        hyper_params['lr'])  # get optimizer
    opt_state = opt_init(init_params)
    return psi, opt_update, opt_state, get_params_from_opt, init_params, opt_init
Exemple #3
0
def test_beta_bernoulli(auto_class, rtol):
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = sample('beta', dist.Beta(1., 1.))
        sample('obs', dist.Bernoulli(f), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.08)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = lax.fori_loop(0, 300, body_fn, (opt_state, rng_train))
    median = guide.median(opt_state)
    assert_allclose(median['beta'], 0.8, rtol=rtol)
def train(rng: jnp.ndarray, bij_params: Sequence[jnp.ndarray], deq_params: Sequence[jnp.ndarray], num_steps: int, lr: float, num_samples: int) -> Tuple:
    """Train the ambient flow with the combined loss function.

    Args:
        rng: Pseudo-random number generator seed.
        bij_params: List of arrays parameterizing the RealNVP bijectors.
        bij_fns: List of functions that compute the shift and scale of the RealNVP
            affine transformation.
        deq_params: Parameters of the mean and scale functions used in
            the log-normal dequantizer.
        deq_fn: Function that computes the mean and scale of the dequantization
            distribution.
        num_steps: Number of gradient descent iterations.
        lr: Gradient descent learning rate.
        num_samples: Number of dequantization samples.

    Returns:
        out: A tuple containing the estimated parameters of the ambient flow
            density and the dequantization distribution. The other element is
            the trace of the loss function.

    """
    opt_init, opt_update, get_params = optimizers.adam(lr)
    def step(opt_state, it):
        step_rng = random.fold_in(rng, it)
        bij_params, deq_params = get_params(opt_state)
        loss_val, loss_grad = value_and_grad(loss, (1, 3))(step_rng, bij_params, bij_fns, deq_params, deq_fn, num_samples)
        loss_grad = tree_util.tree_map(partial(put.clip_and_zero_nans, clip_value=1.), loss_grad)
        opt_state = opt_update(it, loss_grad, opt_state)
        return opt_state, loss_val
    opt_state, trace = lax.scan(step, opt_init((bij_params, deq_params)), jnp.arange(num_steps))
    bij_params, deq_params = get_params(opt_state)
    return (bij_params, deq_params), trace
Exemple #5
0
def test_beta_bernoulli(auto_class):
    data = np.array([[1.0] * 8 + [0.0] * 2, [1.0] * 4 + [0.0] * 6]).T

    def model(data):
        f = sample('beta', dist.Beta(np.ones(2), np.ones(2)))
        sample('obs', dist.Bernoulli(f), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = auto_class(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    true_coefs = (np.sum(data, axis=0) + 1) / (data.shape[0] + 2)
    # test .sample_posterior method
    posterior_samples = guide.sample_posterior(random.PRNGKey(1),
                                               opt_state,
                                               sample_shape=(1000, ))
    assert_allclose(np.mean(posterior_samples['beta'], 0),
                    true_coefs,
                    atol=0.04)
Exemple #6
0
def main(args):
    # Generate some data.
    data = random.normal(PRNGKey(0), shape=(100,)) + 3.0

    # Construct an SVI object so we can do variational inference on our
    # model/guide pair.
    opt_init, opt_update, get_params = optimizers.adam(args.learning_rate)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update, get_params)
    rng = PRNGKey(0)
    opt_state = svi_init(rng, model_args=(data,))

    # Training loop
    rng, = random.split(rng, 1)

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i, opt_state_, rng_, model_args=(data,))
        return opt_state_, rng_

    opt_state, _ = lax.fori_loop(0, args.num_steps, body_fn, (opt_state, rng))

    # Report the final values of the variational parameters
    # in the guide after training.
    params = get_params(opt_state)
    for name, value in params.items():
        print("{} = {}".format(name, value))

    # For this simple (conjugate) model we know the exact posterior. In
    # particular we know that the variational distribution should be
    # centered near 3.0. So let's check this explicitly.
    assert np.abs(params["guide_loc"] - 3.0) < 0.1
Exemple #7
0
def get_initial_state(system,
                      rng,
                      generate_x_obs_seq_init,
                      dim_q,
                      tol,
                      adam_step_size=2e-1,
                      reg_coeff=5e-2,
                      coarse_tol=1e-1,
                      max_iters=1000,
                      max_num_tries=10):
    """Find an initial constraint satisying state.

    Uses a heuristic combination of gradient-based minimisation of the norm
    of a modified constraint function plus a subsequent projection step using a
    quasi-Newton method, to try to find an initial point `q` such that
    `max(abs(constr(q)) < tol`.
    """

    # Use optimizers to set optimizer initialization and update functions
    opt_init, opt_update, get_params = opt.adam(adam_step_size)

    # Define a compiled update step
    @api.jit
    def step(i, opt_state, x_obs_seq_init):
        q, = get_params(opt_state)
        (obj, constr), grad = system.value_and_grad_init_objective(
            q, x_obs_seq_init, reg_coeff)
        opt_state = opt_update(i, grad, opt_state)
        return opt_state, obj, constr

    for t in range(max_num_tries):
        logging.info(f'Starting try {t+1}')
        q_init = rng.standard_normal(dim_q)
        x_obs_seq_init = generate_x_obs_seq_init(rng)
        opt_state = opt_init((q_init, ))
        for i in range(max_iters):
            opt_state_next, norm, constr = step(i, opt_state, x_obs_seq_init)
            if not np.isfinite(norm):
                logger.info('Adam iteration diverged')
                break
            max_abs_constr = maximum_norm(constr)
            if max_abs_constr < coarse_tol:
                logging.info('Within coarse_tol attempting projection.')
                q_init, = get_params(opt_state)
                state = ConditionedDiffusionHamiltonianState(
                    q_init, x_obs_seq=x_obs_seq_init)
                try:
                    state = jitted_solve_projection_onto_manifold_quasi_newton(
                        state, state, 1., system, tol)
                except ConvergenceError:
                    logger.info('Quasi-Newton iteration diverged.')
                if np.max(np.abs(system.constr(state))) < tol:
                    logging.info('Found constraint satisfying state.')
                    state.mom = system.sample_momentum(state, rng)
                    return state
            if i % 100 == 0:
                logging.info(f'Iteration {i: >6}: mean|constr|^2 = {norm:.3e} '
                             f'max|constr| = {max_abs_constr:.3e}')
            opt_state = opt_state_next
    raise RuntimeError(f'Did not find valid state in {max_num_tries} tries.')
def train(params, data, gradients,
          epochs = 100, batch_size = 10, lr = 1e-3, shuffle = True, start = 0,
          lr_decay = 0, lr_decay_steps = 1,
          loggers = [], log_every = 10):

    if lr_decay:
        schedule = opt.inverse_time_decay(lr, lr_decay_steps, lr_decay, staircase=False)
    else:
        schedule = opt.constant(lr)

    opt_init, opt_update, get_params = opt.adam(schedule)
    update = get_opt_update(get_params, opt_update, gradients)
    opt_state = opt_init(params)

    for i in range(start, epochs):
        # TODO: shuffle in-place to reduce memory allocations (first, copy data)
        data, _ = shuffle_perm(data) if shuffle else (data, None)
        batches = batchify(data, batch_size)
        for data_batch in batches:
            opt_state = update(i, opt_state, data_batch)

        if i % log_every == 0:
            params = get_params(opt_state)
            logs = [log(params, i) for log in loggers]
            print(f"Epoch {i}", end=" ")
            for log in logs:
                print(f"[{log[0]} {float_log(log[1])}]", end=" ")
            print()

    return 0
Exemple #9
0
def train_model(lr, iters, train_data, test_data, name='', plot_groups=None):
    opt_init, opt_update, get_params = optimizers.adam(lr)
    opt_update = jit(opt_update)

    _, params = init_fn(rand_key, (-1, train_data[0].shape[-1]))
    opt_state = opt_init(params)

    train_psnrs = []
    test_psnrs = []
    xs = []
    if plot_groups is not None:
        plot_groups['Test PSNR'].append(f'{name}_test')
        plot_groups['Train PSNR'].append(f'{name}_train')
    for i in tqdm(range(iters), desc='train iter', leave=False):
        opt_state = opt_update(i, model_grad_loss(get_params(opt_state), *train_data), opt_state)
        if i % 25 == 0:
            train_psnr = model_psnr(get_params(opt_state), *train_data)
            test_psnr = model_psnr(get_params(opt_state), *test_data)
            train_psnrs.append(train_psnr)
            test_psnrs.append(test_psnr)
            xs.append(i)
            if plot_groups is not None:
                plotlosses_model.update({f'{name}_train':train_psnr, f'{name}_test':test_psnr}, current_step=i)
        if i % 100 == 0 and i != 0 and plot_groups is not None:
            plotlosses_model.send()
    if plot_groups is not None:
        plotlosses_model.send()
    results = {
        'state': get_params(opt_state),
        'train_psnrs': train_psnrs,
        'test_psnrs': test_psnrs,
        'xs': xs
    }
    return results
Exemple #10
0
def test_svi():
    def model(key):
        n1 = func.sample('n1', dist.normal(jnp.array(10.), jnp.array(10.)),
                         key)
        return n1

    def q(params, key):
        n1 = func.sample('n1', dist.normal(params['n1_mean'],
                                           params['n1_std']), key)
        return {'n1': n1}

    optimizer = jax_optim.adam(0.05)
    svi = func.svi(model,
                   q,
                   func.elbo,
                   optimizer,
                   initial_params={
                       'n1_mean': jnp.array(0.),
                       'n1_std': jnp.array(1.)
                   })

    keys = jax.random.split(jax.random.PRNGKey(123), 500)
    for i in range(500):
        loss = svi(keys[i])
        if i % 100 == 0:
            print(f"Step {i}: {loss}")

    inferred_n1_mean = svi.get_param('n1_mean')
    inferred_n1_std = svi.get_param('n1_std')
    tu.check_close(inferred_n1_mean, 10.611818)
    tu.check_close(inferred_n1_std, 9.024648)
Exemple #11
0
def load_model(url):
    train_state = deepx.optimise.TrainState.restore(url)
    hparams = train_state.hparams
    model = deepx.resnet.ResNet(hparams.hidden_channels, 1, hparams.depth)
    opt = optimizers.adam(0.001)
    params = opt.params_fn(train_state.opt_state)
    return model, hparams, params
def main():

    net_init, net_apply = stax.serial(
        stax.Dense(128),
        stax.Softplus,
        stax.Dense(128),
        stax.Softplus,
        stax.Dense(2),
    )

    opt_init, opt_update, get_params = optimizers.adam(1e-3)

    out_shape, net_params = net_init(jax.random.PRNGKey(seed=42),
                                     input_shape=(-1, 2))
    opt_state = opt_init(net_params)

    loss_history = []

    print("Training...")

    train_step = get_train_step(opt_update, get_params, net_apply)

    for i in range(2000):
        x = sample_batch(size=128)
        loss, opt_state = train_step(i, opt_state, x)
        loss_history.append(loss.item())

    print("Training Finished...")

    plot_gradients(loss_history, opt_state, get_params, net_params, net_apply)
Exemple #13
0
def fit(vae, rng_key, data, data_vari, step_size=1e-3, max_iter=1000):
    '''
    Args:
      *data: array like (obs, features)
    '''

    start_params = vae.params
    opt_init, update_params, get_params = adam(step_size)
    opt_state = opt_init(start_params)
    history = []
    min_loss_params = (1e10, None)
    for i in trange(max_iter, smoothing=0):
        params = get_params(opt_state)
        rng_key, subkey = random.split(rng_key)
        loss, grads = value_and_grad(objective)(params, vae, subkey, data,
                                                data_vari)
        opt_state = update_params(i, grads, opt_state)
        if loss < min_loss_params[0]:
            min_loss_params = (loss, params)
        history.append(float(loss))

    vae = VAE(partial(vae.raw_encode, min_loss_params[1][0]), vae.raw_encode,
              partial(vae.raw_decode, min_loss_params[1][1]), vae.raw_decode,
              None, None, min_loss_params[1], vae.n_latent_dims, data_vari)

    return vae._replace(generate=partial(generate_samples, vae),
                        fit=partial(fit, vae)), history
def fit(params: Dict,
        sequences: List[str],
        n: int,
        step_size: float = 0.001) -> Dict:
    """
    Return weights fitted to predict the next letter in each sequence.

    The training loop is as follows.
    Per step in the training loop,
    we loop over each "length batch" of sequences and tune weights
    in order of the length of each sequence.
    For example, if we have sequences of length 302, 305, and 309,
    over K training epochs,
    we will perform 3xK updates,
    one step of updates for each length.

    To get batching of sequences by length done,
    we call on ``batch_sequences`` from our ``utils.py`` module,
    which returns a list of sub-lists,
    in which each sub-list contains the indices
    in the original list of sequences
    that are of a particular length.

    :param params: mLSTM1900 parameters.
    :param sequences: List of sequences to evotune on.
    :param n: The number of iterations to evotune on.
    """
    xs, ys = length_batch_input_outputs(sequences)

    init, update, get_params = adam(step_size=step_size)
    # optimizer_funcs = jit(update), jit(get_params)

    @jit
    def step(i, state):
        """
        Perform one step of evolutionary updating.

        This function is closed inside `fit` because we need access
        to the variables in its scope,
        particularly the update and get_params functions.

        By structuring the function this way, we can JIT-compile it,
        and thus gain a massive speed-up!

        :param i: The current iteration of the training loop.
        :param state: Current state of parameters from jax.
        """
        params = get_params(state)
        g = grad(evotune_loss)(params, x, y)
        state = update(i, g, state)
        return state

    state = init(params)
    from time import time

    for i in range(n):
        for x, y in zip(xs, ys):
            state = step(i, state)
            params = get_params(state)
    return get_params(state)
Exemple #15
0
    def fit(self, X):
        opt_init, opt_update, get_params = optimizers.adam(step_size=1e-3)
        opt_state = opt_init((self.encoder_params, self.decoder_params))

        def loss(params, inputs):
            encoder_params, decoder_params = params
            enc = self.encoder_apply(encoder_params, X)
            dec = self.decoder_apply(decoder_params, X)
            return np.square(inputs - dec).sum() + 1e-3 * np.abs(params).sum()

        @jit
        def step(i, opt_state, inputs):
            params = get_params(opt_state)
            gradient = grad(loss)(params, inputs)
            return opt_update(i, gradient, opt_state)

        print('Training autoencoder...')

        batch_size, itercount = 32, itertools.count()
        key = random.PRNGKey(0)
        for epoch in range(5):
            temp_key, key = random.split(key)
            X = random.permutation(temp_key, X)
            for batch_index in range(0, X.shape[0], batch_size):
                opt_state = step(next(itercount), opt_state,
                                 X[batch_index:batch_index + batch_size])

        self.encoder_params, self.decoder_params = get_params(opt_state)
Exemple #16
0
def test_uniform_normal():
    true_coef = 0.9
    data = true_coef + random.normal(random.PRNGKey(0), (1000, ))

    def model(data):
        alpha = sample('alpha', dist.Uniform(0, 1))
        loc = sample('loc', dist.Uniform(0, alpha))
        sample('obs', dist.Normal(loc, 0.1), obs=data)

    opt_init, opt_update, get_params = optimizers.adam(0.01)
    rng_guide, rng_init, rng_train = random.split(random.PRNGKey(1), 3)
    guide = AutoDiagonalNormal(rng_guide, model, get_params)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    opt_state, constrain_fn = svi_init(rng_init,
                                       model_args=(data, ),
                                       guide_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ),
                                            guide_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 1000, body_fn, (opt_state, rng_train))
    median = guide.median(opt_state)
    assert_allclose(median['loc'], true_coef, rtol=0.05)
    # test .quantile method
    median = guide.quantiles(opt_state, [0.2, 0.5])
    assert_allclose(median['loc'][1], true_coef, rtol=0.1)
Exemple #17
0
def test_beta_bernoulli():
    data = np.array([1.0] * 8 + [0.0] * 2)

    def model(data):
        f = sample("beta", dist.Beta(1., 1.))
        sample("obs", dist.Bernoulli(f), obs=data)

    def guide():
        alpha_q = param("alpha_q", 1.0, constraint=constraints.positive)
        beta_q = param("beta_q", 1.0, constraint=constraints.positive)
        sample("beta", dist.Beta(alpha_q, beta_q))

    opt_init, opt_update, get_params = optimizers.adam(0.05)
    svi_init, svi_update, _ = svi(model, guide, elbo, opt_init, opt_update,
                                  get_params)
    rng_init, rng_train = random.split(random.PRNGKey(1))
    opt_state, constrain_fn = svi_init(rng_init, model_args=(data, ))

    def body_fn(i, val):
        opt_state_, rng_ = val
        loss, opt_state_, rng_ = svi_update(i,
                                            rng_,
                                            opt_state_,
                                            model_args=(data, ))
        return opt_state_, rng_

    opt_state, _ = fori_loop(0, 300, body_fn, (opt_state, rng_train))

    params = constrain_fn(get_params(opt_state))
    assert_allclose(params['alpha_q'] / (params['alpha_q'] + params['beta_q']),
                    0.8,
                    atol=0.05,
                    rtol=0.05)
Exemple #18
0
def train(rng, params, predict, X, y):
  """Generic train function called for each slice.

  Responsible for, given an rng key, a set of parameters to be trained, some inputs X and some outputs y,
  finetuning the params on X and y according to some internally defined training configuration.
  """
  iterations = 65
  batch_size = 32
  step_size = 0.01

  @jit
  def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, predict, batch), opt_state)

  opt_init, opt_update, get_params = optimizers.adam(step_size)
  opt_state = opt_init(params)

  temp, rng = random.split(rng)
  batches = data_stream(rng, batch_size, X, y)
  for i in range(iterations):
    temp, rng = random.split(rng)
    opt_state = update(temp, i, opt_state, next(batches))

  return get_params(opt_state)
Exemple #19
0
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7):
    return {
        "gd": lambda lr, *unused: op.sgd(lr),
        "sgd": lambda lr, *unused: op.sgd(lr),
        "momentum": lambda lr, mom, *unused: op.momentum(lr, mom),
        "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps),
    }[optimizer_string.lower()](lr, mom, var, eps)
    def fit(self,
            X,
            y,
            batch_size=5,
            n_iter=10000,
            lr=0.001,
            lr_type='constant'):
        X = np.array(X).astype(np.float32)
        y = np.array(y).reshape(-1, 1).astype(np.float32)

        m, n = X.shape

        opt_init, opt_update, get_params = optimizers.adam(lr)
        opt_state = opt_init(self.params)

        epochs = ceil(n_iter / floor(m / batch_size))

        for i in range(epochs):

            for j in range(0, m, batch_size):

                X_batch = X[j:j + batch_size]
                y_batch = y[j:j + batch_size]

                cur_loss, grads = value_and_grad(self.mse_loss)(self.params,
                                                                X_batch,
                                                                y_batch)
                opt_state = opt_update(0, grads, opt_state)
                self.params = get_params(opt_state)

                # print(cur_loss)

            print("Epoch: ", i, end=" ")
            print("cost: ", self.mse_loss(self.params, X, y))
Exemple #21
0
def main():

    dataset = tfds.load('mnist',  shuffle_files=True)
    tr, ts = iter(dataset['train'].batch(128).prefetch(1).repeat()), iter(dataset['test'].batch(1).repeat())
    # print("{} {}".format(tr.__next__()['image'].shape, tr.__next__()['label'].shape))

    ## build the container for keep the weights
    weights_container = []
    
    ## building the CNN
    # 建立網路的時候,在subroutine盡量不要放if。if有可能造成計算圖的斷裂。
    # 因此weight initialization跟forward兩個function盡量分開。這樣也同時可以使用到jit加速。
    feature_map_nos = [128, 256, 512, 10]
    input_shape = [1, 28, 28, 1]
    jax_fcn_init(weights_container, input_shape,feature_map_nos, 3)
    print('weights initialized ...')

    # test the forwarding
    # print(jax_fcn(jax.numpy.ones(input_shape, dtype=np.float32), weights_container))
    # exit()
    
    
    ##### creating the optimizer #####
    # 1) jax.experimental.optimizer必須要匯入到主命名空間中,所以需要使用到import ... from
    # 2) optimizer物件在呼叫以後會返回三個不同的方法:
    #    2.1) opt_init : 用來初始化optimizer用的方法。這邊需要告訴optimizer要最佳化那些變數。
    #         因此會出現類似opt_init(param)的使用方式。param就是要用這個optimizer做最佳化的參數。
    #         opt_init會依照param建立在optimizer中相對應的參數,例如momentum或是目前的beta之類的。
    #         這時就須要把這些狀態存起來讓下次update的時候使用這些狀態參數。
    #    2.2) opt_update : 把梯度跟optimizer的狀態輸入後進行weights的update。optimizer的狀態就
    #         包括momentum這些之前記憶下來的數字。
    #    2.3) opt_param: 回傳會被最佳化的weights
    
    ## Create the optimizer, and get the essential objects
    learning_rate = 1e-4
    opt_init, opt_update, opt_get_params = optimizers.adam(learning_rate)
    
    ## initializing the optimizer, and the the status objects after initialing it
    opt_status = opt_init(weights_container)
    # print(opt_get_params(opt_status))
    # exit()

    ##### training loop #####
    for training_step in range(5000):
        tr_c = next(tr)
        image, label = tr_c['image'].numpy(), tr_c['label'].numpy()
        image, label = image.astype('float32'), label.astype('float32')
        
        ## computing the loss and gradients
        current_loss, gradients = jax.value_and_grad(loss)(weights_container, image, one_hot(label))
        opt_status = opt_update(training_step, gradients, opt_status)
        weights_container = opt_get_params(opt_status)
        
        # print(gradients.shape)
        # print(opt_get_params(opt_status))
        # exit()

        print('step {} loss:{} accuracy:{}'.format(training_step, current_loss, accuracy(image, weights_container, label)))
    pass 
Exemple #22
0
def initialize_optimizers(policy_net_params, value_net_params):
    """Initialize optimizers for the policy and value params."""
    # ppo_opt_init, ppo_opt_update = optimizers.sgd(step_size=1e-3)
    # val_opt_init, val_opt_update = optimizers.sgd(step_size=1e-3)
    ppo_opt_init, ppo_opt_update = optimizers.adam(step_size=1e-3,
                                                   b1=0.9,
                                                   b2=0.999,
                                                   eps=1e-08)
    value_opt_init, value_opt_update = optimizers.adam(step_size=1e-3,
                                                       b1=0.9,
                                                       b2=0.999,
                                                       eps=1e-08)

    ppo_opt_state = ppo_opt_init(policy_net_params)
    value_opt_state = value_opt_init(value_net_params)

    return (ppo_opt_state, ppo_opt_update), (value_opt_state, value_opt_update)
Exemple #23
0
def test_optim_multi_params():
    params = {'x': np.array([1., 1., 1.]), 'y': np.array([-1, -1., -1.])}
    opt_init, opt_update, get_params = optimizers.adam(step_size=1e-2)
    opt_state = opt_init(params)
    for i in range(1000):
        opt_state = step(i, opt_state, opt_update, get_params)
    for _, param in get_params(opt_state).items():
        assert np.allclose(param, np.zeros(3))
Exemple #24
0
def objective_funcs(func, x0):
    init, update, get_params = optimizers.adam(1e-2)

    def myupdate(i, state):
        params = get_params(state)
        return update(i, jax.grad(func)(params), state)

    return init(x0), myupdate, get_params
Exemple #25
0
 def __init__(self, activation=Relu):
     self.activation = activation
     self.net_init, self.net_apply = network(activation)
     self.optimizer = optimizers.adam(step_size=1e-3)
     
     rng = random.PRNGKey(0)
     in_shape = (-1, 1,)
     self.out_shape, self.net_params = self.net_init(rng, in_shape)
Exemple #26
0
def minimize_ADAM(fun, x0, data, step_size=0.01, f_tol=1e-5, n_iter_max=1000):
    """Minimize function with ADAM.
    Args:
    fun: function to minimize, which takes (x,data) as input
    x0: ndarray: initial solution 
    data: ndarray: full data to subsample
    
    #Optimizer hyperparams
    step_size: positive scalar for step-size to pass into minimize_adam
    f_tol: positive scalar for value of norm of f change before termination
    n_iter_max: maximum number of sgd steps before termination

    Returns:
    x_opt: Optimized x
    loss: Minimized loss value
    n_iter: Number of iterations at termination
    delta_f: Value of change of f at termination
    """

    #function wrappers
    value_and_grad_fun = jit(value_and_grad(fun))

    @jit  #wrapper around step to allow for termination checks
    def step(carry):
        i, loss, g, opt_state, key, n = carry
        key, *subkey = random.split(key)
        ind = random.shuffle(key, jnp.arange(n))
        params = get_params(opt_state)
        loss, g = value_and_grad_fun(params, data[ind])
        i = i + 1
        carry = i, loss, g, opt_update(i, g, opt_state)
        return carry

    @jit  #termination condition on gradient norm or number of iterations
    def converged(carry):
        i, loss, g, opt_state, key, n = carry
        delta_f = 1
        return jnp.logical_and(norm_g > g_tol, i < n_iter_max)

    #check dimensions and initialize
    d = jnp.shape(x0)[0]
    key = random.PRNGKey(0)
    n = jnp.shape(data)[0]

    #initialize optimzer
    opt_init, opt_update, get_params = adam(step_size=step_size)
    opt_state = opt_init(x0)

    #run optimizer until termination
    carry = 0, 1., jnp.ones(d), opt_state, key, n
    carry = while_loop(converged, step, carry)

    #extract values from carry
    n_iter, loss, g, opt_state, _, _ = carry

    x_opt = get_params(opt_state)

    return x_opt, loss, n_iter, delta_f
Exemple #27
0
def optimiseCrossover(highFileName='data/[email protected]',
                      highDriverName='AN25',
                      lowFileName='data/[email protected]',
                      lowDriverName='TCP115',
                      dataDir='data',
                      plotName='opt',
                      learningRate=1E-2,
                      epochs=25):

    crossover = setupDriverCrossover(highFileName, highDriverName, lowFileName,
                                     lowDriverName)

    flatGrad = jax.grad(crossover.flatness, argnums=[0, 1, 2])

    init_fun, update_fun, get_params = adam(learningRate)

    res = np.log(5E6)
    cap = np.log(2E-12)
    highRes = np.log(1)

    state = init_fun((res, cap, highRes))

    losses = []

    for i in tqdm(range(epochs)):

        grads = flatGrad(res, cap, highRes)

        state = update_fun(i, grads, state)

        res, cap, highRes = get_params(state)

        flatness = crossover.flatness(res, cap, highRes)
        losses.append(flatness)

    availableComponents = AvailableComponents(f'{dataDir}/resistors.json',
                                              f'{dataDir}/capacitors.json')

    nearestRes = availableComponents.nearestRes(np.exp(res))
    nearestHighRes = availableComponents.nearestRes(np.exp(highRes))
    nearestCap = availableComponents.nearestCap(np.exp(cap))

    co, lo, hi = crossover.applyCrossover(nearestRes[1], nearestCap[1],
                                          nearestHighRes[1])

    plt.plot(crossover.frequencies, co, label='Total')
    plt.plot(crossover.frequencies, hi, label='High')
    plt.plot(crossover.frequencies, lo, label='Low')

    plt.xscale('log')

    plt.legend(loc=0, fontsize=18)
    plt.xlabel('Frequency [Hz]', fontsize=16)
    plt.ylabel('Sound pressure level [dB]', fontsize=16)

    plt.savefig(f'{plotName}.pdf')
    plt.clf()
    def __init__(self):
        super(Policy, self).__init__()
        self.data = []

        self.net = hk.Sequential(
                [hk.Flatten(),
                 hk.Linear(128), jax.nn.relu,
                 hk.Linear(2), jax.nn.softmax])
        self.optimizer = adam(step_size=learning_rate)
Exemple #29
0
def get_model_and_optimizer(model_params, hyper_params, rng):
    init_fun, psi, g = serial(model_params)  # load model from serial
    _, init_params = init_fun(
        rng, (hyper_params['batch_size'],
              hyper_params['z_latent'] + 4))  # get initial params
    opt_init, opt_update, get_params_from_opt = adam(
        hyper_params['lr'])  # get optimizer
    opt_state = opt_init(init_params)
    return psi, g, opt_update, opt_state, get_params_from_opt, init_params, opt_init
Exemple #30
0
    def optimize_params(self, initial_params, num_iters, step_size, tolerance,
                        verbal):

        opt_init, opt_update, get_params = optimizers.adam(step_size=step_size)
        opt_state = opt_init(initial_params)

        @jit
        def step(i, opt_state):
            p = get_params(opt_state)
            g = grad(self.cost)(p)
            return opt_update(i, g, opt_state)

        cost_list = []
        params_list = []

        if verbal:
            print('{0}\t{1}\t'.format('Iter', 'Cost'))

        for i in range(num_iters):

            opt_state = step(i, opt_state)
            params_list.append(get_params(opt_state))
            cost_list.append(self.cost(params_list[-1]))

            if verbal:
                if i % int(verbal) == 0:
                    print('{0}\t{1:.3f}\t'.format(i, cost_list[-1]))

            if len(params_list) > tolerance:

                if np.all(
                    (np.array(cost_list[1:])) - np.array(cost_list[:-1]) > 0):
                    params = params_list[0]
                    if verbal:
                        print(
                            'Stop at {} steps: cost has been monotonically increasing for {} steps.'
                            .format(i, tolerance))
                    break
                elif np.all(
                        np.array(cost_list[:-1]) -
                        np.array(cost_list[1:]) < 1e-5):
                    params = params_list[-1]
                    if verbal:
                        print(
                            'Stop at {} steps: cost has been changing less than 1e-5 for {} steps.'
                            .format(i, tolerance))
                    break
                else:
                    params_list.pop(0)
                    cost_list.pop(0)
        else:
            params = params_list[-1]
            if verbal:
                print('Stop: reached {} steps, final cost={}.'.format(
                    num_iters, cost_list[-1]))

        return params