コード例 #1
0
    def test_behler_parrinello_network_neighbor_list(self, N_types, dtype):
        key = random.PRNGKey(1)
        R = np.array([[0, 0, 0], [1, 1, 1], [1, 1, 0]], dtype)
        species = np.array([1, 1, N_types]) if N_types > 1 else None
        box_size = f32(1.5)
        displacement, _ = space.periodic(box_size)
        neighbor_fn, nn_init, nn_apply = energy.behler_parrinello_neighbor_list(
            displacement, box_size, species)

        nbrs = neighbor_fn(R)
        params = nn_init(key, R, nbrs)
        nn_force_fn = grad(nn_apply, argnums=1)
        nn_force = jit(nn_force_fn)(params, R, nbrs)
        nn_energy = jit(nn_apply)(params, R, nbrs)
        self.assertAllClose(np.any(np.isnan(nn_energy)), False)
        self.assertAllClose(np.any(np.isnan(nn_force)), False)
        self.assertAllClose(nn_force.shape, [3, 3])
コード例 #2
0
ファイル: random_test.py プロジェクト: mattwescott/jax
    def testGammaGrad(self, alpha):
        rng = random.PRNGKey(0)
        alphas = onp.full((100, ), alpha)
        z = random.gamma(rng, alphas)
        actual_grad = api.grad(lambda x: random.gamma(rng, x).sum())(alphas)

        eps = 0.01 * alpha / (1.0 + onp.sqrt(alpha))
        cdf_dot = (scipy.stats.gamma.cdf(z, alpha + eps) -
                   scipy.stats.gamma.cdf(z, alpha - eps)) / (2 * eps)
        pdf = scipy.stats.gamma.pdf(z, alpha)
        expected_grad = -cdf_dot / pdf

        self.assertAllClose(
            actual_grad,
            expected_grad,
            check_dtypes=True,
            rtol=2e-2 if jtu.device_under_test() == "tpu" else 5e-4)
コード例 #3
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f, _ = stax.serial(stax.Dense(2048, 1., 0.05), stax.Erf(),
                                stax.Dense(10, 1., 0.05))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    ntk = batch(get_ntk_fun_empirical(f), batch_size=4, device_count=0)
    g_dd = ntk(x_train, None, params)
    g_td = ntk(x_test, x_train, params)
    predictor = predict.gradient_descent_mse(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(FLAGS.train_time, fx_train, fx_test)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
コード例 #4
0
  def test_custom_root_scalar(self):

    # TODO(shoyer): Figure out why this fails and re-enable it, if possible. My
    # best guess is that TPUs use less stable numerics for pow().
    if jtu.device_under_test() == "tpu":
      raise SkipTest("Test fails on TPU")

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2. - np.array(x) ** 3.
      return lax.custom_root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False, rtol=1e-6)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False,
                        rtol=1e-7)
    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    # TODO(shoyer): reenable when batching works
    # inputs = np.array([4.0, 5.0])
    # results = api.vmap(sqrt_cubed)(inputs)
    # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False,
                        rtol={onp.float64:1e-7})
