Esempio n. 1
0
 def testUnpackPackRoundTrip(self):
     opt_init, _, _ = optimizers.momentum(0.1, mass=0.9)
     params = [{'w': onp.random.randn(1, 2), 'bias': onp.random.randn(2)}]
     expected = opt_init(params)
     ans = optimizers.pack_optimizer_state(
         optimizers.unpack_optimizer_state(expected))
     self.assertEqual(ans, expected)
Esempio n. 2
0
def args_to_op(optimizer_string, lr, mom=0.9, var=0.999, eps=1e-7):
    return {
        "gd": lambda lr, *unused: op.sgd(lr),
        "sgd": lambda lr, *unused: op.sgd(lr),
        "momentum": lambda lr, mom, *unused: op.momentum(lr, mom),
        "adam": lambda lr, mom, var, eps: op.adam(lr, mom, var, eps),
    }[optimizer_string.lower()](lr, mom, var, eps)
Esempio n. 3
0
def main(unused_argv):
    # Build data and .
    print('Loading data.')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            permute_train=True)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(
        FLAGS.learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    steps_per_epoch = 50000 // FLAGS.batch_size

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1

    # Print out summary data comparing the linear / nonlinear model.
    x, y = x_train[:10000], y_train[:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', y_test, f(params, x_test),
                       f_lin(params_lin, x_test), loss)
  def testNTKGDPrediction(self, train_shape, test_shape, network, out_logits,
                          fn_and_kernel, momentum, learning_rate, t, loss):
    key, x_test, x_train, y_train = self._get_inputs(out_logits, test_shape,
                                                     train_shape)

    params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)

    g_dd = ntk(x_train, None, 'ntk')
    g_td = ntk(x_test, x_train, 'ntk')

    # Regress to an MSE loss.
    loss_fn = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
    grad_loss = jit(grad(lambda params, x: loss_fn(f(params, x), y_train)))

    trace_axes = () if g_dd.ndim == 4 else (-1,)
    if loss == 'mse_analytic':
      if momentum is not None:
        raise absltest.SkipTest(momentum)
      predictor = predict.gradient_descent_mse(g_dd, y_train,
                                               learning_rate=learning_rate,
                                               trace_axes=trace_axes)
    elif loss == 'mse':
      predictor = predict.gradient_descent(loss_fn, g_dd, y_train,
                                           learning_rate=learning_rate,
                                           momentum=momentum,
                                           trace_axes=trace_axes)
    else:
      raise NotImplementedError(loss)

    predictor = jit(predictor)

    fx_train_0 = f(params, x_train)
    fx_test_0 = f(params, x_test)

    self._test_zero_time(predictor, fx_train_0, fx_test_0, g_td, momentum)
    self._test_multi_step(predictor, fx_train_0, fx_test_0, g_td, momentum)
    if loss == 'mse_analytic':
      self._test_inf_time(predictor, fx_train_0, fx_test_0, g_td, y_train)

    if momentum is None:
      opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
    else:
      opt_init, opt_update, get_params = optimizers.momentum(learning_rate,
                                                             momentum)

    opt_state = opt_init(params)
    for i in range(t):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

    params = get_params(opt_state)

    fx_train_nn, fx_test_nn = f(params, x_train), f(params, x_test)
    fx_train_t, fx_test_t = predictor(t, fx_train_0, fx_test_0, g_td)

    self.assertAllClose(fx_train_nn, fx_train_t, rtol=RTOL, atol=ATOL)
    self.assertAllClose(fx_test_nn, fx_test_t, rtol=RTOL, atol=ATOL)
Esempio n. 5
0
def minimize(f, x, num_steps=10000, step_size=0.000001, mass=0.9):
  opt_init, opt_update, get_params = optimizers.momentum(step_size, mass)

  @jit
  def update(i, opt_state):
    x = get_params(opt_state)
    return opt_update(i, grad(f)(x), opt_state)

  opt_state = opt_init(x)
  for i in range(num_steps):
    opt_state = update(i, opt_state)
  return get_params(opt_state)
Esempio n. 6
0
def get_optimizer(optimizer, sched, b1=0.9, b2=0.999):
  if optimizer.lower() == 'adagrad':
    return optimizers.adagrad(sched)
  elif optimizer.lower() == 'adam':
    return optimizers.adam(sched, b1, b2)
  elif optimizer.lower() == 'rmsprop':
    return optimizers.rmsprop(sched)
  elif optimizer.lower() == 'momentum':
    return optimizers.momentum(sched, 0.9)
  elif optimizer.lower() == 'sgd':
    return optimizers.sgd(sched)
  else:
    raise Exception('Invalid optimizer: {}'.format(optimizer))
  def testMaxLearningRate(self, train_shape, network, out_logits,
                          fn_and_kernel, lr_factor, momentum):

    key = random.PRNGKey(0)

    key, split = random.split(key)
    if len(train_shape) == 2:
      train_shape = (train_shape[0] * 5, train_shape[1] * 10)
    else:
      train_shape = (16, 8, 8, 3)
    x_train = random.normal(split, train_shape)

    key, split = random.split(key)
    y_train = np.array(
        random.bernoulli(split, shape=(train_shape[0], out_logits)), np.float32)

    # Regress to an MSE loss.
    loss = lambda params, x: 0.5 * np.mean((f(params, x) - y_train) ** 2)
    grad_loss = jit(grad(loss))

    def get_loss(opt_state):
      return loss(get_params(opt_state), x_train)

    steps = 30

    params, f, ntk = fn_and_kernel(key, train_shape[1:], network, out_logits)
    g_dd = ntk(x_train, None, 'ntk')

    step_size = predict.max_learning_rate(
        g_dd, y_train_size=y_train.size, momentum=momentum) * lr_factor
    opt_init, opt_update, get_params = optimizers.momentum(step_size,
                                                           mass=momentum)

    opt_state = opt_init(params)

    init_loss = get_loss(opt_state)

    for i in range(steps):
      params = get_params(opt_state)
      opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

    trained_loss = get_loss(opt_state)
    loss_ratio = trained_loss / (init_loss + 1e-12)
    if lr_factor < 1.:
      self.assertLess(loss_ratio, 0.1)
    elif lr_factor == 1:
      # At the threshold, the loss decays slowly
      self.assertLess(loss_ratio, 1.)
    if lr_factor > 2.:
      if not math.isnan(loss_ratio):
        self.assertGreater(loss_ratio, 10.)
