Exemple #1
0
 def testLuGrad(self, shape, dtype, rng):
     _skip_if_unsupported_type(dtype)
     a = rng(shape, dtype)
     jtu.check_grads(jsp.linalg.lu, (a, ), 2, atol=5e-2, rtol=1e-1)
Exemple #2
0
 def testTopKGrad(self, shape, dtype, k):
   flat_values = np.arange(prod(shape), dtype=dtype)
   values = self.rng().permutation(flat_values).reshape(shape)
   fun = lambda vs: lax.top_k(vs, k=k)[0]
   check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
Exemple #3
0
 def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions):
   rng = jtu.rand_default(self.rng())
   operand = rng(inshape, dtype)
   broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions)
   check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #4
0
 def testSliceGrad(self, shape, dtype, starts, limits, strides):
   rng = jtu.rand_default(self.rng())
   operand = rng(shape, dtype)
   slice = lambda x: lax.slice(x, starts, limits, strides)
   check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #5
0
 def testReluGrad(self):
     rtol = 1e-2 if jtu.device_under_test() == "tpu" else None
     check_grads(nn.relu, (1., ), order=3, rtol=rtol)
     check_grads(nn.relu, (-1., ), order=3, rtol=rtol)
     jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.)
     self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
Exemple #6
0
 def testOpGradSpecialValue(self, op, special_value, tol):
   check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
Exemple #7
0
 def testSoftplusGrad(self):
   check_grads(nn.softplus, (1e-8,), 4)
Exemple #8
0
 def testSoftplusGradZero(self):
     check_grads(nn.softplus, (0., ),
                 order=1,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
Exemple #9
0
 def testTransposeGrad(self, shape, dtype, perm, rng_factory):
   rng = rng_factory(self.rng())
   operand = rng(shape, dtype)
   transpose = lambda x: lax.transpose(x, perm)
   check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #10
0
 def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory):
   rng = rng_factory(self.rng())
   src = rng(shape, dtype)
   index_take = lambda src: lax.index_take(src, idxs, axes)
   check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
Exemple #11
0
 def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory):
   rng = rng_factory(self.rng())
   args = (rng(shape, dtype),)
   broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
   check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
Exemple #12
0
 def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory):
   rng = rng_factory(self.rng())
   check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)