コード例 #5
0
def main(unused_argv):
    # Build data pipelines.
    print('Loading data.')
    x_train, y_train, x_test, y_test = \
        datasets.mnist(FLAGS.train_size, FLAGS.test_size)

    # Build the network
    init_fn, f = stax.serial(layers.Dense(4096), stax.Tanh, layers.Dense(10))

    key = random.PRNGKey(0)
    _, params = init_fn(key, (-1, 784))

    # Create and initialize an optimizer.
    opt_init, opt_apply, get_params = optimizers.sgd(FLAGS.learning_rate)
    state = opt_init(params)

    # Create an mse loss function and a gradient function.
    loss = lambda fx, y_hat: 0.5 * np.mean((fx - y_hat)**2)
    grad_loss = jit(grad(lambda params, x, y: loss(f(params, x), y)))

    # Create an MSE predictor to solve the NTK equation in function space.
    theta = tangents.ntk(f, batch_size=32)
    g_dd = theta(params, x_train)
    g_td = theta(params, x_test, x_train)
    predictor = tangents.analytic_mse_predictor(g_dd, y_train, g_td)

    # Get initial values of the network in function space.
    fx_train = f(params, x_train)
    fx_test = f(params, x_test)

    # Train the network.
    train_steps = int(FLAGS.train_time // FLAGS.learning_rate)
    print('Training for {} steps'.format(train_steps))

    for i in range(train_steps):
        params = get_params(state)
        state = opt_apply(i, grad_loss(params, x_train, y_train), state)

    # Get predictions from analytic computation.
    print('Computing analytic prediction.')
    fx_train, fx_test = predictor(fx_train, fx_test, FLAGS.train_time)

    # Print out summary data comparing the linear / nonlinear model.
    util.print_summary('train', y_train, f(params, x_train), fx_train, loss)
    util.print_summary('test', y_test, f(params, x_test), fx_test, loss)
コード例 #6
0
ファイル: batching_test.py プロジェクト: syyunn/jax
 def testGatherGradBatchedBoth(self, op_axis, idxs_axis, shape, dtype, idxs,
                               dnums, slice_sizes, rng_factory,
                               rng_idx_factory):
     rng = rng_factory()
     rng_idx = rng_idx_factory()
     fun = partial(lax.gather,
                   dimension_numbers=dnums,
                   slice_sizes=slice_sizes)
     gfun = grad(lambda x, idx: np.sum(np.sin(fun(x, idx))))
     operand = rng(shape, dtype)
     assert operand.shape[op_axis] == idxs.shape[idxs_axis]
     ans = vmap(gfun, (op_axis, idxs_axis))(operand, idxs)
     expected = onp.stack([
         gfun(operand[(slice(None), ) * op_axis + (i, )],
              idxs[(slice(None), ) * idxs_axis + (i, )])
         for i in range(idxs.shape[idxs_axis])
     ])
     self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #7
0
ファイル: batching_test.py プロジェクト: zhangyixun3433/jax
  def testNpMaximumPerExampleGrad(self):
    R = onp.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: np.sum(np.maximum(np.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = np.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * np.dot(
          np.maximum(np.dot(W_t, np.transpose(x_ex)), 0.0), x_ex)
      expected_ans = np.transpose(expected_ans)

      self.assertAllClose(ans[i], expected_ans, check_dtypes=False)
コード例 #8
0
def hvp(loss, params, batch, v):
  """Computes the hessian vector product Hv.

  This implementation uses forward-over-reverse mode for computing the hvp.

  Args:
    loss: function computing the loss with signature
      loss(params, batch).
    params: pytree for the parameters of the model.
    batch:  A batch of data. Any format is fine as long as it is a valid input
      to loss(params, batch).
    v: pytree of the same structure as params.

  Returns:
    hvp: array of shape [num_params] equal to Hv where H is the hessian.
  """

  loss_fn = lambda x: loss(x, batch)
  return jvp(grad(loss_fn), [params], [v])[1]
コード例 #9
0
ファイル: api_test.py プロジェクト: mitghi/jax
  def test_defvjp_all_multiple_arguments(self):
    # also tests passing in symbolic zero tangents b/c we differentiate wrt only
    # the first argument in one case

    foo_p = Primitive('foo')
    def foo(x, y): return foo_p.bind(x, y)

    def vjpfun(x, y):
      out = x**2 + y**3
      vjp = lambda g: (g + x + y, g * x * 9.)
      return out, vjp

    ad.defvjp_all(foo_p, vjpfun)
    val_ans, grad_ans = api.value_and_grad(foo)(3., 4.)
    self.assertAllClose(val_ans, 3.**2 + 4.**3, check_dtypes=False)
    self.assertAllClose(grad_ans, 1. + 3. + 4., check_dtypes=False)

    ans = api.grad(foo, (0, 1))(3., 4.)
    self.assertAllClose(ans, (1. + 3. + 4., 1. * 3. * 9.), check_dtypes=False)
コード例 #10
0
  def test_sparse_grad(self):
    rng_sparse = rand_sparse(self.rng())
    rng = jtu.rand_default(self.rng())

    y = rng(5, "float32")
    X = rng_sparse((10, 5), "float32")
    Xsp = sparse.BCOO.fromdense(X)

    def f(X, y):
      return jnp.sum(X @ y)

    grad_dense = api.grad(f, argnums=0)(X, y)
    grad_sparse = sparse.grad(f, argnums=0)(Xsp, y)

    # extract sparse gradient from dense gradient
    indices = tuple(Xsp.indices)
    grad_sparse_from_dense = jnp.zeros_like(grad_dense).at[indices].set(grad_dense[indices])

    self.assertArraysEqual(grad_sparse.todense(), grad_sparse_from_dense)
コード例 #11
0
  def test_remat_grad_python_control_flow(self):
    @partial(api.remat, concrete=True)
    def g(x):
      if x > 0:
        return lax.sin(x), 3.
      else:
        return lax.cos(x), 4.

    def f(x):
      x, _ = g(x)
      return x

    ans = f(2.)
    expected = onp.sin(2.)
    self.assertAllClose(ans, expected, check_dtypes=False)

    ans = api.grad(f)(2.)
    expected = onp.cos(2.)
    self.assertAllClose(ans, expected, check_dtypes=False)
コード例 #12
0
  def testNpMaximumPerExampleGrad(self):
    R = np.random.RandomState(0).randn
    x = R(10, 5)
    W = R(5, 5)

    fun = lambda W, x: jnp.sum(jnp.maximum(jnp.dot(x, W), 0.0) ** 2)

    ans = vmap(partial(grad(fun), W))(x)

    W_t = jnp.transpose(W)
    for i in range(10):
      x_ex = x[i:i + 1]

      expected_ans = 2.0 * jnp.dot(
          jnp.maximum(jnp.dot(W_t, jnp.transpose(x_ex)), 0.0), x_ex)
      expected_ans = jnp.transpose(expected_ans)

      self.assertAllClose(
          ans[i], expected_ans, check_dtypes=False,
          atol={np.float32:5e-2} if jtu.device_under_test() == "tpu" else None)
コード例 #13
0
  def test_grad_simple(self):
    def func(x):
      y = hcb.id_print(x * 2., what="x * 2", output_stream=testing_stream)
      return x * hcb.id_print(y * 3., what="y * 3", output_stream=testing_stream)
    grad_func = api.grad(func)
    #assertMultiLineStrippedEqual(self, "", str(api.make_jaxpr(grad_func)(5.)))

    with hcb.outfeed_receiver():
      res_grad = grad_func(jnp.float32(5.))
    self.assertAllClose(2. * 5. * 6., res_grad, check_dtypes=False)
    assertMultiLineStrippedEqual(self, """
what: x * 2
10.00
what: y * 3
30.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: y * 3
5.00
transforms: ({'name': 'jvp'}, {'name': 'transpose'}) what: x * 2
15.00""", testing_stream.output)
    testing_stream.reset()
コード例 #14
0
ファイル: energy_test.py プロジェクト: zizai/jax-md
  def test_soft_sphere(self, spatial_dimension, alpha, dtype):
    key = random.PRNGKey(0)
    alpha = f32(alpha)
    for _ in range(STOCHASTIC_SAMPLES):
      key, split_sigma, split_epsilon = random.split(key, 3)
      sigma = np.array(random.uniform(
          split_sigma, (1,), minval=0.0, maxval=3.0)[0], dtype=dtype)
      epsilon = np.array(
        random.uniform(split_epsilon, (1,), minval=0.0, maxval=4.0)[0],
        dtype=dtype)
      self.assertAllClose(
          energy.soft_sphere(
            dtype(0), sigma, epsilon, alpha), epsilon / alpha, True)
      self.assertAllClose(
        energy.soft_sphere(dtype(sigma), sigma, epsilon, alpha),
        np.array(0.0, dtype=dtype), True)

      if alpha == 3.0:
        grad_energy = grad(energy.soft_sphere)
        g = grad_energy(dtype(sigma), sigma, epsilon, alpha)
        self.assertAllClose(g, np.array(0, dtype=dtype), True)
コード例 #15
0
  def test_root_scalar(self):

    def scalar_solve(f, y):
      return y / f(1.0)

    def binary_search(func, x0, low=0.0, high=100.0, tolerance=1e-6):
      del x0  # unused

      def cond(state):
        low, high = state
        return high - low > tolerance

      def body(state):
        low, high = state
        midpoint = 0.5 * (low + high)
        update_upper = func(midpoint) > 0
        low = np.where(update_upper, low, midpoint)
        high = np.where(update_upper, midpoint, high)
        return (low, high)

      solution, _ = lax.while_loop(cond, body, (low, high))
      return solution

    def sqrt_cubed(x, tangent_solve=scalar_solve):
      f = lambda y: y ** 2 - x ** 3
      return lax.root(f, 0.0, binary_search, tangent_solve)

    value, grad = api.value_and_grad(sqrt_cubed)(5.0)
    self.assertAllClose(value, 5 ** 1.5, check_dtypes=False)
    self.assertAllClose(grad, api.grad(pow)(5.0, 1.5), check_dtypes=False)

    jtu.check_grads(sqrt_cubed, (5.0,), order=2, rtol=1e-3)

    # TODO(shoyer): reenable when batching works
    # inputs = np.array([4.0, 5.0])
    # results = api.vmap(sqrt_cubed)(inputs)
    # self.assertAllClose(results, inputs ** 1.5, check_dtypes=False)

    results = api.jit(sqrt_cubed)(5.0)
    self.assertAllClose(results, 5.0 ** 1.5, check_dtypes=False)
コード例 #16
0
ファイル: energy_test.py プロジェクト: kessel/jax-md
    def test_morse(self, spatial_dimension, dtype):
        key = random.PRNGKey(0)

        for _ in range(STOCHASTIC_SAMPLES):
            key, split_sigma, split_epsilon, split_alpha = random.split(key, 4)
            sigma = dtype(
                random.uniform(split_sigma, (1, ), minval=0., maxval=3.0)[0])
            epsilon = dtype(
                random.uniform(split_epsilon, (1, ), minval=0.0,
                               maxval=4.0)[0])
            alpha = dtype(
                random.uniform(split_alpha, (1, ), minval=1.0, maxval=30.0)[0])
            dr = dtype(sigma)
            self.assertAllClose(energy.morse(dr, sigma, epsilon, alpha),
                                np.array(-epsilon, dtype=dtype))
            g = grad(energy.morse)(dr, sigma, epsilon, alpha)
            self.assertAllClose(g, np.array(0, dtype=dtype))

        # if dr = a/alpha + sigma, then V_morse(dr, sigma, epsilon, alpha)/epsilon
        #   should be independent of sigma, epsilon, and alpha, depending only on a.
        key, split_sigma, split_epsilon, split_alpha = random.split(key, 4)
        sigmas = random.uniform(split_sigma, (STOCHASTIC_SAMPLES, ),
                                minval=0.,
                                maxval=3.0)
        epsilons = random.uniform(split_epsilon, (STOCHASTIC_SAMPLES, ),
                                  minval=0.1,
                                  maxval=4.0)
        alphas = random.uniform(split_alpha, (STOCHASTIC_SAMPLES, ),
                                minval=1.0,
                                maxval=30.0)
        for sigma, epsilon, alpha in zip(sigmas, epsilons, alphas):
            a = np.linspace(max(-2.5, -alpha * sigma), 8.0, 100)
            dr = np.array(a / alpha + sigma, dtype=dtype)
            U = energy.morse(dr, sigma, epsilon, alpha) / dtype(epsilon)
            Ucomp = np.array((dtype(1) - np.exp(-a))**dtype(2) - dtype(1),
                             dtype=dtype)
            self.assertAllClose(U, Ucomp)
コード例 #17
0
  def testStopGradient(self):
    def f(x):
      return lax.sin(x) * lax.cos(lax.stop_gradient(x))

    def f2(x, y):
      return lax.sin(x) * lax.cos(y)

    x = 3.14
    ans = api.grad(f)(x)
    expected = api.grad(f2)(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(api.grad(f))(x)
    expected = api.grad(api.grad(f2))(x, x)
    self.assertAllClose(ans, expected)

    ans = api.grad(lambda x: lax.stop_gradient({'foo':x})['foo'])(3.)
    expected = onp.array(0.0)
    self.assertAllClose(ans, expected, check_dtypes=False)

    with core.skipping_checks():
      with self.assertRaises(TypeError):
        lax.stop_gradient(lambda x: x)
コード例 #18
0
ファイル: jax_wideresnet_exps.py プロジェクト: google/autol2
def main(unused_argv):
    from jax.api import grad, jit, vmap, pmap, device_put
    "The following is required to use TPU Driver as JAX's backend."

    if FLAGS.TPU:
        config.FLAGS.jax_xla_backend = "tpu_driver"
        config.FLAGS.jax_backend_target = "grpc://" + os.environ[
            'TPU_ADDR'] + ':8470'
        TPU_ADDR = os.environ['TPU_ADDR']
    ndevices = xla_bridge.device_count()
    if not FLAGS.TPU:
        ndevices = 1

    pmap = partial(pmap, axis_name='i')
    """Setup some experiment parameters."""
    meas_step = FLAGS.meas_step
    training_epochs = int(FLAGS.epochs)

    tmult = 1.0
    if FLAGS.physical:
        tmult = FLAGS.lr
        if FLAGS.physicalL2:
            tmult = FLAGS.L2 * tmult
    if FLAGS.physical:
        training_epochs = 1 + int(FLAGS.epochs / tmult)

    print('Evolving for {:}e'.format(training_epochs))
    losst = FLAGS.losst
    learning_rate = FLAGS.lr
    batch_size_per_device = FLAGS.bs
    N = FLAGS.N
    K = FLAGS.K

    batch_size = batch_size_per_device * ndevices
    steps_per_epoch = 50000 // batch_size
    training_steps = training_epochs * steps_per_epoch

    "Filename from FLAGS"

    filename = 'wrnL2_' + losst + '_n' + str(N) + '_k' + str(K)
    if FLAGS.momentum:
        filename += '_mom'
    if FLAGS.L2_sch:
        filename += '_L2sch' + '_decay' + str(FLAGS.L2dec) + '_del' + str(
            FLAGS.delay)
    if FLAGS.seed != 1:
        filename += 'seed' + str(FLAGS.seed)
    filename += '_L2' + str(FLAGS.L2)
    if FLAGS.std_wrn_sch:
        filename += '_stddec'
        if FLAGS.physical:
            filename += 'phys'
    else:
        filename += '_ctlr'
    if not FLAGS.augment:
        filename += '_noaug'
    if not FLAGS.mix:
        filename += '_nomixup'
    filename += '_bs' + str(batch_size) + '_lr' + str(learning_rate)
    if FLAGS.jobdir is not None:
        filedir = os.path.join('wrnlogs', FLAGS.jobdir)
    else:
        filedir = 'wrnlogs'
    if not os.path.exists(filedir):
        os.makedirs(filedir)
    filedir = os.path.join(filedir, filename + '.csv')

    print('Saving log to ', filename)
    print('Found {} cores.'.format(ndevices))
    """Load CIFAR10 data and create a minimal pipeline."""

    train_images, train_labels, test_images, test_labels = utils.load_data(
        'cifar10')
    train_images = np.reshape(train_images, (-1, 32, 32 * 3))
    train = (train_images, train_labels)
    test = (test_images, test_labels)
    k = train_labels.shape[-1]
    train = utils.shard_data(train, ndevices)
    test = utils.shard_data(test, ndevices)
    """Create a Wide Resnet and replicate its parameters across the devices."""

    initparams, f, _ = utils.WideResnetnt(N, K, k)

    "Loss and optimizer definitions"

    l2_norm = lambda params: tree_map(lambda x: np.sum(x**2), params)
    l2_reg = lambda params: tree_reduce(lambda x, y: x + y, l2_norm(params))
    currL2 = FLAGS.L2
    L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))

    def xentr(params, images_and_labels):
        images, labels = images_and_labels
        return -np.mean(stax.logsoftmax(f(params, images)) * labels)

    def mse(params, data_tuple):
        """MSE loss."""
        x, y = data_tuple
        return 0.5 * np.mean((y - f(params, x))**2)

    if losst == 'xentr':
        print('Using xentr')
        lossm = xentr
    else:
        print('Using mse')
        lossm = mse

    loss = lambda params, data, L2: lossm(params, data) + L2 * l2_reg(params)

    def accuracy(params, images_and_labels):
        images, labels = images_and_labels
        return np.mean(
            np.array(np.argmax(f(params, images), axis=1) == np.argmax(labels,
                                                                       axis=1),
                     dtype=np.float32))

    "Define optimizer"

    if FLAGS.std_wrn_sch:
        lr = learning_rate
        first_epoch = int(60 / 200 * training_epochs)
        learning_rate_fn = optimizers.piecewise_constant(
            np.array([1, 2, 3]) * first_epoch * steps_per_epoch,
            np.array([lr, lr * 0.2, lr * 0.2**2, lr * 0.2**3]))
    else:
        learning_rate_fn = optimizers.make_schedule(learning_rate)

    if FLAGS.momentum:
        momentum = 0.9
    else:
        momentum = 0

    @pmap
    def update_step(step, state, batch_state, L2):
        batch, batch_state = batch_fn(batch_state)
        params = get_params(state)
        dparams = grad_loss(params, batch, L2)
        dparams = tree_map(lambda x: lax.psum(x, 'i') / ndevices, dparams)
        return step + 1, apply_fn(step, dparams, state), batch_state

    @pmap
    def evaluate(state, data, L2):
        params = get_params(state)
        lossmm = lossm(params, data)
        l2mm = l2_reg(params)
        return lossmm + L2 * l2mm, accuracy(params, data), lossmm, l2mm

    "Initialization and loading"

    _, params = initparams(random.PRNGKey(0), (-1, 32, 32, 3))
    replicate_array = lambda x: \
        np.broadcast_to(x, (ndevices,) + x.shape)
    replicated_params = tree_map(replicate_array, params)

    grad_loss = jit(grad(loss))
    init_fn, apply_fn, get_params = optimizers.momentum(
        learning_rate_fn, momentum)
    apply_fn = jit(apply_fn)
    key = random.PRNGKey(FLAGS.seed)

    batchinit_fn, batch_fn = utils.sharded_minibatcher(batch_size,
                                                       ndevices,
                                                       transform=FLAGS.augment,
                                                       k=k,
                                                       mix=FLAGS.mix)

    batch_state = pmap(batchinit_fn)(random.split(key, ndevices), train)
    state = pmap(init_fn)(replicated_params)

    if FLAGS.checkpointing:
        ## Loading of checkpoint if available/provided.
        single_state = init_fn(params)
        i0, load_state, load_params, filename0, batch_stateb = utils.load_weights(
            filename,
            single_state,
            params,
            full_file=FLAGS.load_w,
            ndevices=ndevices)
        if i0 is not None:
            filename = filename0
            if batch_stateb is not None:
                batch_state = batch_stateb
            if load_params is not None:
                state = pmap(init_fn)(load_params)
            else:
                state = load_state
        else:
            i0 = 0
    else:
        i0 = 0

    if FLAGS.steps_from_load:
        training_steps = i0 + training_steps

    batch_xs, _ = pmap(batch_fn)(batch_state)

    train_loss = []
    train_accuracy = []
    lrL = []
    test_loss = []
    test_accuracy = []
    test_L2, test_lm, train_lm, train_L2 = [], [], [], []
    L2_t = []
    idel0 = i0
    start = time.time()

    step = pmap(lambda x: x)(i0 * np.ones((ndevices, )))

    "Start training loop"
    if FLAGS.checkpointing:
        print('Evolving for {:}e and saving every {:}s'.format(
            training_epochs, FLAGS.checkpointing))

    print(
        'Epoch\tLearning Rate\tTrain bareLoss\t L2_norm \tTest Loss\tTrain Error\tTest Error\tTime / Epoch'
    )

    for i in range(i0, training_steps):
        if i % meas_step == 0:
            # Make Measurement
            l, a, lm, L2m = evaluate(state, test, L2p)
            test_loss += [np.mean(l)]
            test_accuracy += [np.mean(a)]
            test_lm += [np.mean(lm)]
            test_L2 += [np.mean(L2m)]
            train_batch, _ = pmap(batch_fn)(batch_state)
            l, a, lm, L2m = evaluate(state, train_batch, L2p)

            train_loss += [np.mean(l)]
            train_accuracy += [np.mean(a)]
            train_lm += [np.mean(lm)]
            train_L2 += [np.mean(L2m)]
            L2_t.append(currL2)
            lrL += [learning_rate_fn(i)]

            if FLAGS.L2_sch and i > FLAGS.delay / currL2 + idel0 and len(
                    train_lm) > 2 and ((minloss <= train_lm[-1]
                                        and minloss <= train_lm[-2]) or
                                       (maxacc >= train_accuracy[-1]
                                        and maxacc >= train_accuracy[-2])):
                # If AutoL2 is on and we are beyond the refractory period, decay if the loss or error have increased in the last two measurements.
                print('Decaying L2 to', currL2 / FLAGS.L2dec)
                currL2 = currL2 / FLAGS.L2dec
                L2p = pmap(lambda x: x)(currL2 * np.ones((ndevices, )))
                idel0 = i

            elif FLAGS.L2_sch and len(train_lm) >= 2:
                # Update the minimum values.
                try:
                    maxacc = max(train_accuracy[-2], maxacc)
                    minloss = min(train_lm[-2], minloss)
                except:
                    maxacc, minloss = train_accuracy[-2], train_lm[-2]

            if i % (meas_step * 10) == 0 or i == i0:
                # Save measurements to csv
                epoch = batch_size * i / 50000
                dt = (time.time() - start) / (meas_step * 10) * steps_per_epoch
                print(('{}\t' + ('{: .4f}\t' * 7)).format(
                    epoch, learning_rate_fn(i), train_lm[-1], train_L2[-1],
                    test_loss[-1], train_accuracy[-1], test_accuracy[-1], dt))

                start = time.time()
                data = {
                    'train_loss': train_loss,
                    'test_loss': test_loss,
                    'train_acc': train_accuracy,
                    'test_acc': test_accuracy
                }
                data['train_bareloss'] = train_lm
                data['train_L2'] = train_L2
                data['test_bareloss'] = test_lm
                data['test_L2'] = test_L2
                data['L2_t'] = L2_t
                df = pd.DataFrame(data)

                df['learning_rate'] = lrL
                df['width'] = K
                df['batch_size'] = batch_size
                df['step'] = i0 + onp.arange(0, len(train_loss)) * meas_step

                df.to_csv(filedir, index=False)

        if FLAGS.checkpointing:
            ### SAVE MODEL
            if i % FLAGS.checkpointing == 0 and i > i0:

                if not os.path.exists('weights/'):
                    os.makedirs('weights/')
                saveparams = tree_flatten(state[0])[0]
                if ndevices > 1:
                    saveparams = [el[0] for el in saveparams]
                saveparams = np.concatenate(
                    [el.reshape(-1) for el in saveparams])

                step0 = i
                print('Step', i)
                print('saving at', filename, step0, 'size:', saveparams.shape)

                utils.save_weights(filename, step0, saveparams, batch_state)

        ## UPDATE
        step, state, batch_state = update_step(step, state, batch_state, L2p)

    print('Training done')

    if FLAGS.TPU:
        with open('done/' + TPU_ADDR, 'w') as fp:
            fp.write(filedir)
            pass
コード例 #19
0
ファイル: advi.py プロジェクト: zhangfeilong/jax
 def update(i, opt_state):
     params = minmax.get_params(opt_state)
     gradient = grad(objective)(params, i)
     return opt_update(i, gradient, opt_state)
コード例 #20
0
 def apply_carry(carry, _):
     i, x = carry
     new_x = x - 0.1 * api.grad(energy_fn)(x)
     new_carry = (i + 1, new_x)
     return new_carry, _
コード例 #21
0
 def apply_carry(x, i):
     return api.grad(fn, argnums=(0, ))(x)[0], i
コード例 #22
0
st.header("Comparison against finite SGD-NNs")
"""
Finally, let's bring back the practioner loved Neural Networks for a comparison. 
"""

learning_rate = st.slider("Learning rate",
                          1e-4,
                          1.0,
                          0.1,
                          step=1e-4,
                          format="%.4f")

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)
loss = jit(lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y)**2))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))

