def testRecomputeGradXla(self): device_type = self._get_device_type() device_name = f"{device_type}:0" # Necessary for TFRT tests. if device_type == "TPU": tpu_strategy_util.initialize_tpu_system() n = 500 with ops.device(device_name): # XLA:TPU converts f32 matmuls to bf16, and XLA:CPU converts bf16/f16 # matmuls to f32 after cl/461262189. Use a type that doesn't get # converted. if device_type == "TPU": dtype = dtypes.bfloat16 elem_size = 2 else: dtype = dtypes.float32 elem_size = 4 a = array_ops.zeros((n, n), dtype=dtype) # elem_size * n * n bytes def f(x): for _ in range(5): # matmul can not be fused by XLA. x = math_ops.matmul(x, x) return x def g(f, x): for _ in range(5): x = f(x) return x[0][0] def get_peak_memory(test_func): test_func = def_function.function(self._grad(test_func), jit_compile=True) # The hlo_proto contains statically allocated memory info of HLO values. hlo_proto_serialized = test_func.experimental_get_compiler_ir(a)( stage="optimized_hlo_proto_serialized", device_name=device_name) hlo_proto = hlo_pb2.HloProto.FromString(hlo_proto_serialized) allocations = hlo_proto.buffer_assignment.buffer_allocations return sum(getattr(allocation, "size") for allocation in allocations) f_no_recompute = functools.partial(g, f) f_recompute = functools.partial(g, custom_gradient.recompute_grad(f)) peak_memory_no_recompute = get_peak_memory(f_no_recompute) peak_memory_recompute = get_peak_memory(f_recompute) # elem_size * n * n (size of `a`) * 5 (loop of g) * 5 (loop of f) self.assertGreaterEqual(peak_memory_no_recompute, elem_size * n * n * 5 * 5) # elem_size * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f)) self.assertGreaterEqual(peak_memory_recompute, elem_size * n * n * 5 * 2) # peak_memory_recompute should be less than peak_memory_no_recompute. self.assertLess(peak_memory_recompute, elem_size * n * n * 5 * 3) with ops.device(device_name): res_recompute = f_recompute(a) res_no_recompute = f_no_recompute(a) self.assertAllClose(res_recompute, res_no_recompute)
def testRecomputeGradNonXla(self, mode): device_type = self._get_device_type() device_name = f"{device_type}:0" if device_type == "TPU": self.skipTest("XLA is required for TPU.") if device_type == "CPU": self.skipTest( "b/185371422: get_memory_info does't support CPU yet.") config.reset_memory_stats(device_name) base_memory = config.get_memory_info(device_name)["current"] n = 500 with ops.device(device_name): a = array_ops.ones((n, n), dtype=dtypes.float16) def f(x): for _ in range(5): x = math_ops.matmul(x, x) return x def g(f, x): for _ in range(5): x = f(x) return x[0][0] def run(test_func): with ops.device(device_name): if mode == "eager": return self._grad(test_func)(a) else: return def_function.function(self._grad(test_func))(a) f_no_recompute = functools.partial(g, f) f_recompute = functools.partial(g, custom_gradient.recompute_grad(f)) # The result is not saved so the base memory will stay the same. run(f_no_recompute) peak_memory_no_recompute = ( config.get_memory_info(device_name)["peak"] - base_memory) config.reset_memory_stats(device_name) run(f_recompute) peak_memory_recompute = (config.get_memory_info(device_name)["peak"] - base_memory) # 2 * n * n (size of `a`) * 5 (loop of f) * 5 (loop of g) self.assertGreaterEqual(peak_memory_no_recompute, 2 * n * n * 5 * 5) # 2 * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f)) self.assertGreaterEqual(peak_memory_recompute, 2 * n * n * 5 * 2) # peak_memory_recompute should be less than peak_memory_no_recompute. self.assertLess(peak_memory_recompute, 2 * n * n * 5 * 3) res_no_recompute = run(f_no_recompute) res_recompute = run(f_recompute) self.assertAllClose(res_no_recompute, res_recompute)
def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad): """Returns gradients of `test_model` with respect to `vars_to_grad`.""" test_fn_re = custom_gradient.recompute_grad(test_fn) with backprop.GradientTape(persistent=True) as tape: tape.watch(vars_to_grad) out_re = test_fn_re(inputs, vars_to_grad) out = test_fn(inputs, vars_to_grad) grads_re = tape.gradient(out_re, vars_to_grad) grads = tape.gradient(out, vars_to_grad) return grads_re, grads