示例#1
0
 def expand(observed_value):
     return jnp.tile(observed_value, (num_prediction_samples, 1))
    def apply(self,
              inputs,
              vocab_size,
              inputs_positions=None,
              inputs_segmentation=None,
              shared_embedding=None,
              use_bfloat16=False,
              emb_dim=512,
              num_heads=8,
              dtype=jnp.float32,
              num_layers=6,
              qkv_dim=512,
              mlp_dim=2048,
              max_len=512,
              train=True,
              dropout_rate=0.1,
              attention_dropout_rate=0.1,
              learn_pos_emb=False,
              classifier=False,
              classifier_pool='CLS',
              num_classes=10):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      vocab_size: size of the vocabulary
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.
      shared_embedding: a shared embedding layer to use.
      use_bfloat16: bool: whether use bfloat16.
      emb_dim: dimension of embedding
      num_heads: number of heads
      dtype: the dtype of the computation (default: float32)
      num_layers: number of layers
      qkv_dim: dimension of the query/key/value
      mlp_dim: dimension of the mlp on top of attention block
      max_len: maximum length.
      train: if it is training,
      dropout_rate: dropout rate
      attention_dropout_rate: dropout rate for attention weights
      learn_pos_emb: boolean, if learn the positional embedding or use the
        sinusoidal positional embedding.
      classifier: boolean, for classification mode (output N-class logits)
      classifier_pool: str, supports "MEAN", "MAX" pooling.
      num_classes: int, number of classification classes.

    Returns:
      output of a transformer encoder or logits if classifier_mode is true.
    """
        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if shared_embedding is None:
            input_embed = nn.Embed.partial(
                num_embeddings=vocab_size,
                features=emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)

        if classifier and classifier_pool == 'CLS':
            cls = self.param('cls', (1, 1, emb_dim), nn.initializers.zeros)
            cls = jnp.tile(cls, [x.shape[0], 1, 1])
            x = jnp.concatenate([cls, x], axis=1)
            max_len += 1
            src_padding_mask = jnp.concatenate(
                [src_padding_mask[:, :1], src_padding_mask], axis=1)
        pe_init = nn.initializers.normal(
            stddev=0.02) if learn_pos_emb else None
        x = common_layers.AddPositionEmbs(x,
                                          inputs_positions=inputs_positions,
                                          posemb_init=pe_init,
                                          max_len=max_len,
                                          name='posembed_input')
        x = nn.dropout(x, rate=dropout_rate, deterministic=not train)

        if use_bfloat16:
            x = x.astype(jnp.bfloat16)
            dtype = jnp.bfloat16
        else:
            dtype = jnp.float32

        # Input Encoder
        for lyr in range(num_layers):
            x = LinearTransformerBlock(
                x,
                qkv_dim=qkv_dim,
                mlp_dim=mlp_dim,
                num_heads=num_heads,
                dtype=dtype,
                padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation,
                dropout_rate=dropout_rate,
                attention_dropout_rate=attention_dropout_rate,
                deterministic=not train,
                name=f'encoderblock_{lyr}')
        encoded = nn.LayerNorm(x, dtype=dtype, name='encoder_norm')

        if classifier:
            encoded = common_layers.classifier_head(
                encoded, num_classes, mlp_dim, pooling_mode=classifier_pool)
        return encoded
示例#3
0
import pickle

import copy

with open('dataautodiff2.pkl', 'rb') as f:  #meas in jax device array format
    meas, platooninfo = pickle.load(f)

platoonobjfn_grad = grad(platoonobjfn_obj_jax)
sim_jax = copy.deepcopy(meas)
pguess_jax = jnp.array([10 * 3.3, .086 / 3.3, 1.545, 2, .175, 5.01])
args_jax = (True, 6)
curplatoon_jax = [[], 581, 611]
n = int(len(curplatoon_jax) - 1)
leadinfo_jax, folinfo_jax, rinfo_jax = makeleadfolinfo_r3(
    curplatoon_jax, platooninfo, meas)
p2 = jnp.tile(pguess_jax, n)

#testobj = platoonobjfn_obj_jax(p2,OVM,OVMadjsys,OVMadj,meas,sim_jax,platooninfo,curplatoon_jax,leadinfo_jax,folinfo_jax,rinfo_jax,*args_jax)

testgrad = platoonobjfn_grad(p2, OVM, OVMadjsys, OVMadj, meas, sim_jax,
                             platooninfo, curplatoon_jax, leadinfo_jax,
                             folinfo_jax, rinfo_jax, *args_jax)
"""
\\ TO DO \\ 
Get the gradient of platoonobjfn_obj using jax. Compare this is the gradient from finite differences, and the gradient from the adjoint method. 

Time for running finite differences is fdertime
Time for running adjoint method is dertime
Accuracy of adjoint method is acc, Relative accuracy of adjoint method is acc2 

