Exemplo n.º 1
0
        def beam_search_body_fn(state, input_ids_length=1):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(
                    state.running_sequences,
                    (0, 0, state.cur_len - input_ids_length),
                    (batch_size, num_beams, input_ids_length),
                ))
            model_outputs = model(input_token,
                                  params=params,
                                  **state.model_kwargs)

            logits = unflatten_beam_dim(model_outputs.logits[:, -1],
                                        batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams
                                                  ),
                model_outputs.past_key_values)

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(flatten_beam_dim(running_sequences),
                                         flatten_beam_dim(log_probs),
                                         state.cur_len)
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores,
                                                    axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs,
                                                     k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(state.running_sequences,
                                                  topk_beam_indices,
                                                  batch_size, beams_to_keep)
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences,
                                                      topk_ids,
                                                      (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.
                                                    cur_len] == eos_token_id
            running_topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(
                -1.0e7)
            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(running_topk_log_probs,
                                                   k=num_beams)[1],
                                         axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, running_topk_log_probs], next_topk_indices,
                batch_size, num_beams)

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len**length_penalty)
            beams_in_batch_are_full = (jnp.broadcast_to(
                state.is_sent_finished.all(axis=-1, keepdims=True),
                did_topk_just_finished.shape)
                                       & early_stopping)
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate(
                [state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs],
                                            axis=1)
            merged_is_sent_finished = jnp.concatenate(
                [state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores,
                                                     k=num_beams)[1],
                                           axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished],
                topk_merged_indices, batch_size, num_beams)

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices,
                                                next_topk_indices, batch_size,
                                                num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size,
                                      num_beams)
            model_outputs["past_key_values"] = jax.tree_map(
                lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(
                model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )
Exemplo n.º 2
0
Arquivo: jax.py Projeto: yibit/eagerpy
 def flip(self: TensorType, axis: Optional[AxisAxes] = None) -> TensorType:
     return type(self)(np.flip(self.raw, axis=axis))
Exemplo n.º 3
0
 def flip(self, a, axis=None):
     return jnp.flip(a, axis)
Exemplo n.º 4
0
 def sort(tensor, axis, descending=False):
     if descending:
         return np.flip(np.sort(tensor, axis=axis), axis=axis)
     else:
         return np.sort(tensor, axis=axis)
Exemplo n.º 5
0
def flip(x, axis=None):
  if isinstance(x, JaxArray): x = x.value
  return JaxArray(jnp.flip(x, axis=axis))
Exemplo n.º 6
0
def flip(x, axis=None, _=None):
    if isinstance(axis, list) or isinstance(axis, tuple):
        raise Exception('Jax does not support flip() across multiple indices')
    return _jnp.flip(x, axis)
Exemplo n.º 7
0
 def get_index(x):
     x = 2 * (grid.lower - x.flatten()) / grid.size + 1
     idx = (grid.shape * np.arccos(x)) // np.pi
     idx = np.nan_to_num(idx, nan=grid.shape)
     return (*np.flip(np.uint32(idx)), )
Exemplo n.º 8
0
def rewards_to_go(rewards, mask, gamma=0.99):
    r"""Computes rewards to go.

  Reward to go is defined as follows, the discounted reward that we have to
  yet collect, going forward from this point, i.e.:

  r2g_t = \sum_{l=0}^{\infty} (\gamma^{l} * reward_{t+l})

  Args:
    rewards: np.ndarray of shape (B, T) of rewards.
    mask: np.ndarray of shape (B, T) of mask for the rewards.
    gamma: float, discount factor.

  Returns:
    rewards to go, np.ndarray of shape (B, T).
  """
    B, T = rewards.shape  # pylint: disable=invalid-name,unused-variable

    masked_rewards = rewards * mask  # (B, T)

    # The lax.scan version of this is slow, but we still show it here for
    # completeness.
    #   rewards_rev = np.flip(masked_rewards, axis=1)  # (B, T) flipped on time.
    #   rrt = np.transpose(rewards_rev)  # (T, B) transpose to scan over time.
    #
    #   def discounting_add(carry, reward):
    #     x = reward + (gamma * carry)
    #     return x, x
    #
    #   _, ys = lax.scan(discounting_add,
    #                    np.zeros_like(rrt[0], dtype=np.float32),
    #                    rrt.astype(np.float32))
    #
    #   # ys is (T, B) and T is in reverse order.
    #   return np.flip(np.transpose(ys), axis=1)

    # We use the following recurrence relation, derived from the equation above:
    #
    # r2g[t+1] = (r2g[t] - r[t]) / gamma
    #
    # This means we'll need to calculate r2g[0] first and then r2g[1] and so on ..
    #
    # **However** this leads to overflows for long sequences: r2g[t] - r[t] > 0
    # and gamma < 1.0, so the division keeps increasing.
    #
    # So we just run the recurrence in reverse, i.e.
    #
    # r2g[t] = r[t] + (gamma*r2g[t+1])
    #
    # This is much better, but might have lost updates since the (small) rewards
    # at earlier time-steps may get added to a (very?) large sum.

    # Compute r2g_{T-1} at the start and then compute backwards in time.
    r2gs = [masked_rewards[:, -1]]

    # Go from T-2 down to 0.
    for t in reversed(range(T - 1)):
        r2gs.append(masked_rewards[:, t] + (gamma * r2gs[-1]))

    # The list should have length T.
    assert T == len(r2gs)

    # First we stack them in the correct way to make it (B, T), but these are
    # still from newest (T-1) to oldest (0), so then we flip it on time axis.
    return np.flip(np.stack(r2gs, axis=1), axis=1)
Exemplo n.º 9
0
 def reverse_line(self, line, b):
     return jax.lax.cond(b==1, lambda z : z, lambda z : jnp.flip(z,0), line) 
Exemplo n.º 10
0
 def get_index(x):
     h = grid.size / grid.shape
     idx = (x.flatten() - grid.lower) // h
     idx = idx % grid.shape
     return (*np.flip(np.uint32(idx)), )
Exemplo n.º 11
0
 def not_0_or_2():
     perm = list(range(m.ndim))
     perm[ax1], perm[ax2] = perm[ax2], perm[ax1]
     return jnp.where(k == 1, jnp.transpose(jnp.flip(m, ax2), perm),
                      jnp.flip(jnp.transpose(m, perm), ax2))
Exemplo n.º 12
0
 def not_0():
     return jnp.where(k == 2, jnp.flip(jnp.flip(m, ax1), ax2), not_0_or_2())
Exemplo n.º 13
0
def _reverse(tensor, axis, name=None):  # pylint: disable=unused-argument
    if np.array(axis).ndim == 0:
        return np.flip(tensor, axis)
    for ax in axis:
        tensor = np.flip(tensor, ax)
    return tensor
Exemplo n.º 14
0
def argsort(input, dim, descending):
    _sorted = jnp.argsort(input, axis=dim)
    if descending is True:
        _sorted = jnp.flip(_sorted, axis=dim)
    return _sorted
Exemplo n.º 15
0
 def get_index(x):
     h = grid.size / grid.shape
     idx = (x.flatten() - grid.lower) // h
     idx = np.where((idx < 0) | (idx > grid.shape), grid.shape, idx)
     return (*np.flip(np.uint32(idx)), )
Exemplo n.º 16
0
 def testFlip(self, shape, dtype, axis, rng):
   args_maker = self._GetArgsMaker(rng, [shape], [dtype])
   lnp_op = lambda x: lnp.flip(x, axis)
   onp_op = lambda x: onp.flip(x, axis)
   self._CheckAgainstNumpy(onp_op, lnp_op, args_maker, check_dtypes=True)
   self._CompileAndCheck(lnp_op, args_maker, check_dtypes=True)
Exemplo n.º 17
0
def _flip_and_average_on_center(array):
    """Flips and averages array on the center."""
    return (array + jnp.flip(array)) / 2
Exemplo n.º 18
0
def tei_array(geom, basis):
    """
    Build two electron integral array from a jax.numpy array of the cartesian geometry in Bohr, 
    and a basis dictionary as defined by basis_utils.build_basis_set
    We have to loop over primitives rather than shells because JAX needs intermediates to be consistent 
    sizes in order to compile.
    """
    # Smush primitive data together into vectors
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    max_am = jnp.max(ams)
    max_am_idx = max_am * 4 + 1 
    #TODO add excpetion raise if angular momentum is too high
    B_vals = jnp.zeros(4*max_am+1)  
    nprim = coeffs.shape[0]
    # Obtain all possible primitive quartet index combinations 
    primitive_quartets = cartesian_product(jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim), jnp.arange(nprim))

    #print("Number of basis functions: ", nbf)
    #print("Number of primitve quartets: ", primitive_quartets.shape[0])

    #TODO Experimental: precompute quantities and lookup inside loop
    # Compute all possible Gaussian products for this basis set
    aa_plus_bb = jnp.broadcast_to(exps, (nprim,nprim)) + jnp.transpose(jnp.broadcast_to(exps, (nprim,nprim)), (1,0))
    aa_times_A = jnp.einsum('i,ij->ij', exps, geom[atoms])
    aaxA_plus_bbxB = aa_times_A[:,None,:] + aa_times_A[None,:,:]
    gaussian_products = jnp.einsum('ijk,ij->ijk', aaxA_plus_bbxB, 1/aa_plus_bb)  

    # Compute all rab2 (rcd2), every possible jnp.dot(A-B,A-B)
    natom = geom.shape[0]
    tmpA = jnp.broadcast_to(geom, (natom,natom,3))
    AminusB = (tmpA - jnp.transpose(tmpA, (1,0,2)))
    AmBdot = jnp.einsum('ijk,ijk->ij', AminusB, AminusB) # shape: (natom,natom)

    # Compute all differences between gaussian product centers with all atom centers
    tmpP = jnp.tile(gaussian_products, natom).reshape(nprim,nprim,natom,3)
    PminusA = tmpP - jnp.broadcast_to(geom, tmpP.shape)

    # Commpute all powers (up to max_am) of differences between gaussian product centers and atom centers
    # Shape: (nprim, nprim, natom, 3, max_am+1). In loop index PA_pow as [p1,p2,atoms[p1],:,:]
    PminusA_pow = jnp.power(jnp.transpose(jnp.broadcast_to(PminusA, (max_am+1,nprim,nprim,natom,3)), (1,2,3,4,0)), jnp.arange(max_am+1))

    with loops.Scope() as s:
      s.G = jnp.zeros((nbf,nbf,nbf,nbf))
      s.a = 0  # center A angular momentum iterator 
      s.b = 0  # center B angular momentum iterator 
      s.c = 0  # center C angular momentum iterator 
      s.d = 0  # center D angular momentum iterator 

      # Loop over primitive quartets, compute integral, add to appropriate index in G
      for prim_quar in s.range(primitive_quartets.shape[0]):
        # Load in primitive indices, coeffs, exponents, centers, angular momentum index, and leading placement index in TEI array
        p1,p2,p3,p4 = primitive_quartets[prim_quar] 
        coef = coeffs[p1] * coeffs[p2] * coeffs[p3] * coeffs[p4]
        aa, bb, cc, dd = exps[p1], exps[p2], exps[p3], exps[p4]
        ld1, ld2, ld3, ld4 = am_leading_indices[ams[p1]],am_leading_indices[ams[p2]],am_leading_indices[ams[p3]],am_leading_indices[ams[p4]]
        idx1, idx2, idx3, idx4 = indices[p1],indices[p2],indices[p3],indices[p4],
        #A, B, C, D = geom[atoms[p1]], geom[atoms[p2]], geom[atoms[p3]], geom[atoms[p4]]

        # Compute common intermediates before looping over AM distributions.
        # Avoids redundant recomputations/reassignment for all classes other than (ss|ss).
        #AB = A - B
        #CD = C - D
        #rab2 = jnp.dot(AB,AB)
        #rcd2 = jnp.dot(CD,CD)
        #P = (aa * A + bb * B) / gamma1
        #Q = (cc * C + dd * D) / gamma2
        gamma1 = aa + bb
        gamma2 = cc + dd

        #TODO
        P = gaussian_products[p1,p2]
        Q = gaussian_products[p3,p4]
        rab2 = AmBdot[atoms[p1],atoms[p2]]
        rcd2 = AmBdot[atoms[p3],atoms[p4]]
        #PA = PminusA[p1,p2,atoms[p1]]
        #PB = PminusA[p1,p2,atoms[p2]]
        #QC = PminusA[p3,p4,atoms[p3]]
        #QD = PminusA[p3,p4,atoms[p4]]
        #TODO

        PQ = P - Q
        rpq2 = jnp.dot(PQ,PQ)
        delta = 0.25*(1/gamma1+1/gamma2)

        boys_arg = 0.25 * rpq2 / delta
        boys_eval = boys(jnp.arange(max_am_idx), boys_arg) 

        # Need all powers of Pi-Ai,Pi-Bi,Qi-Ci,Qi-Di (i=x,y,z) up to max_am and Qi-Pi up to max_am_idx
        # note: this computes unncessary quantities for lower angular momentum, 
        # but avoids repeated computation of the same quantities in loops for higher angular momentum

        #PA_pow = jnp.power(jnp.broadcast_to(P-A, (max_am+1,3)).T, jnp.arange(max_am+1))
        #PB_pow = jnp.power(jnp.broadcast_to(P-B, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QC_pow = jnp.power(jnp.broadcast_to(Q-C, (max_am+1,3)).T, jnp.arange(max_am+1))
        #QD_pow = jnp.power(jnp.broadcast_to(Q-D, (max_am+1,3)).T, jnp.arange(max_am+1))

        PA_pow = PminusA_pow[p1,p2,atoms[p1],:,:]
        PB_pow = PminusA_pow[p1,p2,atoms[p2],:,:]
        QC_pow = PminusA_pow[p3,p4,atoms[p3],:,:]
        QD_pow = PminusA_pow[p3,p4,atoms[p4],:,:]

        QP_pow = jnp.power(jnp.broadcast_to(Q-P, (max_am_idx,3)).T, jnp.arange(max_am_idx))
        # Gamma powers are negative, up to -(l1+l2). 
        # Make array such that the given negative index returns the same negative power.
        g1_pow = jnp.power(4*gamma1, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        g2_pow = jnp.power(4*gamma2, -jnp.roll(jnp.flip(jnp.arange(2*max_am+1)),1)) 
        oodelta_pow = jnp.power(1 / delta, jnp.arange(max_am_idx))  # l1 + l2 + l3 + l4 + 1

        prefactor = 34.986836655249726 / (gamma1*gamma2*jnp.sqrt(gamma1+gamma2)) \
                    * jnp.exp(-aa*bb*rab2/gamma1 + -cc*dd*rcd2/gamma2) * coef

        # TODO is there symmetry here?
        s.a = 0
        for _ in s.while_range(lambda: s.a < dims[p1]):
          s.b = 0
          for _ in s.while_range(lambda: s.b < dims[p2]):
            s.c = 0
            for _ in s.while_range(lambda: s.c < dims[p3]):
              s.d = 0
              for _ in s.while_range(lambda: s.d < dims[p4]):
                # Collect angular momentum and index in G
                la, ma, na = angular_momentum_combinations[s.a + ld1]
                lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                lc, mc, nc = angular_momentum_combinations[s.c + ld3]
                ld, md, nd = angular_momentum_combinations[s.d + ld4]
                i = idx1 + s.a
                j = idx2 + s.b
                k = idx3 + s.c
                l = idx4 + s.d
                # Compute the primitive quartet tei and add to appropriate index in G
                Bx = B_array(la,lb,lc,ld,PA_pow[0],PB_pow[0],QC_pow[0],QD_pow[0],QP_pow[0],g1_pow,g2_pow,oodelta_pow,B_vals)
                By = B_array(ma,mb,mc,md,PA_pow[1],PB_pow[1],QC_pow[1],QD_pow[1],QP_pow[1],g1_pow,g2_pow,oodelta_pow,B_vals)
                Bz = B_array(na,nb,nc,nd,PA_pow[2],PB_pow[2],QC_pow[2],QD_pow[2],QP_pow[2],g1_pow,g2_pow,oodelta_pow,B_vals)

                with loops.Scope() as S:
                  S.primitive = 0.
                  S.I = 0
                  S.J = 0
                  S.K = 0
                  for _ in S.while_range(lambda: S.I < la + lb + lc + ld + 1):
                    S.J = 0 
                    tmp = Bx[S.I] 
                    for _ in S.while_range(lambda: S.J < ma + mb + mc + md + 1):
                      S.K = 0 
                      tmp *= By[S.J] 
                      for _ in S.while_range(lambda: S.K < na + nb + nc + nd + 1):
                        tmp *= Bz[S.K] * boys_eval[S.I + S.J + S.K]
                        S.primitive += tmp
                        S.K += 1
                      S.J += 1
                    S.I += 1
                tei = prefactor * S.primitive
                s.G = jax.ops.index_add(s.G, jax.ops.index[i,j,k,l], tei) 

                s.d += 1
              s.c += 1
            s.b += 1
          s.a += 1
      return s.G
Exemplo n.º 19
0
    def beam_search_loop_body_fn(state):
        """Beam search loop state update function."""
        # Collect the current position slice along length to feed the fast
        # autoregressive decoder model.  Flatten the beam dimension into batch
        # dimension for feeding into the model.
        # --> [batch * beam, 1]
        flat_ids = flatten_beam_dim(
            lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index),
                              (batch_size, beam_size, 1)))
        # Flatten beam dimension into batch to be compatible with model.
        # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...}
        flat_cache = jax.tree_map(flatten_beam_dim, state.cache)

        # Call fast-decoder model on current tokens to get next-position logits.
        # --> [batch * beam, vocab]
        flat_logits, new_flat_cache = tokens_to_logits(flat_ids, flat_cache)

        # unflatten beam dimension
        # [batch * beam, vocab] --> [batch, beam, vocab]
        logits = unflatten_beam_dim(flat_logits, batch_size, beam_size)
        # Unflatten beam dimension in attention cache arrays
        # {[batch * beam, ...], ...} --> {[batch, beam, ...], ...}
        new_cache = jax.tree_map(
            lambda x: unflatten_beam_dim(x, batch_size, beam_size),
            new_flat_cache)

        # Gather log probabilities from logits
        candidate_log_probs = jax.nn.log_softmax(logits)
        # Add new logprobs to existing prefix logprobs.
        # --> [batch, beam, vocab]
        log_probs = (candidate_log_probs +
                     jnp.expand_dims(state.live_logprobs, axis=2))

        # We'll need the vocab size, gather it from the log probability dimension.
        vocab_size = log_probs.shape[2]

        # Each item in batch has beam_size * vocab_size candidate sequences.
        # For each item, get the top 2*k candidates with the highest log-
        # probabilities. We gather the top 2*K beams here so that even if the best
        # K sequences reach EOS simultaneously, we have another K sequences
        # remaining to continue the live beam search.
        beams_to_keep = 2 * beam_size
        # Flatten beam and vocab dimensions.
        flat_log_probs = log_probs.reshape(
            (batch_size, beam_size * vocab_size))
        # Gather the top 2*K scores from _all_ beams.
        # --> [batch, 2*beams], [batch, 2*beams]
        topk_log_probs, topk_indices = lax.top_k(flat_log_probs,
                                                 k=beams_to_keep)
        # Recover the beam index by floor division.
        topk_beam_indices = topk_indices // vocab_size
        # Gather 2*k top beams.
        # --> [batch, 2*beams, length]
        topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size,
                                beams_to_keep)

        # Append the most probable 2*K token IDs to the top 2*K sequences
        # Recover token id by modulo division and expand Id array for broadcasting.
        # --> [batch, 2*beams, 1]
        topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
        # Update sequences for the 2*K top-k new sequences.
        # --> [batch, 2*beams, length]
        topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids,
                                            (0, 0, state.cur_index + 1))

        # Update LIVE (in-progress) sequences:
        # Did any of these sequences reach an end marker?
        # --> [batch, 2*beams]
        newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker)
        # To prevent these newly finished sequences from being added to the LIVE
        # set of active beam search sequences, set their log probs to a very large
        # negative value.
        new_log_probs = topk_log_probs + newly_finished * NEG_INF
        # Determine the top k beam indices (from top 2*k beams) from log probs.
        # --> [batch, beams]
        _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size)
        new_topk_indices = jnp.flip(new_topk_indices, axis=1)
        # Gather the top k beams (from top 2*k beams).
        # --> [batch, beams, length], [batch, beams]
        top_alive_seq, top_alive_log_probs = gather_beams(
            [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size)

        # Determine the top k beam indices from the original set of all beams.
        # --> [batch, beams]
        top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices,
                                         batch_size, beam_size)
        # With these, gather the top k beam-associated caches.
        # --> {[batch, beams, ...], ...}
        top_alive_cache = gather_beams(new_cache, top_alive_indices,
                                       batch_size, beam_size)

        # Update FINISHED (reached end of sentence) sequences:
        # Calculate new seq scores from log probabilities.
        new_scores = topk_log_probs / brevity_penalty(alpha,
                                                      state.cur_index + 1)
        # Mask out the still unfinished sequences by adding large negative value.
        # --> [batch, 2*beams]
        new_scores += (~newly_finished) * NEG_INF

        # Combine sequences, scores, and flags along the beam dimension and compare
        # new finished sequence scores to existing finished scores and select the
        # best from the new set of beams.
        finished_seqs = jnp.concatenate(  # --> [batch, 3*beams, length]
            [state.finished_seqs, topk_seq],
            axis=1)
        finished_scores = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_scores, new_scores],
            axis=1)
        finished_flags = jnp.concatenate(  # --> [batch, 3*beams]
            [state.finished_flags, newly_finished],
            axis=1)
        # --> [batch, beams, length], [batch, beams], [batch, beams]
        top_finished_seq, top_finished_scores, top_finished_flags = (
            gather_topk_beams([finished_seqs, finished_scores, finished_flags],
                              finished_scores, batch_size, beam_size))

        return BeamState(cur_index=state.cur_index + 1,
                         live_logprobs=top_alive_log_probs,
                         finished_scores=top_finished_scores,
                         live_seqs=top_alive_seq,
                         finished_seqs=top_finished_seq,
                         finished_flags=top_finished_flags,
                         cache=top_alive_cache)
