Example #1
0
def get_mfvi_model_fn(net_fn, params, net_state, seed=0, sigma_init=0.):
    """Convert model, parameters and net state to use MFVI.

  Convert the model to fit a Gaussian distribution to each of the weights
  following the Mean Field Variational Inference (MFVI) procedure.

  Args:
    net_fn: neural network function.
    params: parameters of the network; we intialize the mean in MFVI with
      params.
    net_state: state of the network.
    seed: random seed; used for generating random samples when computing MFVI
      predictions (default: 0).
    sigma_init: initial value of the standard deviation of the per-prarameter
      Gaussians.
  """
    #  net_fn(params, net_state, None, batch, is_training)
    mean_params = jax.tree_map(lambda p: p.copy(), params)
    sigma_isp = inv_softplus(sigma_init)
    std_params = jax.tree_map(lambda p: jnp.ones_like(p) * sigma_isp, params)
    mfvi_params = {"mean": mean_params, "inv_softplus_std": std_params}
    mfvi_state = {
        "net_state": copy.deepcopy(net_state),
        "mfvi_key": jax.random.PRNGKey(seed)
    }

    def sample_parms_fn(params, state):
        mean = params["mean"]
        std = jax.tree_map(jax.nn.softplus, params["inv_softplus_std"])
        noise, new_key = tree_utils.normal_like_tree(mean, state["mfvi_key"])
        params_sampled = jax.tree_multimap(lambda m, s, n: m + n * s, mean,
                                           std, noise)
        new_mfvi_state = {
            "net_state": copy.deepcopy(state["net_state"]),
            "mfvi_key": new_key
        }
        return params_sampled, new_mfvi_state

    def mfvi_apply_fn(params, state, _, batch, is_training):
        params_sampled, new_mfvi_state = sample_parms_fn(params, state)
        predictions, new_net_state = net_fn(params_sampled, state["net_state"],
                                            None, batch, is_training)
        new_mfvi_state = {
            "net_state": copy.deepcopy(new_net_state),
            "mfvi_key": new_mfvi_state["mfvi_key"]
        }
        return predictions, new_mfvi_state

    def mfvi_apply_mean_fn(params, state, _, batch, is_training):
        """Predict with the variational mean."""
        mean = params["mean"]
        predictions, new_net_state = net_fn(mean, state["net_state"], None,
                                            batch, is_training)
        new_mfvi_state = {
            "net_state": copy.deepcopy(new_net_state),
            "mfvi_key": state["mfvi_key"]
        }
        return predictions, new_mfvi_state

    return (mfvi_apply_fn, mfvi_apply_mean_fn, sample_parms_fn, mfvi_params,
            mfvi_state)
Example #2
0
 def test_top_k_fraction(
     self, top_k_fraction, scaled_advantages, expected_top_k_weights):
   """Test that only the top k fraction are used."""
   top_k_weights = mpo_ops.get_top_k_weights(
       top_k_fraction, jnp.ones_like(scaled_advantages), scaled_advantages)
   np.testing.assert_allclose(top_k_weights, expected_top_k_weights)
Example #3
0
 def ones_like(self, tensor):
     return jnp.ones_like(tensor)
