Exemplo n.º 1
0
 def evaluate(opt_state, images):
   params = minmax.get_params(opt_state)
   elbo_rng, data_rng, image_rng = random.split(test_rng, 3)
   binarized_test = random.bernoulli(data_rng, images)
   test_elbo = elbo(elbo_rng, params, binarized_test) / images.shape[0]
   sampled_images = image_sample(image_rng, params, nrow, ncol)
   return test_elbo, sampled_images
Exemplo n.º 2
0
 def body_fun(i, rng__opt_state__images):
   (rng, opt_state, images) = rng__opt_state__images
   rng, elbo_rng, data_rng = random.split(rng, 3)
   batch = binarize_batch(data_rng, i, images)
   loss = lambda params: -elbo(elbo_rng, params, batch) / batch_size
   g = grad(loss)(minmax.get_params(opt_state))
   return rng, opt_update(i, g, opt_state), images
Exemplo n.º 3
0
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
    opt_init, opt_update = minmax.momentum(step_size, mass)

    @jit
    def update(i, opt_state):
        x = minmax.get_params(opt_state)
        return opt_update(i, grad(f)(x), opt_state)

    opt_state = opt_init(x)
    for i in xrange(num_steps):
        opt_state = update(i, opt_state)
    return minmax.get_params(opt_state)
Exemplo n.º 4
0
def sgd(fun,
        params_init,
        args=(),
        jac=None,
        callback=None,
        maxiter=100,
        **options):
    if jac is None:
        jac = grad(fun)

    def step(i, opt_state):
        params = optimizers.get_params(opt_state)
        g = jac(params)
        return opt_update(i, g, opt_state)  #opt_update

    opt_state = opt_init(params_init)  #opt_init
    for i in range(maxiter):
        opt_state = step(i, opt_state)
        if not (callback is None):
            params = optimizers.get_params(opt_state)
            if callback(params):
                break
    result.x = optimizers.get_params(opt_state)
    return result
Exemplo n.º 5
0
 def update(i, opt_state):
     params = minmax.get_params(opt_state)
     gradient = grad(objective)(params, i)
     return opt_update(i, gradient, opt_state)
Exemplo n.º 6
0
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs,
                                                                      *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0 / 60.0)

    # Set up optimizer.
    D = 2
    init_mean = np.zeros(D)
    init_std = np.zeros(D)
    init_params = (init_mean, init_std)
    opt_init, opt_update = minmax.momentum(step_size=0.1, mass=0.9)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = minmax.get_params(opt_state)
        gradient = grad(objective)(params, i)
        return opt_update(i, gradient, opt_state)

    # Main loop.
    print("Optimizing variational parameters...")
    for t in range(100):
        opt_state = update(t, opt_state)
        params = minmax.get_params(opt_state)
        callback(params, t)
Exemplo n.º 7
0
 def update(i, opt_state):
     x = minmax.get_params(opt_state)
     return opt_update(i, grad(f)(x), opt_state)
Exemplo n.º 8
0
 def update(opt_state, step_size):
     _, update_fun = minmax.sgd(step_size)
     x = minmax.get_params(opt_state)
     g = grad(loss)(x, None)
     return update_fun(0, g, opt_state)
Exemplo n.º 9
0
 def update(i, opt_state, batch):
   params = minmax.get_params(opt_state)
   return opt_update(i, grad(loss)(params, batch), opt_state)
Exemplo n.º 10
0
 def step(i, opt_state):
     params = optimizers.get_params(opt_state)
     g = jac(params)
     return opt_update(i, g, opt_state)  #opt_update
Exemplo n.º 11
0
def step(i, opt_state):
    weights = optimizers.get_params(opt_state)
    g = train_gradient_fun(weights)
    return opt_update(i, g, opt_state)
Exemplo n.º 12
0
import jax.experimental.minmax as optimizers


@jit
def step(i, opt_state):
    weights = optimizers.get_params(opt_state)
    g = train_gradient_fun(weights)
    return opt_update(i, g, opt_state)


opt_init, opt_update = optimizers.sgd(step_size=lr)
opt_state = opt_init(weights_init)
print("jax SGD")
for i in range(max_iter):
    opt_state = step(i, opt_state)
weights_final2 = optimizers.get_params(opt_state)
print("Trained loss2: {:0.2f}".format(loss(weights_final2, train_data)))

################
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.optimize.minimize.html
from scipy.optimize import minimize

lbfgs_memory = 10
opts = {"maxiter": max_iter, "maxcor": lbfgs_memory}
iter = itertools.count()


def callback(x):
    print("iter {}, params {}".format(next(iter), x))
    return False