Esempio n. 8
0
    def omniglot():
        n_way, n_support, n_query = 50, 15, 5
        net_init, f = conv_net(n_output=n_way,
                               n_conv_layer=4,
                               n_filter=64,
                               bias_coef=1,
                               activation='relu',
                               norm='None')
        _, params_init = net_init(rng=random.PRNGKey(42),
                                  input_shape=(-1, 28, 28, 1))

        def loss(params, batch):
            inputs, targets = batch
            logits = f(params, inputs)
            outputs = logsoftmax(logits)
            return -np.sum(outputs * targets) / targets.shape[0]

        def accuracy(params, batch):
            inputs, targets = batch
            target_class = np.argmax(targets, axis=-1)
            predicted_class = np.argmax(f(params, inputs), axis=-1)
            return np.mean(predicted_class == target_class)

        splits = load_omniglot(n_support=n_support, n_query=n_query)
        task = omniglot_task(splits['train'],
                             n_way=n_way,
                             n_support=n_support,
                             n_query=n_query)

        opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-0,
                                                               mass=0.9)

        @jit
        def update(i, opt_state, batch):
            params = get_params(opt_state)
            return opt_update(i, grad(loss)(params, batch), opt_state)

        opt_state = opt_init(params_init)

        n_update = 10000
        for i in range(n_update):
            opt_state = update(i, opt_state,
                               (task['x_train'], task['y_train']))
            if i == 0 or (i + 1) % (n_update // 100) == 0:
                print(
                    i,
                    f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))},"
                    f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}"
                )
        trained_params = get_params(opt_state)
