def norm_inf_canon(expr, args):
    x = args[0]
    axis = expr.axis
    shape = expr.shape
    t = Variable(shape)

    if axis is None:  # shape = (1, 1)
        promoted_t = promote(t, x.shape)
    elif axis == 0:  # shape = (1, n)
        promoted_t = Constant(np.ones(
            (x.shape[0], 1))) * reshape(t, (1, x.shape[1]))
    else:  # shape = (m, 1)
        promoted_t = reshape(t, (x.shape[0], 1)) * Constant(
            np.ones((1, x.shape[1])))

    return t, [x <= promoted_t, x + promoted_t >= 0]
示例#2
0
def max_canon(expr, args):
    x = args[0]
    shape = expr.shape
    axis = expr.axis
    t = Variable(shape)

    if axis is None:  # shape = (1, 1)
        promoted_t = promote(t, x.shape)
    elif axis == 0:  # shape = (1, n)
        promoted_t = Constant(np.ones((x.shape[0], 1))) * reshape(
                                                            t, (1, x.shape[1]))
    else:  # shape = (m, 1)
        promoted_t = reshape(t, (x.shape[0], 1)) * Constant(
                                                      np.ones((1, x.shape[1])))

    constraints = [x <= promoted_t]
    return t, constraints
示例#3
0
def log_sum_exp_canon(expr, args):
    x = args[0]
    shape = expr.shape
    axis = expr.axis
    t = Variable(shape)

    # log(sum(exp(x))) <= t <=> sum(exp(x-t)) <= 1
    if axis is None:  # shape = (1, 1)
        promoted_t = promote(t, x.shape)
    elif axis == 0:  # shape = (1, n)
        promoted_t = Constant(np.ones(
            (x.shape[0], 1))) * reshape(t, (1, ) + x.shape[1:])
    else:  # shape = (m, 1)
        promoted_t = reshape(t, x.shape[:-1] +
                             (1, )) * Constant(np.ones((1, x.shape[1])))

    exp_expr = exp(x - promoted_t)
    obj, constraints = exp_canon(exp_expr, exp_expr.args)
    obj = sum(obj, axis=axis)
    ones = Constant(np.ones(shape))
    constraints.append(obj <= ones)
    return t, constraints
示例#4
0
def exp_canon(expr, args):
    x = promote(args[0], expr.shape)
    t = Variable(expr.shape)
    ones = Constant(np.ones(expr.shape))
    constraints = [ExpCone(x, ones, t)]
    return t, constraints