def test_invoke_adder_2_plus_3_eq_5(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('add_wrapper', p) t = p.get_bits_type(32) x = fb.add_param('x', t) y = fb.add_param('y', t) fb.add_add(x, y) add_wrapper = fb.build() main_fb = function_builder.FunctionBuilder('main', p) two = main_fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32)) three = main_fb.add_literal_bits(bits_mod.UBits(value=3, bit_count=32)) observed = main_fb.add_invoke([two, three], add_wrapper) main_fb.add_eq( observed, main_fb.add_literal_bits(bits_mod.UBits(value=5, bit_count=32))) main_fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn add_wrapper(x: bits[32], y: bits[32]) -> bits[32] { ret add.3: bits[32] = add(x, y, id=3) } fn main() -> bits[1] { literal.4: bits[32] = literal(value=2, id=4) literal.5: bits[32] = literal(value=3, id=5) invoke.6: bits[32] = invoke(literal.4, literal.5, to_apply=add_wrapper, id=6) literal.7: bits[32] = literal(value=5, id=7) ret eq.8: bits[1] = eq(invoke.6, literal.7, id=8) } """)
def test_standard_pipeline(self): pkg = package.Package('pname') fb = function_builder.FunctionBuilder('main', pkg) fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32)) fb.build() self.assertFalse(standard_pipeline.run_standard_pass_pipeline(pkg))
def test_simple_build_and_dump_package(self): p = ir_package.Package('test_package') fileno = p.get_or_create_fileno('my_file.x') fb = function_builder.FunctionBuilder('test_function', p) t = p.get_bits_type(32) x = fb.add_param('x', t) lineno = fileno_mod.Lineno(42) colno = fileno_mod.Colno(64) loc = source_location.SourceLocation(fileno, lineno, colno) fb.add_or(x, x, loc=loc, name='my_or') fb.add_not(x, loc=loc, name='why_not') f = fb.build() self.assertEqual(f.name, 'test_function') self.assertEqual( f.dump_ir(), """\ fn test_function(x: bits[32]) -> bits[32] { my_or: bits[32] = or(x, x, id=2, pos=0,42,64) ret why_not: bits[32] = not(x, id=3, pos=0,42,64) } """) self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn test_function(x: bits[32]) -> bits[32] { my_or: bits[32] = or(x, x, id=2, pos=0,42,64) ret why_not: bits[32] = not(x, id=3, pos=0,42,64) } """)
def test_verify_function(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() verifier.verify_function(fn)
def test_bvalue_methods(self): # This test is mainly about checking that pybind11 is able to map parameter # and return types properly. Because of this it's not necessary to check # the result at the end; that methods don't throw when called is enough. p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('test_function', p) x = fb.add_param('param_name', p.get_bits_type(32)) self.assertIn('param_name', str(x)) self.assertEqual(32, x.get_type().get_bit_count())
def test_pipeline_generator_with_clock_period(self): pkg = package.Package('pname') fb = function_builder.FunctionBuilder('main', pkg) fb.add_param('x', pkg.get_bits_type(32)) fb.build() module_signature = pipeline_generator.generate_pipelined_module_with_clock_period( pkg, 100, 'bar') self.assertIn('module bar', module_signature.verilog_text) self.assertIn('p0_x', module_signature.verilog_text) self.assertIn('p1_x', module_signature.verilog_text)
def test_function_type(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() fn_type = fn.get_type() self.assertIsInstance(fn_type, ir_type.FunctionType) self.assertIn('bits[32]', str(fn_type)) self.assertEqual(8, fn_type.return_type().get_bit_count()) self.assertEqual(1, fn_type.get_parameter_count()) self.assertEqual(32, fn_type.get_parameter_type(0).get_bit_count())
def test_package_methods(self): pkg = package.Package('pkg') fb = function_builder.FunctionBuilder('f', pkg) fb.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fb.build() self.assertIn('pkg', pkg.dump_ir()) self.assertIsInstance(pkg.get_bits_type(4), ir_type.BitsType) self.assertIsInstance(pkg.get_array_type(4, pkg.get_bits_type(4)), ir_type.ArrayType) self.assertIsInstance(pkg.get_tuple_type([pkg.get_bits_type(4)]), ir_type.TupleType) self.assertIsInstance(pkg.get_or_create_fileno('file'), fileno.Fileno) self.assertIsInstance(pkg.get_function('f'), function.Function) self.assertEqual(['f'], pkg.get_function_names())
def test_pipeline_generator_with_n_stages(self): pkg = package.Package('pname') fb = function_builder.FunctionBuilder('main', pkg) fb.add_param('x', pkg.get_bits_type(32)) fb.build() module_signature = pipeline_generator.generate_pipelined_module_with_n_stages( pkg, 5, 'foo') self.assertIn('module foo', module_signature.verilog_text) self.assertIn('p0_x', module_signature.verilog_text) self.assertIn('p1_x', module_signature.verilog_text) self.assertIn('p2_x', module_signature.verilog_text) self.assertIn('p3_x', module_signature.verilog_text) self.assertIn('p4_x', module_signature.verilog_text)
def test_literal_array(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('f', p) fb.add_literal_value( ir_value.Value.make_array([ ir_value.Value(bits_mod.UBits(value=5, bit_count=32)), ir_value.Value(bits_mod.UBits(value=6, bit_count=32)), ])) fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn f() -> bits[32][2] { ret literal.1: bits[32][2] = literal(value=[5, 6], id=1) } """)
def _visit_Function( self, node: ast.Function, symbolic_bindings: Optional[SymbolicBindings]) -> ir_function.Function: self.symbolic_bindings = {} if symbolic_bindings is None else dict( symbolic_bindings) self._extract_module_level_constants(self.module) # We use a function builder for the duration of converting this # ast.Function. When it's done being built, we drop the reference to it (by # setting self.fb to None). self.fb = function_builder.FunctionBuilder( mangle_dslx_name(node.name.identifier, node.get_free_parametric_keys(), self.module, symbolic_bindings), self.package) try: for param in node.params: param.accept(self) for parametric_binding in node.parametric_bindings: logging.vlog(4, 'Resolving parametric binding %s', parametric_binding) sb_value = self.symbolic_bindings[parametric_binding.name.identifier] value = self._resolve_dim(sb_value) assert isinstance(value, int), \ 'Expect integral parametric binding; got {!r}'.format(value) self._def_const( parametric_binding, value, self._resolve_type(parametric_binding.type_).get_total_bit_count()) self._def_alias(parametric_binding, to=parametric_binding.name) for dep in self._constant_deps: dep.accept(self) del self._constant_deps[:] node.body.accept(self) last_expression = self.last_expression or node.body if isinstance(last_expression, ast.NameRef): self._def(last_expression, self.fb.add_identity, self._use(last_expression)) f = self.fb.build() logging.vlog(3, 'Built function: %s', f.name) verifier_mod.verify_function(f) return f finally: self.fb = None
def _def_map_with_builtin(self, parent_node: ast.Invocation, node: ast.NameRef, arg: ast.AstNode, symbolic_bindings: SymbolicBindings) -> BValue: """Makes the specified builtin available to the package.""" mangled_name = mangle_dslx_name(node.name_def.identifier, set(), self.module, symbolic_bindings) arg = self._use(arg) if mangled_name not in self.package.get_function_names(): fb = function_builder.FunctionBuilder(mangled_name, self.package) param = fb.add_param('arg', arg.get_type().get_element_type()) builtin_name = node.name_def.identifier assert builtin_name in dslx_builtins.UNARY_BUILTIN_NAMES, dslx_builtins.UNARY_BUILTIN_NAMES fbuilds = {'clz': fb.add_clz, 'ctz': fb.add_ctz} assert set(fbuilds.keys()) == dslx_builtins.UNARY_BUILTIN_NAMES, set( fbuilds.keys()) fbuilds[builtin_name](param) fb.build() return self._def(parent_node, self.fb.add_map, arg, self.package.get_function(mangled_name))
def test_match_true(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('f', p) pred_t = p.get_bits_type(1) expr_t = p.get_bits_type(32) pred_x = fb.add_param('pred_x', pred_t) x = fb.add_param('x', expr_t) pred_y = fb.add_param('pred_y', pred_t) y = fb.add_param('y', expr_t) default = fb.add_param('default', expr_t) fb.add_match_true([pred_x, pred_y], [x, y], default) fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn f(pred_x: bits[1], x: bits[32], pred_y: bits[1], y: bits[32], default: bits[32]) -> bits[32] { concat.6: bits[2] = concat(pred_y, pred_x, id=6) one_hot.7: bits[3] = one_hot(concat.6, lsb_prio=true, id=7) ret one_hot_sel.8: bits[32] = one_hot_sel(one_hot.7, cases=[x, y, default], id=8) } """)
def test_all_add_methods(self): # This test is mainly about checking that pybind11 is able to map parameter # and return types properly. Because of this it's not necessary to check # the result at the end; that methods don't throw when called is enough. p = ir_package.Package('test_package') fileno = p.get_or_create_fileno('my_file.x') lineno = fileno_mod.Lineno(42) colno = fileno_mod.Colno(64) loc = source_location.SourceLocation(fileno, lineno, colno) fb = function_builder.FunctionBuilder('test_function', p) input_function_builder = function_builder.FunctionBuilder('fn', p) input_function_builder.add_literal_value( ir_value.Value(bits_mod.UBits(7, 8))) input_function = input_function_builder.build() single_zero_bit = fb.add_literal_value( ir_value.Value(bits_mod.UBits(value=0, bit_count=1))) t = p.get_bits_type(32) x = fb.add_param('x', t) fb.add_shra(x, x, loc=loc) fb.add_shra(x, x, loc=loc) fb.add_shrl(x, x, loc=loc) fb.add_shll(x, x, loc=loc) fb.add_or(x, x, loc=loc) fb.add_nary_or([x], loc=loc) fb.add_xor(x, x, loc=loc) fb.add_and(x, x, loc=loc) fb.add_smul(x, x, loc=loc) fb.add_umul(x, x, loc=loc) fb.add_udiv(x, x, loc=loc) fb.add_sub(x, x, loc=loc) fb.add_add(x, x, loc=loc) fb.add_concat([x], loc=loc) fb.add_ule(x, x, loc=loc) fb.add_ult(x, x, loc=loc) fb.add_uge(x, x, loc=loc) fb.add_ugt(x, x, loc=loc) fb.add_sle(x, x, loc=loc) fb.add_slt(x, x, loc=loc) fb.add_sge(x, x, loc=loc) fb.add_sgt(x, x, loc=loc) fb.add_eq(x, x, loc=loc) fb.add_ne(x, x, loc=loc) fb.add_neg(x, loc=loc) fb.add_not(x, loc=loc) fb.add_clz(x, loc=loc) fb.add_one_hot(x, lsb_or_msb.LsbOrMsb.LSB, loc=loc) fb.add_one_hot_sel(x, [x], loc=loc) fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32), loc=loc) fb.add_literal_value(ir_value.Value( bits_mod.UBits(value=5, bit_count=32)), loc=loc) fb.add_sel(x, x, x, loc=loc) fb.add_sel_multi(x, [x], x, loc=loc) fb.add_match_true([single_zero_bit], [x], x, loc=loc) tuple_node = fb.add_tuple([x], loc=loc) fb.add_array([x], t, loc=loc) fb.add_tuple_index(tuple_node, 0, loc=loc) fb.add_counted_for(x, 1, 1, input_function, [x], loc=loc) fb.add_map(fb.add_array([x], t, loc=loc), input_function, loc=loc) fb.add_invoke([x], input_function, loc=loc) fb.add_array_index(fb.add_array([x], t, loc=loc), x, loc=loc) fb.add_reverse(fb.add_array([x], t, loc=loc), loc=loc) fb.add_identity(x, loc=loc) fb.add_signext(x, 10, loc=loc) fb.add_zeroext(x, 10, loc=loc) fb.add_bit_slice(x, 4, 2, loc=loc) fb.build()
def build_function(name='function_name'): pkg = package.Package('pname') builder = function_builder.FunctionBuilder(name, pkg) builder.add_literal_value(value.Value(bits.UBits(7, 8))) return builder.build()
def visit_For(self, node: ast.For) -> None: node.init.accept(self) def query_const_range_call() -> int: """Returns trip count if this is a `for ... in range(CONST)` construct.""" range_callee = ( isinstance(node.iterable, ast.Invocation) and isinstance(node.iterable.callee, ast.NameRef) and node.iterable.callee.identifier == 'range') if not range_callee: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found non-range callee.", node.span) if len(node.iterable.args) != 2: raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) if not self._is_constant_zero(node.iterable.args[0]): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(0, const)' call is supported, found inappropriate number " 'of arguments.', node.span) arg = node.iterable.args[1] arg.accept(self) if not self._is_const(arg): raise ConversionError( 'For-loop is of an unsupported form for IR conversion; only a ' "'range(const)' call is supported, did not find a const value " f'for {arg} ({arg!r}).', node.span) return self._get_const(arg) # TODO(leary): We currently only support counted loops of the form: # # for (i, ...): (u32, ...) in range(N) { # ... # } trip_count = query_const_range_call() logging.vlog(3, 'Converting for-loop @ %s', node.span) body_converter = _IrConverterFb( self.package, self.module, self.type_info, emit_positions=self.emit_positions) body_converter.symbolic_bindings = dict(self.symbolic_bindings) body_fn_name = ('__' + self.fb.name + '_counted_for_{}_body').format( self._next_counted_for_ordinal()).replace('.', '_') body_converter.fb = function_builder.FunctionBuilder( body_fn_name, self.package) flat = node.names.flatten1() assert len( flat ) == 2, 'Expect an induction binding and loop carry binding; got {!r}'.format( flat) # Add the induction value. assert isinstance( flat[0], ast.NameDef ), 'Induction variable was not a NameDef: {0} ({0!r})'.format(flat[0]) body_converter.node_to_ir[flat[0]] = body_converter.fb.add_param( flat[0].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[0])) # Add the loop carry value. if isinstance(flat[1], ast.NameDef): body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param( flat[1].identifier.encode('utf-8'), self._resolve_type_to_ir(flat[1])) else: # For tuple loop carries we have to destructure names on entry. carry_type = self._resolve_type_to_ir(flat[1]) carry = body_converter.node_to_ir[flat[1]] = body_converter.fb.add_param( '__loop_carry', carry_type) body_converter._visit_matcher( # pylint: disable=protected-access flat[1], (), carry, self._resolve_type(flat[1])) # Free variables are suffixes on the function parameters. freevars = node.body.get_free_variables(node.span.start) freevars = freevars.drop_defs(lambda x: isinstance(x, ast.BuiltinNameDef)) for name_def in freevars.get_name_defs(): type_ = self.type_info[name_def] if isinstance(type_, FunctionType): continue logging.vlog(3, 'Converting freevar name: %s', name_def) body_converter.node_to_ir[name_def] = body_converter.fb.add_param( name_def.identifier.encode('utf-8'), self._resolve_type_to_ir(name_def)) node.body.accept(body_converter) body_function = body_converter.fb.build() logging.vlog(3, 'Converted body function: %s', body_function.name) stride = 1 invariant_args = tuple( self._use(name_def) for name_def in freevars.get_name_defs() if not isinstance(self.type_info[name_def], FunctionType)) self._def(node, self.fb.add_counted_for, self._use(node.init), trip_count, stride, body_function, invariant_args)