Пример #1
0
    def test_depthwise_conv2d(self, mode):
        """Test grad ops with depthwise convolution2d graph."""
        self._maybe_skip(mode)
        cudnn_version_str = sysconfig.get_build_info().get(
            'cudnn_version', '0.0')
        cudnn_version = tuple([int(x) for x in cudnn_version_str.split('.')])
        if cudnn_version < (8, ):
            # Depthwise conv2d ops are only enabled in auto_mixed_precision as of
            # cuDNN v8.
            self.skipTest('cuDNN version >= 8 required')
        random_seed.set_random_seed(0)
        x = _input([2, 8, 8, 1])
        f = _weight([3, 3, 1, 4])
        y = _depthwise_conv2d(x, f)
        y = array_ops.identity(y)
        optimizer = gradient_descent.GradientDescentOptimizer(
            learning_rate=0.01)
        g = optimizer.compute_gradients(y, [x, f])
        output = (y, g)

        output_val_ref, output_val, cost_graph = self._run(mode, output)
        node_map = _build_node_map(cost_graph.node)
        self._assert_output_f16(mode, node_map, 'depthwise')
        self._assert_output_f16(
            mode, node_map,
            'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropInput')
        self._assert_output_f16(
            mode, node_map,
            'gradients/depthwise_grad/DepthwiseConv2dNativeBackpropFilter')

        output_val_ref, output_val, cost_graph = self._run(mode, output)
        tol = 2e-3
        self.assertAllClose(output_val_ref, output_val, atol=tol, rtol=tol)
Пример #2
0
    def maybe_skip_test(self, mode):
        if mode == 'cuda':
            # It seems the windows os cannot correctly query the cuda_version.
            # TODO(kaixih@nvidia): Remove this when it works.
            if os.name == 'nt':
                self.skipTest("This test doesn't support Windows")

            # The cublaslt matmul with gelu epilog is only supported since cuda 11.4.
            if not test.is_gpu_available(cuda_only=True):
                self.skipTest('This test requires GPU.')
            cuda_version_str = sysconfig.get_build_info().get(
                'cuda_version', '0.0')
            cuda_version = tuple([int(x) for x in cuda_version_str.split('.')])
            if cuda_version < (11, 4):
                self.skipTest('This test requires CUDA >= 11.4.')

        if mode == 'mkl' and not test_util.IsMklEnabled():
            self.skipTest('MKL is not enabled.')
Пример #3
0
 def test_rocm_cuda_info_matches(self):
     build_info = sysconfig.get_build_info()
     self.assertEqual(build_info["is_rocm_build"],
                      test.is_built_with_rocm())
     self.assertEqual(build_info["is_cuda_build"],
                      test.is_built_with_cuda())
Пример #4
0
 def test_get_build_info_works(self):
     build_info = sysconfig.get_build_info()
     self.assertIsInstance(build_info, dict)