def prox_second_order_cone(expr): args = [] if (expr.expression_type == Expression.INDICATOR and expr.cone.cone_type == Cone.SECOND_ORDER): args = expr.arg else: f_expr, t_expr = get_epigraph(expr) if (f_expr and f_expr.expression_type == Expression.NORM_P and f_expr.p == 2): args = [t_expr, f_expr.arg[0]] # make second argument a row vector args[1] = expression.reshape(args[1], 1, dim(args[1])) if not args: return MatchResult(False) scalar_arg0, constrs0 = convert_scalar(args[0]) scalar_arg1, constrs1 = convert_scalar(args[1]) return MatchResult( True, expression.prox_function( create_prox( prox_function_type=ProxFunction.SECOND_ORDER_CONE, arg_size=[ Size(dim=dims(args[0])), Size(dim=dims(args[1]))]), scalar_arg0, scalar_arg1), constrs0 + constrs1)
def ones(*dims): data = {} return Expression(expression_type=expression_pb2.Expression.CONSTANT, size=Size(dim=dims), data=data, constant=_constant.store(np.ones(dims), data), func_curvature=CONSTANT)
def neg_log_det_epigraph(expr): if len(expr.arg[0].arg) != 2: return MatchResult(False) for i in range(2): if expr.arg[0].arg[i].expression_type == Expression.LOG_DET: exprs = [expr.arg[0].arg[i], expr.arg[0].arg[1-i]] break else: return MatchResult(False) arg = exprs[0].arg[0] scalar_arg, constrs = convert_scalar(arg) epi_function = create_prox( alpha=1, prox_function_type=ProxFunction.NEG_LOG_DET, arg_size=[Size(dim=dims(arg))]) epi_function.epigraph = True return MatchResult( True, expression.prox_function( epi_function, *[scalar_arg, exprs[1]]), constrs)
def epigraph(expr): f_expr, t_expr = get_epigraph(expr) if f_expr: for rule in BASE_RULES: result = rule(f_expr) if result.match: epi_function = result.prox_expr.prox_function epi_function.epigraph = True epi_function.arg_size.add().CopyFrom(Size(dim=dims(t_expr))) linear_t_expr = linear.transform_expr(t_expr) if linear_t_expr.affine_props.scalar: constrs = [] else: linear_t_expr, constrs = epi_transform( linear_t_expr, "scalar") return MatchResult( True, expression.prox_function( epi_function, *(result.prox_expr.arg + [linear_t_expr])), result.raw_exprs + constrs) # No epigraph transform found, do conic transformation obj, constrs = conic.transform_expr(f_expr) return MatchResult( True, None, [expression.leq_constraint(obj, t_expr)] + constrs) # Not in epigraph form return MatchResult(False)
def variable(m, n, variable_id): return Expression( expression_type=expression_pb2.Expression.VARIABLE, size=Size(dim=[m, n]), variable=expression_pb2.Variable(variable_id=variable_id), func_curvature=Curvature(curvature_type=Curvature.AFFINE, elementwise=True, scalar_multiple=True))
def vstack(*args): return Expression( expression_type=expression_pb2.Expression.VSTACK, func_curvature=AFFINE, size=Size( dim=reduce(lambda a, b: stack_dims(a, b, 0), (dims(a) for a in args))), arg=args)
def diag_vec(x): if dim(x, 1) != 1: raise ExpressionError("diag_vec on non vector") n = dim(x, 0) return Expression(expression_type=expression_pb2.Expression.DIAG_VEC, size=Size(dim=[n, n]), func_curvature=AFFINE, arg=[x])
def scalar_constant(scalar, size=None): if size is None: size = (1, 1) return Expression(expression_type=expression_pb2.Expression.CONSTANT, size=Size(dim=size), constant=expression_pb2.Constant( constant_type=expression_pb2.Constant.SCALAR, scalar=scalar), func_curvature=CONSTANT)
def linear_map(A, x): if dim(x, 1) != 1: raise ExpressionError("applying linear map to non vector", x) if A.n != dim(x): raise ExpressionError("linear map has wrong size: %s" % A, x) return Expression(expression_type=expression_pb2.Expression.LINEAR_MAP, size=Size(dim=[A.m, 1]), func_curvature=AFFINE, linear_map=A.proto, data=A.data, arg=[x])
def prox_lambda_max(expr): if expr.expression_type == Expression.LAMBDA_MAX: arg = expr.arg[0] else: return MatchResult(False) scalar_arg, constrs = convert_scalar(arg) return MatchResult( True, expression.prox_function( create_prox(prox_function_type=ProxFunction.LAMBDA_MAX, arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
def parameter(m, n, parameter_id, constant_type, sign): # NOTE(mwytock): we assume all parameters are dense matrices for purposes of # symbolic transformation. return Expression(expression_type=expression_pb2.Expression.CONSTANT, size=Size(dim=[m, n]), func_curvature=CONSTANT, constant=expression_pb2.Constant( constant_type=constant_type, parameter_id=parameter_id, m=m, n=n), sign=sign)
def add(*args): if not args: raise ValueError("adding null args") return Expression( expression_type=expression_pb2.Expression.ADD, arg=args, size=Size( dim=reduce(lambda a, b: elementwise_dims(a, b), (dims(a) for a in args))), arg_monotonicity=len(args) * [INCREASING], func_curvature=AFFINE)
def _multiply(args, elemwise=False): if not args: raise ValueError("multiplying null args") op_dims = elementwise_dims if elemwise else matrix_multiply_dims return Expression( expression_type=(expression_pb2.Expression.MULTIPLY_ELEMENTWISE if elemwise else expression_pb2.Expression.MULTIPLY), arg=args, size=Size(dim=reduce(lambda a, b: op_dims(a, b), (dims(a) for a in args))), func_curvature=AFFINE)
def prox_norm_nuclear(expr): if expr.expression_type == Expression.NORM_NUC: arg = expr.arg[0] else: return MatchResult(False) scalar_arg, constrs = convert_scalar(arg) return MatchResult( True, expression.prox_function( create_prox(prox_function_type=ProxFunction.NORM_NUCLEAR, arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
def prox_semidefinite(expr): if (expr.expression_type == Expression.INDICATOR and expr.cone.cone_type == Cone.SEMIDEFINITE): arg = expr.arg[0] else: return MatchResult(False) scalar_arg, constrs = convert_scalar(arg) return MatchResult( True, expression.prox_function( create_prox(prox_function_type=ProxFunction.SEMIDEFINITE, arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
def prox_log_det(expr): if expr.expression_type == Expression.LOG_DET: arg = expr.arg[0] else: return MatchResult(False) scalar_arg, constrs = convert_scalar(arg) return MatchResult( True, expression.prox_function( create_prox(alpha=-1, prox_function_type=ProxFunction.NEG_LOG_DET, arg_size=[Size(dim=dims(arg))]), scalar_arg), constrs)
def prox_norm_1(expr): if (expr.expression_type == Expression.NORM_P and expr.p == 1): arg = expr.arg[0] else: return MatchResult(False) diagonal_arg, constrs = convert_diagonal(arg) return MatchResult( True, expression.prox_function( create_prox(prox_function_type=ProxFunction.NORM_1, arg_size=[Size(dim=dims(arg))]), diagonal_arg), constrs)
def prox_sum_hinge(expr): arg = get_hinge_arg(expr) if not arg: return MatchResult(False) diagonal_arg, constrs = convert_diagonal(arg) return MatchResult( True, expression.prox_function(create_prox( prox_function_type=ProxFunction.SUM_HINGE, arg_size=[Size(dim=dims(arg))], has_axis=expr.has_axis, axis=expr.axis), diagonal_arg, size=dims(expr)), constrs)
def prox_log_sum_exp(expr): if expr.expression_type == Expression.LOG_SUM_EXP: arg = expr.arg[0] else: return MatchResult(False) scalar_arg, constrs = convert_scalar(arg) return MatchResult( True, expression.prox_function(create_prox( prox_function_type=ProxFunction.LOG_SUM_EXP, arg_size=[Size(dim=dims(arg))], has_axis=expr.has_axis, axis=expr.axis), scalar_arg, size=dims(expr)), constrs)
def reshape(arg, m, n): if dim(arg, 0) == m and dim(arg, 1) == n: return arg if m * n != dim(arg): raise ExpressionError("cant reshape to %d x %d" % (m, n), arg) # If we have two reshapes that "undo" each other, cancel them out if (arg.expression_type == expression_pb2.Expression.RESHAPE and dim(arg.arg[0], 0) == m and dim(arg.arg[0], 1) == n): return arg.arg[0] return Expression(expression_type=expression_pb2.Expression.RESHAPE, arg=[arg], size=Size(dim=[m, n]), func_curvature=AFFINE, sign=arg.sign)
def index(x, start_i, stop_i, start_j=None, stop_j=None): if start_j is None and stop_j is None: start_j = 0 stop_j = x.size.dim[1] if (dim(x, 0) == stop_i - start_i and dim(x, 1) == stop_j - start_j): return x return Expression(expression_type=expression_pb2.Expression.INDEX, size=Size(dim=[stop_i - start_i, stop_j - start_j]), func_curvature=AFFINE, key=[ expression_pb2.Slice(start=start_i, stop=stop_i, step=1), expression_pb2.Slice(start=start_j, stop=stop_j, step=1) ], arg=[x])
def prox_sum_deadzone(expr): hinge_arg = get_hinge_arg(expr) arg = None if (hinge_arg and hinge_arg.expression_type == Expression.ADD and len(hinge_arg.arg) == 2 and hinge_arg.arg[0].expression_type == Expression.ABS): m = get_scalar_constant(hinge_arg.arg[1]) if m <= 0: arg = hinge_arg.arg[0].arg[0] if not arg: return MatchResult(False) diagonal_arg, constrs = convert_diagonal(arg) return MatchResult( True, expression.prox_function( create_prox(prox_function_type=ProxFunction.SUM_DEADZONE, scaled_zone_params=ProxFunction.ScaledZoneParams(m=-m), arg_size=[Size(dim=dims(arg))]), diagonal_arg), constrs)
def constant(m, n, scalar=None, constant=None, sign=None, data={}): if scalar is not None: constant = expression_pb2.Constant( constant_type=expression_pb2.Constant.SCALAR, scalar=scalar) if scalar > 0: sign = Sign(sign_type=expression_pb2.Sign.POSITIVE) elif scalar < 0: sign = Sign(sign_type=expression_pb2.Sign.NEGATIVE) else: sign = Sign(sign_type=expression_pb2.Sign.ZERO) elif constant is None: raise ValueError("need either scalar or constant") return Expression( expression_type=expression_pb2.Expression.CONSTANT, data=data, size=Size(dim=[m, n]), constant=constant, func_curvature=Curvature(curvature_type=Curvature.CONSTANT), sign=sign)
def prox_sum_quantile(expr): arg = None if (expr.expression_type == Expression.SUM and expr.arg[0].expression_type == Expression.MAX_ELEMENTWISE and len(expr.arg[0].arg) == 2): alpha, x = get_quantile_arg(expr.arg[0].arg[0]) beta, y = get_quantile_arg(expr.arg[0].arg[1]) if (x is not None and y is not None and x == y): if (alpha.sign.sign_type == Sign.NEGATIVE and beta.sign.sign_type == Sign.POSITIVE): alpha, beta = beta, expression.negate(alpha) arg = x elif (alpha.sign.sign_type == Sign.POSITIVE and beta.sign.sign_type == Sign.NEGATIVE): beta = expression.negate(beta) arg = x if not arg: return MatchResult(False) alpha = linear.transform_expr(alpha) beta = linear.transform_expr(beta) data = alpha.expression_data() data.update(beta.expression_data()) diagonal_arg, constrs = convert_diagonal(arg) return MatchResult( True, expression.prox_function( create_prox( prox_function_type=ProxFunction.SUM_QUANTILE, arg_size=[Size(dim=dims(arg))], scaled_zone_params=ProxFunction.ScaledZoneParams( alpha_expr=alpha.proto_with_args, beta_expr=beta.proto_with_args)), diagonal_arg, data=data), constrs)
def prox_function(f, *args, **kwargs): return Expression(expression_type=expression_pb2.Expression.PROX_FUNCTION, data=kwargs.get("data", {}), size=Size(dim=kwargs.get("size", (1, 1))), prox_function=f, arg=args)
def sum_largest(x, k): return Expression(expression_type=expression_pb2.Expression.SUM_LARGEST, size=Size(dim=[1, 1]), arg=[x], k=k)
def zero(x): return Expression(expression_type=expression_pb2.Expression.ZERO, size=Size(dim=[1, 1]), arg=[x])
def sum_entries(x): return Expression(expression_type=expression_pb2.Expression.SUM, size=Size(dim=[1, 1]), func_curvature=AFFINE, arg=[x])
def trace(X): return Expression(expression_type=expression_pb2.Expression.TRACE, size=Size(dim=[1, 1]), func_curvature=AFFINE, arg=[X])
def transpose(x): m, n = x.size.dim return Expression(expression_type=expression_pb2.Expression.TRANSPOSE, size=Size(dim=[n, m]), func_curvature=AFFINE, arg=[x])