Example #4
0
def model(X: DeviceArray) -> DeviceArray:
    """Gamma-Poisson hierarchical model for daily sales forecasting

    Args:
        X: input data

    Returns:
        output data
    """
    n_stores, n_days, n_features = X.shape
    n_features -= 1  # remove one dim for target
    eps = 1e-12  # epsilon

    plate_features = numpyro.plate(Plate.features, n_features, dim=-1)
    plate_stores = numpyro.plate(Plate.stores, n_stores, dim=-2)
    plate_days = numpyro.plate(Plate.days, n_days, dim=-1)

    disp_param_mu = numpyro.sample(Site.disp_param_mu,
                                   dist.Normal(loc=4.0, scale=1.0))
    disp_param_sigma = numpyro.sample(Site.disp_param_sigma,
                                      dist.HalfNormal(scale=1.0))

    with plate_stores:
        with numpyro.handlers.reparam(
                config={Site.disp_params: TransformReparam()}):
            disp_params = numpyro.sample(
                Site.disp_params,
                dist.TransformedDistribution(
                    dist.Normal(loc=jnp.zeros((n_stores, 1)), scale=0.1),
                    dist.transforms.AffineTransform(disp_param_mu,
                                                    disp_param_sigma),
                ),
            )

    with plate_features:
        coef_mus = numpyro.sample(
            Site.coef_mus,
            dist.Normal(loc=jnp.zeros(n_features), scale=jnp.ones(n_features)),
        )
        coef_sigmas = numpyro.sample(
            Site.coef_sigmas,
            dist.HalfNormal(scale=2.0 * jnp.ones(n_features)))

        with plate_stores:
            with numpyro.handlers.reparam(
                    config={Site.coefs: TransformReparam()}):
                coefs = numpyro.sample(
                    Site.coefs,
                    dist.TransformedDistribution(
                        dist.Normal(loc=jnp.zeros((n_stores, n_features)),
                                    scale=1.0),
                        dist.transforms.AffineTransform(coef_mus, coef_sigmas),
                    ),
                )

    with plate_days, plate_stores:
        targets = X[..., -1]
        features = jnp.nan_to_num(X[..., :-1])  # padded features to 0
        is_observed = jnp.where(jnp.isnan(targets), jnp.zeros_like(targets),
                                jnp.ones_like(targets))
        not_observed = 1 - is_observed
        means = (is_observed * jnp.exp(
            jnp.sum(jnp.expand_dims(coefs, axis=1) * features, axis=2)) +
                 not_observed * eps)

        betas = is_observed * jnp.exp(-disp_params) + not_observed
        alphas = means * betas
        return numpyro.sample(Site.days,
                              dist.GammaPoisson(alphas, betas),
                              obs=jnp.nan_to_num(targets))
Example #5
0
 def g(x, y):
   tot = np.sum(5. * np.cos(x) * np.sin(y))
   return tot * np.ones_like(x)  # broadcast to map like pjit does
Example #6
0
def value_and_grad(fn, args):
    """Given `fn: (args) -> out, extra`, returns `dout/dargs`."""
    output, vjp_fn, extra = jax.vjp(fn, args, has_aux=True)
    grad = vjp_fn(np.ones_like(output))[0]
    return output, extra, grad
Example #7
0
def _mix_with_uniform(probs, epsilon):
    """Mix an arbitrary categorical distribution with a uniform distribution."""
    num_actions = probs.shape[-1]
    uniform_probs = jnp.ones_like(probs) / num_actions
    return (1 - epsilon) * probs + epsilon * uniform_probs
Example #8
0
molmassCO = molinfo.molmass('CO')
mmw = 2.33  # mean molecular weight
mmrH2 = 0.74
molmassH2 = molinfo.molmass('H2')
vmrH2 = (mmrH2*mmw/molmassH2)  # VMR

Mp = 33.2  # fixing mass...

# Loading the molecular database of CO and the CIA
# In[8]:


# reference pressure for a T-P model
Pref = 1.0  # bar
ONEARR = np.ones_like(Parr)
ONEWAV = jnp.ones_like(nflux)


# In[14]:


smalla = 1.0
smalldiag = smalla**2*jnp.identity(NP)


# Now we write the model, which is used in HMC-NUTS.

# In[15]:


def modelcov(t, tau, a):
Example #9
0
 def fun(x, y):
     @partial(api.jit, backend=inner)
     def infun(x, y):
         return np.matmul(x, y)
     return infun(x, y) + np.ones_like(x)
Example #10
0
def safe_sqrt(x, eps=1e-7):
    safe_x = jnp.where(x == 0, jnp.ones_like(x) * eps, x)
    return jnp.sqrt(safe_x)
