Exemple #1
0
  def testRecomputeGradXla(self):
    device_type = self._get_device_type()
    device_name = f"{device_type}:0"
    # Necessary for TFRT tests.
    if device_type == "TPU":
      tpu_strategy_util.initialize_tpu_system()

    n = 500
    with ops.device(device_name):
      # XLA:TPU converts f32 matmuls to bf16, and XLA:CPU converts bf16/f16
      # matmuls to f32 after cl/461262189.  Use a type that doesn't get
      # converted.
      if device_type == "TPU":
        dtype = dtypes.bfloat16
        elem_size = 2
      else:
        dtype = dtypes.float32
        elem_size = 4
      a = array_ops.zeros((n, n), dtype=dtype)  # elem_size * n * n bytes

    def f(x):
      for _ in range(5):
        # matmul can not be fused by XLA.
        x = math_ops.matmul(x, x)
      return x

    def g(f, x):
      for _ in range(5):
        x = f(x)
      return x[0][0]

    def get_peak_memory(test_func):
      test_func = def_function.function(self._grad(test_func), jit_compile=True)
      # The hlo_proto contains statically allocated memory info of HLO values.
      hlo_proto_serialized = test_func.experimental_get_compiler_ir(a)(
          stage="optimized_hlo_proto_serialized", device_name=device_name)

      hlo_proto = hlo_pb2.HloProto.FromString(hlo_proto_serialized)
      allocations = hlo_proto.buffer_assignment.buffer_allocations
      return sum(getattr(allocation, "size") for allocation in allocations)

    f_no_recompute = functools.partial(g, f)
    f_recompute = functools.partial(g, custom_gradient.recompute_grad(f))

    peak_memory_no_recompute = get_peak_memory(f_no_recompute)
    peak_memory_recompute = get_peak_memory(f_recompute)

    # elem_size * n * n (size of `a`) * 5 (loop of g) * 5 (loop of f)
    self.assertGreaterEqual(peak_memory_no_recompute, elem_size * n * n * 5 * 5)
    # elem_size * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f))
    self.assertGreaterEqual(peak_memory_recompute, elem_size * n * n * 5 * 2)
    # peak_memory_recompute should be less than peak_memory_no_recompute.
    self.assertLess(peak_memory_recompute, elem_size * n * n * 5 * 3)

    with ops.device(device_name):
      res_recompute = f_recompute(a)
      res_no_recompute = f_no_recompute(a)

    self.assertAllClose(res_recompute, res_no_recompute)
    def testRecomputeGradNonXla(self, mode):
        device_type = self._get_device_type()
        device_name = f"{device_type}:0"

        if device_type == "TPU":
            self.skipTest("XLA is required for TPU.")

        if device_type == "CPU":
            self.skipTest(
                "b/185371422: get_memory_info does't support CPU yet.")

        config.reset_memory_stats(device_name)
        base_memory = config.get_memory_info(device_name)["current"]
        n = 500
        with ops.device(device_name):
            a = array_ops.ones((n, n), dtype=dtypes.float16)

        def f(x):
            for _ in range(5):
                x = math_ops.matmul(x, x)
            return x

        def g(f, x):
            for _ in range(5):
                x = f(x)
            return x[0][0]

        def run(test_func):
            with ops.device(device_name):
                if mode == "eager":
                    return self._grad(test_func)(a)
                else:
                    return def_function.function(self._grad(test_func))(a)

        f_no_recompute = functools.partial(g, f)
        f_recompute = functools.partial(g, custom_gradient.recompute_grad(f))

        # The result is not saved so the base memory will stay the same.
        run(f_no_recompute)
        peak_memory_no_recompute = (
            config.get_memory_info(device_name)["peak"] - base_memory)

        config.reset_memory_stats(device_name)
        run(f_recompute)
        peak_memory_recompute = (config.get_memory_info(device_name)["peak"] -
                                 base_memory)

        # 2 * n * n (size of `a`) * 5 (loop of f) * 5 (loop of g)
        self.assertGreaterEqual(peak_memory_no_recompute, 2 * n * n * 5 * 5)
        # 2 * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f))
        self.assertGreaterEqual(peak_memory_recompute, 2 * n * n * 5 * 2)
        # peak_memory_recompute should be less than peak_memory_no_recompute.
        self.assertLess(peak_memory_recompute, 2 * n * n * 5 * 3)

        res_no_recompute = run(f_no_recompute)
        res_recompute = run(f_recompute)
        self.assertAllClose(res_no_recompute, res_recompute)
Exemple #3
0
    def _TestFnVariablesGradient(self, inputs, test_fn, vars_to_grad):
        """Returns gradients of `test_model` with respect to `vars_to_grad`."""

        test_fn_re = custom_gradient.recompute_grad(test_fn)

        with backprop.GradientTape(persistent=True) as tape:
            tape.watch(vars_to_grad)
            out_re = test_fn_re(inputs, vars_to_grad)
            out = test_fn(inputs, vars_to_grad)

        grads_re = tape.gradient(out_re, vars_to_grad)
        grads = tape.gradient(out, vars_to_grad)

        return grads_re, grads