示例#1
0
def onehot(labels, num_classes, on_value=1.0, off_value=0.0):
    x = (labels[..., None] == jnp.arange(num_classes)[None])
    x = lax.select(x, jnp.full(x.shape, on_value),
                   jnp.full(x.shape, off_value))
    return x.astype(jnp.float32)
示例#2
0
def soft_maze_values(
    actions,
    target_state_index,
    estimates=None,
    temperature=1.0,
):
    """Compute the value function for a maximum-entropy-optimal policy in a maze.

  Tis function assumes we have an un-discounted reward function that is -1
  everywhere except for a specific goal location, which has 0 reward and
  instantly terminates the episode. It then returns the value of being at each
  state for an optimal entropy-regularized policy that attempts to maximize

    E[ sum_t r_t + temperature * H(policy) ].

  It does this by iterating the soft Bellman equations

    V(s) = t * log(mean(exp(Q(s, a)/t)))  (arithmetic mean over actions)
    Q(s, a) = r_{s,a} + E[V(s')]          (expectation over env dynamics)

  See https://arxiv.org/abs/1702.08165 ("RL with Deep Energy-Based Policies")

  Args:
    actions: <float32[num_states, num_states, num_actions]> which determines a
      distribution over where you end up if you take each given action. The
      policy will choose a distribution over actions to take at each state by
      considering the expected value of taking each action, but does not get to
      decide where it ends up within the selected action's distribution.
    target_state_index: State that terminates the episode.
    estimates: Initial estimates of values.
    temperature: Temperature of the soft updates, which controls how much
      implicit entropy regularization there is. For temperatures near zero, the
      policy becomes a deterministic one and the values represent shortest
      paths. For larger temperatures, the policy will have more entropy and
      paths may be longer than necessary. Larger temperatures are likely to be
      easier to differentiate through. (note: currently the output is NOT
        differentiable w.r.t. the temerature hyperparam.)

  Returns:
    values: <float[num_states]> giving the expected future reward of being in
      each state under the max-ent-optimal policy.
    q_values: <float[num_states, num_actions]> giving q values of each action.
    policy: <float[num_states, num_actions]> giving the probability of taking
      each action at each state.
  """
    num_states, _, num_actions = actions.shape
    if estimates is None:
        # Worst case values, assuming a connected graph: suppose the graph was a
        # linear chain, and the policy walks completely randomly from the start to
        # the end. Then the expected number of steps that it would take would be
        # `num_states**2` (see
        # https://en.wikipedia.org/wiki/Random_walk#One-dimensional_random_walk,
        # since this is equivalent to starting at 0 and walking to either
        # -num_states or num_states)
        estimates = jnp.full([num_states], -float(num_states)**2)

    def soft_bellman_backup(stuff, current_values):
        # Need to pass target_state_index because custom gradient machinery doesn't
        # understand it when closing over batched values.
        actions, target_state_index = stuff
        # Compute Q values for taking each action at each state
        # (every non-goal state has immediate reward -1).
        q_values = -1 + jnp.einsum("sta,t->sa", actions, current_values)
        # Compute value of current states by max-ent Bellman.
        new_v_values = (jax.scipy.special.logsumexp(q_values / temperature, -1)
                        - jnp.log(num_actions)) * temperature
        # Goal state is fixed to value 0 (episode terminates before the agent takes
        # any action).
        new_v_values = jax.ops.index_update(new_v_values,
                                            jax.ops.index[target_state_index],
                                            0.0)
        return new_v_values

    # Must iterate at least num_states to guarantee everything is reachable;
    # we iterate a bit longer to make sure things converge.
    soft_v_values = iterative_fixed_point(soft_bellman_backup,
                                          (actions, target_state_index),
                                          estimates,
                                          iterations=num_states * 2)

    # Also extract the final Q values and entropy-regularized optimal policy.
    soft_q_values = -1 + jnp.einsum("sta,t->sa", actions, soft_v_values)
    policy = jax.nn.softmax(soft_q_values / temperature, -1)

    return soft_v_values, soft_q_values, policy
