def is_extension_supported(arch, ext): """Checks whether an extension is supported on an arch. Args: arch (taichi_core.Arch): Specified arch. ext (taichi_core.Extension): Specified extension. Returns: bool: Whether `ext` is supported on `arch`. """ return _ti_core.is_extension_supported(arch, ext)
def wrapped(*args, **kwargs): if len(arch) == 0: print('No supported arch found. Skipping.') return arch_params_sets = [arch, *_test_features.values()] arch_params_combinations = list( itertools.product(*arch_params_sets)) for arch_params in arch_params_combinations: req_arch, req_params = arch_params[0], arch_params[1:] if (req_arch not in arch) or (req_arch in exclude): continue if not all( _ti_core.is_extension_supported(req_arch, e) for e in require): continue skip = False current_options = copy.deepcopy(options) for feature, param in zip(_test_features, req_params): value = param.value required_extensions = param.required_extensions if current_options.get(feature, value) != value or any( not _ti_core.is_extension_supported(req_arch, e) for e in required_extensions): skip = True else: # Fill in the missing feature current_options[feature] = value if skip: continue ti.init(arch=req_arch, enable_fallback=False, **current_options) foo(*args, **kwargs) ti.reset()
def build_For(ctx, node): if node.orelse: raise TaichiSyntaxError( "'else' clause for 'for' not supported in Taichi kernels") decorator = ASTTransformer.get_decorator(ctx, node.iter) double_decorator = '' if decorator != '' and len(node.iter.args) == 1: double_decorator = ASTTransformer.get_decorator( ctx, node.iter.args[0]) if decorator == 'static': if double_decorator == 'static': raise TaichiSyntaxError("'ti.static' cannot be nested") with ctx.loop_scope_guard(is_static=True): return ASTTransformer.build_static_for( ctx, node, double_decorator == 'grouped') with ctx.loop_scope_guard(): if decorator == 'ndrange': if double_decorator != '': raise TaichiSyntaxError( "No decorator is allowed inside 'ti.ndrange") return ASTTransformer.build_ndrange_for(ctx, node) if decorator == 'grouped': if double_decorator == 'static': raise TaichiSyntaxError( "'ti.static' is not allowed inside 'ti.grouped'") elif double_decorator == 'ndrange': return ASTTransformer.build_grouped_ndrange_for(ctx, node) elif double_decorator == 'grouped': raise TaichiSyntaxError("'ti.grouped' cannot be nested") else: return ASTTransformer.build_struct_for(ctx, node, is_grouped=True) elif isinstance(node.iter, ast.Call) and isinstance( node.iter.func, ast.Name) and node.iter.func.id == 'range': return ASTTransformer.build_range_for(ctx, node) else: build_stmt(ctx, node.iter) if isinstance(node.iter.ptr, mesh.MeshElementField): if not _ti_core.is_extension_supported( impl.default_cfg().arch, _ti_core.Extension.mesh): raise Exception( 'Backend ' + str(impl.default_cfg().arch) + ' doesn\'t support MeshTaichi extension') return ASTTransformer.build_mesh_for(ctx, node) if isinstance(node.iter.ptr, mesh.MeshRelationAccessProxy): return ASTTransformer.build_nested_mesh_for(ctx, node) # Struct for return ASTTransformer.build_struct_for(ctx, node, is_grouped=False)
def test(arch=None, exclude=None, require=None, **options): """ Performs tests on archs in `expected_archs()` which are in `arch` and not in `exclude` and satisfy `require` .. function:: ti.test(arch=[], exclude=[], require=[], **options) :parameter arch: backends to include :parameter exclude: backends to exclude :parameter require: extensions required :parameter options: other options to be passed into ``ti.init`` """ if arch is None: arch = [] if exclude is None: exclude = [] if require is None: require = [] if not isinstance(arch, (list, tuple)): arch = [arch] if not isinstance(exclude, (list, tuple)): exclude = [exclude] if not isinstance(require, (list, tuple)): require = [require] archs_expected = expected_archs() if len(arch) == 0: arch = archs_expected else: arch = [v for v in arch if v in archs_expected] marks = [] # A list of pytest.marks to apply on the test function if len(arch) == 0: marks.append(pytest.mark.skip(reason='No supported archs')) else: arch_params_sets = [arch, *_test_features.values()] # List of (arch, options) to parametrize the test function parameters = [] for req_arch, *req_params in itertools.product(*arch_params_sets): if (req_arch not in arch) or (req_arch in exclude): continue if not all( _ti_core.is_extension_supported(req_arch, e) for e in require): continue current_options = copy.deepcopy(options) for feature, param in zip(_test_features, req_params): value = param.value required_extensions = param.required_extensions if current_options.setdefault(feature, value) != value or any( not _ti_core.is_extension_supported(req_arch, e) for e in required_extensions): break else: # no break occurs, required extensions are supported parameters.append((req_arch, current_options)) if not parameters: marks.append( pytest.mark.skip( reason='No all required extensions are supported')) else: marks.append( pytest.mark.parametrize( "req_arch,req_options", parameters, ids=[ f"arch={arch.name}-{i}" if len(parameters) > 1 else f"arch={arch.name}" for i, (arch, _) in enumerate(parameters) ])) def decorator(func): func.__ti_test__ = True # Mark the function as a taichi test for mark in reversed(marks): # Apply the marks in reverse order func = mark(func) return func return decorator