def verify(self, solver_name: str = "z3") -> tp.Union[None, "CounterExample"]: # create free variable for each ir_val ir_path_types = _create_path_to_adt( strip_modifiers(self.ir_fc(family.SMTFamily()).input_t)) ir_vars = { path: _free_var_from_t(ir_path_types[path]) for path in self.ir_bounded } ir_inputs = self.build_ir_input(ir_vars, family.SMTFamily()) arch_inputs = self.build_arch_input(ir_vars, family.SMTFamily()) ir = self.ir_fc(family.SMTFamily())() arch = self.arch_fc(family.SMTFamily())() ir_out_values = self.parse_ir_output(ir(**ir_inputs)) arch_out_values = self.parse_arch_output(arch(**arch_inputs)) outputs = [] for ir_path, arch_path in self.obinding: if ir_path not in ir_out_values: raise ValueError(f"{ir_path} is not valid") if arch_path not in arch_out_values: raise ValueError(f"{arch_path} is not valid") outputs.append( ir_out_values[ir_path] != arch_out_values[arch_path]) formula = or_reduce(outputs) with smt.Solver(solver_name, logic=BV) as solver: solver.add_assertion(formula.value) verified = not solver.solve() if verified: return None else: return { path: solved_to_bv(var, solver) for path, var in ir_vars.items() }
def learn(self, domain, data, initial_indices=None): active_indices = list(range( len(data))) if initial_indices is None else initial_indices all_active_indices = active_indices self.observer.observe("initial", active_indices) formula = None with smt.Solver() as solver: while len(active_indices) > 0: solving_start = time.time() formula = self.learn_partial(solver, domain, data, active_indices) solving_time = time.time() - solving_start selection_start = time.time() new_active_indices = list( self.selection_strategy.select_active( domain, data, formula, all_active_indices)) active_indices = new_active_indices all_active_indices += active_indices selection_time = time.time() - selection_start self.observer.observe("iteration", formula, active_indices, solving_time, selection_time) return formula
def search_in_node(data, domain, node, h): if node.solvable == False: return None if node.literals is None: neg_data_gen = node.to_gen() possible_pos_lits = domain.bool_vars.copy() possible_neg_lits = domain.bool_vars.copy() for i in neg_data_gen: assert(not data[i][1]) possible_pos_lits = [l for l in possible_pos_lits if not data[i][0][l]] possible_neg_lits = [l for l in possible_neg_lits if data[i][0][l]] if not possible_neg_lits and not possible_pos_lits: break node.literals = [domain.get_symbol(b) for b in possible_pos_lits] + [~domain.get_symbol(b) for b in possible_neg_lits] pos_data = [i for i in range(len(data)) if data[i][1]] for l in possible_pos_lits: pos_data = [i for i in pos_data if not data[i][0][l]] for l in possible_neg_lits: pos_data = [i for i in pos_data if data[i][0][l]] if pos_data: node.next_level_node = build_hierarchy(data, domain, pos_data, False) #else it stays None if node.next_level_node is None: return smt.Or(node.literals) q = PriorityQueue() q.put(node.next_level_node) solutions = [] neg_idx = list(node.to_gen()) while not q.empty(): if len(solutions) + len(q.queue) > h: return None n = q.get() if n.clause is not None: solutions += [n.clause] continue if n.hl_min > 1: if n.left is None: node.solvable = False return None q.put(n.left) q.put(n.right) continue with smt.Solver() as solver: solution = find_iter(data, domain, list(n.to_gen()) + neg_idx,find_hl, solver) if solution is None: n.hl_min = 2 if n.left is None: #print("unsatisfiable node!") return None q.put(n.left) q.put(n.right) else: solutions += [solution] n.clause = solution saturate(data, domain, n) return smt.Or(solutions + node.literals)
def learn(self, domain, data, labels, initial_indices=None): if self.smt_solver: with smt.Solver() as solver: data, formula, labels = self.incremental_loop(domain, data, labels, initial_indices, solver) else: data, formula, labels = self.incremental_loop(domain, data, labels, initial_indices, None) return data, labels, formula
def search_cnf(data, problem): domain = problem.domain cost = 1 idx = range(len(data)) while True: for clauses in range(1, cost): with smt.Solver() as solver: solution = find_iter(data, domain, idx, find_cnf, solver, clauses, cost - clauses) if solution is not None: return solution cost += 1
def reduce(self, node_id, variables=None, fast=None): fast = self.fast if fast is None else fast if variables is None: variables = self._get_variables(node_id) self.variables = variables if fast: self.consistent = set() with smt.Solver() as solver: result_id = self._reduce(self.pool.get_node(node_id), solver, fast).node_id if fast: self.consistent = None return result_id
def check_equal(solver_name, smt_vars, expr1, expr2, name_binding): with smt.Solver(solver_name, logic=QF_BV) as solver: expr = expr1 != expr2 solver.add_assertion(expr.value) if not solver.solve(): return True else: model = solver.get_model() model = {k : model[v.value] for k,v in smt_vars.items()} logging.log(logging.DEBUG - 1, name_binding) logging.log(logging.DEBUG - 1, model) return False
def solve( self, solver_name: str = 'z3', custom_enumeration: tp.Mapping[type, tp.Callable] = {} ) -> tp.Union[None, RewriteRule]: if not self.has_bindings: return None with smt.Solver(solver_name, logic=BV) as solver: solver.add_assertion(self.formula) is_solved = solver.solve() if not is_solved: return None return rr_from_solver(solver, self)
def reduce(self, node_id, variables=None, fast=None): key = node_id if key in self.reduce_cache: # print("REDUCE HIT") return self.reduce_cache[key] fast = self.fast if fast is None else fast if fast: self.consistent = set() with smt.Solver(name="msat") as solver: result_id = self._reduce(node_id, self.pool.get_node(node_id), solver, fast).node_id if fast: self.consistent = None self.reduce_cache[key] = result_id return result_id
def test_convert_support(): x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL) a = smt.Symbol("a", smt.BOOL) formula = (x < 0) | (~a & (x < -1)) | smt.Ite(a, x < 4, x < 8) # Convert formula into abstracted one (replacing inequalities) env, repl_formula, literal_info = extract_and_replace_literals(formula) result = compile_to_sdd(formula=repl_formula, literals=literal_info, vtree=None) recovered = recover_formula(sdd_node=result, literals=literal_info, env=env) # print(pretty_print(recovered)) with smt.Solver() as solver: solver.add_assertion(~smt.Iff(formula, recovered)) # print(pretty_print(formula)) # print(pretty_print(recovered)) assert not solver.solve( ), f"Expected UNSAT but found model {solver.get_model()}"
def __init__(self, env, n_states, solver_name): """Except for "Copy-v0" all the algorithmic problems require some memory - they can't be solved with a policy that only sees the current input character. States are a generic way to add the necessary expressivity. The model decides which way to move and which character to write based on the current input character *and* the current state. It also makes a third decision at each timestep based on the current character/state: which state to move to next. 1 state (which is to say no states) is enough for Copy-v0. 2 states are enough for RepeatCopy, DuplicatedInput, and Reverse. Not clear how many are needed for the addition environments, but safe to say it's at least 3 (they use ternary numbers, and the model at least needs to know which of 3 possible digits it's adding to the digit it's currently looking at.) """ assert isinstance(env, gym.envs.algorithmic.algorithmic_env.AlgorithmicEnv) # Quantifier-free boolean logic - the simplest available, and all we need solver = sc.Solver(name=solver_name, logic=logics.QF_BOOL) self.helper = BoolSatHelper(solver, env, n_states) self.runner = AlgorithmicPolicyRunner(self.helper, env)
def test_op(op: WrappedOp, M: int, N: int, m: int, n: int): if not isinstance(op, WrappedOp): raise TypeError(op) if m not in range(M): raise ValueError((m, M)) if n not in range(N): raise ValueError((n, N)) x = ht.SMTBitVector[M]() x0 = x & ~ht.SMTBitVector[M](1 << m) x1 = x | ht.SMTBitVector[M](1 << m) l = op(x0) r = op(x1) assert type(l) is type(r) is_bit_output = isinstance(l, ht.SMTBit) if is_bit_output: assert N == 1 # assert n in [0, N] guarantees n == 0 f = (op(x0) == op(x1)) else: f = (op(x0)[n] == op(x1)[n]) with smt.Solver("z3", logic=pysmt.logics.BV) as solver: solver.add_assertion((~f).value) return solver.solve()
def integrate(self, domain, convex_bounds: List[LinearInequality], polynomial: Polynomial): # TODO Use power of linear forms? b_geq_a = [] formula = smt.TRUE() for bound in convex_bounds: integer_bound = bound.scale_to_integer() formula &= integer_bound.to_smt() b_geq_a.append([integer_bound.b()] + [-integer_bound.a(v) for v in domain.real_vars]) monomials = [(Fraction(value).limit_denominator(), self.key_to_exponents(domain, key)) for key, value in polynomial.poly_dict.items()] with TemporaryFile(suffix=".hrep.latte") as bounds_file: with TemporaryFile(suffix=".poly.latte") as poly_file: with open(bounds_file, "w") as bounds_ref: print("{} {}".format(len(b_geq_a), len(domain.real_vars) + 1), file=bounds_ref) print(*[" ".join(map(str, e)) for e in b_geq_a], sep="\n", file=bounds_ref) with open(poly_file, "w") as poly_ref: print("[{}]".format(",".join("[{},[{}]]".format(m[0], ",".join(map(str, m[1]))) for m in monomials)), file=poly_ref) command = "integrate --valuation=integrate {} --monomials={} {}"\ .format(self.algorithm, poly_file, bounds_file) try: output = check_output(command, shell=True, stderr=DEVNULL).decode() except CalledProcessError: with smt.Solver() as solver: solver.add_assertion(formula) solver.solve() try: solver.get_model() except InternalSolverError: return 0.0 raise match = re.search(self.pattern, output) if not match: return 0.0 return float(Fraction(int(match.group(1)), int(match.group(2))))
def search_in_hierarchy_lit(data, domain, nodes, halflines): clauses = [] for node in nodes: if node.hl_min == node.hl_max: clauses += [node.clause] halflines -= node.hl_max continue clause = None for h in range(node.hl_min, halflines + 1): with smt.Solver() as solver: clause = search_in_node(data, domain, node, h) if clause is not None: node.hl_max = h node.clause = clause clauses += [clause] halflines -= h saturate(data, domain, node) break; node.hl_min += 1 if clause is None: return None return smt.And(clauses)
def search_in_hierarchy(data, domain, nodes, halflines): clauses = [] for node in nodes: if node.hl_min == node.hl_max: clauses += [node.clause] halflines -= node.hl_max continue idx = [i for i in range(len(data)) if data[i][1]] + list(node.to_gen()) clause = None for h in range(node.hl_min, halflines + 1): with smt.Solver() as solver: clause = find_iter(data, domain, idx, find_clause, solver, h) if clause is not None: node.hl_max = h node.clause = clause clauses += [clause] halflines -= h saturate(data, domain, node) break; node.hl_min += 1 if clause is None: return None return smt.And(clauses)
def gen_mapping( peak_class: Peak, bv_isa: tp. Type[ISABuilder], #This is currently a hack that will be removed. smt_isa: tp.Type[ISABuilder], coreir_module: coreir.ModuleDef, coreir_model: tp.Callable, max_mappings: int, *, verbose: bool = False, solver_name: str = 'z3', constraints=[]): peak_inst = peak_class() #This cannot take any args peak_inputs = _convert_io_types(peak_class.__call__._peak_inputs_) peak_outputs = _convert_io_types(peak_class.__call__._peak_outputs_) core_inputs = { k if k != 'in' else 'in_': v.size for k, v in coreir_module.type.items() if v.is_input() } core_outputs = { k: v.size for k, v in coreir_module.type.items() if v.is_output() } core_smt_vars = {k: SMTBitVector[v]() for k, v in core_inputs.items()} core_smt_expr = coreir_model(**core_smt_vars) #The following is some really gross magic to generate all possible assignments #of core inputs / costants (None) to peak inputs core_inputs_by_size = _group_by_value(core_inputs) peak_inputs_by_size = _group_by_value(peak_inputs) assert core_inputs_by_size.keys() <= peak_inputs_by_size.keys() for k in core_inputs_by_size: assert len(core_inputs_by_size[k]) <= len(peak_inputs_by_size[k]) possible_matching = {} for size, pi in peak_inputs_by_size.items(): ci = core_inputs_by_size.setdefault(size, []) ci = list(it.chain(ci, it.repeat(None, len(pi) - len(ci)))) assert len(ci) == len(pi) for perm in it.permutations(pi): possible_matching.setdefault(size, []).append(list(zip(ci, perm))) del core_inputs_by_size del peak_inputs_by_size bindings = [] for l in it.product(*possible_matching.values()): bindings.append(list(it.chain(*l))) found = 0 if found >= max_mappings: return def f_fun(inst): return all(constraint(inst) for constraint in constraints) bv_isa_list = list(filter(f_fun, bv_isa.enumerate())) bv_isa_len = len(bv_isa_list) smt_isa_list = list(filter(f_fun, smt_isa.enumerate())) smt_isa_len = len(smt_isa_list) for ii, smt_inst in enumerate(smt_isa_list): if verbose: print(f"inst {ii+1}/{isa_len}") print(smt_inst) for bi, binding in enumerate(bindings): binding_dict = { k: core_smt_vars[v] if v is not None else SMTBitVector[peak_inputs[k]](0) for v, k in binding } name_binding = {k: v if v is not None else 0 for v, k in binding} if verbose: print(f"binding {bi+1}/{len(bindings)}") rvals = peak_inst(smt_inst, **binding_dict) if not isinstance(rvals, tuple): rvals = rvals, for idx, bv in enumerate(rvals): if isinstance(bv, (SMTBit, SMTBitVector)) and bv.value.get_type( ) == core_smt_expr.value.get_type(): with smt.Solver(solver_name, logic=QF_BV) as solver: expr = bv != core_smt_expr solver.add_assertion(expr.value) if not solver.solve(): #Create output and input map output_map = { "out": list( peak_class.__call__._peak_outputs_.items()) [idx][0] } input_map = {} for k, v in name_binding.items(): if v == 0: v = "0" elif v == "in_": v = "in" input_map[v] = k mapping = dict(instruction=bv_isa_list[ii], output_map=output_map, input_map=input_map) yield mapping found += 1 if found >= max_mappings: return
from pywmi import Domain from pywmi.engines.xsdd.smt_to_sdd import ( SddConversionWalker, recover_formula, compile_to_sdd, ) from pywmi.smt_print import pretty_print from pywmi.smt_math import PolynomialAlgebra try: from pysdd.sdd import SddManager except ImportError: SddManager = None try: with smt.Solver() as solver: smt_solver_available = True except NoSolverAvailableError: smt_solver_available = False pytestmark = pytest.mark.skipif(SddManager is None, reason="pysdd is not installed") @pytest.mark.skip( reason="Outdated test, SddConversionWalker only supports logical operations" ) def test_convert_weight(): x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL) a = smt.Symbol("a", smt.BOOL) weight_function = (smt.Ite(
(binding == 1).ite( inst.operand_1.match(T), (binding == 2).ite( inst.operand_1.match(Word), (binding == 4).ite( inst.operand_1.match(T), (binding == 8) & inst.operand_1.match(Word) ) ) ) ) sim_expr = sim(inst) for target in targets: target_expr = target(x, y) with smt.Solver('z3', logic=BV) as solver: solver.add_assertion(precondition.value) constraint = smt.ForAll([x.value, y.value, free_bit.value], (sim_expr == target_expr).value) solver.add_assertion(constraint) if solver.solve(): opcode_val = solver.get_value(opcode_bv.value).constant_value() b_val = solver.get_value(b.value).constant_value() tag_val = solver.get_value(tag_bv.value).constant_value() binding_val = solver.get_value(binding.value).constant_value() print(f'mapping found for {target.__name__}') print(f'binding: {binding_val}') print(f'opcode: {opcode_asm.disassemble(opcode_val)}') print(f'tag: {operand_1_asm.disassemble_tag(tag_val)}') print(f'b: {b_val}') print() else:
def __init__(self, formula, domain): super().__init__(formula) self.domain = domain self.solver = smt.Solver() self.solver.add_assertion(formula)
def prove_properties(junctions): width = max(max(y1 for _, y1, _ in junctions), max(y2 for _, _, y2 in junctions)) + 1 length = max(x for x, _, _ in junctions) formula, belts = balancer_flow_formula(junctions, width, length) supply = s.Plus(beltway[0].rho for beltway in belts) demand = s.Plus(beltway[-1].v for beltway in belts) max_theoretical_throughput = s.Min(supply, demand) actual_throughput = s.Plus(beltway[-1].flux for beltway in belts) fully_balanced_output = [] max_flux = s.Max(b[-1].flux for b in belts) for i, b1 in enumerate(belts): fully_balanced_output.append( s.Or(s.Equals(b1[-1].flux, max_flux), s.Equals(b1[-1].flux, b1[-1].v))) fully_balanced_output = s.And(fully_balanced_output) fully_balanced_input = [] max_flux = s.Max(b[0].flux for b in belts) for i, b1 in enumerate(belts): fully_balanced_input.append( s.Or(s.Equals(b1[0].flux, max_flux), s.Equals(b1[0].flux, b1[0].rho))) fully_balanced_input = s.And(fully_balanced_input) with s.Solver() as solver: solver.add_assertion(s.And(formula)) solver.push() solver.add_assertion( s.Not(s.Equals(max_theoretical_throughput, actual_throughput))) if not solver.solve(): print("It's throughput-unlimited!") else: print("It's throughput-limited; here's an example:") m = solver.get_model() inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts) outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts) print(f'Input lane densities: ({", ".join(map(str, inputs))})') print(f'Output lane velocities: ({", ".join(map(str, outputs))})') check_balancer_flow(junctions, inputs, outputs) solver.pop() solver.push() solver.add_assertion(s.Not(fully_balanced_input)) if not solver.solve(): print("It's input-balanced!") else: print("It's input-imbalanced; here's an example:") m = solver.get_model() inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts) outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts) print(f'Input lane densities: ({", ".join(map(str, inputs))})') print(f'Output lane velocities: ({", ".join(map(str, outputs))})') check_balancer_flow(junctions, inputs, outputs) solver.pop() solver.push() solver.add_assertion(s.Not(fully_balanced_output)) if solver.solve: print("It's output-balanced!") else: print("It's output-imbalanced; here's an example:") m = solver.get_model() inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts) outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts) print(f'Input lane densities: ({", ".join(map(str, inputs))})') print(f'Output lane velocities: ({", ".join(map(str, outputs))})') check_balancer_flow(junctions, inputs, outputs)
def perform(self): # process inputs if '.bench' in self.args.b: self.obf_cir = bench2circuit(self.args.o) self.oracle_cir = bench2circuit(self.args.b) else: logging.critical('verilog input is disabled! use main_formal') exit() # self.oracle_ast = ASTWrapper(parse_verilog(self.args.b), self.args.b) # self.obf_ast = ASTWrapper(parse_verilog(self.args.o), self.args.o) # # self.oracle_cir = self.oracle_ast.get_circuit(check_correctness=False, correct_order=False) # self.obf_cir = self.obf_ast.get_circuit(check_correctness=False, correct_order=False) self.oracle_cir.create_ce_circuit() self.obf_cir.create_ce_circuit() sort_circuits(self.oracle_cir, self.obf_cir) # perform attack self.solver_obf = pystm.Solver(name=self.solver_name) self.solver_key = pystm.Solver(name=self.solver_name) self.solver_oracle = pystm.Solver(name=self.solver_name) logging.warning('initial value for boundary={}, step={}, stop={}'.format(self.boundary, self.step, self.stop)) logging.warning('solver={}'.format(self.solver_name)) self.attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir) # add k0 != k1 self.solver_obf.add_assertion(self.attack_formulas.key_inequality_ckt) # assumptions for inequality of dip generator outputs assumptions = self.attack_formulas.dip_gen_assumption(1) # get initial states and the first copy of the circuit for i in range(2): c0, c1 = self.attack_formulas.obf_ckt_at_frame(i) for j in range(len(c0)): self.solver_obf.add_assertion(c0[j]) self.solver_obf.add_assertion(c1[j]) while 1: # query dip generator if self.solver_obf.is_sat(assumptions): dis_boolean = self.query_dip_generator() dis_formula = [] for i in range(1, len(dis_boolean) + 1): dis_formula.append(get_formulas(self.obf_cir.input_wires, dis_boolean[i-1], '@{}'.format(i))) logging.info(dis_formula) dis_out = self.query_oracle(dis_formula) self.add_dip_checker(dis_boolean, dis_out) self.iteration += 1 logging.warning('iteration={}, depth={}'.format(self.iteration, self.unroll_depth)) self.highest_depth = self.unroll_depth else: if (self.solver_obf.is_sat(pystm.TRUE()) or self.iteration == 0) and self.unroll_depth < self.stop: # two agreeing keys are found, but no dip can be found # also it should keep unrolling the circuit until at least one dip is found # then it can decide on uc success if self.unroll_depth == self.boundary: logging.warning('uc failed') # check ce if self.ce_check(): return True elif self.umc_check(): continue else: # increase boundary # self.unroll_depth += 1 self.boundary += self.step else: # increase unroll depth self.unroll_depth += 1 assumptions = self.attack_formulas.dip_gen_assumption(self.unroll_depth) c0, c1 = self.attack_formulas.obf_ckt_at_frame(self.unroll_depth) for i in range(len(c0)): self.solver_obf.add_assertion(c0[i]) self.solver_obf.add_assertion(c1[i]) logging.warning('increasing unroll depth to {}'.format(self.unroll_depth)) elif self.unroll_depth >= self.stop: logging.warning('stopped at {}'.format(self.stop)) self.print_keys() return True else: # key is unique logging.warning('uc successful') self.print_keys() return True
else: constraints.append(SMT.GT(var, SMT.Int(0))) for i in range(0, 3 * cap, 3): # add requirement that either a_i =0 or c_i = 0 constraints.append( SMT.Or(SMT.Equals(variables[i], SMT.Int(0)), SMT.Equals(variables[i + 2], SMT.Int(0)))) # add requirement that if a_i = 0 and c_i = 0 then b_i = 1 constraints.append( SMT.Or( SMT.Or(SMT.GT(variables[i], SMT.Int(0)), SMT.GT(variables[i + 2], SMT.Int(0))), SMT.Equals(variables[i + 1], SMT.Int(1)))) solver = SMT.Solver(name="z3") print("constraints:") for c in constraints: print(c) solver.add_assertion(c) # add equations # 24727*a_1 + 75235*b_1 + 50508*c_1 = 75235*a_2 + 125743*b_2 + 176251*c_2 solver.add_assertion( SMT.Equals( SMT.Plus( SMT.Plus(SMT.Times(SMT.Int(125743), variables[4]), SMT.Times(SMT.Int(75235), variables[3])), SMT.Times(SMT.Int(176251), variables[5])), SMT.Plus(
def reduce(self, node_id, variables=None): if variables is None: variables = self._get_variables(node_id) self.variables = variables with smt.Solver(name=self.solver) as solver: return self._reduce(self.pool.get_node(node_id), solver).node_id
def gen_mapping( peak_class : Peak, bv_isa : BoundMeta, #This is currently a hack that will be removed. smt_isa : BoundMeta, coreir_module : coreir.ModuleDef, coreir_model : tp.Callable, max_mappings : int, *, verbose : int = 0, solver_name : str = 'z3', constraints = [] ): if verbose == 1: logging.getLogger().setLevel(logging.DEBUG) elif verbose == 2: logging.getLogger().setLevel(logging.DEBUG - 1) peak_inst = peak_class() #This cannot take any args peak_inputs = _filter_io_types(peak_class.__call__._peak_inputs_) peak_outputs = _filter_io_types(peak_class.__call__._peak_outputs_) core_inputs = {k if k != 'in' else 'in_' : SMTBitVector[v.size] for k,v in coreir_module.type.items() if v.is_input()} core_outputs = {k : SMTBitVector[v.size] for k,v in coreir_module.type.items() if v.is_output()} core_smt_vars = {k : v() for k,v in core_inputs.items()} core_smt_expr = coreir_model(**core_smt_vars) #The following is some really gross magic to generate all possible assignments #of core inputs / costants (None) to peak inputs core_inputs_by_t = _group_by_value(core_inputs) peak_inputs_by_t = _group_by_value(peak_inputs) assert core_inputs_by_t.keys() <= peak_inputs_by_t.keys() for k in core_inputs_by_t: assert len(core_inputs_by_t[k]) <= len(peak_inputs_by_t[k]) possible_matching = {} for t, pi in peak_inputs_by_t.items(): ci = core_inputs_by_t.setdefault(t, []) ci = list(it.chain(ci, it.repeat(None, len(pi) - len(ci)))) assert len(ci) == len(pi) for perm in it.permutations(pi): possible_matching.setdefault(t, []).append(list(zip(ci, perm))) del core_inputs_by_t del peak_inputs_by_t bindings = [] for l in it.product(*possible_matching.values()): bindings.append(list(it.chain(*l))) found = 0 if found >= max_mappings: return def f_fun(inst): return all(constraint(inst) for constraint in constraints) logging.debug("Enumerating bv instructions") bv_isa_list = list(filter(f_fun, bv_isa.enumerate())) bv_isa_len = len(bv_isa_list) logging.debug("Enumerating smt instructions") smt_isa_list = list(filter(f_fun, smt_isa.enumerate())) smt_isa_len = len(smt_isa_list) logging.debug("Starting search") for ii,smt_inst in enumerate(smt_isa_list): logging.debug(f"inst {ii+1}/{bv_isa_len}") logging.debug(smt_inst) for bi,binding in enumerate(bindings): binding_dict = {k : core_smt_vars[v] if v is not None else peak_inputs[k](0) for v,k in binding} name_binding = {k : v if v is not None else 0 for v,k in binding} rvals = peak_inst(smt_inst, **binding_dict) if not isinstance(rvals, tuple): rvals = rvals, for idx, bv in enumerate(rvals): if isinstance(bv, (SMTBit, SMTBitVector)) and bv.value.get_type() == core_smt_expr.value.get_type(): with smt.Solver(solver_name, logic=QF_BV) as solver: if check_equal(solver_name, binding_dict, bv, core_smt_expr, name_binding): #Create output and input map output_map = {"out":list(peak_class.__call__._peak_outputs_.items())[idx][0]} input_map = {} for k,v in name_binding.items(): if v == 0: v = "0" elif v == "in_": v = "in" input_map[v] = k mapping = dict( instruction=bv_isa_list[ii], output_map=output_map, input_map=input_map ) yield mapping found += 1 if found >= max_mappings: return