def test_nested():
    formula = (smt.Symbol("x", smt.REAL) * smt.Real(2) +
               smt.Real(5.125)) * smt.Real(-1.25)
    positive = (smt.Symbol("x", smt.REAL) * smt.Real(2) +
                smt.Real(5.125)) * smt.Real(1.25)
    result = make_coefficients_positive(formula)
    assert Polynomial.from_smt(positive) == Polynomial.from_smt(result)
Beispiel #2
0
def find_hl(data, domain, active_indices, solver):

    # Constants
    n_r = len(domain.real_vars)

    real_features = [[row[v] for v in domain.real_vars] for row, _ in data]
    labels = [row[1] for row in data]

    # Variables
    a_r = [smt.Symbol("a_r[{}]".format(r), REAL) for r in range(n_r)]
    b = smt.Symbol("b", REAL)

    # Constraints
    for i in active_indices:
        x_r, label = real_features[i], labels[i]
        sum_coefficients = smt.Plus(
            [a_r[r] * smt.Real(x_r[r]) for r in range(n_r)])
        if label:
            solver.add_assertion(sum_coefficients + DELTA <= b)
        else:
            solver.add_assertion(sum_coefficients - DELTA > b)

    if not solver.solve():
        return None
    model = solver.get_model()

    x_vars = [domain.get_symbol(domain.real_vars[r]) for r in range(n_r)]
    return smt.Plus([model.get_value(a_r[r]) * x_vars[r]
                     for r in range(n_r)]) <= model.get_value(b)
Beispiel #3
0
def example4(domain):
    x, y = smt.Symbol("x", REAL), smt.Symbol("y", REAL)
    return domain, (
        ((106.452209182 < 58.3305562428 * x + 162.172448357 * y) |
         (-82.1173457701 < -121.782718841 * x + -45.7311195244 * y)) &
        ((58.3305562428 * x + 162.172448357 * y <= 106.452209182) |
         (-121.782718841 * x + -45.7311195244 * y <= -82.1173457701)))
Beispiel #4
0
def example3(domain):
    x, y = smt.Symbol("x", REAL), smt.Symbol("y", REAL)
    return domain, (
        ((5.03100425089 < 4.72202520763 * x + 4.11473198213 * y) |
         (-4.6261635019 < -5.93640712709 * x + -5.87100650773 * y)) &
        ((5.03100425089 < 4.72202520763 * x + 4.11473198213 * y) |
         (-4.6261635019 < -5.93640712709 * x + -5.87100650773 * y)))
Beispiel #5
0
    def test_simpleSymbol_true(self):
        symbol_name = "b"

        checker = SmtChecker({symbol_name: True})
        self.assertTrue(checker.walk_smt(smt.Symbol(symbol_name)))

        checker = SmtChecker({symbol_name: smt.Bool(True)})
        self.assertTrue(checker.walk_smt(smt.Symbol(symbol_name)))
Beispiel #6
0
    def test_simpleSymbol_false(self):
        symbol_name = "b"

        checker = SmtChecker({symbol_name: False})
        self.assertFalse(checker.walk_smt(smt.Symbol(symbol_name)))

        checker = SmtChecker({symbol_name: smt.Bool(False)})
        self.assertFalse(checker.walk_smt(smt.Symbol(symbol_name)))
Beispiel #7
0
def test_convert_weight():
    x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL)
    a = smt.Symbol("a", smt.BOOL)
    weight_function = (smt.Ite(
        (a & (x > 0) & (x < 10) & (x * 2 <= 20) & (y > 0) & (y < 10))
        | (x > 0) & (x < y) & (y < 20),
        x + y,
        x * y,
    ) + 2)
Beispiel #8
0
    def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC):
        if name is not AUTOMATIC and value is not SMYBOLIC:
            raise TypeError('Can only name symbolic variables')
        elif name is not AUTOMATIC:
            if not isinstance(name, str):
                raise TypeError('Name must be string')
            elif name in _name_table:
                raise ValueError(f'Name {name} already in use')
            _name_table[name] = self

        T = BVType(self.size)

        if value is SMYBOLIC:
            if name is AUTOMATIC:
                value = shortcuts.FreshSymbol(T)
            else:
                value = shortcuts.Symbol(name, T)
        elif isinstance(value, pysmt.fnode.FNode):
            t = value.get_type()
            if t is not T:
                raise TypeError(f'Expected {T} not {t}')
        elif isinstance(value, type(self)):
            value = value._value
        elif isinstance(value, int):
            value = shortcuts.BV(value, self.size)
        else:
            raise TypeError(f"Can't coerce {value} to SMTFPVector")

        self._name = name
        self._value = value