train_losses = []
test_losses = []

opt_state = opt_init(params)

for i in range(training_steps):
    opt_state = opt_update(i, grad_loss(opt_state, *train), opt_state)

    train_losses += [loss(get_params(opt_state), *train)]
    test_losses += [loss(get_params(opt_state), *test)]

# NTK loss
plt.loglog(ts, ntk_train_loss_mean, linewidth=3)
plt.loglog(ts, ntk_test_loss_mean, linewidth=3)
コード例 #23
0
    def testScanRnn(self):
        r = npr.RandomState(0)

        n_in = 4
        n_hid = 2
        n_out = 1
        length = 3

        W_trans = r.randn(n_hid, n_hid + n_in)
        W_out = r.randn(n_out, n_hid + n_in)
        params = W_trans, W_out

        inputs = r.randn(length, n_in)
        targets = r.randn(length, n_out)

        def step(params, state, input):
            W_trans, W_out = params
            stacked = np.concatenate([state, input])
            output = np.tanh(np.dot(W_out, stacked))
            next_state = np.tanh(np.dot(W_trans, stacked))
            return next_state, output

        def rnn(params, inputs):
            init_state = np.zeros(n_hid)
            _, outputs = lax.scan(partial(step, params), init_state, inputs)
            return outputs

        def loss(params, inputs, targets):
            predictions = rnn(params, inputs)
            return np.sum((predictions - targets)**2)

        # evaluation doesn't crash
        loss(params, inputs, targets)

        # jvp evaluation doesn't crash
        api.jvp(lambda params: loss(params, inputs, targets), (params, ),
                (params, ))

        # jvp numerical check passes
        jtu.check_grads(loss, (params, inputs, targets),
                        order=2,
                        modes=["fwd"])

        # linearize works
        _, expected = api.jvp(loss, (params, inputs, targets),
                              (params, inputs, targets))
        _, linfun = api.linearize(loss, params, inputs, targets)
        ans = linfun(params, inputs, targets)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # gradient evaluation doesn't crash
        api.grad(loss)(params, inputs, targets)

        # gradient check passes
        jtu.check_grads(loss, (params, inputs, targets), order=2)

        # we can vmap to batch things
        batch_size = 7
        batched_inputs = r.randn(batch_size, length, n_in)
        batched_targets = r.randn(batch_size, length, n_out)
        batched_loss = api.vmap(lambda x, y: loss(params, x, y))
        losses = batched_loss(batched_inputs, batched_targets)
        expected = onp.stack(
            list(
                map(lambda x, y: loss(params, x, y), batched_inputs,
                    batched_targets)))
        self.assertAllClose(losses, expected, check_dtypes=False)