Esempio n. 9
0
def subset_train(seed, subset_ratio):
  jrng = random.PRNGKey(seed)
  
  step_size = 0.1
  num_epochs = 10
  batch_size = 128
  momentum_mass = 0.9

  num_train_total = mnist_data['train_images'].shape[0]
  num_train = int(num_train_total * subset_ratio)
  num_batches = int(np.ceil(num_train / batch_size))

  rng = npr.RandomState(seed)
  subset_idx = rng.choice(num_train_total, size=num_train, replace=False)
  train_images = mnist_data['train_images'][subset_idx]
  train_labels = mnist_data['train_labels'][subset_idx]

  def data_stream(shuffle=True):
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]

  batches = data_stream()

  opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

  @jit
  def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  _, init_params = init_random_params(jrng, (-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  for epoch in range(num_epochs):
    for _ in range(num_batches):
      opt_state = update(next(itercount), opt_state, next(batches))

  params = get_params(opt_state)
  trainset_correctness = batch_correctness(
      params, (mnist_data['train_images'], mnist_data['train_labels']))
  testset_correctness = batch_correctness(
      params, (mnist_data['test_images'], mnist_data['test_labels']))

  trainset_mask = np.zeros(num_train_total, dtype=np.bool)
  trainset_mask[subset_idx] = True
  return trainset_mask, np.asarray(trainset_correctness), np.asarray(testset_correctness)
Esempio n. 10
0
def optimizer(name="adam",
              momentum_mass=0.9, rmsprop_gamma=0.9, rmsprop_eps=1e-8,
              adam_b1=0.9, adam_b2=0.997, adam_eps=1e-8):
  """Return the optimizer, by name."""
  if name == "sgd":
    return optimizers.sgd(learning_rate)
  if name == "momentum":
    return optimizers.momentum(learning_rate, mass=momentum_mass)
  if name == "rmsprop":
    return optimizers.rmsprop(
        learning_rate, gamma=rmsprop_gamma, eps=rmsprop_eps)
  if name == "adam":
    return optimizers.adam(learning_rate, b1=adam_b1, b2=adam_b2, eps=adam_eps)
  raise ValueError("Unknown optimizer %s" % str(name))
Esempio n. 11
0
def main():
    rng = random.PRNGKey(0)

    batch_size = 128
    step_size = 0.001
    num_epochs = 10
    momentum_mass = 0.9

    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    # define data stream
    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_indices = perm[i*batch_size:(i+1)*batch_size]
                yield train_images[batch_size], train_labels[batch_indices]
    batches = data_stream()

    # define optimizer
    opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    _, init_params = init_random_params(rng, (-1, 28*28))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    print('\nStarting training...')
    for epoch in range(num_epochs):
        start_tm = time.time()
        for _ in range(num_epochs):
            opt_state = update(next(itercount), opt_state, next(batches))
        epoch_tm = time.time() - start_tm
        
        params = get_params(opt_state)
        train_acc = accuracy(params, (train_images, train_labels))
        test_acc = accuracy(params, (test_images, test_labels))
        print(f'Epoch {epoch} in {epoch_tm:0.2f} sec')
        print(f'Training set accuracy {train_acc}')
        print(f'Test set accuracy {test_acc}')
    print('DONE')
Esempio n. 12
0
def train_opt(loss_fn_xy, size, initial_params, lr, momentum):
    opt_init, opt_update, get_params = optimizers.momentum(lr, momentum)

    def step(step, opt_state):
        loss, grads = jax.value_and_grad(loss_fn_xy)(get_params(opt_state))
        opt_state = opt_update(step, grads, opt_state)
        return loss, opt_state

    def scan_fn(opt_state, i):
        loss, opt_state = step(i, opt_state)
        return opt_state, {"loss": loss, "params": get_params(opt_state)}

    def train(initial_params):
        init_opt_state = opt_init(initial_params)
        opt_state, memo = jax.lax.scan(scan_fn, init_opt_state,
                                       jnp.arange(size))
        return get_params(opt_state), memo

    return train(initial_params)
Esempio n. 13
0
    def sinusoid():

        net_init, net_fn = mlp(n_output=1,
                               n_hidden_layer=2,
                               bias_coef=1.0,
                               n_hidden_unit=40,
                               activation='relu',
                               norm='batch_norm')

        rng = random.PRNGKey(42)
        in_shape = (-1, 1)
        out_shape, net_params = net_init(rng, in_shape)

        def loss(params, batch):
            inputs, targets = batch
            predictions = net_fn(params, inputs)
            return np.mean((predictions - targets)**2)

        opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-2,
                                                               mass=0.9)
        opt_update = jit(opt_update)

        @jit
        def step(i, opt_state, batch):
            params = get_params(opt_state)
            g = grad(loss)(params, batch)
            return opt_update(i, g, opt_state)

        task = sinusoid_task(n_support=1000, n_query=100)

        opt_state = opt_init(net_params)
        for i, (x, y) in enumerate(
                minibatch(task['x_train'],
                          task['y_train'],
                          batch_size=256,
                          train_epochs=1000)):
            opt_state = step(i, opt_state, batch=(x, y))
            if i == 0 or (i + 1) % 100 == 0:
                print(
                    f"train loss: {loss(get_params(opt_state), (task['x_train'], task['y_train']))},"
                    f"\ttest loss: {loss(get_params(opt_state), (task['x_test'], task['y_test']))}"
                )
    def __init__(self, CNN="true", L=40, step_size=0.001, seed=0):
        """Defines neural network architecture, parameter initialization, and optimizer"""

        self.CNN = CNN
        if CNN:
            self.input_shape = (-1, L, L, 1)
            # following network gave highest accuracy on test data of all the ones I tried
            self.init_random_params, self.predict = stax.serial(
                PeriodicConv(out_chan=10,
                             filter_shape=(2, 2),
                             strides=(1, 1),
                             padding='VALID'),
                Relu,
                #MaxPool(window_shape=(2, 2), strides=(2, 2), padding='VALID'),
                Flatten,
                Dense(100),
                Relu,
                # Dropout(0.4), # doesn't work yet since prng key has to be passed to predict()
                Dense(1),
                Sigmoid)

        else:
            self.input_shape = (-1, L * L)
            self.init_random_params, self.predict = stax.serial(
                Dense(100),
                Relu,
                Dense(100),
                Relu,
                # Dropout(0.4), # doesn't work yet since prng key has to be passed to predict()
                Dense(1),
                Sigmoid)

        momentum_mass = 0.9
        self.opt_init, self.opt_update, self.get_params = optimizers.momentum(
            step_size, mass=momentum_mass)
        #self.opt_init, self.opt_update, self.get_params = optimizers.adam(0.0001)

        rng = random.PRNGKey(seed)
        _, self.init_params = self.init_random_params(rng, self.input_shape)
        self.opt_state = self.opt_init(self.init_params)
        self.params = self.init_params
Esempio n. 15
0
    def get_optimizer(self, optim=None, stage='learn', step_size=None):

        if optim is None:
            if stage == 'learn':
                optim = self.optim_learn
            else:
                optim = self.optim_proj
        if step_size is None:
            step_size = self.step_size

        if optim == 1:
            if self.verb > 2:
                print("With momentum optimizer")
            opt_init, opt_update, get_params = momentum(step_size=step_size,
                                                        mass=0.95)
        elif optim == 2:
            if self.verb > 2:
                print("With rmsprop optimizer")
            opt_init, opt_update, get_params = rmsprop(step_size,
                                                       gamma=0.9,
                                                       eps=1e-8)
        elif optim == 3:
            if self.verb > 2:
                print("With adagrad optimizer")
            opt_init, opt_update, get_params = adagrad(step_size, momentum=0.9)
        elif optim == 4:
            if self.verb > 2:
                print("With Nesterov optimizer")
            opt_init, opt_update, get_params = nesterov(step_size, 0.9)
        elif optim == 5:
            if self.verb > 2:
                print("With SGD optimizer")
            opt_init, opt_update, get_params = sgd(step_size)
        else:
            if self.verb > 2:
                print("With adam optimizer")
            opt_init, opt_update, get_params = adam(step_size)

        return opt_init, opt_update, get_params
Esempio n. 16
0
def ssvm_loss(params,
              x,
              y,
              lamb=0.01,
              max_steps=80,
              step_size=0.1,
              pretrain_global_energy=False):
    prediction = y is None
    x_hat = compute_feature_energy(params, x)
    if pretrain_global_energy:
        x_hat = lax.stop_gradient(x_hat)
    grad_fun = inference_step if prediction else cost_augmented_inference_step

    opt_init, opt_update, get_params = momentum(0.01, 0.95)
    # opt_state = opt_init(np.full(x.shape[:-1] + (LABELS,), 1. / LABELS))
    opt_state = opt_init(np.zeros(x.shape[:-1] + (LABELS, )))
    prev_energy = None
    for step in range(max_steps):
        y_hat = project(get_params(opt_state))
        g, energy = grad_fun(y_hat, y, x_hat, params)
        opt_state = opt_update(step, g, opt_state)
        if step > 0 and check_saddle_point(step, get_params(opt_state), y_hat,
                                           energy, prev_energy):
            break
        prev_energy = energy

    y_hat = lax.stop_gradient(project(get_params(opt_state)))
    if prediction:
        return y_hat

    y = lax.stop_gradient(y)
    pred_energy = compute_global_energy(params, x_hat, y_hat)
    true_energy = compute_global_energy(params, x_hat, y)
    delta = np.square(y_hat - y).sum(axis=1)
    loss = np.mean(np.maximum(delta + true_energy - pred_energy, 0))
    return loss + lamb * l2_norm(params)
Esempio n. 17
0
  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, batch_size)
  num_batches = num_complete_batches + bool(leftover)

  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * batch_size:(i + 1) * batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]
  batches = data_stream()

  opt_init, opt_update = optimizers.momentum(step_size, mass=momentum_mass)

  @jit
  def update(i, opt_state, batch):
    params = optimizers.get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  _, init_params = init_random_params(rng, (-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  print("\nStarting training...")
  for epoch in range(num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      opt_state = update(next(itercount), opt_state, next(batches))
Esempio n. 18
0
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs,
                                                                      *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0 / 60.0)

    # Set up optimizer.
    D = 2
    init_mean = np.zeros(D)
    init_std = np.zeros(D)
    init_params = (init_mean, init_std)
    opt_init, opt_update = optimizers.momentum(step_size=0.1, mass=0.9)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = optimizers.get_params(opt_state)
        gradient = grad(objective)(params, i)
        return opt_update(i, gradient, opt_state)

    # Main loop.
    print("Optimizing variational parameters...")
    for t in range(100):
        opt_state = update(t, opt_state)
        params = optimizers.get_params(opt_state)
        callback(params, t)
    plt.show(block=True)
Esempio n. 19
0
    train_images, train_labels, test_images, test_labels = datasets.mnist()
    num_train = train_images.shape[0]
    num_complete_batches, leftover = divmod(num_train, batch_size)
    num_batches = num_complete_batches + bool(leftover)

    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * batch_size:(i + 1) * batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]

    batches = data_stream()

    opt_init, opt_update, get_params = optimizers.momentum(step_size,
                                                           mass=momentum_mass)

    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)

    _, init_params = init_random_params(rng, (-1, 28 * 28))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()
        for _ in range(num_batches):
            opt_state = update(next(itercount), opt_state, next(batches))