For jax, compute the run time to get the gradient. Get the accuracy of jax, and the relative accuracy of jax. 
示例#4
0
    def exec_op(self, op, input_values, deterministic, training, **_):
        """Executes an op according to the normal concrete semantics."""
        input_kwargs: Dict[str, Any] = op.input_kwargs
        op_kwargs: Dict[str, Any] = op.op_kwargs
        op_type = op.type
        if "name" not in op_kwargs:
            raise ValueError("Op kwargs must contain a name.")
        op_name = op_kwargs["name"]

        if op_type == OpType.NONE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [lax.stop_gradient(input_value)]

        elif op_type == OpType.IDENTITY:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert len(op_kwargs) == 1
            output_values = [input_value]

        # nn.linear

        elif op_type == OpType.DENSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.Dense(**op_kwargs)(input_value)]

        elif op_type == OpType.DENSE_GENERAL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            assert 2 <= len(op_kwargs) <= 7
            output_values = [nn.DenseGeneral(**op_kwargs)(input_value)]

        elif op_type == OpType.CONV:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs

            ks = op_kwargs["kernel_size"]
            if isinstance(ks, int):
                op_kwargs["kernel_size"] = (ks, ) * (input_value.ndim - 2)

            output_values = [nn.Conv(**op_kwargs)(input_value)]

        # others

        elif op_type == OpType.MUL:
            assert len(input_values) == 2
            assert not input_kwargs
            assert len(op_kwargs) == 1  # name
            output_values = [input_values[0] * input_values[1]]

        elif op_type in [OpType.ADD, OpType.STOCH_DEPTH]:
            assert len(op_kwargs) == 1  # name

            input_value = input_values[0]
            if "layer_drop_rate" in input_kwargs:
                assert len(input_kwargs) == 1
                survival_rate = 1 - input_kwargs["layer_drop_rate"]
                if survival_rate == 1.0 or deterministic:
                    pass
                else:
                    # Reuse dropout's rng stream.
                    rng = self.make_rng("dropout")
                    mask_shape = [input_value.shape[0]
                                  ] + [1] * (input_value.ndim - 1)
                    mask = random.bernoulli(rng,
                                            p=survival_rate,
                                            shape=mask_shape)
                    mask = jnp.tile(mask, [1] + list(input_value.shape[1:]))
                    input_value = lax.select(mask, input_value / survival_rate,
                                             jnp.zeros_like(input_value))
            else:
                assert not input_kwargs
                assert op_type == OpType.ADD

            if op_type == OpType.ADD:
                assert len(input_values) == 2
                output_values = [input_value + input_values[1]]
            else:
                assert len(input_values) == 1
                output_values = [input_value]

        elif op_type == OpType.SCALAR_MUL:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            if "const" in input_kwargs:
                c = input_kwargs["const"]
            else:
                c = 1 / jnp.sqrt(input_values[0].shape[-1])
            output_values = [input_values[0] * c]

        elif op_type == OpType.SCALAR_ADD:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            assert len(op_kwargs) == 1  # name
            assert "const" in input_kwargs
            c = input_kwargs["const"]
            output_values = [input_values[0] + c]

        elif op_type == OpType.DOT_GENERAL:
            assert len(input_values) == 2
            assert 0 < len(input_kwargs) <= 3
            assert len(op_kwargs) == 1  # name
            output_values = [
                lax.dot_general(input_values[0], input_values[1],
                                **input_kwargs)
            ]

        elif op_type == OpType.EINSUM:
            assert len(input_values) == 2
            assert len(input_kwargs) == 1
            assert "sum" in input_kwargs
            output_values = [
                jnp.einsum(input_kwargs["sum"], input_values[0],
                           input_values[1])
            ]

        # nn.attention

        elif op_type == OpType.SELF_ATTENTION:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [
                nn.SelfAttention(**op_kwargs,
                                 deterministic=deterministic)(input_value)
            ]

        # nn.activation

        elif op_type in [
                OpType.RELU, OpType.GELU, OpType.SWISH, OpType.SIGMOID
        ]:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            fn = {
                OpType.RELU: nn.relu,
                OpType.GELU: nn.gelu,
                OpType.SWISH: nn.swish,
                OpType.SIGMOID: nn.sigmoid
            }[op_type]
            output_values = [fn(input_value)]

        elif op_type == OpType.SOFTMAX:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [nn.softmax(input_value, **input_kwargs)]

        # nn.normalization

        elif op_type == OpType.BATCH_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            add_kwargs = {}
            if "use_running_average" not in input_kwargs:
                add_kwargs = {"use_running_average": not training}
            else:
                add_kwargs = {}
            output_values = [
                nn.BatchNorm(**op_kwargs)(input_value, **input_kwargs,
                                          **add_kwargs)
            ]

        elif op_type == OpType.LAYER_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.LayerNorm(**op_kwargs)(input_value)]

        elif op_type == OpType.GROUP_NORM:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            output_values = [nn.GroupNorm(**op_kwargs)(input_value)]

        # reshape operators

        elif op_type == OpType.RESHAPE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert 0 < len(input_kwargs) < 3
            new_shape = input_kwargs.pop("new_shape")
            if new_shape[0] == "B":
                new_shape = (input_value.shape[0], ) + new_shape[1:]
            output_values = [
                jnp.reshape(input_value, new_shape, **input_kwargs)
            ]

        elif op_type == OpType.FLATTEN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert not input_kwargs
            new_shape = (input_value.shape[0], -1)
            output_values = [jnp.reshape(input_value, new_shape)]

        elif op_type == OpType.TRANSPOSE:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) == 1
            assert len(op_kwargs) == 1  # name
            output_values = [jnp.transpose(input_value, **input_kwargs)]

        # nn.stochastic

        elif op_type == OpType.DROPOUT:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert len(input_kwargs) <= 1
            output_values = [
                nn.Dropout(**op_kwargs)(input_value,
                                        deterministic=deterministic,
                                        **input_kwargs)
            ]

        # nn.pooling

        elif op_type == OpType.AVG_POOL or op_type == OpType.MAX_POOL:
            op_fn = nn.avg_pool if op_type == OpType.AVG_POOL else nn.max_pool
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs

            ws = input_kwargs["window_shape"]
            if isinstance(ws, int):
                ws = [ws] * (input_value.ndim - 2)
            new_ws = []
            for window_dim_shape, dim_shape in zip(ws, input_value.shape[1:]):
                if window_dim_shape == 0:
                    new_ws.append(dim_shape)
                else:
                    new_ws.append(window_dim_shape)
            input_kwargs["window_shape"] = tuple(new_ws)

            if "strides" in input_kwargs:
                s = input_kwargs["strides"]
                if isinstance(s, int):
                    input_kwargs["strides"] = (s, ) * (input_value.ndim - 2)

            output_values = [op_fn(input_value, **input_kwargs)]

        elif op_type == OpType.MEAN:
            assert len(input_values) == 1
            input_value = input_values[0]
            assert input_kwargs
            output_values = [jnp.mean(input_value, **input_kwargs)]

        # new param

        elif op_type == OpType.PARAM:
            assert not input_values
            assert 0 < len(input_kwargs) <= 2
            init_fn = input_kwargs.pop("init_fn")

            init_fn_with_kwargs = functools.partial(init_fn, **input_kwargs)
            output_values = [self.param(op_name, init_fn_with_kwargs)]

        else:
            raise ValueError(f"op_type {op_type} not supported...")

        return output_values
示例#5
0
def cartesian_prod(x, y):
    return jnp.stack([jnp.tile(x, len(y)), jnp.repeat(y, len(x))]).T