コード例 #24
0
def update(_, i, opt_state, batch):
    params = get_params(opt_state)
    return opt_update(i, grad(loss)(params, batch), opt_state)
コード例 #25
0
 def update(params, batch):
     grads = grad(loss)(params, batch)
     return [(w - step_size * dw, b - step_size * db)
             for (w, b), (dw, db) in zip(params, grads)]
コード例 #26
0
def update(params, batch):
    grads = grad(loss)(params, batch)
    return [(w - learning_rate * dw, b - learning_rate * db)
            for (w, b), (dw, db) in zip(params, grads)]
コード例 #27
0
#new define of loss function
def computation(params, inputs, targets):
    logits = predict(params, inputs)
    preds = stax.logsoftmax(logits)
    return -np.mean(np.sum(preds * targets, axis=1))


#set up of index
tl = test_labels
index7 = tl.tolist().index([0, 0, 0, 0, 0, 0, 0, 1, 0, 0])
print(test_labels[index7])
#computing process to the new x
input_image, input_label = shape_as_image(test_images[index7],
                                          test_labels[index7])
grad_newx = grad(computation, 1)(params, input_image, input_label)
newx = input_image + hyper * np.sign(grad_newx)
#start plot and its predicted vector
target_class = np.argmax(input_label)
predicted_class = np.argmax(predict(params, newx))
#predicted vector
predict_vector = predict(params, newx)
print('the target class is :', target_class)
print('the predict class is :', predicted_class)
print('the predicted vector is :', predict_vector)