Esempio n. 20
0
def weight_space(train_embedding, test_embedding, data_set):
    init_fn, f, _ = stax.serial(
        stax.Dense(512, 1., 0.05),
        stax.Erf(),
        # 2 denotes 2 type of classes
        stax.Dense(2, 1., 0.05))

    key = random.PRNGKey(0)
    # (-1, 135),  135 denotes the feature length, here is 9 * 15 = 135
    _, params = init_fn(key, (-1, 135))

    # Linearize the network about its initial parameters.
    f_lin = nt.linearize(f, params)

    # Create and initialize an optimizer for both f and f_lin.
    opt_init, opt_apply, get_params = optimizers.momentum(1.0, 0.9)
    opt_apply = jit(opt_apply)

    state = opt_init(params)
    state_lin = opt_init(params)

    # Create a cross-entropy loss function.
    loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print('Training.')
    print('Epoch\tLoss\tLinearized Loss')
    print('------------------------------------------')

    epoch = 0
    # Use whole batch
    batch_size = 64
    train_epochs = 10
    steps_per_epoch = 100

    for i, (x, y) in enumerate(
            datasets.mini_batch(train_embedding, data_set['Y_train'],
                                batch_size, train_epochs)):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        if i % steps_per_epoch == 0:
            print('{}\t{:.4f}\t{:.4f}'.format(epoch, loss(f(params, x), y),
                                              loss(f_lin(params_lin, x), y)))
            epoch += 1
        if i / steps_per_epoch == train_epochs:
            break

    # Print out summary data comparing the linear / nonlinear model.
    x, y = train_embedding[:10000], data_set['Y_train'][:10000]
    util.print_summary('train', y, f(params, x), f_lin(params_lin, x), loss)
    util.print_summary('test', data_set['Y_test'], f(params, test_embedding),
                       f_lin(params_lin, test_embedding), loss)
Esempio n. 21
0
        # Here we clone the rng used in computing the objective
        # so that we can show exactly the same samples.
        rngs = random.split(random.PRNGKey(t), num_samples)
        samples = vmap(diag_gaussian_sample, in_axes=(0, None, None))(rngs,
                                                                      *params)
        ax.plot(samples[:, 0], samples[:, 1], 'b.')

        plt.draw()
        plt.pause(1.0 / 60.0)

    # Set up optimizer.
    D = 2
    init_mean = jnp.zeros(D)
    init_std = jnp.zeros(D)
    init_params = (init_mean, init_std)
    opt_init, opt_update, get_params = optimizers.momentum(step_size=0.1,
                                                           mass=0.9)
    opt_state = opt_init(init_params)

    @jit
    def update(i, opt_state):
        params = get_params(opt_state)
        gradient = grad(objective)(params, i)
        return opt_update(i, gradient, opt_state)

    # Main loop.
    print("Optimizing variational parameters...")
    for t in range(100):
        opt_state = update(t, opt_state)
        params = get_params(opt_state)
        callback(params, t)
    plt.show(block=True)
Esempio n. 22
0
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
Esempio n. 23
0
def _JaxMomentum(machine, learning_rate, beta=0.9, l2reg=0):
    return Wrap(machine, jaxopt.momentum(learning_rate, beta))
Esempio n. 24
0
  train_images, train_labels, test_images, test_labels = datasets.mnist()
  num_train = train_images.shape[0]
  num_complete_batches, leftover = divmod(num_train, wandb.config.batch_size)
  num_batches = num_complete_batches + bool(leftover)

  def data_stream():
    rng = npr.RandomState(0)
    while True:
      perm = rng.permutation(num_train)
      for i in range(num_batches):
        batch_idx = perm[i * wandb.config.batch_size:(i + 1) * wandb.config.batch_size]
        yield train_images[batch_idx], train_labels[batch_idx]
  batches = data_stream()

  opt_init, opt_update, get_params = optimizers.momentum(wandb.config.step_size, mass=wandb.config.momentum_mass)

  @jit
  def update(i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)

  _, init_params = init_random_params(rng, (-1, 28 * 28))
  opt_state = opt_init(init_params)
  itercount = itertools.count()

  print("\nStarting training...")
  for epoch in range(wandb.config.num_epochs):
    start_time = time.time()
    for _ in range(num_batches):
      opt_state = update(next(itercount), opt_state, next(batches))