Exemple #13
0
 def testSoftplusGrad(self):
     check_grads(nn.softplus, (1e-8, ),
                 4,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
Exemple #14
0
 def testIndexTakeGrad(self, shape, dtype, idxs, axes):
   rng = jtu.rand_default(self.rng())
   src = rng(shape, dtype)
   index_take = lambda src: lax.index_take(src, idxs, axes)
   check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
Exemple #15
0
 def testSlogdetGrad(self, shape, dtype, rng_factory):
   rng = rng_factory()
   _skip_if_unsupported_type(dtype)
   a = rng(shape, dtype)
   jtu.check_grads(np.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
Exemple #16
0
    def testLuGrad(self, shape, dtype, rng):
        a = rng(shape, dtype)

        jtu.check_grads(jsp.linalg.lu, (a, ), 2, rtol=1e-1)
Exemple #17
0
 def testLuGrad(self, shape, dtype, rng_factory):
   rng = rng_factory()
   _skip_if_unsupported_type(dtype)
   a = rng(shape, dtype)
   lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu
   jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=1e-1)
Exemple #18
0
 def testSoftplusGradNan(self):
     check_grads(nn.softplus, (float('nan'), ),
                 order=1,
                 rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
    def testScanRnn(self):
        r = npr.RandomState(0)

        n_in = 4
        n_hid = 2
        n_out = 1
        length = 3

        W_trans = r.randn(n_hid, n_hid + n_in)
        W_out = r.randn(n_out, n_hid + n_in)
        params = W_trans, W_out

        inputs = r.randn(length, n_in)
        targets = r.randn(length, n_out)

        def step(params, state, input):
            W_trans, W_out = params
            stacked = np.concatenate([state, input])
            output = np.tanh(np.dot(W_out, stacked))
            next_state = np.tanh(np.dot(W_trans, stacked))
            return next_state, output

        def rnn(params, inputs):
            init_state = np.zeros(n_hid)
            _, outputs = lax.scan(partial(step, params), init_state, inputs)
            return outputs

        def loss(params, inputs, targets):
            predictions = rnn(params, inputs)
            return np.sum((predictions - targets)**2)

        # evaluation doesn't crash
        loss(params, inputs, targets)

        # jvp evaluation doesn't crash
        api.jvp(lambda params: loss(params, inputs, targets), (params, ),
                (params, ))

        # jvp numerical check passes
        jtu.check_grads(loss, (params, inputs, targets),
                        order=2,
                        modes=["fwd"])

        # linearize works
        _, expected = api.jvp(loss, (params, inputs, targets),
                              (params, inputs, targets))
        _, linfun = api.linearize(loss, params, inputs, targets)
        ans = linfun(params, inputs, targets)
        self.assertAllClose(ans, expected, check_dtypes=False)

        # gradient evaluation doesn't crash
        api.grad(loss)(params, inputs, targets)

        # gradient check passes
        jtu.check_grads(loss, (params, inputs, targets), order=2)

        # we can vmap to batch things
        batch_size = 7
        batched_inputs = r.randn(batch_size, length, n_in)
        batched_targets = r.randn(batch_size, length, n_out)
        batched_loss = api.vmap(lambda x, y: loss(params, x, y))
        losses = batched_loss(batched_inputs, batched_targets)
        expected = onp.stack(
            list(
                map(lambda x, y: loss(params, x, y), batched_inputs,
                    batched_targets)))
        self.assertAllClose(losses, expected, check_dtypes=False)
Exemple #20
0
 def testEluGrad(self):
     check_grads(nn.elu, (1e4, ), order=4, eps=1.)
Exemple #21
0
 def testTransposeGrad(self, shape, dtype, perm):
   rng = jtu.rand_default(self.rng())
   operand = rng(shape, dtype)
   transpose = lambda x: lax.transpose(x, perm)
   check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #22
0
 def testBroadcastGrad(self, shape, dtype, broadcast_sizes):
   rng = jtu.rand_default(self.rng())
   args = (rng(shape, dtype),)
   broadcast = lambda x: lax.broadcast(x, broadcast_sizes)
   check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
Exemple #23
0
 def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse):
   rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer)
                  else jtu.rand_small)
   rng = rng_factory(self.rng())
   check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),),
               order=2)
Exemple #24
0
 def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype):
   rng = jtu.rand_default(self.rng())
   operand = rng(arg_shape, dtype)
   reshape = lambda x: lax.reshape(x, out_shape, permutation)
   check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #25
0
 def testSortGrad(self, shape, dtype, axis, is_stable):
   rng = jtu.rand_default(self.rng())
   operand = rng(shape, dtype)
   sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable)
   check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
Exemple #26
0
 def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices):
   rng = jtu.rand_default(self.rng())
   operand = rng(shape, dtype)
   dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices)
   check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
