Ejemplo n.º 1
0
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
    opt_init, opt_update = minmax.momentum(step_size, mass)

    @jit
    def update(i, opt_state):
        x = minmax.get_params(opt_state)
        return opt_update(i, grad(f)(x), opt_state)

    opt_state = opt_init(x)
    for i in xrange(num_steps):
        opt_state = update(i, opt_state)
    return minmax.get_params(opt_state)
Ejemplo n.º 2
0
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs,
                                                                      *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0 / 60.0)

    # Set up optimizer.
    D = 2
    init_mean = np.zeros(D)
    init_std = np.zeros(D)
    init_params = (init_mean, init_std)
    opt_init, opt_update = minmax.momentum(step_size=0.1, mass=0.9)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = minmax.get_params(opt_state)
        gradient = grad(objective)(params, i)
        return opt_update(i, gradient, opt_state)

    # Main loop.
    print("Optimizing variational parameters...")
    for t in range(100):
        opt_state = update(t, opt_state)
        params = minmax.get_params(opt_state)
        callback(params, t)
Ejemplo n.º 3
0
  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]
  batches = data_stream()

  opt_init, opt_update = minmax.momentum(step_size, mass=momentum_mass)

  @jit
  def update(i, opt_state, batch):
    params = minmax.get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  _, init_params = init_random_params((-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  print("\nStarting training...")
  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      opt_state = update(next(itercount), opt_state, next(batches))