Esempio n. 1
0
  def __call__(
      self,
      value,
      update_stats: bool = True,
      error_on_non_matrix: bool = False,
  ) -> jnp.ndarray:
    """Performs Spectral Normalization and returns the new value.
    Args:
      value: The array-like object for which you would like to perform an
        spectral normalization on.
      update_stats: A boolean defaulting to True. Regardless of this arg, this
        function will return the normalized input. When
        `update_stats` is True, the internal state of this object will also be
        updated to reflect the input value. When `update_stats` is False the
        internal stats will remain unchanged.
      error_on_non_matrix: Spectral normalization is only defined on matrices.
        By default, this module will return scalars unchanged and flatten
        higher-order tensors in their leading dimensions. Setting this flag to
        True will instead throw errors in those cases.
    Returns:
      The input value normalized by it's first singular value.
    Raises:
      ValueError: If `error_on_non_matrix` is True and `value` has ndims > 2.
    """
    value = jnp.asarray(value)
    value_shape = value.shape

    # Handle scalars.
    if value.ndim <= 1:
      raise ValueError("Spectral normalization is not well defined for "
                       "scalar or vector inputs.")
    # Handle higher-order tensors.
    elif value.ndim > 2:
      if error_on_non_matrix:
        raise ValueError(
            f"Input is {value.ndim}D but error_on_non_matrix is True")
      else:
        value = jnp.reshape(value, [-1, value.shape[-1]])

    u0 = hk.get_state("u0", [1, value.shape[-1]], value.dtype,
                      init=hk.initializers.RandomNormal())

    # Power iteration for the weight's singular value.
    for _ in range(self.n_steps):
      v0 = _l2_normalize(jnp.matmul(u0, value.transpose([1, 0])), eps=self.eps)
      u0 = _l2_normalize(jnp.matmul(v0, value), eps=self.eps)

    u0 = jax.lax.stop_gradient(u0)
    v0 = jax.lax.stop_gradient(v0)

    sigma = jnp.matmul(jnp.matmul(v0, value), jnp.transpose(u0))[0, 0]

    value /= sigma
    value *= self.val
    value_bar = value.reshape(value_shape)

    if update_stats:
      hk.set_state("u0", u0)
      hk.set_state("sigma", sigma)

    return value_bar
Esempio n. 2
0
 def new_pars(par):
     return jnp.asarray(
         init_fun(rng.take(1)[0], shape=par.shape, dtype=par.dtype),
         dtype=par.dtype,
     )
Esempio n. 3
0
def modified_bessel_first_kind(v, z):
    v = jnp.asarray(v, dtype=float)
    return jnp.exp(jnp.abs(z)) * tfp.math.bessel_ive(v, z)
Esempio n. 4
0
 def init_state(self, params):
     param_states = jax.tree_map(self.init_param_state, params)
     state = OptimizerState(jnp.asarray(0, dtype=jnp.int32), param_states)
     return state