image = np.array(newx)
image = image * 255
image = image.reshape(28, 28)
plt.imshow(image)
"""## From here is Part 2"""
コード例 #28
0
 def testGradOfXlog1pyAtZero(self):
     partial_xlog1py = functools.partial(lsp_special.xlog1py, 0.)
     self.assertAllClose(api.grad(partial_xlog1py)(-1.),
                         0.,
                         check_dtypes=False)
コード例 #29
0
    def testTrainedEnsemblePredCov(self, train_shape, test_shape, network,
                                   out_logits):
        if xla_bridge.get_backend().platform == 'gpu' and config.read(
                'jax_enable_x64'):
            raise jtu.SkipTest('Not running GPU x64 to save time.')
        training_steps = 5000
        learning_rate = 1.0
        ensemble_size = 50

        init_fn, apply_fn, ker_fn = stax.serial(
            stax.Dense(1024, W_std=1.2, b_std=0.05), stax.Erf(),
            stax.Dense(out_logits, W_std=1.2, b_std=0.05))

        opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
        opt_update = jit(opt_update)

        key = random.PRNGKey(0)
        key, = random.split(key, 1)

        key, split = random.split(key)
        x_train = np.cos(random.normal(split, train_shape))

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)
        train = (x_train, y_train)
        key, split = random.split(key)
        x_test = np.cos(random.normal(split, test_shape))

        ensemble_key = random.split(key, ensemble_size)

        loss = jit(lambda params, x, y: 0.5 * np.mean(
            (apply_fn(params, x) - y)**2))
        grad_loss = jit(lambda state, x, y: grad(loss)
                        (get_params(state), x, y))

        def train_network(key):
            _, params = init_fn(key, (-1, ) + train_shape[1:])
            opt_state = opt_init(params)
            for i in range(training_steps):
                opt_state = opt_update(i, grad_loss(opt_state, *train),
                                       opt_state)

            return get_params(opt_state)

        params = vmap(train_network)(ensemble_key)

        ensemble_fx = vmap(apply_fn, (0, None))(params, x_test)
        ensemble_loss = vmap(loss, (0, None, None))(params, x_train, y_train)
        ensemble_loss = np.mean(ensemble_loss)
        self.assertLess(ensemble_loss, 1e-5, True)

        mean_emp = np.mean(ensemble_fx, axis=0)
        mean_subtracted = ensemble_fx - mean_emp
        cov_emp = np.einsum(
            'ijk,ilk->jl', mean_subtracted, mean_subtracted, optimize=True) / (
                mean_subtracted.shape[0] * mean_subtracted.shape[-1])

        reg = 1e-7
        ntk_predictions = predict.gp_inference(ker_fn,
                                               x_train,
                                               y_train,
                                               x_test,
                                               'ntk',
                                               reg,
                                               compute_cov=True)

        self.assertAllClose(mean_emp, ntk_predictions.mean, True, RTOL, ATOL)
        self.assertAllClose(cov_emp, ntk_predictions.covariance, True, RTOL,
                            ATOL)
コード例 #30
0
    def testNTKMomentumPrediction(self, train_shape, test_shape, network,
                                  out_logits, fn_and_kernel):
        key = random.PRNGKey(0)

        key, split = random.split(key)
        x_train = random.normal(split, train_shape)

        key, split = random.split(key)
        y_train = np.array(
            random.bernoulli(split, shape=(train_shape[0], out_logits)),
            np.float32)

        key, split = random.split(key)
        x_test = random.normal(split, test_shape)

        params, f, ntk = fn_and_kernel(key, train_shape[1:], network,
                                       out_logits)

        # Regress to an MSE loss.
        loss = lambda y, y_hat: 0.5 * np.mean((y - y_hat)**2)
        grad_loss = jit(grad(lambda params, x: loss(f(params, x), y_train)))

        g_dd = ntk(x_train, None, 'ntk')
        g_td = ntk(x_test, x_train, 'ntk')

        atol = ATOL
        rtol = RTOL
        step_size = 0.5

        if len(train_shape) > 2:
            # Hacky way to up the tolerance just for convolutions.
            atol = ATOL * 2
            rtol = RTOL * 2
            step_size = 0.1

        train_time = 100.0
        steps = int(train_time / np.sqrt(step_size))

        init, predictor, get = predict.momentum(g_dd, y_train, loss, step_size,
                                                g_td)

        opt_init, opt_update, get_params = momentum(step_size, 0.9)
        opt_state = opt_init(params)

        fx_initial_train = f(params, x_train)
        fx_initial_test = f(params, x_test)

        lin_state = init(fx_initial_train, fx_initial_test)
        fx_pred_train, fx_pred_test = get(lin_state)

        self.assertAllClose(fx_initial_train, fx_pred_train, True)
        self.assertAllClose(fx_initial_test, fx_pred_test, True)

        for i in range(steps):
            params = get_params(opt_state)
            opt_state = opt_update(i, grad_loss(params, x_train), opt_state)

        params = get_params(opt_state)
        fx_train = f(params, x_train)
        fx_test = f(params, x_test)

        lin_state = predictor(lin_state, train_time)
        fx_pred_train, fx_pred_test = get(lin_state)

        fx_disp_train = np.sqrt(np.mean((fx_train - fx_initial_train)**2))
        fx_disp_test = np.sqrt(np.mean((fx_test - fx_initial_test)**2))

        fx_error_train = (fx_train - fx_pred_train) / fx_disp_train
        fx_error_test = (fx_test - fx_pred_test) / fx_disp_test

        self.assertAllClose(fx_error_train, np.zeros_like(fx_error_train),
                            True, rtol, atol)
        self.assertAllClose(fx_error_test, np.zeros_like(fx_error_test), True,
                            rtol, atol)