def testLuGrad(self, shape, dtype, rng): _skip_if_unsupported_type(dtype) a = rng(shape, dtype) jtu.check_grads(jsp.linalg.lu, (a, ), 2, atol=5e-2, rtol=1e-1)
def testTopKGrad(self, shape, dtype, k): flat_values = np.arange(prod(shape), dtype=dtype) values = self.rng().permutation(flat_values).reshape(shape) fun = lambda vs: lax.top_k(vs, k=k)[0] check_grads(fun, (values,), 2, ["fwd", "rev"], eps=1e-2)
def testBroadcastInDimGrad(self, inshape, dtype, outshape, dimensions): rng = jtu.rand_default(self.rng()) operand = rng(inshape, dtype) broadcast_in_dim = lambda x: lax.broadcast_in_dim(x, outshape, dimensions) check_grads(broadcast_in_dim, (operand,), 2, ["fwd", "rev"], eps=1.)
def testSliceGrad(self, shape, dtype, starts, limits, strides): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) slice = lambda x: lax.slice(x, starts, limits, strides) check_grads(slice, (operand,), 2, ["fwd", "rev"], eps=1.)
def testReluGrad(self): rtol = 1e-2 if jtu.device_under_test() == "tpu" else None check_grads(nn.relu, (1., ), order=3, rtol=rtol) check_grads(nn.relu, (-1., ), order=3, rtol=rtol) jaxpr = jax.make_jaxpr(jax.grad(nn.relu))(0.) self.assertGreaterEqual(len(jaxpr.jaxpr.eqns), 2)
def testOpGradSpecialValue(self, op, special_value, tol): check_grads(op, (special_value,), 2, ["fwd", "rev"], rtol=tol, atol=tol)
def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8,), 4)
def testSoftplusGradZero(self): check_grads(nn.softplus, (0., ), order=1, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def testTransposeGrad(self, shape, dtype, perm, rng_factory): rng = rng_factory(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
def testIndexTakeGrad(self, shape, dtype, idxs, axes, rng_factory): rng = rng_factory(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
def testBroadcastGrad(self, shape, dtype, broadcast_sizes, rng_factory): rng = rng_factory(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
def testCumulativeReduceGrad(self, op, shape, dtype, axis, rng_factory): rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis), (rng(shape, dtype),), order=2)
def testSoftplusGrad(self): check_grads(nn.softplus, (1e-8, ), 4, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def testIndexTakeGrad(self, shape, dtype, idxs, axes): rng = jtu.rand_default(self.rng()) src = rng(shape, dtype) index_take = lambda src: lax.index_take(src, idxs, axes) check_grads(index_take, (src,), 2, ["fwd", "rev"], eps=1.)
def testSlogdetGrad(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) a = rng(shape, dtype) jtu.check_grads(np.linalg.slogdet, (a,), 2, atol=1e-1, rtol=1e-1)
def testLuGrad(self, shape, dtype, rng): a = rng(shape, dtype) jtu.check_grads(jsp.linalg.lu, (a, ), 2, rtol=1e-1)
def testLuGrad(self, shape, dtype, rng_factory): rng = rng_factory() _skip_if_unsupported_type(dtype) a = rng(shape, dtype) lu = vmap(jsp.linalg.lu) if len(shape) > 2 else jsp.linalg.lu jtu.check_grads(lu, (a,), 2, atol=5e-2, rtol=1e-1)
def testSoftplusGradNan(self): check_grads(nn.softplus, (float('nan'), ), order=1, rtol=1e-2 if jtu.device_under_test() == "tpu" else None)
def testScanRnn(self): r = npr.RandomState(0) n_in = 4 n_hid = 2 n_out = 1 length = 3 W_trans = r.randn(n_hid, n_hid + n_in) W_out = r.randn(n_out, n_hid + n_in) params = W_trans, W_out inputs = r.randn(length, n_in) targets = r.randn(length, n_out) def step(params, state, input): W_trans, W_out = params stacked = np.concatenate([state, input]) output = np.tanh(np.dot(W_out, stacked)) next_state = np.tanh(np.dot(W_trans, stacked)) return next_state, output def rnn(params, inputs): init_state = np.zeros(n_hid) _, outputs = lax.scan(partial(step, params), init_state, inputs) return outputs def loss(params, inputs, targets): predictions = rnn(params, inputs) return np.sum((predictions - targets)**2) # evaluation doesn't crash loss(params, inputs, targets) # jvp evaluation doesn't crash api.jvp(lambda params: loss(params, inputs, targets), (params, ), (params, )) # jvp numerical check passes jtu.check_grads(loss, (params, inputs, targets), order=2, modes=["fwd"]) # linearize works _, expected = api.jvp(loss, (params, inputs, targets), (params, inputs, targets)) _, linfun = api.linearize(loss, params, inputs, targets) ans = linfun(params, inputs, targets) self.assertAllClose(ans, expected, check_dtypes=False) # gradient evaluation doesn't crash api.grad(loss)(params, inputs, targets) # gradient check passes jtu.check_grads(loss, (params, inputs, targets), order=2) # we can vmap to batch things batch_size = 7 batched_inputs = r.randn(batch_size, length, n_in) batched_targets = r.randn(batch_size, length, n_out) batched_loss = api.vmap(lambda x, y: loss(params, x, y)) losses = batched_loss(batched_inputs, batched_targets) expected = onp.stack( list( map(lambda x, y: loss(params, x, y), batched_inputs, batched_targets))) self.assertAllClose(losses, expected, check_dtypes=False)
def testEluGrad(self): check_grads(nn.elu, (1e4, ), order=4, eps=1.)
def testTransposeGrad(self, shape, dtype, perm): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) transpose = lambda x: lax.transpose(x, perm) check_grads(transpose, (operand,), 2, ["fwd", "rev"], eps=1.)
def testBroadcastGrad(self, shape, dtype, broadcast_sizes): rng = jtu.rand_default(self.rng()) args = (rng(shape, dtype),) broadcast = lambda x: lax.broadcast(x, broadcast_sizes) check_grads(broadcast, args, 2, ["fwd", "rev"], eps=1.)
def testCumulativeReduceGrad(self, op, shape, dtype, axis, reverse): rng_factory = (jtu.rand_default if dtypes.issubdtype(dtype, np.integer) else jtu.rand_small) rng = rng_factory(self.rng()) check_grads(partial(op, axis=axis, reverse=reverse), (rng(shape, dtype),), order=2)
def testReshapeGrad(self, arg_shape, out_shape, permutation, dtype): rng = jtu.rand_default(self.rng()) operand = rng(arg_shape, dtype) reshape = lambda x: lax.reshape(x, out_shape, permutation) check_grads(reshape, (operand,), 2, ["fwd", "rev"], eps=1.)
def testSortGrad(self, shape, dtype, axis, is_stable): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) sort = lambda x: lax.sort(x, dimension=axis, is_stable=is_stable) check_grads(sort, (operand,), 2, ["fwd", "rev"], eps=1e-2)
def testDynamicSliceGrad(self, shape, dtype, start_indices, size_indices): rng = jtu.rand_default(self.rng()) operand = rng(shape, dtype) dynamic_slice = lambda x: lax.dynamic_slice(x, start_indices, size_indices) check_grads(dynamic_slice, (operand,), 2, ["fwd", "rev"], eps=1.)
def assert_potential_invariance(energy_fn, x0, params, box=None): # note: all potentials must support non-periodic implementations, # defined by box being None # explicitly check without box even if box is not None check_grads(energy_fn, (x0, params, None), order=1, eps=1e-5) check_grads(energy_fn, (x0, params, None), order=2, eps=1e-7) # check with box if present if box is not None: check_grads(energy_fn, (x0, params, box), order=1, eps=1e-5) check_grads(energy_fn, (x0, params, box), order=2, eps=1e-7) # test translational and rotational invariance of # energy and its derivatives, with and without box energy = energy_fn(x0, params, None) force_fn = jax.grad(energy_fn, argnums=(0, )) forces = force_fn(x0, params, None)[0] dEdp = jax.jacfwd(energy_fn, argnums=(1, )) d2Edxdp = jax.jacfwd(dEdp, argnums=(0, )) dparam = dEdp(x0, params, None)[0] mixed = d2Edxdp(x0, params, None)[0][0] for _ in range(3): trans_vector = np.random.rand(3).astype(dtype=np.float64) trans_x = x0 + trans_vector trans_energy = energy_fn(trans_x, params, None) trans_forces = force_fn(trans_x, params, None)[0] trans_dEdp = dEdp(trans_x, params, None)[0] trans_mixed = d2Edxdp(trans_x, params, None)[0][0] np.testing.assert_allclose(trans_energy, energy, rtol=1e-10) np.testing.assert_allclose(trans_forces, forces, rtol=1e-10) np.testing.assert_allclose(trans_dEdp, dparam, rtol=1e-10) np.testing.assert_allclose(trans_mixed, mixed, rtol=1e-10) for _ in range(3): rot_matrix = special_ortho_group.rvs(3).astype(dtype=np.float64) rot_x = np.matmul(x0, rot_matrix) rot_energy = energy_fn(rot_x, params, None) rot_forces = force_fn(rot_x, params, None)[0] rot_dEdp = dEdp(rot_x, params, None)[0] rot_mixed = d2Edxdp(rot_x, params, None)[0][0] np.testing.assert_allclose(rot_energy, energy, rtol=1e-10) np.testing.assert_allclose(rot_forces, np.matmul(forces, rot_matrix), rtol=1e-10) np.testing.assert_allclose(rot_dEdp, dparam, rtol=1e-10) for i in range(rot_mixed.shape[0]): np.testing.assert_allclose(rot_mixed[i], np.matmul(mixed[i], rot_matrix), rtol=1e-10) for _ in range(3): trans_vector = np.random.rand(3).astype(dtype=np.float64) rot_matrix = special_ortho_group.rvs(3).astype(dtype=np.float64) comp_x = np.matmul(x0, rot_matrix) + trans_vector comp_energy = energy_fn(comp_x, params, None) comp_forces = force_fn(comp_x, params, None)[0] comp_dEdp = dEdp(comp_x, params, None)[0] comp_mixed = d2Edxdp(comp_x, params, None)[0][0] np.testing.assert_allclose(comp_energy, energy, rtol=1e-10) np.testing.assert_allclose(comp_forces, np.matmul(forces, rot_matrix), rtol=1e-10) np.testing.assert_allclose(comp_dEdp, dparam, rtol=1e-10) for i in range(comp_mixed.shape[0]): np.testing.assert_allclose(comp_mixed[i], np.matmul(mixed[i], rot_matrix), rtol=1e-10) if box is not None: energy = energy_fn(x0, params, box) force_fn = jax.grad(energy_fn, argnums=(0, )) forces = force_fn(x0, params, box)[0] dEdp = jax.jacfwd(energy_fn, argnums=(1, )) d2Edxdp = jax.jacfwd(dEdp, argnums=(0, )) dparam = dEdp(x0, params, box)[0] mixed = d2Edxdp(x0, params, box)[0][0] for _ in range(3): trans_vector = np.random.rand(3).astype(dtype=np.float64) trans_x = x0 + trans_vector trans_energy = energy_fn(trans_x, params, box) trans_forces = force_fn(trans_x, params, box)[0] trans_dEdp = dEdp(trans_x, params, box)[0] trans_mixed = d2Edxdp(trans_x, params, box)[0][0] np.testing.assert_allclose(trans_energy, energy, rtol=1e-10) np.testing.assert_allclose(trans_forces, forces, rtol=1e-10) np.testing.assert_allclose(trans_dEdp, dparam, rtol=1e-10) np.testing.assert_allclose(trans_mixed, mixed, rtol=1e-10)