Esempio n. 5
0
    def apply(self,
              inputs,
              features,
              bias=True,
              scale=True,
              dtype=jnp.float32,
              precision=None,
              bias_init=initializers.zeros,
              scale_init=None,
              softplus=True,
              norm_grad_block=False):
        if scale_init is None:
            if softplus:
                scale_init = new_initializers.init_softplus_ones
            else:
                scale_init = initializers.ones

        norm = jax.scipy.stats.norm
        erf = jax.scipy.special.erf  # Error function.

        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
        else:
            bias = 0.

        if scale:
            scale = self.param('scale', (features, ), scale_init)
            scale = jnp.asarray(scale, dtype)
        else:
            scale = float(
                new_initializers.inv_softplus(1.0)) if softplus else 1.0

        if softplus:
            scale = nn.softplus(scale)

        pre = inputs
        pre *= scale
        pre = pre + bias
        y = jax.nn.selu(pre)

        # Compute moments based in learned scale/bias.
        if norm_grad_block:
            scale = jax.lax.stop_gradient(scale)
            bias = jax.lax.stop_gradient(bias)
        std = scale
        mean = bias
        var = std**2

        # SELU magic numbers from SeLU paper [2] and jax.nn.selu.
        alpha = 1.6732632423543772848170429916717
        selu_scale = 1.0507009873554804934193349852946
        selu_threshold = 0

        # Compute moments of left and right side of split gaussian for x <=0 & x > 0
        t = (selu_threshold - mean) / std
        # If the distribution lies 4 stdev below the threshold, cap at t=4.
        t = jnp.maximum(-3, jnp.minimum(3, t))
        z = 1 - norm.cdf(t)
        new_mean_right = (mean + (std * norm.pdf(t)) / z)
        new_var_right = (var) * (1 + t * norm.pdf(t) / z -
                                 (norm.pdf(t) / z)**2)

        l_scale = jnp.exp(mean)  # Log normal scale parameter = exp(mean)
        log_scale = mean
        min_log = -5

        # Compute truncated log normal statistics for left part of SELU.
        # TODO(basv): improve numerical errors with np.exp1m?
        a1 = .5 * (1. /
                   (std + 1e-5)) * jnp.sqrt(2) * (-var + min_log - log_scale)
        a2 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (var + log_scale -
                                                       selu_threshold)
        a3 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (min_log - log_scale)
        a4 = .5 * (1. /
                   (std + 1e-5)) * jnp.sqrt(2) * (-selu_threshold + log_scale)
        a5 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (-2 * var + min_log -
                                                       log_scale)
        a6 = .5 * (1. / (std + 1e-5)) * jnp.sqrt(2) * (2 * var + log_scale -
                                                       selu_threshold)
        e_a1 = erf(a1)
        e_a2 = erf(a2)
        e_a3 = erf(a3)
        e_a4 = erf(a4)
        e_a5 = erf(a5)
        e_a6 = erf(a6)
        exp_var = jnp.exp(var)

        # Equation 18 [1].
        trunc_lognorm_mean = (l_scale * jnp.exp(.5 * var) *
                              (e_a1 + e_a2)) / (e_a3 + e_a4 + 1e-5)
        trunc_lognorm_mean_m1 = trunc_lognorm_mean - 1  # selu uses e^x - 1
        # Equation 20 [1].
        n = exp_var * (e_a3 * e_a5 * exp_var + e_a3 * e_a6 * exp_var +
                       e_a4 * e_a5 * exp_var + e_a4 * e_a6 * exp_var -
                       e_a1**2 - 2 * e_a1 * e_a2 - e_a2**2) * l_scale**2
        # Equation 19 [1].
        trunc_lognorm_var = n / ((e_a3 + e_a4 + 1e-5)**2)

        selu_mean = alpha * trunc_lognorm_mean_m1
        selu_var = alpha**2 * trunc_lognorm_var

        # Compute mixture mean multiplied by selu_scale.
        new_mean = (selu_mean * (1 - z) + new_mean_right * z)

        # Compute mixture variance.
        new_var = z * (new_var_right + new_mean_right**2 - new_mean**2)
        new_var += (1 - z) * (selu_var + selu_mean**2 - new_mean**2)
        new_mean = selu_scale * new_mean
        new_std = jnp.sqrt(new_var + 1e-5) * selu_scale
        new_var *= selu_scale**2

        if norm_grad_block:
            new_mean = jax.lax.stop_gradient(new_mean)
            new_std = jax.lax.stop_gradient(new_std)

        new_std = jnp.maximum(1e-3, new_std)

        # Normalize y.
        y_norm = y
        y_norm -= new_mean
        y_norm /= new_std
        return y_norm
Esempio n. 6
0
 def func(rng, X):
     X = jnp.asarray(X, dtype=space.dtype)   # ensure ndarray
     X = jnp.reshape(X, (-1, *space.shape))  # ensure batch axis
     X = jnp.clip(X, space.low, space.high)  # clip to be safe
     return X
Esempio n. 7
0
import matplotlib.pyplot as plt

config.update("jax_enable_x64", True)
fn.set_log_level(fn.LogLevel.ERROR)

# Create mesh
n = 30
mesh = fn.UnitIntervalMesh(n)

# Define discrete function spaces and functions
V = fn.FunctionSpace(mesh, "CG", 2)

v = fn.TestFunction(V)
nu = fn.Constant(0.0001)
a = fn.Constant(0.4)
timestep = np.asarray([0.05])
bcs = [fn.DirichletBC(V, 0.0, "on_boundary")]


def Dt(u, u_prev, timestep):
    return (u - u_prev) / timestep


solve_templates = (fn.Function(V), fn.Constant(0.0))
assemble_templates = (fn.Function(V),)