示例#3
0
    def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.
    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.
        Attention weights are masked out if their corresponding mask value
        is `False`.
      deterministic: if false, the attention weight is masked randomly
        using dropout, whereas if true, the attention weights
        are deterministic.
    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
            deterministic = module.merge_param('deterministic',
                                               self.deterministic,
                                               deterministic)
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = functools.partial(linear.DenseGeneral,
                                  axis=-1,
                                  features=(self.num_heads, head_dim),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  precision=self.precision)
        relative_attention_embed = linear.Embed(
            num_embeddings=self.num_relative_position_buckets,
            features=self.num_heads,
            embedding_init=initializers.normal(stddev=1.0),
            dtype=self.dtype)

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(dtype=self.dtype, name='query')(inputs_q),
                             dense(dtype=self.dtype, name='key')(inputs_kv),
                             dense(dtype=self.dtype, name='value')(inputs_kv))

        query_length = inputs_q.shape[-2]
        key_length = inputs_kv.shape[-2]
        context_position = jnp.arange(query_length, dtype=jnp.int32)[:, None]
        memory_position = jnp.arange(key_length, dtype=jnp.int32)[None, :]

        relative_position = memory_position - context_position
        relative_position_bucket = make_relative_position_bucket(
            relative_position,
            causal=self.causal,
            num_buckets=self.num_relative_position_buckets)

        bias = relative_attention_embed(relative_position_bucket)
        bias = bias.transpose((2, 0, 1))[None, :, :, :]

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.decode:
            # detect if we're initializing by absence of existing cache data.
            is_initialized = self.has_variable('cache', 'cached_key')
            cached_key = self.variable('cache', 'cached_key', jnp.zeros,
                                       key.shape, key.dtype)
            cached_value = self.variable('cache', 'cached_value', jnp.zeros,
                                         value.shape, value.dtype)
            cache_index = self.variable('cache', 'cache_index',
                                        lambda: jnp.array(0, dtype=jnp.int32))
            if is_initialized:
                *batch_dims, max_length, num_heads, depth_per_head = (
                    cached_key.value.shape)
                # shape check of cached keys against query input
                expected_shape = tuple(batch_dims) + (1, num_heads,
                                                      depth_per_head)
                if expected_shape != query.shape:
                    raise ValueError(
                        'Autoregressive cache shape error, '
                        'expected query shape %s instead got %s.' %
                        (expected_shape, query.shape))
                # update key, value caches with our new 1d spatial slices
                cur_index = cache_index.value
                indices = (0, ) * len(batch_dims) + (cur_index, 0, 0)
                key = lax.dynamic_update_slice(cached_key.value, key, indices)
                value = lax.dynamic_update_slice(cached_value.value, value,
                                                 indices)
                cached_key.value = key
                cached_value.value = value
                cache_index.value = cache_index.value + 1
                # causal mask for cached decoder self-attention:
                # our single query position should only attend to those key
                # positions that have already been generated and cached,
                # not the remaining zero elements.
                mask = attention.combine_masks(
                    mask,
                    jnp.broadcast_to(
                        jnp.arange(max_length) <= cur_index,
                        tuple(batch_dims) + (1, 1, max_length)))

                bias = lax.dynamic_slice(bias, (0, 0, cur_index, 0),
                                         (1, self.num_heads, 1, max_length))

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            bias += lax.select(mask > 0,
                               jnp.full(mask.shape, 0.).astype(self.dtype),
                               jnp.full(mask.shape, -1e10).astype(self.dtype))

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # apply attention
        x = attention.dot_product_attention(
            query,
            key,
            value,
            bias=bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout_rate,
            broadcast_dropout=self.broadcast_dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=self.precision)  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = linear.DenseGeneral(features=features,
                                  axis=(-2, -1),
                                  kernel_init=self.kernel_init,
                                  bias_init=self.bias_init,
                                  use_bias=self.use_bias,
                                  dtype=self.dtype,
                                  precision=self.precision,
                                  name='out')(x)
        return out
    def __call__(
        self,
        hidden_states,
        attention_mask,
        sinusoidal_pos,
        layer_head_mask,
        deterministic=True,
        output_attentions: bool = False,
    ):
        head_dim = self.config.hidden_size // self.config.num_attention_heads

        query_states = self.query(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        value_states = self.value(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )
        key_states = self.key(hidden_states).reshape(
            hidden_states.shape[:2] + (self.config.num_attention_heads, head_dim)
        )

        if sinusoidal_pos is not None:
            if self.rotary_value:
                query_states, key_states, value_states = self.apply_rotary_position_embeddings(
                    sinusoidal_pos, query_states, key_states, value_states
                )
            else:
                query_states, key_states = self.apply_rotary_position_embeddings(
                    sinusoidal_pos, query_states, key_states
                )

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e10).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.config.attention_probs_dropout_prob > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = jnp.einsum("...hqk,h->...hqk", attn_weights, layer_head_mask)

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
        attn_output = attn_output.reshape(attn_output.shape[:2] + (-1,))

        outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
        return outputs
