def _visit_map(self, node: ast.Invocation) -> BValue: for arg in node.args[:-1]: visit(arg, self) arg = self._use(node.args[0]) fn_node = node.args[1] if isinstance(fn_node, ast.NameRef): map_fn_name = fn_node.name_def.identifier if map_fn_name in dslx_builtins.PARAMETRIC_BUILTIN_NAMES: return self._def_map_with_builtin(node, fn_node, node.args[0], self._get_invocation_bindings(node)) else: lookup_module = self.module fn = lookup_module.get_function(map_fn_name) elif isinstance(fn_node, ast.ModRef): map_fn_name = fn_node.value imports = self.type_info.get_imports() lookup_module, _ = imports[fn_node.mod] fn = lookup_module.get_function(map_fn_name) else: raise NotImplementedError( 'Unhandled function mapping: {!r}'.format(fn_node)) node_sym_bindings = self._get_invocation_bindings(node) mangled_name = mangle_dslx_name(fn.name, fn.get_free_parametric_keys(), lookup_module, node_sym_bindings) return self._def(node, self.fb.add_map, arg, self.package.get_function(mangled_name))
def _visit_matcher(self, matcher: ast.NameDefTree, index: Tuple[int, ...], matched_value: BValue, matched_type: ConcreteType) -> BValue: if matcher.is_leaf(): leaf = matcher.get_leaf() logging.vlog(5, 'Matcher is leaf: %s (%s)', leaf, leaf.__class__.__name__) if isinstance(leaf, ast.WildcardPattern): return self._def(matcher, self.fb.add_literal_bits, bits_mod.UBits(1, 1)) elif isinstance(leaf, (ast.Number, ast.EnumRef)): visit(leaf, self) return self._def(matcher, self.fb.add_eq, self._use(leaf), matched_value) elif isinstance(leaf, ast.NameRef): result = self._def(matcher, self.fb.add_eq, self._use(leaf.name_def), matched_value) self._def_alias(leaf.name_def, to=leaf) return result else: assert isinstance( leaf, ast.NameDef ), 'Expected leaf to be wildcard, number, or name; got: {!r}'.format( leaf) ok = self._def(leaf, self.fb.add_literal_bits, bits_mod.UBits(1, 1)) self.node_to_ir[matcher] = self.node_to_ir[leaf] = matched_value return ok else: ok = self.fb.add_literal_bits(bits_mod.UBits(value=1, bit_count=1)) for i, (element, element_type) in enumerate( zip(matcher.tree, matched_type.get_unnamed_members())): # pytype: disable=attribute-error # Extract the element. member = self.fb.add_tuple_index(matched_value, i) cond = self._visit_matcher(element, index + (i,), member, element_type) ok = self.fb.add_and(ok, cond) return ok
def query_const_range_call() -> int: """Returns trip count if this is a `for ... in range(CONST)` construct.""" range_callee = ( isinstance(node.iterable, ast.Invocation) and isinstance(node.iterable.callee, ast.NameRef) and node.iterable.callee.identifier == 'range') if not range_callee: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found non-range callee.", node.span) if len(node.iterable.args) != 2: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) if not self._is_constant_zero(node.iterable.args[0]): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) arg = node.iterable.args[1] visit(arg, self) if not self._is_const(arg): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(const)' call is supported, did not find a const value " f'for {arg} ({arg!r}).', node.span) return self._get_const(arg)
def test_simple_number(self): m = ast.Module('test') n = ast.Number(m, self.fake_span, '42') self.assertEmpty(n.children) collector = Collector() cpp_ast_visitor.visit(n, collector) self.assertEqual(collector.visited, [n])
def visit_StructInstance(self, node: ast.StructInstance) -> None: operands = [] struct = self._deref_struct(node.struct) for _, m in node.get_ordered_members(struct): visit(m, self) operands.append(self._use(m)) operands = tuple(operands) self._def(node, self.fb.add_tuple, operands)
def test_array_of_numbers(self): m = ast.Module('test') n0 = ast.Number(m, self.fake_span, '42') n1 = ast.Number(m, self.fake_span, '64') a = ast.Array(m, self.fake_span, [n0, n1], False) collector = Collector() cpp_ast_visitor.visit(a, collector) self.assertEqual(collector.visited, [n0, n1, a])
def test_visit_unop(self): fake_span = self.fake_span m = self.m i_def = ast.NameDef(m, fake_span, 'i') i_ref = ast.NameRef(m, fake_span, 'i', i_def) negated = ast.Unop(m, fake_span, ast.UnopKind.NEG, i_ref) c = _Collector() visit(negated, c) self.assertEqual(c.collected, [i_ref, negated])
def visit_Array(self, node: ast.Array) -> None: array_type = self._resolve_type(node) members = [] for member in node.members: visit(member, self) members.append(self._use(member)) if node.has_ellipsis: while len(members) < array_type.size: # pytype: disable=attribute-error members.append(members[-1]) self._def(node, self.fb.add_array, members, members[0].get_type())
def test_visit_index(self): fake_span = self.fake_span m = self.m # Make a t[i] inde xnode. t = ast.NameRef(m, fake_span, 't', ast.NameDef(m, fake_span, 't')) i = ast.NameRef(m, fake_span, 'i', ast.NameDef(m, fake_span, 'i')) index = ast.Index(m, fake_span, t, i) c = _Collector() visit(index, c) self.assertEqual(c.collected, [t, i, index])
def test_visit_match_multi_pattern(self): fake_pos = self.fake_pos fake_span = Span(fake_pos, fake_pos) m = self.m e = ast.Number(m, fake_span, u'0xf00') p0 = ast.NameDefTree(m, fake_span, e) p1 = ast.NameDefTree(m, fake_span, e) arm = ast.MatchArm(m, fake_span, patterns=(p0, p1), expr=e) c = _Collector() visit(arm, c) self.assertEqual(c.collected, [e, e, e])
def test_visit_type(self): fake_span = self.fake_span five = self.five # Make a uN[5] type node. t = ast_helpers.make_builtin_type_annotation(self.m, fake_span, Token(value=Keyword.BITS, span=fake_span), dims=(five, )) assert isinstance(t, ast.ArrayTypeAnnotation), t c = _Collector() visit(t, c) self.assertEqual(c.collected, [five, t])
def visit_SplatStructInstance(self, node: ast.SplatStructInstance) -> None: visit(node.splatted, self) orig = self._use(node.splatted) updates = {} for k, e in node.members: visit(e, self) updates[k] = self._use(e) struct = self._deref_struct(node.struct) members = [] for i, k in enumerate(struct.member_names): if k in updates: members.append(updates[k]) else: members.append(self.fb.add_tuple_index(orig, i)) self._def(node, self.fb.add_tuple, members)
def visit_Cast(self, node: ast.Cast) -> None: visit(node.expr, self) output_type = self._resolve_type(node) if isinstance(output_type, ArrayType): return self._cast_to_array(node, output_type) if not (isinstance(output_type, BitsType) or isinstance(output_type, EnumType)): raise NotImplementedError( 'Cast can only handle bits output types; got: ' f'{output_type} @ {node.span} ({output_type!r})') input_type = self._resolve_type(node.expr) if isinstance(input_type, ArrayType): return self._cast_from_array(node, output_type) new_bit_count = output_type.get_total_bit_count() input_type = self._resolve_type(node.expr) if new_bit_count < input_type.get_total_bit_count(): self._def(node, self.fb.add_bit_slice, self._use(node.expr), 0, new_bit_count) else: signed_input = input_type.get_signedness() f = self.fb.add_signext if signed_input else self.fb.add_zeroext self._def(node, f, self._use(node.expr), new_bit_count)
def visit_Match(self, node: ast.Match): if (not node.arms or not node.arms[-1].patterns[0].is_irrefutable()): raise ConversionError( 'Only matches with trailing irrefutable patterns are currently handled.', node.span) visit(node.matched, self) matched = self._use(node.matched) matched_type = self._resolve_type(node.matched) default_arm = node.arms[-1] assert len(default_arm.patterns) == 1, ( 'Multiple patterns in default arm is not yet implemented for IR ' 'conversion.') self._visit_matcher(default_arm.patterns[0], (len(node.arms) - 1,), matched, matched_type) visit(default_arm.expr, self) arm_selectors = [] arm_values = [] for i, arm in enumerate(node.arms[:-1]): # Visit all the match patterns. this_arm_selectors = [] for pattern in arm.patterns: selector = self._visit_matcher(pattern, (i,), matched, matched_type) this_arm_selectors.append(selector) # "Or" together the patterns, if necessary, to determine if the arm is # selected. if len(this_arm_selectors) > 1: arm_selectors.append(self.fb.add_nary_or(this_arm_selectors)) else: arm_selectors.append(this_arm_selectors[0]) visit(arm.expr, self) arm_values.append(self._use(arm.expr)) # So now we have the following representation of the match arms: # match x { # 42 => blah # 64 => snarf # 128 => yep # _ => burp # } # # selectors: [x==42, x==64, x==128] # values: [blah, snarf, yep] # default_value: burp self.node_to_ir[node] = self.fb.add_match_true(arm_selectors, arm_values, self._use(default_arm.expr)) self.last_expression = node
def _visit_Function( self, node: ast.Function, symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function: self.symbolic_bindings = {} if symbolic_bindings is None else dict( symbolic_bindings) self._extract_module_level_constants(self.module) # We use a function builder for the duration of converting this # ast.Function. When it's done being built, we drop the reference to it (by # setting self.fb to None). self.fb = function_builder.FunctionBuilder( mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(), self.module, symbolic_bindings), self.package) try: for param in node.params: visit(param, self) for parametric_binding in node.parametric_bindings: logging.vlog(4, 'Resolving parametric binding %s', parametric_binding) sb_value = self.symbolic_bindings[parametric_binding.name.identifier] value = self._resolve_dim(sb_value) assert isinstance(value, int), \ 'Expect integral parametric binding; got {!r}'.format(value) self._def_const( parametric_binding, value, self._resolve_type(parametric_binding.type_).get_total_bit_count()) self._def_alias(parametric_binding, to=parametric_binding.name) for dep in self._constant_deps: visit(dep, self) del self._constant_deps[:] visit(node.body, self) last_expression = self.last_expression or node.body if isinstance(last_expression, ast.NameRef): self._def(last_expression, self.fb.add_identity, self._use(last_expression)) f = self.fb.build() logging.vlog(3, 'Built function: %s', f.name) verifier_mod.verify_function(f) return f finally: self.fb = None
def visit_Index(self, node: ast.Index) -> None: visit(node.lhs, self) lhs_type = self.type_info[node.lhs] if isinstance(lhs_type, TupleType): visit(node.index, self) self._def(node, self.fb.add_tuple_index, self._use(node.lhs), self._get_const(node.index)) elif isinstance(lhs_type, BitsType): index_slice = node.index if isinstance(index_slice, ast.WidthSlice): return self._visit_width_slice(node, index_slice, lhs_type) assert isinstance(index_slice, ast.Slice), index_slice start, width = self.type_info.get_slice_start_width( index_slice, self._get_symbolic_bindings_tuple()) self._def(node, self.fb.add_bit_slice, self._use(node.lhs), start, width) else: visit(node.index, self) self._def(node, self.fb.add_array_index, self._use(node.lhs), self._use(node.index))
def visit_Let(self, node: ast.Let): visit(node.rhs, self) if node.name_def_tree.is_leaf(): self._def_alias(node.rhs, to=node.name_def_tree.get_leaf()) visit(node.body, self) self._def_alias(node.body, node) else: # Walk the tree performing tuple_index operations to get to the binding # levels desired. names = [self._use(node.rhs)] # List[BValue] def walk(x: ast.NameDefTree, level: int, index: int) -> None: """Invoked at each level of the name def tree. Binds the name in the name def tree to the corresponding value being pattern matched. Args: x: The current level of the NameDefTree. level: Level in the NameDefTree (root is 0). index: Index of node in the current tree level (e.g. leftmost is 0). """ del names[level:] names.append( self._def( x, self.fb.add_tuple_index, names[-1], index, span=(x.get_leaf().span if x.is_leaf() else x.span))) if x.is_leaf(): self._def_alias(x, x.get_leaf()) ast_helpers.do_preorder(node.name_def_tree, walk) visit(node.body, self) self._def_alias(node.body, to=node) if self.last_expression is None: self.last_expression = node.body
def visit_EnumRef(self, node: ast.EnumRef) -> None: enum = self._deref_enum(node.enum) value = enum.get_value(node.value) visit(value, self) self._def_alias(from_=value, to=node)
def visit_XlsTuple(self, node: ast.XlsTuple) -> None: for o in node.members: visit(o, self) operands = tuple(self._use(o) for o in node.members) self._def(node, self.fb.add_tuple, operands)
def get_callees(func: Union[ast.Function, ast.Test], m: ast.Module, type_info: type_info_mod.TypeInfo, imports: Dict[ast.Import, ImportedInfo], bindings: SymbolicBindings) -> Tuple[Callee, ...]: """Traverses the definition of f to find callees. Args: func: Function/test construct to inspect for calls. m: Module that f resides in. type_info: Node to type mapping that should be used with f. imports: Mapping of modules imported by m. bindings: Bindings used in instantiation of f. Returns: Callee functions invoked by f, and the parametric bindings used in each of those invocations. """ assert isinstance(bindings, SymbolicBindings), bindings callees = [] class InvocationVisitor(cpp_ast_visitor.AstVisitor): """Visits invocation nodes to build up the callees list.""" @cpp_ast_visitor.AstVisitor.no_auto_traverse def visit_ParametricBinding(self, node: ast.ParametricBinding) -> None: pass def visit_Invocation(self, node: ast.Invocation) -> None: if isinstance(node.callee, ast.ColonRef): this_m, _ = imports[node.callee.subject.name_def.definer] f = this_m.get_function(node.callee.attr) fn_identifier = f.identifier elif isinstance(node.callee, ast.NameRef): this_m = m fn_identifier = node.callee.identifier if fn_identifier == 'map': # We need to make sure we convert the mapped function! fn_node = node.args[1] if isinstance(fn_node, ast.ColonRef): fn_identifier = fn_node.attr import_node = fn_node.subject.name_def.definer this_m = imports[import_node][0] else: fn_identifier = fn_node.name_def.identifier try: f = this_m.get_function(fn_identifier) except KeyError: if node.callee.identifier in dslx_builtins.PARAMETRIC_BUILTIN_NAMES: return raise else: raise NotImplementedError( 'Only calls to named functions are currently supported, got callee: {!r}' .format(node.callee)) node_symbolic_bindings = type_info.get_invocation_symbolic_bindings( node, bindings) # Either use the global type_info or the child type_info # chained off of this invocation. try: invocation_type_info = type_info.get_instantiation( node, node_symbolic_bindings) except KeyError: invocation_type_info = type_info assert invocation_type_info is not None assert isinstance(node_symbolic_bindings, SymbolicBindings), node_symbolic_bindings callees.append( Callee(f, this_m, invocation_type_info, node_symbolic_bindings)) cpp_ast_visitor.visit(func, InvocationVisitor()) logging.vlog(3, 'Callees for %s: %s', func, [(cr.m.name, cr.f.identifier, cr.sym_bindings) for cr in callees]) return tuple(callees)
def visit_For(self, node: ast.For) -> None: visit(node.init, self) def query_const_range_call() -> int: """Returns trip count if this is a `for ... in range(CONST)` construct.""" range_callee = ( isinstance(node.iterable, ast.Invocation) and isinstance(node.iterable.callee, ast.NameRef) and node.iterable.callee.identifier == 'range') if not range_callee: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found non-range callee.", node.span) if len(node.iterable.args) != 2: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) if not self._is_constant_zero(node.iterable.args[0]): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) arg = node.iterable.args[1] visit(arg, self) if not self._is_const(arg): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(const)' call is supported, did not find a const value " f'for {arg} ({arg!r}).', node.span) return self._get_const(arg) # TODO(leary): We currently only support counted loops of the form: # # for (i, ...): (u32, ...) in range(N) { # ... # } trip_count = query_const_range_call() logging.vlog(3, 'Converting for-loop @ %s', node.span) body_converter = _IrConverterFb( self.package, self.module, self.type_info, emit_positions=self.emit_positions) body_converter.symbolic_bindings = dict(self.symbolic_bindings) body_fn_name = ('__' + self.fb.name + '_counted_for_{}_body').format( self._next_counted_for_ordinal()).replace('.', '_') body_converter.fb = function_builder.FunctionBuilder( body_fn_name, self.package) flat = node.names.flatten1() assert len( flat ) == 2, 'Expect an induction binding and loop carry binding; got {!r}'.format( flat) # Add the induction value. assert isinstance( flat[0], ast.NameDef ), 'Induction variable was not a NameDef: {0} ({0!r})'.format(flat[0]) body_converter.node_to_ir[flat[0]] = body_converter.fb.add_param( flat[0].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[0])) # Add the loop carry value. if isinstance(flat[1], ast.NameDef): body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param( flat[1].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[1])) else: # For tuple loop carries we have to destructure names on entry. carry_type = self._resolve_type_to_ir(flat[1]) carry = body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param( '__loop_carry', carry_type) body_converter._visit_matcher( # pylint: disable=protected-access flat[1], (), carry, self._resolve_type(flat[1])) # Free variables are suffixes on the function parameters. freevars = node.body.get_free_variables(node.span.start) freevars = freevars.drop_builtin_defs() for name_def in freevars.get_name_defs(self.module): type_ = self.type_info[name_def] if isinstance(type_, FunctionType): continue logging.vlog(3, 'Converting freevar name: %s', name_def) body_converter.node_to_ir[name_def] = body_converter.fb.add_param( name_def.identifier.encode('utf-8'), self._resolve_type_to_ir(name_def)) visit(node.body, body_converter) body_function = body_converter.fb.build() logging.vlog(3, 'Converted body function: %s', body_function.name) stride = 1 invariant_args = tuple( self._use(name_def) for name_def in freevars.get_name_defs(self.module) if not isinstance(self.type_info[name_def], FunctionType)) self._def(node, self.fb.add_counted_for, self._use(node.init), trip_count, stride, body_function, invariant_args)
def _visit_width_slice(self, node: ast.Index, width_slice: ast.WidthSlice, lhs_type: ConcreteType) -> None: visit(width_slice.start, self) self._def(node, self.fb.add_dynamic_bit_slice, self._use(node.lhs), self._use(width_slice.start), self._resolve_type(node).get_total_bit_count())
def visit_Constant(self, node: ast.Constant) -> None: visit(node.value, self) logging.vlog(5, 'Aliasing NameDef for constant: %r', node.name) self._def_alias(node.value, to=node.name)
def accept_args() -> Tuple[BValue, ...]: for arg in node.args: visit(arg, self) return tuple(self._use(arg) for arg in node.args)