def naive_upsample_2d(x, factor=2):
    _N, H, W, C = x.shape
    x = jnp.reshape(x, [-1, H, 1, W, 1, C])
    x = jnp.tile(x, [1, 1, factor, 1, factor, 1])
    return jnp.reshape(x, [-1, H * factor, W * factor, C])
  def __call__(
      self,
      encoding: Array,
      mention_batch_positions: Array,
      mention_start_positions: Array,
      mention_end_positions: Array,
      mention_mask: Array,
      memory_keys: Array,
      memory_values: Array,
      memory_mask: Array,
      memory_entity_ids: Array,
      deterministic: bool,
  ) -> Tuple[Array, Dict[str, Array], Dict[str, Array]]:
    """Perform attention update over memory table.

    Args:
      encoding: [batch_size, n_tokens, hidden_size] input representation.
      mention_batch_positions: [n_mentions] mention sample position in batch.
      mention_start_positions: [n_mentions] mention start position in input.
      mention_end_positions: [n_mentions] mention end position in input.
      mention_mask: [n_mentions] attention mask to prevent updates from padding.
      memory_keys: [memory_size, memory_key_dim] mention memory keys.
      memory_values: [memory_size, memory_value_dim] mention memory values.
      memory_mask: [memory_size] mask for valid mentions in memory.
      memory_entity_ids: [memory_size] mention memory entity ids.
      deterministic: don't apply dropout if true.

    Returns:
      Updated input, loss and logging helper dicts.
    """
    loss_helpers, logging_helpers = {}, {}

    # We generate mention representations to use as queries for similarity
    # search by concatenating start and end tokens for each mention and
    # projecting the concatenation with a dense layer.
    mention_start_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_start_positions))
    mention_end_encodings = jut.matmul_2d_index_select(
        encoding, (mention_batch_positions, mention_end_positions))

    queries = self.query_projector(
        jnp.concatenate((mention_start_encodings, mention_end_encodings),
                        axis=-1))

    n_queries = queries.shape[0]

    # For attention over entire memory table, we do not want to duplicate the
    # entire memory table for each query. Instead, we perform an
    # attention-weighted sum to produce a single value. We then feed this value
    # to the update layer as a set of retrieved values of size 1, with score 1.
    if self.k_top is None:
      loss_helpers['top_entity_ids'] = jnp.tile(memory_entity_ids,
                                                (n_queries, 1))
      scores = jnp.einsum('qd,md->qm', queries, memory_keys)
      scores = scores - (1 - memory_mask) * _LARGE_NUMBER
      true_attention_weights = nn.softmax(scores, axis=-1)
      loss_helpers['memory_attention_weights'] = true_attention_weights
      top_values = jnp.einsum('qm,md->qd', true_attention_weights,
                              memory_values)
      # Expand value as though it were a set of retrieved values for each query.
      # Shape (n_queries, 1, memory_value_dim)
      top_values = jnp.expand_dims(top_values, axis=1)
      # Generate pseudo-score (n_queries, 1).
      attention_weights = jnp.ones_like(top_values, shape=(n_queries, 1))
    else:
      # Reshape memory keys for use in approximate top-k similarity layer.
      memory_keys = memory_keys.reshape(self.rows, -1, self.memory_key_dim)
      # We generate a version of the queries with stop gradient to use as input
      # to the topk similarity layer. We actually do want gradient to flow to
      # the queries, but backward differentiation over the topk layer yields
      # inefficient HLO ops. Instead we use queries with gradient to recompute
      # attention scores later.
      queries_sg = jax.lax.stop_gradient(queries)

      # Perform top-k similarity search over queries, yielding
      #   top_values: (queries, k_top, memory_dim)
      #   top_ids: (queries, k_top)
      top_keys, _, top_ids = self.topk_similarity(queries_sg, memory_keys)

      top_ids = top_ids.reshape(n_queries, self.k_top)
      top_values = memory_values[top_ids]
      loss_helpers['top_entity_ids'] = memory_entity_ids[top_ids]

      # We re-compute top scores using the queries with gradient (wg) to make
      # sure the query projector and the rest of the model receives gradient.
      top_scores_wg = jnp.einsum('qd,qkd->qk', queries, top_keys)
      top_mask = memory_mask[top_ids]
      top_scores_wg = top_scores_wg - (1 - top_mask) * _LARGE_NUMBER

      # We perform dot product attention using retrieved memory vectors as key,
      # dense projection of retrieved vectors as value and value and mention
      # representations as query.
      attention_weights = nn.softmax(top_scores_wg, axis=-1)
      loss_helpers['memory_attention_weights'] = attention_weights
    encoding = self.update_layer(
        encoded_input=encoding,
        retrieval_values=top_values,
        retrieval_scores=attention_weights,
        mention_batch_positions=mention_batch_positions,
        mention_start_positions=mention_start_positions,
        mention_end_positions=mention_end_positions,
        mention_mask=mention_mask,
        deterministic=deterministic,
    )

    return encoding, loss_helpers, logging_helpers
