Exemplo n.º 1
0
def prox_second_order_cone(expr):
    args = []
    if (expr.expression_type == Expression.INDICATOR and
        expr.cone.cone_type == Cone.SECOND_ORDER):
        args = expr.arg
    else:
        f_expr, t_expr = get_epigraph(expr)
        if (f_expr and
            f_expr.expression_type == Expression.NORM_P and
            f_expr.p == 2):
            args = [t_expr, f_expr.arg[0]]
            # make second argument a row vector
            args[1] = expression.reshape(args[1], 1, dim(args[1]))
    if not args:
        return MatchResult(False)

    scalar_arg0, constrs0 = convert_scalar(args[0])
    scalar_arg1, constrs1 = convert_scalar(args[1])
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(
                prox_function_type=ProxFunction.SECOND_ORDER_CONE,
                arg_size=[
                    Size(dim=dims(args[0])),
                    Size(dim=dims(args[1]))]),
            scalar_arg0,
            scalar_arg1),
        constrs0 + constrs1)
Exemplo n.º 2
0
def ones(*dims):
    data = {}
    return Expression(expression_type=expression_pb2.Expression.CONSTANT,
                      size=Size(dim=dims),
                      data=data,
                      constant=_constant.store(np.ones(dims), data),
                      func_curvature=CONSTANT)
Exemplo n.º 3
0
def neg_log_det_epigraph(expr):
    if len(expr.arg[0].arg) != 2:
        return MatchResult(False)

    for i in range(2):
        if expr.arg[0].arg[i].expression_type == Expression.LOG_DET:
            exprs = [expr.arg[0].arg[i],
                        expr.arg[0].arg[1-i]]
            break
    else:
        return MatchResult(False)

    arg = exprs[0].arg[0]
    scalar_arg, constrs = convert_scalar(arg)

    epi_function = create_prox(
                alpha=1,
                prox_function_type=ProxFunction.NEG_LOG_DET,
                arg_size=[Size(dim=dims(arg))])
    epi_function.epigraph = True

    return MatchResult(
        True,
        expression.prox_function(
            epi_function,
            *[scalar_arg, exprs[1]]),
        constrs)
Exemplo n.º 4
0
def epigraph(expr):
    f_expr, t_expr = get_epigraph(expr)
    if f_expr:
        for rule in BASE_RULES:
            result = rule(f_expr)

            if result.match:
                epi_function = result.prox_expr.prox_function
                epi_function.epigraph = True
                epi_function.arg_size.add().CopyFrom(Size(dim=dims(t_expr)))

                linear_t_expr = linear.transform_expr(t_expr)
                if linear_t_expr.affine_props.scalar:
                    constrs = []
                else:
                    linear_t_expr, constrs = epi_transform(
                        linear_t_expr, "scalar")

                return MatchResult(
                    True,
                    expression.prox_function(
                        epi_function, *(result.prox_expr.arg + [linear_t_expr])),
                    result.raw_exprs + constrs)

        # No epigraph transform found, do conic transformation
        obj, constrs = conic.transform_expr(f_expr)
        return MatchResult(
            True,
            None,
            [expression.leq_constraint(obj, t_expr)] + constrs)

    # Not in epigraph form
    return MatchResult(False)
Exemplo n.º 5
0
def variable(m, n, variable_id):
    return Expression(
        expression_type=expression_pb2.Expression.VARIABLE,
        size=Size(dim=[m, n]),
        variable=expression_pb2.Variable(variable_id=variable_id),
        func_curvature=Curvature(curvature_type=Curvature.AFFINE,
                                 elementwise=True,
                                 scalar_multiple=True))
Exemplo n.º 6
0
def vstack(*args):
    return Expression(
        expression_type=expression_pb2.Expression.VSTACK,
        func_curvature=AFFINE,
        size=Size(
            dim=reduce(lambda a, b: stack_dims(a, b, 0), (dims(a)
                                                          for a in args))),
        arg=args)
