コード例 #1
0
    def multi_distinguish(self, regexes):
        start = time.time()
        # Problem: cannot distinguish more than 4 regexes at once: it takes forever.
        # Solution: use only 4 randomly selected regexes for the SMT maximization,
        # and then add the others to the solution.
        if len(regexes) <= 4:
            selected_regexes = regexes
            others = []
        else:
            random.seed('regex')
            random.shuffle(regexes)
            selected_regexes = regexes[:4]
            others = regexes[4:]
        solver = z3.Optimize()

        z3_regexes = []
        for regex in selected_regexes:
            z3_regex = self._toz3.eval(regex)
            z3_regexes.append(z3_regex)

        dist = z3.String("distinguishing")
        # solver.add(z3.Length(dist) <= 6)

        ro_z3 = []
        for i, z3_regex in enumerate(z3_regexes):
            ro = z3.Bool(f"ro_{i}")
            ro_z3.append(ro)
            solver.add(ro == z3.InRe(dist, z3_regex))

        # ro_z3[i] == true if dist matches regex[i].

        big_or = []
        for ro_i, ro_j in combinations(ro_z3, 2):
            big_or.append(z3.Xor(ro_i, ro_j))
            solver.add_soft(z3.Xor(ro_i, ro_j))
        solver.add(z3.Or(big_or))  # at least one regex is distinguished

        if solver.check() == z3.sat:
            # print(solver.model())
            print("took", round(time.time() - start, 2), "seconds")
            keep_if_valid = []
            keep_if_invalid = []
            dist_input = str(solver.model()[dist]).strip('"')
            for i, ro in enumerate(ro_z3):
                if solver.model()[ro]:
                    keep_if_valid.append(selected_regexes[i])
                else:
                    keep_if_invalid.append(selected_regexes[i])
                smallest_regex = min(selected_regexes,
                                     key=lambda r: len(self._printer.eval(r)))
            return dist_input, keep_if_valid, keep_if_invalid, others
        else:
            return None, None, None, None
コード例 #2
0
    def distinguish(self, models: List[List[Tuple]]):
        """ Find distinguishing input for sets of conditions in models """
        solver = z3.Optimize()
        distinct_models = keep_distinct(models)
        cs = []
        logger.info(f"Distinguishing {distinct_models}")
        for cap_idx in range(len(self._capture_groups)):
            cs.append(z3.Int(f"c{cap_idx}"))

        sat_ms = []
        for m_idx, model in enumerate(distinct_models):
            sat_ms.append(z3.Bool(f"sat_m{m_idx}"))
            solver.add(self._get_sat_m_constraint(cs, model, sat_ms[m_idx]))

        # solver.add(z3.Xor(sat_m1, sat_m2)) # For conversational clarification
        big_or = []
        for m_i, m_j in combinations(
                sat_ms, 2):  # maximisation objective from multi-distinguish
            big_or.append(z3.Xor(m_i, m_j))
            solver.add_soft(z3.Xor(m_i, m_j))
        solver.add(
            z3.Or(big_or)
        )  # at least one set of conditions is distinguished from the rest

        if solver.check() == z3.sat:
            valid_ex = self._valid_example
            for c_idx, c_var in enumerate(cs):
                if solver.model()[c_var] is None:
                    continue
                c_val = str(solver.model()[c_var])
                regex_str = self._interpreter.eval(
                    self._regex, captures=[self._capture_groups[c_idx]])
                compiled_re = re.compile(regex_str)
                cap_substr = compiled_re.fullmatch(valid_ex).groups()[0]
                c_val = c_val.rjust(len(cap_substr), '0')
                valid_ex = valid_ex.replace(cap_substr, c_val, 1)
            keep_if_valid = []
            keep_if_invalid = []
            for m_idx in range(len(distinct_models)):
                if solver.model()[sat_ms[m_idx]]:
                    keep_if_valid.append(models[m_idx])
                else:
                    keep_if_invalid.append(models[m_idx])
            logger.info(f"Dist. input: {valid_ex}")
            return valid_ex, keep_if_valid, keep_if_invalid

        else:
            logger.info(f"Indistinguishable")
            return None, None, None
コード例 #3
0
def _convert_expr(e, variables_dict):
    if isinstance(e, (bool, int)):
        return e
    if not isinstance(e, Expr):
        raise TypeError()
    if isinstance(e, (BoolVar, IntVar)):
        return variables_dict[e.id]
    else:
        operands = list(
            map(lambda x: _convert_expr(x, variables_dict), e.operands))
        if e.op == Op.NEG:
            return -operands[0]
        elif e.op == Op.ADD:
            ret = operands[0]
            for i in range(1, len(operands)):
                ret = ret + operands[i]
            return ret
        elif e.op == Op.SUB:
            ret = operands[0]
            for i in range(1, len(operands)):
                ret = ret - operands[i]
            return ret
        elif e.op == Op.MUL:
            ret = operands[0]
            for i in range(1, len(operands)):
                ret = ret * operands[i]
            return ret
        elif e.op == Op.MOD:
            ret = operands[0]
            for i in range(1, len(operands)):
                ret = ret % operands[i]
            return ret
        elif e.op == Op.EQ:
            return operands[0] == operands[1]
        elif e.op == Op.NE:
            return operands[0] != operands[1]
        elif e.op == Op.LE:
            return operands[0] <= operands[1]
        elif e.op == Op.LT:
            return operands[0] < operands[1]
        elif e.op == Op.GE:
            return operands[0] >= operands[1]
        elif e.op == Op.GT:
            return operands[0] > operands[1]
        elif e.op == Op.NOT:
            return z3.Not(operands[0])
        elif e.op == Op.AND:
            return z3.And(operands)
        elif e.op == Op.OR:
            return z3.Or(operands)
        elif e.op == Op.XOR:
            return z3.Xor(operands[0], operands[1])
        elif e.op == Op.IFF:
            return operands[0] == operands[1]
        elif e.op == Op.IMP:
            return z3.Or(z3.Not(operands[0]), operands[1])
        elif e.op == Op.IF:
            return z3.If(operands[0], operands[1], operands[2])
        elif e.op == Op.ALLDIFF:
            return z3.Distinct(operands)