示例#8
0
def generate_triplets(key,
                      inputs,
                      n_inliers,
                      n_outliers,
                      n_random,
                      weight_temp=0.5,
                      distance='euclidean',
                      verbose=False):
    """Generate triplets.

  Args:
    key: Random key.
    inputs: Input points.
    n_inliers: Number of inliers.
    n_outliers: Number of outliers.
    n_random: Number of random triplets per point.
    weight_temp: Temperature of the log transformation on the weights.
    distance: Distance type.
    verbose: Whether to print progress.

  Returns:
    triplets and weights
  """
    n_points = inputs.shape[0]
    n_extra = min(n_inliers + 50, n_points)
    index = pynndescent.NNDescent(inputs, metric=distance)
    index.prepare()
    neighbors = index.query(inputs, n_extra)[0]
    neighbors = np.concatenate(
        (np.arange(n_points).reshape([-1, 1]), neighbors), 1)
    if verbose:
        logging.info('found nearest neighbors')
    distance_fn = get_distance_fn(distance)
    # conpute scaled neighbors and the scale parameter
    knn_distances, neighbors, sig = find_scaled_neighbors(
        inputs, neighbors, distance_fn)
    neighbors = neighbors[:, :n_inliers + 1]
    knn_distances = knn_distances[:, :n_inliers + 1]
    key, use_key = random.split(key)
    triplets = sample_knn_triplets(use_key, neighbors, n_inliers, n_outliers)
    weights = find_triplet_weights(inputs,
                                   triplets,
                                   neighbors[:, 1:n_inliers + 1],
                                   distance_fn,
                                   sig,
                                   distances=knn_distances[:, 1:n_inliers + 1])
    flip = weights < 0
    anchors, pairs = triplets[:, 0].reshape([-1, 1]), triplets[:, 1:]
    pairs = jnp.where(jnp.tile(flip.reshape([-1, 1]), [1, 2]),
                      jnp.fliplr(pairs), pairs)
    triplets = jnp.concatenate((anchors, pairs), 1)

    if n_random > 0:
        key, use_key = random.split(key)
        rand_triplets, rand_weights = sample_random_triplets(
            use_key, inputs, n_random, distance_fn, sig)

        triplets = jnp.concatenate((triplets, rand_triplets), 0)
        weights = jnp.concatenate((weights, 0.1 * rand_weights))

    weights -= jnp.min(weights)
    weights = tempered_log(1. + weights, weight_temp)
    return triplets, weights
    def __call__(self,
                 inputs,
                 decoder_mask=None,
                 encoder_decoder_mask=None,
                 inputs_kv=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      inputs: input data for decoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
        cfg = self.config

        # Decoder block.
        assert inputs.ndim == 3
        input_shape = list(inputs.shape)
        input_shape[-1] *= self.num_repeat

        if cfg.use_layernorm:
            x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        else:
            x = inputs
        if self.is_self_att:
            x = MultiDimSelfAttention(num_heads=cfg.num_heads,
                                      num_repeat=self.num_repeat,
                                      out_features=self.out_features,
                                      dtype=cfg.dtype,
                                      qkv_features=cfg.qkv_dim,
                                      kernel_init=cfg.kernel_init,
                                      bias_init=cfg.bias_init,
                                      use_bias=False,
                                      broadcast_dropout=False,
                                      dropout_rate=cfg.attention_dropout_rate,
                                      deterministic=cfg.deterministic,
                                      decode=cfg.decode)(x, decoder_mask)
        else:
            if cfg.use_layernorm:
                x_kv = nn.LayerNorm(dtype=cfg.dtype)(inputs_kv)
            else:
                x_kv = inputs_kv
            x = MultiDimMultiHeadDotProductAttention(
                num_heads=cfg.num_heads,
                num_repeat=self.num_repeat,
                out_features=self.out_features,
                dtype=cfg.dtype,
                qkv_features=cfg.qkv_dim,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init,
                use_bias=False,
                broadcast_dropout=False,
                dropout_rate=cfg.attention_dropout_rate,
                deterministic=cfg.deterministic,
                decode=cfg.decode)(x, x_kv, mask=decoder_mask)

        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)
        inputs = jnp.expand_dims(inputs, -2)
        inputs = jnp.tile(inputs, self.to_tile_shape)
        x = x + inputs

        # MLP block.
        if cfg.use_layernorm:
            z = nn.LayerNorm(dtype=cfg.dtype)(x)
        else:
            z = x
        z = MlpBlock(config=cfg)(z)

        return jnp.reshape(x + z, input_shape)
