def get_mfvi_model_fn(net_fn, params, net_state, seed=0, sigma_init=0.): """Convert model, parameters and net state to use MFVI. Convert the model to fit a Gaussian distribution to each of the weights following the Mean Field Variational Inference (MFVI) procedure. Args: net_fn: neural network function. params: parameters of the network; we intialize the mean in MFVI with params. net_state: state of the network. seed: random seed; used for generating random samples when computing MFVI predictions (default: 0). sigma_init: initial value of the standard deviation of the per-prarameter Gaussians. """ # net_fn(params, net_state, None, batch, is_training) mean_params = jax.tree_map(lambda p: p.copy(), params) sigma_isp = inv_softplus(sigma_init) std_params = jax.tree_map(lambda p: jnp.ones_like(p) * sigma_isp, params) mfvi_params = {"mean": mean_params, "inv_softplus_std": std_params} mfvi_state = { "net_state": copy.deepcopy(net_state), "mfvi_key": jax.random.PRNGKey(seed) } def sample_parms_fn(params, state): mean = params["mean"] std = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"]) noise, new_key = tree_utils.normal_like_tree(mean, state["mfvi_key"]) params_sampled = jax.tree_multimap(lambda m, s, n: m + n * s, mean, std, noise) new_mfvi_state = { "net_state": copy.deepcopy(state["net_state"]), "mfvi_key": new_key } return params_sampled, new_mfvi_state def mfvi_apply_fn(params, state, _, batch, is_training): params_sampled, new_mfvi_state = sample_parms_fn(params, state) predictions, new_net_state = net_fn(params_sampled, state["net_state"], None, batch, is_training) new_mfvi_state = { "net_state": copy.deepcopy(new_net_state), "mfvi_key": new_mfvi_state["mfvi_key"] } return predictions, new_mfvi_state def mfvi_apply_mean_fn(params, state, _, batch, is_training): """Predict with the variational mean.""" mean = params["mean"] predictions, new_net_state = net_fn(mean, state["net_state"], None, batch, is_training) new_mfvi_state = { "net_state": copy.deepcopy(new_net_state), "mfvi_key": state["mfvi_key"] } return predictions, new_mfvi_state return (mfvi_apply_fn, mfvi_apply_mean_fn, sample_parms_fn, mfvi_params, mfvi_state)
def test_top_k_fraction( self, top_k_fraction, scaled_advantages, expected_top_k_weights): """Test that only the top k fraction are used.""" top_k_weights = mpo_ops.get_top_k_weights( top_k_fraction, jnp.ones_like(scaled_advantages), scaled_advantages) np.testing.assert_allclose(top_k_weights, expected_top_k_weights)
def ones_like(self, tensor): return jnp.ones_like(tensor)
def model(X: DeviceArray) -> DeviceArray: """Gamma-Poisson hierarchical model for daily sales forecasting Args: X: input data Returns: output data """ n_stores, n_days, n_features = X.shape n_features -= 1 # remove one dim for target eps = 1e-12 # epsilon plate_features = numpyro.plate(Plate.features, n_features, dim=-1) plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2) plate_days = numpyro.plate(Plate.days, n_days, dim=-1) disp_param_mu = numpyro.sample(Site.disp_param_mu, dist.Normal(loc=4.0, scale=1.0)) disp_param_sigma = numpyro.sample(Site.disp_param_sigma, dist.HalfNormal(scale=1.0)) with plate_stores: with numpyro.handlers.reparam( config={Site.disp_params: TransformReparam()}): disp_params = numpyro.sample( Site.disp_params, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1), dist.transforms.AffineTransform(disp_param_mu, disp_param_sigma), ), ) with plate_features: coef_mus = numpyro.sample( Site.coef_mus, dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)), ) coef_sigmas = numpyro.sample( Site.coef_sigmas, dist.HalfNormal(scale=2.0 * jnp.ones(n_features))) with plate_stores: with numpyro.handlers.reparam( config={Site.coefs: TransformReparam()}): coefs = numpyro.sample( Site.coefs, dist.TransformedDistribution( dist.Normal(loc=jnp.zeros((n_stores, n_features)), scale=1.0), dist.transforms.AffineTransform(coef_mus, coef_sigmas), ), ) with plate_days, plate_stores: targets = X[..., -1] features = jnp.nan_to_num(X[..., :-1]) # padded features to 0 is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets), jnp.ones_like(targets)) not_observed = 1 - is_observed means = (is_observed * jnp.exp( jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) + not_observed * eps) betas = is_observed * jnp.exp(-disp_params) + not_observed alphas = means * betas return numpyro.sample(Site.days, dist.GammaPoisson(alphas, betas), obs=jnp.nan_to_num(targets))
def g(x, y): tot = np.sum(5. * np.cos(x) * np.sin(y)) return tot * np.ones_like(x) # broadcast to map like pjit does
def value_and_grad(fn, args): """Given `fn: (args) -> out, extra`, returns `dout/dargs`.""" output, vjp_fn, extra = jax.vjp(fn, args, has_aux=True) grad = vjp_fn(np.ones_like(output))[0] return output, extra, grad
def _mix_with_uniform(probs, epsilon): """Mix an arbitrary categorical distribution with a uniform distribution.""" num_actions = probs.shape[-1] uniform_probs = jnp.ones_like(probs) / num_actions return (1 - epsilon) * probs + epsilon * uniform_probs
molmassCO = molinfo.molmass('CO') mmw = 2.33 # mean molecular weight mmrH2 = 0.74 molmassH2 = molinfo.molmass('H2') vmrH2 = (mmrH2*mmw/molmassH2) # VMR Mp = 33.2 # fixing mass... # Loading the molecular database of CO and the CIA # In[8]: # reference pressure for a T-P model Pref = 1.0 # bar ONEARR = np.ones_like(Parr) ONEWAV = jnp.ones_like(nflux) # In[14]: smalla = 1.0 smalldiag = smalla**2*jnp.identity(NP) # Now we write the model, which is used in HMC-NUTS. # In[15]: def modelcov(t, tau, a):
def fun(x, y): @partial(api.jit, backend=inner) def infun(x, y): return np.matmul(x, y) return infun(x, y) + np.ones_like(x)
def safe_sqrt(x, eps=1e-7): safe_x = jnp.where(x == 0, jnp.ones_like(x) * eps, x) return jnp.sqrt(safe_x)
def general_loss_with_squared_residual(squared_x, alpha, scale): r"""The general loss that takes a squared residual. This fuses the sqrt operation done to compute many residuals while preserving the square in the loss formulation. This implements the rho(x, \alpha, c) function described in "A General and Adaptive Robust Loss Function", Jonathan T. Barron, https://arxiv.org/abs/1701.03077. Args: squared_x: The residual for which the loss is being computed. x can have any shape, and alpha and scale will be broadcasted to match x's shape if necessary. alpha: The shape parameter of the loss (\alpha in the paper), where more negative values produce a loss with more robust behavior (outliers "cost" less), and more positive values produce a loss with less robust behavior (outliers are penalized more heavily). Alpha can be any value in [-infinity, infinity], but the gradient of the loss with respect to alpha is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth interpolation between several discrete robust losses: alpha=-Infinity: Welsch/Leclerc Loss. alpha=-2: Geman-McClure loss. alpha=0: Cauchy/Lortentzian loss. alpha=1: Charbonnier/pseudo-Huber loss. alpha=2: L2 loss. scale: The scale parameter of the loss. When |x| < scale, the loss is an L2-like quadratic bowl, and when |x| > scale the loss function takes on a different shape according to alpha. Returns: The losses for each element of x, in the same shape as x. """ eps = jnp.finfo(jnp.float32).eps # This will be used repeatedly. squared_scaled_x = squared_x / (scale**2) # The loss when alpha == 2. loss_two = 0.5 * squared_scaled_x # The loss when alpha == 0. loss_zero = log1p_safe(0.5 * squared_scaled_x) # The loss when alpha == -infinity. loss_neginf = -jnp.expm1(-0.5 * squared_scaled_x) # The loss when alpha == +infinity. loss_posinf = expm1_safe(0.5 * squared_scaled_x) # The loss when not in one of the above special cases. # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. beta_safe = jnp.maximum(eps, jnp.abs(alpha - 2.)) # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. alpha_safe = jnp.where(jnp.greater_equal(alpha, 0.), jnp.ones_like(alpha), -jnp.ones_like(alpha)) * jnp.maximum( eps, jnp.abs(alpha)) loss_otherwise = (beta_safe / alpha_safe) * ( jnp.power(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.) # Select which of the cases of the loss to return. loss = jnp.where( alpha == -jnp.inf, loss_neginf, jnp.where( alpha == 0, loss_zero, jnp.where(alpha == 2, loss_two, jnp.where(alpha == jnp.inf, loss_posinf, loss_otherwise)))) return scale * loss
def vmpo_loss( sample_log_probs: Array, advantages: Array, temperature_constraint: LagrangePenalty, kl_constraints: Sequence[Tuple[Array, LagrangePenalty]], projection_operator: Callable[[Numeric], Numeric] = functools.partial( jnp.clip, a_min=_EPSILON), restarting_weights: Optional[Array] = None, importance_weights: Optional[Array] = None, top_k_fraction: float = 0.5, policy_loss_weight: float = 1.0, temperature_loss_weight: float = 1.0, kl_loss_weight: float = 1.0, alpha_loss_weight: float = 1.0, axis_name: Optional[str] = None, use_stop_gradient: bool = True, ) -> Tuple[Array, MpoOutputs]: """Calculates the V-MPO policy improvement loss. Note: This is a per-example loss which works on any shape inputs as long as they are consistent. We denote the shape of the examples E* for ease of reference. Args: sample_log_probs: Log probabilities of actions for each example. Shape E*. advantages: Advantages for the E-step. Shape E*. temperature_constraint: Lagrange constraint for the E-step temperature optimization. kl_constraints: KL and variables for applying Lagrangian penalties to bound them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension in the case of per-dimension KL constraints. projection_operator: Function to project dual variables (temperature and kl constraint alphas) into the positive range. restarting_weights: Optional restarting weights, shape E*, 0 means that this step is the start of a new episode and we ignore losses at this step because the agent cannot influence these. importance_weights: Optional importance weights, shape E*. top_k_fraction: Fraction of samples to use in the E-step. policy_loss_weight: Weight for the policy loss. temperature_loss_weight: Weight for the temperature loss. kl_loss_weight: Weight for the KL loss. alpha_loss_weight: Weight for the alpha loss. axis_name: Optional axis name for `pmap`. If `None`, computations are performed locally on each device. use_stop_gradient: bool indicating whether or not to apply stop gradient. Returns: Per example `loss` with same shape E* as array inputs, and additional data including the components of this loss and the normalized weights in the AdditionalOutputs. """ # Define default restarting weights and importance weights. if restarting_weights is None: restarting_weights = jnp.ones_like(sample_log_probs) if importance_weights is None: importance_weights = jnp.ones_like(sample_log_probs) # Check shapes. chex.assert_equal_shape( [advantages, sample_log_probs, restarting_weights, importance_weights]) chex.assert_rank(temperature_constraint.epsilon, 0) chex.assert_type([ sample_log_probs, advantages, restarting_weights, importance_weights, temperature_constraint.alpha, temperature_constraint.epsilon], float) for kl, penalty in kl_constraints: chex.assert_rank(penalty.epsilon, 0) chex.assert_type([kl, penalty.alpha, penalty.epsilon], float) if penalty.per_dimension: chex.assert_rank(kl, advantages.ndim + 1) chex.assert_equal_shape_prefix([kl, advantages], advantages.ndim) else: chex.assert_equal_shape([kl, advantages]) # E-step: Calculate the reweighting and the temperature loss. temperature_loss, norm_weights, num_samples = ( vmpo_compute_weights_and_temperature_loss( advantages, restarting_weights, importance_weights, temperature_constraint, projection_operator, top_k_fraction, axis_name=axis_name, use_stop_gradient=use_stop_gradient)) # M-step: Supervised learning of reweighted trajectories using the weights # from the E-step, with additional KL constraints. # The weights are normalized so that the sum is 1. We multiply by the number # of examples so that we can give a policy loss per example and take the mean, # and we assume `restarting_weights` are already included. if axis_name: num_examples = jax.lax.all_gather( sample_log_probs, axis_name=axis_name).size else: num_examples = sample_log_probs.size policy_loss = -sample_log_probs * norm_weights * num_examples kl_loss, alpha_loss = compute_parametric_kl_penalty_and_dual_loss( kl_constraints, projection_operator, use_stop_gradient) chex.assert_equal_shape([policy_loss, kl_loss, alpha_loss]) # Calculate the total policy improvement loss. loss = (policy_loss_weight * policy_loss + temperature_loss_weight * temperature_loss + kl_loss_weight * kl_loss + alpha_loss_weight * alpha_loss) return loss, MpoOutputs( temperature_loss=temperature_loss, policy_loss=policy_loss, kl_loss=kl_loss, alpha_loss=alpha_loss, normalized_weights=norm_weights, num_samples=num_samples)
# that the mapping parameter c = (zf-z0)/(xf-x0) is also constant ## construct univariate tfc class: ***************************************************************** tfc = utfc(N + 1, nC, int(m + 1), basis=basis, x0=0, xf=xstep) x = tfc.x # !!! notice I am using N+1 for the number of points. this is because I will be using the last point # of a segment 'n' for the initial conditons of the 'n+1' segment H = tfc.H dH = tfc.dH H0 = H(x[0:1]) H0p = dH(x[0:1]) ## define tfc constrained expression and derivatives: ********************************************** # switching function phi1 = lambda x: np.ones_like(x) phi2 = lambda x: x # tfc constrained expression y = lambda x,xi,IC: np.dot(H(x),xi) + phi1(x)*(IC['y0'] - np.dot(H0,xi)) \ + phi2(x)*(IC['y0p'] - np.dot(H0p,xi)) # !!! notice here that the initial conditions are passed as a dictionary (i.e. IC['y0']) # this will be important so that the least-squares does not need to be re-JITed yp = egrad(y) ypp = egrad(yp) ## define the loss function: *********************************************************************** # yₓₓ + δ yₓ + α y + β y^3 - γ cos(ω x) = 0 L = jit(lambda xi,IC: ypp(x,xi,IC) + delta*yp(x,xi,IC) + alpha*y(x,xi,IC) + beta*y(x,xi,IC)**3 \ - gamma*np.cos(omega*x))
def model(data): mean = numpyro.sample( "mean", dist.Normal(ref_params, jnp.ones_like(ref_params))) with numpyro.plate("N", data.shape[0], subsample_size=100, dim=-2) as idx: numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])
def ones_like(self: TensorType) -> TensorType: return type(self)(np.ones_like(self.raw))
def add_fxy_coord(X, Y, coefficient): X_fxy_agg.append(jnp.ravel(X)) Y_fxy_agg.append(jnp.ravel(Y)) fxy_coefficient.append(coefficient * jnp.ones_like(Y_fxy_agg[-1]))
def __call__(self, x, t): tt = jnp.ones_like(x[:, :, :, :1]) * t ttx = jnp.concatenate([tt, x], axis=-1) return self._layer(ttx)
def encode( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Example: ```python >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2") >>> tokenizer = BertTokenizer.from_pretrained("bert-base-cased") >>> text = "My friends are cool but they eat too many carbs." >>> input_ids = tokenizer.encode(text, return_tensors="np") >>> encoder_outputs = model.encode(input_ids) ```""" output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs): encode_module = module._get_encoder_module() return encode_module(input_ids, attention_mask, position_ids, **kwargs) outputs = self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, method=_encoder_forward, ) if return_dict: outputs = FlaxBaseModelOutput( last_hidden_state=outputs.last_hidden_state, hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) return outputs
def get_text_features( self, input_ids, attention_mask=None, position_ids=None, params: dict = None, dropout_rng: jax.random.PRNGKey = None, train=False, ): r""" Args: input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`): Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide it. Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details. [What are input IDs?](../glossary#input-ids) Returns: text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying the projection layer to the pooled output of [`FlaxCLIPTextModel`]. Examples: ```python >>> from transformers import CLIPTokenizer, FlaxCLIPModel >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32") >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np") >>> text_features = model.get_text_features(**inputs) ```""" if position_ids is None: position_ids = jnp.broadcast_to( jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape) if attention_mask is None: attention_mask = jnp.ones_like(input_ids) # Handle any PRNG if needed rngs = {} if dropout_rng is not None: rngs["dropout"] = dropout_rng def _get_features(module, input_ids, attention_mask, position_ids, deterministic): text_outputs = module.text_model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, deterministic=deterministic, ) pooled_output = text_outputs[1] text_features = module.text_projection(pooled_output) return text_features return self.module.apply( {"params": params or self.params}, jnp.array(input_ids, dtype="i4"), jnp.array(attention_mask, dtype="i4"), jnp.array(position_ids, dtype="i4"), not train, method=_get_features, rngs=rngs, )
def __call__( self, input_ids: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, decoder_input_ids: Optional[jnp.ndarray] = None, decoder_attention_mask: Optional[jnp.ndarray] = None, position_ids: Optional[jnp.ndarray] = None, decoder_position_ids: Optional[jnp.ndarray] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, train: bool = False, params: dict = None, dropout_rng: PRNGKey = None, ): r""" Returns: Examples: ```python >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer >>> # load a fine-tuned bert2gpt2 model >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16") >>> # load input & output tokenizer >>> tokenizer_input = BertTokenizer.from_pretrained("bert-base-cased") >>> tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2") >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members >>> singing a racist chant. SAE's national chapter suspended the students, >>> but University of Oklahoma President David Boren took it a step further, >>> saying the university's affiliation with the fraternity is permanently done.''' >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids >>> # use GPT2's eos_token as the pad as well as eos token >>> model.config.eos_token_id = model.config.decoder.eos_token_id >>> model.config.pad_token_id = model.config.eos_token_id >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0] >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members" ``` """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) return_dict = return_dict if return_dict is not None else self.config.return_dict # prepare encoder inputs if attention_mask is None: attention_mask = jnp.ones_like(input_ids) if position_ids is None: batch_size, sequence_length = input_ids.shape position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)) # prepare decoder inputs if decoder_input_ids is None: raise ValueError( "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument." ) if decoder_attention_mask is None: decoder_attention_mask = jnp.ones_like(decoder_input_ids) if decoder_position_ids is None: batch_size, sequence_length = decoder_input_ids.shape decoder_position_ids = jnp.broadcast_to( jnp.arange(sequence_length)[None, :], (batch_size, sequence_length) ) # Handle any PRNG if needed rngs = {"dropout": dropout_rng} if dropout_rng is not None else {} return self.module.apply( {"params": params or self.params}, input_ids=jnp.array(input_ids, dtype="i4"), attention_mask=jnp.array(attention_mask, dtype="i4"), decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"), decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"), position_ids=jnp.array(position_ids, dtype="i4"), decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"), output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, deterministic=not train, rngs=rngs, )
def kl_to_standard_normal_fn(mu: Array, sigma: Array = sigma): v = jnp.clip(sigma**2, 1e-6, 1e6) return 0.5 * (jnp.sum(v) + jnp.sum(mu**2) - jnp.sum(jnp.ones_like(mu)) - jnp.sum(jnp.log(v)))
def apply(self, inputs, vocab_size, emb_dim=512, num_heads=8, num_layers=6, qkv_dim=512, mlp_dim=2048, max_len=2048, train=False, dropout_rate=0.1, attention_dropout_rate=0.1, causal=True, cache=None, positional_encoding_module=AddLearnedPositionalEncodings, self_attention_module=nn.SelfAttention, attention_fn=None, pad_token=None, output_head='logits'): """Applies Transformer model on the inputs. Args: inputs: An array of shape (batch_size, length) or (batch_size, length, vocab_size) with the input sequences. When 2-dimensional, the array contains sequences of int tokens. Otherwise, the array contains next-token distributions over tokens (e.g. one-hot representations). vocab_size: An int with the size of the vocabulary. emb_dim: An int with the token embedding dimension. num_heads: An int with the number of attention heads. num_layers: An int with the number of transformer encoder layers. qkv_dim: An int with the dimension of the query/key/value vectors. mlp_dim: An int with the inner dimension of the feed-forward network which follows the attention block. max_len: An int with the maximum training sequence length. train: A bool denoting whether we are currently training. dropout_rate: A float with the dropout rate. attention_dropout_rate: A float with a dropout rate for attention weights. causal: Whether to apply causal masking. cache: Cache for decoding. positional_encoding_module: A module used for adding positional encodings. self_attention_module: Self attention module. attention_fn: Method to use in place of dot product attention. pad_token: Token to ignore in attention. output_head: String or iterable over strings containing the model's output head(s) to return. Returns: Output of a transformer decoder. If output_head is a string, we return a single output head output; if output_head is an iterable, we return a dict with (output head name, output head output) key-value pairs. """ if inputs.ndim != 2 and inputs.ndim != 3: raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim) if inputs.ndim == 3: padding_mask = jnp.ones_like(inputs[Ellipsis, 0]) elif pad_token is None: padding_mask = jnp.ones_like(inputs) else: # Mask out padding tokens. padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32) padding_mask = padding_mask[Ellipsis, None] # Add embedding dimension. heads = dict() x = inputs if inputs.ndim == 2: x = x.astype('int32') x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed') if positional_encoding_module == AddLearnedPositionalEncodings: x = positional_encoding_module( x, max_len=max_len, cache=cache, posemb_init=sinusoidal_init(max_len=max_len)) else: x = positional_encoding_module(x, max_len=max_len) x = nn.dropout(x, rate=dropout_rate, deterministic=not train) heads['input_emb'] = x for i in range(num_layers): x = Transformer1DBlock( x, qkv_dim=qkv_dim, mlp_dim=mlp_dim, num_heads=num_heads, causal_mask=causal, padding_mask=padding_mask, dropout_rate=dropout_rate, attention_dropout_rate=attention_dropout_rate, self_attention_module=self_attention_module, deterministic=not train, attention_fn=attention_fn, cache=cache, ) heads['layer_%s' % i] = x x = nn.LayerNorm(x) heads['output_emb'] = x * padding_mask # Zero out PAD positions. if 'logits' in output_head: logits = nn.Dense( x, vocab_size, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) heads['logits'] = logits if 'regression' in output_head: regression = nn.Dense( x, 1, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6)) regression = jnp.squeeze(regression, axis=-1) heads['regression'] = regression if isinstance(output_head, (tuple, list)): return {head: heads[head] for head in output_head} return heads[output_head]
def splitjvp(x): _, jvp = linearize(f, x) return jvp(np.ones_like(x))
def _reset_head_kernel(params, value): params = flax.core.unfreeze(params) params["head"]["kernel"] = value * jnp.ones_like(params["head"]["kernel"]) return flax.core.freeze(params)
def vmpo_e_step_without_restarting_or_importance_weights(advantages, **kwargs): restarting_weights = jnp.ones_like(advantages) importance_weights = jnp.ones_like(advantages) return mpo_ops.vmpo_compute_weights_and_temperature_loss( advantages=advantages, restarting_weights=restarting_weights, importance_weights=importance_weights, **kwargs)
class JaxBox(qml.math.TensorBox): """Implements the :class:`~.TensorBox` API for ``numpy.ndarray``. For more details, please refer to the :class:`~.TensorBox` documentation. """ abs = wrap_output(lambda self: jnp.abs(self.data)) angle = wrap_output(lambda self: jnp.angle(self.data)) arcsin = wrap_output(lambda self: jnp.arcsin(self.data)) cast = wrap_output(lambda self, dtype: jnp.array(self.data, dtype=dtype)) expand_dims = wrap_output( lambda self, axis: jnp.expand_dims(self.data, axis=axis)) ones_like = wrap_output(lambda self: jnp.ones_like(self.data)) sqrt = wrap_output(lambda self: jnp.sqrt(self.data)) sum = wrap_output(lambda self, axis=None, keepdims=False: jnp.sum( self.data, axis=axis, keepdims=keepdims)) T = wrap_output(lambda self: self.data.T) take = wrap_output(lambda self, indices, axis=None: jnp.take( self.data, indices, axis=axis, mode="wrap")) def __init__(self, tensor): tensor = jnp.asarray(tensor) super().__init__(tensor) @staticmethod def astensor(tensor): return jnp.asarray(tensor) @staticmethod @wrap_output def concatenate(values, axis=0): return jnp.concatenate(JaxBox.unbox_list(values), axis=axis) @staticmethod @wrap_output def dot(x, y): x, y = JaxBox.unbox_list([x, y]) x = jnp.asarray(x) y = jnp.asarray(y) if x.ndim == 0 and y.ndim == 0: return x * y if x.ndim == 2 and y.ndim == 2: return x @ y return jnp.dot(x, y) @property def interface(self): return "jax" def numpy(self): return self.data @property def requires_grad(self): return True @property def shape(self): return self.data.shape @staticmethod @wrap_output def stack(values, axis=0): return jnp.stack(JaxBox.unbox_list(values), axis=axis) @staticmethod @wrap_output def where(condition, x, y): return jnp.where(condition, *JaxBox.unbox_list([x, y]))
def collapse_and_remove_blanks(labels: jnp.ndarray, seq_length: jnp.ndarray, blank_id: int = 0): """Merge repeated labels into single labels and remove the designated blank symbol. Args: labels: Array of shape (batch, seq_length) seq_length: Arrray of shape (batch), sequence length of each batch element. blank_id: Optional id of the blank symbol Returns: tuple of tf.SparseTensor of shape (batch, seq_length) with repeated labels collapsed, eg: [[A, A, B, B, A], [A, B, C, D, E]] => [[A, B, A], [A, B, C, D, E]] and int tensor of shape [batch] with new sequence lengths. """ b, t = labels.shape # Zap out blank blank_mask = 1 - jnp.equal(labels, blank_id) labels = (labels * blank_mask).astype(labels.dtype) # Mask labels that don't equal previous label. label_mask = jnp.concatenate([ jnp.ones_like(labels[:, :1], dtype=jnp.int32), jnp.not_equal(labels[:, 1:], labels[:, :-1]) ], axis=1) # Filter labels that aren't in the original sequence. maxlen = labels.shape[1] seq_mask = sequence_mask(seq_length, maxlen=maxlen) label_mask = label_mask * seq_mask # remove repetitions from the labels ulabels = label_mask * labels # Count masks for new sequence lengths. label_mask = jnp.not_equal(ulabels, 0).astype(labels.dtype) new_seq_len = jnp.sum(label_mask, axis=1) # Mask indexes based on sequence length mask. new_maxlen = maxlen idx_mask = sequence_mask(new_seq_len, maxlen=new_maxlen) # Flatten everything and mask out labels to keep and sparse indices. flat_labels = jnp.reshape(ulabels, [-1]) flat_idx_mask = jnp.reshape(idx_mask, [-1]) indices = jnp.nonzero(flat_idx_mask, size=b * t)[0] values = jnp.nonzero(flat_labels, size=b * t)[0] updates = jnp.take_along_axis(flat_labels, values, axis=-1) # Scatter to flat shape. flat = jnp.zeros(flat_idx_mask.shape).astype(labels.dtype) flat = flat.at[indices].set(updates) # 0'th position in the flat array gets clobbered by later padded updates, # so reset it here to its original value flat = flat.at[0].set(updates[0]) # Reshape back to square batch. batch_size = labels.shape[0] new_shape = [batch_size, new_maxlen] return (jnp.reshape(flat, new_shape).astype(labels.dtype), new_seq_len.astype(seq_length.dtype))
def fit_STC(self, prewhiten=False, n_repeats=10, percentile=100., random_seed=2046, verbose=5): """ Spike-triggered Covariance Analysis. Parameters ========== prewhiten: bool n_repeats: int Number of repeats for STC significance test. percentile: float Valid range of STC significance test. verbose: int random_seed: int """ def get_stc(_X, _y, _w): n = len(_X) ste = _X[_y != 0] proj = ste - ste * _w * _w.T stc = proj.T @ proj / (n - 1) _eigvec, _eigval, _ = jnp.linalg.svd(stc) return _eigvec, _eigval key = random.PRNGKey(random_seed) y = self.y if prewhiten: if self.compute_mle is False: self.XtX = self.X.T @ self.X self.w_mle = jnp.linalg.solve(self.XtX, self.XtY) X = jnp.linalg.solve(self.XtX, self.X.T).T w = uvec(self.w_mle) else: X = self.X w = uvec(self.w_sta) eigvec, eigval = get_stc(X, y, w) self.w_stc = dict() if n_repeats: print('STC significance test: ') eigval_null = [] for counter in range(n_repeats): if verbose: if counter % int(verbose) == 0: print(f' {counter + 1:}/{n_repeats}') y_randomize = random.permutation(key, y) _, eigval_randomize = get_stc(X, y_randomize, w) eigval_null.append(eigval_randomize) else: if verbose: print(f'Done.') eigval_null = jnp.vstack(eigval_null) max_null, min_null = jnp.percentile(eigval_null, percentile), jnp.percentile( eigval_null, 100 - percentile) mask_sig_pos = eigval > max_null mask_sig_neg = eigval < min_null mask_sig = jnp.logical_or(mask_sig_pos, mask_sig_neg) self.w_stc['eigvec'] = eigvec self.w_stc['pos'] = eigvec[:, mask_sig_pos] self.w_stc['neg'] = eigvec[:, mask_sig_neg] self.w_stc['eigval'] = eigval self.w_stc['eigval_mask'] = mask_sig self.w_stc['eigval_pos_mask'] = mask_sig_pos self.w_stc['eigval_neg_mask'] = mask_sig_neg self.w_stc['max_null'] = max_null self.w_stc['min_null'] = min_null else: self.w_stc['eigvec'] = eigvec self.w_stc['eigval'] = eigval self.w_stc['eigval_mask'] = jnp.ones_like(eigval).astype(bool)
def init(x0): avg_sq_grad = np.ones_like(x0) return x0, avg_sq_grad
def vector_DP_inv_pd(v): return np.exp(v - np.ones_like(v))