def __call__(self, input_ids, token_type_ids, position_ids, attention_mask): # Embed w_emb = FlaxRobertaEmbedding(self.vocab_size, self.hidden_size, name="word_embeddings")(jnp.atleast_2d( input_ids.astype("i4"))) p_emb = FlaxRobertaEmbedding(self.max_length, self.hidden_size, name="position_embeddings")( jnp.atleast_2d( position_ids.astype("i4"))) t_emb = FlaxRobertaEmbedding(self.type_vocab_size, self.hidden_size, name="token_type_embeddings")( jnp.atleast_2d( token_type_ids.astype("i4"))) # Sum all embeddings summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb # Layer Norm layer_norm = FlaxRobertaLayerNorm(name="layer_norm")(summed_emb) return layer_norm
def update(self, likelihood, y, post_mean, post_cov, hyp=None, site_params=None): """ The update function takes a likelihood as input, and uses CVI to update the site parameters """ if site_params is None: _, dE_dm, dE_dv = likelihood.variational_expectation( y, post_mean, post_cov, hyp, self.cubature_func) dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv) site_cov = -0.5 * inv_any(dE_dv + 1e-10 * np.eye(dE_dv.shape[0])) site_mean = post_mean + site_cov @ dE_dm site_cov = ensure_positive_variance(site_cov) else: site_mean, site_cov = site_params log_marg_lik, dE_dm, dE_dv = likelihood.variational_expectation( y, post_mean, post_cov, hyp, self.cubature_func) dE_dm, dE_dv = np.atleast_2d(dE_dm), np.atleast_2d(dE_dv) dE_dv = -ensure_positive_variance(-dE_dv) lambda_t_2 = inv_any(site_cov + 1e-10 * np.eye(site_cov.shape[0])) lambda_t_1 = lambda_t_2 @ site_mean lambda_t_1 = (1 - self.damping) * lambda_t_1 + self.damping * ( dE_dm - 2 * dE_dv @ post_mean) lambda_t_2 = (1 - self.damping) * lambda_t_2 + self.damping * ( -2 * dE_dv) site_cov = inv_any(lambda_t_2 + 1e-10 * np.eye(site_cov.shape[0])) site_mean = site_cov @ lambda_t_1 log_marg_lik, _, _ = likelihood.moment_match(y, post_mean, post_cov, hyp, 1.0, self.cubature_func) return log_marg_lik, site_mean, site_cov
def __do_rank_regression(self): f = jnp.hstack((jnp.atleast_2d(self.failures).T, jnp.zeros((self.failures.shape[0], 1)))) f = f[f[:, 0].argsort()] f = jnp.hstack((f, jnp.reshape(jnp.arange(self.failures.shape[0]), (self.failures.shape[0], -1)))) # censored items will be having flag '1' c = jnp.hstack((jnp.atleast_2d(self.censored).T, jnp.ones((self.censored.shape[0], 1)))) c = jnp.hstack((c, jnp.reshape(jnp.empty(self.censored.shape[0]), (self.censored.shape[0], -1)))) d = jnp.concatenate((c, f), axis=0) d = d[d[:, 0].argsort()] df = pd.DataFrame(data=d, columns=['time', 'is_cens', 'fo']) self.N = len(df.index) df['new_increment'] = (self.N + 1 - df['fo']) / (self.N + 2 - df.index.values) m = 1.0 - df['new_increment'].min() df['new_increment'] = df['new_increment'] + m df = df.drop(df[df['is_cens'] == 1].index) df['new_order_num'] = df['new_increment'].cumsum() df['cdf'] = util.median_rank(self.N, df['new_order_num'], 0.5) self.est_data = df
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed w_emb = FlaxRobertaEmbedding( self.vocab_size, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="word_embeddings", dtype=self.dtype, )(jnp.atleast_2d(input_ids.astype("i4"))) p_emb = FlaxRobertaEmbedding( self.max_length, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="position_embeddings", dtype=self.dtype, )(jnp.atleast_2d(position_ids.astype("i4"))) t_emb = FlaxRobertaEmbedding( self.type_vocab_size, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="token_type_embeddings", dtype=self.dtype, )(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb # Layer Norm layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) return embeddings
def test_discrete_barycenter_pointcloud(self, lse_mode, epsilon): """Tests the discrete barycenters on pointclouds. Two measures supported on the same set of points (a 1D grid), barycenter is evaluated on a different set of points (still in 1D). Args: lse_mode: bool, lse or scaling computations epsilon: float """ n = 50 ma = 0.2 mb = 0.8 # define two narrow Gaussian bumps in segment [0,1] a = jnp.exp(-(jnp.arange(0, n) / (n - 1) - ma)**2 / .01) + 1e-10 b = jnp.exp(-(jnp.arange(0, n) / (n - 1) - mb)**2 / .01) + 1e-10 a = a / jnp.sum(a) b = b / jnp.sum(b) # positions on the real line where weights are supported. x = jnp.atleast_2d(jnp.arange(0, n) / (n - 1)).T # choose a different support, half the size, for the barycenter. # note this is the reason why we do not use debiasing in this case. x_support_bar = jnp.atleast_2d((jnp.arange(0, (n / 2)) / (n / 2 - 1) - .5) * .9 + .5).T geom = pointcloud.PointCloud(x, x_support_bar, epsilon=epsilon) bar = db.discrete_barycenter(geom, a=jnp.stack((a, b)), lse_mode=lse_mode).histogram # check the barycenter has bump in the middle. self.assertGreater(bar[n // 4], 0.1)
def conditional_moments(self, f, hyp=None): """ """ obs_noise_var = hyp if hyp is not None else self.hyp num_components = int(f.shape[0] / 2) subbands, modulators = f[:num_components], self.link_fn( f[num_components:]) return np.atleast_2d(np.sum(subbands * modulators, axis=0)), np.atleast_2d(obs_noise_var)
def transform_data(x): if isinstance(x,np.ndarray): if len(x.shape)==1: return np.atleast_2d(x.astype(np.float64)).T else: return np.atleast_2d(x.astype(np.float64)) elif isinstance(x,list): return transform_data(np.array(x)) else: raise ValueError("Cannot convert to numpy.array")
def __init__( self, w_in: np.ndarray, w_out: np.ndarray, tau: np.ndarray, bias: np.ndarray, noise_std: float = 0.0, activation_func: Callable[[FloatVector], FloatVector] = H_ReLU, dt: Optional[float] = None, name: Optional[str] = None, rng_key: Optional[int] = None, ): """ ``JAX``-backed firing rate reservoir, used for reservoir transfer :param np.ndarray w_in: Input weights [IxN] :param np.ndarray w_out: Output weights [NxO] :param np.ndarray tau: Time constants [N] :param np.ndarray bias: Bias values [N] :param Optional[float] noise_std: White noise standard deviation applied to reservoir neurons. Default: ``0.0`` :param Callable[[FloatVector], float] activation_func: Neuron transfer function f(x: float) -> float. Must be vectorised. Default: ``H_ReLU`` :param Optional[float] dt: Reservoir time step. Default: ``np.min(tau) / 10.0`` :param Optional[str] name: Name of the layer. Default: ``None`` :param Optional[Jax RNG key] rng_key Jax RNG key to use for noise. Default: Internally generated """ # - Everything should be 2D w_in = np.atleast_2d(w_in) w_out = np.atleast_2d(w_out) # - Get information about network size self._size_in = w_in.shape[0] self._size = w_in.shape[1] self._size_out = w_out.shape[1] # - Call super-class initialisation super().__init__( w_in, np.zeros((self._size, self._size)), w_out, tau, bias, noise_std, activation_func, dt, name, rng_key, ) # - Correct layer size self._size_in = w_in.shape[0] self._size_out = w_out.shape[1] # - Get compiled evolution function for forced reservoir self._evolve_jit = _get_force_evolve_jit(activation_func)
def analytical_linearisation(self, m, sigma=None, hyp=None): """ """ obs_noise_var = hyp if hyp is not None else self.hyp num_components = int(m.shape[0] / 2) subbands, modulators = m[:num_components], self.link_fn( m[num_components:]) Jf = np.block([[modulators], [subbands * self.dlink_fn(m[num_components:])]]) Jsigma = np.array([[np.sqrt(obs_noise_var)]]) return np.atleast_2d(Jf).T, np.atleast_2d(Jsigma).T
def __init__(self, name, pi, mu, gamma, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_2d(mu), False) if not isinstance(gamma, PriorTransform): gamma = DeltaPrior('_{}_gamma'.format(name), jnp.atleast_2d(gamma), False) assert (get_shape(pi)[0] == get_shape(mu)[0]) and (get_shape(pi)[0] == get_shape(gamma)[0]) \ and (get_shape(mu)[1] == get_shape(gamma)[1]) # replaces mu and gamma when parents injected U_dims = 1 + broadcast_shapes(get_shape(mu)[-1:], get_shape(gamma)[-1:])[0] super(GMMDiagPrior, self).__init__(name, U_dims, [pi, mu, gamma], tracked)
def __init__(self, name, pi, low, high, tracked=True): if not isinstance(pi, PriorTransform): pi = DeltaPrior('_{}_pi'.format(name), pi, False) if not isinstance(low, PriorTransform): low = DeltaPrior('_{}_low'.format(name), jnp.atleast_2d(low), False) if not isinstance(high, PriorTransform): high = DeltaPrior('_{}_high'.format(name), jnp.atleast_2d(high), False) assert (get_shape(pi)[0] == get_shape(low)[0]) and (get_shape(pi)[0] == get_shape(high)[0]) \ and (get_shape(low)[1] == get_shape(high)[1]) # replaces mu and high when parents injected U_dims = 1 + broadcast_shapes(get_shape(low)[-1:], get_shape(high)[-1:])[0] super(UniformMixturePrior, self).__init__(name, U_dims, [pi, low, high], tracked)
def analytical_linearisation(self, m, sigma=None, hyp=None): """ Compute the Jacobian of the state space observation model w.r.t. the function fₙ and the noise term σₙ. The implicit observation model is: h(fₙ,rₙ) = E[yₙ|fₙ] + √Cov[yₙ|fₙ] σₙ The Jacobians are evaluated at the means, fₙ=m, σₙ=0, to be used during Extended Kalman filtering and Extended EP. """ sigma = np.array([[0.0]]) if sigma is None else sigma Jf, Jsigma = jacrev(self.observation_model, argnums=(0, 1))(m, sigma, hyp) return np.atleast_2d(np.squeeze(Jf)), np.atleast_2d(np.squeeze(Jsigma))
def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed inputs_embeds = self.word_embeddings(jnp.atleast_2d(input_ids.astype("i4"))) position_embeds = self.position_embeddings(jnp.atleast_2d(position_ids.astype("i4"))) token_type_embeddings = self.token_type_embeddings(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings hidden_states = inputs_embeds + jnp.broadcast_to(position_embeds, inputs_embeds.shape) + token_type_embeddings # Layer Norm hidden_states = self.LayerNorm(hidden_states) hidden_states = self.dropout(hidden_states, deterministic=deterministic) return hidden_states
def __init__(self, dataset, bw_method=None, weights=None): _check_arraylike("gaussian_kde", dataset) dataset = jnp.atleast_2d(dataset) if jnp.issubdtype(lax.dtype(dataset), jnp.complexfloating): raise NotImplementedError( "gaussian_kde does not support complex data") if not dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") d, n = dataset.shape if weights is not None: _check_arraylike("gaussian_kde", weights) dataset, weights = _promote_dtypes_inexact(dataset, weights) weights = jnp.atleast_1d(weights) weights /= jnp.sum(weights) if weights.ndim != 1: raise ValueError("`weights` input should be one-dimensional.") if len(weights) != n: raise ValueError("`weights` input should be of length n") else: dataset, = _promote_dtypes_inexact(dataset) weights = jnp.full(n, 1.0 / n, dtype=dataset.dtype) self._setattr("dataset", dataset) self._setattr("weights", weights) neff = self._setattr("neff", 1 / jnp.sum(weights**2)) bw_method = "scott" if bw_method is None else bw_method if bw_method == "scott": factor = jnp.power(neff, -1. / (d + 4)) elif bw_method == "silverman": factor = jnp.power(neff * (d + 2) / 4.0, -1. / (d + 4)) elif jnp.isscalar(bw_method) and not isinstance(bw_method, str): factor = bw_method elif callable(bw_method): factor = bw_method(self) else: raise ValueError( "`bw_method` should be 'scott', 'silverman', a scalar, or a callable." ) data_covariance = jnp.atleast_2d( jnp.cov(dataset, rowvar=1, bias=False, aweights=weights)) data_inv_cov = jnp.linalg.inv(data_covariance) covariance = data_covariance * factor**2 inv_cov = data_inv_cov / factor**2 self._setattr("covariance", covariance) self._setattr("inv_cov", inv_cov)
def __matmul__(self, inp: CombT) -> RkhsObject: if isinstance(inp, FiniteMap): G = inner(self.inp_feat, inp.outp_feat) if not inp.debias_outp: matr = self.matr @ G @ inp.matr inp_bias = (matr @ inp.bias.T).T else: matr = self.matr @ (G - G @ inp.bias.T) @ inp.matr inp_bias = (self.matr @ G @ inp.bias.T).T rval = FiniteMap(inp.inp_feat, self.outp_feat, matr, outp_bias=self.bias + inp_bias) rval.mean_center_inp = inp.mean_center_inp return rval else: if isinstance(inp, DeviceArray): inp = FiniteVec(self.inp_feat.k, np.atleast_2d(inp)) lin_map = (self.matr @ inner(self.inp_feat, inp)).T if self.debias_outp: r = [DecenterOutFeat(lin_map)] else: if self._normalize: lin_map = lin_map / lin_map.sum(1, keepdims=True) r = [LinearReduce(lin_map + self.bias)] if len(inp) == 1: r.append(Sum()) rval = self.outp_feat.extend_reduce(r) return rval
def make_classification(prng_key, num_features, num_examples, num_classes=2, sep=2.): """Generates random classification problems.""" num_examples_per_class = num_examples // num_classes features, labels = [], [] for label in range(num_classes): # Class mean is on a vertex of the hypercube. # e.g. mu0 = [sep, 0, 0, 0] # mu1 = [ 0, sep, 0, 0] # and so on. mu = sep * jnp.eye(num_features)[:, label] # Sample data from a multivariate normal centered at the class mean. keys = jax.random.split(prng_key, 2) sqrt_sigma = jax.random.normal(keys[0], (num_features, num_features)) samples = jax.random.normal(keys[1], (num_examples_per_class, num_features)) features.append(jnp.atleast_2d(mu) + jnp.dot(samples, sqrt_sigma)) labels.append(label * jnp.ones(num_examples_per_class)) return jnp.vstack(features), jnp.hstack(labels)
def welfare_path(par, sim, wgt, disc, long_run, T, K): # params ψ = np.atleast_2d(par['ψ']) # optional county dependence wgt1 = wgt/np.sum(wgt) # to distribution # discounting ydelt = 1/year_days ytvec = np.arange(T)/year_days down = np.exp(-disc*ytvec) # input factors out = sim['out'][:T, :] irate = np.diff(sim['ka'][:T, :], axis=0) irate = np.concatenate([irate, irate[-1:, :]], axis=0) # immediate welfare util = out - ψ*irate eutil = np.sum(util*wgt1[None, :], axis=1) welf0 = (ydelt*down*eutil).sum() # total welfare if long_run: welf1 = (down[-1]*eutil[-1])/disc welf = disc*(welf0 + welf1) else: Ty = T/year_days welf = disc*welf0/(1-np.exp(-disc*Ty)) return welf
def get_text_features( self, input_ids, attention_mask=None, position_ids=None, dropout_rng: jax.random.PRNGKey = None, train=False ): r""" Args: input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for details. `What are input IDs? <../glossary.html#input-ids>`__ Returns: text_features (:obj:`jax_xla.DeviceArray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`. Examples:: >>> from transformers import CLIPTokenizer, FlaxCLIPModel >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") >>> text_features = model.get_text_features(**inputs) """ if position_ids is None: position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _get_features(module, input_ids, attention_mask, position_ids, deterministic): text_outputs = module.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, deterministic=deterministic, ) pooled_output = text_outputs[1] text_features = module.text_projection(pooled_output) return text_features return self.module.apply( {"params": self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, method=_get_features, rngs=rngs, )
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} random_params = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"] if params is not None: random_params = flatten_dict(unfreeze(random_params)) params = flatten_dict(unfreeze(params)) for missing_key in self._missing_keys: params[missing_key] = random_params[missing_key] self._missing_keys = set() return freeze(unflatten_dict(params)) else: return random_params
def log_density_cubature(likelihood, y, mean, cov, cubature=None): """ logZₙ = log ∫ p(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ :param likelihood: the likelihood model :param y: observed data (yₙ) [scalar] :param mean: cavity mean (mₙ) [scalar] :param cov: cavity covariance (cₙ) [scalar] :param cubature: the function to compute sigma points and weights to use during cubature :return: lZ: the log density, logZₙ [scalar] """ if cubature is None: x, w = gauss_hermite(mean.shape[0], 20) # Gauss-Hermite sigma points and weights else: x, w = cubature(mean.shape[0]) cav_cho, low = cho_factor(cov) # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to cavity dist. sigma_points = cav_cho @ np.atleast_2d(x) + mean # pre-compute wᵢ p(yₙ|xᵢ√(2vₙ) + mₙ) weighted_likelihood_eval = w * likelihood.evaluate_likelihood( y, sigma_points) # Compute partition function via cubature: # Zₙ = ∫ p(yₙ|fₙ) 𝓝(fₙ|mₙ,vₙ) dfₙ ≈ ∑ᵢ wᵢ p(yₙ|fsigᵢ) Z = np.sum(weighted_likelihood_eval, axis=-1) lZ = np.log(np.maximum(Z, 1e-8)) return lZ
def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, ): # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) if position_ids is None: position_ids = jnp.arange(jnp.atleast_2d(input_ids).shape[-1]) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, rngs=rngs, )
def predict_cubature(likelihood, mean_f, var_f, cubature=None): """ predict in data space given predictive mean and var of the latent function """ if cubature is None: x, w = gauss_hermite(mean_f.shape[0], 20) # Gauss-Hermite sigma points and weights else: x, w = cubature(mean_f.shape[0]) chol_f, low = cho_factor(var_f) # fsigᵢ=xᵢ√cₙ + mₙ: scale locations according to latent dist. sigma_points = chol_f @ np.atleast_2d(x) + mean_f # Compute moments via cubature: # E[y] = ∫ E[yₙ|fₙ] 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ E[yₙ|fₙ] # E[y^2] = ∫ (Cov[yₙ|fₙ] + E[yₙ|fₙ]^2) 𝓝(fₙ|mₙ,vₙ) dfₙ # ≈ ∑ᵢ wᵢ (Cov[yₙ|fₙ] + E[yₙ|fₙ]^2) conditional_expectation, conditional_covariance = likelihood.conditional_moments( sigma_points) expected_y = np.sum(w * conditional_expectation, axis=-1) expected_y_squared = np.sum( w * (conditional_covariance + conditional_expectation**2), axis=-1) # Cov[y] = E[y^2] - E[y]^2 covariance_y = expected_y_squared - expected_y**2 return expected_y, covariance_y
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} if self.config.add_cross_attention: encoder_hidden_states = jnp.zeros(input_shape + (self.config.n_embd, )) encoder_attention_mask = attention_mask module_init_outputs = self.module.init( rngs, input_ids, attention_mask, position_ids, encoder_hidden_states, encoder_attention_mask, return_dict=False, ) else: module_init_outputs = self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False) return module_init_outputs["params"]
def __matmul__( self, right_inp: CombT ) -> Union[OutVecT, "FiniteOp[RhInpVectT, OutVecT]"]: if isinstance(right_inp, FiniteOp): # right_inp is an operator # Op1 @ Op2 matr = self.inp_feat.inner(right_inp.outp_feat) @ right_inp.matr if self.matr is not None: matr = self.matr @ matr return FiniteOp(right_inp.inp_feat, self.outp_feat, matr) else: if isinstance(right_inp, DeviceArray): right_inp = FiniteVec(self.inp_feat.k, np.atleast_2d(right_inp)) # right_inp is a vector # Op @ vec lin_LinOp = inner(self.inp_feat, right_inp) if self.matr is not None: lin_LinOp = self.matr @ lin_LinOp if self._normalize: lin_LinOp = lin_LinOp / lin_LinOp.sum(1, keepdims=True) lr = LinearReduce(lin_LinOp.T) if len(right_inp) != 1: return self.outp_feat.extend_reduce([lr]) else: return self.outp_feat.extend_reduce([lr, Sum()])
def jac_reshaped(*points, **kw_points): jac = jax.jacobian(fun_raveled, *args, **kwargs) jac_eval = np.atleast_2d(jac(*points, **kw_points).T).T diff_dim = jac_eval.shape[1] logger.debug("Reshaping jacobian with original shape: {}".format( jac_eval.shape)) jac_eval_tmp = jac_eval.ravel(order='F').reshape(diff_dim, *shape) return np.einsum('i...->...i', jac_eval_tmp) # Transpose for column major
def response_mlp(theta: np.ndarray, _x: np.ndarray) -> np.ndarray: _x = np.atleast_2d(_x) if _x.shape[0] == 1: # (n,) <- (1, k) @ (k, n) return (basis_predict(_x) @ theta).squeeze() else: # (n_constr, n) <- (n_constr, n, k) @ (k, n) return np.einsum('ijk,kj->ij', basis_predict(_x[:, :, None]), theta)
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict: # init input tensors input_ids = jnp.zeros(input_shape, dtype="i4") attention_mask = jnp.ones_like(input_ids) position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape) params_rng, dropout_rng = jax.random.split(rng) rngs = {"params": params_rng, "dropout": dropout_rng} return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)["params"]
def misclassification_polytope(a, c, ls): """creates misclassification constraints""" assert a.ndim == 2 assert a.shape[0] == 1 # only batch size 1 is supported n_classes = a.shape[1] u = a[:, ls] - a[:, c] c = np.atleast_1d(np.asarray([c]).squeeze()) ls = np.atleast_1d(np.asarray([ls]).squeeze()) Av = lambda Vv: Vv[:, c] - Vv[:, ls] # noqa: E731 vA = lambda v: ( scatter(c, np.sum(np.atleast_2d(v), axis=-1, keepdims=True), n_classes) + # noqa: E731 scatter(ls, -np.atleast_2d(v), n_classes)) return Av, vA, u
def __init__(self, name, mu, Gamma, ill_cond=False, tracked=True): self._ill_cond = ill_cond if not isinstance(mu, PriorTransform): mu = DeltaPrior('_{}_mu'.format(name), jnp.atleast_1d(mu), False) if not isinstance(Gamma, PriorTransform): Gamma = DeltaPrior('_{}_Gamma'.format(name), jnp.atleast_2d(Gamma), False) U_dims = broadcast_shapes(get_shape(mu), get_shape(Gamma)[0:1])[0] super(MVNPrior, self).__init__(name, U_dims, [mu, Gamma], tracked)
def __call__( self, input_ids, attention_mask=None, token_type_ids=None, position_ids=None, head_mask=None, params: dict = None, dropout_rng: PRNGKey = None, train: bool = False, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = (output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states) return_dict = return_dict if return_dict is not None else self.config.return_dict # init input tensors if not passed if token_type_ids is None: token_type_ids = jnp.ones_like(input_ids) if position_ids is None: position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if head_mask is None: head_mask = jnp.ones((self.config.num_hidden_layers, self.config.num_attention_heads)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(token_type_ids, dtype="i4"), jnp.array(position_ids, dtype="i4"), jnp.array(head_mask, dtype="i4"), not train, output_attentions, output_hidden_states, return_dict, rngs=rngs, )