Esempio n. 25
0
def main(unused_argv):
    # print(f'Available GPU memory: {util.get_gpu_memory()}')
    # Load and normalize data
    print('Loading data...')
    x_train, y_train, x_test, y_test = datasets.get_dataset('mnist',
                                                            n_train=60000,
                                                            n_test=10000,
                                                            permute_train=True)
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Reformat MNIST data to 28x28x1 pictures
    x_train = np.asarray(x_train.reshape(-1, 28, 28, 1))
    x_test = np.asarray(x_test.reshape(-1, 28, 28, 1))
    print('Data loaded and reshaped')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Set random seed
    key = random.PRNGKey(0)

    # # Add random translation to images
    # x_train = util.add_translation(x_train, FLAGS.max_pixel)
    # x_test = util.add_translation(x_test, FLAGS.max_pixel)
    # print(f'Random translation by up to {FLAGS.max_pixel} pixels added')

    # # Add random translations with padding
    # x_train = util.add_padded_translation(x_train, 10)
    # x_test = util.add_padded_translation(x_test, 10)
    # print(f'Random translations with additional padding up to 10 pixels added')

    # Build the LeNet network with NTK parameterization
    init_fn, f, kernel_fn = util.build_le_net(FLAGS.network_width)
    print(f'Network of width x{FLAGS.network_width} built.')

    # # Construct the kernel function
    # kernel_fn = nt.batch(kernel_fn, device_count=-1, batch_size=FLAGS.batch_size_kernel)
    # print('Kernel constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # Compute random initial parameters
    _, params = init_fn(key, (-1, 28, 28, 1))
    params_lin = params

    print('Initial parameters constructed')
    # print(f'Available GPU memory: {util.get_gpu_memory()}')

    # # Save initial parameters
    # with open('init_params.npy', 'wb') as file:
    #     np.save(file, params)

    # Linearize the network about its initial parameters.
    # Use jit for faster GPU computation (only feasible for width < 25)
    f_lin = nt.linearize(f, params)
    if FLAGS.network_width <= 10:
        f_jit = jit(f)
        f_lin_jit = jit(f_lin)
    else:
        f_jit = f
        f_lin_jit = f_lin

    # Create a callable function for dynamic learning rates
    # Starts with learning_rate, divided by 10 after learning_decline epochs.
    dynamic_learning_rate = lambda iteration_step: FLAGS.learning_rate / 10**(
        (iteration_step //
         (x_train.shape[0] // FLAGS.batch_size)) // FLAGS.learning_decline)

    # Create and initialize an optimizer for both f and f_lin.
    # Use momentum with coefficient 0.9 and jit
    opt_init, opt_apply, get_params = optimizers.momentum(
        dynamic_learning_rate, 0.9)
    opt_apply = jit(opt_apply)

    # Compute the initial states
    state = opt_init(params)
    state_lin = opt_init(params)

    # Define the accuracy function
    accuracy = lambda fx, y_hat: np.mean(
        np.argmax(fx, axis=1) == np.argmax(y_hat, axis=1))

    # Define mean square error loss function
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)

    # # Create a cross-entropy loss function.
    # loss = lambda fx, y_hat: -np.mean(logsoftmax(fx) * y_hat)

    # Specialize the loss function to compute gradients for both linearized and
    # full networks.
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))
    grad_loss_lin = jit(grad(lambda params, x, y: loss(f_lin(params, x), y)))

    # Train the network.
    print(
        f'Training with dynamic learning decline after {FLAGS.learning_decline} epochs...'
    )
    print(
        'Epoch\tTime\tAccuracy\tLin. Accuracy\tLoss\tLin. Loss\tAccuracy Train\tLin.Accuracy Train'
    )
    print(
        '----------------------------------------------------------------------------------------------------------'
    )

    # Initialize training
    epoch = 0
    steps_per_epoch = x_train.shape[0] // FLAGS.batch_size

    # Set start time (total and 100 epochs)
    start = time.time()
    start_epoch = time.time()

    for i, (x, y) in enumerate(
            datasets.minibatch(x_train, y_train, FLAGS.batch_size,
                               FLAGS.train_epochs)):

        # Update the parameters
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x, y), state)

        params_lin = get_params(state_lin)
        state_lin = opt_apply(i, grad_loss_lin(params_lin, x, y), state_lin)

        # Print information after each 100 epochs
        if (i + 1) % (steps_per_epoch * 100) == 0:
            time_point = time.time() - start_epoch

            # Update epoch
            epoch += 100

            # Accuracy in batches
            f_x = util.output_in_batches(x_train, params, f_jit,
                                         FLAGS.batch_count_accuracy)
            f_x_test = util.output_in_batches(x_test, params, f_jit,
                                              FLAGS.batch_count_accuracy)
            f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                             FLAGS.batch_count_accuracy)
            f_x_lin_test = util.output_in_batches(x_test, params_lin,
                                                  f_lin_jit,
                                                  FLAGS.batch_count_accuracy)
            # time_point = time.time() - start_epoch

            # Print information about past 100 epochs
            print(
                '{}\t{:.3f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}\t{:.4f}\t\t{:.4f}\t\t{:.4f}'
                .format(epoch, time_point,
                        accuracy(f_x, y_train) * 100,
                        accuracy(f_x_lin, y_train) * 100, loss(f_x, y_train),
                        loss(f_x_lin, y_train),
                        accuracy(f_x_test, y_test) * 100,
                        accuracy(f_x_lin_test, y_test) * 100))

            # # Save params if epoch is multiple of learning decline or multiple of fixed value
            # if epoch % FLAGS.learning_decline == 0:
            #     filename = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}.npy'
            #     with open(filename, 'wb') as file:
            #         np.save(file, params)
            #     filename_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_pmod_{epoch}_{FLAGS.learning_decline}_lin.npy'
            #     with open(filename_lin, 'wb') as file_lin:
            #         np.save(file_lin, params_lin)

            # Reset timer
            start_epoch = time.time()

    duration = time.time() - start
    print(
        '----------------------------------------------------------------------------------------------------------'
    )
    print(f'Training complete in {duration} seconds.')

    # # Save final params in file
    # filename_final = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}.npy '
    # with open(filename_final, 'wb') as final:
    #     np.save(final, params)
    # filename_final_lin = FLAGS.default_path + f'LinLeNetx{FLAGS.network_width}_final_pmod_{FLAGS.train_epochs}_{FLAGS.learning_decline}_lin.npy'
    # with open(filename_final_lin, 'wb') as final_lin:
    #     np.save(final_lin, params_lin)

    # Compute output in batches
    f_x = util.output_in_batches(x_train, params, f_jit,
                                 FLAGS.batch_count_accuracy)
    f_x_lin = util.output_in_batches(x_train, params_lin, f_lin_jit,
                                     FLAGS.batch_count_accuracy)

    f_x_test = util.output_in_batches(x_test, params, f_jit,
                                      FLAGS.batch_count_accuracy)
    f_x_lin_test = util.output_in_batches(x_test, params_lin, f_lin_jit,
                                          FLAGS.batch_count_accuracy)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f_x, f_x_lin, loss)
    util.print_summary('test', y_test, f_x_test, f_x_lin_test, loss)