Example #11
0
def general_loss_with_squared_residual(squared_x, alpha, scale):
    r"""The general loss that takes a squared residual.

  This fuses the sqrt operation done to compute many residuals while preserving
  the square in the loss formulation.

  This implements the rho(x, \alpha, c) function described in "A General and
  Adaptive Robust Loss Function", Jonathan T. Barron,
  https://arxiv.org/abs/1701.03077.

  Args:
    squared_x: The residual for which the loss is being computed. x can have
      any shape, and alpha and scale will be broadcasted to match x's shape if
      necessary.
    alpha: The shape parameter of the loss (\alpha in the paper), where more
      negative values produce a loss with more robust behavior (outliers "cost"
      less), and more positive values produce a loss with less robust behavior
      (outliers are penalized more heavily). Alpha can be any value in
      [-infinity, infinity], but the gradient of the loss with respect to alpha
      is 0 at -infinity, infinity, 0, and 2. Varying alpha allows for smooth
      interpolation between several discrete robust losses:
        alpha=-Infinity: Welsch/Leclerc Loss.
        alpha=-2: Geman-McClure loss.
        alpha=0: Cauchy/Lortentzian loss.
        alpha=1: Charbonnier/pseudo-Huber loss.
        alpha=2: L2 loss.
    scale: The scale parameter of the loss. When |x| < scale, the loss is an
      L2-like quadratic bowl, and when |x| > scale the loss function takes on a
      different shape according to alpha.

  Returns:
    The losses for each element of x, in the same shape as x.
  """
    eps = jnp.finfo(jnp.float32).eps

    # This will be used repeatedly.
    squared_scaled_x = squared_x / (scale**2)

    # The loss when alpha == 2.
    loss_two = 0.5 * squared_scaled_x
    # The loss when alpha == 0.
    loss_zero = log1p_safe(0.5 * squared_scaled_x)
    # The loss when alpha == -infinity.
    loss_neginf = -jnp.expm1(-0.5 * squared_scaled_x)
    # The loss when alpha == +infinity.
    loss_posinf = expm1_safe(0.5 * squared_scaled_x)

    # The loss when not in one of the above special cases.
    # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
    beta_safe = jnp.maximum(eps, jnp.abs(alpha - 2.))
    # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
    alpha_safe = jnp.where(jnp.greater_equal(alpha, 0.), jnp.ones_like(alpha),
                           -jnp.ones_like(alpha)) * jnp.maximum(
                               eps, jnp.abs(alpha))
    loss_otherwise = (beta_safe / alpha_safe) * (
        jnp.power(squared_scaled_x / beta_safe + 1., 0.5 * alpha) - 1.)

    # Select which of the cases of the loss to return.
    loss = jnp.where(
        alpha == -jnp.inf, loss_neginf,
        jnp.where(
            alpha == 0, loss_zero,
            jnp.where(alpha == 2, loss_two,
                      jnp.where(alpha == jnp.inf, loss_posinf,
                                loss_otherwise))))

    return scale * loss
