Example #1
0
 def __new__(self, name, base, dic):
     cls = type.__new__(container_mateclass, name, base, dic)
     cls.register(_np.ndarray)
     for type_ in [
             float, _np.float64, _np.float32, _np.float16, complex,
             _np.complex64, _np.complex128
     ]:
         cls.register(type_)
     for method_name in nondiff_methods + diff_methods:
         setattr(cls, method_name, anp.__dict__[method_name])
     setattr(cls, 'flatten', anp.__dict__['ravel'])
     defvjp(func(cls.__getitem__),
            lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)))
     defjvp(func(cls.__getitem__), 'same')
     defjvp(untake, 'same')
     setattr(cls, 'reshape', wrapped_reshape)
     return cls
Example #2
0
def _is_vspace_add(node):
  return (node.fun is ag_util.func(ag_extend.VSpace.add) or
          node.fun is ag_util.func(ag_extend.VSpace.mut_add))
Example #3
0
    @primitive
    def __call__(self, x):
        return sum(
            [a * self.kernel(x, x_repr) for x_repr, a in self.alphas.items()],
            0.0)

    def __add__(self, f):
        return self.vs.add(self, f)

    def __mul__(self, a):
        return self.vs.scalar_mul(self, a)


# TODO: add vjp of __call__ wrt x (and show it in action)
defvjp(func(RKHSFun.__call__),
       lambda ans, f, x: lambda g: RKHSFun(f.kernel, {x: 1}) * g)


class RKHSFunBox(Box, RKHSFun):
    @property
    def kernel(self):
        return self._value.kernel


RKHSFunBox.register(RKHSFun)


class RKHSFunVSpace(VSpace):
    def __init__(self, value):
        self.kernel = value.kernel
Example #4
0
    else:
        return lambda g: g


defvjp(
    acp._array_from_scalar_or_array,
    array_from_scalar_or_array_gradmaker,
    argnums=(2, 3),
)


@primitive
def untake(x, idx, vs):
    def mut_add(A):
        # in numpy codebase, this used to be:
        # onp.add.at(A, idx, x)
        # according to https://docs-cupy.chainer.org/en/stable/reference/ufunc.html?highlight=ufunc.at,
        # scatter_add is the correct function to use.
        # TODO: PR into cupy codebase the ability to use scatter_add with float64?
        ocpx.scatter_add(A, idx, x)
        return A

    return SparseObject(vs, mut_add)


defvjp(
    func(container.__getitem__),
    lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)),  # noqa: E501
)
defvjp(untake, lambda ans, x, idx, _: lambda g: g[idx])
def recast(op):
    def new_op(*args):
        return op(*map(partial(tf.cast, dtype=tf.float32), args))

    return new_op


np2tf = {
    np.add: recast(operator.add),
    np.subtract: recast(operator.sub),
    np.multiply: recast(operator.mul),
    np.divide: recast(operator.div),
    np.true_divide: recast(operator.div),
    np.einsum: einsum,
    func(VSpace.add_not_none): lambda unused_vs, x, y: x + y,
    cast: tf.cast,
    np.power: tf.pow,
    np.log: tf.log,
    one_hot: tf.one_hot,
    np.sum: tf.reduce_sum,
    special.gammaln: tf.lgamma,
    special.psi: tf.digamma,
    np.reshape: tf.reshape,
    misc.logsumexp: tf.reduce_logsumexp,
    np.exp: tf.exp,
    np.negative: tf.negative,
}


class TFNode(Node):
Example #6
0
    else:
        return lambda g: g


defvjp(
    acp._array_from_scalar_or_array,
    array_from_scalar_or_array_gradmaker,
    argnums=(2, 3),
)


@primitive
def untake(x, idx, vs):
    def mut_add(A):
        # in numpy codebase, this used to be:
        # onp.add.at(A, idx, x)
        # according to https://docs-cupy.chainer.org/en/stable/reference/ufunc.html?highlight=ufunc.at,
        # scatter_add is the correct function to use.
        # TODO: PR into cupy codebase the ability to use scatter_add with float64?
        ocpx.scatter_add(A, idx, x)
        return A

    return SparseObject(vs, mut_add)


defvjp(
    func(ArrayBox.__getitem__),
    lambda ans, A, idx: lambda g: untake(g, idx, vspace(A)),  # noqa: E501
)
defvjp(untake, lambda ans, x, idx, _: lambda g: g[idx])
Example #7
0
from .tracers import remake_expr
from .tracers import toposort
from .tracers import env_lookup
from .util import Enum

## canonicalization rule sets

