Esempio n. 1
0
    def verify(self,
               solver_name: str = "z3") -> tp.Union[None, "CounterExample"]:
        # create free variable for each ir_val
        ir_path_types = _create_path_to_adt(
            strip_modifiers(self.ir_fc(family.SMTFamily()).input_t))
        ir_vars = {
            path: _free_var_from_t(ir_path_types[path])
            for path in self.ir_bounded
        }
        ir_inputs = self.build_ir_input(ir_vars, family.SMTFamily())
        arch_inputs = self.build_arch_input(ir_vars, family.SMTFamily())
        ir = self.ir_fc(family.SMTFamily())()
        arch = self.arch_fc(family.SMTFamily())()
        ir_out_values = self.parse_ir_output(ir(**ir_inputs))
        arch_out_values = self.parse_arch_output(arch(**arch_inputs))

        outputs = []
        for ir_path, arch_path in self.obinding:
            if ir_path not in ir_out_values:
                raise ValueError(f"{ir_path} is not valid")
            if arch_path not in arch_out_values:
                raise ValueError(f"{arch_path} is not valid")
            outputs.append(
                ir_out_values[ir_path] != arch_out_values[arch_path])
        formula = or_reduce(outputs)
        with smt.Solver(solver_name, logic=BV) as solver:
            solver.add_assertion(formula.value)
            verified = not solver.solve()
            if verified:
                return None
            else:
                return {
                    path: solved_to_bv(var, solver)
                    for path, var in ir_vars.items()
                }
Esempio n. 2
0
    def learn(self, domain, data, initial_indices=None):

        active_indices = list(range(
            len(data))) if initial_indices is None else initial_indices
        all_active_indices = active_indices

        self.observer.observe("initial", active_indices)

        formula = None

        with smt.Solver() as solver:
            while len(active_indices) > 0:
                solving_start = time.time()
                formula = self.learn_partial(solver, domain, data,
                                             active_indices)
                solving_time = time.time() - solving_start

                selection_start = time.time()
                new_active_indices = list(
                    self.selection_strategy.select_active(
                        domain, data, formula, all_active_indices))
                active_indices = new_active_indices
                all_active_indices += active_indices
                selection_time = time.time() - selection_start
                self.observer.observe("iteration", formula, active_indices,
                                      solving_time, selection_time)

        return formula
Esempio n. 3
0
def search_in_node(data, domain, node, h):
    if node.solvable == False:
        return None
    if node.literals is None:
        neg_data_gen = node.to_gen()
        possible_pos_lits = domain.bool_vars.copy()
        possible_neg_lits = domain.bool_vars.copy()
        for i in neg_data_gen:
            assert(not data[i][1])
            possible_pos_lits = [l for l in possible_pos_lits if not data[i][0][l]]
            possible_neg_lits = [l for l in possible_neg_lits if data[i][0][l]]
            if not possible_neg_lits and not possible_pos_lits:
                break
        node.literals = [domain.get_symbol(b) for b in possible_pos_lits] + [~domain.get_symbol(b) for b in possible_neg_lits]
        pos_data = [i for i in range(len(data)) if data[i][1]]
        for l in possible_pos_lits:
            pos_data = [i for i in pos_data if not data[i][0][l]]
        for l in possible_neg_lits:
            pos_data = [i for i in pos_data if data[i][0][l]]
        if pos_data:
            node.next_level_node = build_hierarchy(data, domain, pos_data, False)
        #else it stays None


    if node.next_level_node is None:
        return smt.Or(node.literals)
    q = PriorityQueue()
    q.put(node.next_level_node)
    solutions = []
    neg_idx = list(node.to_gen())
    while not q.empty():
        if len(solutions) + len(q.queue) > h:
            return None
        n = q.get()
        if n.clause is not None:
            solutions += [n.clause]
            continue
        if n.hl_min > 1:
            if n.left is None:
                node.solvable = False
                return None
            q.put(n.left)
            q.put(n.right)
            continue
        with smt.Solver() as solver:
            solution = find_iter(data, domain, list(n.to_gen()) + neg_idx,find_hl, solver)
            if solution is None:
                n.hl_min = 2
                if n.left is None:
                    #print("unsatisfiable node!")
                    return None
                q.put(n.left)
                q.put(n.right)
            else:
                solutions += [solution]
                n.clause = solution
                saturate(data, domain, n)
    return smt.Or(solutions + node.literals)
            
        
