def test_array_equivalence_fail(self): dslx_elements = [] ir_elements = [] for i in range(5): dslx_elements.append( dslx_value.Value.make_ubits(bit_count=4, value=i)) ir_elements.append( ir_value.Value(ir_bits.UBits(value=i, bit_count=4))) dslx_array = dslx_value.Value.make_array(tuple(dslx_elements)) ir_elements_extra = ( ir_elements + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))]) ir_array_different_size = ir_value.Value.make_array( tuple(ir_elements_extra)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_array, ir_array_different_size) ir_elements_wrong = ( ir_elements[:2] + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))] + ir_elements[3:]) ir_array_different_element = ir_value.Value.make_array( ir_elements_wrong) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_array, ir_array_different_element)
def test_invoke_adder_2_plus_3_eq_5(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('add_wrapper', p) t = p.get_bits_type(32) x = fb.add_param('x', t) y = fb.add_param('y', t) fb.add_add(x, y) add_wrapper = fb.build() main_fb = function_builder.FunctionBuilder('main', p) two = main_fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32)) three = main_fb.add_literal_bits(bits_mod.UBits(value=3, bit_count=32)) observed = main_fb.add_invoke([two, three], add_wrapper) main_fb.add_eq( observed, main_fb.add_literal_bits(bits_mod.UBits(value=5, bit_count=32))) main_fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn add_wrapper(x: bits[32], y: bits[32]) -> bits[32] { ret add.3: bits[32] = add(x, y, id=3) } fn main() -> bits[1] { literal.4: bits[32] = literal(value=2, id=4) literal.5: bits[32] = literal(value=3, id=5) invoke.6: bits[32] = invoke(literal.4, literal.5, to_apply=add_wrapper, id=6) literal.7: bits[32] = literal(value=5, id=7) ret eq.8: bits[1] = eq(invoke.6, literal.7, id=8) } """)
def test_tuple_equivalence_fail(self): dslx_members = [] ir_members = [] for i in range(5): dslx_members.append( dslx_value.Value.make_ubits(bit_count=4, value=i)) ir_members.append( ir_value.Value(ir_bits.UBits(value=i, bit_count=4))) dslx_tuple = dslx_value.Value.make_tuple(tuple(dslx_members)) ir_members_extra = ( ir_members + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))]) ir_tuple_different_size = ir_value.Value.make_tuple(ir_members_extra) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_tuple, ir_tuple_different_size) ir_members_wrong = ( ir_members[:2] + [ir_value.Value(ir_bits.UBits(value=5, bit_count=4))] + ir_members[3:]) ir_tuple_different_member = ir_value.Value.make_tuple(ir_members_wrong) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_tuple, ir_tuple_different_member)
def _visit_matcher(self, matcher: ast.NameDefTree, index: Tuple[int, ...], matched_value: BValue, matched_type: ConcreteType) -> BValue: if matcher.is_leaf(): leaf = matcher.get_leaf() if isinstance(leaf, ast.WildcardPattern): return self._def(matcher, self.fb.add_literal_bits, bits_mod.UBits(1, 1)) elif isinstance(leaf, (ast.Number, ast.EnumRef)): leaf.accept(self) return self._def(matcher, self.fb.add_eq, self._use(leaf), matched_value) elif isinstance(leaf, ast.NameRef): result = self._def(matcher, self.fb.add_eq, self._use(leaf.name_def), matched_value) self._def_alias(leaf.name_def, to=leaf) return result else: assert isinstance( leaf, ast.NameDef ), 'Expected leaf to be wildcard, number, or name; got: {!r}'.format( leaf) ok = self._def(leaf, self.fb.add_literal_bits, bits_mod.UBits(1, 1)) self.node_to_ir[matcher] = self.node_to_ir[leaf] = matched_value return ok else: ok = self.fb.add_literal_bits(bits_mod.UBits(value=1, bit_count=1)) for i, (element, element_type) in enumerate( zip(matcher.tree, matched_type.get_unnamed_members())): # pytype: disable=attribute-error # Extract the element. member = self.fb.add_tuple_index(matched_value, i) cond = self._visit_matcher(element, index + (i,), member, element_type) ok = self.fb.add_and(ok, cond) return ok
def _visit_concat(self, node: ast.Binop): output_type = self._resolve_type(node) lhs, rhs = self._use(node.lhs), self._use(node.rhs) if isinstance(output_type, BitsType): self._def(node, self.fb.add_concat, (lhs, rhs)) return # Array concat case: since we don't currently have an array_concat # operation (see https://github.com/google/xls/issues/72) in the IR we # gather up all the lhs and rhs elements and form an array from them. assert isinstance(output_type, ArrayType), output_type element_type = output_type.get_element_type() lhs_type = self._resolve_type(node.lhs) rhs_type = self._resolve_type(node.rhs) vals = [] for i in range(lhs_type.size): vals.append( self.fb.add_array_index( lhs, self.fb.add_literal_bits(bits_mod.UBits(value=i, bit_count=32)))) for i in range(rhs_type.size): vals.append( self.fb.add_array_index( rhs, self.fb.add_literal_bits(bits_mod.UBits(value=i, bit_count=32)))) self._def(node, self.fb.add_array, vals, self._type_to_ir(element_type))
def test_bits_equivalence_fail(self): dslx_bits = dslx_value.Value.make_ubits(bit_count=4, value=4) ir_bits_wrong_bit_count = ir_value.Value( ir_bits.UBits(value=4, bit_count=5)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_bits, ir_bits_wrong_bit_count) ir_bits_wrong_value = ir_value.Value( ir_bits.UBits(value=3, bit_count=4)) with self.assertRaises(jit_comparison.JitMiscompareError): jit_comparison.compare_values(dslx_bits, ir_bits_wrong_value)
def _int_to_bits(value: int, bit_count: int) -> bits_mod.Bits: """Converts a Python arbitrary precision int to a Bits type.""" if bit_count <= 64: return bits_mod.UBits(value, bit_count) if value >= 0 else bits_mod.SBits( value, bit_count) return number_parser.bits_from_string( bit_helpers.to_hex_string(value, bit_count), bit_count=bit_count)
def test_standard_pipeline(self): pkg = package.Package('pname') fb = function_builder.FunctionBuilder('main', pkg) fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32)) fb.build() self.assertFalse(standard_pipeline.run_standard_pass_pipeline(pkg))
def test_verify_function(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() verifier.verify_function(fn)
def test_literal_array(self): p = ir_package.Package('test_package') fb = function_builder.FunctionBuilder('f', p) fb.add_literal_value( ir_value.Value.make_array([ ir_value.Value(bits_mod.UBits(value=5, bit_count=32)), ir_value.Value(bits_mod.UBits(value=6, bit_count=32)), ])) fb.build() self.assertMultiLineEqual( p.dump_ir(), """\ package test_package fn f() -> bits[32][2] { ret literal.1: bits[32][2] = literal(value=[5, 6], id=1) } """)
def _cast_from_array(self, node: ast.Cast, output_type: ConcreteType) -> None: array = self._use(node.expr) array_type = self._resolve_type_to_ir(node.expr) pieces = [] for i in range(array_type.get_size()): pieces.append( self.fb.add_array_index( array, self.fb.add_literal_bits(bits_mod.UBits(i, 32)))) self._def(node, self.fb.add_concat, pieces)
def int_to_bits(value: int, bit_count: int) -> ir_bits.Bits: """Converts a Python arbitrary precision int to a Bits type.""" if bit_count <= WORD_SIZE: return ir_bits.UBits(value, bit_count) if value >= 0 else ir_bits.SBits( value, bit_count) return number_parser.bits_from_string(bit_helpers.to_hex_string( value, bit_count), bit_count=bit_count)
def _visit_one_hot_sel(self, node: ast.Invocation, args: Tuple[BValue, ...]) -> BValue: lhs, array = args array_type = array.get_type() rhs_elements = [] for i in range(array_type.get_size()): rhs_elements.append( self.fb.add_array_index( array, self.fb.add_literal_bits(bits_mod.UBits(i, 32)))) return self._def(node, self.fb.add_one_hot_sel, lhs, rhs_elements)
def test_function_type(self): pkg = package.Package('pname') builder = function_builder.FunctionBuilder('f_name', pkg) builder.add_param('x', pkg.get_bits_type(32)) builder.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fn = builder.build() fn_type = fn.get_type() self.assertIsInstance(fn_type, ir_type.FunctionType) self.assertIn('bits[32]', str(fn_type)) self.assertEqual(8, fn_type.return_type().get_bit_count()) self.assertEqual(1, fn_type.get_parameter_count()) self.assertEqual(32, fn_type.get_parameter_type(0).get_bit_count())
def test_package_methods(self): pkg = package.Package('pkg') fb = function_builder.FunctionBuilder('f', pkg) fb.add_literal_value(ir_value.Value(bits.UBits(7, 8))) fb.build() self.assertIn('pkg', pkg.dump_ir()) self.assertIsInstance(pkg.get_bits_type(4), ir_type.BitsType) self.assertIsInstance(pkg.get_array_type(4, pkg.get_bits_type(4)), ir_type.ArrayType) self.assertIsInstance(pkg.get_tuple_type([pkg.get_bits_type(4)]), ir_type.TupleType) self.assertIsInstance(pkg.get_or_create_fileno('file'), fileno.Fileno) self.assertIsInstance(pkg.get_function('f'), function.Function) self.assertEqual(['f'], pkg.get_function_names())
def test_bits(self): self.assertEqual(43, bits.UBits(43, 7).to_uint()) self.assertEqual(53, bits.UBits(53, 7).to_int()) self.assertEqual(33, bits.SBits(33, 8).to_uint()) self.assertEqual(83, bits.SBits(83, 8).to_int()) self.assertEqual(255, bits.UBits(255, 8).to_uint()) self.assertEqual(-1, bits.UBits(255, 8).to_int()) self.assertEqual(-1, bits.UBits(2**64 - 1, 64).to_int()) self.assertEqual(-2**63, bits.SBits(-2**63, 64).to_int()) self.assertEqual(-2**31, bits.SBits(-2**31, 32).to_int()) self.assertEqual(2**31 - 1, bits.SBits(2**31 - 1, 32).to_int()) self.assertEqual(-2**31, bits.UBits(2**31, 32).to_int()) self.assertEqual(-1, bits.SBits(-1, 1).to_int()) self.assertEqual(-1, bits.SBits(-1, 8).to_int()) self.assertEqual(-1, bits.SBits(-1, 63).to_int()) self.assertEqual(-2, bits.SBits(-2, 64).to_int()) self.assertEqual(-83, bits.SBits(-83, 8).to_int())
def test_all_add_methods(self): # This test is mainly about checking that pybind11 is able to map parameter # and return types properly. Because of this it's not necessary to check # the result at the end; that methods don't throw when called is enough. p = ir_package.Package('test_package') fileno = p.get_or_create_fileno('my_file.x') lineno = fileno_mod.Lineno(42) colno = fileno_mod.Colno(64) loc = source_location.SourceLocation(fileno, lineno, colno) fb = function_builder.FunctionBuilder('test_function', p) input_function_builder = function_builder.FunctionBuilder('fn', p) input_function_builder.add_literal_value( ir_value.Value(bits_mod.UBits(7, 8))) input_function = input_function_builder.build() single_zero_bit = fb.add_literal_value( ir_value.Value(bits_mod.UBits(value=0, bit_count=1))) t = p.get_bits_type(32) x = fb.add_param('x', t) fb.add_shra(x, x, loc=loc) fb.add_shra(x, x, loc=loc) fb.add_shrl(x, x, loc=loc) fb.add_shll(x, x, loc=loc) fb.add_or(x, x, loc=loc) fb.add_nary_or([x], loc=loc) fb.add_xor(x, x, loc=loc) fb.add_and(x, x, loc=loc) fb.add_smul(x, x, loc=loc) fb.add_umul(x, x, loc=loc) fb.add_udiv(x, x, loc=loc) fb.add_sub(x, x, loc=loc) fb.add_add(x, x, loc=loc) fb.add_concat([x], loc=loc) fb.add_ule(x, x, loc=loc) fb.add_ult(x, x, loc=loc) fb.add_uge(x, x, loc=loc) fb.add_ugt(x, x, loc=loc) fb.add_sle(x, x, loc=loc) fb.add_slt(x, x, loc=loc) fb.add_sge(x, x, loc=loc) fb.add_sgt(x, x, loc=loc) fb.add_eq(x, x, loc=loc) fb.add_ne(x, x, loc=loc) fb.add_neg(x, loc=loc) fb.add_not(x, loc=loc) fb.add_clz(x, loc=loc) fb.add_one_hot(x, lsb_or_msb.LsbOrMsb.LSB, loc=loc) fb.add_one_hot_sel(x, [x], loc=loc) fb.add_literal_bits(bits_mod.UBits(value=2, bit_count=32), loc=loc) fb.add_literal_value(ir_value.Value( bits_mod.UBits(value=5, bit_count=32)), loc=loc) fb.add_sel(x, x, x, loc=loc) fb.add_sel_multi(x, [x], x, loc=loc) fb.add_match_true([single_zero_bit], [x], x, loc=loc) tuple_node = fb.add_tuple([x], loc=loc) fb.add_array([x], t, loc=loc) fb.add_tuple_index(tuple_node, 0, loc=loc) fb.add_counted_for(x, 1, 1, input_function, [x], loc=loc) fb.add_map(fb.add_array([x], t, loc=loc), input_function, loc=loc) fb.add_invoke([x], input_function, loc=loc) fb.add_array_index(fb.add_array([x], t, loc=loc), x, loc=loc) fb.add_reverse(fb.add_array([x], t, loc=loc), loc=loc) fb.add_identity(x, loc=loc) fb.add_signext(x, 10, loc=loc) fb.add_zeroext(x, 10, loc=loc) fb.add_bit_slice(x, 4, 2, loc=loc) fb.build()
def v1(n): return ir_value.Value(bits_mod.UBits(value=int(n), bit_count=1))
def test_bits(self): self.assertEqual(43, bits.UBits(43, 7).to_uint()) self.assertEqual(53, bits.UBits(53, 7).to_int()) self.assertEqual(33, bits.SBits(33, 8).to_uint()) self.assertEqual(83, bits.SBits(83, 8).to_int()) self.assertEqual(-83, bits.SBits(-83, 8).to_int())
def build_function(name='function_name'): pkg = package.Package('pname') builder = function_builder.FunctionBuilder(name, pkg) builder.add_literal_value(value.Value(bits.UBits(7, 8))) return builder.build()