def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9): opt_init, opt_update = minmax.momentum(step_size, mass) @jit def update(i, opt_state): x = minmax.get_params(opt_state) return opt_update(i, grad(f)(x), opt_state) opt_state = opt_init(x) for i in xrange(num_steps): opt_state = update(i, opt_state) return minmax.get_params(opt_state)
# Here we clone the rng used in computing the objective # so that we can show exactly the same samples. rngs = random.split(random.PRNGKey(t), num_samples) samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs, *params) ax.plot(samples[:, 0], samples[:, 1], 'b.') plt.draw() plt.pause(1.0 / 60.0) # Set up optimizer. D = 2 init_mean = np.zeros(D) init_std = np.zeros(D) init_params = (init_mean, init_std) opt_init, opt_update = minmax.momentum(step_size=0.1, mass=0.9) opt_state = opt_init(init_params) @jit def update(i, opt_state): params = minmax.get_params(opt_state) gradient = grad(objective)(params, i) return opt_update(i, gradient, opt_state) # Main loop. print("Optimizing variational parameters...") for t in range(100): opt_state = update(t, opt_state) params = minmax.get_params(opt_state) callback(params, t)
train_images, train_labels, test_images, test_labels = datasets.mnist() num_train = train_images.shape[0] num_complete_batches, leftover = divmod(num_train, batch_size) num_batches = num_complete_batches + bool(leftover) def data_stream(): rng = npr.RandomState(0) while True: perm = rng.permutation(num_train) for i in range(num_batches): batch_idx = perm[i * batch_size:(i + 1) * batch_size] yield train_images[batch_idx], train_labels[batch_idx] batches = data_stream() opt_init, opt_update = minmax.momentum(step_size, mass=momentum_mass) @jit def update(i, opt_state, batch): params = minmax.get_params(opt_state) return opt_update(i, grad(loss)(params, batch), opt_state) _, init_params = init_random_params((-1, 28 * 28)) opt_state = opt_init(init_params) itercount = itertools.count() print("\nStarting training...") for epoch in range(num_epochs): start_time = time.time() for _ in range(num_batches): opt_state = update(next(itercount), opt_state, next(batches))