Beispiel #9
0
    def get_variable_groups_poly(
        cls, weight: Polynomial, real_vars: List[str]
    ) -> List[Tuple[Set[str], Polynomial]]:

        if isinstance(weight, Polynomial):
            if len(real_vars) > 0:
                result = []
                found_vars = weight.variables
                for v in real_vars:
                    if v not in found_vars:
                        result.append(({v}, Polynomial.from_constant(1)))
                return result + cls.get_variable_groups_poly(weight, [])

            if len(weight.poly_dict) > 1:
                return [(weight.variables, weight)]
            elif len(weight.poly_dict) == 0:
                return [(set(), Polynomial.from_constant(0))]
            else:
                result = defaultdict(lambda: Polynomial.from_constant(1))
                for name, value in weight.poly_dict.items():
                    if len(name) == 0:
                        result[frozenset()] *= Polynomial.from_constant(value)
                    else:
                        for v in name:
                            result[frozenset((v,))] *= Polynomial.from_smt(
                                smt.Symbol(v, smt.REAL)
                            )
                        result[frozenset()] *= Polynomial.from_constant(value)
                return list(result.items())
        else:
            raise NotImplementedError
Beispiel #10
0
def sympy2pysmt(expr, expr_type=None):
    if type(expr
            ) == Poly:  # turn Poly instances into generic sympy expressions
        expr = expr.as_expr()

    op = type(expr)

    if len(expr.free_symbols) == 0:
        if expr.is_Boolean:
            return smt.Bool(bool(expr))
        elif expr.is_number:
            return smt.Real(float(expr))

    elif op == sym.Symbol:
        if expr_type is None:
            raise ValueError(
                "Can't create a pysmt Symbol without type information")

        return smt.Symbol(expr.name, expr_type)

    elif op in SYM2SMT:
        if expr_type is None:
            expr_type = OP_TYPE[op]

        smtargs = [sympy2pysmt(c, expr_type) for c in expr.args]
        return SYM2SMT[op](*smtargs)

    raise NotImplementedError(f"SYMPY -> PYSMT Not implemented for op: {op}")
Beispiel #11
0
    def query_oracle(self, dis_formula):
        self.solver_oracle.reset_assertions()
        c0 = self.attack_formulas.oracle_ckt_at_frame(0)
        for i in range(len(c0)):
            self.solver_oracle.add_assertion(c0[i])

        dis_out = []
        for d in range(1, self.unroll_depth + 1):
            c0 = self.attack_formulas.oracle_ckt_at_frame(d)
            for i in range(len(c0)):
                self.solver_oracle.add_assertion(c0[i])

            self.solver_oracle.add_assertion(pystm.And(dis_formula[d - 1]))
            if not self.solver_oracle.is_sat(pystm.TRUE()):
                logging.critical('something is wrong in oracle query')
                exit()
            else:
                dip_out = []
                # for w in self.oracle_cir.output_wires:
                for w in self.obf_cir.output_wires:
                    f = pystm.Symbol(w + '@{}'.format(d))
                    dip_out.append(self.solver_oracle.get_value(f))
                dis_out.append(dip_out)
        logging.info(dis_out)
        return dis_out
Beispiel #12
0
    def to_sbv(self, size: int) -> SMTBitVector:
        cls = type(self)
        ufs = _uf_table[cls]['to_usbv']
        if size not in ufs:
            name = '.'.join((cls.__name__, f'to_sbv[{size}]'))
            ufs[size] = shortcuts.Symbol(
                name, FunctionType(BVType(size), (BVType(self.size), )))

        return SMTBitVector[size](ufs[size](self._value))