Esempio n. 4
0
    def learn(self, domain, data, labels, initial_indices=None):
        if self.smt_solver:
            with smt.Solver() as solver:
                data, formula, labels = self.incremental_loop(domain, data, labels, initial_indices, solver)
        else:
            data, formula, labels = self.incremental_loop(domain, data, labels, initial_indices, None)

        return data, labels, formula
Esempio n. 5
0
def search_cnf(data, problem):
    domain = problem.domain
    cost = 1
    idx = range(len(data))
    while True:
        for clauses in range(1, cost):
            with smt.Solver() as solver:
                solution = find_iter(data, domain, idx, find_cnf, solver, clauses, cost - clauses)
                if solution is not None:
                    return solution
        cost += 1
Esempio n. 6
0
 def reduce(self, node_id, variables=None, fast=None):
     fast = self.fast if fast is None else fast
     if variables is None:
         variables = self._get_variables(node_id)
     self.variables = variables
     if fast:
         self.consistent = set()
     with smt.Solver() as solver:
         result_id = self._reduce(self.pool.get_node(node_id), solver, fast).node_id
         if fast:
             self.consistent = None
         return result_id
Esempio n. 7
0
def check_equal(solver_name, smt_vars, expr1, expr2, name_binding):
    with smt.Solver(solver_name, logic=QF_BV) as solver:
        expr = expr1 != expr2
        solver.add_assertion(expr.value)
        if not solver.solve():
            return True
        else:
            model = solver.get_model()
            model = {k : model[v.value] for k,v in smt_vars.items()}
            logging.log(logging.DEBUG - 1, name_binding)
            logging.log(logging.DEBUG - 1, model)
            return False
Esempio n. 8
0
    def solve(
        self,
        solver_name: str = 'z3',
        custom_enumeration: tp.Mapping[type, tp.Callable] = {}
    ) -> tp.Union[None, RewriteRule]:
        if not self.has_bindings:
            return None

        with smt.Solver(solver_name, logic=BV) as solver:
            solver.add_assertion(self.formula)
            is_solved = solver.solve()
            if not is_solved:
                return None
            return rr_from_solver(solver, self)
Esempio n. 9
0
    def reduce(self, node_id, variables=None, fast=None):
        key = node_id
        if key in self.reduce_cache:
            # print("REDUCE HIT")
            return self.reduce_cache[key]

        fast = self.fast if fast is None else fast
        if fast:
            self.consistent = set()
        with smt.Solver(name="msat") as solver:
            result_id = self._reduce(node_id, self.pool.get_node(node_id),
                                     solver, fast).node_id
            if fast:
                self.consistent = None
            self.reduce_cache[key] = result_id
            return result_id
Esempio n. 10
0
def test_convert_support():
    x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL)
    a = smt.Symbol("a", smt.BOOL)
    formula = (x < 0) | (~a & (x < -1)) | smt.Ite(a, x < 4, x < 8)
    # Convert formula into abstracted one (replacing inequalities)
    env, repl_formula, literal_info = extract_and_replace_literals(formula)
    result = compile_to_sdd(formula=repl_formula,
                            literals=literal_info,
                            vtree=None)
    recovered = recover_formula(sdd_node=result,
                                literals=literal_info,
                                env=env)
    # print(pretty_print(recovered))
    with smt.Solver() as solver:
        solver.add_assertion(~smt.Iff(formula, recovered))
        # print(pretty_print(formula))
        # print(pretty_print(recovered))
        assert not solver.solve(
        ), f"Expected UNSAT but found model {solver.get_model()}"
