def test_checkpoint_correctness(): bar = lambda x, y: 2 * x + y + 5 checkpointed_bar = checkpoint(bar) foo = lambda x: bar(x, x / 3.) + bar(x, x**2) foo2 = lambda x: checkpointed_bar(x, x / 3.) + checkpointed_bar(x, x**2) assert np.allclose(foo(3.), foo2(3.)) assert np.allclose(grad(foo)(3.), grad(foo2)(3.)) baz = lambda *args: sum(args) checkpointed_baz = checkpoint(baz) foobaz = lambda x: baz(x, x / 3.) foobaz2 = lambda x: checkpointed_baz(x, x / 3.) assert np.allclose(foobaz(3.), foobaz2(3.)) assert np.allclose(grad(foobaz)(3.), grad(foobaz2)(3.))
def test_checkpoint_correctness(): bar = lambda x, y: 2*x + y + 5 checkpointed_bar = checkpoint(bar) foo = lambda x: bar(x, x/3.) + bar(x, x**2) foo2 = lambda x: checkpointed_bar(x, x/3.) + checkpointed_bar(x, x**2) assert np.allclose(foo(3.), foo2(3.)) assert np.allclose(grad(foo)(3.), grad(foo2)(3.)) baz = lambda *args: sum(args) checkpointed_baz = checkpoint(baz) foobaz = lambda x: baz(x, x/3.) foobaz2 = lambda x: checkpointed_baz(x, x/3.) assert np.allclose(foobaz(3.), foobaz2(3.)) assert np.allclose(grad(foobaz)(3.), grad(foobaz2)(3.))
def wrapped_fun_helper(xdict, dummy): ## ag.value_and_grad() to avoid second forward pass ## ag.checkpoint() ensures hessian gets properly checkpointed val, grad = ag.checkpoint(ag.value_and_grad(fun))(xdict) assert len(val.shape) == 0 dummy.cache = grad return val
def checkpoint_memory(): '''This test is meant to be run manually, since it depends on memory_profiler and its behavior may vary.''' try: from memory_profiler import memory_usage except ImportError: return def f(a): for _ in range(10): a = np.sin(a**2 + 1) return a checkpointed_f = checkpoint(f) def testfun(f, x): for _ in range(5): x = f(x) return np.sum(x) gradfun = grad(testfun, 1) A = npr.RandomState(0).randn(100000) max_usage = max(memory_usage((gradfun, (f, A)))) max_checkpointed_usage = max(memory_usage((gradfun, (checkpointed_f, A)))) assert max_checkpointed_usage < max_usage / 2.
def _raw_log_lik(cache, G, data, truncate_probs, folded, error_matrices, vector=False): def wrapped_fun(cache): demo = Demography(G, cache=cache) return _composite_log_likelihood(data, demo, truncate_probs=truncate_probs, folded=folded, error_matrices=error_matrices, vector=vector) if vector: return ag.checkpoint(wrapped_fun)(cache) else: ## avoids second forward pass, and has proper ## checkpointing for hessian, ## but doesn't work for vectorized output return rearrange_dict_grad(wrapped_fun)(cache)