Esempio n. 26
0
 def __init__(self, learning_rate, mass=0.9):
     super().__init__(learning_rate)
     self.mass = mass
     self.opt_init, self.opt_update, self.get_params = momentum(
         step_size=self.lr, mass=self.mass)
Esempio n. 27
0
pi = jnp.array([1, 1]) / 2

casino = HMMJax(A, B, pi)
num_hidden, num_obs = 2, 6

seed = 0
rng_key = PRNGKey(seed)
rng_key, rng_sample = split(rng_key)

n_obs_seq, max_len = 4, 5000
num_epochs = 400

observations, lens = pad_sequences(
    *hmm_sample_n(casino, hmm_sample_jax, n_obs_seq, max_len, rng_sample))
optimizer = optimizers.momentum(step_size=1e-3, mass=0.95)

# Mini Batch Gradient Descent
batch_size = 2
params_mbgd, losses_mbgd = fit(observations,
                               lens,
                               num_hidden,
                               num_obs,
                               batch_size,
                               optimizer,
                               rng_key=None,
                               num_epochs=num_epochs)

# Full Batch Gradient Descent
batch_size = n_obs_seq
params_fbgd, losses_fbgd = fit(observations,
Esempio n. 28
0
def main(_):
    logging.info('Starting experiment.')
    configs = FLAGS.config

    # Create model folder for outputs
    try:
        gfile.MakeDirs(FLAGS.exp_dir)
    except gfile.GOSError:
        pass
    stdout_log = gfile.Open('{}/stdout.log'.format(FLAGS.exp_dir), 'w+')

    logging.info('Loading data.')
    tic = time.time()

    train_images, train_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'train')
    n_train = len(train_images)
    train_mu, train_std = onp.mean(train_images), onp.std(train_images)
    train = data.DataChunk(X=(train_images - train_mu) / train_std,
                           Y=train_labels,
                           image_size=32,
                           image_channels=3,
                           label_dim=1,
                           label_format='numeric')

    test_images, test_labels, _ = datasets.get_dataset_split(
        FLAGS.dataset, 'test')
    test = data.DataChunk(
        X=(test_images - train_mu) / train_std,  # normalize w train mean/std
        Y=test_labels,
        image_size=32,
        image_channels=3,
        label_dim=1,
        label_format='numeric')

    # Data augmentation
    if configs.augment_data:
        augmentation = data.chain_transforms(data.RandomHorizontalFlip(0.5),
                                             data.RandomCrop(4), data.ToDevice)
    else:
        augmentation = None
    batch = data.minibatcher(train, configs.batch_size, transform=augmentation)

    # Model architecture
    if configs.architect == 'wrn':
        init_random_params, predict = wide_resnet(configs.block_size,
                                                  configs.channel_multiplier,
                                                  10)
    elif configs.architect == 'cnn':
        init_random_params, predict = cnn()
    else:
        raise ValueError('Model architecture not implemented.')

    if configs.seed is not None:
        key = random.PRNGKey(configs.seed)
    else:
        key = random.PRNGKey(int(time.time()))
    _, params = init_random_params(key, (-1, 32, 32, 3))

    # count params of JAX model
    def count_parameters(params):
        return tree_util.tree_reduce(
            operator.add, tree_util.tree_map(lambda x: np.prod(x.shape),
                                             params))

    logging.info('Number of parameters: %d', count_parameters(params))
    stdout_log.write('Number of params: {}\n'.format(count_parameters(params)))

    # loss functions
    def cross_entropy_loss(params, x_img, y_lbl):
        return -np.mean(stax.logsoftmax(predict(params, x_img)) * y_lbl)

    def mse_loss(params, x_img, y_lbl):
        return 0.5 * np.mean((y_lbl - predict(params, x_img))**2)

    def accuracy(y_lbl_hat, y_lbl):
        target_class = np.argmax(y_lbl, axis=1)
        predicted_class = np.argmax(y_lbl_hat, axis=1)
        return np.mean(predicted_class == target_class)

    # Loss and gradient
    if configs.loss == 'xent':
        loss = cross_entropy_loss
    elif configs.loss == 'mse':
        loss = mse_loss
    else:
        raise ValueError('Loss function not implemented.')
    grad_loss = jit(grad(loss))

    # learning rate schedule and optimizer
    def cosine(initial_step_size, train_steps):
        k = np.pi / (2.0 * train_steps)

        def schedule(i):
            return initial_step_size * np.cos(k * i)

        return schedule

    if configs.optimization == 'sgd':
        lr_schedule = optimizers.make_schedule(configs.learning_rate)
        opt_init, opt_update, get_params = optimizers.sgd(lr_schedule)
    elif configs.optimization == 'momentum':
        lr_schedule = cosine(configs.learning_rate, configs.train_steps)
        opt_init, opt_update, get_params = optimizers.momentum(
            lr_schedule, 0.9)
    else:
        raise ValueError('Optimizer not implemented.')

    opt_state = opt_init(params)

    def private_grad(params, batch, rng, l2_norm_clip, noise_multiplier,
                     batch_size):
        """Return differentially private gradients of params, evaluated on batch."""
        def _clipped_grad(params, single_example_batch):
            """Evaluate gradient for a single-example batch and clip its grad norm."""
            grads = grad_loss(params, single_example_batch[0].reshape(
                (-1, 32, 32, 3)), single_example_batch[1])

            nonempty_grads, tree_def = tree_util.tree_flatten(grads)
            total_grad_norm = np.linalg.norm(
                [np.linalg.norm(neg.ravel()) for neg in nonempty_grads])
            divisor = stop_gradient(
                np.amax((total_grad_norm / l2_norm_clip, 1.)))
            normalized_nonempty_grads = [
                neg / divisor for neg in nonempty_grads
            ]
            return tree_util.tree_unflatten(tree_def,
                                            normalized_nonempty_grads)

        px_clipped_grad_fn = vmap(partial(_clipped_grad, params))
        std_dev = l2_norm_clip * noise_multiplier
        noise_ = lambda n: n + std_dev * random.normal(rng, n.shape)
        normalize_ = lambda n: n / float(batch_size)
        sum_ = lambda n: np.sum(n, 0)  # aggregate
        aggregated_clipped_grads = tree_util.tree_map(
            sum_, px_clipped_grad_fn(batch))
        noised_aggregated_clipped_grads = tree_util.tree_map(
            noise_, aggregated_clipped_grads)
        normalized_noised_aggregated_clipped_grads = (tree_util.tree_map(
            normalize_, noised_aggregated_clipped_grads))
        return normalized_noised_aggregated_clipped_grads

    # summarize measurements
    steps_per_epoch = n_train // configs.batch_size

    def summarize(step, params):
        """Compute measurements in a zipped way."""
        set_entries = [train, test]
        set_bsizes = [configs.train_eval_bsize, configs.test_eval_bsize]
        set_names, loss_dict, acc_dict = ['train', 'test'], {}, {}

        for set_entry, set_bsize, set_name in zip(set_entries, set_bsizes,
                                                  set_names):
            temp_loss, temp_acc, points = 0.0, 0.0, 0
            for b in data.batch(set_entry, set_bsize):
                temp_loss += loss(params, b.X, b.Y) * b.X.shape[0]
                temp_acc += accuracy(predict(params, b.X), b.Y) * b.X.shape[0]
                points += b.X.shape[0]
            loss_dict[set_name] = temp_loss / float(points)
            acc_dict[set_name] = temp_acc / float(points)

        logging.info('Step: %s', str(step))
        logging.info('Train acc : %.4f', acc_dict['train'])
        logging.info('Train loss: %.4f', loss_dict['train'])
        logging.info('Test acc  : %.4f', acc_dict['test'])
        logging.info('Test loss : %.4f', loss_dict['test'])

        stdout_log.write('Step: {}\n'.format(step))
        stdout_log.write('Train acc : {}\n'.format(acc_dict['train']))
        stdout_log.write('Train loss: {}\n'.format(loss_dict['train']))
        stdout_log.write('Test acc  : {}\n'.format(acc_dict['test']))
        stdout_log.write('Test loss : {}\n'.format(loss_dict['test']))

        return acc_dict['test']

    toc = time.time()
    logging.info('Elapsed SETUP time: %s', str(toc - tic))
    stdout_log.write('Elapsed SETUP time: {}\n'.format(toc - tic))

    # BEGIN: training steps
    logging.info('Training network.')
    tic = time.time()
    t = time.time()

    for s in range(configs.train_steps):
        b = next(batch)
        params = get_params(opt_state)

        # t0 = time.time()
        if FLAGS.dpsgd:
            key = random.fold_in(key, s)  # get new key for new random numbers
            opt_state = opt_update(
                s,
                private_grad(params, (b.X.reshape(
                    (-1, 1, 32, 32, 3)), b.Y), key, configs.l2_norm_clip,
                             configs.noise_multiplier, configs.batch_size),
                opt_state)
        else:
            opt_state = opt_update(s, grad_loss(params, b.X, b.Y), opt_state)
        # t1 = time.time()
        # logging.info('batch update time: %s', str(t1 - t0))

        if s % steps_per_epoch == 0:
            with gfile.Open(
                    '{}/ckpt_{}'.format(FLAGS.exp_dir,
                                        int(s / steps_per_epoch)),
                    'wr') as fckpt:
                pickle.dump(optimizers.unpack_optimizer_state(opt_state),
                            fckpt)

            if FLAGS.dpsgd:
                eps = compute_epsilon(s, configs.batch_size, n_train,
                                      configs.target_delta,
                                      configs.noise_multiplier)
                stdout_log.write(
                    'For delta={:.0e}, current epsilon is: {:.2f}\n'.format(
                        configs.target_delta, eps))

            logging.info('Elapsed EPOCH time: %s', str(time.time() - t))
            stdout_log.write('Elapsed EPOCH time: {}'.format(time.time() - t))
            stdout_log.flush()
            t = time.time()

    toc = time.time()
    summarize(configs.train_steps, params)
    logging.info('Elapsed TRAIN time: %s', str(toc - tic))
    stdout_log.write('Elapsed TRAIN time: {}'.format(toc - tic))
    stdout_log.close()