Example #12
0
def vmpo_loss(
    sample_log_probs: Array,
    advantages: Array,
    temperature_constraint: LagrangePenalty,
    kl_constraints: Sequence[Tuple[Array, LagrangePenalty]],
    projection_operator: Callable[[Numeric], Numeric] = functools.partial(
        jnp.clip, a_min=_EPSILON),
    restarting_weights: Optional[Array] = None,
    importance_weights: Optional[Array] = None,
    top_k_fraction: float = 0.5,
    policy_loss_weight: float = 1.0,
    temperature_loss_weight: float = 1.0,
    kl_loss_weight: float = 1.0,
    alpha_loss_weight: float = 1.0,
    axis_name: Optional[str] = None,
    use_stop_gradient: bool = True,
) -> Tuple[Array, MpoOutputs]:
  """Calculates the V-MPO policy improvement loss.

  Note: This is a per-example loss which works on any shape inputs as long as
  they are consistent. We denote the shape of the examples E* for ease of
  reference.

  Args:
    sample_log_probs: Log probabilities of actions for each example. Shape E*.
    advantages: Advantages for the E-step. Shape E*.
    temperature_constraint: Lagrange constraint for the E-step temperature
      optimization.
    kl_constraints: KL and variables for applying Lagrangian penalties to bound
      them in the M-step, KLs are E* or [E*, A]. Here A is the action dimension
      in the case of per-dimension KL constraints.
    projection_operator: Function to project dual variables (temperature and kl
      constraint alphas) into the positive range.
    restarting_weights: Optional restarting weights, shape E*, 0 means that this
      step is the start of a new episode and we ignore losses at this step
      because the agent cannot influence these.
    importance_weights: Optional importance weights, shape E*.
    top_k_fraction: Fraction of samples to use in the E-step.
    policy_loss_weight: Weight for the policy loss.
    temperature_loss_weight: Weight for the temperature loss.
    kl_loss_weight: Weight for the KL loss.
    alpha_loss_weight: Weight for the alpha loss.
    axis_name: Optional axis name for `pmap`. If `None`, computations
      are performed locally on each device.
    use_stop_gradient: bool indicating whether or not to apply stop gradient.

  Returns:
    Per example `loss` with same shape E* as array inputs, and additional data
    including the components of this loss and the normalized weights in the
    AdditionalOutputs.
  """
  # Define default restarting weights and importance weights.
  if restarting_weights is None:
    restarting_weights = jnp.ones_like(sample_log_probs)
  if importance_weights is None:
    importance_weights = jnp.ones_like(sample_log_probs)

  # Check shapes.
  chex.assert_equal_shape(
      [advantages, sample_log_probs, restarting_weights, importance_weights])

  chex.assert_rank(temperature_constraint.epsilon, 0)
  chex.assert_type([
      sample_log_probs, advantages, restarting_weights, importance_weights,
      temperature_constraint.alpha, temperature_constraint.epsilon], float)

  for kl, penalty in kl_constraints:
    chex.assert_rank(penalty.epsilon, 0)
    chex.assert_type([kl, penalty.alpha, penalty.epsilon], float)
    if penalty.per_dimension:
      chex.assert_rank(kl, advantages.ndim + 1)
      chex.assert_equal_shape_prefix([kl, advantages], advantages.ndim)
    else:
      chex.assert_equal_shape([kl, advantages])

  # E-step: Calculate the reweighting and the temperature loss.
  temperature_loss, norm_weights, num_samples = (
      vmpo_compute_weights_and_temperature_loss(
          advantages, restarting_weights, importance_weights,
          temperature_constraint, projection_operator, top_k_fraction,
          axis_name=axis_name, use_stop_gradient=use_stop_gradient))

  # M-step: Supervised learning of reweighted trajectories using the weights
  # from the E-step, with additional KL constraints.
  # The weights are normalized so that the sum is 1. We multiply by the number
  # of examples so that we can give a policy loss per example and take the mean,
  # and we assume `restarting_weights` are already included.
  if axis_name:
    num_examples = jax.lax.all_gather(
        sample_log_probs, axis_name=axis_name).size
  else:
    num_examples = sample_log_probs.size
  policy_loss = -sample_log_probs * norm_weights * num_examples

  kl_loss, alpha_loss = compute_parametric_kl_penalty_and_dual_loss(
      kl_constraints, projection_operator, use_stop_gradient)

  chex.assert_equal_shape([policy_loss, kl_loss, alpha_loss])

  # Calculate the total policy improvement loss.
  loss = (policy_loss_weight * policy_loss +
          temperature_loss_weight * temperature_loss +
          kl_loss_weight * kl_loss +
          alpha_loss_weight * alpha_loss)

  return loss, MpoOutputs(
      temperature_loss=temperature_loss, policy_loss=policy_loss,
      kl_loss=kl_loss, alpha_loss=alpha_loss, normalized_weights=norm_weights,
      num_samples=num_samples)
Example #13
0
#     that the mapping parameter c = (zf-z0)/(xf-x0) is also constant

## construct univariate tfc class: *****************************************************************
tfc = utfc(N + 1, nC, int(m + 1), basis=basis, x0=0, xf=xstep)
x = tfc.x
# !!! notice I am using N+1 for the number of points. this is because I will be using the last point
#     of a segment 'n' for the initial conditons of the 'n+1' segment

H = tfc.H
dH = tfc.dH
H0 = H(x[0:1])
H0p = dH(x[0:1])

## define tfc constrained expression and derivatives: **********************************************
# switching function
phi1 = lambda x: np.ones_like(x)
phi2 = lambda x: x

# tfc constrained expression
y = lambda x,xi,IC: np.dot(H(x),xi) + phi1(x)*(IC['y0']  - np.dot(H0,xi)) \
                                    + phi2(x)*(IC['y0p'] - np.dot(H0p,xi))
# !!! notice here that the initial conditions are passed as a dictionary (i.e. IC['y0'])
#     this will be important so that the least-squares does not need to be re-JITed

yp = egrad(y)
ypp = egrad(yp)

## define the loss function: ***********************************************************************
#   yₓₓ + δ yₓ + α y + β y^3 - γ cos(ω x) = 0
L = jit(lambda xi,IC: ypp(x,xi,IC) + delta*yp(x,xi,IC) + alpha*y(x,xi,IC) + beta*y(x,xi,IC)**3 \
                                   - gamma*np.cos(omega*x))
Example #14
0
 def model(data):
     mean = numpyro.sample(
         "mean", dist.Normal(ref_params, jnp.ones_like(ref_params)))
     with numpyro.plate("N", data.shape[0], subsample_size=100,
                        dim=-2) as idx:
         numpyro.sample("obs", dist.Normal(mean, sigma), obs=data[idx])