@build_jax_solve_eval(solve_templates)
def fenics_solve(u_prev, timestep):
    # Define and solve one step of the Burgers equation
    u = fn.Function(V)
Esempio n. 8
0
 def push(self, arr: Array) -> int:
   self._buffers.append(jnp.asarray(arr))  # type: ignore
   return len(self._buffers) - 1
Esempio n. 9
0
 def f3(x, y):
   r1 = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1]])
   r2 = jnp.asarray([4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
   return r1, r2
Esempio n. 10
0
import time
from functools import partial

#######################################################################
# Reading in Data
#######################################################################

NS = 32
N = 100
lr = 0.0001
theta = np.linspace(0, 2 * np.pi, NS + 1)
nn = np.load("nn.npy")
sg = np.load("sg.npy")
r_surf = np.load("r_surf.npy")
with tb.open_file("coils.hdf5", "r") as f:
    p = np.asarray(f.root.coilSeries[:, :, :])

num_devices = device_count()
assert nn.shape[0] % num_devices == 0
nn = np.reshape(
    nn, (num_devices, nn.shape[0] // num_devices, nn.shape[1], nn.shape[2]))
r_surf = np.reshape(r_surf, (num_devices, r_surf.shape[0] // num_devices,
                             r_surf.shape[1], r_surf.shape[2]))
sg = np.reshape(sg, (num_devices, sg.shape[0] // num_devices, sg.shape[1]))

#######################################################################
# Calculating Objective Function
#######################################################################


def biot_savart(r_eval, dl, l):
Esempio n. 11
0
    def __init__(
        self,
        vocab_size: Optional[int] = None,
        embed_dim: Optional[int] = None,
        embedding_matrix: Optional[jnp.ndarray] = None,
        w_init: Optional[hk.initializers.Initializer] = None,
        lookup_style: Union[str, hk.EmbedLookupStyle] = "ARRAY_INDEX",
        name: Optional[str] = None,
    ):
        """Constructs an Embed module.

    Args:
      vocab_size: The number of unique tokens to embed. If not provided, an
        existing vocabulary matrix from which ``vocab_size`` can be inferred
        must be provided as ``embedding_matrix``.
      embed_dim: Number of dimensions to assign to each embedding. If an
        existing vocabulary matrix initializes the module, this should not be
        provided as it will be inferred.
      embedding_matrix: A matrix-like object equivalent in size to
        ``[vocab_size, embed_dim]``. If given, it is used as the initial value
        for the embedding matrix and neither ``vocab_size`` or ``embed_dim``
        need be given. If they are given, their values are checked to be
        consistent with the dimensions of ``embedding_matrix``.
      w_init: An initializer for the embeddings matrix. As a default,
        embeddings are initialized via a truncated normal distribution.
      lookup_style: One of the enum values of :class:`EmbedLookupStyle`
        determining how to access the value of the embbeddings given an ID.
        Regardless the input should be a dense array of integer values
        representing ids. This setting changes how internally this module maps
        those ides to embeddings. The result is the same, but the speed and
        memory tradeoffs are different. It default to using numpy-style array
        indexing. This value is only the default for the module, and at any
        given invocation can be overriden in :meth:`__call__`.
      name: Optional name for this module.

    Raises:
      ValueError: If none of ``embed_dim``, ``embedding_matrix`` and
        ``vocab_size`` are supplied, or if ``embedding_matrix`` is supplied
        and ``embed_dim`` or ``vocab_size`` is not consistent with the
        supplied matrix.
    """
        super().__init__(name=name)
        if embedding_matrix is None and not (vocab_size and embed_dim):
            raise ValueError(
                "hk.Embed must be supplied either with an initial `embedding_matrix` "
                "or with `embed_dim` and `vocab_size`.")

        if embedding_matrix is not None:
            embedding_matrix = jnp.asarray(embedding_matrix)
            if vocab_size and embedding_matrix.shape[0] != vocab_size:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `vocab_size` of "
                    f"{vocab_size} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            if embed_dim and embedding_matrix.shape[1] != embed_dim:
                raise ValueError(
                    "An `embedding_matrix` was supplied but the `embed_dim` of "
                    f"{embed_dim} was not consistent with its shape "
                    f"{embedding_matrix.shape}.")
            w_init = lambda _, __: embedding_matrix
            vocab_size = embedding_matrix.shape[0]
            embed_dim = embedding_matrix.shape[1]

        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.lookup_style = lookup_style
        self.w_init = w_init or hk.initializers.TruncatedNormal()
Esempio n. 12
0
    def apply(self,
              inputs,
              filters,
              kernel_size,
              block_size,
              strides=None,
              padding='SAME',
              input_dilation=None,
              kernel_dilation=None,
              feature_group_count=1,
              bias=True,
              dtype=jnp.float32,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros):
        """Applies a convolution to the inputs.

    Args:
      inputs: input data with dimensions (batch, spatial_dims..., features).
      filters: number of convolution filters.
      kernel_size: shape of the convolutional kernel.
      block_size: shape of space-to-depth blocks.
      strides: a sequence of `n` integers, representing the inter-window
        strides.
      padding: either the string `'SAME'`, the string `'VALID'`, or a sequence
        of `n` `(low, high)` integer pairs that give the padding to apply before
        and after each spatial dimension.
      input_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of `inputs`.
        Convolution with input dilation `d` is equivalent to transposed
        convolution with stride `d`.
      kernel_dilation: `None`, or a sequence of `n` integers, giving the
        dilation factor to apply in each spatial dimension of the convolution
        kernel. Convolution with kernel dilation is also known as 'atrous
        convolution'.
      feature_group_count: integer, default 1. If specified divides the input
        features into groups.
      bias: whether to add a bias to the output (default: True).
      dtype: the dtype of the computation (default: float32).
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the convolutional kernel.
      bias_init: initializer for the bias.
    Returns:
      The convolved data.
    """
        inputs = jnp.asarray(inputs, dtype)

        if strides is None:
            strides = block_size
        assert strides[0] % block_size[0] == 0
        assert strides[1] % block_size[1] == 0
        strides = tuple(s // b for s, b in zip(strides, block_size))

        # create kernel as if there were no space to depth
        batch_size, h, w, features = inputs.shape
        original_input_shape = (batch_size, h * block_size[0],
                                w * block_size[1],
                                features // block_size[0] // block_size[1])
        in_features = original_input_shape[-1]
        assert in_features % feature_group_count == 0
        kernel_shape = kernel_size + (in_features // feature_group_count,
                                      filters)
        kernel = self.param('kernel', kernel_shape, kernel_init)
        kernel = jnp.asarray(kernel, dtype)

        # zero-pad kernel to multiple of block size (e.g. 7x7 --> 8x8)
        h_blocks, h_ragged = divmod(kernel_size[0], block_size[0])
        h_blocks = h_blocks + 1
        if h_ragged != 0:
            kernel = jnp.pad(kernel,
                             pad_width=[[block_size[0] - h_ragged, 0], [0, 0],
                                        [0, 0], [0, 0]],
                             mode='constant',
                             constant_values=0.)
        w_blocks, w_ragged = divmod(kernel_size[1], block_size[1])
        w_blocks = w_blocks + 1
        if w_ragged != 0:
            kernel = jnp.pad(kernel,
                             pad_width=[[0, 0], [block_size[1] - w_ragged, 0],
                                        [0, 0], [0, 0]],
                             mode='constant',
                             constant_values=0.)

        # transform kernel following space-to-depth logic: http://shortn/_9YvHW96xPJ
        kernel = jnp.reshape(kernel, [
            h_blocks, block_size[0], w_blocks, block_size[1],
            in_features // feature_group_count, filters
        ])
        kernel = jnp.transpose(kernel, [0, 2, 1, 3, 4, 5])
        kernel = jnp.reshape(kernel, [h_blocks, w_blocks, features, filters])
        kernel = kernel.astype(inputs.dtype)

        dimension_numbers = nn.linear._conv_dimension_numbers(inputs.shape)  # pylint: disable=protected-access

        y = lax.conv_general_dilated(lhs=inputs,
                                     rhs=kernel,
                                     window_strides=strides,
                                     padding=padding,
                                     lhs_dilation=input_dilation,
                                     rhs_dilation=kernel_dilation,
                                     dimension_numbers=dimension_numbers,
                                     feature_group_count=feature_group_count,
                                     precision=precision)
        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
            y = y + bias
        return y
Esempio n. 13
0
 def near_singular_inverse(N=5, eps=1E-40):
     X = rng((N, N), dtype='float64')
     X = jnp.asarray(X)
     X = X.at[-1].mul(eps)
     return jnp.linalg.inv(X)
Esempio n. 14
0
 def nan_err(args):
     return jnp.asarray(jnp.nan, dtype=stat_dtype), jnp.asarray(
         jnp.nan, dtype=stat_dtype
     )
Esempio n. 15
0
def _check_synced(pytree):
    mins = jax.lax.pmin(pytree, axis_name='batch')
    equals = jax.tree_multimap(jnp.array_equal, pytree, mins)
    return jnp.all(jnp.asarray(jax.tree_leaves(equals)))
Esempio n. 16
0
 def f1(x, y):
   r = jnp.asarray([x[0] * y[0], 5 * x[2] * y[1], 4 * x[1] ** 2 - 2 * x[2], x[2] * jnp.sin(x[0])])
   return r
Esempio n. 17
0
 def func(rng, X):
     X = jnp.asarray(X)
     X = jax.nn.one_hot(X, space.n)     # one-hot encoding
     X = jnp.reshape(X, (-1, space.n))  # ensure batch axis
     return X
Esempio n. 18
0
 def loss_fn(p):
     return jnp.asarray(p, jnp.float16)**2
Esempio n. 19
0
 def func(rng, X):
     X = jnp.asarray(X, dtype=jnp.float32)  # ensure ndarray
     X = jnp.reshape(X, (-1, space.n))      # ensure batch axis
     return X
Esempio n. 20
0
 def __init__(self, input_dim, active_dims=None):
     self.input_dim = input_dim
     if active_dims is None:
         self.active_dims = jnp.arange(input_dim)
     else:
         self.active_dims = jnp.asarray(active_dims, np.int)
Esempio n. 21
0
def ndim_at_least(x, num_dims):
  if x is None:
    return False
  x = jnp.asarray(x)
  return len(x.shape) >= num_dims
Esempio n. 22
0
 def grad_fn(g):
     g_logits = jnp.expand_dims(g,
                                axis=-1) * (exp_shifted / sum_exp - targets)
     return jnp.asarray(g_logits,
                        logits.dtype), jnp.asarray(g, targets.dtype)
Esempio n. 23
0
def rolled_loop_step(i, state):
  x, ks, rotations = state
  for r in rotations[0]:
    x = apply_round(x, r)
  new_x = [x[0] + ks[0], x[1] + ks[1] + jnp.asarray(i + 1, dtype=np.uint32)]
  return new_x, rotate_list(ks), rotate_list(rotations)
Esempio n. 24
0
 def f(*args):
     return jnp.asarray(args).sum()
Esempio n. 25
0
    def apply(self,
              inputs,
              features,
              bias=True,
              scale=False,
              dtype=jnp.float32,
              precision=None,
              bias_init=initializers.zeros,
              scale_init=None,
              softplus=True):
        if scale_init is None:
            if softplus:
                scale_init = new_initializers.init_softplus_ones
            else:
                scale_init = initializers.ones
        if bias:
            bias = self.param('bias', (features, ), bias_init)
            bias = jnp.asarray(bias, dtype)
        else:
            bias = 0.

        if scale:
            scale = self.param('scale', (features, ), scale_init)
            scale = jnp.asarray(scale, dtype)
        else:
            scale = float(
                new_initializers.inv_softplus(1.0)) if softplus else 1.0

        if softplus:
            scale = nn.softplus(scale)

        y = inputs
        y *= scale
        y = y + bias
        relu_threshold = 0.0
        y = jnp.maximum(relu_threshold, y)

        # Normalize y analytically.
        mean = bias
        std = scale
        var = std**2
        # Kaiming initialized weights + bias + TLU
        # = mixture of delta peak + left-truncated gaussian
        # https://en.wikipedia.org/wiki/Mixture_distribution#Moments
        # https://en.wikipedia.org/wiki/Truncated_normal_distribution#One_sided_truncation_(of_lower_tail)[4]
        norm = jax.scipy.stats.norm
        t = (relu_threshold - mean) / std

        # If the distribution lies 4 stdev below the threshold, cap at t=4.
        t = jnp.minimum(4, t)
        z = 1 - norm.cdf(t)

        new_mean_non_cut = mean + (std * norm.pdf(t)) / z
        new_var_non_cut = (var) * (1 + t * norm.pdf(t) / z -
                                   (norm.pdf(t) / z)**2)

        # Psi function.
        # Compute mixture mean.
        new_mean = new_mean_non_cut * z + relu_threshold * norm.cdf(t)
        # Compute mixture variance.
        new_var = z * (new_var_non_cut + new_mean_non_cut**2 - new_mean**2)
        new_var += (1 - z) * (0 + relu_threshold**2 - new_mean**2)
        new_std = jnp.sqrt(new_var + 1e-8)
        new_std = jnp.maximum(0.01, new_std)

        # Normalize y.
        y_norm = y
        y_norm -= new_mean
        y_norm /= new_std
        return y_norm
Esempio n. 26
0
def step(state, t, params, D, stimuli, dt, dx):
    # v, w, u, at, max_du = state
    v, w, u = state

    # apply stimulus
    u = np.where(params["current_stimulus"], u, stimulate(t, u, stimuli))

    # apply boundary conditions
    v = neumann(v)
    w = neumann(w)
    u = neumann(u)

    # gate variables
    p = np.greater_equal(u, params["V_c"])
    q = np.greater_equal(u, params["V_v"])
    tau_v_minus = (1 - q) * params["tau_v1_minus"] + q * params["tau_v2_minus"]

    d_v = ((1 - p) * (1 - v) / tau_v_minus) - ((p * v) / params["tau_v_plus"])
    d_w = ((1 - p) * (1 - w) / params["tau_w_minus"]) - ((p * w) / params["tau_w_plus"])

    # currents
    J_fi = - v * p * (u - params["V_c"]) * (1 - u) / params["tau_d"]
    J_so = (u * (1 - p) / params["tau_0"]) + (p / params["tau_r"])
    J_si = - (w * (1 + np.tanh(params["k"] * (u - params["V_csi"])))) / (2 * params["tau_si"])

    I_ion = -(J_fi + J_so + J_si) / params["Cm"]

    # voltage01
#     u_x, u_y = np.gradient(u)
    u_x, u_y = gradient(u, 0), gradient(u, 1)
    u_x /= dx
    u_y /= dx
#     u_xx = np.gradient(u_x, axis=0)
#     u_yy = np.gradient(u_y, axis=1)
    u_xx = gradient(u_x, 0)
    u_yy = gradient(u_y, 1)
    u_xx /= dx
    u_yy /= dx
#     D_x, D_y = np.gradient(D)
#     D_x /= dx
#     D_y /= dx

# Kostas ---------
    D_x, D_y = gradient(D,0), gradient(D,1)
    D_x /= dx
    D_y /= dx
    extra_term = D_x*u_x + D_y*u_y


    current_stimuli = np.zeros(u.shape)
    current_stimuli = np.where(params["current_stimulus"], stimulate(t, current_stimuli, stimuli), current_stimuli)

# Kostas ---------


    d_u = D * (u_xx + u_yy) + extra_term + I_ion + current_stimuli
    
    # checking du for activation time update
    # at = np.where(np.greater_equal(d_u,max_du), t, at)
    # max_du = np.where(np.greater_equal(d_u,max_du), d_u, max_du)

    # euler update
    v += d_v * dt
    w += d_w * dt
    u += d_u * dt

    
    return np.asarray((v, w, u))
Esempio n. 27
0
    def __call__(self, inputs: Array) -> Array:
        """Applies a linear transformation to the inputs along multiple dimensions.

    Args:
      inputs: The nd-array to be transformed.

    Returns:
      The transformed input.
    """
        inputs = jnp.asarray(inputs, self.dtype)

        ndim = inputs.ndim
        n_batch_dims = len(self.batch_dims)
        axis = _normalize_axes(self.axis, ndim)
        batch_dims = _normalize_axes(self.batch_dims, ndim)
        n_axis, n_features = len(axis), len(self.features)

        def kernel_init_wrap(rng, shape, dtype=jnp.float32):
            size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
            flat_shape = (
                np.prod(shape[n_batch_dims:n_axis + n_batch_dims]),
                np.prod(shape[-n_features:]),
            )
            kernel = jnp.concatenate([
                self.kernel_init(rng, flat_shape, dtype)
                for _ in range(size_batch_dims)
            ],
                                     axis=0)
            return jnp.reshape(kernel, shape)

        batch_shape = tuple([inputs.shape[ax] for ax in batch_dims])
        kernel_shape = tuple([inputs.shape[ax] for ax in axis]) + self.features
        kernel = self.param('kernel', kernel_init_wrap,
                            batch_shape + kernel_shape)
        kernel = jnp.asarray(kernel, self.dtype)

        batch_ind = tuple(range(n_batch_dims))
        contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims))
        out = lax.dot_general(inputs,
                              kernel,
                              ((axis, contract_ind), (batch_dims, batch_ind)),
                              precision=self.precision)
        if self.use_bias:

            def bias_init_wrap(rng, shape, dtype=jnp.float32):
                size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32)
                flat_shape = (np.prod(shape[-n_features:]), )
                bias = jnp.concatenate([
                    self.bias_init(rng, flat_shape, dtype)
                    for _ in range(size_batch_dims)
                ],
                                       axis=0)
                return jnp.reshape(bias, shape)

            bias = self.param('bias', bias_init_wrap,
                              batch_shape + self.features)

            # Reshape bias for broadcast.
            expand_dims = sorted(
                set(range(inputs.ndim)) - set(axis) - set(batch_dims))
            for ax in expand_dims:
                bias = jnp.expand_dims(bias, ax)
            bias = jnp.asarray(bias, self.dtype)
            out = out + bias
        return out
Esempio n. 28
0
def global_norm(pytree):
    return jnp.sqrt(
        jnp.sum(
            jnp.asarray(
                [jnp.sum(jnp.square(x)) for x in jax.tree_leaves(pytree)])))
Esempio n. 29
0
 def make_from_covariance(
         diagonal: Union[hints.Array,
                         Sequence[float]]) -> "DiagonalGaussian":
     return DiagonalGaussian(sqrt_precision_diagonal=1.0 /
                             jnp.sqrt(jnp.asarray(diagonal)))
Esempio n. 30
0
    def __call__(self, x, training: bool):
        """Normalizes the input using batch statistics.
        Args:
            x: the input to be normalized.
        Returns:
            Normalized inputs (the same shape as inputs).
        """
        x = jnp.asarray(x, self.dtype)
        axis = self.axis if isinstance(self.axis, tuple) else (self.axis, )
        axis = _absolute_dims(x.ndim, axis)
        feature_shape = tuple(d if i in axis else 1
                              for i, d in enumerate(x.shape))
        reduced_feature_shape = tuple(d for i, d in enumerate(x.shape)
                                      if i in axis)
        reduction_axis = tuple(i for i in range(x.ndim) if i not in axis)

        # we detect if we're in initialization via empty variable tree.
        initializing = not self.has_variable('batch_stats', 'mean')

        ra_mean = self.variable('batch_stats', 'mean',
                                lambda s: jnp.zeros(s, jnp.float32),
                                reduced_feature_shape)
        ra_var = self.variable('batch_stats', 'var',
                               lambda s: jnp.ones(s, jnp.float32),
                               reduced_feature_shape)

        if not training:
            mean, var = ra_mean.value, ra_var.value
        else:
            mean = jnp.mean(x, axis=reduction_axis, keepdims=False)
            var = jnp.mean(lax.abs(x - mean),
                           axis=reduction_axis,
                           keepdims=False) * jnp.sqrt(jnp.pi / 2)
            if self.axis_name is not None and not initializing:
                concatenated_mean = jnp.concatenate([mean, var])
                mean, var = jnp.split(
                    lax.pmean(concatenated_mean,
                              axis_name=self.axis_name,
                              axis_index_groups=self.axis_index_groups), 2)

            if not initializing:
                ra_mean.value = self.momentum * ra_mean.value + (
                    1 - self.momentum) * mean
                ra_var.value = self.momentum * ra_var.value + (
                    1 - self.momentum) * var

        mean = jnp.asarray(mean, self.dtype)
        var = jnp.asarray(var, self.dtype)
        y = x - mean.reshape(feature_shape)
        mul = lax.reciprocal(var + self.epsilon)
        if self.use_scale:
            scale = self.param('scale', self.scale_init,
                               reduced_feature_shape).reshape(feature_shape)
            scale = jnp.asarray(scale, self.dtype)
            mul = mul * scale
        y = y * mul
        if self.use_bias:
            bias = self.param('bias', self.bias_init,
                              reduced_feature_shape).reshape(feature_shape)
            bias = jnp.asarray(bias, self.dtype)
            y = y + bias
        return jnp.asarray(y, self.dtype)