Exemplo n.º 20
0
if __name__ == '__main__':  # Test
    key = random.PRNGKey(42)
    from glm_py import GLMPy

    N = 2
    M = 100
    dh = 2
    ds = 8
    p = {'N': N, 'M': M, 'dh': dh, 'ds': ds, 'dt': 0.1, 'n': 0, 'N_lim': N, 'M_lim': M}

    w = random.normal(key, shape=(N, N)) * 0.001
    h = random.normal(key, shape=(N, dh)) * 0.001
    k = random.normal(key, shape=(N, ds)) * 0.001
    b = random.normal(key, shape=(N, 1)) * 0.001

    theta = {'h': np.flip(h, axis=1), 'w': w, 'b': b, 'k': k}
    model = GLMJax(p, theta)

    sN = 8  #
    data = onp.random.randn(sN, 2)  # onp.zeros((8, 50))
    stim = onp.random.randn(ds, 2)
    print(model.ll(data, stim))


    def gen_ref():
        ow = onp.asarray(w)[:sN, :sN]
        oh = onp.asarray(h)[:sN, ...]
        ok = onp.asarray(k)[:sN, ...]
        ob = onp.asarray(b)[:sN, ...]

        p = {'numNeurons': sN, 'hist_dim': dh, 'numSamples': M, 'dt': 0.1, 'stim_dim': ds}