Esempio n. 11
0
    def __init__(self, env, n_states, solver_name):
        """Except for "Copy-v0" all the algorithmic problems require some memory - 
    they can't be solved with a policy that only sees the current input character.

    States are a generic way to add the necessary expressivity. The model decides
    which way to move and which character to write based on the current input 
    character *and* the current state. It also makes a third decision at each timestep
    based on the current character/state: which state to move to next. 

    1 state (which is to say no states) is enough for Copy-v0. 2 states are enough
    for RepeatCopy, DuplicatedInput, and Reverse. Not clear how many are needed
    for the addition environments, but safe to say it's at least 3 (they use 
    ternary numbers, and the model at least needs to know which of 3 possible digits
    it's adding to the digit it's currently looking at.)
    """
        assert isinstance(env,
                          gym.envs.algorithmic.algorithmic_env.AlgorithmicEnv)
        # Quantifier-free boolean logic - the simplest available, and all we need
        solver = sc.Solver(name=solver_name, logic=logics.QF_BOOL)
        self.helper = BoolSatHelper(solver, env, n_states)
        self.runner = AlgorithmicPolicyRunner(self.helper, env)
Esempio n. 12
0
def test_op(op: WrappedOp, M: int, N: int, m: int, n: int):
    if not isinstance(op, WrappedOp):
        raise TypeError(op)
    if m not in range(M):
        raise ValueError((m, M))
    if n not in range(N):
        raise ValueError((n, N))
    x = ht.SMTBitVector[M]()
    x0 = x & ~ht.SMTBitVector[M](1 << m)
    x1 = x | ht.SMTBitVector[M](1 << m)
    l = op(x0)
    r = op(x1)
    assert type(l) is type(r)
    is_bit_output = isinstance(l, ht.SMTBit)
    if is_bit_output:
        assert N == 1  # assert n in [0, N] guarantees n == 0
        f = (op(x0) == op(x1))
    else:
        f = (op(x0)[n] == op(x1)[n])
    with smt.Solver("z3", logic=pysmt.logics.BV) as solver:
        solver.add_assertion((~f).value)
        return solver.solve()
    def integrate(self, domain, convex_bounds: List[LinearInequality], polynomial: Polynomial):
        # TODO Use power of linear forms?
        b_geq_a = []
        formula = smt.TRUE()
        for bound in convex_bounds:
            integer_bound = bound.scale_to_integer()
            formula &= integer_bound.to_smt()
            b_geq_a.append([integer_bound.b()] + [-integer_bound.a(v) for v in domain.real_vars])

        monomials = [(Fraction(value).limit_denominator(), self.key_to_exponents(domain, key))
                     for key, value in polynomial.poly_dict.items()]

        with TemporaryFile(suffix=".hrep.latte") as bounds_file:
            with TemporaryFile(suffix=".poly.latte") as poly_file:
                with open(bounds_file, "w") as bounds_ref:
                    print("{} {}".format(len(b_geq_a), len(domain.real_vars) + 1), file=bounds_ref)
                    print(*[" ".join(map(str, e)) for e in b_geq_a], sep="\n", file=bounds_ref)

                with open(poly_file, "w") as poly_ref:
                    print("[{}]".format(",".join("[{},[{}]]".format(m[0], ",".join(map(str, m[1])))
                                                 for m in monomials)), file=poly_ref)

                command = "integrate --valuation=integrate {} --monomials={} {}"\
                    .format(self.algorithm, poly_file, bounds_file)
                try:
                    output = check_output(command, shell=True, stderr=DEVNULL).decode()
                except CalledProcessError:
                    with smt.Solver() as solver:
                        solver.add_assertion(formula)
                        solver.solve()
                        try:
                            solver.get_model()
                        except InternalSolverError:
                            return 0.0
                    raise
                match = re.search(self.pattern, output)
                if not match:
                    return 0.0
                return float(Fraction(int(match.group(1)), int(match.group(2))))
