SparseObject, VJPNode, register_notrace) # ----- Non-differentiable functions ----- nograd_functions = [ anp.floor, anp.ceil, anp.round, anp.rint, anp.around, anp.fix, anp.trunc, anp.all, anp.any, anp.argmax, anp.argmin, anp.argpartition, anp.argsort, anp.argwhere, anp.nonzero, anp.flatnonzero, anp.count_nonzero, anp.searchsorted, anp.sign, anp.ndim, anp.shape, anp.floor_divide, anp.logical_and, anp.logical_or, anp.logical_not, anp.logical_xor, anp.isfinite, anp.isinf, anp.isnan, anp.isneginf, anp.isposinf, anp.allclose, anp.isclose, anp.array_equal, anp.array_equiv, anp.greater, anp.greater_equal, anp.less, anp.less_equal, anp.equal, anp.not_equal, anp.iscomplexobj, anp.iscomplex, anp.size, anp.isscalar, anp.isreal, anp.zeros_like, anp.ones_like, anp.result_type] for fun in nograd_functions: register_notrace(VJPNode, fun) # ----- Functions that are constant w.r.t. continuous inputs ----- defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs ----- defvjp(anp.add, lambda ans, x, y : unbroadcast_f(x, lambda g: g), lambda ans, x, y : unbroadcast_f(y, lambda g: g)) defvjp(anp.multiply, lambda ans, x, y : unbroadcast_f(x, lambda g: y * g), lambda ans, x, y : unbroadcast_f(y, lambda g: x * g)) defvjp(anp.subtract, lambda ans, x, y : unbroadcast_f(x, lambda g: g), lambda ans, x, y : unbroadcast_f(y, lambda g: -g)) defvjp(anp.divide, lambda ans, x, y : unbroadcast_f(x, lambda g: g / y), lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum-2, vspace(ans))) defjvp(anp._array_from_scalar_or_array, None, None, lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs (linear) ----- def_linear(anp.multiply) # ----- Binary ufuncs ----- defjvp(anp.add, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(g, ans)) defjvp(anp.subtract, lambda g, ans, x, y : broadcast(g, ans), lambda g, ans, x, y : broadcast(-g, ans)) defjvp(anp.divide, 'same',
from . import numpy_wrapper as anp from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero, dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0, tensordot_adjoint_1, nograd_functions) from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode, register_notrace) from ..util import func from .numpy_boxes import ArrayBox for fun in nograd_functions: register_notrace(JVPNode, fun) defjvp(func(ArrayBox.__getitem__), 'same') defjvp(untake, 'same') defjvp_argnum( anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans))) defjvp( anp._array_from_scalar_or_array, None, None, lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array( args, kwargs, g)) # ----- Functions that are constant w.r.t. continuous inputs ----- defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.)) # ----- Binary ufuncs (linear) ----- def_linear(anp.multiply) # ----- Binary ufuncs ----- defjvp(anp.add, lambda g, ans, x, y: broadcast(g, ans),
def apply_node(node, arg_strs): if node.fun is env_lookup: name, = arg_strs name = node_names[ node] = name if name not in env else env[name] fragment[0] += dot_variable_node(name, name) else: name = node_names[node] fragment[0] += dot_function_node(name, node.fun.__name__) for argnum, arg in enumerate(node.args): if argnum in node.parent_argnums: fragment[0] += dot_edge(node_names[node.args[argnum]], name) else: argname = '{}_arg_{}'.format(name, argnum) fragment[0] += dot_edge(argname, name) fragment[0] += dot_variable_node(argname, arg) return name name = _eval_graph(expr.expr_node, eval_args, apply_node, cse=False) fragment[0] += dot_variable_node('output', 'output') fragment[0] += dot_edge(name, 'output') return dot_graph(fragment[0]) else: raise TypeError("Can't draw expression type: {}".format(type(expr))) notrace_functions = [np.ones_like, np.zeros_like] for fun in notrace_functions: register_notrace(ExprNode, fun)