Exemplo n.º 21
0
    def __call__(
        self,
        input_values,
        attention_mask=None,
        mask_time_indices=None,
        deterministic=True,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        """

        Returns:

        Example::

            >>> from transformers import Wav2Vec2Processor, FlaxWav2Vec2Model
            >>> from datasets import load_dataset
            >>> import soundfile as sf

            >>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h")
            >>> model = FlaxWav2Vec2Model.from_pretrained("facebook/wav2vec2-base-960h")

            >>> def map_to_array(batch):
            >>>     speech, _ = sf.read(batch["file"])
            >>>     batch["speech"] = speech
            >>>     return batch

            >>> ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
            >>> ds = ds.map(map_to_array)

            >>> input_values = processor(ds["speech"][0], return_tensors="np").input_values  # Batch size 1
            >>> hidden_states = model(input_values).last_hidden_state

        """
        extract_features = self.feature_extractor(input_values)

        if attention_mask is not None:
            # compute real output lengths according to convolution formula
            output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1).astype("i4"))

            attention_mask = jnp.zeros(extract_features.shape[:2], dtype=self.dtype)

            # these two operations makes sure that all values
            # before the output lengths indices are attended to
            attention_mask = jax.ops.index_update(
                attention_mask, jax.ops.index[jnp.arange(attention_mask.shape[0]), output_lengths - 1], 1
            )
            attention_mask = jnp.flip(jnp.flip(attention_mask, -1).cumsum(-1), -1).astype("bool")

        hidden_states, extract_features = self.feature_projection(extract_features, deterministic=deterministic)
        if mask_time_indices is not None:  # apply SpecAugment along time axis with given indices
            hidden_states = jnp.where(
                jnp.broadcast_to(mask_time_indices[:, :, None], hidden_states.shape),
                jnp.broadcast_to(self.masked_spec_embed[None, None, :], hidden_states.shape),
                hidden_states,
            )

        encoder_outputs = self.encoder(
            hidden_states,
            attention_mask=attention_mask,
            deterministic=deterministic,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        hidden_states = encoder_outputs[0]

        if not return_dict:
            return (hidden_states, extract_features) + encoder_outputs[1:]

        return FlaxWav2Vec2BaseModelOutput(
            last_hidden_state=hidden_states,
            extract_features=extract_features,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )