コード例 #1
0
ファイル: testing.py プロジェクト: wjgaas/taichi
        def ti_init(arch=None, exclude=None, require=None, **options):
            if arch is None:
                arch = []
            if exclude is None:
                exclude = []
            if require is None:
                require = []
            if not isinstance(arch, (list, tuple)):
                arch = [arch]
            if not isinstance(exclude, (list, tuple)):
                exclude = [exclude]
            if not isinstance(require, (list, tuple)):
                require = [require]
            if len(arch) == 0:
                arch = ti.supported_archs()

            if (req_arch not in arch) or (req_arch in exclude):
                raise pytest.skip(f'Arch={req_arch} not included in this test')

            if not all(
                    ti.core.is_extension_supported(req_arch, e)
                    for e in require):
                raise pytest.skip(
                    f'Arch={req_arch} some extension(s) not satisfied')

            ti.init(arch=req_arch, **options)
コード例 #2
0
ファイル: testing.py プロジェクト: quadpixels/taichi
        def wrapped(*args, **kwargs):
            arch_params_sets = [ti.supported_archs(), *_test_features.values()]
            arch_params_combinations = list(
                itertools.product(*arch_params_sets))

            for arch_params in arch_params_combinations:
                req_arch, req_params = arch_params[0], arch_params[1:]

                if (req_arch not in arch) or (req_arch in exclude):
                    continue

                if not all(
                        _ti_core.is_extension_supported(req_arch, e)
                        for e in require):
                    continue

                skip = False
                current_options = copy.deepcopy(options)
                for feature, param in zip(_test_features, req_params):
                    value = param.value
                    required_extensions = param.required_extensions
                    if current_options.get(feature, value) != value or any(
                            not _ti_core.is_extension_supported(req_arch, e)
                            for e in required_extensions):
                        skip = True
                    else:
                        # Fill in the missing feature
                        current_options[feature] = value
                if skip:
                    continue

                ti.init(arch=req_arch, **current_options)
                foo(*args, **kwargs)
                ti.reset()
コード例 #3
0
def grad_test(tifunc, npfunc=None, default_fp=ti.f32):
    for arch in ti.supported_archs():
        ti.init(arch=arch, default_fp=default_fp)
        if npfunc is None:
            npfunc = tifunc

        x = ti.var(default_fp)
        y = ti.var(default_fp)

        @ti.layout
        def place():
            ti.root.dense(ti.i, 1).place(x, x.grad, y, y.grad)

        @ti.kernel
        def func():
            for i in x:
                y[i] = tifunc(x[i])

        v = 0.234

        y.grad[0] = 1
        x[0] = v
        func()
        func.grad()

        assert y[0] == approx(npfunc(v))
        assert x.grad[0] == approx(grad(npfunc)(v))
コード例 #4
0
        def test(*func_args, **func_kwargs):
            import taichi as ti
            can_run_on = func_kwargs.pop(_tests_arch_checkers_argname,
                                         _ArchCheckers())
            # Filter away archs that don't support 64-bit data.
            fp = func_kwargs.get('default_fp', ti.get_runtime().default_fp)
            ip = func_kwargs.get('default_ip', ti.get_runtime().default_ip)
            if fp == ti.f64 or ip == ti.i64:
                can_run_on.register(
                    lambda arch: is_supported(arch, extension.data64))

            for arch in ti.supported_archs():
                if can_run_on(arch):
                    ti.init(arch=arch, **kwargs)
                    func(*func_args, **func_kwargs)
コード例 #5
0
    def wrapped(*test_args, **test_kwargs):
      import taichi as ti
      can_run_on = test_kwargs.pop(
          _tests_arch_checkers_argname, _ArchCheckers())
      # Filter away archs that don't support 64-bit data.
      fp = kwargs.get('default_fp', ti.f32)
      ip = kwargs.get('default_ip', ti.i32)
      if fp == ti.f64 or ip == ti.i64:
        can_run_on.register(lambda arch: is_supported(arch, extension.data64))

      for arch in ti.supported_archs():
        if can_run_on(arch):
          print('Running test on arch={}'.format(arch))
          ti.init(arch=arch, **kwargs)
          test(*test_args, **test_kwargs)
        else:
          print('Skipped test on arch={}'.format(arch))
コード例 #6
0
ファイル: testing.py プロジェクト: wjgaas/taichi
def _get_taichi_archs_fixture():
    import pytest

    @pytest.fixture(params=ti.supported_archs(), ids=ti.core.arch_name)
    def taichi_archs(request):
        marker = request.node.get_closest_marker('taichi')
        req_arch = request.param

        def ti_init(arch=None, exclude=None, require=None, **options):
            if arch is None:
                arch = []
            if exclude is None:
                exclude = []
            if require is None:
                require = []
            if not isinstance(arch, (list, tuple)):
                arch = [arch]
            if not isinstance(exclude, (list, tuple)):
                exclude = [exclude]
            if not isinstance(require, (list, tuple)):
                require = [require]
            if len(arch) == 0:
                arch = ti.supported_archs()

            if (req_arch not in arch) or (req_arch in exclude):
                raise pytest.skip(f'Arch={req_arch} not included in this test')

            if not all(
                    ti.core.is_extension_supported(req_arch, e)
                    for e in require):
                raise pytest.skip(
                    f'Arch={req_arch} some extension(s) not satisfied')

            ti.init(arch=req_arch, **options)

        ti_init(*marker.args, **marker.kwargs)
        yield

    return taichi_archs
コード例 #7
0
        # test if specified in argument:
        for value in values:
            kwargs = {key: value}
            test_arg(key, value, kwargs)

    # test if specified in environment:
    env_key = 'TI_' + key.upper()
    for value in values:
        env_value = str(int(value) if isinstance(value, bool) else value)
        environ = {env_key: env_value}
        with patch_os_environ_helper(environ, excludes=env_configs):
            test_arg(key, value)


@pytest.mark.parametrize('arch', ti.supported_archs())
def test_init_arch(arch):
    with patch_os_environ_helper({}, excludes=['TI_ARCH']):
        ti.init(arch=arch)
        assert ti.cfg.arch == arch
    with patch_os_environ_helper({'TI_ARCH': ti.core.arch_name(arch)},
                                 excludes=['TI_ARCH']):
        ti.init(arch=ti.cc)
        assert ti.cfg.arch == arch


def test_init_bad_arg():
    with pytest.raises(KeyError):
        ti.init(_test_mode=True, debug=True, foo_bar=233)

コード例 #8
0
ファイル: test_test.py プロジェクト: wjgaas/taichi
def test_all_archs():
    assert ti.cfg.arch in ti.supported_archs()
コード例 #9
0
ファイル: testing.py プロジェクト: quadpixels/taichi
def test(arch=None, exclude=None, require=None, **options):
    '''
.. function:: ti.test(arch=[], exclude=[], require=[], **options)

    :parameter arch: backends to include
    :parameter exclude: backends to exclude
    :parameter require: extensions required
    :parameter options: other options to be passed into ``ti.init``
    '''

    if arch is None:
        arch = []
    if exclude is None:
        exclude = []
    if require is None:
        require = []
    if not isinstance(arch, (list, tuple)):
        arch = [arch]
    if not isinstance(exclude, (list, tuple)):
        exclude = [exclude]
    if not isinstance(require, (list, tuple)):
        require = [require]
    if len(arch) == 0:
        arch = ti.supported_archs()

    def decorator(foo):
        import functools

        @functools.wraps(foo)
        def wrapped(*args, **kwargs):
            arch_params_sets = [ti.supported_archs(), *_test_features.values()]
            arch_params_combinations = list(
                itertools.product(*arch_params_sets))

            for arch_params in arch_params_combinations:
                req_arch, req_params = arch_params[0], arch_params[1:]

                if (req_arch not in arch) or (req_arch in exclude):
                    continue

                if not all(
                        _ti_core.is_extension_supported(req_arch, e)
                        for e in require):
                    continue

                skip = False
                current_options = copy.deepcopy(options)
                for feature, param in zip(_test_features, req_params):
                    value = param.value
                    required_extensions = param.required_extensions
                    if current_options.get(feature, value) != value or any(
                            not _ti_core.is_extension_supported(req_arch, e)
                            for e in required_extensions):
                        skip = True
                    else:
                        # Fill in the missing feature
                        current_options[feature] = value
                if skip:
                    continue

                ti.init(arch=req_arch, **current_options)
                foo(*args, **kwargs)
                ti.reset()

        return wrapped

    return decorator
コード例 #10
0
ファイル: __init__.py プロジェクト: stantoxt/taichi-1
 def test(*func_args, **func_kwargs):
   for arch in ti.supported_archs():
     ti.init(arch=arch, **kwargs)
     func(*func_args, **func_kwargs)
コード例 #11
0
def test_svd():
  for arch in ti.supported_archs():
    for fp in [ti.f32, ti.f64]:
      for d in [2, 3]:
        print(arch, fp, d)
        _test_svd(arch, fp, d)