Beispiel #13
0
    def ast_to_smt(self, node):
        """
        :type node: Node
        """
        def convert_children(number=None):
            if number is not None and len(node.children) != number:
                raise Exception(
                    "The number of children ({}) differed from {}".format(
                        len(node.children), number))
            return [self.ast_to_smt(child) for child in node.children]

        if node.name == "ite":
            return smt.Ite(*convert_children(3))
        elif node.name == "~":
            return smt.Not(*convert_children(1))
        elif node.name == "^":
            return smt.Pow(*convert_children(2))
        elif node.name == "&":
            return smt.And(*convert_children())
        elif node.name == "|":
            return smt.Or(*convert_children())
        elif node.name == "*":
            return smt.Times(*convert_children())
        elif node.name == "+":
            return smt.Plus(*convert_children())
        elif node.name == "-":
            return smt.Minus(*convert_children(2))
        elif node.name == "<=":
            return smt.LE(*convert_children(2))
        elif node.name == ">=":
            return smt.GE(*convert_children(2))
        elif node.name == "<":
            return smt.LT(*convert_children(2))
        elif node.name == ">":
            return smt.GT(*convert_children(2))
        elif node.name == "=":
            return smt.Equals(*convert_children(2))
        elif node.name == "const":
            c_type, c_value = [child.name for child in node.children]
            if c_type == "bool":
                return smt.Bool(bool(c_value))
            elif c_type == "real":
                return smt.Real(float(c_value))
            else:
                raise Exception("Unknown constant type {}".format(c_type))
        elif node.name == "var":
            v_type, v_name = [child.name for child in node.children]
            if v_type == "bool":
                v_smt_type = smt.BOOL
            elif v_type == "real":
                v_smt_type = smt.REAL
            else:
                raise Exception("Unknown variable type {}".format(v_type))
            return smt.Symbol(v_name, v_smt_type)
        else:
            raise RuntimeError("Unrecognized node type '{}'".format(node.name))
Beispiel #14
0
def test_convert_support():
    x, y = smt.Symbol("x", smt.REAL), smt.Symbol("y", smt.REAL)
    a = smt.Symbol("a", smt.BOOL)
    formula = (x < 0) | (~a & (x < -1)) | smt.Ite(a, x < 4, x < 8)
    # Convert formula into abstracted one (replacing inequalities)
    env, repl_formula, literal_info = extract_and_replace_literals(formula)
    result = compile_to_sdd(formula=repl_formula,
                            literals=literal_info,
                            vtree=None)
    recovered = recover_formula(sdd_node=result,
                                literals=literal_info,
                                env=env)
    # print(pretty_print(recovered))
    with smt.Solver() as solver:
        solver.add_assertion(~smt.Iff(formula, recovered))
        # print(pretty_print(formula))
        # print(pretty_print(recovered))
        assert not solver.solve(
        ), f"Expected UNSAT but found model {solver.get_model()}"
Beispiel #15
0
    def check(self, sys):
        styp = type(sys)
        if styp is BD.VarIntro:
            smtsym = SMT.Symbol(repr(sys.name), SMT.INT)
            self._ctxt.set(sys.name, smtsym)
            self.check(sys.cont)

        elif styp is BD.RelIntro:
            Rtyp = SMT.FunctionType(SMT.BOOL,
                                    [SMT.INT for i in range(0, sys.n_args)])
            smtsym = SMT.Symbol(repr(sys.name), Rtyp)
            self._ctxt.set(sys.name, smtsym)
            self.check(sys.cont)

        elif styp is BD.Guard:
            pred = self.formula(sys.pred)
            self._slv.add_assertion(pred)
            self.check(sys.cont)

        elif styp is BD.Both:
            # make sure we can backtrack from the first branch
            self._slv.push()
            self._ctxt.push()
            self.check(sys.lhs)
            self._ctxt.pop()
            self._slv.pop()
            # now the second branch we can just proceed
            self.check(sys.rhs)

        elif styp is BD.Check:
            pred = SMT.Not(self.formula(sys.pred))
            failure = self._slv.is_sat(pred)
            if failure:
                mapping = self._get_solution(pred)
                self._err(sys, f"Out of Bounds Access:\n{mapping}")
            # continue regardless
            self.check(sys.cont)

        elif styp is BD.NullSys:
            pass
