def test_class_methods(self): source = """ class Rectangle { int width; int height; int getArea(){ return width * height; } }; void Rectangle::set_values(int w, int h) { width = w; height = h; } int Rectangle::get_width(){ return width; } int Rectangle::get_height(){ return height; } int test(int a, int b) { Rectangle rect; rect.width = 0; rect.height = 0; rect.set_values(a, b); return rect.getArea() + rect.get_height() + rect.get_width(); } """ f = self.parse_and_get_function(source) aval = ir_value.Value(bits_mod.SBits(value=int(6), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(3), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(27, result_int)
def test_array_equivalence_fail(self): dslx_elements = [] ir_elements = [] for i in range(5): dslx_elements.append( dslx_value.Value.make_ubits(bit_count=4, value=i)) ir_elements.append( ir_value.Value(ir_bits.UBits(value=i, bit_count=4))) dslx_array = dslx_value.Value.make_array(tuple(dslx_elements)) ir_elements_extra = ( ir_elements + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))]) ir_array_different_size = ir_value.Value.make_array( tuple(ir_elements_extra)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_array, ir_array_different_size) ir_elements_wrong = ( ir_elements[:2] + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))] + ir_elements[3:]) ir_array_different_element = ir_value.Value.make_array( ir_elements_wrong) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_array, ir_array_different_element)
def test_tuple_equivalence_fail(self): dslx_members = [] ir_members = [] for i in range(5): dslx_members.append( dslx_value.Value.make_ubits(bit_count=4, value=i)) ir_members.append( ir_value.Value(ir_bits.UBits(value=i, bit_count=4))) dslx_tuple = dslx_value.Value.make_tuple(tuple(dslx_members)) ir_members_extra = ( ir_members + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))]) ir_tuple_different_size = ir_value.Value.make_tuple(ir_members_extra) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_tuple, ir_tuple_different_size) ir_members_wrong = ( ir_members[:2] + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))] + ir_members[3:]) ir_tuple_different_member = ir_value.Value.make_tuple(ir_members_wrong) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_tuple, ir_tuple_different_member)
def test_mandelbrot(self): translator = xlscc_translator.Translator("mypackage") my_parser = ext_c_parser.XLSccParser() source_path = runfiles.get_path( "xls/contrib/xlscc_obsolete/translate/testdata/mandelbrot_test.cc") binary_path = runfiles.get_path( "xls/contrib/xlscc_obsolete/translate/mandelbrot_test") nx = 48 ny = 32 with open(source_path) as f: content = f.read() # Hackily do the preprocessing, since in build environments we do not have a # cpp binary available (and if we did relying on it here without a # build-system-noted dependency would be non-hermetic). content = re.sub("^#if !NO_TESTBENCH$.*^#endif$", "", content, 0, re.MULTILINE | re.DOTALL) content = re.sub(re.compile("//.*?\n"), "", content) f = self.create_tempfile(content=content) ast = pycparser.parse_file(f.full_path, use_cpp=False, parser=my_parser) cpp_out = None with os.popen(binary_path) as osf: cpp_out = osf.read() parsed_cpp_out = eval(cpp_out) translator.parse(ast) p = translator.gen_ir() f = p.get_function("mandelbrot") result_arr = parsed_cpp_out for y in range(0, ny): for x in range(0, nx): xx = float(x) / nx yy = float(y) / ny xi = int(xx * 2.5 - 1.8) * (1 << 16) yi = int(yy * 2.2 - 1.1) * (1 << 16) args = dict(c_r=ir_value.Value( bits_mod.SBits(value=int(xi), bit_count=32)), c_i=ir_value.Value( bits_mod.SBits(value=int(yi), bit_count=32))) result = ir_interpreter.run_function_kwargs(f, args) result_sai32 = int(str(result)) result_arr[y][x] = result_sai32 self.assertEqual(parsed_cpp_out, result_arr)
def one_in_one_out(self, source, a_input, b_input, expected_output): f = self.parse_and_get_function(""" int test(sai32 a, sai32 b) { """ + source + """ } """) aval = ir_value.Value(bits_mod.SBits(value=int(a_input), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(b_input), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) self.assertEqual(expected_output, int(ctypes.c_int32(int(str(result))).value))
def test_bits_equivalence_fail(self): dslx_bits = dslx_value.Value.make_ubits(bit_count=4, value=4) ir_bits_wrong_bit_count = ir_value.Value( ir_bits.UBits(value=4, bit_count=5)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_bits, ir_bits_wrong_bit_count) ir_bits_wrong_value = ir_value.Value( ir_bits.UBits(value=3, bit_count=4)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_bits, ir_bits_wrong_value)
def test_enum_autocount(self): source = """ enum states{w, x, y=5, z}; int test(int a, int b){ return a+z+b+x; } """ f = self.parse_and_get_function(source) aval = ir_value.Value(bits_mod.SBits(value=int(2), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(3), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(12, result_int)
def test_literal_array(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('f', p) fb.add_literal_value( ir_value.Value.make_array([ ir_value.Value(bits_mod.UBits(value=5, bit_count=32)), ir_value.Value(bits_mod.UBits(value=6, bit_count=32)), ])) fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn f() -> bits[32][2] { ret literal.1: bits[32][2] = literal(value=[5, 6], id=1) } """)
def test_arrayref(self): f = self.parse_and_get_function(""" void array_update(int o[2], int x) { o[0] = x; } int test(int a, int b) { int s[2] = {0,b}; array_update(s, a); return s[0] + s[1]; } """) aval = ir_value.Value(bits_mod.SBits(value=int(22), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(10), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(32, result_int)
def test_verify_function(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() verifier.verify_function(fn)
def test_structref(self): somestruct = hls_types_pb2.HLSStructType() int_type = hls_types_pb2.HLSIntType() int_type.signed = True int_type.width = 18 translated_hls_type = hls_types_pb2.HLSType() translated_hls_type.as_int.CopyFrom(int_type) hls_field = hls_types_pb2.HLSNamedType() hls_field.name = "x" hls_field.hls_type.CopyFrom(translated_hls_type) somestruct.fields.add().CopyFrom(hls_field) hls_field = hls_types_pb2.HLSNamedType() hls_field.name = "y" hls_field.hls_type.CopyFrom(translated_hls_type) somestruct.fields.add().CopyFrom(hls_field) somestructtype = hls_types_pb2.HLSType() somestructtype.as_struct.CopyFrom(somestruct) hls_types_by_name = {"SomeStruct": somestructtype} f = self.parse_and_get_function( """ void struct_update(SomeStruct &o, sai18 x) { o.x.set_slc(0, x); } int test(sai32 a, int b) { SomeStruct s; s.x = 0; s.y = b; struct_update(s, a); return s.x + s.y; } """, hls_types_by_name) aval = ir_value.Value(bits_mod.SBits(value=int(22), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(10), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(32, result_int)
def test_ref_params(self): f = self.parse_and_get_function(""" void add_in_place(int &o, int x) { o += x; } int add_in_place_2(int x, int &o) { o += x; return o*2; } int test(int a, int b){ add_in_place(a, b); return add_in_place_2(b, a) + a; } """) aval = ir_value.Value(bits_mod.SBits(value=int(11), bit_count=32)) bval = ir_value.Value(bits_mod.SBits(value=int(3), bit_count=32)) args = dict(a=aval, b=bval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(51, result_int)
def test_globalconst(self): f = self.parse_and_get_function(""" const int foo[6] = {2,4,5,3,2,1}; int test(int a) { return foo[a]; } """) aval = ir_value.Value(bits_mod.SBits(value=int(2), bit_count=32)) args = dict(a=aval) result = ir_interpreter.run_function_kwargs(f, args) result_int = int(ctypes.c_int32(int(str(result))).value) self.assertEqual(5, result_int)
def test_function_type(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() fn_type = fn.get_type() self.assertIsInstance(fn_type, ir_type.FunctionType) self.assertIn('bits[32]', str(fn_type)) self.assertEqual(8, fn_type.return_type().get_bit_count()) self.assertEqual(1, fn_type.get_parameter_count()) self.assertEqual(32, fn_type.get_parameter_type(0).get_bit_count())
def visit_ConstantArray(self, node: ast.ConstantArray) -> None: array_type = self._resolve_type(node) e_type = array_type.get_element_type() # pytype: disable=attribute-error values = [] for n in node.members: e = self._get_const(n) values.append( ir_value.Value(_int_to_bits(e, e_type.get_total_bit_count()))) if node.has_ellipsis: while len(values) < array_type.size: # pytype: disable=attribute-error values.append(values[-1]) self._def(node, self.fb.add_literal_value, ir_value.Value.make_array(values))
def test_package_methods(self): pkg = package.Package('pkg') fb = function_builder.FunctionBuilder('f', pkg) fb.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fb.build() self.assertIn('pkg', pkg.dump_ir()) self.assertIsInstance(pkg.get_bits_type(4), ir_type.BitsType) self.assertIsInstance(pkg.get_array_type(4, pkg.get_bits_type(4)), ir_type.ArrayType) self.assertIsInstance(pkg.get_tuple_type([pkg.get_bits_type(4)]), ir_type.TupleType) self.assertIsInstance(pkg.get_or_create_fileno('file'), fileno.Fileno) self.assertIsInstance(pkg.get_function('f'), function.Function) self.assertEqual(['f'], pkg.get_function_names())
def convert_interpreter_value_to_ir( interpreter_value: dslx_value.Value) -> ir_value.Value: """Recursively translates a DSLX Value into an IR Value.""" if interpreter_value.is_bits() or interpreter_value.is_enum(): return ir_value.Value(interpreter_value.get_bits()) elif interpreter_value.is_array(): ir_arr = [] for e in interpreter_value.get_elements(): ir_arr.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_array(ir_arr) elif interpreter_value.is_tuple(): ir_tuple = [] for e in interpreter_value.get_elements(): ir_tuple.append(convert_interpreter_value_to_ir(e)) return ir_value.Value.make_tuple(ir_tuple) else: raise UnsupportedJitConversionError( "Can't convert to JIT value: {}".format(interpreter_value))
def test_all_add_methods(self): # This test is mainly about checking that pybind11 is able to map parameter # and return types properly. Because of this it's not necessary to check # the result at the end; that methods don't throw when called is enough. p = ir_package.Package('test_package') fileno = p.get_or_create_fileno('my_file.x') lineno = fileno_mod.Lineno(42) colno = fileno_mod.Colno(64) loc = source_location.SourceLocation(fileno, lineno, colno) fb = function_builder.FunctionBuilder('test_function', p) input_function_builder = function_builder.FunctionBuilder('fn', p) input_function_builder.add_literal_value( ir_value.Value(bits_mod.UBits(7, 8))) input_function = input_function_builder.build() single_zero_bit = fb.add_literal_value( ir_value.Value(bits_mod.UBits(value=0, bit_count=1))) t = p.get_bits_type(32) x = fb.add_param('x', t) fb.add_shra(x, x, loc=loc) fb.add_shra(x, x, loc=loc) fb.add_shrl(x, x, loc=loc) fb.add_shll(x, x, loc=loc) fb.add_or(x, x, loc=loc) fb.add_nary_or([x], loc=loc) fb.add_xor(x, x, loc=loc) fb.add_and(x, x, loc=loc) fb.add_smul(x, x, loc=loc) fb.add_umul(x, x, loc=loc) fb.add_udiv(x, x, loc=loc) fb.add_sub(x, x, loc=loc) fb.add_add(x, x, loc=loc) fb.add_concat([x], loc=loc) fb.add_ule(x, x, loc=loc) fb.add_ult(x, x, loc=loc) fb.add_uge(x, x, loc=loc) fb.add_ugt(x, x, loc=loc) fb.add_sle(x, x, loc=loc) fb.add_slt(x, x, loc=loc) fb.add_sge(x, x, loc=loc) fb.add_sgt(x, x, loc=loc) fb.add_eq(x, x, loc=loc) fb.add_ne(x, x, loc=loc) fb.add_neg(x, loc=loc) fb.add_not(x, loc=loc) fb.add_clz(x, loc=loc) fb.add_one_hot(x, lsb_or_msb.LsbOrMsb.LSB, loc=loc) fb.add_one_hot_sel(x, [x], loc=loc) fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32), loc=loc) fb.add_literal_value(ir_value.Value( bits_mod.UBits(value=5, bit_count=32)), loc=loc) fb.add_sel(x, x, x, loc=loc) fb.add_sel_multi(x, [x], x, loc=loc) fb.add_match_true([single_zero_bit], [x], x, loc=loc) tuple_node = fb.add_tuple([x], loc=loc) fb.add_array([x], t, loc=loc) fb.add_tuple_index(tuple_node, 0, loc=loc) fb.add_counted_for(x, 1, 1, input_function, [x], loc=loc) fb.add_map(fb.add_array([x], t, loc=loc), input_function, loc=loc) fb.add_invoke([x], input_function, loc=loc) fb.add_array_index(fb.add_array([x], t, loc=loc), x, loc=loc) fb.add_reverse(fb.add_array([x], t, loc=loc), loc=loc) fb.add_identity(x, loc=loc) fb.add_signext(x, 10, loc=loc) fb.add_zeroext(x, 10, loc=loc) fb.add_bit_slice(x, 4, 2, loc=loc) fb.build()
def v1(n): return ir_value.Value(bits_mod.UBits(value=int(n), bit_count=1))
def v32(n): return ir_value.Value(bits_mod.SBits(value=int(n), bit_count=32))
def build_function(name='function_name'): pkg = package.Package('pname') builder = function_builder.FunctionBuilder(name, pkg) builder.add_literal_value(value.Value(bits.UBits(7, 8))) return builder.build()