Example #15
0
File: jax.py Project: yibit/eagerpy
 def ones_like(self: TensorType) -> TensorType:
     return type(self)(np.ones_like(self.raw))
Example #16
0
 def add_fxy_coord(X, Y, coefficient):
     X_fxy_agg.append(jnp.ravel(X))
     Y_fxy_agg.append(jnp.ravel(Y))
     fxy_coefficient.append(coefficient * jnp.ones_like(Y_fxy_agg[-1]))
Example #17
0
 def __call__(self, x, t):
     tt = jnp.ones_like(x[:, :, :, :1]) * t
     ttx = jnp.concatenate([tt, x], axis=-1)
     return self._layer(ttx)
Example #18
0
    def encode(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer

        >>> # initialize a bert2gpt2 from pretrained BERT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxEncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")

        >>> tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

        >>> text = "My friends are cool but they eat too many carbs."
        >>> input_ids = tokenizer.encode(text, return_tensors="np")
        >>> encoder_outputs = model.encode(input_ids)
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _encoder_forward(module, input_ids, attention_mask, position_ids, **kwargs):
            encode_module = module._get_encoder_module()
            return encode_module(input_ids, attention_mask, position_ids, **kwargs)

        outputs = self.module.apply(
            {"params": params or self.params},
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            method=_encoder_forward,
        )

        if return_dict:
            outputs = FlaxBaseModelOutput(
                last_hidden_state=outputs.last_hidden_state,
                hidden_states=outputs.hidden_states,
                attentions=outputs.attentions,
            )

        return outputs
Example #19
0
    def get_text_features(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train=False,
    ):
        r"""
        Args:
            input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`CLIPTokenizer`]. See [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)

        Returns:
            text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by applying
            the projection layer to the pooled output of [`FlaxCLIPTextModel`].

        Examples:

        ```python
        >>> from transformers import CLIPTokenizer, FlaxCLIPModel

        >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
        >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

        >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="np")
        >>> text_features = model.get_text_features(**inputs)
        ```"""
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
                input_ids.shape)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _get_features(module, input_ids, attention_mask, position_ids,
                          deterministic):
            text_outputs = module.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
            )
            pooled_output = text_outputs[1]
            text_features = module.text_projection(pooled_output)
            return text_features

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            method=_get_features,
            rngs=rngs,
        )