Beispiel #16
0
 def __encodeTerminal(symbolicExpression, type):
     if isinstance(symbolicExpression, sympy.Symbol):
         if type.literal == 'Integer':
             return pysmt.Symbol(symbolicExpression.name, pysmt.INT)
         elif type.literal == 'Real':
             return pysmt.Symbol(symbolicExpression.name, pysmt.REAL)
         elif type.literal == 'Bool':
             return pysmt.Symbol(symbolicExpression.name, pysmt.BOOL)
         else:  # type.literal == 'BitVector'
             return pysmt.Symbol(symbolicExpression.name, pysmt.BVType(type.size))
     else:
         if type.literal == 'Integer':
             return pysmt.Int(symbolicExpression.p)
         elif type.literal == 'Real':
             if isinstance(symbolicExpression, sympy.Rational):
                 return pysmt.Real(symbolicExpression.p / symbolicExpression.q)
             else:  # isinstance(symbolicExpression, sympy.Float)
                 return pysmt.Real(symbolicExpression)
         elif type.literal == 'Bool':
             return pysmt.Bool(symbolicExpression)
         else:  # type.literal == 'BitVector'
             return pysmt.BV(symbolicExpression, type.size)
Beispiel #17
0
    def __init_subclass__(cls):
        _uf_table[cls] = ufs = dict()
        T = BVType(cls.size)
        for method_name, *args in _SIGS:
            args = [T if x is None else x for x in args]
            rtype = args[-1]
            params = args[:-1]
            name = '.'.join((cls.__name__, method_name))
            ufs[method_name] = shortcuts.Symbol(name,
                                                FunctionType(rtype, params))

        ufs['to_sbv'] = dict()
        ufs['to_ubv'] = dict()
Beispiel #18
0
    def query_dip_generator(self):
        dis_boolean = []

        for d in range(1, self.unroll_depth + 1):
            dip_boolean = []
            for w in self.obf_cir.input_wires:
                f = pystm.Symbol(w + '@{}'.format(d))
                if self.solver_obf.get_py_value(f):
                    dip_boolean.append(pystm.TRUE())
                else:
                    dip_boolean.append(pystm.FALSE())
            dis_boolean.append(dip_boolean)

        return dis_boolean
Beispiel #19
0
    def _exp_to_smt(self, expression):
        if isinstance(expression, sympy.Add):
            return smt.Plus([self._exp_to_smt(arg) for arg in expression.args])
        elif isinstance(expression, sympy.Mul):
            return smt.Times(*[self._exp_to_smt(arg) for arg in expression.args])
        elif isinstance(expression, sympy.Symbol):
            return smt.Symbol(str(expression), INT)

        try:
            expression = int(expression)
            return smt.Int(expression)
        except ValueError:
            pass
        raise RuntimeError("Could not parse {} of type {}".format(expression, type(expression)))
Beispiel #20
0
    def __init__(self, value=SMYBOLIC, *, name=AUTOMATIC, prefix=AUTOMATIC):
        if (name is not AUTOMATIC
                or prefix is not AUTOMATIC) and value is not SMYBOLIC:
            raise TypeError('Can only name symbolic variables')
        elif name is not AUTOMATIC and prefix is not AUTOMATIC:
            raise ValueError('Can only set either name or prefix not both')
        elif name is not AUTOMATIC:
            if not isinstance(name, str):
                raise TypeError('Name must be string')
            elif name in _name_table:
                raise ValueError(f'Name {name} already in use')
            elif _name_re.fullmatch(name):
                warnings.warn(
                    'Name looks like an auto generated name, this might break things'
                )
            _name_table[name] = self
        elif prefix is not AUTOMATIC:
            name = _gen_name(prefix)
            _name_table[name] = self
        elif name is AUTOMATIC and value is SMYBOLIC:
            name = _gen_name()
            _name_table[name] = self

        if value is SMYBOLIC:
            self._value = smt.Symbol(name, BOOL)
        elif isinstance(value, pysmt.fnode.FNode):
            if value.get_type().is_bool_type():
                self._value = value
            else:
                raise TypeError(f'Expected bool type not {value.get_type()}')
        elif isinstance(value, SMTBit):
            if name is not AUTOMATIC and name != value.name:
                warnings.warn(
                    'Changing the name of a SMTBit does not cause a new underlying smt variable to be created'
                )
            self._value = value._value
        elif isinstance(value, bool):
            self._value = smt.Bool(value)
        elif isinstance(value, int):
            if value not in {0, 1}:
                raise ValueError(
                    'Bit must have value 0 or 1 not {}'.format(value))
            self._value = smt.Bool(bool(value))
        elif hasattr(value, '__bool__'):
            self._value = smt.Bool(bool(value))
        else:
            raise TypeError("Can't coerce {} to Bit".format(type(value)))

        self._name = name
        self._value = smt.simplify(self._value)