Esempio n. 14
0
def search_in_hierarchy_lit(data, domain, nodes, halflines):
    clauses = []
    for node in nodes:
        if node.hl_min == node.hl_max:
            clauses += [node.clause]
            halflines -= node.hl_max
            continue
        clause = None
        for h in range(node.hl_min, halflines + 1):
            with smt.Solver() as solver:
                clause = search_in_node(data, domain, node, h)
                if clause is not None:
                    node.hl_max = h
                    node.clause = clause
                    clauses += [clause]
                    halflines -= h
                    saturate(data, domain, node)
                    break;
            node.hl_min += 1
        if clause is None:
            return None
    return smt.And(clauses)
Esempio n. 15
0
def search_in_hierarchy(data, domain, nodes, halflines):
    clauses = []
    for node in nodes:
        if node.hl_min == node.hl_max:
            clauses += [node.clause]
            halflines -= node.hl_max
            continue
        idx = [i for i in range(len(data)) if data[i][1]] + list(node.to_gen())
        clause = None
        for h in range(node.hl_min, halflines + 1):
            with smt.Solver() as solver:
                clause = find_iter(data, domain, idx, find_clause, solver, h)
                if clause is not None:
                    node.hl_max = h
                    node.clause = clause
                    clauses += [clause]
                    halflines -= h
                    saturate(data, domain, node)
                    break;
            node.hl_min += 1
        if clause is None:
            return None
    return smt.And(clauses)
Esempio n. 16
0
File: mapper.py Progetto: Kuree/peak
def gen_mapping(
        peak_class: Peak,
        bv_isa: tp.
    Type[ISABuilder],  #This is currently a hack that will be removed.
        smt_isa: tp.Type[ISABuilder],
        coreir_module: coreir.ModuleDef,
        coreir_model: tp.Callable,
        max_mappings: int,
        *,
        verbose: bool = False,
        solver_name: str = 'z3',
        constraints=[]):

    peak_inst = peak_class()  #This cannot take any args

    peak_inputs = _convert_io_types(peak_class.__call__._peak_inputs_)
    peak_outputs = _convert_io_types(peak_class.__call__._peak_outputs_)

    core_inputs = {
        k if k != 'in' else 'in_': v.size
        for k, v in coreir_module.type.items() if v.is_input()
    }
    core_outputs = {
        k: v.size
        for k, v in coreir_module.type.items() if v.is_output()
    }
    core_smt_vars = {k: SMTBitVector[v]() for k, v in core_inputs.items()}
    core_smt_expr = coreir_model(**core_smt_vars)

    #The following is some really gross magic to generate all possible assignments
    #of core inputs / costants (None) to peak inputs
    core_inputs_by_size = _group_by_value(core_inputs)
    peak_inputs_by_size = _group_by_value(peak_inputs)

    assert core_inputs_by_size.keys() <= peak_inputs_by_size.keys()
    for k in core_inputs_by_size:
        assert len(core_inputs_by_size[k]) <= len(peak_inputs_by_size[k])

    possible_matching = {}
    for size, pi in peak_inputs_by_size.items():
        ci = core_inputs_by_size.setdefault(size, [])
        ci = list(it.chain(ci, it.repeat(None, len(pi) - len(ci))))
        assert len(ci) == len(pi)
        for perm in it.permutations(pi):
            possible_matching.setdefault(size, []).append(list(zip(ci, perm)))

    del core_inputs_by_size
    del peak_inputs_by_size

    bindings = []
    for l in it.product(*possible_matching.values()):
        bindings.append(list(it.chain(*l)))
    found = 0
    if found >= max_mappings:
        return

    def f_fun(inst):
        return all(constraint(inst) for constraint in constraints)

    bv_isa_list = list(filter(f_fun, bv_isa.enumerate()))
    bv_isa_len = len(bv_isa_list)
    smt_isa_list = list(filter(f_fun, smt_isa.enumerate()))
    smt_isa_len = len(smt_isa_list)

    for ii, smt_inst in enumerate(smt_isa_list):
        if verbose:
            print(f"inst {ii+1}/{isa_len}")
            print(smt_inst)

        for bi, binding in enumerate(bindings):
            binding_dict = {
                k: core_smt_vars[v]
                if v is not None else SMTBitVector[peak_inputs[k]](0)
                for v, k in binding
            }
            name_binding = {k: v if v is not None else 0 for v, k in binding}
            if verbose:
                print(f"binding {bi+1}/{len(bindings)}")

            rvals = peak_inst(smt_inst, **binding_dict)
            if not isinstance(rvals, tuple):
                rvals = rvals,

            for idx, bv in enumerate(rvals):
                if isinstance(bv,
                              (SMTBit, SMTBitVector)) and bv.value.get_type(
                              ) == core_smt_expr.value.get_type():
                    with smt.Solver(solver_name, logic=QF_BV) as solver:
                        expr = bv != core_smt_expr
                        solver.add_assertion(expr.value)
                        if not solver.solve():
                            #Create output and input map
                            output_map = {
                                "out":
                                list(
                                    peak_class.__call__._peak_outputs_.items())
                                [idx][0]
                            }
                            input_map = {}
                            for k, v in name_binding.items():
                                if v == 0:
                                    v = "0"
                                elif v == "in_":
                                    v = "in"
                                input_map[v] = k

                            mapping = dict(instruction=bv_isa_list[ii],
                                           output_map=output_map,
                                           input_map=input_map)
                            yield mapping
                            found += 1
                            if found >= max_mappings:
                                return