eager_simplifications = {
    np.dot: rewrites.dot_as_einsum,
    np.multiply: rewrites.maybe_multiply,
    np.divide: rewrites.maybe_divide,
    np.true_divide: rewrites.maybe_divide,
    np.add: rewrites.maybe_add,
    np.subtract: rewrites.maybe_subtract,
    np.einsum: rewrites.maybe_einsum,
    ag_util.func(ag_extend.VSpace.add): rewrites.maybe_vspace_add,
    ag_util.func(ag_extend.VSpace.mut_add): rewrites.maybe_vspace_add,
    np.reciprocal: lambda x: x**-1,
    np.square: lambda x: x**2,
    np.sqrt: lambda x: x**0.5,
    np.power: rewrites.maybe_power,
    np.swapaxes: rewrites.swapaxes,
    add_n: lambda *args: args[0] if len(args) == 0 else add_n(*args),
}

simplification_rules = [
    rewrites.transpose_inside_einsum, rewrites.replace_sum,
    rewrites.combine_einsum_compositions, rewrites.distribute_einsum,
    rewrites.einsum_repeated_one_hot, rewrites.log_behind_onehot_einsum,
    rewrites.log_addn_behind_onehot_einsum, rewrites.replace_log_einsum,
    rewrites.fold_power, rewrites.add_powers_within_einsum,
        input_dimension = Nx * Ny
        output_dimension = Nx_interp * Ny_interp

        interp_weights = np.zeros(4 * output_dimension)
        row_ind = np.zeros(4 * output_dimension, dtype=np.int64)
        col_ind = np.zeros(4 * output_dimension, dtype=np.int64)

        ri = 0
        for rx in rho_x_interp:
            for ry in rho_y_interp:
                # get weights
                weights, interp_idx = self.get_bilinear_row(
                    rx, ry, rho_x, rho_y)

                # populate sparse matrix vectors
                interp_weights[4 * ri:4 * (ri + 1)] = weights
                row_ind[4 * ri:4 * (ri + 1)] = np.array([ri, ri, ri, ri],
                                                        dtype=np.int64)
                col_ind[4 * ri:4 * (ri + 1)] = interp_idx

                ri += 1

        # From matrix vectors, populate the sparse matrix
        A = sparse.coo_matrix((interp_weights, (row_ind, col_ind)),
                              shape=(output_dimension, input_dimension))

        return A


defvjp(func(BilinearInterpolationBasis.__call__), None,
       lambda ans, f, p, eps: lambda a: f.gradient(p, eps), None)
Example #9
0
class RKHSFun(object):
    def __init__(self, kernel, alphas={}):
        self.alphas = alphas
        self.kernel = kernel
        self.vs = RKHSFunVSpace(self)

    @primitive
    def __call__(self, x):
        return sum([a * self.kernel(x, x_repr)
                    for x_repr, a in self.alphas.items()], 0.0)

    def __add__(self, f):  return self.vs.add(self, f)
    def __mul__(self, a):  return self.vs.scalar_mul(self, a)

# TODO: add vjp of __call__ wrt x (and show it in action)
defvjp(func(RKHSFun.__call__),
       lambda ans, f, x: lambda g: RKHSFun(f.kernel, {x : 1}) * g)

class RKHSFunBox(Box, RKHSFun):
    @property
    def kernel(self): return self._value.kernel
RKHSFunBox.register(RKHSFun)

class RKHSFunVSpace(VSpace):
    def __init__(self, value):
        self.kernel = value.kernel

    def zeros(self): return RKHSFun(self.kernel)
    def randn(self):
        # These arbitrary vectors are not analogous to randn in any meaningful way
        N = npr.randint(1,3)
Example #10
0
)
from autograd.extend import (
    defjvp,
    defjvp_argnum,
    def_linear,
    vspace,
    JVPNode,
    register_notrace,
)
from autograd.util import func
from .cupy_boxes import ArrayBox

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(ArrayBox.__getitem__), "same")
defjvp(untake, "same")

defjvp_argnum(
    acp.array_from_args,
    lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans)),
)
defjvp(
    acp._array_from_scalar_or_array,
    None,
    None,
    lambda g, ans, args, kwargs, _: acp._array_from_scalar_or_array(
        args, kwargs, g),
)

# ----- Functions that are constant w.r.t. continuous inputs -----
Example #11
0
)
from autograd.extend import (
    defjvp,
    defjvp_argnum,
    def_linear,
    vspace,
    JVPNode,
    register_notrace,
)
from autograd.util import func
from .cupy_containers import container

for fun in nograd_functions:
    register_notrace(JVPNode, fun)

defjvp(func(container.__getitem__), "same")
defjvp(untake, "same")

defjvp_argnum(
    acp.array_from_args,
    lambda argnum, g, ans, args, kwargs: untake(g, argnum - 2, vspace(ans)),
)
defjvp(
    acp._array_from_scalar_or_array,
    None,
    None,
    lambda g, ans, args, kwargs, _: acp._array_from_scalar_or_array(
        args, kwargs, g),
)

# ----- Functions that are constant w.r.t. continuous inputs -----