Beispiel #21
0
 def print_keys(self):
     # logging.warning('print keys')
     # add initial states
     c0, c1 = self.attack_formulas.obf_ckt_at_frame(0)
     for i in range(len(c0)):
         self.solver_key.add_assertion(c0[i])
         self.solver_key.add_assertion(c1[i])
     if self.solver_key.solve():
         key = ''
         for w in self.obf_cir.key_wires:
             k = w + '_0'
             if self.solver_key.get_py_value(pystm.Symbol(k)):
                 key += '1'
             else:
                 key += '0'
         logging.warning('iterations={}, highest depth={}'.format(self.iteration, self.highest_depth))
         # logging.warning("key=%s" % key[::-1])
         logging.warning("key=%s" % key)
     else:
         logging.warning('something is wrong! could not find a correct key')
Beispiel #22
0
    def test_bv2pysmt(self):
        bvx, bvy = Variable("x", 8), Variable("y", 8)
        psx, psy = bv2pysmt(bvx), bv2pysmt(bvy)

        self.assertEqual(bv2pysmt(Constant(0, 8)), sc.BV(0, 8))
        self.assertEqual(psx, sc.Symbol("x", typing.BVType(8)))

        self.assertEqual(bv2pysmt(~bvx), sc.BVNot(psx))
        self.assertEqual(bv2pysmt(bvx & bvy), sc.BVAnd(psx, psy))
        self.assertEqual(bv2pysmt(bvx | bvy), sc.BVOr(psx, psy))
        self.assertEqual(bv2pysmt(bvx ^ bvy), sc.BVXor(psx, psy))

        self.assertEqual(bv2pysmt(BvComp(bvx, bvy)), sc.Equals(psx, psy))
        self.assertEqual(bv2pysmt(BvNot(BvComp(bvx, bvy))),
                         sc.Not(sc.Equals(psx, psy)))

        self.assertEqual(bv2pysmt(bvx < bvy), sc.BVULT(psx, psy))
        self.assertEqual(bv2pysmt(bvx <= bvy), sc.BVULE(psx, psy))
        self.assertEqual(bv2pysmt(bvx > bvy), sc.BVUGT(psx, psy))
        self.assertEqual(bv2pysmt(bvx >= bvy), sc.BVUGE(psx, psy))

        self.assertEqual(bv2pysmt(bvx << bvy), sc.BVLShl(psx, psy))
        self.assertEqual(bv2pysmt(bvx >> bvy), sc.BVLShr(psx, psy))
        self.assertEqual(bv2pysmt(RotateLeft(bvx, 1)), sc.BVRol(psx, 1))
        self.assertEqual(bv2pysmt(RotateRight(bvx, 1)), sc.BVRor(psx, 1))

        self.assertEqual(bv2pysmt(bvx[4:2]), sc.BVExtract(psx, 2, 4))
        self.assertEqual(bv2pysmt(Concat(bvx, bvy)), sc.BVConcat(psx, psy))
        # zeroextend reduces to Concat
        # self.assertEqual(bv2pysmt(ZeroExtend(bvx, 2)), sc.BVZExt(psx, 2))
        self.assertEqual(bv2pysmt(Repeat(bvx, 2)), psx.BVRepeat(2))

        self.assertEqual(bv2pysmt(-bvx), sc.BVNeg(psx))
        self.assertEqual(bv2pysmt(bvx + bvy), sc.BVAdd(psx, psy))
        # bvsum reduces to add
        # self.assertEqual(bv2pysmt(bvx - bvy), sc.BVSub(psx, psy))
        self.assertEqual(bv2pysmt(bvx * bvy), sc.BVMul(psx, psy))
        self.assertEqual(bv2pysmt(bvx / bvy), sc.BVUDiv(psx, psy))
        self.assertEqual(bv2pysmt(bvx % bvy), sc.BVURem(psx, psy))
 def symbol(self, name):
     return smt.Symbol(name, REAL)
