コード例 #1
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_sum(expr):
    x = only_arg(expr)
    m, n = dims(x)

    if not expr.has_axis:
        return expression.linear_map(linear_map.sum(m, n), transform_expr(x))

    if expr.axis == 0:
        return expression.linear_map(linear_map.sum_left(m, n), transform_expr(x))

    if expr.axis == 1:
        return expression.linear_map(linear_map.sum_right(m, n), transform_expr(x))

    raise TransformError("unknown axis attribute", expr)
コード例 #2
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_sum(expr):
    x = only_arg(expr)
    m, n = dims(x)

    if not expr.has_axis:
        return expression.linear_map(linear_map.sum(m, n), transform_expr(x))

    if expr.axis == 0:
        return expression.linear_map(linear_map.sum_left(m, n),
                                     transform_expr(x))

    if expr.axis == 1:
        return expression.linear_map(linear_map.sum_right(m, n),
                                     transform_expr(x))

    raise TransformError("unknown axis attribute", expr)
コード例 #3
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_index(expr):
    return expression.linear_map(
        linear_map.kronecker_product(
            linear_map.index(expr.key[1], dim(only_arg(expr), 1)), linear_map.index(expr.key[0], dim(only_arg(expr), 0))
        ),
        transform_expr(only_arg(expr)),
    )
コード例 #4
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_multiply(expr):
    if len(expr.arg) != 2:
        raise TransformError("wrong number of args", expr)

    m = dim(expr, 0)
    n = dim(expr, 1)
    if expr.arg[0].dcp_props.constant:
        A = multiply_constant(expr.arg[0], m)
        B = promote(transform_expr(expr.arg[1]), n * n)
        return expression.linear_map(linear_map.left_matrix_product(A, n), B)

    if expr.arg[1].dcp_props.constant:
        A = promote(transform_expr(expr.arg[0]), m * m)
        B = multiply_constant(expr.arg[1], n)
        return expression.linear_map(linear_map.right_matrix_product(B, m), A)

    raise TransformError("multiplying non constants", expr)
コード例 #5
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_multiply(expr):
    if len(expr.arg) != 2:
        raise TransformError("wrong number of args", expr)

    m = dim(expr, 0)
    n = dim(expr, 1)
    if expr.arg[0].dcp_props.constant:
        A = multiply_constant(expr.arg[0], m)
        B = promote(transform_expr(expr.arg[1]), n * n)
        return expression.linear_map(linear_map.left_matrix_product(A, n), B)

    if expr.arg[1].dcp_props.constant:
        A = promote(transform_expr(expr.arg[0]), m * m)
        B = multiply_constant(expr.arg[1], n)
        return expression.linear_map(linear_map.right_matrix_product(B, m), A)

    raise TransformError("multiplying non constants", expr)
コード例 #6
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_kron(expr):
    if len(expr.arg) != 2:
        raise TransformError("Wrong number of arguments", expr)

    if not expr.arg[0].dcp_props.constant:
        raise TransformError("First arg is not constant", expr)

    return expression.linear_map(
        linear_map.kronecker_product_single_arg(
            multiply_constant(expr.arg[0], 1), dim(expr.arg[1], 0),
            dim(expr.arg[1], 1)), transform_expr(expr.arg[1]))
コード例 #7
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_kron(expr):
    if len(expr.arg) != 2:
        raise TransformError("Wrong number of arguments", expr)

    if not expr.arg[0].dcp_props.constant:
        raise TransformError("First arg is not constant", expr)

    return expression.linear_map(
        linear_map.kronecker_product_single_arg(
            multiply_constant(expr.arg[0], 1), dim(expr.arg[1], 0), dim(expr.arg[1], 1)
        ),
        transform_expr(expr.arg[1]),
    )
コード例 #8
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_hstack(expr):
    m = dim(expr, 0)
    n = dim(expr, 1)
    offset = 0
    add_args = []
    for arg in expr.arg:
        ni = dim(arg, 1)
        add_args.append(
            expression.linear_map(
                linear_map.right_matrix_product(linear_map.index(slice(offset, offset + ni), n), m), transform_expr(arg)
            )
        )
        offset += ni
    return expression.add(*add_args)
コード例 #9
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_multiply_elementwise(expr):
    if len(expr.arg) != 2:
        raise TransformError("wrong number of args", expr)

    if expr.arg[0].dcp_props.constant:
        c_expr = expr.arg[0]
        x_expr = expr.arg[1]
    elif expr.arg[1].dcp_props.constant:
        c_expr = expr.arg[1]
        x_expr = expr.arg[0]
    else:
        raise TransformError("multiply non constants", expr)

    return expression.linear_map(multiply_elementwise_constant(c_expr), transform_expr(x_expr))
コード例 #10
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_hstack(expr):
    m = dim(expr, 0)
    n = dim(expr, 1)
    offset = 0
    add_args = []
    for arg in expr.arg:
        ni = dim(arg, 1)
        add_args.append(
            expression.linear_map(
                linear_map.right_matrix_product(
                    linear_map.index(slice(offset, offset + ni), n), m),
                transform_expr(arg)))
        offset += ni
    return expression.add(*add_args)
コード例 #11
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_multiply_elementwise(expr):
    if len(expr.arg) != 2:
        raise TransformError("wrong number of args", expr)

    if expr.arg[0].dcp_props.constant:
        c_expr = expr.arg[0]
        x_expr = expr.arg[1]
    elif expr.arg[1].dcp_props.constant:
        c_expr = expr.arg[1]
        x_expr = expr.arg[0]
    else:
        raise TransformError("multiply non constants", expr)

    return expression.linear_map(multiply_elementwise_constant(c_expr),
                                 transform_expr(x_expr))
コード例 #12
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_vstack(expr):
    m = dim(expr, 0)
    n = dim(expr, 1)
    offset = 0
    add_args = []
    for arg in expr.arg:
        mi = dim(arg, 0)

        add_args.append(
            expression.linear_map(
                linear_map.left_matrix_product(
                    linear_map.transpose(
                        linear_map.index(slice(offset, offset + mi), m)), n),
                transform_expr(arg)))
        offset += mi
    return expression.add(*add_args)
コード例 #13
0
ファイル: linear.py プロジェクト: JeroenSoeters/epsilon
def transform_vstack(expr):
    m = dim(expr, 0)
    n = dim(expr, 1)
    offset = 0
    add_args = []
    for arg in expr.arg:
        mi = dim(arg, 0)

        add_args.append(
            expression.linear_map(
                linear_map.left_matrix_product(
                    linear_map.transpose(
                        linear_map.index(slice(offset, offset+mi), m)),
                    n),
                transform_expr(arg)))
        offset += mi
    return expression.add(*add_args)
コード例 #14
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_transpose(expr):
    x = only_arg(expr)
    return expression.linear_map(linear_map.transpose_matrix(*dims(x)), transform_expr(x))
コード例 #15
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_trace(expr):
    return expression.linear_map(linear_map.trace(dim(only_arg(expr), 0)), transform_expr(only_arg(expr)))
コード例 #16
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_upper_tri(expr):
    return expression.linear_map(linear_map.upper_tri(dim(expr, 0)), transform_expr(only_arg(expr)))
コード例 #17
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_diag_vec(expr):
    return expression.linear_map(linear_map.diag_vec(dim(expr, 0)), transform_expr(only_arg(expr)))
コード例 #18
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def promote(expr, new_dim):
    if dim(expr) != 1 or dim(expr) == new_dim:
        return expr
    return expression.linear_map(linear_map.promote(new_dim), expr)
コード例 #19
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_diag_vec(expr):
    return expression.linear_map(linear_map.diag_vec(dim(expr, 0)),
                                 transform_expr(only_arg(expr)))
コード例 #20
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_upper_tri(expr):
    return expression.linear_map(linear_map.upper_tri(dim(expr, 0)),
                                 transform_expr(only_arg(expr)))
コード例 #21
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_trace(expr):
    return expression.linear_map(linear_map.trace(dim(only_arg(expr), 0)),
                                 transform_expr(only_arg(expr)))
コード例 #22
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_index(expr):
    return expression.linear_map(
        linear_map.kronecker_product(
            linear_map.index(expr.key[1], dim(only_arg(expr), 1)),
            linear_map.index(expr.key[0], dim(only_arg(expr), 0))),
        transform_expr(only_arg(expr)))
コード例 #23
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def promote(expr, new_dim):
    if dim(expr) != 1 or dim(expr) == new_dim:
        return expr
    return expression.linear_map(linear_map.promote(new_dim), expr)
コード例 #24
0
ファイル: linear.py プロジェクト: xflash96/epsilon
def transform_negate(expr):
    return expression.linear_map(linear_map.negate(dim(expr)), transform_expr(only_arg(expr)))
コード例 #25
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_transpose(expr):
    x = only_arg(expr)
    return expression.linear_map(linear_map.transpose_matrix(*dims(x)),
                                 transform_expr(x))
コード例 #26
0
ファイル: linear.py プロジェクト: silky/epsilon
def transform_sum(expr):
    return expression.linear_map(
        linear_map.sum(dim(only_arg(expr))),
        transform_expr(only_arg(expr)))
コード例 #27
0
ファイル: linear.py プロジェクト: mfouda/epsilon
def transform_negate(expr):
    return expression.linear_map(linear_map.negate(dim(expr)),
                                 transform_expr(only_arg(expr)))