示例#1
0
def np_cov(m, rowvar=False):
    # Handles complex arrays too
    # m = np.asarray(m)
    if m.ndim > 2:
        raise ValueError('m has more than 2 dimensions')
    dtype = np.result_type(m, np.float64)
    m = np.array(m, ndmin=2, dtype=dtype)
    if not rowvar and m.shape[0] != 1:
        m = m.T
    if m.shape[0] == 0:
        return np.array([], dtype=dtype).reshape(0, 0)

    # Determine the normalization
    fact = m.shape[1] - 1
    if fact <= 0:
        warnings.warn("Degrees of freedom <= 0 for slice",
                      RuntimeWarning,
                      stacklevel=2)
        fact = 0.0

    m -= np.mean(m, axis=1, keepdims=1)
    # c = np.dot(m, m.T.conj())
    c = np.dot(m, m.T)
    c *= np.true_divide(1, fact)
    return c.squeeze()
示例#2
0
def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
    shape, dtype = np.shape(x), np.result_type(x)
    def vjp(g):
        g_repeated,   _ = repeat_to_match_shape(g,   shape, dtype, axis, keepdims)
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * np.exp(x - ans_repeated)
    return vjp
示例#3
0
文件: misc.py 项目: j-towns/autograd
def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
    shape, dtype = anp.shape(x), anp.result_type(x)
    def vjp(g):
        g_repeated,   _ = repeat_to_match_shape(g,   shape, dtype, axis, keepdims)
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * anp.exp(x - ans_repeated)
    return vjp
示例#4
0
def maybe_multiply(x, y):
    if _is_constant_zero(x) or _is_constant_zero(y):
        return np.zeros(np.broadcast(x, y).shape, dtype=np.result_type(x, y))
    if _is_constant_one(x) and np.shape(y) == np.broadcast(x, y).shape:
        return y
    if _is_constant_one(y) and np.shape(x) == np.broadcast(x, y).shape:
        return x
    return _multiply_as_einsum(x, y)
示例#5
0
def _distribute_einsum(formula, op, add_args, args1, args2):
    # Make sure any implicit broadcasting isn't lost.
    broadcast_shape = np.broadcast(*add_args).shape
    dtype = np.result_type(*add_args)
    add_args = [
        arg * np.ones(broadcast_shape, dtype=dtype)
        if not hasattr(arg, 'shape') or broadcast_shape != arg.shape else arg
        for arg in add_args
    ]
    return op(
        *[np.einsum(formula, *(args1 + (arg, ) + args2)) for arg in add_args])
示例#6
0
def _zeros_like_einsum(formula, args1, args2):
    args = args1 + args2
    input_formulas, output_formula = split_einsum_formula(formula)
    input_formulas = input_formulas[:len(args1)] + input_formulas[len(args1) +
                                                                  1:]
    out_shape = []
    for output_index in output_formula:
        for i, input_formula in enumerate(input_formulas):
            position = input_formula.find(output_index)
            if position != -1:
                out_shape.append(args[i].shape[position])
                break
    return np.zeros(out_shape, dtype=np.result_type(*args))
示例#7
0
def _zeros_like_einsum(formula, args1, args2):
    args = args1 + args2
    input_formulas, output_formula = split_einsum_formula(formula)
    output_formula = decompose_formula(output_formula)
    input_formulas = input_formulas[:len(args1)] + input_formulas[len(args1) +
                                                                  1:]
    input_formulas = [
        decompose_formula(input_formula) for input_formula in input_formulas
    ]

    out_shape = []
    for output_index in output_formula:
        for i, input_formula in enumerate(input_formulas):
            position = input_formula.index(output_index)
            if position != -1 and output_index != '...':
                out_shape.append(args[i].shape[position])
                break
            elif position != -1 and output_index == '...':
                for offset in range(args[i].ndim - len(input_formula) + 1):
                    out_shape.append(args[i].shape[position + offset])
    return np.zeros(out_shape, dtype=np.result_type(*args))