Beispiel #24
0
def balancer_flow_formula(junctions, width, length):
    formula = []
    belts = [[
        Belt(s.Symbol(f'b{i}[{x}].rho', t.REAL),
             s.Symbol(f'b{i}[{x}].v', t.REAL)) for x in range(length + 1)
    ] for i in range(width)]
    for beltway in belts:
        for belt in beltway:
            formula.extend(domain(belt))
    if not s.is_sat(s.And(formula)):
        raise Exception('Domain is not SAT :/')

    # Balancing rules.
    junctions_by_x = [[] for x in range(length + 1)]
    for (x, y1, y2) in junctions:
        junctions_by_x[x].append((y1, y2))
        inn = x - 1
        out = x

        input_rho = s.Plus(belts[y1][inn].rho, belts[y2][inn].rho)
        # We want to put half of the input on each output.
        half_input = s.Div(input_rho, s.Real(2))

        # If output velocity is less than the half_input that we would like to
        # assign to it, we've filled it. Velocity is a hard limit, because it's out
        # of influence of this splitter. We'll set the output density to 1 in that
        # case. Aside: The flux is the min of that and the velocity, so if
        # out-velocity is the limiting factor, it won't change the flux calculation
        # to just assign rho_out = v_out.
        #
        # Now, the excess that we couldn't assign has to go somewhere: (1) to the
        # other output belt; if that's full, (2) feed back up the chain by reducing
        # input velocities.
        excess_from_1 = s.Max(s.Real(0), s.Minus(half_input, belts[y1][out].v))
        excess_from_2 = s.Max(s.Real(0), s.Minus(half_input, belts[y2][out].v))

        # This formula is most accurate for assignment > velocity (density will
        # equal 1), but it doesn't change the flux calculation just toset rho to
        # the velocity when velocity limits flow. (So you should be able to replace
        # the Ite by v_out and be OK.)
        formula.append(
            s.Equals(
                belts[y1][out].rho,
                s.Ite(half_input + excess_from_2 > belts[y1][out].v, s.Real(1),
                      half_input + excess_from_2)))
        formula.append(
            s.Equals(
                belts[y2][out].rho,
                s.Ite(half_input + excess_from_1 > belts[y2][out].v, s.Real(1),
                      half_input + excess_from_1)))

        output_v = s.Plus(belts[y1][out].v, belts[y2][out].v)
        half_output = s.Div(output_v, s.Real(2))
        unused_density_from_1 = s.Max(s.Real(0),
                                      s.Minus(half_output, belts[y1][inn].rho))
        unused_density_from_2 = s.Max(s.Real(0),
                                      s.Minus(half_output, belts[y2][inn].rho))

        formula.append(
            s.Equals(
                belts[y1][inn].v,
                s.Ite(half_output + unused_density_from_2 > belts[y1][inn].rho,
                      s.Real(1), half_output + unused_density_from_2)))
        formula.append(
            s.Equals(
                belts[y2][inn].v,
                s.Ite(half_output + unused_density_from_1 > belts[y2][inn].rho,
                      s.Real(1), half_output + unused_density_from_1)))
        # Conservation of flux at each junction.
        input_flux = s.Plus(belts[y1][inn].flux, belts[y2][inn].flux)
        output_flux = s.Plus(belts[y1][out].flux, belts[y2][out].flux)
        formula.append(s.Equals(input_flux, output_flux))

    # Any belts not involved in a junction are pass-throughs. Their successive
    # values must remain equal.
    thru_belts = [
        list(
            set(range(width)) - {y1
                                 for y1, y2 in junctions_by_x[x]} -
            {y2
             for y1, y2 in junctions_by_x[x]}) for x in range(length + 1)
    ]
    for x, thru in enumerate(thru_belts[1:]):
        for y in thru:
            formula.append(s.Equals(belts[y][x].rho, belts[y][x + 1].rho))
            formula.append(s.Equals(belts[y][x].v, belts[y][x + 1].v))

    return formula, belts