Esempio n. 17
0
from pywmi import Domain
from pywmi.engines.xsdd.smt_to_sdd import (
    SddConversionWalker,
    recover_formula,
    compile_to_sdd,
)
from pywmi.smt_print import pretty_print
from pywmi.smt_math import PolynomialAlgebra

try:
    from pysdd.sdd import SddManager
except ImportError:
    SddManager = None

try:
    with smt.Solver() as solver:
        smt_solver_available = True
except NoSolverAvailableError:
    smt_solver_available = False

pytestmark = pytest.mark.skipif(SddManager is None,
                                reason="pysdd is not installed")


@pytest.mark.skip(
    reason="Outdated test, SddConversionWalker only supports logical operations"
)
def test_convert_weight():
    x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL)
    a = smt.Symbol("a", smt.BOOL)
    weight_function = (smt.Ite(
Esempio n. 18
0
    (binding == 1).ite(
        inst.operand_1.match(T),
        (binding == 2).ite(
            inst.operand_1.match(Word),
            (binding == 4).ite(
                inst.operand_1.match(T),
                (binding == 8) & inst.operand_1.match(Word)
            )
        )
    )
)

sim_expr = sim(inst)
for target in targets:
    target_expr = target(x, y)
    with smt.Solver('z3', logic=BV) as solver:
        solver.add_assertion(precondition.value)
        constraint = smt.ForAll([x.value, y.value, free_bit.value], (sim_expr == target_expr).value)
        solver.add_assertion(constraint)
        if solver.solve():
            opcode_val = solver.get_value(opcode_bv.value).constant_value()
            b_val = solver.get_value(b.value).constant_value()
            tag_val = solver.get_value(tag_bv.value).constant_value()
            binding_val = solver.get_value(binding.value).constant_value()
            print(f'mapping found for {target.__name__}')
            print(f'binding: {binding_val}')
            print(f'opcode: {opcode_asm.disassemble(opcode_val)}')
            print(f'tag: {operand_1_asm.disassemble_tag(tag_val)}')
            print(f'b: {b_val}')
            print()
        else:
Esempio n. 19
0
 def __init__(self, formula, domain):
     super().__init__(formula)
     self.domain = domain
     self.solver = smt.Solver()
     self.solver.add_assertion(formula)
Esempio n. 20
0
def prove_properties(junctions):
    width = max(max(y1 for _, y1, _ in junctions),
                max(y2 for _, _, y2 in junctions)) + 1
    length = max(x for x, _, _ in junctions)
    formula, belts = balancer_flow_formula(junctions, width, length)

    supply = s.Plus(beltway[0].rho for beltway in belts)
    demand = s.Plus(beltway[-1].v for beltway in belts)
    max_theoretical_throughput = s.Min(supply, demand)
    actual_throughput = s.Plus(beltway[-1].flux for beltway in belts)

    fully_balanced_output = []
    max_flux = s.Max(b[-1].flux for b in belts)
    for i, b1 in enumerate(belts):
        fully_balanced_output.append(
            s.Or(s.Equals(b1[-1].flux, max_flux),
                 s.Equals(b1[-1].flux, b1[-1].v)))
    fully_balanced_output = s.And(fully_balanced_output)

    fully_balanced_input = []
    max_flux = s.Max(b[0].flux for b in belts)
    for i, b1 in enumerate(belts):
        fully_balanced_input.append(
            s.Or(s.Equals(b1[0].flux, max_flux),
                 s.Equals(b1[0].flux, b1[0].rho)))
    fully_balanced_input = s.And(fully_balanced_input)

    with s.Solver() as solver:
        solver.add_assertion(s.And(formula))
        solver.push()
        solver.add_assertion(
            s.Not(s.Equals(max_theoretical_throughput, actual_throughput)))

        if not solver.solve():
            print("It's throughput-unlimited!")
        else:
            print("It's throughput-limited; here's an example:")
            m = solver.get_model()
            inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts)
            outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts)
            print(f'Input lane densities: ({", ".join(map(str, inputs))})')
            print(f'Output lane velocities: ({", ".join(map(str, outputs))})')
            check_balancer_flow(junctions, inputs, outputs)

        solver.pop()
        solver.push()
        solver.add_assertion(s.Not(fully_balanced_input))

        if not solver.solve():
            print("It's input-balanced!")
        else:
            print("It's input-imbalanced; here's an example:")
            m = solver.get_model()
            inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts)
            outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts)
            print(f'Input lane densities: ({", ".join(map(str, inputs))})')
            print(f'Output lane velocities: ({", ".join(map(str, outputs))})')
            check_balancer_flow(junctions, inputs, outputs)

        solver.pop()
        solver.push()
        solver.add_assertion(s.Not(fully_balanced_output))
        if solver.solve:
            print("It's output-balanced!")
        else:
            print("It's output-imbalanced; here's an example:")
            m = solver.get_model()
            inputs = tuple(m.get_py_value(beltway[0].rho) for beltway in belts)
            outputs = tuple(m.get_py_value(beltway[-1].v) for beltway in belts)
            print(f'Input lane densities: ({", ".join(map(str, inputs))})')
            print(f'Output lane velocities: ({", ".join(map(str, outputs))})')
            check_balancer_flow(junctions, inputs, outputs)