Exemplo n.º 7
0
def diag_vec(x):
    if dim(x, 1) != 1:
        raise ExpressionError("diag_vec on non vector")

    n = dim(x, 0)
    return Expression(expression_type=expression_pb2.Expression.DIAG_VEC,
                      size=Size(dim=[n, n]),
                      func_curvature=AFFINE,
                      arg=[x])
Exemplo n.º 8
0
def scalar_constant(scalar, size=None):
    if size is None:
        size = (1, 1)

    return Expression(expression_type=expression_pb2.Expression.CONSTANT,
                      size=Size(dim=size),
                      constant=expression_pb2.Constant(
                          constant_type=expression_pb2.Constant.SCALAR,
                          scalar=scalar),
                      func_curvature=CONSTANT)
Exemplo n.º 9
0
def linear_map(A, x):
    if dim(x, 1) != 1:
        raise ExpressionError("applying linear map to non vector", x)
    if A.n != dim(x):
        raise ExpressionError("linear map has wrong size: %s" % A, x)

    return Expression(expression_type=expression_pb2.Expression.LINEAR_MAP,
                      size=Size(dim=[A.m, 1]),
                      func_curvature=AFFINE,
                      linear_map=A.proto,
                      data=A.data,
                      arg=[x])
Exemplo n.º 10
0
def prox_lambda_max(expr):
    if expr.expression_type == Expression.LAMBDA_MAX:
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    scalar_arg, constrs = convert_scalar(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(prox_function_type=ProxFunction.LAMBDA_MAX,
                        arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
Exemplo n.º 11
0
def parameter(m, n, parameter_id, constant_type, sign):
    # NOTE(mwytock): we assume all parameters are dense matrices for purposes of
    # symbolic transformation.
    return Expression(expression_type=expression_pb2.Expression.CONSTANT,
                      size=Size(dim=[m, n]),
                      func_curvature=CONSTANT,
                      constant=expression_pb2.Constant(
                          constant_type=constant_type,
                          parameter_id=parameter_id,
                          m=m,
                          n=n),
                      sign=sign)
Exemplo n.º 12
0
def add(*args):
    if not args:
        raise ValueError("adding null args")

    return Expression(
        expression_type=expression_pb2.Expression.ADD,
        arg=args,
        size=Size(
            dim=reduce(lambda a, b: elementwise_dims(a, b), (dims(a)
                                                             for a in args))),
        arg_monotonicity=len(args) * [INCREASING],
        func_curvature=AFFINE)
Exemplo n.º 13
0
def _multiply(args, elemwise=False):
    if not args:
        raise ValueError("multiplying null args")

    op_dims = elementwise_dims if elemwise else matrix_multiply_dims
    return Expression(
        expression_type=(expression_pb2.Expression.MULTIPLY_ELEMENTWISE
                         if elemwise else expression_pb2.Expression.MULTIPLY),
        arg=args,
        size=Size(dim=reduce(lambda a, b: op_dims(a, b), (dims(a)
                                                          for a in args))),
        func_curvature=AFFINE)
Exemplo n.º 14
0
def prox_norm_nuclear(expr):
    if expr.expression_type == Expression.NORM_NUC:
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    scalar_arg, constrs = convert_scalar(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(prox_function_type=ProxFunction.NORM_NUCLEAR,
                        arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
Exemplo n.º 15
0
def prox_semidefinite(expr):
    if (expr.expression_type == Expression.INDICATOR
            and expr.cone.cone_type == Cone.SEMIDEFINITE):
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    scalar_arg, constrs = convert_scalar(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(prox_function_type=ProxFunction.SEMIDEFINITE,
                        arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
Exemplo n.º 16
0
def prox_log_det(expr):
    if expr.expression_type == Expression.LOG_DET:
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    scalar_arg, constrs = convert_scalar(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(alpha=-1,
                        prox_function_type=ProxFunction.NEG_LOG_DET,
                        arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
Exemplo n.º 17
0
def prox_norm_1(expr):
    if (expr.expression_type == Expression.NORM_P and expr.p == 1):
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    diagonal_arg, constrs = convert_diagonal(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(prox_function_type=ProxFunction.NORM_1,
                        arg_size=[Size(dim=dims(arg))]), diagonal_arg),
        constrs)
Exemplo n.º 18
0
def prox_sum_hinge(expr):
    arg = get_hinge_arg(expr)
    if not arg:
        return MatchResult(False)

    diagonal_arg, constrs = convert_diagonal(arg)
    return MatchResult(
        True,
        expression.prox_function(create_prox(
            prox_function_type=ProxFunction.SUM_HINGE,
            arg_size=[Size(dim=dims(arg))],
            has_axis=expr.has_axis,
            axis=expr.axis),
                                 diagonal_arg,
                                 size=dims(expr)), constrs)
Exemplo n.º 19
0
def prox_log_sum_exp(expr):
    if expr.expression_type == Expression.LOG_SUM_EXP:
        arg = expr.arg[0]
    else:
        return MatchResult(False)

    scalar_arg, constrs = convert_scalar(arg)
    return MatchResult(
        True,
        expression.prox_function(create_prox(
            prox_function_type=ProxFunction.LOG_SUM_EXP,
            arg_size=[Size(dim=dims(arg))],
            has_axis=expr.has_axis,
            axis=expr.axis),
                                 scalar_arg,
                                 size=dims(expr)), constrs)
Exemplo n.º 20
0
def reshape(arg, m, n):
    if dim(arg, 0) == m and dim(arg, 1) == n:
        return arg

    if m * n != dim(arg):
        raise ExpressionError("cant reshape to %d x %d" % (m, n), arg)

    # If we have two reshapes that "undo" each other, cancel them out
    if (arg.expression_type == expression_pb2.Expression.RESHAPE
            and dim(arg.arg[0], 0) == m and dim(arg.arg[0], 1) == n):
        return arg.arg[0]

    return Expression(expression_type=expression_pb2.Expression.RESHAPE,
                      arg=[arg],
                      size=Size(dim=[m, n]),
                      func_curvature=AFFINE,
                      sign=arg.sign)
Exemplo n.º 21
0
def index(x, start_i, stop_i, start_j=None, stop_j=None):
    if start_j is None and stop_j is None:
        start_j = 0
        stop_j = x.size.dim[1]

    if (dim(x, 0) == stop_i - start_i and dim(x, 1) == stop_j - start_j):
        return x

    return Expression(expression_type=expression_pb2.Expression.INDEX,
                      size=Size(dim=[stop_i - start_i, stop_j - start_j]),
                      func_curvature=AFFINE,
                      key=[
                          expression_pb2.Slice(start=start_i,
                                               stop=stop_i,
                                               step=1),
                          expression_pb2.Slice(start=start_j,
                                               stop=stop_j,
                                               step=1)
                      ],
                      arg=[x])
Exemplo n.º 22
0
def prox_sum_deadzone(expr):
    hinge_arg = get_hinge_arg(expr)
    arg = None
    if (hinge_arg and hinge_arg.expression_type == Expression.ADD
            and len(hinge_arg.arg) == 2
            and hinge_arg.arg[0].expression_type == Expression.ABS):
        m = get_scalar_constant(hinge_arg.arg[1])
        if m <= 0:
            arg = hinge_arg.arg[0].arg[0]
    if not arg:
        return MatchResult(False)

    diagonal_arg, constrs = convert_diagonal(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(prox_function_type=ProxFunction.SUM_DEADZONE,
                        scaled_zone_params=ProxFunction.ScaledZoneParams(m=-m),
                        arg_size=[Size(dim=dims(arg))]), diagonal_arg),
        constrs)
Exemplo n.º 23
0
def constant(m, n, scalar=None, constant=None, sign=None, data={}):
    if scalar is not None:
        constant = expression_pb2.Constant(
            constant_type=expression_pb2.Constant.SCALAR, scalar=scalar)
        if scalar > 0:
            sign = Sign(sign_type=expression_pb2.Sign.POSITIVE)
        elif scalar < 0:
            sign = Sign(sign_type=expression_pb2.Sign.NEGATIVE)
        else:
            sign = Sign(sign_type=expression_pb2.Sign.ZERO)

    elif constant is None:
        raise ValueError("need either scalar or constant")

    return Expression(
        expression_type=expression_pb2.Expression.CONSTANT,
        data=data,
        size=Size(dim=[m, n]),
        constant=constant,
        func_curvature=Curvature(curvature_type=Curvature.CONSTANT),
        sign=sign)
Exemplo n.º 24
0
def prox_sum_quantile(expr):
    arg = None
    if (expr.expression_type == Expression.SUM and
        expr.arg[0].expression_type == Expression.MAX_ELEMENTWISE and
        len(expr.arg[0].arg) == 2):

        alpha, x = get_quantile_arg(expr.arg[0].arg[0])
        beta, y  = get_quantile_arg(expr.arg[0].arg[1])
        if (x is not None and y is not None and x == y):
            if (alpha.sign.sign_type == Sign.NEGATIVE and
                beta.sign.sign_type == Sign.POSITIVE):
                alpha, beta = beta, expression.negate(alpha)
                arg = x
            elif (alpha.sign.sign_type == Sign.POSITIVE and
                  beta.sign.sign_type == Sign.NEGATIVE):
                beta = expression.negate(beta)
                arg = x

    if not arg:
        return MatchResult(False)

    alpha = linear.transform_expr(alpha)
    beta = linear.transform_expr(beta)
    data = alpha.expression_data()
    data.update(beta.expression_data())

    diagonal_arg, constrs = convert_diagonal(arg)
    return MatchResult(
        True,
        expression.prox_function(
            create_prox(
                prox_function_type=ProxFunction.SUM_QUANTILE,
                arg_size=[Size(dim=dims(arg))],
                scaled_zone_params=ProxFunction.ScaledZoneParams(
                    alpha_expr=alpha.proto_with_args,
                    beta_expr=beta.proto_with_args)),
            diagonal_arg,
            data=data),
        constrs)
Exemplo n.º 25
0
def prox_function(f, *args, **kwargs):
    return Expression(expression_type=expression_pb2.Expression.PROX_FUNCTION,
                      data=kwargs.get("data", {}),
                      size=Size(dim=kwargs.get("size", (1, 1))),
                      prox_function=f,
                      arg=args)
Exemplo n.º 26
0
def sum_largest(x, k):
    return Expression(expression_type=expression_pb2.Expression.SUM_LARGEST,
                      size=Size(dim=[1, 1]),
                      arg=[x],
                      k=k)
Exemplo n.º 27
0
def zero(x):
    return Expression(expression_type=expression_pb2.Expression.ZERO,
                      size=Size(dim=[1, 1]),
                      arg=[x])
Exemplo n.º 28
0
def sum_entries(x):
    return Expression(expression_type=expression_pb2.Expression.SUM,
                      size=Size(dim=[1, 1]),
                      func_curvature=AFFINE,
                      arg=[x])
Exemplo n.º 29
0
def trace(X):
    return Expression(expression_type=expression_pb2.Expression.TRACE,
                      size=Size(dim=[1, 1]),
                      func_curvature=AFFINE,
                      arg=[X])
Exemplo n.º 30
0
def transpose(x):
    m, n = x.size.dim
    return Expression(expression_type=expression_pb2.Expression.TRANSPOSE,
                      size=Size(dim=[n, m]),
                      func_curvature=AFFINE,
                      arg=[x])