示例#10
0
def tmp_potential(geom, basis, charges):
    """
    Build one electron integral arrays (overlap, kinetic, and potential integrals)
    """
    coeffs, exps, atoms, ams, indices, dims = flatten_basis_data(basis)
    nbf = get_nbf(basis)
    nprim = coeffs.shape[0]
    max_am = np.max(ams)
    A_vals = np.zeros(2 * max_am + 1)

    # Save various AM distributions for indexing
    # Obtain all possible primitive duet index combinations
    primitive_duets = cartesian_product(np.arange(nprim), np.arange(nprim))

    with loops.Scope() as s:
        s.V = np.zeros((nbf, nbf))
        s.a = 0  # center A angular momentum iterator
        s.b = 0  # center B angular momentum iterator

        for prim_duet in s.range(primitive_duets.shape[0]):
            p1, p2 = primitive_duets[prim_duet]
            coef = coeffs[p1] * coeffs[p2]
            aa, bb = exps[p1], exps[p2]
            atom1, atom2 = atoms[p1], atoms[p2]
            am1, am2 = ams[p1], ams[p2]
            A, B = geom[atom1], geom[atom2]
            ld1, ld2 = am_leading_indices[am1], am_leading_indices[am2]

            gamma = aa + bb
            prefactor = np.exp(-aa * bb * np.dot(A - B, A - B) / gamma)
            P = (aa * A + bb * B) / gamma
            # Maximum angular momentum: hard coded
            # Precompute all powers up to 2+max_am of Pi-Ai, Pi-Bi.
            # We need 2+max_am since kinetic requires incrementing angluar momentum by +2
            PA_pow = np.power(
                np.broadcast_to(P - A, (max_am + 3, 3)).T,
                np.arange(max_am + 3))
            PB_pow = np.power(
                np.broadcast_to(P - B, (max_am + 3, 3)).T,
                np.arange(max_am + 3))

            # For potential integrals, we need the difference between
            # the gaussian product center P and ALL atoms in the molecule,
            # and then take all possible powers up to 2*max_am.
            # We pre-collect this into a 3d array, and then just pull out what we need via indexing in the loops, so they need not be recomputed.
            # The resulting array has dimensions (atom, cartesian component, power) so index (0, 1, 3) would return (Py - atom0_y)^3
            P_minus_geom = np.broadcast_to(P, geom.shape) - geom
            Pgeom_pow = np.power(
                np.transpose(
                    np.broadcast_to(
                        P_minus_geom,
                        (2 * max_am + 1, geom.shape[0], geom.shape[1])),
                    (1, 2, 0)), np.arange(2 * max_am + 1))
            # All possible np.dot(P-atom,P-atom)
            rcp2 = np.einsum('ij,ij->i', P_minus_geom, P_minus_geom)
            # All needed (and unneeded, for am < max_am) boys function evaluations
            boys_arg = np.broadcast_to(rcp2 * gamma,
                                       (2 * max_am + 1, geom.shape[0]))
            boys_nu = np.tile(np.arange(2 * max_am + 1), (geom.shape[0], 1)).T
            boys_eval = boys(boys_nu, boys_arg)

            s.a = 0
            for _ in s.while_range(lambda: s.a < dims[p1]):
                s.b = 0
                for _ in s.while_range(lambda: s.b < dims[p2]):
                    # Gather angular momentum and index
                    la, ma, na = angular_momentum_combinations[s.a + ld1]
                    lb, mb, nb = angular_momentum_combinations[s.b + ld2]
                    # To only create unique indices, need to have separate indices arrays for i and j.
                    i = indices[p1] + s.a
                    j = indices[p2] + s.b
                    # Compute one electron integrals and add to appropriate index
                    potential_int = potential(la, ma, na, lb, mb, nb, aa, bb,
                                              PA_pow, PB_pow, Pgeom_pow,
                                              boys_eval, prefactor, charges,
                                              A_vals) * coef
                    s.V = jax.ops.index_add(s.V, jax.ops.index[i, j],
                                            potential_int)

                    s.b += 1
                s.a += 1
        return s.V
    def __call__(self, inputs_q, inputs_kv, mask=None, deterministic=None):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    Args:
      inputs_q: input queries of shape
        `[batch_sizes..., length, features]`.
      inputs_kv: key/values of shape
        `[batch_sizes..., length, features]`.
      mask: attention mask of shape
        `[batch_sizes..., num_heads, query_length, key/value_length]`.
      deterministic: if false, the attention weight is masked randomly
        using dropout, whereas if true, the attention weights
        are deterministic.

    Returns:
      output of shape `[batch_sizes..., length, features]`.
    """
        assert inputs_q.ndim == 3 and inputs_kv.ndim == 3
        if self.dropout_rate > 0.:  # Require `deterministic` only if using dropout.
            deterministic = merge_param('deterministic', self.deterministic,
                                        deterministic)
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        assert qkv_features % self.num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // self.num_heads

        dense = partial(DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_dim),
                        kernel_init=self.kernel_init,
                        bias_init=self.bias_init,
                        use_bias=self.use_bias,
                        precision=self.precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]
        query, key, value = (dense(dtype=self.dtype,
                                   name='query',
                                   features=(self.num_repeat, self.num_heads,
                                             head_dim))(inputs_q),
                             dense(dtype=self.dtype, name='key')(inputs_kv),
                             dense(dtype=self.dtype, name='value')(inputs_kv))
        key = jnp.expand_dims(key, -3)
        value = jnp.expand_dims(value, -3)
        key = jnp.tile(key, self.to_tile_shape)
        value = jnp.tile(value, self.to_tile_shape)
        query = jnp.swapaxes(query, -3, -4)
        key = jnp.swapaxes(key, -3, -4)
        value = jnp.swapaxes(value, -3, -4)
        '''
    query shape: (batch_size, num_repeat, query_seq_len, num_head, emb_dim)
    kv shape: (batch_size, num_repeat, kv_seq_len, num_head, emb_dim)
    '''

        # Convert the boolean attention mask to an attention bias.
        if mask is not None:
            # attention mask in the form of attention bias
            attention_bias = lax.select(
                mask > 0,
                jnp.full(mask.shape, 0.).astype(self.dtype),
                jnp.full(mask.shape, -1e10).astype(self.dtype))
        else:
            attention_bias = None

        dropout_rng = None
        if not deterministic and self.dropout_rate > 0.:
            dropout_rng = self.make_rng('dropout')

        # apply attention
        x = self.attention_fn(query,
                              key,
                              value,
                              bias=attention_bias,
                              dropout_rng=dropout_rng,
                              dropout_rate=self.dropout_rate,
                              broadcast_dropout=self.broadcast_dropout,
                              deterministic=deterministic,
                              dtype=self.dtype,
                              precision=self.precision)  # pytype: disable=wrong-keyword-args
        # back to the original inputs dimensions
        out = DenseGeneral(features=features,
                           axis=(-2, -1),
                           kernel_init=self.kernel_init,
                           bias_init=self.bias_init,
                           use_bias=self.use_bias,
                           dtype=self.dtype,
                           precision=self.precision,
                           name='out')(x)

        out = jnp.swapaxes(out, -2, -3)
        '''
    swap out from (batch_size, num_repeat, seq_len, emb_dim) to (batch_size, seq_len, num_repeat, emb_dim)
    '''
        return out
示例#12
0
    def predict(self, params, logits, context, target=None):
        context = jnp.expand_dims(jnp.expand_dims(jnp.expand_dims(context,
                                                                  axis=1),
                                                  axis=1),
                                  axis=1)
        context_bias = params.get('context_bias', 0.0)
        context_index = (params['context_maps'] *
                         context).sum(axis=-1) > context_bias

        context_map_values = jnp.asarray(
            [[[[1 << n for n in range(self.context_map_size)]]]])
        context_index = jnp.where(context_index, context_map_values, 0)
        context_index = context_index.sum(axis=-1, keepdims=True)

        batch_size = logits.shape[0]
        class_neuron_index = jnp.asarray([[[[c, n] for n in range(self.size)]
                                           for c in range(self.num_classes)]])
        class_neuron_index = jnp.tile(class_neuron_index,
                                      reps=(batch_size, 1, 1, 1))
        context_index = jnp.concatenate([class_neuron_index, context_index],
                                        axis=-1)

        dims = lax.GatherDimensionNumbers(offset_dims=(3, ),
                                          collapsed_slice_dims=(0, 1, 2),
                                          start_index_map=(0, 1, 2))
        weights = lax.gather(operand=params['weights'],
                             start_indices=context_index,
                             dimension_numbers=dims,
                             slice_sizes=(1, 1, 1,
                                          self.input_size + int(self.bias)))

        if self.bias:
            bias = jnp.tile(params['bias'], reps=(batch_size, 1, 1))
            logits = jnp.concatenate([logits, bias], axis=-1)
        logits = jnp.expand_dims(logits, axis=-1)

        output_logits = jnp.matmul(weights, logits)
        output_logits = jnp.clip(output_logits,
                                 a_min=jsp.special.logit(self.pred_clipping),
                                 a_max=jsp.special.logit(1.0 -
                                                         self.pred_clipping))

        if target is None:
            return jnp.squeeze(output_logits, axis=-1)

        else:
            logits = jnp.expand_dims(jnp.squeeze(logits, axis=-1), axis=-2)
            output_preds = jnn.sigmoid(output_logits)
            target = jnp.expand_dims(jnp.expand_dims(target, axis=-1), axis=-1)
            params['lr_step'], learning_rate = self.learning_rate.value(
                params['lr_step'])
            delta = learning_rate * (target - output_preds) * logits

            dims = lax.ScatterDimensionNumbers(
                update_window_dims=(3, ),
                inserted_window_dims=(0, 1, 2),
                scatter_dims_to_operand_dims=(0, 1, 2))

            if self.weight_clipping is None:
                params['weights'] = lax.scatter_add(
                    operand=params['weights'],
                    scatter_indices=context_index,
                    updates=delta,
                    dimension_numbers=dims)
            else:
                weights = jnp.clip(weights + delta,
                                   a_min=-self.weight_clipping,
                                   a_max=self.weight_clipping)
                params['weights'] = lax.scatter(operand=params['weights'],
                                                scatter_indices=context_index,
                                                updates=weights,
                                                dimension_numbers=dims)

            return params, jnp.squeeze(output_logits, axis=-1)
示例#13
0
文件: tile.py 项目: gglin001/onnx-jax
def onnx_tile(a, b):
    return jnp.tile(a, b)
示例#14
0
    def __call__(
            self,
            pose_coeffs,
            betas=np.zeros(1),
    ):
        batch_size = pose_coeffs.shape[0]
        # Get axis angle from PCA components and coefficients
        # Remove global rot coeffs
        hand_pose_coeffs = pose_coeffs[:, self.rot:self.rot + self.ncomps]
        full_hand_pose = hand_pose_coeffs

        # Concatenate back global rot
        full_pose = np.concatenate(
            (pose_coeffs[:, :self.rot], self.hands_mean + full_hand_pose), 1)
        # compute rotation matrixes from axis-angle while skipping global rotation
        pose_map, rot_map = self._posemap_axisang(full_pose)
        root_rot = rot_map[:, :9].reshape(batch_size, 3, 3)
        rot_map = rot_map[:, 9:]
        pose_map = pose_map[:, 9:]

        # Full axis angle representation with root joint
        v_shaped = np.matmul(self.shapedirs, betas.transpose(
            (1, 0))).transpose((2, 0, 1)) + self.v_template
        j = np.matmul(self.J_regressor, v_shaped)
        # th_pose_map should have shape 20x135
        v_posed = v_shaped + np.matmul(
            self.posedirs,
            pose_map.transpose((1, 0))[np.newaxis, ...]).transpose((2, 0, 1))
        # Final T pose with transformation done !

        # Global rigid transformation

        root_j = j[:, 0, :].reshape(batch_size, 3, 1)
        root_trans = self._with_zeros(np.concatenate((root_rot, root_j), 2))

        all_rots = rot_map.reshape(rot_map.shape[0], 15, 3, 3)
        lev1_idxs = [1, 4, 7, 10, 13]
        lev2_idxs = [2, 5, 8, 11, 14]
        lev3_idxs = [3, 6, 9, 12, 15]
        lev1_rots = all_rots[:, [idx - 1 for idx in lev1_idxs]]
        lev2_rots = all_rots[:, [idx - 1 for idx in lev2_idxs]]
        lev3_rots = all_rots[:, [idx - 1 for idx in lev3_idxs]]
        lev1_j = j[:, lev1_idxs]
        lev2_j = j[:, lev2_idxs]
        lev3_j = j[:, lev3_idxs]

        # From base to tips
        # Get lev1 results
        all_transforms = [root_trans[:, np.newaxis, ...]]
        lev1_j_rel = lev1_j - root_j.transpose((0, 2, 1))
        lev1_rel_transform_flt = self._with_zeros(
            np.concatenate((lev1_rots, lev1_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        root_trans_flt = np.tile(root_trans[:, np.newaxis, ...],
                                 (1, 5, 1, 1)).reshape(root_trans.shape[0] * 5,
                                                       4, 4)
        lev1_flt = np.matmul(root_trans_flt, lev1_rel_transform_flt)
        all_transforms.append(lev1_flt.reshape(all_rots.shape[0], 5, 4, 4))

        # Get lev2 results
        lev2_j_rel = lev2_j - lev1_j
        lev2_rel_transform_flt = self._with_zeros(
            np.concatenate((lev2_rots, lev2_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        lev2_flt = np.matmul(lev1_flt, lev2_rel_transform_flt)
        all_transforms.append(lev2_flt.reshape(all_rots.shape[0], 5, 4, 4))

        # Get lev3 results
        lev3_j_rel = lev3_j - lev2_j
        lev3_rel_transform_flt = self._with_zeros(
            np.concatenate((lev3_rots, lev3_j_rel[..., np.newaxis]),
                           3).reshape(-1, 3, 4))
        lev3_flt = np.matmul(lev2_flt, lev3_rel_transform_flt)
        all_transforms.append(lev3_flt.reshape(all_rots.shape[0], 5, 4, 4))

        reorder_idxs = [0, 1, 6, 11, 2, 7, 12, 3, 8, 13, 4, 9, 14, 5, 10, 15]
        results = np.concatenate(all_transforms, 1)[:, reorder_idxs]
        results_global = results

        joint_js = np.concatenate((j, np.zeros((j.shape[0], 16, 1))), 2)

        tmp2 = np.matmul(results, joint_js[..., np.newaxis])
        results2 = (results - np.concatenate((np.zeros(
            (*tmp2.shape[:2], 4, 3)), tmp2), 3)).transpose((0, 2, 3, 1))

        T = np.matmul(results2, self.weights.transpose((1, 0)))

        rest_shape_h = np.concatenate((v_posed.transpose(
            (0, 2, 1)), np.ones((batch_size, 1, v_posed.shape[1]))), 1)

        verts = (T * rest_shape_h[:, np.newaxis, ...]).sum(2).transpose(
            (0, 2, 1))
        verts = verts[:, :, :3]
        jtr = results_global[:, :, :3, 3]
        # In addition to MANO reference joints we sample vertices on each finger
        # to serve as finger tips
        tips = verts[:, [745, 317, 444, 556, 673]]
        jtr = np.concatenate((jtr, tips), 1)

        # Reorder joints to match visualization utilities
        jtr = jtr[:, [
            0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8,
            9, 20
        ]]

        center_joint = jtr[:, self.center_idx][:, np.newaxis, ...]
        jtr = jtr - center_joint
        verts = verts - center_joint

        return verts, jtr, full_pose
示例#15
0
    def __call__(self, x, num_quantiles, rng):

        if self.net_conf == 'minatar':
            x = x.squeeze(3)
            x = x.astype(jnp.float32)
            x = nn.Conv(features=16,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))

        elif self.net_conf == 'atari':
            # We need to add a "batch dimension" as nn.Conv expects it, yet vmap will
            # have removed the true batch dimension.
            x = x.astype(jnp.float32) / 255.
            x = nn.Conv(features=32,
                        kernel_size=(8, 8),
                        strides=(4, 4),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(4, 4),
                        strides=(2, 2),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = nn.Conv(features=64,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        kernel_init=self.initzer)(x)
            x = jax.nn.relu(x)
            x = x.reshape((-1))  # flatten

        elif self.net_conf == 'classic':
            #classic environments
            x = x.astype(jnp.float32)
            x = x.reshape((-1))

        if self.env is not None and self.env in env_inf:
            x = x - env_inf[self.env]['MIN_VALS']
            x /= env_inf[self.env]['MAX_VALS'] - env_inf[self.env]['MIN_VALS']
            x = 2.0 * x - 1.0

        if self.noisy:

            def net(x, features, rng):
                return NoisyNetwork(features, rng=rng, bias_in=True)(x)
        else:

            def net(x, features, rng):
                return nn.Dense(features, kernel_init=self.initzer)(x)

        for _ in range(self.hidden_layer):
            x = net(x, features=self.neurons, rng=rng)
            x = jax.nn.relu(x)

        state_vector_length = x.shape[-1]
        state_net_tiled = jnp.tile(x, [num_quantiles, 1])
        quantiles_shape = [num_quantiles, 1]
        quantiles = jax.random.uniform(rng, shape=quantiles_shape)
        quantile_net = jnp.tile(quantiles, [1, self.quantile_embedding_dim])
        quantile_net = (jnp.arange(1, self.quantile_embedding_dim + 1,
                                   1).astype(jnp.float32) * onp.pi *
                        quantile_net)
        quantile_net = jnp.cos(quantile_net)
        quantile_net = nn.Dense(features=state_vector_length,
                                kernel_init=self.initzer)(quantile_net)
        quantile_net = jax.nn.relu(quantile_net)
        x = state_net_tiled * quantile_net

        adv = net(x, features=self.num_actions, rng=rng)
        val = net(x, features=1, rng=rng)
        dueling_q = val + (adv - (jnp.mean(adv, -1, keepdims=True)))
        non_dueling_q = net(x, features=self.num_actions, rng=rng)
        quantile_values = jnp.where(self.dueling, dueling_q, non_dueling_q)

        return atari_lib.ImplicitQuantileNetworkType(quantile_values,
                                                     quantiles)
示例#16
0
    def _beam_search(
        self,
        input_ids: None,
        max_length: Optional[int] = None,
        pad_token_id: Optional[int] = None,
        eos_token_id: Optional[int] = None,
        length_penalty: Optional[float] = None,
        early_stopping: Optional[bool] = None,
        logits_processor: Optional[FlaxLogitsProcessorList] = None,
        trace: bool = True,
        params: Optional[Dict[str, jax_xla.DeviceArray]] = None,
        model_kwargs: Optional[Dict[str, jax_xla.DeviceArray]] = None,
    ):
        """
        This beam search function is heavily inspired by Flax's official example:
        https://github.com/google/flax/blob/master/examples/wmt/train.py#L254
        """

        def flatten_beam_dim(tensor):
            """Flattens the first two dimensions of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])

        def unflatten_beam_dim(tensor, batch_size, num_beams):
            """Unflattens the first, flat batch*beam dimension of a non-scalar array."""
            # ignore scalars (e.g. cache index)
            if tensor.ndim == 0:
                return tensor
            return tensor.reshape((batch_size, num_beams) + tensor.shape[1:])

        def gather_beams(nested, beam_indices, batch_size, new_num_beams):
            """
            Gathers the beam slices indexed by beam_indices into new beam array.
            """
            batch_indices = jnp.reshape(
                jnp.arange(batch_size * new_num_beams) // new_num_beams, (batch_size, new_num_beams)
            )

            def gather_fn(tensor):
                # ignore scalars (e.g. cache index)
                if tensor.ndim == 0:
                    return tensor
                else:
                    return tensor[batch_indices, beam_indices]

            return jax.tree_map(gather_fn, nested)

        # init values
        max_length = max_length if max_length is not None else self.config.max_length
        pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
        eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
        length_penalty = length_penalty if length_penalty is not None else self.config.length_penalty
        early_stopping = early_stopping if early_stopping is not None else self.config.early_stopping

        batch_size, num_beams, cur_len = input_ids.shape

        eos_token_id = jnp.array(eos_token_id)
        pad_token_id = jnp.array(pad_token_id)
        cur_len = jnp.array(cur_len)

        # per batch,beam-item holding current token in loop.
        sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = jnp.full((batch_size, num_beams, max_length), pad_token_id, dtype=jnp.int32)
        running_sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0, 0))

        # per batch,beam-item state bit indicating if sentence has finished.
        is_sent_finished = jnp.zeros((batch_size, num_beams), dtype=jnp.bool_)

        # per batch,beam-item score, logprobs
        running_scores = jnp.tile(jnp.array([0.0] + [np.array(-1.0e7)] * (num_beams - 1)), [batch_size, 1])
        scores = jnp.ones((batch_size, num_beams)) * np.array(-1.0e7)

        # For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
        # and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
        model = self.decode if self.config.is_encoder_decoder else self

        # flatten beam dim
        if "encoder_outputs" in model_kwargs:
            model_kwargs["encoder_outputs"]["last_hidden_state"] = flatten_beam_dim(
                model_kwargs["encoder_outputs"]["last_hidden_state"]
            )
        if "attention_mask" in model_kwargs:
            model_kwargs["attention_mask"] = flatten_beam_dim(model_kwargs["attention_mask"])

        # initialize model specific kwargs
        model_kwargs = self.prepare_inputs_for_generation(flatten_beam_dim(input_ids), max_length, **model_kwargs)

        # initialize state
        state = BeamSearchState(
            cur_len=cur_len,
            running_sequences=running_sequences,
            running_scores=running_scores,
            sequences=sequences,
            scores=scores,
            is_sent_finished=is_sent_finished,
            model_kwargs=model_kwargs,
        )

        def beam_search_cond_fn(state):
            """beam search state termination condition fn."""

            # 1. is less than max length?
            not_max_length_yet = state.cur_len < max_length

            # 2. can the new beams still improve?
            best_running_score = state.running_scores[:, -1:] / (max_length ** length_penalty)
            worst_finished_score = jnp.where(
                state.is_sent_finished, jnp.min(state.scores, axis=1, keepdims=True), np.array(-1.0e7)
            )
            improvement_still_possible = jnp.all(worst_finished_score < best_running_score)

            # 3. is there still a beam that has not finished?
            still_open_beam = ~(jnp.all(state.is_sent_finished) & early_stopping)

            return not_max_length_yet & still_open_beam & improvement_still_possible

        def beam_search_body_fn(state):
            """beam search state update fn."""
            # 1. Forward current tokens
            # Collect the current position slice along length to feed the fast
            # autoregressive decoder model.  Flatten the beam dimension into batch
            # dimension for feeding into the model.
            # unflatten beam dimension
            # Unflatten beam dimension in attention cache arrays
            input_token = flatten_beam_dim(
                lax.dynamic_slice(state.running_sequences, (0, 0, state.cur_len - 1), (batch_size, num_beams, 1))
            )
            model_outputs = model(input_token, params=params, **state.model_kwargs)
            logits = unflatten_beam_dim(model_outputs.logits[:, 0], batch_size, num_beams)
            cache = jax.tree_map(
                lambda tensor: unflatten_beam_dim(tensor, batch_size, num_beams), model_outputs.past_key_values
            )

            # adapt logits for FlaxMarianMTModel
            logits = self._adapt_logits_for_beam_search(logits)

            # 2. Compute log probs
            # get log probabilities from logits,
            # process logits with processors (*e.g.* min_length, ...), and
            # add new logprobs to existing running logprobs scores.
            log_probs = jax.nn.log_softmax(logits)
            log_probs = logits_processor(
                flatten_beam_dim(running_sequences), flatten_beam_dim(log_probs), state.cur_len
            )
            log_probs = unflatten_beam_dim(log_probs, batch_size, num_beams)
            log_probs = log_probs + jnp.expand_dims(state.running_scores, axis=2)
            vocab_size = log_probs.shape[2]
            log_probs = log_probs.reshape((batch_size, num_beams * vocab_size))

            # 3. Retrieve top-K
            # Each item in batch has num_beams * vocab_size candidate sequences.
            # For each item, get the top 2*k candidates with the highest log-
            # probabilities. We gather the top 2*K beams here so that even if the best
            # K sequences reach EOS simultaneously, we have another K sequences
            # remaining to continue the live beam search.
            # Gather the top 2*K scores from _all_ beams.
            # Gather 2*k top beams.
            # Recover the beam index by floor division.
            # Recover token id by modulo division and expand Id array for broadcasting.
            # Update sequences for the 2*K top-k new sequences.
            beams_to_keep = 2 * num_beams
            topk_log_probs, topk_indices = lax.top_k(log_probs, k=beams_to_keep)
            topk_beam_indices = topk_indices // vocab_size
            topk_running_sequences = gather_beams(
                state.running_sequences, topk_beam_indices, batch_size, beams_to_keep
            )
            topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2)
            topk_sequences = lax.dynamic_update_slice(topk_running_sequences, topk_ids, (0, 0, state.cur_len))

            # 4. Check which sequences have ended
            # Update current sequences:
            # Did any of these sequences reach an end marker?
            # To prevent these just finished sequences from being added to the current sequences
            # set of active beam search sequences, set their log probs to a very large
            # negative value.
            did_topk_just_finished = topk_sequences[:, :, state.cur_len] == eos_token_id
            topk_log_probs = topk_log_probs + did_topk_just_finished * np.array(-1.0e7)

            # 5. Get running sequences scores for next
            # Determine the top k beam indices (from top 2*k beams) from log probs
            # and gather top k beams (from top 2*k beams).
            next_topk_indices = jnp.flip(lax.top_k(topk_log_probs, k=num_beams)[1], axis=1)
            next_running_sequences, next_running_scores = gather_beams(
                [topk_sequences, topk_log_probs], next_topk_indices, batch_size, num_beams
            )

            # 6. Process topk logits
            # Further process log probs:
            # - add length penalty
            # - make sure no scores can be added anymore if beam is full
            # - make sure still running sequences cannot be chosen as finalized beam
            topk_log_probs = topk_log_probs / (state.cur_len ** length_penalty)
            beams_in_batch_are_full = (
                jnp.broadcast_to(state.is_sent_finished.all(axis=-1, keepdims=True), did_topk_just_finished.shape)
                & early_stopping
            )
            add_penalty = ~did_topk_just_finished | beams_in_batch_are_full
            topk_log_probs += add_penalty * np.array(-1.0e7)

            # 7. Get scores, sequences, is sentence finished for next.
            # Combine sequences, scores, and flags along the beam dimension and compare
            # new finished sequence scores to existing finished scores and select the
            # best from the new set of beams
            merged_sequences = jnp.concatenate([state.sequences, topk_sequences], axis=1)
            merged_scores = jnp.concatenate([state.scores, topk_log_probs], axis=1)
            merged_is_sent_finished = jnp.concatenate([state.is_sent_finished, did_topk_just_finished], axis=1)
            topk_merged_indices = jnp.flip(lax.top_k(merged_scores, k=num_beams)[1], axis=1)
            next_sequences, next_scores, next_is_sent_finished = gather_beams(
                [merged_sequences, merged_scores, merged_is_sent_finished], topk_merged_indices, batch_size, num_beams
            )

            # 8. Update model kwargs.
            # Determine the top k beam indices from the original set of all beams.
            # With these, gather the top k beam-associated caches.
            next_running_indices = gather_beams(topk_beam_indices, next_topk_indices, batch_size, num_beams)
            next_cache = gather_beams(cache, next_running_indices, batch_size, num_beams)
            model_outputs["past_key_values"] = jax.tree_map(lambda x: flatten_beam_dim(x), next_cache)
            next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)

            return BeamSearchState(
                cur_len=state.cur_len + 1,
                running_scores=next_running_scores,
                running_sequences=next_running_sequences,
                scores=next_scores,
                sequences=next_sequences,
                is_sent_finished=next_is_sent_finished,
                model_kwargs=next_model_kwargs,
            )

        # The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
        state = beam_search_body_fn(state)

        if not trace:
            state = self._run_loop_in_debug(beam_search_cond_fn, beam_search_body_fn, state)
        else:
            state = lax.while_loop(beam_search_cond_fn, beam_search_body_fn, state)

        # Account for the edge-case where there are no finished sequences for a
        # particular batch item. If so, return running sequences for that batch item.
        none_finished = jnp.any(state.is_sent_finished, axis=1)
        sequences = jnp.where(none_finished[:, None, None], state.sequences, state.running_sequences)
        scores = jnp.where(none_finished[:, None], state.scores, state.running_scores)

        # take best beam for each batch
        sequences = sequences[:, -1]
        scores = scores[:, -1]

        return FlaxBeamSearchOutput(sequences=sequences, scores=scores)
示例#17
0
 def init_fn(position: Array) -> PHopState:
   position_buffer = jnp.tile(position, (window_size, 1, 1))
   assert position_buffer.shape == ((window_size,) + position.shape)
   return PHopState(position_buffer, jnp.zeros((position.shape[0],)))  # pytype: disable=wrong-arg-count