Esempio n. 21
0
    def perform(self):
        # process inputs
        if '.bench' in self.args.b:
            self.obf_cir = bench2circuit(self.args.o)
            self.oracle_cir = bench2circuit(self.args.b)
        else:
            logging.critical('verilog input is disabled! use main_formal')
            exit()
        #     self.oracle_ast = ASTWrapper(parse_verilog(self.args.b), self.args.b)
        #     self.obf_ast = ASTWrapper(parse_verilog(self.args.o), self.args.o)
        #
        #     self.oracle_cir = self.oracle_ast.get_circuit(check_correctness=False, correct_order=False)
        #     self.obf_cir = self.obf_ast.get_circuit(check_correctness=False, correct_order=False)

        self.oracle_cir.create_ce_circuit()
        self.obf_cir.create_ce_circuit()
        sort_circuits(self.oracle_cir, self.obf_cir)

        # perform attack
        self.solver_obf = pystm.Solver(name=self.solver_name)
        self.solver_key = pystm.Solver(name=self.solver_name)
        self.solver_oracle = pystm.Solver(name=self.solver_name)

        logging.warning('initial value for boundary={}, step={}, stop={}'.format(self.boundary, self.step, self.stop))
        logging.warning('solver={}'.format(self.solver_name))

        self.attack_formulas = FormulaGenerator(self.oracle_cir, self.obf_cir)

        # add k0 != k1
        self.solver_obf.add_assertion(self.attack_formulas.key_inequality_ckt)

        # assumptions for inequality of dip generator outputs
        assumptions = self.attack_formulas.dip_gen_assumption(1)

        # get initial states and the first copy of the circuit
        for i in range(2):
            c0, c1 = self.attack_formulas.obf_ckt_at_frame(i)
            for j in range(len(c0)):
                self.solver_obf.add_assertion(c0[j])
                self.solver_obf.add_assertion(c1[j])

        while 1:
            # query dip generator
            if self.solver_obf.is_sat(assumptions):
                dis_boolean = self.query_dip_generator()

                dis_formula = []
                for i in range(1, len(dis_boolean) + 1):
                    dis_formula.append(get_formulas(self.obf_cir.input_wires, dis_boolean[i-1], '@{}'.format(i)))
                logging.info(dis_formula)

                dis_out = self.query_oracle(dis_formula)
                self.add_dip_checker(dis_boolean, dis_out)

                self.iteration += 1
                logging.warning('iteration={}, depth={}'.format(self.iteration, self.unroll_depth))
                self.highest_depth = self.unroll_depth
            else:
                if (self.solver_obf.is_sat(pystm.TRUE()) or self.iteration == 0) and self.unroll_depth < self.stop:
                    # two agreeing keys are found, but no dip can be found
                    # also it should keep unrolling the circuit until at least one dip is found
                    #  then it can decide on uc success
                    if self.unroll_depth == self.boundary:
                        logging.warning('uc failed')

                        # check ce
                        if self.ce_check():
                            return True
                        elif self.umc_check():
                            continue
                        else:
                            # increase boundary
                            # self.unroll_depth += 1
                            self.boundary += self.step
                    else:
                        # increase unroll depth
                        self.unroll_depth += 1
                        assumptions = self.attack_formulas.dip_gen_assumption(self.unroll_depth)

                        c0, c1 = self.attack_formulas.obf_ckt_at_frame(self.unroll_depth)
                        for i in range(len(c0)):
                            self.solver_obf.add_assertion(c0[i])
                            self.solver_obf.add_assertion(c1[i])
                        logging.warning('increasing unroll depth to {}'.format(self.unroll_depth))
                elif self.unroll_depth >= self.stop:
                    logging.warning('stopped at {}'.format(self.stop))
                    self.print_keys()
                    return True
                else:
                    # key is unique
                    logging.warning('uc successful')
                    self.print_keys()
                    return True