Esempio n. 29
0
def run():
    """
    Run the experiment.
    """
    ds_train, ds_train_eval, meta = init_data()
    num_batches = meta["num_batches"]
    num_test_batches = meta["num_test_batches"]

    forward, model = init_model()
    forward_all = model["model"]["forward_all"]
    grad_fn = jax.grad(lambda *args: loss_fn(forward, *args))

    def lr_schedule(train_itr):
        """
        The learning rate schedule.
        """
        _epoch = train_itr // num_batches
        id = lambda x: x
        return lax.cond(
            _epoch < 60, 1e-1, id, 0, lambda _: lax.cond(
                _epoch < 100, 1e-2, id, 0, lambda _: lax.cond(
                    _epoch < 140, 1e-3, id, 1e-4, id)))

    opt_init, opt_update, get_params = optimizers.momentum(
        step_size=lr_schedule, mass=0.9)
    if parse_args.load_ckpt:
        file_ = open(parse_args.load_ckpt, 'rb')
        init_params = pickle.load(file_)
        file_.close()

        # parse itr from the checkpoint
        load_itr = int(os.path.basename(parse_args.load_ckpt).split("_")[-2])
    else:
        init_params = model["params"]
        load_itr = 0
    opt_state = opt_init(init_params)

    #@jax.jit
    def update(_itr, _opt_state, _key, _batch):
        """
        Update the params based on grad for current batch.
        """
        images, labels = _batch
        return opt_update(
            _itr, grad_fn(get_params(_opt_state), images, labels, _key),
            _opt_state)

