def log_prob(self, value): batch_shape = jnp.shape(value)[:jnp.ndim(value) - len(self.event_shape)] batch_shape = lax.broadcast_shapes(batch_shape, self.batch_shape) return jnp.zeros(batch_shape)
def testScalarInstantiation(self): for t in [jnp.bool_, jnp.int32, jnp.bfloat16, jnp.float32, jnp.complex64]: a = t(1) self.assertEqual(a.dtype, jnp.dtype(t)) self.assertIsInstance(a, xla.DeviceArray) self.assertEqual(0, jnp.ndim(a))
def log_prob(self, value): shape = lax.broadcast_shapes( self.batch_shape, jnp.shape(value)[:max(jnp.ndim(value) - self.event_dim, 0)]) log_prob = self.base_dist.log_prob(value) return jnp.broadcast_to(log_prob, shape)
def _is_batched(arg): return np.ndim(arg) > 0
def _validate_sample(self, value): mask = super(ImproperUniform, self)._validate_sample(value) batch_dim = jnp.ndim(value) - len(self.event_shape) if batch_dim < jnp.ndim(mask): mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1) return mask
def solve_max( self, inner_dual_vars: Any, opt_instance: InnerVerifInstance, key: jnp.array, step: int, ) -> jnp.array: """Solve maximization problem of opt_instance with projected gradient ascent. Args: inner_dual_vars: Dual variables for the inner maximisation. opt_instance: Verification instance that defines optimization problem to be solved. key: Jax PRNG key. step: outer optimization iteration number. Returns: final_value: final value of the objective function found by PGA. """ if not opt_instance.same_lagrangian_form_pre_post: raise ValueError( 'Different lagrangian forms on inputs and outputs not' 'supported') # only supporting adversarial robustness specification for now # when affine_before_relu and logits layer. affine_before_relu = opt_instance.affine_before_relu assert not (opt_instance.spec_type == verify_utils.SpecType.UNCERTAINTY and opt_instance.is_last and affine_before_relu) # some renaming to simplify variable names if affine_before_relu: lower_bound = opt_instance.bounds[0].lb upper_bound = opt_instance.bounds[0].ub else: lower_bound = opt_instance.bounds[0].lb_pre upper_bound = opt_instance.bounds[0].ub_pre assert lower_bound.shape[ 0] == 1, 'Batching across samples not supported' if self._normalize: center = .5 * (upper_bound + lower_bound) radius = .5 * (upper_bound - lower_bound) normalize_fn = lambda x: x * radius + center lower_bound = -jnp.ones_like(lower_bound) upper_bound = jnp.ones_like(lower_bound) else: normalize_fn = lambda x: x duals_pre = opt_instance.lagrange_params_pre duals_post = opt_instance.lagrange_params_post # dual variables never used for grad tracing duals_pre_nondiff = jax.lax.stop_gradient(duals_pre) duals_post_nondiff = jax.lax.stop_gradient(duals_post) # Define the loss function. if (opt_instance.spec_type == verify_utils.SpecType.UNCERTAINTY and opt_instance.is_last): # Last layer here isn't the final spec layer, treat like other layers new_opt_instance = dataclasses.replace(opt_instance, is_last=False) else: new_opt_instance = opt_instance softmax = (opt_instance.spec_type == verify_utils.SpecType.ADVERSARIAL_SOFTMAX and opt_instance.is_last) obj = self.build_spec(new_opt_instance, step, softmax) def loss_pgd(x): # Expects x without batch dimension, as vmap adds batch-dimension. x = jnp.reshape(x, lower_bound.shape) x = normalize_fn(x) v = obj(x, duals_pre_nondiff, duals_post_nondiff) return -v loss_pgd = jax.vmap(loss_pgd) # Compute shape for compatibility with blackbox 'square' attack. if jnp.ndim(lower_bound) == 2 and self._method == 'square': d = lower_bound.shape[1] max_h = int(np.round(np.sqrt(d))) for h in range(max_h, 0, -1): w, ragged = divmod(d, h) if ragged == 0: break assert d == h * w shape = [1, h, w, 1] else: shape = lower_bound.shape # Optimization. init_x = (upper_bound + lower_bound) / 2 init_x = jnp.reshape(init_x, shape) optimizer = self._build_optimizer(self._method, self._n_iter, self._lr, jnp.reshape(lower_bound, shape), jnp.reshape(upper_bound, shape)) if self._n_restarts > 1: optimizer = optimizer_module.Restarted( optimizer, restarts_using_tiling=self._n_restarts) key, next_key = jax.random.split(key) x = optimizer(loss_pgd, key, init_x) if self._finetune_n_iter > 0: optimizer = self._build_optimizer(self._finetune_method, self._finetune_n_iter, self._finetune_lr, jnp.reshape(lower_bound, shape), jnp.reshape(upper_bound, shape)) x = optimizer(loss_pgd, next_key, x) # compute final value and return it x = normalize_fn(jnp.reshape(x, lower_bound.shape)) final_value = obj( jax.lax.stop_gradient(x), # non-differentiable duals_pre, duals_post # differentiable ) return final_value
def __init__(self, probs, validate_args=None): if jnp.ndim(probs) < 1: raise ValueError("`probs` parameter must be at least one-dimensional.") self.probs = probs super(CategoricalProbs, self).__init__(batch_shape=jnp.shape(self.probs)[:-1], validate_args=validate_args)
def broadcast_to(self, shape): if jnp.ndim(shape) == 0: shape = (shape, ) new_shape = (*shape, *self.impl.key_shape) return PRNGKeyArray(self.impl, jnp.broadcast_to(self._keys, new_shape))
def __post_init__(self): super().__post_init__() if jnp.ndim(self.samples) != 2: samples_r = self.samples.reshape((-1, self.samples.shape[-1])) object.__setattr__(self, "samples", samples_r)
def _update_batched(self, x): axis = self.info.axis axis_diff = np.ndim(x) - len(self.info.shape) axis = tuple(range(axis_diff)) + tuple(a + axis_diff for a in axis) return self._update_axis(x, axis)
def grad_expect_operator_Lrho2( model_apply_fun: Callable, mutable: bool, parameters: PyTree, model_state: PyTree, σ: jnp.ndarray, σp: jnp.ndarray, mels: jnp.ndarray, ) -> Tuple[PyTree, PyTree, Stats]: σ_shape = σ.shape if jnp.ndim(σ) != 2: σ = σ.reshape((-1, σ_shape[-1])) n_samples_node = σ.shape[0] has_aux = mutable is not False # if not has_aux: # out_axes = (0, 0) # else: # out_axes = (0, 0, 0) if not has_aux: logpsi = lambda w, σ: model_apply_fun({"params": w, **model_state}, σ) else: # TODO: output the mutable state logpsi = lambda w, σ: model_apply_fun( {"params": w, **model_state}, σ, mutable=mutable )[0] # local_kernel_vmap = jax.vmap( # partial(local_value_kernel, logpsi), in_axes=(None, 0, 0, 0), out_axes=0 # ) # _Lρ = local_kernel_vmap(parameters, σ, σp, mels).reshape((σ_shape[0], -1)) ( Lρ, der_loc_vals, ) = _local_values_and_grads_notcentered_kernel(logpsi, parameters, σp, mels, σ) # _local_values_and_grads_notcentered_kernel returns a loc_val that is conjugated Lρ = jnp.conjugate(Lρ) LdagL_stats = statistics((jnp.abs(Lρ) ** 2).T) LdagL_mean = LdagL_stats.mean _logpsi_ave, d_logpsi = nkjax.vjp(lambda w: logpsi(w, σ), parameters) # TODO: this ones_like might produce a complexXX type but we only need floatXX # and we cut in 1/2 the # of operations to do. der_logs_ave = d_logpsi( jnp.ones_like(_logpsi_ave).real / (n_samples_node * mpi.n_nodes) )[0] der_logs_ave = jax.tree_map(lambda x: mpi.mpi_sum_jax(x)[0], der_logs_ave) def gradfun(der_loc_vals, der_logs_ave): par_dims = der_loc_vals.ndim - 1 _lloc_r = Lρ.reshape((n_samples_node,) + tuple(1 for i in range(par_dims))) grad = mean(der_loc_vals.conjugate() * _lloc_r, axis=0) - ( der_logs_ave.conjugate() * LdagL_mean ) return grad LdagL_grad = jax.tree_util.tree_multimap(gradfun, der_loc_vals, der_logs_ave) # ⟨L†L⟩ ∈ R, so if the parameters are real we should cast away # the imaginary part of the gradient. # we do this also for standard gradient of energy. # this avoid errors in #867, #789, #850 LdagL_grad = jax.tree_multimap( lambda x, target: (x if jnp.iscomplexobj(target) else x.real).astype( target.dtype ), LdagL_grad, parameters, ) return ( LdagL_stats, LdagL_grad, model_state, )
def initialize_model(rng_key, model, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute(seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy) inv_transforms, replay_model, model_trace = _get_model_transforms( substituted_model, model_args, model_kwargs) constrained_values = {k: v['value'] for k, v in model_trace.items() if v['type'] == 'sample' and not v['is_observed'] and not v['fn'].is_discrete} potential_fn, postprocess_fn = get_potential_fn(model, inv_transforms, replay_model=replay_model, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs) init_strategy = init_strategy if isinstance(init_strategy, partial) else init_strategy() if init_strategy.func is init_to_value: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value(values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params(rng_key, model, init_strategy=init_strategy, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): raise RuntimeError("Cannot find valid initial parameters. Please check your model again.") return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def _sample(key, state): return ppl.random_variable( bd.Independent( # pytype: disable=module-attr bd.Normal(state, scale), # pytype: disable=module-attr reinterpreted_batch_ndims=np.ndim(state)))(key)
def __init__(self, logits, validate_args=None): if jnp.ndim(logits) < 1: raise ValueError("`logits` parameter must be at least one-dimensional.") self.logits = logits super(CategoricalLogits, self).__init__(batch_shape=jnp.shape(logits)[:-1], validate_args=validate_args)
def sum_rightmost(x, dim): out_dim = np.ndim(x) - dim x = np.reshape(x[..., np.newaxis], np.shape(x)[:out_dim] + (-1, )) return np.sum(x, axis=-1)
def _initialize_mass_matrix(z, inverse_mass_matrix, dense_mass): if isinstance(dense_mass, list): if inverse_mass_matrix is None: inverse_mass_matrix = {} # if user specifies an ndarray mass matrix, then we convert it to a dict elif not isinstance(inverse_mass_matrix, dict): inverse_mass_matrix = {tuple(sorted(z)): inverse_mass_matrix} mass_matrix_sqrt = {} mass_matrix_sqrt_inv = {} for site_names in dense_mass: inverse_mm = inverse_mass_matrix.get(site_names) z_block = tuple(z[k] for k in site_names) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, inverse_mm, True) inverse_mass_matrix[site_names] = inverse_mm mass_matrix_sqrt[site_names] = mm_sqrt mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv # NB: this branch only happens when users want to use block diagonal # inverse_mass_matrix, for example, {("a",): jnp.ones(3), ("b",): jnp.ones(3)}. for site_names, inverse_mm in inverse_mass_matrix.items(): if site_names in dense_mass: continue z_block = tuple(z[k] for k in site_names) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, inverse_mm, False) inverse_mass_matrix[site_names] = inverse_mm mass_matrix_sqrt[site_names] = mm_sqrt mass_matrix_sqrt_inv[site_names] = mm_sqrt_inv remaining_sites = tuple( sorted(set(z) - set().union(*inverse_mass_matrix))) if len(remaining_sites) > 0: z_block = tuple(z[k] for k in remaining_sites) inverse_mm, mm_sqrt, mm_sqrt_inv = _initialize_mass_matrix( z_block, None, False) inverse_mass_matrix[remaining_sites] = inverse_mm mass_matrix_sqrt[remaining_sites] = mm_sqrt mass_matrix_sqrt_inv[remaining_sites] = mm_sqrt_inv expected_site_names = sorted(z) actual_site_names = sorted( [k for site_names in inverse_mass_matrix for k in site_names]) assert actual_site_names == expected_site_names, ( "There seems to be a conflict of sites names specified in the initial" " `inverse_mass_matrix` and in `dense_mass` argument.") return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv mass_matrix_size = jnp.size(ravel_pytree(z)[0]) if inverse_mass_matrix is None: if dense_mass: inverse_mass_matrix = jnp.identity(mass_matrix_size) else: inverse_mass_matrix = jnp.ones(mass_matrix_size) mass_matrix_sqrt = mass_matrix_sqrt_inv = inverse_mass_matrix else: if dense_mass: if jnp.ndim(inverse_mass_matrix) == 1: inverse_mass_matrix = jnp.diag(inverse_mass_matrix) mass_matrix_sqrt_inv = jnp.swapaxes( jnp.linalg.cholesky( inverse_mass_matrix[..., ::-1, ::-1])[..., ::-1, ::-1], -2, -1, ) identity = jnp.identity(inverse_mass_matrix.shape[-1]) mass_matrix_sqrt = solve_triangular(mass_matrix_sqrt_inv, identity, lower=True) else: if jnp.ndim(inverse_mass_matrix) == 2: inverse_mass_matrix = jnp.diag(inverse_mass_matrix) mass_matrix_sqrt_inv = jnp.sqrt(inverse_mass_matrix) mass_matrix_sqrt = jnp.reciprocal(mass_matrix_sqrt_inv) return inverse_mass_matrix, mass_matrix_sqrt, mass_matrix_sqrt_inv
def initialize_model( rng_key, model, *, init_strategy=init_to_uniform, dynamic_args=False, model_args=(), model_kwargs=None, forward_mode_differentiation=False, validate_grad=True, ): """ (EXPERIMENTAL INTERFACE) Helper function that calls :func:`~numpyro.infer.util.get_potential_fn` and :func:`~numpyro.infer.util.find_valid_initial_params` under the hood to return a tuple of (`init_params_info`, `potential_fn`, `postprocess_fn`, `model_trace`). :param jax.random.PRNGKey rng_key: random number generator seed to sample from the prior. The returned `init_params` will have the batch shape ``rng_key.shape[:-1]``. :param model: Python callable containing Pyro primitives. :param callable init_strategy: a per-site initialization function. See :ref:`init_strategy` section for available functions. :param bool dynamic_args: if `True`, the `potential_fn` and `constraints_fn` are themselves dependent on model arguments. When provided a `*model_args, **model_kwargs`, they return `potential_fn` and `constraints_fn` callables, respectively. :param tuple model_args: args provided to the model. :param dict model_kwargs: kwargs provided to the model. :param bool forward_mode_differentiation: whether to use forward-mode differentiation or reverse-mode differentiation. By default, we use reverse mode but the forward mode can be useful in some cases to improve the performance. In addition, some control flow utility on JAX such as `jax.lax.while_loop` or `jax.lax.fori_loop` only supports forward-mode differentiation. See `JAX's The Autodiff Cookbook <https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html>`_ for more information. :param bool validate_grad: whether to validate gradient of the initial params. Defaults to True. :return: a namedtupe `ModelInfo` which contains the fields (`param_info`, `potential_fn`, `postprocess_fn`, `model_trace`), where `param_info` is a namedtuple `ParamInfo` containing values from the prior used to initiate MCMC, their corresponding potential energy, and their gradients; `postprocess_fn` is a callable that uses inverse transforms to convert unconstrained HMC samples to constrained values that lie within the site's support, in addition to returning values at `deterministic` sites in the model. """ model_kwargs = {} if model_kwargs is None else model_kwargs substituted_model = substitute( seed(model, rng_key if jnp.ndim(rng_key) == 1 else rng_key[0]), substitute_fn=init_strategy, ) ( inv_transforms, replay_model, has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["param"] }, ) constrained_values = { k: v["value"] for k, v in model_trace.items() if v["type"] == "sample" and not v["is_observed"] and not v["fn"].support.is_discrete } if has_enumerate_support: from numpyro.contrib.funsor import config_enumerate, enum if not isinstance(model, enum): max_plate_nesting = _guess_max_plate_nesting(model_trace) _validate_model(model_trace) model = enum(config_enumerate(model), -max_plate_nesting - 1) potential_fn, postprocess_fn = get_potential_fn( model, inv_transforms, replay_model=replay_model, enum=has_enumerate_support, dynamic_args=dynamic_args, model_args=model_args, model_kwargs=model_kwargs, ) init_strategy = (init_strategy if isinstance(init_strategy, partial) else init_strategy()) if (init_strategy.func is init_to_value) and not replay_model: init_values = init_strategy.keywords.get("values") unconstrained_values = transform_fn(inv_transforms, init_values, invert=True) init_strategy = _init_to_unconstrained_value( values=unconstrained_values) prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, substitute( model, data={ k: site["value"] for k, site in model_trace.items() if site["type"] in ["plate"] }, ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, model_kwargs=model_kwargs, prototype_params=prototype_params, forward_mode_differentiation=forward_mode_differentiation, validate_grad=validate_grad, ) if not_jax_tracer(is_valid): if device_get(~jnp.all(is_valid)): with numpyro.validation_enabled(), trace() as tr: # validate parameters substituted_model(*model_args, **model_kwargs) # validate values for site in tr.values(): if site["type"] == "sample": with warnings.catch_warnings(record=True) as ws: site["fn"]._validate_sample(site["value"]) if len(ws) > 0: for w in ws: # at site information to the warning message w.message.args = ("Site {}: {}".format( site["name"], w.message.args[0]), ) + w.message.args[1:] warnings.showwarning( w.message, w.category, w.filename, w.lineno, file=w.file, line=w.line, ) raise RuntimeError( "Cannot find valid initial parameters. Please check your model again." ) return ModelInfo(ParamInfo(init_params, pe, grad), potential_fn, postprocess_fn, model_trace)
def vindex(tensor, args): """ Vectorized advanced indexing with broadcasting semantics. See also the convenience wrapper :class:`Vindex`. This is useful for writing indexing code that is compatible with batching and enumeration, especially for selecting mixture components with discrete random variables. For example suppose ``x`` is a parameter with ``len(x.shape) == 3`` and we wish to generalize the expression ``x[i, :, j]`` from integer ``i,j`` to tensors ``i,j`` with batch dims and enum dims (but no event dims). Then we can write the generalize version using :class:`Vindex` :: xij = Vindex(x)[i, :, j] batch_shape = broadcast_shape(i.shape, j.shape) event_shape = (x.size(1),) assert xij.shape == batch_shape + event_shape To handle the case when ``x`` may also contain batch dimensions (e.g. if ``x`` was sampled in a plated context as when using vectorized particles), :func:`vindex` uses the special convention that ``Ellipsis`` denotes batch dimensions (hence ``...`` can appear only on the left, never in the middle or in the right). Suppose ``x`` has event dim 3. Then we can write:: old_batch_shape = x.shape[:-3] old_event_shape = x.shape[-3:] xij = Vindex(x)[..., i, :, j] # The ... denotes unknown batch shape. new_batch_shape = broadcast_shape(old_batch_shape, i.shape, j.shape) new_event_shape = (x.size(1),) assert xij.shape = new_batch_shape + new_event_shape Note that this special handling of ``Ellipsis`` differs from the NEP [1]. Formally, this function assumes: 1. Each arg is either ``Ellipsis``, ``slice(None)``, an integer, or a batched integer tensor (i.e. with empty event shape). This function does not support Nontrivial slices or boolean tensor masks. ``Ellipsis`` can only appear on the left as ``args[0]``. 2. If ``args[0] is not Ellipsis`` then ``tensor`` is not batched, and its event dim is equal to ``len(args)``. 3. If ``args[0] is Ellipsis`` then ``tensor`` is batched and its event dim is equal to ``len(args[1:])``. Dims of ``tensor`` to the left of the event dims are considered batch dims and will be broadcasted with dims of tensor args. Note that if none of the args is a tensor with ``len(shape) > 0``, then this function behaves like standard indexing:: if not any(isinstance(a, np.ndarray) and len(a.shape) > 0 for a in args): assert Vindex(x)[args] == x[args] **References** [1] https://www.numpy.org/neps/nep-0021-advanced-indexing.html introduces ``vindex`` as a helper for vectorized indexing. This implementation is similar to the proposed notation ``x.vindex[]`` except for slightly different handling of ``Ellipsis``. :param np.ndarray tensor: A tensor to be indexed. :param tuple args: An index, as args to ``__getitem__``. :returns: A nonstandard interpetation of ``tensor[args]``. :rtype: np.ndarray """ if not isinstance(args, tuple): return tensor[args] if not args: return tensor assert np.ndim(tensor) > 0 # Compute event dim before and after indexing. if args[0] is Ellipsis: args = args[1:] if not args: return tensor old_event_dim = len(args) args = (slice(None),) * (np.ndim(tensor) - len(args)) + args else: args = args + (slice(None),) * (np.ndim(tensor) - len(args)) old_event_dim = len(args) assert len(args) == np.ndim(tensor) if any(a is Ellipsis for a in args): raise NotImplementedError("Non-leading Ellipsis is not supported") # In simple cases, standard advanced indexing broadcasts correctly. is_standard = True if np.ndim(tensor) > old_event_dim and _is_batched(args[0]): is_standard = False elif any(_is_batched(a) for a in args[1:]): is_standard = False if is_standard: return tensor[args] # Convert args to use broadcasting semantics. new_event_dim = sum(isinstance(a, slice) for a in args[-old_event_dim:]) new_dim = 0 args = list(args) for i, arg in reversed(list(enumerate(args))): if isinstance(arg, slice): # Convert slices to arange()s. if arg != slice(None): raise NotImplementedError("Nontrivial slices are not supported") arg = np.arange(tensor.shape[i], dtype=np.int32) arg = arg.reshape((-1,) + (1,) * new_dim) new_dim += 1 elif _is_batched(arg): # Reshape nontrivial tensors. arg = arg.reshape(arg.shape + (1,) * new_event_dim) args[i] = arg args = tuple(args) return tensor[args]
def testScalarInstantiation(self, scalar_type): a = scalar_type(1) self.assertEqual(a.dtype, jnp.dtype(scalar_type)) self.assertIsInstance(a, jnp.DeviceArray) self.assertEqual(0, jnp.ndim(a)) self.assertIsInstance(np.dtype(scalar_type).type(1), scalar_type)
def general_norm(self, x): x = np.asarray(x) if np.ndim(x) == 0: x = x[None] return np.linalg.norm(x)