Esempio n. 22
0
    else:
        constraints.append(SMT.GT(var, SMT.Int(0)))

for i in range(0, 3 * cap, 3):
    # add requirement that either a_i =0 or c_i = 0
    constraints.append(
        SMT.Or(SMT.Equals(variables[i], SMT.Int(0)),
               SMT.Equals(variables[i + 2], SMT.Int(0))))
    # add requirement that if a_i = 0 and c_i = 0 then b_i = 1
    constraints.append(
        SMT.Or(
            SMT.Or(SMT.GT(variables[i], SMT.Int(0)),
                   SMT.GT(variables[i + 2], SMT.Int(0))),
            SMT.Equals(variables[i + 1], SMT.Int(1))))

solver = SMT.Solver(name="z3")
print("constraints:")
for c in constraints:
    print(c)
    solver.add_assertion(c)

# add equations

# 24727*a_1 + 75235*b_1 + 50508*c_1 = 75235*a_2 + 125743*b_2 + 176251*c_2
solver.add_assertion(
    SMT.Equals(
        SMT.Plus(
            SMT.Plus(SMT.Times(SMT.Int(125743), variables[4]),
                     SMT.Times(SMT.Int(75235), variables[3])),
            SMT.Times(SMT.Int(176251), variables[5])),
        SMT.Plus(
Esempio n. 23
0
 def reduce(self, node_id, variables=None):
     if variables is None:
         variables = self._get_variables(node_id)
     self.variables = variables
     with smt.Solver(name=self.solver) as solver:
         return self._reduce(self.pool.get_node(node_id), solver).node_id
Esempio n. 24
0
def gen_mapping(
        peak_class : Peak,
        bv_isa : BoundMeta, #This is currently a hack that will be removed.
        smt_isa : BoundMeta,
        coreir_module : coreir.ModuleDef,
        coreir_model : tp.Callable,
        max_mappings : int,
        *,
        verbose : int = 0,
        solver_name : str = 'z3',
        constraints = []
        ):

    if verbose == 1:
        logging.getLogger().setLevel(logging.DEBUG)
    elif verbose == 2:
        logging.getLogger().setLevel(logging.DEBUG - 1)

    peak_inst = peak_class() #This cannot take any args

    peak_inputs = _filter_io_types(peak_class.__call__._peak_inputs_)
    peak_outputs = _filter_io_types(peak_class.__call__._peak_outputs_)

    core_inputs = {k if k != 'in' else 'in_' : SMTBitVector[v.size] for k,v in coreir_module.type.items() if v.is_input()}
    core_outputs = {k : SMTBitVector[v.size] for k,v in  coreir_module.type.items() if v.is_output()}
    core_smt_vars = {k : v() for k,v in core_inputs.items()}
    core_smt_expr = coreir_model(**core_smt_vars)

    #The following is some really gross magic to generate all possible assignments
    #of core inputs / costants (None) to peak inputs
    core_inputs_by_t = _group_by_value(core_inputs)
    peak_inputs_by_t = _group_by_value(peak_inputs)

    assert core_inputs_by_t.keys() <= peak_inputs_by_t.keys()
    for k in core_inputs_by_t:
        assert len(core_inputs_by_t[k]) <= len(peak_inputs_by_t[k])

    possible_matching = {}
    for t, pi in peak_inputs_by_t.items():
        ci = core_inputs_by_t.setdefault(t, [])
        ci = list(it.chain(ci, it.repeat(None, len(pi) - len(ci))))
        assert len(ci) == len(pi)
        for perm in it.permutations(pi):
            possible_matching.setdefault(t, []).append(list(zip(ci, perm)))

    del core_inputs_by_t
    del peak_inputs_by_t

    bindings = []
    for l in it.product(*possible_matching.values()):
        bindings.append(list(it.chain(*l)))
    found = 0
    if found >= max_mappings:
        return
    def f_fun(inst):
        return all(constraint(inst) for constraint in constraints)

    logging.debug("Enumerating bv instructions")
    bv_isa_list = list(filter(f_fun, bv_isa.enumerate()))
    bv_isa_len = len(bv_isa_list)
    logging.debug("Enumerating smt instructions")
    smt_isa_list = list(filter(f_fun, smt_isa.enumerate()))
    smt_isa_len = len(smt_isa_list)

    logging.debug("Starting search")

    for ii,smt_inst in enumerate(smt_isa_list):
        logging.debug(f"inst {ii+1}/{bv_isa_len}")
        logging.debug(smt_inst)

        for bi,binding in enumerate(bindings):
            binding_dict = {k : core_smt_vars[v] if v is not None else peak_inputs[k](0) for v,k in binding}
            name_binding = {k : v if v is not None else 0 for v,k in binding}

            rvals = peak_inst(smt_inst, **binding_dict)
            if not isinstance(rvals, tuple):
                rvals = rvals,

            for idx, bv in enumerate(rvals):
                if isinstance(bv, (SMTBit, SMTBitVector)) and bv.value.get_type() == core_smt_expr.value.get_type():
                    with smt.Solver(solver_name, logic=QF_BV) as solver:
                        if check_equal(solver_name, binding_dict, bv, core_smt_expr, name_binding):
                            #Create output and input map
                            output_map = {"out":list(peak_class.__call__._peak_outputs_.items())[idx][0]}
                            input_map = {}
                            for k,v in name_binding.items():
                                if v == 0:
                                    v = "0"
                                elif v == "in_":
                                    v = "in"
                                input_map[v] = k

                            mapping = dict(
                                instruction=bv_isa_list[ii],
                                output_map=output_map,
                                input_map=input_map
                            )
                            yield mapping
                            found  += 1
                            if found >= max_mappings:
                                return