# @jax.jit

    def sep_losses(_opt_state, _batch, key):
        """
        Convenience function for calculating losses separately.
        """
        params = get_params(_opt_state)
        images, labels = _batch
        logits, r2_regs, fro_regs, kin_regs = forward_all(key, params, images)
        loss_ = _loss_fn(logits, labels)
        r2_reg_ = _reg_loss_fn(r2_regs)
        fro_reg_ = _reg_loss_fn(fro_regs)
        kin_reg_ = _reg_loss_fn(kin_regs)
        total_loss_ = loss_ + lam * r2_reg_ + lam_fro * fro_reg_ + lam_kin * kin_reg_
        acc_ = _acc_fn(logits, labels)
        return acc_, total_loss_, loss_, r2_reg_, fro_reg_, kin_reg_

    def evaluate_loss(opt_state, _key, ds_train_eval):
        """
        Convenience function for evaluating loss over train set in smaller batches.
        """
        sep_acc_, sep_loss_aug_, sep_loss_, \
        sep_loss_r2_reg_, sep_loss_fro_reg_, sep_loss_kin_reg_, nfe = [], [], [], [], [], [], []

        for test_batch_num in range(num_test_batches):
            test_batch = next(ds_train_eval)
            _key, = jax.random.split(_key, num=1)

            test_batch_acc_, test_batch_loss_aug_, test_batch_loss_, \
            test_batch_loss_r2_reg_, test_batch_loss_fro_reg_, test_batch_loss_kin_reg_ = \
                sep_losses(opt_state, test_batch, _key)

            if count_nfe:
                nfe.append(model["nfe"](get_params(opt_state), *test_batch))
            else:
                nfe.append(0)

            sep_acc_.append(test_batch_acc_)
            sep_loss_aug_.append(test_batch_loss_aug_)
            sep_loss_.append(test_batch_loss_)
            sep_loss_r2_reg_.append(test_batch_loss_r2_reg_)
            sep_loss_fro_reg_.append(test_batch_loss_fro_reg_)
            sep_loss_kin_reg_.append(test_batch_loss_kin_reg_)

        sep_acc_ = jnp.array(sep_acc_)
        sep_loss_aug_ = jnp.array(sep_loss_aug_)
        sep_loss_ = jnp.array(sep_loss_)
        sep_loss_r2_reg_ = jnp.array(sep_loss_r2_reg_)
        sep_loss_fro_reg_ = jnp.array(sep_loss_fro_reg_)
        sep_loss_kin_reg_ = jnp.array(sep_loss_kin_reg_)
        nfe = jnp.array(nfe)

        return jnp.mean(sep_acc_), jnp.mean(sep_loss_aug_), jnp.mean(sep_loss_), \
               jnp.mean(sep_loss_r2_reg_), jnp.mean(sep_loss_fro_reg_), jnp.mean(sep_loss_kin_reg_), jnp.mean(nfe)

    itr = 0
    info = collections.defaultdict(dict)

    key = rng
    #创建迭代器
    iterator = iter(ds_train)
    for epoch in range(parse_args.nepochs):
        for i in range(num_batches):
            batch = next(iterator)

            key, = jax.random.split(key, num=1)

            itr += 1

            if parse_args.load_ckpt:
                if itr <= load_itr:
                    continue

            update_start = time.time()
            opt_state = update(itr, opt_state, key, batch)
            tree_flatten(opt_state)[0][0].block_until_ready()
            update_end = time.time()
            time_str = "%d %.18f %d\n" % (itr, update_end - update_start,
                                          load_itr)
            outfile = open(
                "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_time.txt" %
                (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a")
            outfile.write(time_str)
            outfile.close()

            if itr % parse_args.test_freq == 0:
                acc_, loss_aug_, loss_, \
                loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_ = evaluate_loss(opt_state, key, ds_train_eval)

                print_str = 'Iter {:04d} | Total (Regularized) Loss {:.6f} | Loss {:.6f} | ' \
                            'r {:.6f} | fro {:.6f} | kin {:.6f} | ' \
                            'NFE {:.6f}'.format(itr, loss_aug_, loss_, loss_r2_reg_, loss_fro_reg_, loss_kin_reg_, nfe_)

                print(print_str)

                outfile = open(
                    "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_info.txt"
                    % (dirname, reg, reg_type, lam, lam_fro, lam_kin), "a")
                outfile.write(print_str + "\n")
                outfile.close()

                info[itr]["acc"] = acc_
                info[itr]["loss_aug"] = loss_aug_
                info[itr]["loss"] = loss_
                info[itr]["loss_r2_reg"] = loss_r2_reg_
                info[itr]["loss_fro_reg"] = loss_fro_reg_
                info[itr]["loss_kin_reg"] = loss_kin_reg_
                info[itr]["nfe"] = nfe_

            if itr % parse_args.save_freq == 0:
                param_filename = "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_fargs.pickle" \
                             % (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr)
                fargs = get_params(opt_state)
                outfile = open(param_filename, "wb")
                pickle.dump(fargs, outfile)
                outfile.close()

    meta = {"info": info, "args": parse_args}
    outfile = open(
        "%s/reg_%s_%s_lam_%.18e_lam_fro_%.18e_lam_kin_%.18e_%d_meta.pickle" %
        (dirname, reg, reg_type, lam, lam_fro, lam_kin, itr), "wb")
    pickle.dump(meta, outfile)
    outfile.close()
    num_complete_batches, leftover = divmod(num_train, config.batch_size)
    num_batches = num_complete_batches + bool(leftover)


    def data_stream():
        rng = npr.RandomState(0)
        while True:
            perm = rng.permutation(num_train)
            for i in range(num_batches):
                batch_idx = perm[i * config.batch_size:(i + 1) * config.batch_size]
                yield train_images[batch_idx], train_labels[batch_idx]


    batches = data_stream()

    opt_init, opt_update, get_params = optimizers.momentum(config.learning_rate, mass=config.momentum_mass)


    @jit
    def update(i, opt_state, batch):
        params = get_params(opt_state)
        return opt_update(i, grad(loss)(params, batch), opt_state)


    _, init_params = init_random_params(rng, (-1, 28 * 28))
    opt_state = opt_init(init_params)
    itercount = itertools.count()

    print("\nStarting training...")
    for epoch in range(num_epochs):
        start_time = time.time()