def __new__(self, name, base, dic): cls = type.__new__(container_mateclass, name, base, dic) cls.register(_np.ndarray) for type_ in [ float, _np.float64, _np.float32, _np.float16, complex, _np.complex64, _np.complex128 ]: cls.register(type_) for method_name in nondiff_methods + diff_methods: setattr(cls, method_name, anp.__dict__[method_name]) setattr(cls, 'flatten', anp.__dict__['ravel']) defvjp(func(cls.__getitem__), lambda ans, A, idx: lambda g: untake(g, idx, vspace(A))) defjvp(func(cls.__getitem__), 'same') defjvp(untake, 'same') setattr(cls, 'reshape', wrapped_reshape) return cls
ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims) return g_repeated * b * np.exp(x - ans_repeated) return vjp defvjp(logsumexp, make_grad_logsumexp) def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False): if not keepdims: if isinstance(axis, int): ans = np.expand_dims(ans, axis) elif isinstance(axis, tuple): for ax in sorted(axis): ans = np.expand_dims(ans, ax) return np.sum(g * b * np.exp(x - ans), axis=axis, keepdims=keepdims) defjvp(logsumexp, fwd_grad_logsumexp) ## ========================== Assoc Legendre function ========================== #### LEGENDRE FUNCTION IMPLEMENTAION # declaring a black box legendre = primitive(scipy.special.lpmv) def vjp_legendre(ans, m, n, x): ''' TODO: implement abs(x==1) cases ''' def vjp( g ): return g * ( (n+1-m)*legendre(m,n+1,x) - (n+1)*x*ans ) / (x*x - 1)
def vjp_maker_spdot(b, A, x): """ Gives vjp for b = spdot(A, x) w.r.t. x""" def vjp(v): return spdot(A.T, v) return vjp def jvp_spdot(g, b, A, x): """ Gives jvp for b = spdot(A, x) w.r.t. x""" return spdot(A, g) defvjp(spdot, None, vjp_maker_spdot) defjvp(spdot, None, jvp_spdot) """ =================== PLOTTING AND MEASUREMENT =================== """ import matplotlib.pylab as plt def aniplot(F, source, steps, component='Ez', num_panels=10): """ Animate an FDTD (F) with `source` for `steps` time steps. display the `component` field components at `num_panels` equally spaced. """ F.initialize_fields() # initialize the plot f, ax_list = plt.subplots(1, num_panels, figsize=(20 * num_panels, 20)) Nx, Ny, _ = F.eps_r.shape ax_index = 0
return vjp defvjp(solve_triangular, grad_solve_triangular, lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): assert disp, "sqrtm jvp not implemented for disp=False" return solve_sylvester(ans, ans, dA) defjvp(sqrtm, _jvp_sqrtm) def _jvp_sylvester(argnums, dms, ans, args, _): a, b, q = args if 0 in argnums: da = dms[0] db = dms[1] if 1 in argnums else 0 else: da = 0 db = dms[0] if 1 in argnums else 0 dq = dms[-1] if 2 in argnums else 0 rhs = dq - anp.dot(da, ans) - anp.dot(ans, db) return solve_sylvester(a, b, rhs)
transpose = lambda x: x if _flip(a, trans) != 'N' else x.T al2d = lambda x: x if x.ndim > 1 else x[...,None] def vjp(g): v = al2d(solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) return -transpose(tri(anp.dot(v, al2d(ans).T))) return vjp defvjp(solve_triangular, grad_solve_triangular, lambda ans, a, b, trans=0, lower=False, **kwargs: lambda g: solve_triangular(a, g, trans=_flip(a, trans), lower=lower)) def _jvp_sqrtm(dA, ans, A, disp=True, blocksize=64): assert disp, "sqrtm jvp not implemented for disp=False" return solve_sylvester(ans, ans, dA) defjvp(sqrtm, _jvp_sqrtm) def _jvp_sylvester(argnums, dms, ans, args, _): a, b, q = args if 0 in argnums: da = dms[0] db = dms[1] if 1 in argnums else 0 else: da = 0 db = dms[0] if 1 in argnums else 0 dq = dms[-1] if 2 in argnums else 0 rhs = dq - anp.dot(da, ans) - anp.dot(ans, db) return solve_sylvester(a, b, rhs) defjvp_argnums(solve_sylvester, _jvp_sylvester) def _vjp_sylvester(argnums, ans, args, _):
from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') defjvp_argnum( anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans))) defjvp( anp._array_from_scalar_or_array, None, None, lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array( args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs (linear) ----- def_linear(anp.multiply) # ----- Binary ufuncs ----- defjvp(anp.add, lambda g, ans, x, y: broadcast(g, ans),
from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum-2, vspace(ans))) defjvp(anp._array_from_scalar_or_array, None, None, lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs (linear) ----- def_linear(anp.multiply) # ----- Binary ufuncs ----- defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(g, ans)) defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(-g, ans)) defjvp(anp.divide, 'same',
defvjp(pinv, grad_pinv) def fwd_grad_pinv(g, ans, A): # ans is pinv(A) #return (-_dot(_dot(ans, g), ans) + # _dot(_dot(_dot(ans, T(ans)), T(g)), (anp.eye(A.shape[-2]) - _dot(A, ans))) + # _dot(_dot(_dot((anp.eye(A.shape[-1]) - _dot(ans, A)), T(g)), T(ans)), ans)) return (-_dot(_dot(ans, g), ans) + _dot(_dot(ans, T(ans)), T(g)) - _dot(_dot(_dot(_dot(ans, T(ans)), T(g)), A), ans)) # + # _dot(_dot(T(g), T(ans)), ans) - # _dot(_dot(_dot(_dot(ans, A), T(g)), T(ans)), ans)) defjvp(pinv, fwd_grad_pinv) def grad_solve(argnum, ans, a, b): updim = lambda x: x if x.ndim == a.ndim else x[..., None] if argnum == 0: return lambda g: -_dot(updim(solve(T(a), g)), T(updim(ans))) else: return lambda g: solve(T(a), g) defvjp(solve, partial(grad_solve, 0), partial(grad_solve, 1)) def fwd_grad_solve_0(g, ans, a, b): return -solve(a, anp.dot(g, ans))
def jvp_solve_Ez_source(g, Ez, info_dict, eps_vec_zz, source, iterative=False, method=DEFAULT_SOLVER): """ Gives jvp for solve_Ez with respect to source """ A = make_A_Ez(info_dict, eps_vec_zz) return 1j * info_dict['omega'] * sparse_solve( A, g, iterative=iterative, method=method) defvjp(solve_Ez, None, vjp_maker_solve_Ez, vjp_maker_solve_Ez_source) defjvp(solve_Ez, None, jvp_solve_Ez, jvp_solve_Ez_source) # Linear Hz @primitive def solve_Hz(info_dict, eps_vec_zz, source, iterative=False, method=DEFAULT_SOLVER): """ solve `Hz = A^-1 b` where A is constructed from the FDFD `info_dict` and 'eps_vec' is a (1D) vecay of the relative permittivity """ A = make_A_Hz(info_dict, eps_vec_zz)
from __future__ import absolute_import import scipy.misc from autograd.extend import primitive, defvjp, defjvp import autograd.numpy as anp from autograd.numpy.numpy_vjps import repeat_to_match_shape logsumexp = primitive(scipy.misc.logsumexp) def make_grad_logsumexp(ans, x, axis=None, b=1.0, keepdims=False): shape, dtype = anp.shape(x), anp.result_type(x) def vjp(g): g_repeated, _ = repeat_to_match_shape(g, shape, dtype, axis, keepdims) ans_repeated, _ = repeat_to_match_shape(ans, shape, dtype, axis, keepdims) return g_repeated * b * anp.exp(x - ans_repeated) return vjp defvjp(logsumexp, make_grad_logsumexp) def fwd_grad_logsumexp(g, ans, x, axis=None, b=1.0, keepdims=False): if not keepdims: if isinstance(axis, int): ans = anp.expand_dims(ans, axis) elif isinstance(axis, tuple): for ax in sorted(axis): ans = anp.expand_dims(ans, ax) return anp.sum(g * b * anp.exp(x - ans), axis=axis, keepdims=keepdims) defjvp(logsumexp, fwd_grad_logsumexp)
check_implemented() if ord in (None, 2, 'fro'): return contract(g * x) / ans elif ord == 'nuc': x_rolled = roll(x) u, s, vt = svd(x_rolled, full_matrices=False) uvt_rolled = _dot(u, vt) # Roll the matrix axes back to their correct positions uvt = unroll(uvt_rolled) return contract(g * uvt) else: # see https://en.wikipedia.org/wiki/Norm_(mathematics)#p-norm return contract(g * x * anp.abs(x)**(ord - 2)) / ans**(ord - 1) defjvp(norm, norm_jvp) def grad_eigh(ans, x, UPLO='L'): """Gradient for eigenvalues and vectors of a symmetric matrix.""" N = x.shape[-1] w, v = ans # Eigenvalues, eigenvectors. vc = anp.conj(v) def vjp(g): wg, vg = g # Gradient w.r.t. eigenvalues, eigenvectors. w_repeated = anp.repeat(w[..., anp.newaxis], N, axis=-1) # Eigenvalue part vjp_temp = _dot(vc * wg[..., anp.newaxis, :], T(v))
return np.fft.fft(x) def fft_grad(g, ans, x): """ Define the jacobian-vector product of my_fft(x) The gradient of FFT times g is the vjp ans = fft(x) = D @ x jvp(fft(x))(g) = d{fft}/d{x} @ g = D @ g Therefore, it looks like the FFT of g """ return np.fft.fft(g) defjvp(my_fft, fft_grad) def get_spectrum(series, dt): """ Get FFT of series """ steps = len(series) times = np.arange(steps) * dt # reshape to be able to multiply by hamming window series = series.reshape((steps, -1)) # multiply with hamming window to get rid of numerical errors hamming_window = np.hamming(steps).reshape((steps, 1)) signal_f = my_fft(hamming_window * series)