コード例 #4
0
ファイル: z3eval.py プロジェクト: elefthei/hyperkernel-hash
    def xor(self, ctx, return_type, a, atype, b, btype):
        assert atype == return_type
        assert atype == btype

        if z3.is_bool(a) or z3.is_bool(b):
            a = util.as_bool(a)
            b = util.as_bool(b)
            assert return_type.size() == 1
            return z3.Xor(a, b)
        else:
            return a ^ b
コード例 #5
0
def solve_lightbulb_problem(N, K):
    solver = z3.Solver()

    # Assign a Boolean variable to each lightbulb:
    # 1 if this lightbulb is selected as the *second* in a pair, 0 otherwise
    selected_bulbs = [z3.Bool("b" + str(i)) for i in range(N)]

    # Assert that every lightbulb turned on eactly once
    for i in range(N):
        solver.add(z3.Xor(selected_bulbs[i], selected_bulbs[(i + K) % N]))

    return result_to_str[str(solver.check())]
コード例 #6
0
def projection_test(n=5, m=7):
    P = polyedre(n=n, m=m)
    F1 = conjunction(P)
    print("Random polyhedron", F1)
    E1 = z3.Exists(z3.Int("x0"), F1)
    #print("Expected formula after projection", E1)
    #t=z3.Tactic("qe")
    #Formula1 = t(E1)
    #print("Expected formula after projection, with z3 quantifier elimination", Formula1)
    a = Parse(get_lexema_list(str(E1))).cooper().toZ3()
    #print("Expected formula after projection, with my implementation of quantifier elimination", a)

    F2 = isl_intersection(P)
    Fp = project(F2)
    Str = "(" + get_formula(Fp) + ")"

    LL = get_lexema_list(Str)
    E2 = Parse(LL)
    #print("Formula given by ISL after projection", E2.toZ3())
    #Formula2=t(E2.toZ3())

    #print("Formula given by ISL after projection, after z3 quantifier elimination", Formula2)
    b = E2.cooper().toZ3()
    #print("Formula given by ISL after projection, after my implementation of quantifier elimination", b)

    my_solver = z3.Solver()

    my_solver.add(z3.Xor(a, b))
    r = my_solver.check() == z3.unsat
    print("Result with my implementation of quantifier elimination", r)
    return r

    #z3_solver = z3.Then("smt","qe").solver()
    #z3_solver.add(z3.Xor(E1,E2.toZ3()))
    #print("Result with z3 quantifier elimination", z3_solver.check()==z3.unsat)

    #z3_solver = z3.Then("smt","qe").solver()
    #z3_solver.add(z3.Not(z3.Implies(E1,E2.toZ3())))
    #print("A => B", z3_solver.check()==z3.unsat)
    #z3_solver = z3.Then("smt", "qe").solver()
    #z3_solver.add(z3.Not(z3.Implies(E2.toZ3(),E1)))
    #print("B => A", z3_solver.check()==z3.unsat)
    '''SE=str(E)
	LLE=get_lexema_list(SE)
	Lol = Parse(LLE)
	print(Lol.toString())
	Unch = Lol.cooper()
	print("unch",Unch.toString())'''
    '''Fp=isl_intersection(P)
コード例 #7
0
def traverse(value: Value, vars: dict[str, Any]):
    if isinstance(value, Variable):
        return vars[value.var]
    elif isinstance(value, Unop):
        if value.op == "__invert__":
            return z3.Not(traverse(value.arg, vars))
        else:
            return eval(f"{UNARY_OPS[value.op]} (__arg)",
                        {"__arg": traverse(value.arg, vars)})
    elif isinstance(value, Binop):
        if value.op == "__and__":
            return z3.And(traverse(value.left, vars),
                          traverse(value.right, vars))
        elif value.op == "__or__":
            return z3.Or(traverse(value.left, vars),
                         traverse(value.right, vars))
        elif value.op == "__xor__":
            return z3.Xor(traverse(value.left, vars),
                          traverse(value.right, vars))
        # don't mistake these two by accident
        elif value.op == "__rshift__":
            return z3.Implies(traverse(value.left, vars),
                              traverse(value.right, vars))
        elif value.op == "__lshift__":
            return z3.Implies(traverse(value.right, vars),
                              traverse(value.left, vars))
        else:
            return eval(
                f"(__left) {BIN_OPS[value.op]} (__right)", {
                    "__left": traverse(value.left, vars),
                    "__right": traverse(value.right, vars),
                })
    elif isinstance(value, Call):
        return traverse(value.fn,
                        vars)(*(traverse(arg, vars) for arg in value.args))
    else:
        return value
コード例 #8
0
def keep_trying(spec, example_gen, code_maker, max_len):
    examples = [next(example_gen)]
    curr_code = []
    while len(code_maker) <= max_len:
        example = examples[-1]
        curr_code = code_maker(examples)
        if not curr_code:
            return False
        try:
            meaning = sem.read_from_parsed(curr_code)
            meaning = meaning[-1]
            in_state = z3.And(*(i() == example[0][i] for i in example[0] if i() in meaning[2]))
            prog = sem.env_to_query(meaning)
            # print('cc', curr_code)
            # print('specs', spec, 'XOR', prog)
            examples.append(example_gen.send((z3.Xor(spec, prog), code_maker.using_examples)))
        except StopIteration:
            #this means that the example generator couldn't make an example
            #the code matches the spec
            return curr_code
        except z3.z3types.Z3Exception:
            #LOL - I'm too dumb for strongly typing
            continue
    return False
コード例 #9
0
ファイル: symbols.py プロジェクト: katelaan/sloth
import z3

from .. import consts, config
from ..utils import logger, utils
from ..z3api import z3utils
from . import struct

###############################################################################
# Core SL-independent symbols
###############################################################################

if_decl = z3.If(True, True, True).decl()
or_decl = z3.Or(True, True).decl()
and_decl = z3.And(True, True).decl()
implies_decl = z3.Implies(True, True).decl()
xor_decl = z3.Xor(True, True).decl()
not_decl = z3.Not(True).decl()
Z3True = z3.BoolVal(True)
Z3False = z3.BoolVal(False)


def LAnd(ls):
    """Smart conjunction over a sequence of :class:`z3.ExprRef`.

    Only introduces an `And` for sequences of at least two expressions.

    >>> x, y = z3.Ints("x y")
    >>> LAnd([x == y])
    x == y
    >>> LAnd([x == y, x!= y, x > y])
    And(x == y, x != y, x > y)
コード例 #10
0
def get_z3_formula(sketch_ir: str, input_bits: int) -> z3.QuantifierRef:
    """Given an intermediate representation of a sketch file and returns a z3
    formula corresponding to that IR with the specified input bits for source
    variables."""

    z3_vars = collections.OrderedDict()
    z3_asserts = []
    z3_srcs = []
    for line in sketch_ir.splitlines():
        records = line.split()
        start = records[0]
        if (start in ['dag', 'TUPLE_DEF']):
            continue
        else:
            # common processing across all nodes
            output_var = '_n' + records[0]
            operation = records[2]
            if operation in ['NEG', 'NOT']:
                operand1 = z3_vars['_n' + records[4]]
                check_sort(operand1)
            elif operation in [
                    'AND', 'OR', 'XOR', 'PLUS', 'TIMES', 'DIV', 'MOD', 'LT',
                    'EQ'
            ]:
                operand1 = z3_vars['_n' + records[4]]
                operand2 = z3_vars['_n' + records[5]]
                check_sort(operand1)
                check_sort(operand2)

            # node-specific processing
            if operation == 'ASSERT':
                z3_asserts += ['_n' + records[3]]
            elif operation == 'S':
                var_type = records[3]
                source_name = records[4]
                assert var_type == 'INT', ('Unexpected variable type found in \
                        sketch IR:', line)
                z3_vars[source_name] = z3.Int(source_name)
                z3_vars[output_var] = z3.Int(source_name)
                z3_srcs += [source_name]
            elif operation in ['NEG']:
                z3_vars[output_var] = -make_int(operand1)
            elif operation in ['NOT']:
                z3_vars[output_var] = z3.Not(make_bool(operand1))
            elif operation in [
                    'AND', 'OR', 'XOR', 'PLUS', 'TIMES', 'DIV', 'MOD', 'LT',
                    'EQ'
            ]:
                if operation == 'AND':
                    z3_vars[output_var] = z3.And(make_bool(operand1),
                                                 make_bool(operand2))
                elif operation == 'OR':
                    z3_vars[output_var] = z3.Or(make_bool(operand1),
                                                make_bool(operand2))
                elif operation == 'XOR':
                    z3_vars[output_var] = z3.Xor(make_bool(operand1),
                                                 make_bool(operand2))
                elif operation == 'PLUS':
                    z3_vars[output_var] = make_int(operand1) + make_int(
                        operand2)
                elif operation == 'TIMES':
                    z3_vars[output_var] = make_int(operand1) * make_int(
                        operand2)
                elif operation == 'DIV':
                    z3_vars[output_var] = make_int(operand1) / make_int(
                        operand2)
                elif operation == 'MOD':
                    z3_vars[output_var] = make_int(operand1) % make_int(
                        operand2)
                elif operation == 'LT':
                    z3_vars[output_var] = make_int(operand1) < make_int(
                        operand2)
                elif operation == 'EQ':
                    z3_vars[output_var] = make_int(operand1) == make_int(
                        operand2)
                else:
                    assert False, ('Invalid operation', operation)
            # One can consider ARRACC and ARRASS as array access and
            # assignment. For more details please refer this sketchusers
            # mailing list thread.
            # https://lists.csail.mit.edu/pipermail/sketchusers/2019-August/000104.html
            elif operation in ['ARRACC']:
                predicate = make_bool((z3_vars['_n' + records[4]]))
                yes_val = z3_vars['_n' + records[7]]
                no_val = z3_vars['_n' + records[6]]
                z3_vars[output_var] = z3.If(predicate, yes_val, no_val)
            elif operation in ['ARRASS']:
                var_type = type(z3_vars['_n' + records[4]])
                if var_type == z3.BoolRef:
                    assert records[6] in ['0', '1']
                    cmp_constant = records[6] == '1'
                elif var_type == z3.ArithRef:
                    cmp_constant = int(records[6])
                else:
                    assert False, ('Variable type', var_type, 'not supported')
                predicate = z3_vars['_n' + records[4]] == cmp_constant
                yes_val = z3_vars['_n' + records[8]]
                no_val = z3_vars['_n' + records[7]]
                z3_vars[output_var] = z3.If(predicate, yes_val, no_val)
            elif operation in ['CONST']:
                var_type = records[3]
                if var_type == 'INT':
                    z3_vars[output_var] = z3.IntVal(int(records[4]))
                elif var_type == 'BOOL':
                    assert records[4] in ['0', '1']
                    z3_vars[output_var] = z3.BoolVal(records[4] == '1')
                else:
                    assert False, ('Constant type', var_type, 'not supported')
            else:
                assert False, ('Unknown operation:', line)

    # To handle cases where we don't have any assert or source variable, add
    # a dummy bool variable.
    constraints = z3.BoolVal(True)
    for var in z3_asserts:
        constraints = z3.And(constraints, z3_vars[var])

    variable_range = z3.BoolVal(True)
    for var in z3_srcs:
        variable_range = z3.And(
            variable_range,
            z3.And(0 <= z3_vars[var], z3_vars[var] < 2**input_bits))

    final_assert = z3.ForAll([z3_vars[x] for x in z3_srcs],
                             z3.Implies(variable_range, constraints))
    # We could use z3.simplify on the final assert, however that could result
    # in a formula that is oversimplified and doesn't have a QuantfierRef which
    # is expected from the negated_body() function above.
    return final_assert
コード例 #11
0
ファイル: bool.py プロジェクト: cpstdhs/mythril-docker
def Xor(a: Bool, b: Bool) -> Bool:
    """Create an And expression."""

    union = a.annotations.union(b.annotations)
    return Bool(z3.Xor(a.raw, b.raw), union)
コード例 #12
0
ファイル: smtlibv2.py プロジェクト: liumuqing/symbbl
 def __rxor__(self, other):
     return Bool(z3.Xor(self.symbol, other.symbol))
コード例 #13
0
 def __xor__(self, other: 'z3Bit') -> 'z3Bit':
     return type(self)(z3.Xor(self.value, other.value))
コード例 #14
0
 def Xor(self, other: Bool):
     return BoolExpr(z3.Xor(self.z3obj, other.z3obj))
コード例 #15
0
 def Xor(self, other: Bool):
     if isinstance(other, BoolV):
         return BoolV((self.value or other.value)
                      and not (self.value and other.value))
     return BoolExpr(z3.Xor(self.z3obj, other.z3obj))
コード例 #16
0
ファイル: combi_expressions.py プロジェクト: Jakob-Bach/CFFS
 def __init__(self, bool_expression1: expr.BooleanExpression, bool_expression2: expr.BooleanExpression):
     super().__init__(bool_expression1, bool_expression2)
     self.z3_expr = z3.Xor(bool_expression1.get_z3(), bool_expression2.get_z3())
コード例 #17
0
ファイル: z3_helpers.py プロジェクト: arey0pushpa/syndra
def Iff(a, b):
    return z3.Not(z3.Xor(a, b))
コード例 #18
0
def wildcard_trace(solver, symbols: Dict[int, List[str]], use_priming=True) -> Tuple[z3.CheckSatResult, Dict]:
    """Return the result of the attack (sat means attack was successful) and associated data.

    If the attack was successful, associated data includes what attack was executed, witness
    information, and the solution found.  The dictionary keys are "strategy", "solution", "witness".

    Otherwise, associated data includes the attack that was executed, plus debug info under the key
    "debug_info".

    """

    if use_priming:
        # prime the solver with an easy question.
        # In `test_optional_dollar` it seems to give several folds speedup, if you used the same
        # solver for `test_dot` previously.
        prime_result, prime_model = check_formulae(solver, z3.Not(RegexStringExpr.ignore_wildcards))
        logger.info('check %s', prime_result)
        if prime_result == z3.sat:
            logger.debug(public_vars(prime_model))
        else:
            solver.reset()
            return prime_result, None

    base = proto + proto_delimiter + fqdn

    solver.add(
        z3.Or(proto_delimiter == z3.StringVal('//'), proto_delimiter == z3.StringVal(WILDCARD_MARKER * 2)),
        z3.Xor(z3.PrefixOf(base + '/', unknown_string), base == unknown_string),
        z3.Not(z3.Contains(proto, '/')),
        z3.Not(z3.Contains(fqdn, '/')),
        z3.Length(proto) > 0,
        z3.Length(fqdn) > 0,
        )

    if DEBUG:
        #debug_result, debug_model = check_formulae(solver, unknown_string == debug_model[unknown_string])
        #logger.debug(debug_model)

        debug_result = solver.check()
        logger.debug(debug_result)
        if debug_result == z3.sat:
            debug_model = solver.model()
            logger.debug(public_vars(debug_model))
            ans = z3.simplify(debug_model[proto] + debug_model[proto_delimiter] + debug_model[fqdn])
        else:
            return debug_result, None


        #debug_result, debug_model = check_formulae(solver, (base_url != ans))

    result, model = check_formulae(solver,
                                   z3.Not(RegexStringExpr.ignore_wildcards),
                                   z3.Contains(proto + fqdn, z3.StringVal(WILDCARD_MARKER)))

    if result == z3.sat:
        _conc1 = lambda zs: tz.first(concretizations(z3_str_to_bytes(zs), symbols))
        logger.info(public_vars(model))
        ans = z3.simplify(model[proto] + model[proto_delimiter] + model[fqdn])
        return result, {
            'solution': _conc1(model[unknown_string]).replace(WILDCARD_MARKER, OUTPUT_WILDCARD_MARKER),
            'strategy': 'wildcard_trace',
            'witness': _conc1(ans).replace(WILDCARD_MARKER, OUTPUT_WILDCARD_MARKER)}
    else:
        return result, {'strategy': 'wildcard_trace',
                        'debug_info': None}
コード例 #19
0
def find_best_intervals(videos: List[Video], goal_intervals: List[Interval], target_codec, target_resolution):
    # For each GOP, find the intervals that are from the same video and cross that GOP.
    # For each goal fragment, create a boolean variable.
    opt = z3.Optimize()

    # Each interval has to be covered.
    cover_vars = []
    for i, interval in enumerate(goal_intervals):
        is_covered = z3.Bool(f'interval-{i}')
        opt.add(is_covered)
        cover_vars.append(is_covered)

    # Each video fragment can be used or not.
    goal_interval_to_video_fragments = defaultdict(list)
    video_fragment_indicators: List[List] = []
    for i, video in enumerate(videos):
        fragment_indicators = []
        for j, fragment in enumerate(video.fragments):
            # Requires that the fragment intervals exactly match some goal interval.
            goal_interval_to_video_fragments[(fragment.interval.start, fragment.interval.end)].append((i, j))

            # Could probably be a bool, but then have to use ite rather than indicator variable.
            fragment_is_used = z3.Int(f'fragment-{i}-{j}')
            opt.add(fragment_is_used >= 0, fragment_is_used <= 1)
            fragment_indicators.append(fragment_is_used)
        video_fragment_indicators.append(fragment_indicators)

    # Each interval is covered iff one fragment covers it. Only one fragment should be picked.
    for i, interval in enumerate(goal_intervals):
        # Find all possible fragments for this interval.
        possible_fragments = goal_interval_to_video_fragments[(interval.start, interval.end)]
        if len(possible_fragments) == 1:
            video_index = possible_fragments[0][0]
            fragment_index = possible_fragments[0][1]
            opt.add(cover_vars[i] == video_fragment_indicators[video_index][fragment_index] > 0)
        elif len(possible_fragments):
            choice1 = possible_fragments[0]
            choice2 = possible_fragments[1]
            pick_one = z3.Xor(video_fragment_indicators[choice1[0]][choice1[1]] > 0, video_fragment_indicators[choice2[0]][choice2[1]] > 0)
            for j in range(2, len(possible_fragments)):
                choice = possible_fragments[j]
                pick_one = z3.Xor(pick_one, video_fragment_indicators[choice[0]][choice[1]] > 0)
            opt.add(cover_vars[i] == pick_one)
        else:
            # Shouldn't really happen, but useful for checking logic.
            assert False, "This usually happens for a request outside of a logical video's duration"
            #opt.add(cover_vars[i] == z3.BoolVal(False))

    # Add encode cost for each fragment that is used.
    # If the fragment's target == goal target, then encode cost is 0.
    # Else it's the size of the fragment.
    # If the target format is raw, then switching formats will still be necessary.
    video_fragment_encode_costs: List[List] = []
    # if not is_raw(target):
    for v, video in enumerate(videos):
        fragment_encode_costs = []
        for f, fragment in enumerate(video.fragments):
            encode_cost = z3.Int(f'encode-cost-{v}-{f}')
            opt.add(encode_cost >= 0)
            if fragment.target == target_codec and fragment.resolution == target_resolution:
                encode_cost = z3.IntVal(0)
            else:
                # There is no lookback cost associated with encoding because it assumes the starting point is raw.
                # should_add_penalty_for_encoding_raw_frame = fragment.interval.length() == 1 and is_raw(fragment.target)
                encode_cost = video_fragment_indicators[v][f] * fragment.interval.length() * estimate_encode_cost(target_codec, target_resolution)
            fragment_encode_costs.append(encode_cost)
        video_fragment_encode_costs.append(fragment_encode_costs)


    # Look at each fragment. Decode cost = decode cost for all GOPs that are completely covered + z3 cost for GOPs that are partially covered.
    # If the GOP is in the same format as the target, then no decode is necessary.
    video_decode_and_lookback_costs: List[List] = []
    for v, video in enumerate(videos):
        # Have non-negotiable fragment decode costs.
        fragment_decode_costs = []
        for f, fragment in enumerate(video.fragments):
            fragment_decode_cost = z3.Int(f'fragment-decode-cost-{v}-{f}')
            if is_raw(fragment.target) or (fragment.target == target_codec and fragment.resolution == target_resolution):
                fragment_decode_cost = z3.IntVal(0)
            else:
                decode_cost = video.fragment_decode_dependencies[f].num_p_frames * non_keyframe_cost(target_codec, target_resolution) + video.fragment_decode_dependencies[f].num_keyframes * keyframe_cost(target_codec, target_resolution)
                fragment_decode_cost = video_fragment_indicators[v][f] * decode_cost
            fragment_decode_costs.append(fragment_decode_cost)

        video_decode_and_lookback_costs.append(fragment_decode_costs)

        gop_decode_costs = []
        for gop_idx, fragment_idxs in video.gop_to_partial_fragments.items():
            # Add the decode cost of the GOP.
            if is_raw(video.gops[gop_idx].target):
                # If the GOP is raw, then there is no decode or lookback cost.
                continue
            elif video.gops[gop_idx].target != target_codec:
                # If the GOPs target doesn't equal the goal target, then we will have to decode as much as necessary for
                # any fragments that are picked and partially lie in this GOP.
                # Start by looking at the furthest fragments, then work towards the ones at the start of the GOP.
                # Reverse the list of fragment indexes so that the latest one comes first.
                # This assumes that the fragments are stored in ascending order.
                fragment_idxs.sort(reverse=True)
                not_later_fragments = z3.And()
                fragment_subcosts = []
                for f_idx in fragment_idxs:
                    associated_fragment = video.fragments[f_idx]
                    associated_gop = video.gops[gop_idx]
                    num_p_frames_to_decode = min(associated_fragment.interval.end, associated_gop.interval.end) - associated_gop.interval.start
                    # TODO: Update to use actual cost.
                    cost = num_p_frames_to_decode * non_keyframe_cost(target_codec, target_resolution) + 1 * keyframe_cost(target_codec, target_resolution)

                    fragment_cost = z3.If(z3.And(not_later_fragments, video_fragment_indicators[v][f_idx] > 0), cost, 0)

                    fragment_subcosts.append(fragment_cost)
                    not_later_fragments = z3.And(not_later_fragments, video_fragment_indicators[v][f_idx] < 1)
                decode_cost = z3.Sum(fragment_subcosts)
                gop_decode_costs.append(decode_cost)
            else:
                # The gop's target is the same as the goal. We only have to count the cost of decoding up to the start of the last fragment.
                fragment_idxs.sort(reverse=True)
                not_later_fragments = z3.And()
                lookback_costs = []
                for f_idx in fragment_idxs:
                    associated_fragment = video.fragments[f_idx]
                    associated_gop = video.gops[gop_idx]
                    # There will only be a lookback cost if the fragment starts within the GOP.
                    # If a fragment ends partway through a GOP, there is no cost because that GOP can be truncated without
                    # decoding since it's already in the desired format.
                    # Don't have to check ends because we know the fragment intersects the GOP.
                    if associated_fragment.interval.start > associated_gop.interval.start:
                        num_leading_p_frames = associated_fragment.interval.start - associated_gop.interval.start
                        # TODO: Update to use actual cost.
                        cost = num_leading_p_frames * non_keyframe_cost(target_codec, target_resolution) + 1 * keyframe_cost(target_codec, target_resolution)

                        # I think this should actually be the cost of all of the frames not in fragments but before and
                        # between included fragments.
                        fragment_cost = z3.If(z3.And(not_later_fragments, video_fragment_indicators[v][f_idx] > 0), cost, 0)

                        lookback_costs.append(fragment_cost)

                    not_later_fragments = z3.And(not_later_fragments, video_fragment_indicators[v][f_idx] < 1)
                lookback_cost = z3.Sum(lookback_costs)
                gop_decode_costs.append(lookback_cost)
        video_decode_and_lookback_costs.append(gop_decode_costs)

        """
        # Seek costs for all fragments (not just partial ones)
        seek_costs = []
        not_later_fragments = z3.And()
        video.fragments.sort(key=lambda f: f.id, reverse=True)
        for index, fragment in enumerate(video.fragments):
            seek_cost = z3.If(z3.And(not_later_fragments, video_fragment_indicators[v][index] > 0),
                              sum(g.interval.end - g.interval.start for g in video.gops if g.interval.end <= fragment.interval.start), 0)
            not_later_fragments = z3.And(not_later_fragments, video_fragment_indicators[v][index] < 1)
            seek_costs.append(seek_cost)
        #total_seek_cost = z3.Sum(seek_costs)
        video_decode_and_lookback_costs.append(seek_costs)
        """

    flat = list(itertools.chain.from_iterable(video_fragment_encode_costs + video_decode_and_lookback_costs))
    opt.minimize(z3.Sum(flat))

    result = opt.check()
    if result == z3.sat:
        model = opt.model()

        # Get the fragments from each video that should be read.
        video_to_fragments = defaultdict(list)
        fragment_ids = []
        total_num_fragments = 0
        for v, video in enumerate(videos):
            for f, fragment in enumerate(video.fragments):
                if model.eval(video_fragment_indicators[v][f]) == 1:
                    video_to_fragments[video.id].append(fragment.id)
                    fragment_ids.append(fragment.id)
                    total_num_fragments += 1

        # print(video_to_fragments)
        return fragment_ids
    else:
        print(opt.unsat_core())

    return {}
コード例 #20
0
ファイル: smtpnras.py プロジェクト: leonardt/SMT-PNR
def run_test(adj,
             fab_dims,
             wire_lengths={},
             debug_prints=True,
             constraints_gen=place_constraints_2d,
             model_checker=None,
             model_printer=print_model_2d):

    comps = build_graph(adj)

    if wire_lengths:  #use provided wire lengths
        print('Finding satisfying model with given wire lenths')
        frows, fcols = fab_dims

        def _get_x(bv):
            return z3.Extract(frows + fcols - 1, frows, bv)

        def _get_y(bv):
            return z3.Extract(frows - 1, 0, bv)

        #def _get_x(bv): return bv[0]
        #def _get_y(bv): return bv[1]
        constraints = constraints_gen(comps, fab_dims, wire_lengths)
        s = z3.Optimize()
        #z3.set_param("timeout",120000)
        #z3.set_param(verbose=10)
        s.add(constraints)
        wire_lengths = sorted(wire_lengths)
        for comp in comps:
            for adj in comp.inputs:
                #should dispatch to fabric to get the rules
                w = len(wire_lengths)
                s.add_soft(
                    z3.Xor(
                        _get_x(comp.pos) == _get_x(adj.pos),
                        _get_y(comp.pos) == _get_y(adj.pos)), str(w + 1))
                for wl in wire_lengths:
                    #should probably be comp.pos & (shifted adj.pos) != 0,
                    #so that components can have variable size
                    zx = z3.Or(
                        _get_x(comp.pos) == z3.LShR(_get_x(adj.pos), wl),
                        _get_x(comp.pos) == _get_x(adj.pos) << wl)
                    zy = z3.Or(
                        _get_y(comp.pos) == z3.LShR(_get_y(adj.pos), wl),
                        _get_y(comp.pos) == _get_y(adj.pos) << wl)
                    s.add_soft(zx, str(w))
                    s.add_soft(zy, str(w))
                    w -= 1
        if s.check() != z3.sat:
            if debug_prints:
                print('test is unsat')
            return s

        if debug_prints:
            print('test is sat')

        model_printer(s.model(), comps, fab_dims, wire_lengths)

        #if debug_prints and all(model_checker(s.model(), comps, fab_dims, wire_lengths)):
        #    model_printer(s.model(), comps, fab_dims, wire_lengths)
        #    return (True, s)
        #elif debug_prints:
        #    return (False, s)
        #elif all(model_checker(s.model(), comps, fab_dims, wire_lengths, printer=lambda *x: None)):
        #    return (True, s)
        #else:
        #    return (False, s)
    else:  #no provided wire lengths, optimize the manhattan distance
        print('No provided wire lengths. Minimizing total L1 norm')
        constraints, manhattan_dist = place_constraints_opt(comps, fab_dims)
        s = z3.Optimize()
        s.add(constraints)
        h = s.minimize(manhattan_dist)
        s.set('enable_sls', True)
        if s.check() != z3.sat:
            if debug_prints:
                print('test is unsat')
            return s

        if debug_prints:
            print('test is sat')
            print('Total L1 Norm = ', s.lower(h))

        #print(s.model())

        print_model_opt(s.model(), comps, fab_dims)
        return s
コード例 #21
0
 def _op_raw_Xor(self, *args):
     return z3.Xor(*(tuple(args) + (self._context, )))
コード例 #22
0
class Logic:
    """
    rules:
        class1:
            out = func(ins)
        class2:
            X1 X2 | ddd
            not by input N
            out = X2(X1(group1), X1(group2), ...)
            len(group1)len(group2)... = ddd
        class3
            MX MXT MXIT
            out = (not) in[s]
        class4
            out = in
        class5
            CO = in1 in2 + in2 in3 + in1 in3
            out = xor(in)
    """
    LogicPrefix = [
        'XOR',
        'OR',
        'NAND',
        'AND',
        'XNOR',
        'NOR',  # class 1
        'AOI',
        'AO',
        'OAI',
        'OA',  # class 2
        'MXIT',
        'MXT',
        'MX',  # class 3 
        'BUFH',
        'BUF',
        'INV',  # class 4
        'ADDF',
        'ADDH'  # class 5
    ]

    LogicClass = {
        'XOR': 1,
        'OR': 1,
        'NAND': 1,
        'AND': 1,
        'XNOR': 1,
        'NOR': 1,
        'AOI': 2,
        'AO': 2,
        'OAI': 2,
        'OA': 2,  # class 2
        'MXIT': 3,
        'MXT': 3,
        'MX': 3,  # class 3 
        'BUFH': 4,
        'BUF': 4,
        'INV': 4,  # class 4
        'ADDH': 5,
        'ADDF': 6  # class 5/6
    }

    BasicP = {
        'XOR': lambda I1, I2: 1 - (I1) * (I2) - (1 - (I1)) * (1 - (I2)),
        'OR': lambda I1, I2: 1 - (1 - (I1)) * (1 - (I2)),
        'AND': lambda I1, I2: (I1) * (I2),
        'INV': lambda I1: 1 - (I1),
        'BUF': lambda I1: (I1),
        'MX': lambda I1, I2, SW: (SW) * (I2) + (1 - (SW)) *
        (I1)  # SW == 0 => I1
    }

    BasicA = {
        'XOR': lambda I1, I2: (I1) ^ (I2),
        'OR': lambda I1, I2: (I1) or (I2),
        'AND': lambda I1, I2: (I1) and (I2),
        'INV': lambda I1: not (I1),
        'BUF': lambda I1: (I1),
        'MX': lambda I1, I2, SW: (I2) if SW else (I1)
    }

    BasicZ3 = {
        'XOR':
        lambda I1, I2: z3.Xor((I1), (I2)),
        'OR':
        lambda I1, I2: z3.Or((I1), (I2)),
        'AND':
        lambda I1, I2: z3.And((I1), (I2)),
        'INV':
        lambda I1: z3.Not((I1)),
        'BUF':
        lambda I1: (I1),
        'MX':
        lambda I1, I2, SW: z3.Or(z3.And((SW),
                                        (I2)), z3.And(z3.Not((SW)), (I1)))
    }

    def parseFunc(self, fn):
        for i in Logic.LogicPrefix:
            if fn.startswith(i):
                self.func = i
                self.para = fn[len(i):]
                self.para = re.sub('\D', '', self.para)
                self.lclass = Logic.LogicClass[self.func]
                return
        assert False  # unknown func type !!

    def statement(self):
        a1 = self.func + self.para + " "
        a1 += self.define['name'] + ' ('

        def processArg(arg):
            sub = []
            for i in arg:
                for j in arg[i]:
                    sub.append("." + j['group'] + str(j['id']) +
                               ('N' if j['inv'] else '') + "(" + j['name'] +
                               ')')
            return ', '.join(sub)

        a1 += processArg(self.define['argsIn']) + ', ' + processArg(
            self.define['argsOut']) + ' );'
        return a1

    def __init__(self, define):
        self.name = define['name']
        # Define {func:string, name:strig, argsIn: Args, argsOut: Args}
        # Args {A: [Arg], B: [Arg], C:[Arg], ...}
        # Arg {group:string, id:int, name: name (name in net), inv: bool, type:input/output}
        self.define = define
        self.inputs = []
        for i in define['argsIn']:
            for j in define['argsIn'][i]:
                self.inputs.append(j['name'])
        self.outputs = []
        for i in define['argsOut']:
            for j in define['argsOut'][i]:
                self.outputs.append(j['name'])
        self.parseFunc(define['func'])
        #if self.name == 'U5069':
        #    print(self.define['argsIn'])
        #    self.define['argsIn']['D'][0].inv = False
        self.acc_mode_cache = {k: [] for k in self.outputs}
        if (MODE_ACC):
            for i in range(0, 2**len(self.inputs)):
                inx = []
                k = i
                for j in range(0, len(self.inputs)):
                    inx.append((k % 2) == 1)
                    k //= 2
                inmap = {k: v for k, v in zip(self.inputs, inx)}
                rst = self.eval(lambda I: inmap[I])
                for k in rst:
                    if rst[k]:
                        self.acc_mode_cache[k].append(inx)

    def __str__(self):
        outs = ""
        outs += " > Logic: " + self.name + "(" + self.func + ", " + self.para + ")" + "\n"
        outs += "  > Inputs: \n"
        innames = []
        for gsi in self.define['argsIn']:
            for ii in self.define['argsIn'][gsi]:
                nn = "N" if ii['inv'] else ""
                outs += "   >" + ii['group'] + str(
                    ii['id']) + " " + (nn) + ": " + ii['name'] + "\n"
                innames.append(ii['name'])
        outs += "  > Outputs: \n"
        statP = {}
        for gsi in self.define['argsOut']:
            for ii in self.define['argsOut'][gsi]:
                nn = "N" if ii['inv'] else ""
                outs += "   >" + ii['group'] + str(
                    ii['id']) + " " + (nn) + ": " + ii['name'] + "\n"
                statP[ii['name']] = 0
        outs += "  Input = " + ", ".join(innames) + "\n"

        for i in range(0, 2**len(innames)):
            k = i
            outs += "  Input = "
            iis = []
            invals = {}
            for j in range(0, len(innames)):
                invals[innames[j]] = ((k % 2) == 1)
                iis.append(str(k % 2))
                k //= 2
            outs += ", ".join(iis)
            evals = self.eval(lambda I: invals[I])
            for ii in evals:
                if evals[ii]:
                    statP[ii] = statP[ii] + 1
            outs += " => " + str(evals) + "\n"
        outP = self.getPossible(lambda I: 0.5)
        outPa = [statP[i] / (2**len(innames)) for i in statP]
        outs += "  OutputP = " + str(outP) + ", OutputPACC = " + str(outPa)
        return outs

    def getRunFunc(self):
        # class 1, 3, 4 only
        assert self.lclass == 1 or self.lclass == 3 or self.lclass == 4
        runFunc = self.func
        postNot = False
        if runFunc == 'XNOR':
            runFunc = 'XOR'
            postNot = True
        if runFunc == 'NAND':
            runFunc = 'AND'
            postNot = True
        if runFunc == 'NOR':
            runFunc = 'OR'
            postNot = True
        if runFunc == 'MXT':
            runFunc = 'MX'
        if runFunc == 'BUFH':
            runFunc = 'BUF'
        if runFunc == 'MXIT':
            runFunc = 'MX'
            postNot = True

        return runFunc, postNot

    def getRunSteps(self):
        # class == 2
        assert self.lclass == 2
        decode = {'A': 'AND', 'O': 'OR'}
        s1 = decode[self.func[0]]
        s2 = decode[self.func[1]]
        postNot = len(self.func) == 3 and self.func[2] == 'I'
        return s1, s2, postNot

    # some input maybe invert input
    def getInputWithINV(self, state, fNOT, arg):
        val = state(arg['name'])
        if (arg['inv']):
            val = fNOT(val)
        return val

    # class2 is grouped in groups
    def applyClass2(self, state, fX1, fX2, fNOT, inv):
        s2_inputs = []
        xlogs = ""
        for gsi in self.define['argsIn']:
            gs = self.define['argsIn'][gsi]
            if len(gs) == 1:
                s2_inputs.append(self.getInputWithINV(state, fNOT, gs[0]))
            else:
                s1_grouped_inputs = [
                    self.getInputWithINV(state, fNOT, ga) for ga in gs
                ]
                xlogs += str(s1_grouped_inputs)
                s1_val = reduce(fX1, s1_grouped_inputs)
                s2_inputs.append(s1_val)
        val = reduce(fX2, s2_inputs)
        if inv:
            val = fNOT(val)
        return val

    def getAllInputs(self, state, fNOT):
        ins = []
        for i in self.define['argsIn']:
            argsi = self.define['argsIn'][i]
            for j in argsi:
                ins.append(self.getInputWithINV(state, fNOT, j))
        return ins

    def applyGen(self, state, fGroup):
        if self.lclass == 1:
            runFunc, postNot = self.getRunFunc()

            in_vals = self.getAllInputs(state, fGroup['INV'])
            val = reduce(fGroup[runFunc], in_vals)
            if postNot:
                val = fGroup['INV'](val)
            return {self.outputs[0]: val}
        if self.lclass == 3:
            runFunc, postNot = self.getRunFunc()
            in_vals = ['A', 'B', 'S']
            in_vals = [
                self.getInputWithINV(state, fGroup['INV'],
                                     (self.define['argsIn'][i][0]))
                for i in in_vals
            ]
            val = fGroup[runFunc](*in_vals)
            if postNot:
                val = fGroup['INV'](val)
            return {self.outputs[0]: val}
        if self.lclass == 2:
            x1, x2, postNot = self.getRunSteps()
            val = self.applyClass2(state, fGroup[x1], fGroup[x2],
                                   fGroup['INV'], postNot)
            return {self.outputs[0]: val}
        if self.lclass == 4:
            runFunc, _ = self.getRunFunc()
            val = fGroup[runFunc](*self.getAllInputs(state, fGroup['INV']))
            return {self.outputs[0]: val}
        if self.lclass == 5:  # ADDH
            in_vals = ['A', 'B']
            x1, x2 = [
                self.getInputWithINV(state, fGroup['INV'],
                                     (self.define['argsIn'][i][0]))
                for i in in_vals
            ]
            s = fGroup['XOR'](x1, x2)
            co = fGroup['AND'](x1, x2)
            rst = {'S': s, 'CO': co}
            rst = {
                self.define['argsOut'][k][0]['name']: rst[k]
                for k in self.define['argsOut']
            }
            return rst

        if self.lclass == 6:  # ADDF
            in_vals = ['A', 'B', 'CI']
            x1, x2, ci = [
                self.getInputWithINV(state, fGroup['INV'],
                                     (self.define['argsIn'][i][0]))
                for i in in_vals
            ]
            x1xx2 = fGroup['XOR'](x1, x2)
            s = fGroup['XOR'](x1xx2, ci)
            co = fGroup['OR'](fGroup['AND'](x1xx2, ci), fGroup['AND'](x1, x2))
            rst = {'S': s, 'CO': co}
            rst = {
                self.define['argsOut'][k][0]['name']: rst[k]
                for k in self.define['argsOut']
            }
            return rst

    def getPossible(self, state):
        if (MODE_ACC):
            inP = [state(i) for i in self.inputs]
            ouP = {}
            for o in self.outputs:
                sop = 0
                for j in self.acc_mode_cache[o]:
                    op = 1
                    for val, p in zip(self.inputs, inP):
                        if val:
                            op *= p
                        else:
                            op *= 1 - p
                    sop += op
                ouP[o] = sop
            return ouP
        if (self.lclass == 6):
            # special workaround for adder
            in_vals = ['A', 'B', 'CI']
            x1, x2, ci = [
                state(self.define['argsIn'][i][0]['name']) for i in in_vals
            ]
            x1xx2 = Logic.BasicP['XOR'](x1, x2)
            s = Logic.BasicP['XOR'](x1xx2, ci)
            co = 0
            for v1 in range(0, 2):
                for v2 in range(0, 2):
                    for v3 in range(0, 2):
                        if v1 + v2 + v3 > 1:
                            p1 = x1 if v1 == 1 else 1 - x1
                            p2 = x2 if v2 == 1 else 1 - x2
                            p3 = ci if v3 == 1 else 1 - ci
                            co += p1 * p2 * p3
            rst = {'S': s, 'CO': co}
            rst = {
                self.define['argsOut'][k][0]['name']: rst[k]
                for k in self.define['argsOut']
            }
            return rst
        else:
            return self.applyGen(state, Logic.BasicP)

    def eval(self, state):
        return self.applyGen(state, Logic.BasicA)

    def z3Interface(self, state):
        return self.applyGen(state, Logic.BasicZ3)