Example #1
0
 def __new__(self, name, base, dic):
     cls = type.__new__(container_mateclass, name, base, dic)
     cls.register(_np.ndarray)
     for type_ in [
             float, _np.float64, _np.float32, _np.float16, complex,
             _np.complex64, _np.complex128
     ]:
         cls.register(type_)
     for method_name in nondiff_methods + diff_methods:
         setattr(cls, method_name, anp.__dict__[method_name])
     setattr(cls, 'flatten', anp.__dict__['ravel'])
     defvjp(func(cls.__getitem__),
            lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)))
     defjvp(func(cls.__getitem__), 'same')
     defjvp(untake, 'same')
     setattr(cls, 'reshape', wrapped_reshape)
     return cls
Example #2
0
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * np.exp(x - ans_repeated)
    return vjp

defvjp(logsumexp, make_grad_logsumexp)

def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):
    if not keepdims:
        if isinstance(axis, int):
            ans = np.expand_dims(ans, axis)
        elif isinstance(axis, tuple):
            for ax in sorted(axis):
                ans = np.expand_dims(ans, ax)
    return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims)

defjvp(logsumexp, fwd_grad_logsumexp)


## ========================== Assoc Legendre function ==========================
#### LEGENDRE FUNCTION IMPLEMENTAION

# declaring a black box
legendre = primitive(scipy.special.lpmv)

def vjp_legendre(ans, m, n, x):
    '''
    TODO: implement abs(x==1) cases
    '''
    def vjp( g ):
        return g * ( (n+1-m)*legendre(m,n+1,x) - (n+1)*x*ans ) / (x*x - 1)
Example #3
0
def vjp_maker_spdot(b, A, x):
    """ Gives vjp for b = spdot(A, x) w.r.t. x"""
    def vjp(v):
        return spdot(A.T, v)

    return vjp


def jvp_spdot(g, b, A, x):
    """ Gives jvp for b = spdot(A, x) w.r.t. x"""
    return spdot(A, g)


defvjp(spdot, None, vjp_maker_spdot)
defjvp(spdot, None, jvp_spdot)
""" =================== PLOTTING AND MEASUREMENT =================== """

import matplotlib.pylab as plt


def aniplot(F, source, steps, component='Ez', num_panels=10):
    """ Animate an FDTD (F) with `source` for `steps` time steps.
    display the `component` field components at `num_panels` equally spaced.
    """
    F.initialize_fields()

    # initialize the plot
    f, ax_list = plt.subplots(1, num_panels, figsize=(20 * num_panels, 20))
    Nx, Ny, _ = F.eps_r.shape
    ax_index = 0
Example #4
0
    return vjp


defvjp(solve_triangular,
       grad_solve_triangular,
       lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g:
       solve_triangular(a, g, trans=_flip(a, trans), lower=lower))


def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm jvp not implemented for disp=False"
    return solve_sylvester(ans, ans, dA)


defjvp(sqrtm, _jvp_sqrtm)


def _jvp_sylvester(argnums, dms, ans, args, _):
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)

Example #5
0
    transpose = lambda x: x if _flip(a, trans) != 'N' else x.T
    al2d = lambda x: x if x.ndim > 1 else x[...,None]
    def vjp(g):
        v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower))
        return -transpose(tri(anp.dot(v, al2d(ans).T)))
    return vjp

defvjp(solve_triangular,
       grad_solve_triangular,
       lambda ans, a, b, trans=0, lower=False, **kwargs:
       lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower))

def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64):
    assert disp, "sqrtm jvp not implemented for disp=False"
    return solve_sylvester(ans, ans, dA)
defjvp(sqrtm, _jvp_sqrtm)

def _jvp_sylvester(argnums, dms, ans, args, _):
    a, b, q = args
    if 0 in argnums:
        da = dms[0]
        db = dms[1] if 1 in argnums else 0
    else:
        da = 0
        db = dms[0] if 1 in argnums else 0
    dq = dms[-1] if 2 in argnums else 0
    rhs = dq - anp.dot(da, ans) - anp.dot(ans, db)
    return solve_sylvester(a, b, rhs)
defjvp_argnums(solve_sylvester, _jvp_sylvester)

def _vjp_sylvester(argnums, ans, args, _):
Example #6
0
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
                         dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
                         tensordot_adjoint_1, nograd_functions)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace,
                             JVPNode, register_notrace)
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

defjvp_argnum(
    anp.array_from_args,
    lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans)))
defjvp(
    anp._array_from_scalar_or_array, None, None,
    lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(
        args, kwargs, g))

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y: broadcast(g, ans),
Example #7
0
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
                         dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
                         tensordot_adjoint_1, nograd_functions)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode,
                             register_notrace)
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum-2, vspace(ans)))
defjvp(anp._array_from_scalar_or_array, None, None,
       lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g))

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)

# ----- Binary ufuncs -----
defjvp(anp.add,        lambda g, ans, x, y : broadcast(g, ans),
                       lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.subtract,   lambda g, ans, x, y : broadcast(g, ans),
                       lambda g, ans, x, y : broadcast(-g, ans))
