예제 #1
0
  def testResetMemoryStatsCPU(self):
    if test_util.IsMklEnabled():
      # TODO(gzmkl) work with Google team to address design issue in allocator.h
      self.skipTest('MklCPUAllocator does not throw exception. So skip test.')

    with self.assertRaisesRegex(ValueError, 'Cannot reset memory stats'):
      config.reset_memory_stats('CPU:0')
    def testRecomputeGradNonXla(self, mode):
        device_type = self._get_device_type()
        device_name = f"{device_type}:0"

        if device_type == "TPU":
            self.skipTest("XLA is required for TPU.")

        if device_type == "CPU":
            self.skipTest(
                "b/185371422: get_memory_info does't support CPU yet.")

        config.reset_memory_stats(device_name)
        base_memory = config.get_memory_info(device_name)["current"]
        n = 500
        with ops.device(device_name):
            a = array_ops.ones((n, n), dtype=dtypes.float16)

        def f(x):
            for _ in range(5):
                x = math_ops.matmul(x, x)
            return x

        def g(f, x):
            for _ in range(5):
                x = f(x)
            return x[0][0]

        def run(test_func):
            with ops.device(device_name):
                if mode == "eager":
                    return self._grad(test_func)(a)
                else:
                    return def_function.function(self._grad(test_func))(a)

        f_no_recompute = functools.partial(g, f)
        f_recompute = functools.partial(g, custom_gradient.recompute_grad(f))

        # The result is not saved so the base memory will stay the same.
        run(f_no_recompute)
        peak_memory_no_recompute = (
            config.get_memory_info(device_name)["peak"] - base_memory)

        config.reset_memory_stats(device_name)
        run(f_recompute)
        peak_memory_recompute = (config.get_memory_info(device_name)["peak"] -
                                 base_memory)

        # 2 * n * n (size of `a`) * 5 (loop of f) * 5 (loop of g)
        self.assertGreaterEqual(peak_memory_no_recompute, 2 * n * n * 5 * 5)
        # 2 * n * n (size of `a`) * (5 (loop of g) + 5 (recompute in f))
        self.assertGreaterEqual(peak_memory_recompute, 2 * n * n * 5 * 2)
        # peak_memory_recompute should be less than peak_memory_no_recompute.
        self.assertLess(peak_memory_recompute, 2 * n * n * 5 * 3)

        res_no_recompute = run(f_no_recompute)
        res_recompute = run(f_recompute)
        self.assertAllClose(res_no_recompute, res_recompute)
예제 #3
0
    def testResetMemoryStats(self):
        x = array_ops.zeros((1000, 1000), dtype=dtypes.float32)
        config.reset_memory_stats('GPU:0')
        info1 = config.get_memory_info('GPU:0')
        self.assertGreaterEqual(info1['peak'], 4 * 1000 * 1000)
        self.assertGreaterEqual(info1['peak'], info1['current'])
        self.assertGreater(info1['current'], 0)

        del x  # With CPython, causes tensor memory to be immediately freed
        config.reset_memory_stats('GPU:0')
        info2 = config.get_memory_info('GPU:0')
        self.assertLess(info2['peak'], info1['peak'])
예제 #4
0
 def testResetMemoryStatsUnknownDevice(self):
     with self.assertRaisesRegex(ValueError, 'No matching devices found'):
         config.reset_memory_stats('unknown_device:0')
예제 #5
0
 def testResetMemoryStatsCPU(self):
   with self.assertRaisesRegex(ValueError, 'Cannot reset memory stats'):
     config.reset_memory_stats('CPU:0')
예제 #6
0
 def testResetMemoryStatsAmbiguousDevice(self):
     if len(config.list_physical_devices('GPU')) < 2:
         self.skipTest('Need at least 2 GPUs')
     with self.assertRaisesRegex(ValueError, 'Multiple devices'):
         config.reset_memory_stats('GPU')
예제 #7
0
 def testResetMemoryStatsUnknownDevice(self):
     with self.assertRaisesRegex(ValueError, 'Failed parsing device name'):
         config.reset_memory_stats('unknown_device')