Example #20
0
    def __call__(
        self,
        input_ids: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        position_ids: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import FlaxEncoderDecoderModel, BertTokenizer, GPT2Tokenizer

        >>> # load a fine-tuned bert2gpt2 model
        >>> model = FlaxEncoderDecoderModel.from_pretrained("patrickvonplaten/bert2gpt2-cnn_dailymail-fp16")
        >>> # load input & output tokenizer
        >>> tokenizer_input = BertTokenizer.from_pretrained("bert-base-cased")
        >>> tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2")

        >>> article = '''Sigma Alpha Epsilon is under fire for a video showing party-bound fraternity members
        >>> singing a racist chant. SAE's national chapter suspended the students,
        >>> but University of Oklahoma President David Boren took it a step further,
        >>> saying the university's affiliation with the fraternity is permanently done.'''

        >>> input_ids = tokenizer_input(article, add_special_tokens=True, return_tensors="np").input_ids

        >>> # use GPT2's eos_token as the pad as well as eos token
        >>> model.config.eos_token_id = model.config.decoder.eos_token_id
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> sequences = model.generate(input_ids, num_beams=4, max_length=12).sequences

        >>> summary = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)[0]
        >>> assert summary == "SAS Alpha Epsilon suspended Sigma Alpha Epsilon members"
        ```
        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # prepare encoder inputs
        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)
        if position_ids is None:
            batch_size, sequence_length = input_ids.shape
            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        # prepare decoder inputs
        if decoder_input_ids is None:
            raise ValueError(
                "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
            )
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},
            input_ids=jnp.array(input_ids, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            position_ids=jnp.array(position_ids, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
        )
Example #21
0
 def kl_to_standard_normal_fn(mu: Array, sigma: Array = sigma):
     v = jnp.clip(sigma**2, 1e-6, 1e6)
     return 0.5 * (jnp.sum(v) + jnp.sum(mu**2) -
                   jnp.sum(jnp.ones_like(mu)) - jnp.sum(jnp.log(v)))
Example #22
0
  def apply(self,
            inputs,
            vocab_size,
            emb_dim=512,
            num_heads=8,
            num_layers=6,
            qkv_dim=512,
            mlp_dim=2048,
            max_len=2048,
            train=False,
            dropout_rate=0.1,
            attention_dropout_rate=0.1,
            causal=True,
            cache=None,
            positional_encoding_module=AddLearnedPositionalEncodings,
            self_attention_module=nn.SelfAttention,
            attention_fn=None,
            pad_token=None,
            output_head='logits'):
    """Applies Transformer model on the inputs.

    Args:
      inputs: An array of shape (batch_size, length) or (batch_size, length,
        vocab_size) with the input sequences. When 2-dimensional, the array
        contains sequences of int tokens. Otherwise, the array contains
        next-token distributions over tokens (e.g. one-hot representations).
      vocab_size: An int with the size of the vocabulary.
      emb_dim: An int with the token embedding dimension.
      num_heads: An int with the number of attention heads.
      num_layers: An int with the number of transformer encoder layers.
      qkv_dim: An int with the dimension of the query/key/value vectors.
      mlp_dim: An int with the inner dimension of the feed-forward network which
        follows the attention block.
      max_len: An int with the maximum training sequence length.
      train: A bool denoting whether we are currently training.
      dropout_rate: A float with the dropout rate.
      attention_dropout_rate: A float with a dropout rate for attention weights.
      causal: Whether to apply causal masking.
      cache: Cache for decoding.
      positional_encoding_module: A module used for adding positional encodings.
      self_attention_module: Self attention module.
      attention_fn: Method to use in place of dot product attention.
      pad_token: Token to ignore in attention.
      output_head: String or iterable over strings containing the model's output
        head(s) to return.

    Returns:
      Output of a transformer decoder. If output_head is a string, we return a
        single output head output; if output_head is an iterable, we return a
        dict with (output head name, output head output) key-value pairs.
    """
    if inputs.ndim != 2 and inputs.ndim != 3:
      raise ValueError('Expected 2 or 3 dimensions, found %d.' % inputs.ndim)

    if inputs.ndim == 3:
      padding_mask = jnp.ones_like(inputs[Ellipsis, 0])
    elif pad_token is None:
      padding_mask = jnp.ones_like(inputs)
    else:
      # Mask out padding tokens.
      padding_mask = jnp.where(inputs != pad_token, 1, 0).astype(jnp.float32)
    padding_mask = padding_mask[Ellipsis, None]  # Add embedding dimension.

    heads = dict()
    x = inputs
    if inputs.ndim == 2:
      x = x.astype('int32')
    x = Embed(x, num_embeddings=vocab_size, num_features=emb_dim, name='embed')

    if positional_encoding_module == AddLearnedPositionalEncodings:
      x = positional_encoding_module(
          x,
          max_len=max_len,
          cache=cache,
          posemb_init=sinusoidal_init(max_len=max_len))
    else:
      x = positional_encoding_module(x, max_len=max_len)
    x = nn.dropout(x, rate=dropout_rate, deterministic=not train)
    heads['input_emb'] = x
    for i in range(num_layers):
      x = Transformer1DBlock(
          x,
          qkv_dim=qkv_dim,
          mlp_dim=mlp_dim,
          num_heads=num_heads,
          causal_mask=causal,
          padding_mask=padding_mask,
          dropout_rate=dropout_rate,
          attention_dropout_rate=attention_dropout_rate,
          self_attention_module=self_attention_module,
          deterministic=not train,
          attention_fn=attention_fn,
          cache=cache,
      )
      heads['layer_%s' % i] = x
    x = nn.LayerNorm(x)
    heads['output_emb'] = x * padding_mask  # Zero out PAD positions.
    if 'logits' in output_head:
      logits = nn.Dense(
          x,
          vocab_size,
          kernel_init=nn.initializers.xavier_uniform(),
          bias_init=nn.initializers.normal(stddev=1e-6))
      heads['logits'] = logits

    if 'regression' in output_head:
      regression = nn.Dense(
          x,
          1,
          kernel_init=nn.initializers.xavier_uniform(),
          bias_init=nn.initializers.normal(stddev=1e-6))
      regression = jnp.squeeze(regression, axis=-1)
      heads['regression'] = regression

    if isinstance(output_head, (tuple, list)):
      return {head: heads[head] for head in output_head}
    return heads[output_head]
Example #23
0
 def splitjvp(x):
   _, jvp = linearize(f, x)
   return jvp(np.ones_like(x))
def _reset_head_kernel(params, value):
    params = flax.core.unfreeze(params)
    params["head"]["kernel"] = value * jnp.ones_like(params["head"]["kernel"])
    return flax.core.freeze(params)
Example #25
0
def vmpo_e_step_without_restarting_or_importance_weights(advantages, **kwargs):
  restarting_weights = jnp.ones_like(advantages)
  importance_weights = jnp.ones_like(advantages)
  return mpo_ops.vmpo_compute_weights_and_temperature_loss(
      advantages=advantages, restarting_weights=restarting_weights,
      importance_weights=importance_weights, **kwargs)
Example #26
0
class JaxBox(qml.math.TensorBox):
    """Implements the :class:`~.TensorBox` API for ``numpy.ndarray``.

    For more details, please refer to the :class:`~.TensorBox` documentation.
    """

    abs = wrap_output(lambda self: jnp.abs(self.data))
    angle = wrap_output(lambda self: jnp.angle(self.data))
    arcsin = wrap_output(lambda self: jnp.arcsin(self.data))
    cast = wrap_output(lambda self, dtype: jnp.array(self.data, dtype=dtype))
    expand_dims = wrap_output(
        lambda self, axis: jnp.expand_dims(self.data, axis=axis))
    ones_like = wrap_output(lambda self: jnp.ones_like(self.data))
    sqrt = wrap_output(lambda self: jnp.sqrt(self.data))
    sum = wrap_output(lambda self, axis=None, keepdims=False: jnp.sum(
        self.data, axis=axis, keepdims=keepdims))
    T = wrap_output(lambda self: self.data.T)
    take = wrap_output(lambda self, indices, axis=None: jnp.take(
        self.data, indices, axis=axis, mode="wrap"))

    def __init__(self, tensor):
        tensor = jnp.asarray(tensor)

        super().__init__(tensor)

    @staticmethod
    def astensor(tensor):
        return jnp.asarray(tensor)

    @staticmethod
    @wrap_output
    def concatenate(values, axis=0):
        return jnp.concatenate(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def dot(x, y):
        x, y = JaxBox.unbox_list([x, y])
        x = jnp.asarray(x)
        y = jnp.asarray(y)

        if x.ndim == 0 and y.ndim == 0:
            return x * y

        if x.ndim == 2 and y.ndim == 2:
            return x @ y

        return jnp.dot(x, y)

    @property
    def interface(self):
        return "jax"

    def numpy(self):
        return self.data

    @property
    def requires_grad(self):
        return True

    @property
    def shape(self):
        return self.data.shape

    @staticmethod
    @wrap_output
    def stack(values, axis=0):
        return jnp.stack(JaxBox.unbox_list(values), axis=axis)

    @staticmethod
    @wrap_output
    def where(condition, x, y):
        return jnp.where(condition, *JaxBox.unbox_list([x, y]))
Example #27
0
def collapse_and_remove_blanks(labels: jnp.ndarray,
                               seq_length: jnp.ndarray,
                               blank_id: int = 0):
  """Merge repeated labels into single labels and remove the designated blank symbol.

  Args:
    labels: Array of shape (batch, seq_length)
    seq_length: Arrray of shape (batch), sequence length of each batch element.
    blank_id: Optional id of the blank symbol

  Returns:
    tuple of tf.SparseTensor of shape (batch, seq_length) with repeated labels
    collapsed, eg: [[A, A, B, B, A],
                    [A, B, C, D, E]] => [[A, B, A],
                                         [A, B, C, D, E]]
    and int tensor of shape [batch] with new sequence lengths.
  """
  b, t = labels.shape
  # Zap out blank
  blank_mask = 1 - jnp.equal(labels, blank_id)
  labels = (labels * blank_mask).astype(labels.dtype)

  # Mask labels that don't equal previous label.
  label_mask = jnp.concatenate([
      jnp.ones_like(labels[:, :1], dtype=jnp.int32),
      jnp.not_equal(labels[:, 1:], labels[:, :-1])
  ],
                               axis=1)

  # Filter labels that aren't in the original sequence.
  maxlen = labels.shape[1]
  seq_mask = sequence_mask(seq_length, maxlen=maxlen)
  label_mask = label_mask * seq_mask

  # remove repetitions from the labels
  ulabels = label_mask * labels

  # Count masks for new sequence lengths.
  label_mask = jnp.not_equal(ulabels, 0).astype(labels.dtype)
  new_seq_len = jnp.sum(label_mask, axis=1)

  # Mask indexes based on sequence length mask.
  new_maxlen = maxlen
  idx_mask = sequence_mask(new_seq_len, maxlen=new_maxlen)

  # Flatten everything and mask out labels to keep and sparse indices.
  flat_labels = jnp.reshape(ulabels, [-1])
  flat_idx_mask = jnp.reshape(idx_mask, [-1])

  indices = jnp.nonzero(flat_idx_mask, size=b * t)[0]
  values = jnp.nonzero(flat_labels, size=b * t)[0]
  updates = jnp.take_along_axis(flat_labels, values, axis=-1)

  # Scatter to flat shape.
  flat = jnp.zeros(flat_idx_mask.shape).astype(labels.dtype)
  flat = flat.at[indices].set(updates)
  # 0'th position in the flat array gets clobbered by later padded updates,
  # so reset it here to its original value
  flat = flat.at[0].set(updates[0])

  # Reshape back to square batch.
  batch_size = labels.shape[0]
  new_shape = [batch_size, new_maxlen]
  return (jnp.reshape(flat, new_shape).astype(labels.dtype),
          new_seq_len.astype(seq_length.dtype))
Example #28
0
    def fit_STC(self,
                prewhiten=False,
                n_repeats=10,
                percentile=100.,
                random_seed=2046,
                verbose=5):
        """

        Spike-triggered Covariance Analysis.

        Parameters
        ==========

        prewhiten: bool

        n_repeats: int
            Number of repeats for STC significance test.

        percentile: float
            Valid range of STC significance test.

        verbose: int
        random_seed: int
        """
        def get_stc(_X, _y, _w):

            n = len(_X)
            ste = _X[_y != 0]
            proj = ste - ste * _w * _w.T
            stc = proj.T @ proj / (n - 1)

            _eigvec, _eigval, _ = jnp.linalg.svd(stc)

            return _eigvec, _eigval

        key = random.PRNGKey(random_seed)

        y = self.y

        if prewhiten:

            if self.compute_mle is False:
                self.XtX = self.X.T @ self.X
                self.w_mle = jnp.linalg.solve(self.XtX, self.XtY)

            X = jnp.linalg.solve(self.XtX, self.X.T).T
            w = uvec(self.w_mle)

        else:
            X = self.X
            w = uvec(self.w_sta)

        eigvec, eigval = get_stc(X, y, w)

        self.w_stc = dict()
        if n_repeats:
            print('STC significance test: ')
            eigval_null = []
            for counter in range(n_repeats):
                if verbose:
                    if counter % int(verbose) == 0:
                        print(f'  {counter + 1:}/{n_repeats}')

                y_randomize = random.permutation(key, y)
                _, eigval_randomize = get_stc(X, y_randomize, w)
                eigval_null.append(eigval_randomize)
            else:
                if verbose:
                    print(f'Done.')
            eigval_null = jnp.vstack(eigval_null)
            max_null, min_null = jnp.percentile(eigval_null,
                                                percentile), jnp.percentile(
                                                    eigval_null,
                                                    100 - percentile)
            mask_sig_pos = eigval > max_null
            mask_sig_neg = eigval < min_null
            mask_sig = jnp.logical_or(mask_sig_pos, mask_sig_neg)

            self.w_stc['eigvec'] = eigvec
            self.w_stc['pos'] = eigvec[:, mask_sig_pos]
            self.w_stc['neg'] = eigvec[:, mask_sig_neg]

            self.w_stc['eigval'] = eigval
            self.w_stc['eigval_mask'] = mask_sig
            self.w_stc['eigval_pos_mask'] = mask_sig_pos
            self.w_stc['eigval_neg_mask'] = mask_sig_neg

            self.w_stc['max_null'] = max_null
            self.w_stc['min_null'] = min_null

        else:
            self.w_stc['eigvec'] = eigvec
            self.w_stc['eigval'] = eigval
            self.w_stc['eigval_mask'] = jnp.ones_like(eigval).astype(bool)
Example #29
0
 def init(x0):
   avg_sq_grad = np.ones_like(x0)
   return x0, avg_sq_grad
Example #30
0
def vector_DP_inv_pd(v):
    return np.exp(v - np.ones_like(v))