def _get_callee_identifier(self, node: ast.Invocation) -> Text: logging.vlog(3, 'Getting callee identifier for invocation: %s', node) if isinstance(node.callee, ast.NameRef): callee_name = node.callee.identifier m = self.module elif isinstance(node.callee, ast.ModRef): m = self.type_info.get_imports()[node.callee.mod][0] callee_name = node.callee.value else: raise NotImplementedError('Callee not currently supported @ {}'.format( node.span)) try: function = m.get_function(callee_name) except KeyError: # For e.g. builtins that are not in the module we just provide the name # directly. return callee_name if not function.is_parametric(): return mangle_dslx_name(function.name.identifier, function.get_free_parametric_keys(), m, None) resolved_symbolic_bindings = self._get_invocation_bindings(node) logging.vlog(2, 'Node %s @ %s symbolic bindings %r', node, node.span, resolved_symbolic_bindings) assert resolved_symbolic_bindings, node return mangle_dslx_name(function.name.identifier, function.get_free_parametric_keys(), m, resolved_symbolic_bindings)
def _visit_map(self, node: ast.Invocation) -> BValue: for arg in node.args[:-1]: arg.accept(self) arg = self._use(node.args[0]) fn_node = node.args[1] if isinstance(fn_node, ast.NameRef): map_fn_name = fn_node.name_def.identifier if map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES: return self._def_map_with_builtin(node, fn_node, node.args[0], self._get_invocation_bindings(node)) else: lookup_module = self.module fn = lookup_module.get_function(map_fn_name) elif isinstance(fn_node, ast.ModRef): map_fn_name = fn_node.value imports = self.type_info.get_imports() lookup_module, _ = imports[fn_node.mod] fn = lookup_module.get_function(map_fn_name) else: raise NotImplementedError( 'Unhandled function mapping: {!r}'.format(fn_node)) node_sym_bindings = self._get_invocation_bindings(node) mangled_name = mangle_dslx_name(fn.name, fn.get_free_parametric_keys(), lookup_module, node_sym_bindings) return self._def(node, self.fb.add_map, arg, self.package.get_function(mangled_name))
def _visit_Function( self, node: ast.Function, symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function: self.symbolic_bindings = {} if symbolic_bindings is None else dict( symbolic_bindings) self._extract_module_level_constants(self.module) # We use a function builder for the duration of converting this # ast.Function. When it's done being built, we drop the reference to it (by # setting self.fb to None). self.fb = function_builder.FunctionBuilder( mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(), self.module, symbolic_bindings), self.package) try: for param in node.params: param.accept(self) for parametric_binding in node.parametric_bindings: logging.vlog(4, 'Resolving parametric binding %s', parametric_binding) sb_value = self.symbolic_bindings[parametric_binding.name.identifier] value = self._resolve_dim(sb_value) assert isinstance(value, int), \ 'Expect integral parametric binding; got {!r}'.format(value) self._def_const( parametric_binding, value, self._resolve_type(parametric_binding.type_).get_total_bit_count()) self._def_alias(parametric_binding, to=parametric_binding.name) for dep in self._constant_deps: dep.accept(self) del self._constant_deps[:] node.body.accept(self) last_expression = self.last_expression or node.body if isinstance(last_expression, ast.NameRef): self._def(last_expression, self.fb.add_identity, self._use(last_expression)) f = self.fb.build() logging.vlog(3, 'Built function: %s', f.name) verifier_mod.verify_function(f) return f finally: self.fb = None
def _def_map_with_builtin(self, parent_node: ast.Invocation, node: ast.NameRef, arg: ast.AstNode, symbolic_bindings: SymbolicBindings) -> BValue: """Makes the specified builtin available to the package.""" mangled_name = mangle_dslx_name(node.name_def.identifier, set(), self.module, symbolic_bindings) arg = self._use(arg) if mangled_name not in self.package.get_function_names(): fb = function_builder.FunctionBuilder(mangled_name, self.package) param = fb.add_param('arg', arg.get_type().get_element_type()) builtin_name = node.name_def.identifier assert builtin_name in dslx_builtins.UNARY_BUILTIN_NAMES, dslx_builtins.UNARY_BUILTIN_NAMES fbuilds = {'clz': fb.add_clz, 'ctz': fb.add_ctz} assert set(fbuilds.keys()) == dslx_builtins.UNARY_BUILTIN_NAMES, set( fbuilds.keys()) fbuilds[builtin_name](param) fb.build() return self._def(parent_node, self.fb.add_map, arg, self.package.get_function(mangled_name))