Ejemplo n.º 1
0
                             SparseObject, VJPNode, register_notrace)

# ----- Non-differentiable functions -----

nograd_functions = [
    anp.floor, anp.ceil, anp.round, anp.rint, anp.around, anp.fix, anp.trunc, anp.all,
    anp.any, anp.argmax, anp.argmin, anp.argpartition, anp.argsort, anp.argwhere, anp.nonzero,
    anp.flatnonzero, anp.count_nonzero, anp.searchsorted, anp.sign, anp.ndim, anp.shape,
    anp.floor_divide, anp.logical_and, anp.logical_or, anp.logical_not, anp.logical_xor,
    anp.isfinite, anp.isinf, anp.isnan, anp.isneginf, anp.isposinf, anp.allclose, anp.isclose,
    anp.array_equal, anp.array_equiv, anp.greater, anp.greater_equal, anp.less, anp.less_equal,
    anp.equal, anp.not_equal, anp.iscomplexobj, anp.iscomplex, anp.size, anp.isscalar,
    anp.isreal, anp.zeros_like, anp.ones_like, anp.result_type]

for fun in nograd_functions:
    register_notrace(VJPNode, fun)

# ----- Functions that are constant w.r.t. continuous inputs -----

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs -----

defvjp(anp.add,         lambda ans, x, y : unbroadcast_f(x, lambda g: g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: g))
defvjp(anp.multiply,    lambda ans, x, y : unbroadcast_f(x, lambda g: y * g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: x * g))
defvjp(anp.subtract,    lambda ans, x, y : unbroadcast_f(x, lambda g: g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: -g))
defvjp(anp.divide,      lambda ans, x, y : unbroadcast_f(x, lambda g:   g / y),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
Ejemplo n.º 2
0
                             SparseObject, VJPNode, register_notrace)

# ----- Non-differentiable functions -----

nograd_functions = [
    anp.floor, anp.ceil, anp.round, anp.rint, anp.around, anp.fix, anp.trunc, anp.all,
    anp.any, anp.argmax, anp.argmin, anp.argpartition, anp.argsort, anp.argwhere, anp.nonzero,
    anp.flatnonzero, anp.count_nonzero, anp.searchsorted, anp.sign, anp.ndim, anp.shape,
    anp.floor_divide, anp.logical_and, anp.logical_or, anp.logical_not, anp.logical_xor,
    anp.isfinite, anp.isinf, anp.isnan, anp.isneginf, anp.isposinf, anp.allclose, anp.isclose,
    anp.array_equal, anp.array_equiv, anp.greater, anp.greater_equal, anp.less, anp.less_equal,
    anp.equal, anp.not_equal, anp.iscomplexobj, anp.iscomplex, anp.size, anp.isscalar,
    anp.isreal, anp.zeros_like, anp.ones_like, anp.result_type]

for fun in nograd_functions:
    register_notrace(VJPNode, fun)

# ----- Functions that are constant w.r.t. continuous inputs -----

defvjp(anp.nan_to_num, lambda ans, x: lambda g: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs -----

defvjp(anp.add,         lambda ans, x, y : unbroadcast_f(x, lambda g: g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: g))
defvjp(anp.multiply,    lambda ans, x, y : unbroadcast_f(x, lambda g: y * g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: x * g))
defvjp(anp.subtract,    lambda ans, x, y : unbroadcast_f(x, lambda g: g),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: -g))
defvjp(anp.divide,      lambda ans, x, y : unbroadcast_f(x, lambda g:   g / y),
                        lambda ans, x, y : unbroadcast_f(y, lambda g: - g * x / y**2))
Ejemplo n.º 3
0
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
                         dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
                         tensordot_adjoint_1, nograd_functions)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace, JVPNode,
                             register_notrace)
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

defjvp_argnum(anp.array_from_args, lambda argnum, g, ans, args, kwargs: untake(g, argnum-2, vspace(ans)))
defjvp(anp._array_from_scalar_or_array, None, None,
       lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(args, kwargs, g))

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)

# ----- Binary ufuncs -----
defjvp(anp.add,        lambda g, ans, x, y : broadcast(g, ans),
                       lambda g, ans, x, y : broadcast(g, ans))
defjvp(anp.subtract,   lambda g, ans, x, y : broadcast(g, ans),
                       lambda g, ans, x, y : broadcast(-g, ans))
defjvp(anp.divide,     'same',
Ejemplo n.º 4
0
from . import numpy_wrapper as anp
from .numpy_vjps import (untake, balanced_eq, match_complex, replace_zero,
                         dot_adjoint_0, dot_adjoint_1, tensordot_adjoint_0,
                         tensordot_adjoint_1, nograd_functions)
from autograd.extend import (defjvp, defjvp_argnum, def_linear, vspace,
                             JVPNode, register_notrace)
from ..util import func
from .numpy_boxes import ArrayBox

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), 'same')
defjvp(untake, 'same')

defjvp_argnum(
    anp.array_from_args,
    lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans)))
defjvp(
    anp._array_from_scalar_or_array, None, None,
    lambda g, ans, args, kwargs, _: anp._array_from_scalar_or_array(
        args, kwargs, g))

# ----- Functions that are constant w.r.t. continuous inputs -----
defjvp(anp.nan_to_num, lambda g, ans, x: anp.where(anp.isfinite(x), g, 0.))

# ----- Binary ufuncs (linear) -----
def_linear(anp.multiply)

# ----- Binary ufuncs -----
defjvp(anp.add, lambda g, ans, x, y: broadcast(g, ans),
Ejemplo n.º 5
0
        def apply_node(node, arg_strs):
            if node.fun is env_lookup:
                name, = arg_strs
                name = node_names[
                    node] = name if name not in env else env[name]
                fragment[0] += dot_variable_node(name, name)
            else:
                name = node_names[node]
                fragment[0] += dot_function_node(name, node.fun.__name__)
                for argnum, arg in enumerate(node.args):
                    if argnum in node.parent_argnums:
                        fragment[0] += dot_edge(node_names[node.args[argnum]],
                                                name)
                    else:
                        argname = '{}_arg_{}'.format(name, argnum)
                        fragment[0] += dot_edge(argname, name)
                        fragment[0] += dot_variable_node(argname, arg)
            return name

        name = _eval_graph(expr.expr_node, eval_args, apply_node, cse=False)
        fragment[0] += dot_variable_node('output', 'output')
        fragment[0] += dot_edge(name, 'output')
        return dot_graph(fragment[0])
    else:
        raise TypeError("Can't draw expression type: {}".format(type(expr)))


notrace_functions = [np.ones_like, np.zeros_like]
for fun in notrace_functions:
    register_notrace(ExprNode, fun)