Exemple #1
0
def add_batch(nest, batch_size):
    broadcast = lambda x: jnp.broadcast_to(x, (batch_size, ) + x.shape)
    return jax.tree_map(broadcast, nest)
Exemple #2
0
 def pre_pmap(xs):
     return jax.tree_map(lambda x: jnp.broadcast_to(x, (1, ) + x.shape), xs)
Exemple #3
0
 def s(keys):
     keys = np.broadcast_to(keys, (N_DEVICES, ) + keys.shape)
     return g(keys)
Exemple #4
0
 def mean(self):
     return jnp.broadcast_to(self.base_dist.mean,
                             self.batch_shape + self.event_shape)
Exemple #5
0
 def log_prob(self, value):
     shape = lax.broadcast_shapes(self.batch_shape, jnp.shape(value)[:-1])
     return jnp.broadcast_to(self.log_factor, shape)
    def __call__(
        self,
        inputs: jnp.ndarray,
        attention_mask: Optional[jnp.ndarray] = None,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        freeze_feature_encoder: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import FlaxSpeechEncoderDecoderModel, BartTokenizer

        >>> # load a fine-tuned wav2vec2-2-bart model
        >>> model = FlaxSpeechEncoderDecoderModel.from_pretrained("patrickvonplaten/wav2vec2-2-bart-large")
        >>> # load output tokenizer
        >>> tokenizer_output = BartTokenizer.from_pretrained("facebook/bart-large")

        >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)

        >>> # use bart's special bos, pad and eos tokens
        >>> model.config.decoder_start_token_id = model.decoder.config.bos_token_id
        >>> model.config.pad_token_id = model.decoder.config.pad_token_id
        >>> model.config.eos_token_id = model.decoder.config.eos_token_id

        >>> outputs = model.generate(inputs)
        # Assert something? More interesting input? dtype correct?
        ```
        """

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # prepare encoder inputs
        if attention_mask is None:
            attention_mask = jnp.ones_like(inputs)

        # prepare decoder inputs
        if decoder_input_ids is None:
            raise ValueError(
                "`decoder_input_ids` cannot be `None`. For sequence to sequence training, `decoder_position_ids` must be specified as an input argument."
            )
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},
            inputs=jnp.array(inputs, dtype="i4"),
            attention_mask=jnp.array(attention_mask, dtype="i4"),
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            freeze_feature_encoder=freeze_feature_encoder,
            rngs=rngs,
        )
Exemple #7
0
    def get_text_features(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        token_type_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train=False,
    ):
        r"""
        Args:
            input_ids (`numpy.ndarray` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using [`PreTrainedTokenizer`]. See
                [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
                for details.

                [What are input IDs?](../glossary#input-ids)

        Returns:
            text_features (`jnp.ndarray` of shape `(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of text model.
        """
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
                input_ids.shape)

        if token_type_ids is None:
            token_type_ids = jnp.zeros_like(input_ids)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _get_features(module, input_ids, attention_mask, position_ids,
                          token_type_ids, deterministic):
            text_outputs = module.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                token_type_ids=token_type_ids,
                deterministic=deterministic,
            )
            pooled_output = text_outputs[1]
            text_features = module.text_projection(pooled_output)
            return text_features

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            jnp.array(token_type_ids, dtype="i4"),
            not train,
            method=_get_features,
            rngs=rngs,
        )
Exemple #8
0
def polygamma(n, x):
    assert jnp.issubdtype(lax.dtype(n), jnp.integer)
    n, x = _promote_args_inexact("polygamma", n, x)
    shape = lax.broadcast_shapes(n.shape, x.shape)
    return _polygamma(jnp.broadcast_to(n, shape), jnp.broadcast_to(x, shape))
Exemple #9
0
def _gen_associated_legendre(l_max: int, x: jnp.ndarray,
                             is_normalized: bool) -> jnp.ndarray:
    r"""Computes associated Legendre functions (ALFs) of the first kind.

  The ALFs of the first kind are used in spherical harmonics. The spherical
  harmonic of degree `l` and order `m` can be written as
  `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
  normalization factor and θ and φ are the colatitude and longitude,
  repectively. `N_l^m` is chosen in the way that the spherical harmonics form
  a set of orthonormal basis function of L^2(S^2). For the computational
  efficiency of spherical harmonics transform, the normalization factor is
  used in the computation of the ALFs. In addition, normalizing `P_l^m`
  avoids overflow/underflow and achieves better numerical stability. Three
  recurrence relations are used in the computation.

  Args:
    l_max: The maximum degree of the associated Legendre function. Both the
      degrees and orders are `[0, 1, 2, ..., l_max]`.
    x: A vector of type `float32`, `float64` containing the sampled points in
      spherical coordinates, at which the ALFs are computed; `x` is essentially
      `cos(θ)`. For the numerical integration used by the spherical harmonics
      transforms, `x` contains the quadrature points in the interval of
      `[-1, 1]`. There are several approaches to provide the quadrature points:
      Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
      method (`scipy.special.roots_chebyu`), and Driscoll & Healy
      method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
      transforms and convolutions on the 2-sphere." Advances in applied
      mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
      points are nearly equal-spaced along θ and provide exact discrete
      orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
      operation, `W` is a diagonal matrix containing the quadrature weights,
      and `I` is the identity matrix. The Gauss-Chebyshev points are equally
      spaced, which only provide approximate discrete orthogonality. The
      Driscoll & Healy qudarture points are equally spaced and provide the
      exact discrete orthogonality. The number of sampling points is required to
      be twice as the number of frequency points (modes) in the Driscoll & Healy
      approach, which enables FFT and achieves a fast spherical harmonics
      transform.
    is_normalized: True if the associated Legendre functions are normalized.
      With normalization, `N_l^m` is applied such that the spherical harmonics
      form a set of orthonormal basis functions of L^2(S^2).

  Returns:
    The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
    of the ALFs at `x`; the dimensions in the sequence of order, degree, and
    evalution points.
  """
    p = jnp.zeros((l_max + 1, l_max + 1, x.shape[0]))

    a_idx = jnp.arange(1, l_max + 1)
    b_idx = jnp.arange(l_max)
    if is_normalized:
        initial_value = 0.5 / jnp.sqrt(jnp.pi)  # The initial value p(0,0).
        f_a = jnp.cumprod(-1 * jnp.sqrt(1.0 + 0.5 / a_idx))
        f_b = jnp.sqrt(2.0 * b_idx + 3.0)
    else:
        initial_value = 1.0  # The initial value p(0,0).
        f_a = jnp.cumprod(1.0 - 2.0 * a_idx)
        f_b = 2.0 * b_idx + 1.0

    p = p.at[(0, 0)].set(initial_value)

    # Compute the diagonal entries p(l,l) with recurrence.
    y = jnp.cumprod(jnp.broadcast_to(jnp.sqrt(1.0 - x * x),
                                     (l_max, x.shape[0])),
                    axis=0)
    p_diag = initial_value * jnp.einsum('i,ij->ij', f_a, y)
    diag_indices = jnp.diag_indices(l_max + 1)
    p = p.at[(diag_indices[0][1:], diag_indices[1][1:])].set(p_diag)

    # Compute the off-diagonal entries with recurrence.
    p_offdiag = jnp.einsum('ij,ij->ij', jnp.einsum('i,j->ij', f_b, x),
                           p[jnp.diag_indices(l_max)])
    offdiag_indices = (diag_indices[0][:l_max], diag_indices[1][:l_max] + 1)
    p = p.at[offdiag_indices].set(p_offdiag)

    # Compute the remaining entries with recurrence.
    d0_mask_3d, d1_mask_3d = _gen_recurrence_mask(l_max,
                                                  is_normalized=is_normalized)

    def body_fun(i, p_val):
        coeff_0 = d0_mask_3d[i]
        coeff_1 = d1_mask_3d[i]
        h = (jnp.einsum(
            'ij,ijk->ijk', coeff_0,
            jnp.einsum('ijk,k->ijk', jnp.roll(p_val, shift=1, axis=1), x)) -
             jnp.einsum('ij,ijk->ijk', coeff_1, jnp.roll(
                 p_val, shift=2, axis=1)))
        p_val = p_val + h
        return p_val

    # TODO(jakevdp): use some sort of fixed-point procedure here instead?
    p = p.astype(jnp.result_type(p, x, d0_mask_3d))
    if l_max > 1:
        p = lax.fori_loop(lower=2,
                          upper=l_max + 1,
                          body_fun=body_fun,
                          init_val=p)

    return p
Exemple #10
0
 def variance(self):
     """ Computes circular variance of distribution """
     return jnp.broadcast_to(
         1. - lax.bessel_i1e(self.concentration) /
         lax.bessel_i0e(self.concentration), self.batch_shape)
Exemple #11
0
    def call(
        self,
        inputs: jnp.ndarray,
        scale: Optional[jnp.ndarray] = None,
        offset: Optional[jnp.ndarray] = None,
    ) -> jnp.ndarray:
        """Connects the layer norm.

        Args:
          inputs: An array, where the data format is ``[N, ..., C]``.
          scale: An array up to n-D. The shape of this tensor must be broadcastable
            to the shape of ``inputs``. This is the scale applied to the normalized
            inputs. This cannot be passed in if the module was constructed with
            ``create_scale=True``.
          offset: An array up to n-D. The shape of this tensor must be broadcastable
            to the shape of ``inputs``. This is the offset applied to the normalized
            inputs. This cannot be passed in if the module was constructed with
            ``create_offset=True``.

        Returns:
          The array, normalized.
        """
        if self.create_scale and scale is not None:
            raise ValueError("Cannot pass `scale` at call time if `create_scale=True`.")
        if self.create_offset and offset is not None:
            raise ValueError(
                "Cannot pass `offset` at call time if `create_offset=True`."
            )

        axis = self.axis
        if isinstance(axis, slice):
            axis = tuple(range(inputs.ndim)[axis])

        mean = jnp.mean(inputs, axis=axis, keepdims=True)
        variance = jnp.var(inputs, axis=axis, keepdims=True)

        param_shape = inputs.shape[-1:]
        if self.create_scale:
            scale = self.add_parameter(
                "scale",
                lambda: self.scale_init(
                    param_shape,
                    jnp.float32,
                ),
            )
        elif scale is None:
            scale = np.array(1.0, dtype=inputs.dtype)

        if self.create_offset:
            offset = self.add_parameter(
                "offset",
                lambda: self.offset_init(
                    param_shape,
                    jnp.float32,
                ),
            )
        elif offset is None:
            offset = np.array(0.0, dtype=inputs.dtype)

        scale = jnp.broadcast_to(scale, inputs.shape)
        offset = jnp.broadcast_to(offset, inputs.shape)
        mean = jnp.broadcast_to(mean, inputs.shape)

        inv = scale * jax.lax.rsqrt(variance + self.eps)
        return inv * (inputs - mean) + offset
Exemple #12
0
 def mean(self):
     """ Computes circular mean of distribution. NOTE: same as location when mapped to support [-pi, pi] """
     return jnp.broadcast_to((self.loc + jnp.pi) % (2. * jnp.pi) - jnp.pi,
                             self.batch_shape)
Exemple #13
0
    def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape `[batch_sizes..., length, features]`.
      mask: attention mask of shape `[batch_sizes..., num_heads, query_length,
        key/value_length]`. Attention weights are masked out if their
        corresponding mask value is `False`.
      deterministic: if false, the attention weight is masked randomly using
        dropout, whereas if true, the attention weights are deterministic.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
            deterministic = merge_param('deterministic', self.deterministic,
                                        deterministic)
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = functools.partial(DenseGeneral,
                                  axis=-1,
                                  dtype=self.dtype,
                                  param_dtype=self.param_dtype,
                                  features=(self.num_heads, head_dim),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  precision=self.precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(name='query')(inputs_q),
                             dense(name='key')(inputs_kv),
                             dense(name='value')(inputs_kv))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape)
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads,
                                                      depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        'expected query shape %s instead got %s.' %
                        (expected_shape, query.shape))
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length)))

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        toeplitz_params = self.param('toeplitz_params', ones, (
            query.shape[-2],
            4 * self.nb_x_patches * self.nb_y_patches,
        ))

        # apply attention
        x = self.attention_fn(query,
                              key,
                              value,
                              toeplitz_params,
                              mask=mask,
                              dropout_rng=dropout_rng,
                              dropout_rate=self.dropout_rate,
                              broadcast_dropout=self.broadcast_dropout,
                              deterministic=deterministic,
                              dtype=self.dtype,
                              precision=self.precision,
                              nb_x_patches=self.nb_x_patches,
                              nb_y_patches=self.nb_y_patches)  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = DenseGeneral(features=features,
                           axis=(-2, -1),
                           kernel_init=self.kernel_init,
                           bias_init=self.bias_init,
                           use_bias=self.use_bias,
                           dtype=self.dtype,
                           param_dtype=self.param_dtype,
                           precision=self.precision,
                           name='out')(x)
        return out
Exemple #14
0
def top_k(x, k):
  """Select the top k slices from the last dimension."""
  bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape)
  sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs)
  # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization
  return sorted_vals[..., -k:], sorted_idxs[..., -k:]
    def __call__(
        self,
        pixel_values: jnp.ndarray,
        decoder_input_ids: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Examples:

        ```python
        >>> from transformers import FlaxVisionEncoderDecoderModel, ViTFeatureExtractor, GPT2Tokenizer
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # load output tokenizer
        >>> tokenizer_output = GPT2Tokenizer.from_pretrained("gpt2")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values

        >>> # use GPT2's eos_token as the pad as well as eos token
        >>> model.config.eos_token_id = model.config.decoder.eos_token_id
        >>> model.config.pad_token_id = model.config.eos_token_id

        >>> # generation
        >>> sequences = model.generate(pixel_values, num_beams=4, max_length=12).sequences

        >>> captions = tokenizer_output.batch_decode(sequences, skip_special_tokens=True)
        ```"""

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        # prepare encoder inputs

        # `FlaxViTModel` expects channel first format, but `FlaxViTModule` expects channel last format.
        # Currently, we assume this holds for all Flax vision models, and perform a transpose here.
        pixel_values = jnp.transpose(pixel_values, (0, 2, 3, 1))

        # prepare decoder inputs
        if decoder_input_ids is None:
            raise ValueError("`decoder_input_ids` can't be `None`.")
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones_like(decoder_input_ids)
        if decoder_position_ids is None:
            batch_size, sequence_length = decoder_input_ids.shape
            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {"dropout": dropout_rng} if dropout_rng is not None else {}

        return self.module.apply(
            {"params": params or self.params},
            pixel_values=jnp.array(pixel_values, dtype=self.dtype),
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
        )
Exemple #16
0
  def log_prob(self, value):
    """Computes the log prob of the discrete mixture."""
    if value.ndim != 4:
      raise ValueError(f'Value is assumed to have 4 dims, got ' f'{value.ndim}')
    num_channels = value.shape[-1]

    if num_channels != self._nb_channels:
      raise ValueError(f'value should have {self._nb_channels} channels, '
                       f'got {num_channels}.')

    # [bs, height, width, num_channels]
    # bring to [-1, 1]
    value = value.astype(jnp.float32)
    value = value / 255.
    value = 2 * value - 1
    shape = value.shape
    value = value[Ellipsis, None]  # add a mixture axis

    # [bs, height, width, num_channels, num_mix]
    value = jnp.broadcast_to(value, shape + (self._num_mixtures,))

    if num_channels == 3:
      # RGB channels are produced autoregressively, thus the mean of the second
      # channel depends on the first channel of the value, and the mean of
      # the third depends on the 2 first channels.
      mean_r = self._means[:, :, :, 0, :]
      mean_g = (
          self._means[:, :, :, 1, :] +
          self._rgb_mix_coeffs[:, :, :, 0, :] * value[:, :, :, 0, :])
      mean_b = (
          self._means[:, :, :, 2, :] +
          self._rgb_mix_coeffs[:, :, :, 1, :] * value[:, :, :, 0, :] +
          self._rgb_mix_coeffs[:, :, :, 2, :] * value[:, :, :, 1, :])
      # [bs, height, width, num_channels, num_mix]
      means = jnp.stack([mean_r, mean_g, mean_b], axis=3)
    else:
      means = self._means

    centered = value - means

    # Each element of the mixture uses
    # sigmoid((x-m+1/max_val)/scale) - sigmoid((x-m-1/max_val)/scale)
    # as the probability for x, except for extremal xs, i.e. x = -1 and x = 1.
    # For x = 1, 1 - sigmoid((x - m - 1 / max_val) / scale) is taken,
    # For x = -1, sigmoid((x - m + 1 / max_val) / scale) is taken. One can
    # see that the final quantities sum to 1 when x ranges over -1, 1 with
    # steps 2. / max_val. See https://arxiv.org/pdf/1701.05517.pdf for a
    # more in depth discussion.
    inv_std = jnp.exp(-self._log_scales)
    plus_in = inv_std * (centered + 1. / self._max_val)
    cdf_plus = jax.nn.sigmoid(plus_in)
    minus_in = inv_std * (centered - 1. / self._max_val)
    cdf_min = jax.nn.sigmoid(minus_in)
    # For numerical stability, we only take x - softplus(x) when softplus(x)
    # is far from x, i.e. when x < 0. Otherwise we may end up computing
    # inaccurately small quantities.
    log_cdf_plus = jnp.where(plus_in > 0, -jax.nn.softplus(-plus_in),
                             plus_in - jax.nn.softplus(plus_in))
    log_one_minus_cdf_min = -jax.nn.softplus(minus_in)
    cdf_delta = cdf_plus - cdf_min

    # When the difference of CDFs get too small, numerical instabilities
    # can happen. To prevent this, a first order approximation of the
    # difference is taken in such cases.
    mid_in = inv_std * centered
    log_pdf_mid = mid_in - self._log_scales - 2 * jax.nn.softplus(mid_in)
    log_prob_mid_safe = jnp.where(cdf_delta > 1e-5,
                                  jnp.log(jnp.clip(cdf_delta, a_min=1e-10)),
                                  log_pdf_mid - jnp.log(self._max_val / 2))

    # Discrete values go from -1 to 1 with a 2 / max_val step.
    # To confidently catch all -1 values, we compare to the midpoint
    # between -1 and the closest discrete value -1 + 2 / max_val, i.e.
    # 1. / self.max_val - 1. We do the same for the max value.
    log_probs = jnp.where(
        value < 1. / self._max_val - 1.,
        log_cdf_plus,
        jnp.where(
            value > 1. - 1. / self._max_val,
            log_one_minus_cdf_min,
            log_prob_mid_safe,
        ),
    )

    # Compute the mixture log probs, i.e. sum log probs
    # along the channel axis, then add the mixture log probs
    # (each pixel as its own mixture, but not each channel)
    # [bs, height, width, num_mix]
    log_probs = (
        jnp.sum(log_probs, axis=3) +
        jax.nn.log_softmax(self._logits_probs, axis=-1))
    # [bs, height, width]
    log_probs = jax.nn.logsumexp(log_probs, axis=-1)
    log_probs = jnp.sum(log_probs, axis=(-1, -2))
    return log_probs
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import FlaxSpeechEncoderDecoderModel
        >>> import jax.numpy as jnp

        >>> # initialize a wav2vec2-2-bart from pretrained wav2vec2 and bart models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxSpeechEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "facebook/wav2vec2-large-lv60", "facebook/bart-large"
        ... )

        >>> inputs = jnp.ones((2, 5000), dtype=jnp.float32)
        >>> encoder_outputs = model.encode(inputs)

        >>> decoder_start_token_id = model.config.decoder.bos_token_id
        >>> decoder_input_ids = jnp.ones((inputs.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> logits = outputs.logits
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0]
        if encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `decoder_position_ids` when passing `past_key_values`."
                )

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        params = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            params["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(module, decoder_input_ids, decoder_attention_mask,
                             decoder_position_ids, encoder_hidden_states,
                             **kwargs):

            projection_module = module._get_projection_module()
            decoder_module = module._get_decoder_module()

            # optionally project encoder_hidden_states
            if projection_module is not None:
                encoder_hidden_states = projection_module(
                    encoder_hidden_states)

            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                encoder_hidden_states,
                **kwargs,
            )

        outputs = self.module.apply(
            params,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask,
                                             dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask,
                                             dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]), ) + outputs[1:]

        return outputs
Exemple #18
0
def _new_eye(x, shape):
    n = shape[-1]
    return np.broadcast_to(np.eye(n), shape + (n, ))
Exemple #19
0
 def sample(self, key, sample_shape=()):
     shape = sample_shape + self.batch_shape + self.event_shape
     return jnp.broadcast_to(self.v, shape)
Exemple #20
0
    def get_text_features(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        params: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train=False,
    ):
        r"""
        Args:
            input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you
                provide it.

                Indices can be obtained using :class:`~transformers.CLIPTokenizer`. See
                :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
                for details.

                `What are input IDs? <../glossary.html#input-ids>`__

        Returns:
            text_features (:obj:`jnp.ndarray` of shape :obj:`(batch_size, output_dim`): The text embeddings obtained by
            applying the projection layer to the pooled output of :class:`~transformers.FlaxCLIPTextModel`.

        Examples::

            >>> from transformers import CLIPTokenizer, FlaxCLIPModel

            >>> model = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
            >>> tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32")

            >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"],  padding=True, return_tensors="np")
            >>> text_features = model.get_text_features(**inputs)
        """
        if position_ids is None:
            position_ids = jnp.broadcast_to(
                jnp.arange(jnp.atleast_2d(input_ids).shape[-1]),
                input_ids.shape)

        if attention_mask is None:
            attention_mask = jnp.ones_like(input_ids)

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        def _get_features(module, input_ids, attention_mask, position_ids,
                          deterministic):
            text_outputs = module.text_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                position_ids=position_ids,
                deterministic=deterministic,
            )
            pooled_output = text_outputs[1]
            text_features = module.text_projection(pooled_output)
            return text_features

        return self.module.apply(
            {"params": params or self.params},
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            not train,
            method=_get_features,
            rngs=rngs,
        )
Exemple #21
0
 def log_prob(self, value):
     shape = lax.broadcast_shapes(
         self.batch_shape,
         jnp.shape(value)[:max(jnp.ndim(value) - self.event_dim, 0)])
     log_prob = self.base_dist.log_prob(value)
     return jnp.broadcast_to(log_prob, shape)
    def __call__(
        self,
        hidden_states,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        if not is_cross_attention:
            qkv_out = self.c_attn(hidden_states)
            query, key, value = jnp.split(qkv_out, 3, axis=2)
        else:
            q_out = self.q_attn(hidden_states)
            (query, ) = jnp.split(q_out, 1, axis=2)
            kv_out = self.c_attn(key_value_states)
            key, value = jnp.split(kv_out, 2, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
Exemple #23
0
 def variance(self):
     return jnp.broadcast_to(self.base_dist.variance,
                             self.batch_shape + self.event_shape)
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (output_hidden_states
                                if output_hidden_states is not None else
                                self.config.output_hidden_states)
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if encoder_hidden_states is not None and encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError(
                    "Make sure to provide `position_ids` when passing `past_key_values`."
                )

            position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :],
                (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            encoder_hidden_states,
            encoder_attention_mask,
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(
                past_key_values["cache"]), ) + outputs[1:]

        return outputs
Exemple #25
0
def volumetric_rendering(rgb,
                         sigma,
                         z_vals,
                         dirs,
                         use_white_background,
                         sample_at_infinity=True,
                         return_weights=False,
                         eps=1e-10):
    """Volumetric Rendering Function.

  Args:
    rgb: an array of size (B,S,3) containing the RGB color values.
    sigma: an array of size (B,S,1) containing the densities.
    z_vals: an array of size (B,S) containing the z-coordinate of the samples.
    dirs: an array of size (B,3) containing the directions of rays.
    use_white_background: whether to assume a white background or not.
    sample_at_infinity: if True adds a sample at infinity.
    return_weights: if True returns the weights in the dictionary.
    eps: a small number to prevent numerical issues.

  Returns:
    A dictionary containing:
      rgb: an array of size (B,3) containing the rendered colors.
      depth: an array of size (B,) containing the rendered depth.
      acc: an array of size (B,) containing the accumulated density.
      weights: an array of size (B,S) containing the weight of each sample.
  """
    # TODO(keunhong): remove this hack.
    last_sample_z = 1e10 if sample_at_infinity else 1e-19
    dists = jnp.concatenate([
        z_vals[..., 1:] - z_vals[..., :-1],
        jnp.broadcast_to([last_sample_z], z_vals[..., :1].shape)
    ], -1)
    dists = dists * jnp.linalg.norm(dirs[..., None, :], axis=-1)
    alpha = 1.0 - jnp.exp(-sigma * dists)
    # Prepend a 1.0 to make this an 'exclusive' cumprod as in `tf.math.cumprod`.
    accum_prod = jnp.concatenate([
        jnp.ones_like(alpha[..., :1], alpha.dtype),
        jnp.cumprod(1.0 - alpha[..., :-1] + eps, axis=-1),
    ],
                                 axis=-1)
    weights = alpha * accum_prod

    rgb = (weights[..., None] * rgb).sum(axis=-2)
    exp_depth = (weights * z_vals).sum(axis=-1)
    med_depth = compute_depth_map(weights, z_vals)
    acc = weights.sum(axis=-1)
    if use_white_background:
        rgb = rgb + (1. - acc[..., None])

    if sample_at_infinity:
        acc = weights[..., :-1].sum(axis=-1)

    out = {
        'rgb': rgb,
        'depth': exp_depth,
        'med_depth': med_depth,
        'acc': acc,
    }
    if return_weights:
        out['weights'] = weights
    return out
Exemple #26
0
  def __call__(self,
               inputs_q,
               inputs_kv,
               mask = None,
               custom_relative_position = None,
               deterministic = None):
    """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.
        Attention weights are masked out if their corresponding mask value
        is `False`.
      custom_relative_position: relative positions tensor
        `[batch_sizes..., query_length, key/value_length]'
      deterministic: if false, the attention weight is masked randomly
        using dropout, whereas if true, the attention weights
        are deterministic.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
    if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
      deterministic = module.merge_param('deterministic', self.deterministic,
                                         deterministic)
    features = self.out_features or inputs_q.shape[-1]
    qkv_features = self.qkv_features or inputs_q.shape[-1]
    assert qkv_features % self.num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // self.num_heads

    dense = functools.partial(linear.DenseGeneral,
                              axis=-1,
                              features=(self.num_heads, head_dim),
                              kernel_init=self.kernel_init,
                              bias_init=self.bias_init,
                              use_bias=self.use_bias,
                              precision=self.precision)
    relative_attention_embed = linear.Embed(
        num_embeddings=self.num_relative_position_buckets,
        features=self.num_heads,
        embedding_init=initializers.normal(stddev=1.0),
        dtype=self.dtype)

    # project inputs_q to multi-headed q/k/v
    # dimensions are then [batch..., length, n_heads, n_features_per_head]
    query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                         dense(dtype=self.dtype, name='key')(inputs_kv),
                         dense(dtype=self.dtype, name='value')(inputs_kv))

    if custom_relative_position is None:
      query_length = inputs_q.shape[-2]
      key_length = inputs_kv.shape[-2]
      context_position = jnp.arange(query_length, dtype=jnp.int32)[:, None]
      memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :]

      if self.mod_position > 0:
        memory_position %= self.mod_position
        context_position %= self.mod_position

      relative_position = memory_position - context_position
      relative_position_bucket = make_relative_position_bucket(
          relative_position,
          bidirectional=self.bidirectional,
          num_buckets=self.num_relative_position_buckets,
          max_distance=self.max_distance)

      bias = relative_attention_embed(relative_position_bucket)
      bias = bias.transpose((2, 0, 1))
      # Expand batch dimensions.
      bias = jnp.broadcast_to(
          bias, (1,) * len(inputs_q.shape[:-2]) + bias.shape)

    else:
      relative_position = custom_relative_position
      relative_position_bucket = make_relative_position_bucket(
          relative_position,
          bidirectional=self.bidirectional,
          num_buckets=self.num_relative_position_buckets,
          max_distance=self.max_distance)

      bias = relative_attention_embed(relative_position_bucket)
      permute = tuple(map(lambda i: len(inputs_q.shape) + 1 + i, (-1, -3, -2)))
      bias = bias.transpose(
          tuple(range(len(inputs_q.shape[:-2]))) + permute)

    # During fast autoregressive decoding, we feed one position at a time,
    # and cache the keys and values step by step.
    if self.decode:
      # detect if we're initializing by absence of existing cache data.
      is_initialized = self.has_variable('cache', 'cached_key')
      cached_key = self.variable('cache', 'cached_key',
                                 jnp.zeros, key.shape, key.dtype)
      cached_value = self.variable('cache', 'cached_value',
                                   jnp.zeros, value.shape, value.dtype)
      cache_index = self.variable('cache', 'cache_index',
                                  lambda: jnp.array(0, dtype=jnp.int32))
      if is_initialized:
        *batch_dims, max_length, num_heads, depth_per_head = (
            cached_key.value.shape)
        # shape check of cached keys against query input
        expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head)
        if expected_shape != query.shape:
          raise ValueError('Autoregressive cache shape error, '
                           'expected query shape %s instead got %s.' %
                           (expected_shape, query.shape))
        # update key, value caches with our new 1d spatial slices
        cur_index = cache_index.value
        indices = (0,) * len(batch_dims) + (cur_index, 0, 0)
        key = lax.dynamic_update_slice(cached_key.value, key, indices)
        value = lax.dynamic_update_slice(cached_value.value, value, indices)
        cached_key.value = key
        cached_value.value = value
        cache_index.value = cache_index.value + 1
        # causal mask for cached decoder self-attention:
        # our single query position should only attend to those key
        # positions that have already been generated and cached,
        # not the remaining zero elements.
        mask = attention.combine_masks(
            mask,
            jnp.broadcast_to(jnp.arange(max_length) <= cur_index,
                             tuple(batch_dims) + (1, 1, max_length)))

        bias = lax.dynamic_slice(
            bias,
            (0, 0, cur_index, 0),
            (1, self.num_heads, 1, max_length))

    # Convert the boolean attention mask to an attention bias.
    if mask is not None:
      # attention mask in the form of attention bias
      bias += lax.select(
          mask > 0,
          jnp.full(mask.shape, 0.).astype(self.dtype),
          jnp.full(mask.shape, -1e10).astype(self.dtype))

    dropout_rng = None
    if not deterministic and self.dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')

    # apply attention
    x = attention.dot_product_attention(
        query,
        key,
        value,
        bias=bias,
        dropout_rng=dropout_rng,
        dropout_rate=self.dropout_rate,
        broadcast_dropout=self.broadcast_dropout,
        deterministic=deterministic,
        dtype=self.dtype,
        precision=self.precision)  # pytype: disable=wrong-keyword-args
    # back to the original inputs dimensions
    out = linear.DenseGeneral(features=features,
                              axis=(-2, -1),
                              kernel_init=self.kernel_init,
                              bias_init=self.bias_init,
                              use_bias=self.use_bias,
                              dtype=self.dtype,
                              precision=self.precision,
                              name='out')(x)
    return out
Exemple #27
0
 def __call__(self, shape: Sequence[int], dtype) -> jnp.ndarray:
     return jnp.broadcast_to(self.constant, shape).astype(dtype)
    def decode(
        self,
        decoder_input_ids,
        encoder_outputs,
        decoder_attention_mask: Optional[jnp.ndarray] = None,
        decoder_position_ids: Optional[jnp.ndarray] = None,
        past_key_values: dict = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        train: bool = False,
        params: dict = None,
        dropout_rng: PRNGKey = None,
    ):
        r"""
        Returns:

        Example:

        ```python
        >>> from transformers import ViTFeatureExtractor, FlaxVisionEncoderDecoderModel
        >>> import jax.numpy as jnp
        >>> from PIL import Image
        >>> import requests

        >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")

        >>> # initialize a vit-gpt2 from pretrained ViT and GPT2 models. Note that the cross-attention layers will be randomly initialized
        >>> model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
        ...     "google/vit-base-patch16-224-in21k", "gpt2"
        ... )

        >>> pixel_values = feature_extractor(images=image, return_tensors="np").pixel_values
        >>> encoder_outputs = model.encode(pixel_values)

        >>> decoder_start_token_id = model.config.decoder.bos_token_id
        >>> decoder_input_ids = jnp.ones((pixel_values.shape[0], 1), dtype="i4") * decoder_start_token_id

        >>> outputs = model.decode(decoder_input_ids, encoder_outputs)
        >>> logits = outputs.logits
        ```"""
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        encoder_hidden_states = encoder_outputs[0]

        batch_size, sequence_length = encoder_hidden_states.shape[:2]
        encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = decoder_input_ids.shape
        if decoder_attention_mask is None:
            decoder_attention_mask = jnp.ones((batch_size, sequence_length))

        if decoder_position_ids is None:
            if past_key_values is not None:
                raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")

            decoder_position_ids = jnp.broadcast_to(
                jnp.arange(sequence_length)[None, :], (batch_size, sequence_length)
            )

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be
        # passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
        # it can be changed by FlaxBartAttention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        def _decoder_forward(
            module, decoder_input_ids, decoder_attention_mask, decoder_position_ids, encoder_hidden_states, **kwargs
        ):

            projection_module = module._get_projection_module()
            decoder_module = module._get_decoder_module()

            # optionally project encoder_hidden_states
            if projection_module is not None:
                encoder_hidden_states = projection_module(encoder_hidden_states)

            return decoder_module(
                decoder_input_ids,
                decoder_attention_mask,
                decoder_position_ids,
                encoder_hidden_states,
                **kwargs,
            )

        outputs = self.module.apply(
            inputs,
            decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
            decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
            decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            deterministic=not train,
            rngs=rngs,
            mutable=mutable,
            method=_decoder_forward,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past = outputs
            outputs["past_key_values"] = unfreeze(past["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past = outputs
            outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]

        return outputs
Exemple #29
0
 def second(x, y):
     return jnp.broadcast_to(y, x.shape)
        def beam_search_body_fn(state):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
            )
            model_outputs = model(input_token, params=params, **state.model_kwargs)
            logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
            )

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(
                flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
            )
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(
                state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
            )
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
            topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)

            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
            )

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
            beams_in_batch_are_full = (
                jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
                & early_stopping
            )
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
            merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
            )

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
            model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )