Example #1
0
def special_constraint(constraint, kind):
    """special constraints are the following inequalities

    (1) x + y <= 1 | xy
    (2) x1 + x2 + x3 <= 1 | x1x2 + x1x3 + x2x3
    (3) x <= y | x - xy
    (4) x = y | x + y - 2xy
    (5) x + y >= 1 | 1- x- y + xy
    (6) x + y = 1 | 1 - x- y + 2xy

    Parameters
    ----------
    constraint : sympy.core.expr.Expr or subclasses
        constraint
    kind : string
        'lt', 'gt' or 'eq'

    Returns
    -------
    sympy.core.expr.Expr, bool
        constraint and Boolean, if it is a special constraint
    """
    constraint_as_dict = Poly(constraint).as_expr().as_coefficients_dict()
    constant = constraint_as_dict[1]
    if constant == 0:
        del constraint_as_dict[1]
        length = len(constraint_as_dict)
    else:
        length = len(constraint_as_dict)
        del constraint_as_dict[1]
    
    sufficient_cond = [all(
        coeff==1 or coeff==-1 for coeff in constraint_as_dict.values()),
                       constant in [0, -1],
                       length <= 4]
    if not all(sufficient_cond):
        return constraint, False
    
    condition_eq_1_2 = [any([length == 3, length == 4]), 
                        kind == 'lt',
                        all(coeff==1 for coeff in constraint_as_dict.values()), 
                        constant==-1]
    condition_eq_3_4 = [length == 2,
                        any([kind=='eq', kind=='lt']),
                        sum(constraint_as_dict.values())==0,
                        constant==0]
    condition_eq_5_6 = [length == 3,
                        any([kind=='eq', kind == 'gt']),
                        all(coeff==1 for coeff in constraint_as_dict.values()),
                        constant==-1]

    ### x + y <= 1, x + y + z <= 1
    if all(condition_eq_1_2):
        var_combinations = sorted(
            itermonomials(list(constraint.free_symbols), 2), 
            key=monomial_key('lex', list(constraint.free_symbols))
        )
        new_constraint = sum([
            func for func in var_combinations 
            if (not func.is_integer) and Poly(func).is_multivariate
        ])
        return new_constraint, True
    ### x <= y , x = y
    elif all(condition_eq_3_4):
        if kind == 'lt':
            new_constraint = [var for var, coeff in constraint_as_dict.items() 
                               if coeff==-1][0]
            var_combinations = sorted(
                itermonomials(list(constraint.free_symbols), 2), 
                key=monomial_key('lex', list(constraint.free_symbols))
            )
            new_constraint -= [term for term in var_combinations 
                                if term.is_Mul][0]
            return new_constraint, True
        if kind == 'eq' :
            var_combinations = sorted(
                itermonomials(list(constraint.free_symbols), 2), 
                key=monomial_key('lex', list(constraint.free_symbols))
            )
            new_constraint = 0
            for term in var_combinations:
                if term.is_symbol:
                    new_constraint += term
                elif term.is_Mul:
                    new_constraint -= 2*term 
        return new_constraint, True
    ### x + y >= 1, x + y = 1
    if all(condition_eq_5_6):
        var_combinations = sorted(
            itermonomials(list(constraint.free_symbols), 2), 
            key=monomial_key('lex', list(constraint.free_symbols))
        )
        new_constraint = 0
        for term in var_combinations:
            if term.is_integer or term.is_Mul:
                new_constraint += term
            if term.is_Mul and kind=='eq':
                new_constraint += term 
            if term.is_symbol and kind=='gt':
                new_constraint -= term
            if term.is_symbol and kind == 'eq':
                new_constraint -= 2*term
        return new_constraint, True

    return constraint, False