Beispiel #25
0
def find_cnf(data, domain, active_indices, solver, n_c, n_h):
    # Constants
    n_b_original = len(domain.bool_vars)
    n_b = n_b_original * 2
    n_r = len(domain.real_vars)
    n_d = len(data)

    real_features = [[row[v] for v in domain.real_vars] for row, _ in data]
    bool_features = [[row[v] for v in domain.bool_vars] for row, _ in data]
    labels = [row[1] for row in data]

    # Variables
    a_hr = [[
        smt.Symbol("a_hr[{}][{}]".format(h, r), REAL) for r in range(n_r)
    ] for h in range(n_h)]
    b_h = [smt.Symbol("b_h[{}]".format(h), REAL) for h in range(n_h)]
    s_ch = [[smt.Symbol("s_ch[{}][{}]".format(c, h)) for h in range(n_h)]
            for c in range(n_c)]
    s_cb = [[smt.Symbol("s_cb[{}][{}]".format(c, b)) for b in range(n_b)]
            for c in range(n_c)]

    # Aux variables
    s_ih = [[smt.Symbol("s_ih[{}][{}]".format(i, h)) for h in range(n_h)]
            for i in range(n_d)]
    s_ic = [[smt.Symbol("s_ic[{}][{}]".format(i, c)) for c in range(n_c)]
            for i in range(n_d)]

    # Constraints
    for i in active_indices:
        x_r, x_b, label = real_features[i], bool_features[i], labels[i]

        for h in range(n_h):
            sum_coefficients = smt.Plus(
                [a_hr[h][r] * smt.Real(x_r[r]) for r in range(n_r)])
            if label:
                solver.add_assertion(
                    smt.Iff(s_ih[i][h], sum_coefficients + DELTA <= b_h[h]))
            else:
                solver.add_assertion(
                    smt.Iff(s_ih[i][h], sum_coefficients - DELTA <= b_h[h]))

        for c in range(n_c):
            solver.add_assertion(
                smt.Iff(
                    s_ic[i][c],
                    smt.Or([smt.FALSE()] + [(s_ch[c][h] & s_ih[i][h])
                                            for h in range(n_h)] +
                           [s_cb[c][b]
                            for b in range(n_b_original) if x_b[b]] + [
                                s_cb[c][b] for b in range(n_b_original, n_b)
                                if not x_b[b - n_b_original]
                            ])))

        if label:
            solver.add_assertion(smt.And([s_ic[i][c] for c in range(n_c)]))
        else:
            solver.add_assertion(smt.Or([~s_ic[i][c] for c in range(n_c)]))

    if not solver.solve():
        return None
    model = solver.get_model()

    x_vars = [domain.get_symbol(domain.real_vars[r]) for r in range(n_r)]
    half_spaces = [
        smt.Plus([model.get_value(a_hr[h][r]) * x_vars[r]
                  for r in range(n_r)]) <= model.get_value(b_h[h])
        for h in range(n_h)
    ]

    b_vars = [
        domain.get_symbol(domain.bool_vars[b]) for b in range(n_b_original)
    ]
    bool_literals = [b_vars[b] for b in range(n_b_original)]
    bool_literals += [~b_vars[b] for b in range(n_b - n_b_original)]

    conjunctions = [
        [half_spaces[h]
         for h in range(n_h) if model.get_py_value(s_ch[c][h])] + [
             bool_literals[b]
             for b in range(n_b) if model.get_py_value(s_cb[c][b])
         ] for c in range(n_c)
    ]

    return smt.And([smt.Or(conjunction) for conjunction in conjunctions])
Beispiel #26
0
 def to_symbol(s):
     return smt.Symbol(s, typename=smt.types.INT)
Beispiel #27
0
 def walk_symbol(self, name, v_type):
     return smt.Symbol(name, v_type)
Beispiel #28
0
def example1(domain):
    x, y = smt.Symbol("x", REAL), smt.Symbol("y", REAL)
    return domain, (x + y <= 0.5)
Beispiel #29
0
def example6(domain):
    x, y = smt.Symbol("x", REAL), smt.Symbol("y", REAL)
    return domain, (
        ((-1.27554738321 < 2.00504448571 * x + -2.40276942762 * y) |
         (4.56336137649 < 11.0066321223 * x + -9.72098326672 * y)) &
        (11.0066321223 * x + -9.72098326672 * y <= 4.56336137649))
Beispiel #30
0
def example5(domain):
    x, y = smt.Symbol("x", REAL), smt.Symbol("y", REAL)
    return domain, (
        ((-1.81491574069 < 2.82223533496 * x + -2.86421413834 * y) |
         (1.74295350642 < 5.75692214636 * x + -5.67797696689 * y)) &
        (5.75692214636 * x + -5.67797696689 * y <= 1.74295350642))