示例#5
0
    def _sample(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        prng_key: Optional[jax_xla.DeviceArray] = None,
        logits_warper: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
        model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
    ):
        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)

        batch_size, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch-item holding current token in loop.
        sequences = jnp.full((batch_size, max_length),
                             pad_token_id,
                             dtype=jnp.int32)
        sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))

        # per batch-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, ), dtype=jnp.bool_)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            input_ids, max_length, **model_kwargs)

        # initialize state
        state = SampleState(
            cur_len=cur_len,
            sequences=sequences,
            current_token=input_ids,
            is_sent_finished=is_sent_finished,
            prng_key=prng_key,
            model_kwargs=model_kwargs,
        )

        def sample_search_cond_fn(state):
            """state termination condition fn."""
            has_reached_max_length = state.cur_len == max_length
            all_sequence_finished = jnp.all(state.is_sent_finished)
            finish_generation = jnp.logical_or(has_reached_max_length,
                                               all_sequence_finished)
            return ~finish_generation

        def sample_search_body_fn(state):
            """state update fn."""
            prng_key, prng_key_next = jax.random.split(state.prng_key)
            model_outputs = model(state.current_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = model_outputs.logits[:, -1]

            # apply top_k, top_k, temperature
            logits = logits_warper(state.sequences, logits)

            next_token = jax.random.categorical(prng_key,
                                                model_outputs.logits[:, -1],
                                                axis=-1)

            next_is_sent_finished = state.is_sent_finished | (next_token
                                                              == eos_token_id)
            next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
            next_token = next_token[:, None]

            next_sequences = lax.dynamic_update_slice(state.sequences,
                                                      next_token,
                                                      (0, state.cur_len))
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return SampleState(
                cur_len=state.cur_len + 1,
                sequences=next_sequences,
                current_token=next_token,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
                prng_key=prng_key_next,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[1] > 1:
            state = sample_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(sample_search_cond_fn,
                                            sample_search_body_fn, state)
        else:
            state = lax.while_loop(sample_search_cond_fn,
                                   sample_search_body_fn, state)

        return FlaxSampleOutput(sequences=state.sequences)
示例#6
0
 def sample_eval_batch(key, k, num_data_points):
     return gmm_dist.sample_batch_fixed_ks2(key, model_name,
                                            jnp.full([eval_batch_size], k),
                                            max_k, max_num_data_points,
                                            data_dim, mode_var, cov_dof,
                                            separation_mult)
示例#7
0
    def __call__(
        self,
        x,
        *,
        mask,
    ):
        """Applies a tag to track distributions.

    Args:
      x: the array to compute statistics distributions over.
      mask: boolean array indicating which elements of 'x' should be
        included in the stats calculation ('True' means to include).

    Returns:
      x unchanged. The return value can also be ignored.
    """
        if mask is None:
            mask = jnp.full(x.shape, True)
        shape_utils.assert_shapes_compatible(x.shape, mask.shape)
        mask = jnp.broadcast_to(mask, x.shape)
        channel_axis = self.channel_axis
        if channel_axis is not None:
            if not isinstance(channel_axis, Iterable):
                channel_axis = (channel_axis, )
            channel_axis = normalize_axes(channel_axis, x.ndim)
            x = _take_subset_of_axes(
                x,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            mask = _take_subset_of_axes(
                mask,
                axis=channel_axis,
                num_indices_per_ax=self.num_indices_per_ax)
            reduction_axis = tuple(
                [ax for ax in range(x.ndim) if ax not in channel_axis])
        else:
            reduction_axis = None

        distr_shape = ()
        if channel_axis:
            distr_shape = tuple(d for i, d in enumerate(x.shape)
                                if i in channel_axis)

        # TODO(wanglisa): Consider adding configurability to specify which
        # statistics are collected.
        init_with_zeros = lambda shape: jnp.zeros(shape, dtype=jnp.float32)
        is_initializing = not self.has_variable('stats_tag', 'min_per_ch')
        min_per_ch = self.variable(
            'stats_tag',
            'min_per_ch',
            init_with_zeros,
            distr_shape,
        )
        max_per_ch = self.variable('stats_tag', 'max_per_ch', init_with_zeros,
                                   distr_shape)
        mean_per_ch = self.variable(
            'stats_tag',
            'mean_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch = self.variable(
            'stats_tag',
            'stddev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch = self.variable(
            'stats_tag',
            'absdev_per_ch',
            init_with_zeros,
            distr_shape,
        )
        stddev_per_ch_uncentered = self.variable(
            'stats_tag',
            'stddev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        absdev_per_ch_uncentered = self.variable(
            'stats_tag',
            'absdev_per_ch_uncentered',
            init_with_zeros,
            distr_shape,
        )
        if self.update_stats and not is_initializing:
            min_per_ch.value = jnp.min(jnp.where(mask, x, math.inf),
                                       axis=reduction_axis)
            max_per_ch.value = jnp.max(jnp.where(mask, x, -math.inf),
                                       axis=reduction_axis)
            mean_per_ch_keepdims = stats.masked_mean(x,
                                                     mask=mask,
                                                     axis=reduction_axis,
                                                     paxis_name=None,
                                                     keepdims=True)
            mean_per_ch.value = mean_per_ch_keepdims.squeeze(
                axis=reduction_axis)
            stddev_per_ch.value = jnp.sqrt(
                stats.masked_mean((x - mean_per_ch_keepdims)**2,
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch.value = stats.masked_mean(
                jnp.abs(x - mean_per_ch_keepdims),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)
            stddev_per_ch_uncentered.value = jnp.sqrt(
                stats.masked_mean(jnp.square(x),
                                  mask=mask,
                                  axis=reduction_axis,
                                  paxis_name=None,
                                  keepdims=False))
            absdev_per_ch_uncentered.value = stats.masked_mean(
                jnp.abs(x),
                mask=mask,
                axis=reduction_axis,
                paxis_name=None,
                keepdims=False)
示例#8
0
 def full(self: TensorType, shape: ShapeOrScalar,
          value: float) -> TensorType:
     if not isinstance(shape, Iterable):
         shape = (shape, )
     return type(self)(np.full(shape, value, dtype=self.raw.dtype))
示例#9
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return np.full(np.shape(x) if self.event_dim == 0 else np.shape(x)[:-1], 0.)
示例#10
0
 def hk_fn(x):
     return layer_stack.layer_stack(
         stack_height, with_per_layer_inputs=True)(f_with_multi_args)(
             x, jnp.full([stack_height], 2.), jnp.ones([stack_height]))
示例#11
0
 def model_classify(params, inputs, batch_size):
     return model.classify(params, inputs,
                           jnp.full([batch_size], k * data_points_per_mode),
                           jnp.full([batch_size], k))
示例#12
0
  def predict_fn(
      t: ArrayOrScalar = None,
      fx_train_or_state_0: Union[ArrayOrScalar, ODEState] = 0.,
      fx_test_0: ArrayOrScalar = None,
      k_test_train: np.ndarray = None
  ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray], ODEState]:
    """Return output predictions on train [and test] set[s] at time[s] `t`.

    Args:
      t:
        a scalar or array of scalars of any shape in strictly increasing order.
        `t=None` is equivalent to `t=np.inf` and may not converge. Equivalent of
        training steps (but can be fractional).
      fx_train_or_state_0:
        either (a) output of the network at `t == 0` on the training set or (b)
        complete ODE state (`predict.ODEState`). Pass an ODE state if you want
        to operate on the full ODE state instead of output variables only
        (useful for inspecting auxiliary variables or resuming an optimizer with
        auxiliary variables from a specific state. Note that only
        `momentum != None` optimizer currently has auxiliary variables. To
        initialize an ODE state from scratch, call
        `predict.ODEState(fx_train_0, fx_test_0)`. If an ODE state is passed, an
        ODE state is returned. `fx_train_0=None` means to not compute
        predictions on the training set.
      fx_test_0:
        output of the network at `t == 0` on the test set. `fx_test_0=None`
        means to not compute predictions on the test set.
      k_test_train:
        kernel relating test data with training data. Must have the shape of
        `zip(y_test.shape, y_train.shape)` with `trace_axes` absent. Pass
        `k_test_train=None` if you only need predictions on the training set.

    Returns:
      `fx_train_t` or `(fx_train_t, fx_test_t)` if `fx_test_0 != None` with
      potentially additional leading time dimensions matching `t.shape`.
      Alternatively can return an `ODEState` at time[s] `t`.

    Raises:
      ValueError: if `fx_test_0` is not `None`, but `k_test_train` is `None`.
    """
    _check_inputs(fx_train_or_state_0, fx_test_0, k_test_train)

    t = np.array(t if t is not None else np.inf, dtype) * learning_rate
    t_shape = t.shape
    t = t.reshape((-1,))

    # ODE solver requires `t[0]` to be the time where `fx_train_0` [and
    # `fx_test_0`] are evaluated, but also a strictly increasing sequence of
    # timesteps, so we always temporarily append an [almost] `0` at the start.
    t0 = np.where(t[0] == 0,
                  np.full((1,), -1e-24, t.dtype),
                  np.zeros((1,), t.dtype))
    t = np.concatenate([t0, t])

    # Solve the ODE.
    fx_test_shape = _get_fx_test_shape(y_train, k_test_train, trace_axes)
    state_0 = get_state_0(fx_train_or_state_0, fx_test_0, fx_test_shape)
    state_t = ode.odeint(get_dstate_dt(k_test_train), state_0, t)

    # Remove the added `t0`.
    trim = lambda x: x[1:].reshape(t_shape + x.shape[1:])
    trim_tree = lambda tree: tree_map(trim, tree)
    state_t = trim_tree(state_t)

    # `ODEState` -> `ODEState`
    if isinstance(fx_train_or_state_0, ODEState):
      return state_t

    # `np.ndarray` -> `np.ndarray`
    fx_train_t, fx_test_t = state_t.fx_train, state_t.fx_test

    if fx_train_or_state_0 is not None and fx_test_0 is None:
      return fx_train_t
    if fx_test_0 is not None and fx_train_or_state_0 is None:
      return fx_test_t
    return fx_train_t, fx_test_t
示例#13
0
def test_warmup_adapter(jitted):
    def find_reasonable_step_size(step_size, m_inv, z, rng_key):
        return jnp.where(step_size < 1, step_size * 4, step_size / 4)

    num_steps = 150
    adaptation_schedule = build_adaptation_schedule(num_steps)
    init_step_size = 1.0
    mass_matrix_size = 3

    wa_init, wa_update = warmup_adapter(num_steps, find_reasonable_step_size)
    wa_update = jit(wa_update) if jitted else wa_update

    rng_key = random.PRNGKey(0)
    z = jnp.ones(3)
    wa_state = wa_init(
        (z, None, None, None),
        rng_key,
        init_step_size,
        mass_matrix_size=mass_matrix_size,
    )
    step_size, inverse_mass_matrix, _, _, _, _, window_idx, _ = wa_state
    assert step_size == find_reasonable_step_size(
        init_step_size, inverse_mass_matrix, z, rng_key
    )
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))
    assert window_idx == 0

    window = adaptation_schedule[0]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(
            t, 0.7 + 0.1 * t / (window.end - window.start), z, wa_state
        )
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, _, window_idx, _ = wa_state
    assert window_idx == 1
    # step_size is decreased because accept_prob < target_accept_prob
    assert step_size < last_step_size
    # inverse_mass_matrix does not change at the end of the first window
    assert_allclose(inverse_mass_matrix, jnp.ones(mass_matrix_size))

    window = adaptation_schedule[1]
    window_len = window.end - window.start
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(
            t, 0.8 + 0.1 * (t - window.start) / window_len, 2 * z, wa_state
        )
    last_step_size = step_size
    step_size, inverse_mass_matrix, _, _, _, _, window_idx, _ = wa_state
    assert window_idx == 2
    # step_size is increased because accept_prob > target_accept_prob
    assert step_size > last_step_size
    # Verifies that inverse_mass_matrix changes at the end of the second window.
    # Because z_flat is constant during the second window, covariance will be 0
    # and only regularize_term of welford scheme is involved.
    # This also verifies that z_flat terms in the first window does not affect
    # the second window.
    welford_regularize_term = 1e-3 * (5 / (window.end + 1 - window.start + 5))
    assert_allclose(
        inverse_mass_matrix,
        jnp.full((mass_matrix_size,), welford_regularize_term),
        atol=1e-7,
    )

    window = adaptation_schedule[2]
    for t in range(window.start, window.end + 1):
        wa_state = wa_update(t, 0.8, t * z, wa_state)
    last_step_size = step_size
    step_size, final_inverse_mass_matrix, _, _, _, _, window_idx, _ = wa_state
    assert window_idx == 3
    # during the last window, because target_accept_prob=0.8,
    # log_step_size will be equal to the constant prox_center=log(10*last_step_size)
    assert_allclose(step_size, last_step_size * 10, atol=1e-6)
    # Verifies that inverse_mass_matrix does not change during the last window
    # despite z_flat changes w.r.t time t,
    assert_allclose(final_inverse_mass_matrix, inverse_mass_matrix)
def test_standard_gamma_stats(alpha):
    rng = random.PRNGKey(0)
    z = standard_gamma(rng, np.full((1000,), alpha))
    assert_allclose(np.mean(z), alpha, rtol=0.06)
    assert_allclose(np.var(z), alpha, rtol=0.2)
示例#15
0
    def __call__(
        self,
        hidden_states,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        if not is_cross_attention:
            qkv_out = self.c_attn(hidden_states)
            query, key, value = jnp.split(qkv_out, 3, axis=2)
        else:
            q_out = self.q_attn(hidden_states)
            (query, ) = jnp.split(q_out, 1, axis=2)
            kv_out = self.c_attn(key_value_states)
            key, value = jnp.split(kv_out, 2, axis=2)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.causal:
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        dropout_rng = None
        if not deterministic and self.config.attn_pdrop > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attn_pdrop,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.c_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
示例#16
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        causal_attention_mask = None
        if self.causal:
            query_length, key_length = query.shape[1], key.shape[1]
            causal_attention_mask = self.causal_mask[:, :,
                                                     key_length - query_length:
                                                     key_length, :key_length]

        if attention_mask is not None and causal_attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
            attention_mask = combine_masks(attention_mask,
                                           causal_attention_mask,
                                           dtype="i4")
        elif causal_attention_mask is not None:
            attention_mask = causal_attention_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        if attention_mask is not None:
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape, -1e4).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
示例#17
0
 def full(self, shape, fill_value, type_as=None):
     if type_as is None:
         return jnp.full(shape, fill_value)
     else:
         return jnp.full(shape, fill_value, dtype=type_as.dtype)
示例#18
0
    def __call__(
        self,
        hidden_states,
        attention_mask=None,
        deterministic: bool = True,
        init_cache: bool = False,
        output_attentions: bool = False,
    ):
        query = self.q_proj(hidden_states) * jnp.sqrt(self.head_dim).astype(
            self.dtype)
        key = self.k_proj(hidden_states)
        value = self.v_proj(hidden_states)

        query = self._split_heads(query)
        key = self._split_heads(key)
        value = self._split_heads(value)

        query_length, key_length = query.shape[1], key.shape[1]

        if self.has_variable("cache", "cached_key"):
            mask_shift = self.variables["cache"]["cache_index"]
            max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
            causal_mask = lax.dynamic_slice(
                self.causal_mask, (0, 0, mask_shift, 0),
                (1, 1, query_length, max_decoder_length))
        else:
            causal_mask = self.causal_mask[:, :, :query_length, :key_length]

        batch_size = hidden_states.shape[0]
        causal_mask = jnp.broadcast_to(causal_mask,
                                       (batch_size, ) + causal_mask.shape[1:])

        attention_mask = jnp.broadcast_to(
            jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
        attention_mask = combine_masks(attention_mask, causal_mask)

        dropout_rng = None
        if not deterministic and self.config.attention_dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.has_variable("cache", "cached_key") or init_cache:
            key, value, attention_mask = self._concatenate_to_cache(
                key, value, query, attention_mask)

        # transform boolean mask into float mask
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
            jnp.full(attention_mask.shape, -1e9).astype(self.dtype),
        )

        # usual dot product attention
        attn_weights = dot_product_attention_weights(
            query,
            key,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.config.attention_dropout,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)
        attn_output = self.resid_dropout(attn_output,
                                         deterministic=deterministic)

        outputs = (attn_output,
                   attn_weights) if output_attentions else (attn_output, )
        return outputs
示例#19
0
 def model_classify(params, inputs, input_length, k):
     return gmm_models.classify_with_defaults(
         model, params, inputs, eval_batch_size,
         jnp.full([eval_batch_size], input_length, dtype=jnp.int32),
         jnp.full([eval_batch_size], k, dtype=jnp.int32), max_k,
         jnp.eye(data_dim) * mode_var)
示例#20
0
def fullsigval(inputs: Signal, fill_value=1):
    x, t = inputs
    full_shape = (x.shape[0] + t.start - t.stop, ) + x.shape[1:]
    return jnp.full(full_shape, fill_value, dtype=x.dtype)
示例#21
0
 def test_pool_custom_reduce(self):
     x = jnp.full((1, 3, 3, 1), 2.)
     mul_reduce = lambda x, y: x * y
     y = nn.pooling.pool(x, 1., mul_reduce, (2, 2), (1, 1), 'VALID')
     onp.testing.assert_allclose(y, onp.full((1, 2, 2, 1), 2.**4))
示例#22
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=default_kernel_init,
              bias_init=initializers.zeros,
              bias=True):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = DenseGeneral.partial(axis=-1,
                                     features=(num_heads, head_dim),
                                     kernel_init=kernel_init,
                                     bias_init=bias_init,
                                     bias=bias,
                                     precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))

        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(lambda: (key.ndim, key.shape[-2:]))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                indices = [0] * len(cshape)
                i = cache_entry.i
                attn_size = onp.prod(onp.take(cshape, attention_axis))
                for attn_dim in attention_axis:
                    attn_size //= cshape[attn_dim]
                    indices[attn_dim] = i // attn_size
                    i = i % attn_size

                key = lax.dynamic_update_slice(cache_entry.key, key, indices)
                value = lax.dynamic_update_slice(cache_entry.value, value,
                                                 indices)
                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(jnp.float32)[...,
                                                                        None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(onp.take(key.shape, attention_axis))
                attn_size = onp.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                             padding_mask_key=key_padding_mask,
                                             query_shape=query.shape,
                                             key_shape=key.shape,
                                             attention_axis=attention_axis)
            mask_components.append(padding_mask)

        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = DenseGeneral(x,
                           features=features,
                           axis=(-2, -1),
                           kernel_init=kernel_init,
                           bias_init=bias_init,
                           bias=bias,
                           dtype=dtype,
                           precision=precision,
                           name='out')

        return out
    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jnp.ndarray]] = None,
        model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
    ):
        """
        This beam search function is heavily inspired by Flax's official example:
        https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
        """
        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1], ) +
                                  tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            """
            Gathers the beam slices indexed by beam_indices into new beam array.
            """
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams,
                (batch_size, new_num_beams))

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                else:
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length),
                             pad_token_id,
                             dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length),
                                     pad_token_id,
                                     dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids,
                                                     (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(
            jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)),
            [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"][
                "last_hidden_state"] = flatten_beam_dim(
                    model_kwargs["encoder_outputs"]["last_hidden_state"])
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(
                model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(
            flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(
            cur_len=cur_len,
            running_sequences=running_sequences,
            running_scores=running_scores,
            sequences=sequences,
            scores=scores,
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (
                max_length**length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished,
                jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7))
            improvement_still_possible = jnp.all(
                worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished)
                                & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        if input_ids.shape[-1] > 1:
            state = partial(beam_search_body_fn,
                            input_ids_length=input_ids.shape[-1])(state)

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn,
                                            beam_search_body_fn, state)
        else:
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn,
                                   state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences,
                              state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores,
                           state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
示例#24
0
    def __call__(
        self,
        hidden_states: jnp.ndarray,
        key_value_states: Optional[jnp.ndarray] = None,
        attention_mask: Optional[jnp.ndarray] = None,
        init_cache: bool = False,
        deterministic: bool = True,
    ) -> Tuple[jnp.ndarray]:
        """Input shape: Batch x Time x Channel"""

        # if key_value_states are provided this layer is used as a cross-attention layer
        # for the decoder
        is_cross_attention = key_value_states is not None
        batch_size = hidden_states.shape[0]

        # get query proj
        query_states = self.q_proj(hidden_states)
        # get key, value proj
        if is_cross_attention:
            # cross_attentions
            key_states = self.k_proj(key_value_states)
            value_states = self.v_proj(key_value_states)
        else:
            # self_attention
            key_states = self.k_proj(hidden_states)
            value_states = self.v_proj(hidden_states)

        query_states = self._split_heads(query_states)
        key_states = self._split_heads(key_states)
        value_states = self._split_heads(value_states)

        # handle cache prepare causal attention mask
        if self.causal:
            query_length, key_length = query_states.shape[1], key_states.shape[
                1]
            if self.has_variable("cache", "cached_key"):
                mask_shift = self.variables["cache"]["cache_index"]
                max_decoder_length = self.variables["cache"][
                    "cached_key"].shape[1]
                causal_mask = lax.dynamic_slice(
                    self.causal_mask, (0, 0, mask_shift, 0),
                    (1, 1, query_length, max_decoder_length))
            else:
                causal_mask = self.causal_mask[:, :, :query_length, :
                                               key_length]
            causal_mask = jnp.broadcast_to(causal_mask, (batch_size, ) +
                                           causal_mask.shape[1:])

        # combine masks if needed
        if attention_mask is not None and self.causal:
            attention_mask = jnp.broadcast_to(
                jnp.expand_dims(attention_mask, axis=(-3, -2)),
                causal_mask.shape)
            attention_mask = combine_masks(attention_mask, causal_mask)
        elif self.causal:
            attention_mask = causal_mask
        elif attention_mask is not None:
            attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))

        # During fast autoregressive decoding, we feed one position at a time,
        # and cache the keys and values step by step.
        if self.causal and (self.has_variable("cache", "cached_key")
                            or init_cache):
            key_states, value_states, attention_mask = self._concatenate_to_cache(
                key_states, value_states, query_states, attention_mask)

        # Convert the boolean attention mask to an attention bias.
        if attention_mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
                jnp.full(attention_mask.shape,
                         jnp.finfo(self.dtype).min).astype(self.dtype),
            )
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout > 0.0:
            dropout_rng = self.make_rng("dropout")

        attn_weights = dot_product_attention_weights(
            query_states,
            key_states,
            bias=attention_bias,
            dropout_rng=dropout_rng,
            dropout_rate=self.dropout,
            broadcast_dropout=True,
            deterministic=deterministic,
            dtype=self.dtype,
            precision=None,
        )

        attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights,
                                 value_states)
        attn_output = self._merge_heads(attn_output)
        attn_output = self.out_proj(attn_output)

        return attn_output, attn_weights
示例#25
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     return jnp.full(jnp.shape(x)[:-1], 0.)
示例#26
0
def modelSelection(
    X, y, models=None, prior="normal", family="logistic", groups=None, method="post"
):
    """Bayesian model selection for generalized linear models.

    ``modelSelection`` enumerates all the models or iterates over ``models`` to
    perform Bayesian model selection for linear, logistic or poisson regression
    using local or non-local priors.

    Parameters
    ----------
    X : array_like
        Design matrix
    y : array_like
        Observations
    models : array_like, optional
        If ``None``, enumeration will consider all possible models `ignoring groups`.
    prior : {mom, normal}, optional
    family : {logistic, poisson}, optional
    groups : array_like, optional
    method : {post, like}, optional

    Returns
    -------
    models : array_like
        Models considered in enumeration.
    modelprobs : array_like
        Posterior probabilities assigned to each model.
    """
    if prior not in ("normal", "mom"):
        raise ValueError("prior not recognized")
    _, p = X.shape
    if models is None:
        n_models = 2 ** p
        model_i = 0
    else:
        n_models = models.shape[0]
    if groups is None:
        groups = jnp.arange(p)
    W, Winv = get_group_zellner(groups, X, prior == "mom")
    p_j = get_p_j(groups)
    modelprobs = jnp.empty(n_models)
    fact_y = jnp.sum(gammaln(y + 1))
    ytX = jnp.dot(y, X)
    if prior == "mom":
        XtX = jnp.dot(X.T, X)
    ## define vmapped functions ##
    # likelihood helpers
    if family == "poisson":
        _loglik = lambda b, ytx, x: poisson_log_lik(b, ytx, x, fact_y=fact_y)
    elif family == "logistic":
        _loglik = logistic_log_lik
    else:
        raise ValueError("family not recognized")
    if method == "post":
        _logpost = lambda b, ytx, x, w: (_loglik(b, ytx, x) + normalprior(b, 1, 1, w))
        vmaparg = (0, 0, 0, 0)
        logpost = jit(vmap(_logpost, vmaparg, 0))
        glogpost = jit(vmap(grad(_logpost, argnums=0), vmaparg, 0))
        hlogpost = jit(vmap(hessian(_logpost, argnums=0), vmaparg, 0))
        jitted_ala = jit(
            Partial(
                marghood_ala_post, logpost=logpost, glogpost=glogpost, hlogpost=hlogpost
            )
        )
    elif method == "lik":
        logpr = jit(vmap(lambda b, w: normalprior(b, 1, 1, w), (0, 0), 0))
        loglik = jit(vmap(_loglik, (0, 0, 0), 0))
        gloglik = jit(vmap(grad(_loglik, argnums=0), (0, 0, 0), 0))
        hloglik = jit(vmap(hessian(_loglik, argnums=0), (0, 0, 0), 0))
        jitted_ala = jit(
            Partial(
                marghood_ala_lik,
                logl=loglik,
                glogl=gloglik,
                hlogl=hloglik,
                logpr=logpr,
            )
        )
    else:
        raise ValueError("method not recognized")
    ## loop, we need equal shapes for vmap, hence calculations are vectorized ##
    ## on a number of selected variables basis ##
    for n_vars in range(1, p + 1):
        if models is None:
            models_iter = jnp.array(list(combinations(jnp.arange(p), n_vars)))
            model_mask = jnp.full((n_models,), False)
            model_mask = model_mask.at[model_i : models_iter.shape[0]].set(True)
            model_i += models_iter.shape[0]
        else:
            model_mask = models.sum(axis=1) == n_vars
            if model_mask.sum() == 0:
                continue
            models_iter = (
                jnp.arange(p)
                .reshape((1, -1))[models[model_mask, :]]
                .reshape((-1, n_vars))
            )
        b0 = jnp.zeros(models_iter.shape)
        X_iter = apply_mask_2d(X, models_iter)
        ytX_iter = apply_mask_1d(ytX, models_iter)
        W_iter = apply_mask_matrix(W, models_iter)
        if method == "post":
            args = [ytX_iter, X_iter, W_iter]
            margs = jitted_ala(b0=b0, args=args)
        elif method == "lik":
            argspr = (W_iter,)
            argsl = (ytX_iter, X_iter)
            margs = jitted_ala(b0=b0, argsl=argsl, argspr=argspr)
        if prior == "mom":
            p_j_iter = apply_mask_1d(p_j, models_iter)
            Winv_iter = apply_mask_matrix(Winv, models_iter)
            XtX_iter = apply_mask_matrix(XtX, models_iter)
            margs += vmap(gmomprior_correction, (None, 0, 0, 0, 0), 0)(
                1, Winv_iter, p_j_iter, XtX_iter, ytX_iter
            )
        modelprobs = modelprobs.at[model_mask].set(margs)
    return models, modelprobs
示例#27
0
 def _get_transform(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale = numpyro.param('{}_scale'.format(self.prefix),
                           np.full(self.latent_size, self._init_scale),
                           constraint=constraints.positive)
     return AffineTransform(loc, scale, domain=constraints.real_vector)
示例#28
0
 def _get_posterior(self):
     loc = numpyro.param('{}_loc'.format(self.prefix), self._init_latent)
     scale = numpyro.param('{}_scale'.format(self.prefix),
                           jnp.full(self.latent_dim, self._init_scale),
                           constraint=constraints.positive)
     return dist.Normal(loc, scale)
示例#29
0
 def _counter_init(rng, shape, dtype, state):
     del rng, dtype
     state['counter'] += 1.
     return jnp.full(shape, state['counter'])
示例#30
0
 def testFull(self, shape, fill_value_dtype, out_dtype, rng):
   onp_fun = lambda fill_value: onp.full(shape, fill_value, dtype=out_dtype)
   lnp_fun = lambda fill_value: lnp.full(shape, fill_value, dtype=out_dtype)
   args_maker = lambda: [rng((), fill_value_dtype)]
   self._CheckAgainstNumpy(onp_fun, lnp_fun, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True)