示例#1
0
def multiply_elementwise_constant(expr):
    # TODO(mwytock): Handle this case
    if expr.expression_type != Expression.CONSTANT:
        raise TransformError("multiply constant is not leaf", expr)

    if expr.constant.constant_type == Constant.DENSE_MATRIX:
        return linear_map.diagonal_matrix(expr.constant)
    if expr.constant.constant_type == Constant.SCALAR:
        return linear_map.scalar(expr.constant.scalar, 1)

    raise TransformError("unknown constant type", expr)
示例#2
0
文件: linear.py 项目: mfouda/epsilon
def multiply_elementwise_constant(expr):
    # TODO(mwytock): Handle this case
    if expr.expression_type != Expression.CONSTANT:
        raise TransformError("multiply constant is not leaf", expr)

    if expr.constant.constant_type == Constant.DENSE_MATRIX:
        return linear_map.diagonal_matrix(expr.constant, expr.data)
    if expr.constant.constant_type == Constant.SCALAR:
        return linear_map.scalar(expr.constant.scalar, 1)

    raise TransformError("unknown constant type", expr)
示例#3
0
def multiply_constant(expr, n):
    if expr.expression_type == Expression.CONSTANT:
        if expr.constant.constant_type == Constant.SCALAR:
            return linear_map.scalar(expr.constant.scalar, n)
        if expr.constant.constant_type == Constant.DENSE_MATRIX:
            return linear_map.dense_matrix(expr.constant, expr.data)
        if expr.constant.constant_type == Constant.SPARSE_MATRIX:
            return linear_map.sparse_matrix(expr.constant, expr.data)
    elif expr.expression_type == Expression.TRANSPOSE:
        return linear_map.transpose(multiply_constant(only_arg(expr), n))
    raise TransformError("unknown constant type", expr)
示例#4
0
文件: linear.py 项目: mfouda/epsilon
def multiply_constant(expr, n):
    if expr.expression_type == Expression.CONSTANT:
        if expr.constant.constant_type == Constant.SCALAR:
            return linear_map.scalar(expr.constant.scalar, n)
        if expr.constant.constant_type == Constant.DENSE_MATRIX:
            return linear_map.dense_matrix(expr.constant, expr.data)
        if expr.constant.constant_type == Constant.SPARSE_MATRIX:
            return linear_map.sparse_matrix(expr.constant, expr.data)
    elif expr.expression_type == Expression.TRANSPOSE:
        return linear_map.transpose(multiply_constant(only_arg(expr), n))
    raise TransformError("unknown constant type", expr)