Exemple #27
0
def assert_potential_invariance(energy_fn, x0, params, box=None):
    # note: all potentials must support non-periodic implementations,
    # defined by box being None

    # explicitly check without box even if box is not None
    check_grads(energy_fn, (x0, params, None), order=1, eps=1e-5)
    check_grads(energy_fn, (x0, params, None), order=2, eps=1e-7)

    # check with box if present
    if box is not None:
        check_grads(energy_fn, (x0, params, box), order=1, eps=1e-5)
        check_grads(energy_fn, (x0, params, box), order=2, eps=1e-7)

    # test translational and rotational invariance of
    # energy and its derivatives, with and without box
    energy = energy_fn(x0, params, None)
    force_fn = jax.grad(energy_fn, argnums=(0, ))
    forces = force_fn(x0, params, None)[0]
    dEdp = jax.jacfwd(energy_fn, argnums=(1, ))
    d2Edxdp = jax.jacfwd(dEdp, argnums=(0, ))
    dparam = dEdp(x0, params, None)[0]
    mixed = d2Edxdp(x0, params, None)[0][0]

    for _ in range(3):
        trans_vector = np.random.rand(3).astype(dtype=np.float64)
        trans_x = x0 + trans_vector
        trans_energy = energy_fn(trans_x, params, None)
        trans_forces = force_fn(trans_x, params, None)[0]
        trans_dEdp = dEdp(trans_x, params, None)[0]
        trans_mixed = d2Edxdp(trans_x, params, None)[0][0]
        np.testing.assert_allclose(trans_energy, energy, rtol=1e-10)
        np.testing.assert_allclose(trans_forces, forces, rtol=1e-10)
        np.testing.assert_allclose(trans_dEdp, dparam, rtol=1e-10)
        np.testing.assert_allclose(trans_mixed, mixed, rtol=1e-10)

    for _ in range(3):
        rot_matrix = special_ortho_group.rvs(3).astype(dtype=np.float64)
        rot_x = np.matmul(x0, rot_matrix)
        rot_energy = energy_fn(rot_x, params, None)
        rot_forces = force_fn(rot_x, params, None)[0]
        rot_dEdp = dEdp(rot_x, params, None)[0]
        rot_mixed = d2Edxdp(rot_x, params, None)[0][0]
        np.testing.assert_allclose(rot_energy, energy, rtol=1e-10)
        np.testing.assert_allclose(rot_forces,
                                   np.matmul(forces, rot_matrix),
                                   rtol=1e-10)
        np.testing.assert_allclose(rot_dEdp, dparam, rtol=1e-10)
        for i in range(rot_mixed.shape[0]):
            np.testing.assert_allclose(rot_mixed[i],
                                       np.matmul(mixed[i], rot_matrix),
                                       rtol=1e-10)

    for _ in range(3):
        trans_vector = np.random.rand(3).astype(dtype=np.float64)
        rot_matrix = special_ortho_group.rvs(3).astype(dtype=np.float64)
        comp_x = np.matmul(x0, rot_matrix) + trans_vector
        comp_energy = energy_fn(comp_x, params, None)
        comp_forces = force_fn(comp_x, params, None)[0]
        comp_dEdp = dEdp(comp_x, params, None)[0]
        comp_mixed = d2Edxdp(comp_x, params, None)[0][0]
        np.testing.assert_allclose(comp_energy, energy, rtol=1e-10)
        np.testing.assert_allclose(comp_forces,
                                   np.matmul(forces, rot_matrix),
                                   rtol=1e-10)
        np.testing.assert_allclose(comp_dEdp, dparam, rtol=1e-10)
        for i in range(comp_mixed.shape[0]):
            np.testing.assert_allclose(comp_mixed[i],
                                       np.matmul(mixed[i], rot_matrix),
                                       rtol=1e-10)

    if box is not None:
        energy = energy_fn(x0, params, box)
        force_fn = jax.grad(energy_fn, argnums=(0, ))
        forces = force_fn(x0, params, box)[0]
        dEdp = jax.jacfwd(energy_fn, argnums=(1, ))
        d2Edxdp = jax.jacfwd(dEdp, argnums=(0, ))
        dparam = dEdp(x0, params, box)[0]
        mixed = d2Edxdp(x0, params, box)[0][0]

        for _ in range(3):
            trans_vector = np.random.rand(3).astype(dtype=np.float64)
            trans_x = x0 + trans_vector
            trans_energy = energy_fn(trans_x, params, box)
            trans_forces = force_fn(trans_x, params, box)[0]
            trans_dEdp = dEdp(trans_x, params, box)[0]
            trans_mixed = d2Edxdp(trans_x, params, box)[0][0]
            np.testing.assert_allclose(trans_energy, energy, rtol=1e-10)
            np.testing.assert_allclose(trans_forces, forces, rtol=1e-10)
            np.testing.assert_allclose(trans_dEdp, dparam, rtol=1e-10)
            np.testing.assert_allclose(trans_mixed, mixed, rtol=1e-10)