defjvp(anp.divide,     'same',
Example #8
0
defvjp(pinv, grad_pinv)


def fwd_grad_pinv(g, ans, A):
    # ans is pinv(A)
    #return (-_dot(_dot(ans, g), ans) +
    #        _dot(_dot(_dot(ans, T(ans)), T(g)), (anp.eye(A.shape[-2]) - _dot(A, ans))) +
    #        _dot(_dot(_dot((anp.eye(A.shape[-1]) - _dot(ans, A)), T(g)), T(ans)), ans))
    return (-_dot(_dot(ans, g), ans) + _dot(_dot(ans, T(ans)), T(g)) -
            _dot(_dot(_dot(_dot(ans, T(ans)), T(g)), A), ans))  # +
    # _dot(_dot(T(g), T(ans)), ans) -
    # _dot(_dot(_dot(_dot(ans, A), T(g)), T(ans)), ans))


defjvp(pinv, fwd_grad_pinv)


def grad_solve(argnum, ans, a, b):
    updim = lambda x: x if x.ndim == a.ndim else x[..., None]
    if argnum == 0:
        return lambda g: -_dot(updim(solve(T(a), g)), T(updim(ans)))
    else:
        return lambda g: solve(T(a), g)


defvjp(solve, partial(grad_solve, 0), partial(grad_solve, 1))


def fwd_grad_solve_0(g, ans, a, b):
    return -solve(a, anp.dot(g, ans))
Example #9
0
def jvp_solve_Ez_source(g,
                        Ez,
                        info_dict,
                        eps_vec_zz,
                        source,
                        iterative=False,
                        method=DEFAULT_SOLVER):
    """ Gives jvp for solve_Ez with respect to source """
    A = make_A_Ez(info_dict, eps_vec_zz)
    return 1j * info_dict['omega'] * sparse_solve(
        A, g, iterative=iterative, method=method)


defvjp(solve_Ez, None, vjp_maker_solve_Ez, vjp_maker_solve_Ez_source)
defjvp(solve_Ez, None, jvp_solve_Ez, jvp_solve_Ez_source)

# Linear Hz


@primitive
def solve_Hz(info_dict,
             eps_vec_zz,
             source,
             iterative=False,
             method=DEFAULT_SOLVER):
    """ solve `Hz = A^-1 b` where A is constructed from the FDFD `info_dict`
        and 'eps_vec' is a (1D) vecay of the relative permittivity
    """

    A = make_A_Hz(info_dict, eps_vec_zz)
Example #10
0
from __future__ import absolute_import
import scipy.misc
from autograd.extend import primitive, defvjp, defjvp
import autograd.numpy as anp
from autograd.numpy.numpy_vjps import repeat_to_match_shape

logsumexp = primitive(scipy.misc.logsumexp)

def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False):
    shape, dtype = anp.shape(x), anp.result_type(x)
    def vjp(g):
        g_repeated,   _ = repeat_to_match_shape(g,   shape, dtype, axis, keepdims)
        ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims)
        return g_repeated * b * anp.exp(x - ans_repeated)
    return vjp

defvjp(logsumexp, make_grad_logsumexp)

def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False):
    if not keepdims:
        if isinstance(axis, int):
            ans = anp.expand_dims(ans, axis)
        elif isinstance(axis, tuple):
            for ax in sorted(axis):
                ans = anp.expand_dims(ans, ax)
    return anp.sum(g * b * anp.exp(x - ans), axis=axis, keepdims=keepdims)

defjvp(logsumexp, fwd_grad_logsumexp)
Example #11
0
    check_implemented()
    if ord in (None, 2, 'fro'):
        return contract(g * x) / ans
    elif ord == 'nuc':
        x_rolled = roll(x)
        u, s, vt = svd(x_rolled, full_matrices=False)
        uvt_rolled = _dot(u, vt)
        # Roll the matrix axes back to their correct positions
        uvt = unroll(uvt_rolled)
        return contract(g * uvt)
    else:
        # see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm
        return contract(g * x * anp.abs(x)**(ord - 2)) / ans**(ord - 1)


defjvp(norm, norm_jvp)


def grad_eigh(ans, x, UPLO='L'):
    """Gradient for eigenvalues and vectors of a symmetric matrix."""
    N = x.shape[-1]
    w, v = ans  # Eigenvalues, eigenvectors.
    vc = anp.conj(v)

    def vjp(g):
        wg, vg = g  # Gradient w.r.t. eigenvalues, eigenvectors.
        w_repeated = anp.repeat(w[..., anp.newaxis], N, axis=-1)

        # Eigenvalue part
        vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))
Example #12
0
    return np.fft.fft(x)


def fft_grad(g, ans, x):
    """ 
    Define the jacobian-vector product of my_fft(x)
        The gradient of FFT times g is the vjp
        ans = fft(x) = D @ x
        jvp(fft(x))(g) = d{fft}/d{x} @ g
                       = D @ g
        Therefore, it looks like the FFT of g
    """
    return np.fft.fft(g)


defjvp(my_fft, fft_grad)


def get_spectrum(series, dt):
    """ Get FFT of series """

    steps = len(series)
    times = np.arange(steps) * dt

    # reshape to be able to multiply by hamming window
    series = series.reshape((steps, -1))

    # multiply with hamming window to get rid of numerical errors
    hamming_window = np.hamming(steps).reshape((steps, 1))
    signal_f = my_fft(hamming_window * series)