def __init__(self, arch_fc, *, path_constraints={}): super().__init__(arch_fc) if self.num_output_forms > 1: raise NotImplementedError("Multiple ir output forms") #Verify that all the path_constraints are valid path_constraints = { path: (c if isinstance(c, tuple) else (c, )) for path, c in path_constraints.items() } path_to_adt = self.path_to_adt(input=True, family=family.SMTFamily(), strip=True) for path, constraints in path_constraints.copy().items(): if path not in path_to_adt: raise ValueError( f"{path} is either invalid or not an adt leaf") assert path in self.input_varmap adt = path_to_adt[path] aadt = family.SMTFamily().get_adt_t(adt) try: constraints = tuple((aadt(c) for c in constraints)) except Exception as e: print("Invalid constraints for {path}") raise e path_constraints[path] = constraints self.path_constraints = path_constraints
def verify(self, solver_name: str = "z3") -> tp.Union[None, "CounterExample"]: # create free variable for each ir_val ir_path_types = _create_path_to_adt( strip_modifiers(self.ir_fc(family.SMTFamily()).input_t)) ir_vars = { path: _free_var_from_t(ir_path_types[path]) for path in self.ir_bounded } ir_inputs = self.build_ir_input(ir_vars, family.SMTFamily()) arch_inputs = self.build_arch_input(ir_vars, family.SMTFamily()) ir = self.ir_fc(family.SMTFamily())() arch = self.arch_fc(family.SMTFamily())() ir_out_values = self.parse_ir_output(ir(**ir_inputs)) arch_out_values = self.parse_arch_output(arch(**arch_inputs)) outputs = [] for ir_path, arch_path in self.obinding: if ir_path not in ir_out_values: raise ValueError(f"{ir_path} is not valid") if arch_path not in arch_out_values: raise ValueError(f"{arch_path} is not valid") outputs.append( ir_out_values[ir_path] != arch_out_values[arch_path]) formula = or_reduce(outputs) with smt.Solver(solver_name, logic=BV) as solver: solver.add_assertion(formula.value) verified = not solver.solve() if verified: return None else: return { path: solved_to_bv(var, solver) for path, var in ir_vars.items() }
def rebind_value(val, _family): if isinstance(val, family.PyFamily().BitVector): return _family.BitVector[val.size](val.value) elif isinstance(val, family.PyFamily().Bit): return _family.Bit(val) elif isinstance(val, family.SMTFamily().BitVector): if not val._value_.is_constant(): raise ValueError("Cannot convert non-const SMT var to Py") return _family.BitVector[val.size](val._value_.constant_value()) elif isinstance(val, family.SMTFamily().Bit): if not val._value_.is_constant(): raise ValueError("Cannot convert non-const SMT var to Py") return _family.Bit(val._value_.constant_value()) else: raise ValueError(f"Cannot rebind value: {val}")
def __init__(self, peak_fc: tp.Callable): if not isinstance(peak_fc, family_closure): raise ValueError( f"family closure {peak_fc} needs to be decorated with @family_closure" ) Peak_cls = _get_peak_cls(peak_fc(family.SMTFamily())) try: input_t = Peak_cls.input_t output_t = Peak_cls.output_t except AttributeError: raise ValueError("Need to use gen_input_t and gen_output_t") stripped_input_t = strip_modifiers(input_t) stripped_output_t = strip_modifiers(output_t) input_aadt_t = family.SMTFamily().get_adt_t(stripped_input_t) output_aadt_t = family.SMTFamily().get_adt_t(stripped_output_t) input_forms, input_varmap = SMTForms()(input_aadt_t) #output_form = output_forms[input_form_idx][output_form_idx] output_forms = [] for input_form in input_forms: inputs = aadt_product_to_dict(input_form.value) #Construct output_aadt value outputs = Peak_cls()(**inputs) output_value = wrap_outputs(outputs, output_aadt_t) forms, output_varmap = SMTForms()(output_aadt_t, value=output_value) #Check consistency of SMTForms for f in forms: assert f.value == output_value output_forms.append(forms) num_input_forms = len(output_forms) num_output_forms = len(output_forms[0]) #verify same number of output forms assert all(num_output_forms == len(forms) for forms in output_forms) self.peak_fc = peak_fc self.input_form_var = SBV[num_input_forms]() self.output_form_var = SBV[num_output_forms]() self.input_forms = input_forms self.output_forms = output_forms self.num_output_forms = num_output_forms self.num_input_forms = num_input_forms self.input_varmap = input_varmap
def _free_var_from_t(T): if issubclass(T, SBV): return T() aadt_t = family.SMTFamily().get_adt_t(T) adt_t, assembler_t, bv_t = aadt_t.fields assembler = assembler_t(adt_t) return bv_t[assembler.width]()
def test_non_const_constraint(): @family_closure def ir_fc(family): Data = family.BitVector[8] @family.assemble(locals(), globals()) class IR(Peak): @name_outputs(out=Data) def __call__(self, in0: Data): return in0 return IR isa = ISA_fc(family.SMTFamily()) OpT = isa.Op for in0_constraint, solved in ( (-5, True), (-4, False), ((-5, -4), False), ): constraints = { ("inst", isa.ArithOp, 0): OpT.A, # Const ("inst", isa.ArithOp, 1): 5, # Const ("in0", ): in0_constraint, # Not Const } run_constraint_test(ir_fc, constraints=constraints, solved=solved)
def test_const_constraint(): @family_closure def ir_fc(family): Data = family.BitVector[8] @family.assemble(locals(), globals()) class IR(Peak): @name_outputs(out=Data) def __call__(self, in0: Data, in1: Data): return in0 + in1 + 4 # inst.offset should be 4 return IR isa = ISA_fc(family.SMTFamily()) for constraint, solved in ( (4, True), ((4, 5), True), (5, False), ((3, 5), False), ): constraints = {("inst", isa.ArithOp, 1): constraint} run_constraint_test(ir_fc, constraints=constraints, solved=solved) OpT = isa.Op for constraint, solved in ( (OpT.A, True), ((OpT.A, OpT.B), True), (OpT.B, False), ): constraints = {("inst", isa.ArithOp, 0): constraint} run_constraint_test(ir_fc, constraints=constraints, solved=solved)
def test_assemble(): @family_closure def PE_fc(family): Bit = family.Bit @family.assemble(locals(), globals()) class PESimple(Peak, typecheck=True): def __call__(self, in0: Bit, in1: Bit) -> Bit: return in0 & in1 return PESimple #verify BV works PE_bv = PE_fc(family.PyFamily()) vals = [Bit(0), Bit(1)] for i0, i1 in itertools.product(vals, vals): assert PE_bv()(i0, i1) == i0 & i1 #verify SMT works PE_smt = PE_fc(family.SMTFamily()) vals = [SMTBit(0), SMTBit(1), SMTBit(), SMTBit()] for i0, i1 in itertools.product(vals, vals): assert PE_smt()(i0, i1) == i0 & i1 #verify magma works PE_magma = PE_fc(family.MagmaFamily()) tester = fault.Tester(PE_magma) vals = [0, 1] for i0, i1 in itertools.product(vals, vals): tester.circuit.in0 = i0 tester.circuit.in1 = i1 tester.eval() tester.circuit.O.expect(i0 & i1) tester.compile_and_run("verilator", flags=["-Wno-fatal"])
def test_enum(): class Op(Enum): And = 1 Or = 2 @family_closure def PE_fc(family): Bit = family.Bit @family.assemble(locals(), globals()) class PE_Enum(Peak): def __call__(self, op: Const(Op), in0: Bit, in1: Bit) -> Bit: if op == Op.And: return in0 & in1 else: #op == Op.Or return in0 | in1 return PE_Enum # verify BV works PE_bv = PE_fc(family.PyFamily()) vals = [Bit(0), Bit(1)] for op in Op.enumerate(): for i0, i1 in itertools.product(vals, vals): res = PE_bv()(op, i0, i1) gold = (i0 & i1) if (op is Op.And) else (i0 | i1) assert res == gold # verify BV works PE_smt = PE_fc(family.SMTFamily()) Op_aadt = AssembledADT[Op, Assembler, SMTBitVector] vals = [SMTBit(0), SMTBit(1), SMTBit(), SMTBit()] for op in Op.enumerate(): op = Op_aadt(op) for i0, i1 in itertools.product(vals, vals): res = PE_smt()(op, i0, i1) gold = (i0 & i1) if (op is Op.And) else (i0 | i1) assert res == gold # verify magma works asm = Assembler(Op) PE_magma = PE_fc(family.MagmaFamily()) tester = fault.Tester(PE_magma) vals = [0, 1] for op in (Op.And, Op.Or): for i0, i1 in itertools.product(vals, vals): gold = (i0 & i1) if (op is Op.And) else (i0 | i1) tester.circuit.op = int(asm.assemble(op)) tester.circuit.in0 = i0 tester.circuit.in1 = i1 tester.eval() tester.circuit.O.expect(gold) tester.compile_and_run("verilator", flags=["-Wno-fatal"])
def __init__(self, archmapper, ir_fc): super().__init__(ir_fc) #For now assume that ir input forms and ir output forms is just 1 if self.num_input_forms > 1: raise NotImplementedError("Multiple ir input forms") if self.num_output_forms > 1: raise NotImplementedError("Multiple ir output forms") ir_input_form = self.input_forms[0] ir_output_form = self.output_forms[0][0] self.archmapper = archmapper arch_output_form = archmapper.output_forms[0] # Create input bindings # binding = [input_form_idx][bidx] input_bindings = [] arch_input_path_to_adt = archmapper.path_to_adt( input=True, family=family.SMTFamily()) #Removes any invalid bindings def constraint_filter(binding): for ir_path, arch_path in binding: if arch_path in archmapper.path_constraints and ir_path is not Unbound: return False return True ir_path_to_adt = self.path_to_adt(input=True, family=family.SMTFamily()) #Verify all paths are the same assert set(ir_path_to_adt.keys()) == set(self.input_varmap.keys()) for af in archmapper.input_forms: #Verify all paths of form is subset of all paths assert set(arch_input_path_to_adt.keys()).issuperset( set(af.varmap.keys())) form_arch_input_path_to_adt = { p: T for p, T in arch_input_path_to_adt.items() if p in af.varmap } bindings = create_bindings(form_arch_input_path_to_adt, ir_path_to_adt) bindings = list(filter(constraint_filter, bindings)) input_bindings.append(bindings) # Check Early out self.has_bindings = max(len(bs) for bs in input_bindings) > 0 if not self.has_bindings: return # Create output bindings arch_output_path_to_adt = archmapper.path_to_adt( input=False, family=family.SMTFamily()) ir_path_to_adt = self.path_to_adt(input=False, family=family.SMTFamily()) #binding = [bidx] output_bindings = create_bindings(arch_output_path_to_adt, ir_path_to_adt) # Check Early out self.has_bindings = len(output_bindings) > 0 if not self.has_bindings: return form_var = archmapper.input_form_var #Create the form_conditions (preconditions) based off of the arch_forms #[input_form_idx] form_conditions = [] for fi, form in enumerate(archmapper.input_forms): #form_condition represents the & of all the appropriate matche conditions = [form_var == 2**fi] for path, choice in form.path_dict.items(): match_path = path + (Match, ) assert match_path in archmapper.input_varmap conditions.append(archmapper.input_varmap[match_path][choice]) form_conditions.append(conditions) max_input_bindings = max(len(bindings) for bindings in input_bindings) ib_var = SBV[max_input_bindings]() max_output_bindings = len(output_bindings) ob_var = SBV[max_output_bindings]() constraints = [] #Build the constraint forall_vars = set() for fi, ibindings in enumerate(input_bindings): conditions = list(form_conditions[fi]) for bi, ibinding in enumerate(ibindings): bi_match = (ib_var == 2**bi) #Build substitution map submap = [] for ir_path, arch_path in ibinding: arch_var = archmapper.input_varmap[arch_path] is_unbound = ir_path is Unbound is_constrained = arch_path in self.archmapper.path_constraints is_const = issubclass(arch_input_path_to_adt[arch_path], Const) if is_constrained: assert is_unbound continue if is_unbound and not is_const: #add arch_var to list of forall vars forall_vars.add(arch_var.value) elif not is_unbound and not is_constrained: #substitue arch_var with ir_var add ir_var to forall list ir_var = self.input_varmap[ir_path] submap.append((arch_var, ir_var)) forall_vars.add(ir_var.value) for bo, obinding in enumerate(output_bindings): bo_match = (ob_var == 2**bo) conditions = list( form_conditions[fi]) + [bi_match, bo_match] for ir_path, arch_path in obinding: if ir_path is Unbound: continue ir_out = self.output_forms[0][0].varmap[ir_path] arch_out = archmapper.output_forms[fi][0].varmap[ arch_path] arch_out = arch_out.substitute(*submap) conditions.append(ir_out == arch_out) constraints.append(conditions) formula = or_reduce([and_reduce(conds) for conds in constraints]) # Adding in the constraints: # Start with the non-const constraints. # This is basically doing universal qualifier over a limited set of values # To do this, evaluate the formula with each possible value, # then 'and' together the resulting partially-evaluted formulas for path, constraints in archmapper.path_constraints.items(): is_const = issubclass(arch_input_path_to_adt[path], Const) if is_const: continue arch_var = archmapper.input_varmap[path] formula = and_reduce( (formula.substitute((arch_var, c)) for c in constraints)) # Then deal with the const constraints. # This is saying only allow a set of values for the existential quantifier # Thus just create an additional constraint that arch_var is either c1 or c2 or c3 for path, constraints in archmapper.path_constraints.items(): is_const = issubclass(arch_input_path_to_adt[path], Const) if not is_const: continue arch_var = archmapper.input_varmap[path] constraint = or_reduce((arch_var == c for c in constraints)) formula &= constraint self.ib_var = ib_var self.ob_var = ob_var self.input_bindings = input_bindings self.output_bindings = output_bindings self.formula = smt.ForAll(list(forall_vars), formula.value)
def parse_arch_output(self, outputs): output_t = self.arch_fc(family.SMTFamily()).output_t output_aadt = family.SMTFamily().get_adt_t(output_t) output_value = wrap_outputs(outputs, output_aadt) _, values = SMTForms()(output_aadt, value=output_value) return values
def check_families(PE_fc): PE_bv = PE_fc(f.PyFamily()) PE_smt = PE_fc(f.SMTFamily()) PE_magma = PE_fc(f.MagmaFamily())