def test_jit_or_pmap_broadcast(self):
        def kernel_fn(x1,
                      x2,
                      do_flip,
                      keys,
                      do_square,
                      params,
                      _unused=None,
                      p=0.65):
            res = np.abs(np.matmul(x1, x2))
            if do_square:
                res *= res
            if do_flip:
                res = -res

            res *= random.uniform(keys) * p
            return [res, params]

        params = (np.array([1., 0.3]), (np.array([1.2]), np.array([0.5])))
        x2 = np.arange(0, 10).reshape((10, ))
        keys = random.PRNGKey(1)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=0)
        x1 = np.arange(0, 10).reshape((1, 10))
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=0):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=True,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=True)
                    self.assertAllClose(res_1, res_2, True)

        test_utils.stub_out_pmap(batch, 1)
        x1 = np.arange(0, 10).reshape((1, 10))
        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=1)
        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=1):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      _unused=False,
                                      p=0.65)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None)
                    self.assertAllClose(res_1[0], res_2[0], True)
                    self.assertAllClose(
                        tree_map(partial(np.expand_dims, axis=0), res_1[1]),
                        res_2[1], True)

        kernel_fn_pmapped = batch._jit_or_pmap_broadcast(kernel_fn,
                                                         device_count=2)
        x1 = np.arange(0, 20).reshape((2, 10))
        test_utils.stub_out_pmap(batch, 2)

        def broadcast(arg):
            return np.broadcast_to(arg, (2, ) + arg.shape)

        for do_flip in [True, False]:
            for do_square in [True, False]:
                with self.subTest(do_flip=do_flip,
                                  do_square=do_square,
                                  device_count=2):
                    res_1 = kernel_fn(x1,
                                      x2,
                                      do_flip,
                                      keys,
                                      do_square,
                                      params,
                                      p=0.2)
                    res_2 = kernel_fn_pmapped(x1,
                                              x2,
                                              do_flip,
                                              keys,
                                              do_square,
                                              params,
                                              _unused=None,
                                              p=0.2)
                    self.assertAllClose(res_1[0][0], res_2[0][0], True)
                    self.assertAllClose(res_1[0][1], res_2[0][1], True)
                    self.assertAllClose(tree_map(broadcast, res_1[1]),
                                        res_2[1], True)
def loss_fn(params, observations, q_values, actions):
    logits = model.apply(params, observations)
    # We pick the values of taken actions
    logits = logits[np.arange(logits.shape[0]), actions]
    return nn.losses.mse(q_values, logits, reduction='mean')
Exemple #3
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     # Ref: http://web.mit.edu/18.325/www/handouts/handout2.pdf page 13
     n = jnp.shape(x)[-1]
     order = -jnp.arange(n, 0, -1)
     return -n * jnp.log(2) + jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1)
Exemple #4
0
    def _encode(self,
                inputs: JTensor,
                input_paddings: JTensor,
                input_segment_ids: Optional[JTensor] = None,
                input_segment_pos: Optional[JTensor] = None) -> JTensor:
        """Apply the Transformer encoder to the source sequence.

    Args:
      inputs: Input ids. An int32 JTensor of shape [B, S].
      input_paddings: A 0/1 JTensor of shape [B, S] with 1 denoting padding
        correspdonding to the input sequence.
      input_segment_ids: A JTensor of shape [B,S]. The segment that each input
        token belongs to.
      input_segment_pos: A JTensor of shape [B, S]. The position of each input
        token within a segment.

    Returns:
      The encoded sequence after applying the Transformer encoder.
    """
        p = self.params
        batch, seq_length = inputs.shape
        if p.encoder_embedding_tpl is not None:
            # Encoder has its own embedding lookup table for source ids.
            input_emb = self.encoder_embedding_lookup.fprop(inputs)
        elif p.decoder_embedding_tpl is not None:
            # Encoder shares the same embedding as the target ids.
            # The embedding lookup for target ids is separate from the softmax.
            input_emb = self.decoder_embedding_lookup.fprop(inputs)
        else:
            # Encoder and decoder share the softmax and embedding params.
            input_emb = self.softmax.emb_lookup(inputs)

        if input_segment_ids is None:
            assert input_segment_pos is None
            # Fold the paddings with the segment mask.
            input_segment_ids = jnp.asarray(1 - input_paddings, jnp.int32)
            input_segment_pos = jnp.tile(
                jnp.arange(seq_length, dtype=jnp.int32)[None, :], [batch, 1])
        assert input_segment_ids is not None
        assert input_segment_pos is not None

        # Add NGrammer to the source embeddings.
        if p.encoder_ngrammer_tpl is not None:
            input_emb = self.encoder_ngrammer.fprop(
                input_ids=inputs,
                input_embs=input_emb,
                paddings=input_paddings,
                segment_pos=input_segment_pos)

        if p.position_emb_tpl is not None:
            position_emb = self.position_emb.fprop(seq_length=seq_length,
                                                   position=input_segment_pos)
            input_emb += position_emb

        inputs_segment_mask = attentions.segment_mask(input_segment_ids,
                                                      dtype=input_emb.dtype)
        encoder_output = self.encoder.fprop(input_emb,
                                            input_paddings,
                                            segment_mask=inputs_segment_mask)

        # Final layer norm for encoder output.
        encoder_output = self.encoder_ln.fprop(encoder_output)
        return encoder_output
                          sigR,
                          gridperdeg=gridperdeg,
                          gridsizedeg=gridsizedeg)

ssn_Ampa = SSN_classes._SSN_AMPAGABA(tau_s, NMDAratio, n, k, Ne, Ni, tau_vec,
                                     W)
r_init = np.zeros([ssn_Ampa.N, len(Contrasts)])
inp_vec = np.vstack((gE * Inp, gI * Inp))

r_fp, CONVG = ssn_Ampa.fixed_point_r(inp_vec,
                                     r_init=r_init,
                                     Tmax=Tmax,
                                     dt=dt,
                                     xtol=xtol)

gen_inds = np.arange(len(Contrasts))
rad_inds = np.arange(
    len(contrasts) - 1,
    len(r_cent) + len(contrasts) -
    1)  #np.where(stimCon[0, :] == np.max(Contrasts), gen_inds, 0)
con_inds = np.hstack(
    (np.arange(0,
               len(contrasts) - 1), len(r_cent) + len(contrasts) -
     2))  #np.where(stimCon[1, :] == np.max(stimCon[1,:]), gen_inds, 0)
gabor_inds = -1

trgt = np.floor(Ne / 2)

con_inds = np.hstack((con_inds, gabor_inds))
cons = len(con_inds)
ssn_Ampa.topos_vec = np.ravel(OMap)
Exemple #6
0
 def test_jit_large(self):
     arg = jnp.arange(10000, dtype=jnp.int32).reshape((10, 10, 5, -1))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(hcb.id_print)(arg)
Exemple #7
0
def vectorized_perturbative_triples(T1, T2, V, fock_Od, fock_Vd):
    Voooo, Vooov, Voovv, Vovov, Vovvv, Vvvvv = V
    # below equations are in chemists, so transpose
    Vvvvo = jnp.transpose(Vovvv, (3, 1, 2, 0))
    Vvooo = jnp.transpose(Vooov, (3, 1, 0, 2))
    Vvovo = jnp.transpose(Voovv, (2, 0, 3, 1))
    o, v = T1.shape
    delta_o = jnp.eye(o)
    delta_v = jnp.eye(v)
    # IDEA: Build up index arrays which mimic loop structure TODO regular numpy probably better here, with int16's
    occ_range = jnp.arange(o)
    vir_range = jnp.arange(v)
    occ_indices = cartesian_product(occ_range, occ_range, occ_range)
    i, j, k = occ_indices[:, 0], occ_indices[:, 1], occ_indices[:, 2]
    occ_cond = (i <= j) & (j <= k)
    vir_indices = cartesian_product(vir_range, vir_range, vir_range)
    a, b, c = occ_indices[:, 0], occ_indices[:, 1], occ_indices[:, 2]
    vir_cond = (a <= b) & (b <= c)
    # Now have all indices prepared
    occ_indices = occ_indices[occ_cond]
    vir_indices = vir_indices[vir_cond]
    i, j, k = occ_indices[:, 0], occ_indices[:, 1], occ_indices[:, 2]
    a, b, c = occ_indices[:, 0], occ_indices[:, 1], occ_indices[:, 2]

    @jax.jit
    def inner_func(i, j, k):
        delta_ij = delta_o[i, j]
        delta_jk = delta_o[j, k]
        W = jnp.einsum('bda,cd', Vvvvo[:, :, :, i], T2[k, j, :, :])
        W -= jnp.einsum('cl,lab', Vvooo[:, k, j, :], T2[i, :, :, :])
        W += jnp.einsum('cda,bd', Vvvvo[:, :, :, i], T2[j, k, :, :])
        W -= jnp.einsum('bl,lac', Vvooo[:, j, k, :], T2[i, :, :, :])
        W += jnp.einsum('adc,bd', Vvvvo[:, :, :, k], T2[j, i, :, :])
        W -= jnp.einsum('bl,lca', Vvooo[:, j, i, :], T2[k, :, :, :])
        W += jnp.einsum('bdc,ad', Vvvvo[:, :, :, k], T2[i, j, :, :])
        W -= jnp.einsum('al,lcb', Vvooo[:, i, j, :], T2[k, :, :, :])
        W += jnp.einsum('cdb,ad', Vvvvo[:, :, :, j], T2[i, k, :, :])
        W -= jnp.einsum('al,lbc', Vvooo[:, i, k, :], T2[j, :, :, :])
        W += jnp.einsum('adb,cd', Vvvvo[:, :, :, j], T2[k, i, :, :])
        W -= jnp.einsum('cl,lba', Vvooo[:, k, i, :], T2[j, :, :, :])
        V  = W + jnp.einsum('bc,a', Vvovo[:,j,:,k], T1[i,:]) \
               + jnp.einsum('ac,b', Vvovo[:,i,:,k], T1[j,:]) \
               + jnp.einsum('ab,c', Vvovo[:,i,:,j], T1[k,:])

        delta_occ = 2 - delta_ij - delta_jk
        Dd = fock_Od[i] + fock_Od[j] + fock_Od[k]

        with loops.Scope() as s:
            s.pT_contribution = 0.0
            # TODO is while looping better here?
            for vir_idx in s.range(vir_indices.shape[0]):
                a, b, c = vir_indices[vir_idx]
                delta_ab = delta_v[a, b]
                delta_bc = delta_v[b, c]
                #Dd = fock_Od[i] + fock_Od[j] + fock_Od[k] - fock_Vd[a] - fock_Vd[b] - fock_Vd[c]
                Dd -= fock_Vd[a] + fock_Vd[b] + fock_Vd[c]
                X = W[a,b,c]*V[a,b,c] + W[a,c,b]*V[a,c,b] + W[b,a,c]*V[b,a,c]  \
                  + W[b,c,a]*V[b,c,a] + W[c,a,b]*V[c,a,b] + W[c,b,a]*V[c,b,a]
                Y = (V[a, b, c] + V[b, c, a] + V[c, a, b])
                Z = (V[a, c, b] + V[b, a, c] + V[c, b, a])
                E = (Y - 2 * Z) * (W[a, b, c] + W[b, c, a] + W[c, a, b]) + (
                    Z - 2 * Y) * (W[a, c, b] + W[b, a, c] + W[c, b, a]) + 3 * X
                #s.pT_contribution += E * (2 - delta_ij - delta_jk)  / (Dd * (1 + delta_ab + delta_bc))
                s.pT_contribution += E * (delta_occ) / (
                    Dd * (1 + delta_ab + delta_bc))
            return s.pT_contribution

    with loops.Scope() as S:
        S.pT = 0.0
        for occ_idx in S.range(occ_indices.shape[0]):
            i, j, k = occ_indices[occ_idx]
            S.pT += inner_func(i, j, k)
        return S.pT
    train_time = 0
    epochs = tqdm(range(num_epochs),
                  desc=f"Epoch ... (1/{num_epochs})",
                  position=0)
    for epoch in epochs:
        # ======================== Training ================================
        train_start = time.time()
        train_metrics = []

        # Create sampling rng
        rng, input_rng = jax.random.split(rng)

        # Generate an epoch by shuffling sampling indices from the train dataset
        num_train_samples = len(tokenized_datasets["train"])
        train_samples_idx = jax.random.permutation(
            input_rng, jnp.arange(num_train_samples))
        train_batch_idx = generate_batch_splits(train_samples_idx,
                                                train_batch_size)

        # Gather the indexes for creating the batch and do a training step
        for step, batch_idx in enumerate(
                tqdm(train_batch_idx, desc="Training...", position=1)):
            samples = [
                tokenized_datasets["train"][int(idx)] for idx in batch_idx
            ]
            model_inputs = data_collator(samples, pad_to_multiple_of=16)

            # Model forward
            model_inputs = shard(model_inputs.data)
            state, train_metric, dropout_rngs = p_train_step(
                state, model_inputs, dropout_rngs)
Exemple #9
0
# Output file name is out.XXXX.YYYY.jpg
# where XXXX is the current doubling and
#       YYYY is the iteration on the current doubling

# Other things to try:
# The `scatter_nd` line below scatters the initial sand grains.
# It's easily modified.
# The first argument to `scatter_nd` is the list of coordinates
# and the second argument is the list of corresponding numbers of grains.
#
# You can try varying the kernel, eg. for a hexagonal grid
# try [[1, 1, 0], [1, 0, 1], [0, 1, 1]]
# and warping the resulting image.

background = np.full((h, w), float_type(init_background))
i = np.arange(h)
j = np.arange(w)
sand = np.where(np.logical_and(i[:, None] == h // 2, j[None, :] == w // 2),
                base_size, 0)
init = background + sand
init = init[None, :, :, None]

kernel = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 0]], dtype=float_type)
kernel = kernel[:, :, None, None]
neighbours = np.sum(kernel)

pile = init


@jit
def reduce(pile):
Exemple #10
0
 def fn(x, k, dtype='float32'):
     """Create a one-hot encoding of x of size k."""
     return jnp.array(x[:, None] == jnp.arange(k), dtype)
 def testMap(self):
     f = lambda x: x**2
     xs = np.arange(10)
     expected = xs**2
     actual = lax.map(f, xs)
     self.assertAllClose(actual, expected, check_dtypes=True)
    def __call__(
        self,
        input_ids,
        attention_mask=None,
        position_ids=None,
        encoder_hidden_states: Optional[jnp.ndarray] = None,
        encoder_attention_mask: Optional[jnp.ndarray] = None,
        params: dict = None,
        past_key_values: dict = None,
        dropout_rng: jax.random.PRNGKey = None,
        train: bool = False,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.return_dict

        if encoder_hidden_states is not None and encoder_attention_mask is None:
            batch_size, sequence_length = encoder_hidden_states.shape[:2]
            encoder_attention_mask = jnp.ones((batch_size, sequence_length))

        batch_size, sequence_length = input_ids.shape

        if position_ids is None:
            if past_key_values is not None:
                raise ValueError("Make sure to provide `position_ids` when passing `past_key_values`.")

            position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))

        if attention_mask is None:
            attention_mask = jnp.ones((batch_size, sequence_length))

        # Handle any PRNG if needed
        rngs = {}
        if dropout_rng is not None:
            rngs["dropout"] = dropout_rng

        inputs = {"params": params or self.params}

        # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module
        if past_key_values:
            inputs["cache"] = past_key_values
            mutable = ["cache"]
        else:
            mutable = False

        outputs = self.module.apply(
            inputs,
            jnp.array(input_ids, dtype="i4"),
            jnp.array(attention_mask, dtype="i4"),
            jnp.array(position_ids, dtype="i4"),
            encoder_hidden_states,
            encoder_attention_mask,
            not train,
            False,
            output_attentions,
            output_hidden_states,
            return_dict,
            rngs=rngs,
            mutable=mutable,
        )

        # add updated cache to model output
        if past_key_values is not None and return_dict:
            outputs, past_key_values = outputs
            outputs["past_key_values"] = unfreeze(past_key_values["cache"])
            return outputs
        elif past_key_values is not None and not return_dict:
            outputs, past_key_values = outputs
            outputs = outputs[:1] + (unfreeze(past_key_values["cache"]),) + outputs[1:]

        return outputs
Exemple #13
0
def lengths_to_paddings(lengths: JTensor, maxlength: int) -> JTensor:
  indices = jnp.arange(maxlength).reshape((1,) * lengths.ndim + (maxlength,))
  lengths = jnp.expand_dims(lengths, axis=-1)
  elem_valid = indices < lengths
  return np.logical_not(elem_valid).astype(np.float32)
Exemple #14
0
 def zakharov_fn(x):
   ii = jnp.arange(1, len(x) + 1, step=1)
   answer = zakharovFromIndices(x=x, ii=ii)
   return answer
Exemple #15
0
def compute_ssim(img0,
                 img1,
                 max_val,
                 filter_size=11,
                 filter_sigma=1.5,
                 k1=0.01,
                 k2=0.03,
                 return_map=False):
    """Computes SSIM from two images.

  This function was modeled after tf.image.ssim, and should produce comparable
  output.

  Args:
    img0: array. An image of size [..., width, height, num_channels].
    img1: array. An image of size [..., width, height, num_channels].
    max_val: float > 0. The maximum magnitude that `img0` or `img1` can have.
    filter_size: int >= 1. Window size.
    filter_sigma: float > 0. The bandwidth of the Gaussian used for filtering.
    k1: float > 0. One of the SSIM dampening parameters.
    k2: float > 0. One of the SSIM dampening parameters.
    return_map: Bool. If True, will cause the per-pixel SSIM "map" to returned

  Returns:
    Each image's mean SSIM, or a tensor of individual values if `return_map`.
  """
    # Construct a 1D Gaussian blur filter.
    hw = filter_size // 2
    shift = (2 * hw - filter_size + 1) / 2
    f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma)**2
    filt = jnp.exp(-0.5 * f_i)
    filt /= jnp.sum(filt)

    # Blur in x and y (faster than the 2D convolution).
    filt_fn1 = lambda z: jsp.signal.convolve2d(z, filt[:, None], mode="valid")
    filt_fn2 = lambda z: jsp.signal.convolve2d(z, filt[None, :], mode="valid")

    # Vmap the blurs to the tensor size, and then compose them.
    num_dims = len(img0.shape)
    map_axes = tuple(list(range(num_dims - 3)) + [num_dims - 1])
    for d in map_axes:
        filt_fn1 = jax.vmap(filt_fn1, in_axes=d, out_axes=d)
        filt_fn2 = jax.vmap(filt_fn2, in_axes=d, out_axes=d)
    filt_fn = lambda z: filt_fn1(filt_fn2(z))

    mu0 = filt_fn(img0)
    mu1 = filt_fn(img1)
    mu00 = mu0 * mu0
    mu11 = mu1 * mu1
    mu01 = mu0 * mu1
    sigma00 = filt_fn(img0**2) - mu00
    sigma11 = filt_fn(img1**2) - mu11
    sigma01 = filt_fn(img0 * img1) - mu01

    # Clip the variances and covariances to valid values.
    # Variance must be non-negative:
    sigma00 = jnp.maximum(0., sigma00)
    sigma11 = jnp.maximum(0., sigma11)
    sigma01 = jnp.sign(sigma01) * jnp.minimum(jnp.sqrt(sigma00 * sigma11),
                                              jnp.abs(sigma01))

    c1 = (k1 * max_val)**2
    c2 = (k2 * max_val)**2
    numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
    denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
    ssim_map = numer / denom
    ssim = jnp.mean(ssim_map, list(range(num_dims - 3, num_dims)))
    return ssim_map if return_map else ssim
Exemple #16
0
def onehot(labels, num_classes=10):
    x = (labels[..., None] == jnp.arange(num_classes)[None])
    return x.astype(jnp.float32)
Exemple #17
0
 def test_pmap_error_no_receiver(self):
     # Check for errors if starting jit without a consumer active
     vargs = 2. + jnp.arange(api.local_device_count(), dtype=jnp.float32)
     with self.assertRaisesRegex(ValueError,
                                 "outfeed_receiver is not started"):
         api.pmap(lambda x: hcb.id_print(x))(vargs)
Exemple #18
0
t_grid, x_grid = jnp.meshgrid(t, x, indexing="ij")
u = burgers(x_grid, t_grid, 0.1, 1.0)

X_train = jnp.concatenate(
    [t_grid.reshape(-1, 1), x_grid.reshape(-1, 1)], axis=1)
y_train = u.reshape(-1, 1)

# Instantiating model and optimizers
model = Deepmod(features=[50, 50, 1])
key = random.PRNGKey(42)
params = model.init(key, X_train)
optimizer = optim.Adam(learning_rate=2e-3, beta1=0.99, beta2=0.99)
optimizer = optimizer.create(params)

# Compiling train step
update = create_update(loss_fn_pinn, model=model, x=X_train, y=y_train)
_ = update(optimizer)  # triggering compilation

# Running to convergence
max_epochs = 10001
t_start = time()
for i in jnp.arange(max_epochs):
    optimizer, loss = update(optimizer)
    if i % 1000 == 0:
        print(f"Loss step {i}: {loss}")
t_end = time()
print(t_end - t_start)
theta, coeffs = model.apply(optimizer.target, X_train)[2:]
print(coeffs * jnp.linalg.norm(theta, axis=0, keepdims=True).T)
Exemple #19
0
 def test_jit_several_together(self):
     arg = jnp.arange(50, dtype=jnp.int32).reshape((10, 5))
     with hcb.outfeed_receiver(receiver_name=self._testMethodName):
         api.jit(lambda x, y: hcb.id_print((x, y, x * 2.)))(
             arg, jnp.ones(100, dtype=jnp.int32))
Exemple #20
0
def lorenz96_dynamics(x: jnp.ndarray,
                      t: float,
                      forcing_constant: float) -> jnp.ndarray:
    d = len(x)
    return (x[(jnp.arange(d) + 1) % d] - x[(jnp.arange(d) - 2) % d]) * x[(jnp.arange(d) - 1) % d] \
           - x + forcing_constant
Exemple #21
0
    def fprop(self,
              inputs: JTensor,
              paddings: JTensor,
              labels: Optional[NestedMap] = None,
              segment_ids: Optional[JTensor] = None,
              segment_pos: Optional[JTensor] = None) -> NestedMap:
        """Computes xent loss given the language model inputs.

    Args:
      inputs: Input ids. An int32 JTensor of shape [B, T].
      paddings: A 0/1 JTensor of shape [B, T] with 1 denoting padding.
      labels: A `.NestedMap` containing the following fields: class_weights, a
        JTensor with shape [batch, seqlen] containing weights for each target
        word. class_ids, a JTensor with shape [B, T] of int32 dtype containing
        the target class labels. class_probabilities, a JTensor with shape [B,
        T, V] of float values indicating class-membership probabilities.
      segment_ids: A JTensor of shape [B, T]. The segment that each token
        belongs to.
      segment_pos: A JTensor of shape [B, T]. The position of each token in a
        segment.

    Returns:
      Returns xent_output, where
      `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return. In
      addition, per_sequence_xent is added which equal to the sum of xent loss
      for tokens in a sequence.
    """
        p = self.params
        # reentrant=True, to enable scan-local context override.
        with py_utils.AuxLossContext(reentrant=True) as aux_loss_ctx:
            assert aux_loss_ctx is not None
            # Get the input embeddings.
            if self.params.separate_embedding_tpl is not None:
                input_emb = self.embedding_lookup.fprop(inputs)
            else:
                input_emb = self.softmax.emb_lookup(inputs)
            batch, seq_length = inputs.shape

            if segment_ids is None:
                assert segment_pos is None
                # Fold the paddings with the segment mask
                segment_ids = jnp.asarray(1 - paddings, jnp.int32)
                segment_pos = jnp.tile(
                    jnp.arange(seq_length, dtype=jnp.int32)[None, :],
                    [batch, 1])

            # Add NGrammer to the source embeddings.
            if p.ngrammer_tpl is not None:
                input_emb = self.ngrammer.fprop(input_ids=inputs,
                                                input_embs=input_emb,
                                                paddings=paddings,
                                                segment_pos=segment_pos)

            if p.position_emb_tpl is not None:
                position_emb = self.position_emb.fprop(seq_length=seq_length,
                                                       position=segment_pos)
                inputs = input_emb + position_emb
            else:
                inputs = input_emb

            if p.masked_lm:
                segment_mask = attentions.segment_mask(segment_ids,
                                                       segment_ids,
                                                       inputs.dtype)
            else:
                segment_mask = attentions.causal_segment_mask(
                    segment_ids, inputs.dtype)

            output = self.transformer.fprop(inputs,
                                            paddings,
                                            segment_mask=segment_mask,
                                            segment_pos=segment_pos)

            # Final layer norm
            if p.final_ln_tpl is not None:
                output = self.final_ln.fprop(output)
            return self.compute_loss(output, labels)
Exemple #22
0
def main(args):
    X, Y, expected_thetas, expected_pairwise = get_data(
        N=args.num_data, P=args.num_dimensions, S=args.active_dimensions)

    # setup hyperparameters
    hypers = {
        "expected_sparsity": max(1.0, args.num_dimensions / 10),
        "alpha1": 3.0,
        "beta1": 1.0,
        "alpha2": 3.0,
        "beta2": 1.0,
        "alpha3": 1.0,
        "c": 1.0,
    }

    # do inference
    rng_key = random.PRNGKey(0)
    samples = run_inference(model, args, rng_key, X, Y, hypers)

    # compute the mean and square root variance of each coefficient theta_i
    means, stds = vmap(
        lambda dim: analyze_dimension(samples, X, Y, dim, hypers))(jnp.arange(
            args.num_dimensions))

    print(
        "Coefficients theta_1 to theta_%d used to generate the data:" %
        args.active_dimensions,
        expected_thetas,
    )
    print(
        "The single quadratic coefficient theta_{1,2} used to generate the data:",
        expected_pairwise,
    )
    active_dimensions = []

    for dim, (mean, std) in enumerate(zip(means, stds)):
        # we mark the dimension as inactive if the interval [mean - 3 * std, mean + 3 * std] contains zero
        lower, upper = mean - 3.0 * std, mean + 3.0 * std
        inactive = "inactive" if lower < 0.0 and upper > 0.0 else "active"
        if inactive == "active":
            active_dimensions.append(dim)
        print("[dimension %02d/%02d]  %s:\t%.2e +- %.2e" %
              (dim + 1, args.num_dimensions, inactive, mean, std))

    print("Identified a total of %d active dimensions; expected %d." %
          (len(active_dimensions), args.active_dimensions))

    # Compute the mean and square root variance of coefficients theta_ij for i,j active dimensions.
    # Note that the resulting numbers are only meaningful for i != j.
    if len(active_dimensions) > 0:
        dim_pairs = jnp.array(
            list(itertools.product(active_dimensions, active_dimensions)))
        means, stds = vmap(lambda dim_pair: analyze_pair_of_dimensions(
            samples, X, Y, dim_pair[0], dim_pair[1], hypers))(dim_pairs)
        for dim_pair, mean, std in zip(dim_pairs, means, stds):
            dim1, dim2 = dim_pair
            if dim1 >= dim2:
                continue
            lower, upper = mean - 3.0 * std, mean + 3.0 * std
            if not (lower < 0.0 and upper > 0.0):
                format_str = "Identified pairwise interaction between dimensions %d and %d: %.2e +- %.2e"
                print(format_str % (dim1 + 1, dim2 + 1, mean, std))

        # Draw a single sample of coefficients theta from the posterior, where we return all singleton
        # coefficients theta_i and pairwise coefficients theta_ij for i, j active dimensions. We use the
        # final MCMC sample obtained from the HMC sampler.
        thetas = sample_theta_space(
            X,
            Y,
            active_dimensions,
            samples["msq"][-1],
            samples["lambda"][-1],
            samples["eta1"][-1],
            samples["xisq"][-1],
            hypers["c"],
            samples["sigma"][-1],
        )
        print("Single posterior sample theta:\n", thetas)
Exemple #23
0
    def fprop(
        self,
        inputs: JTensor,
        input_paddings: JTensor,
        targets: JTensor,
        target_paddings: JTensor,
        labels: Optional[NestedMap] = None,
        input_segment_ids: Optional[JTensor] = None,
        input_segment_pos: Optional[JTensor] = None,
        target_segment_ids: Optional[JTensor] = None,
        target_segment_pos: Optional[JTensor] = None,
    ) -> NestedMap:
        """Computes xent loss given the sequence model inputs.

    Args:
      inputs: Input ids. An int32 JTensor of shape [B, S].
      input_paddings: A 0/1 JTensor of shape [B, S] with 1 denoting padding
        correspdonding to the input sequence.
      targets: Target ids. An int32 JTensor of shape [B, T].
      target_paddings: A 0/1 JTensor of shape [B, T] with 1 denoting padding
        corresponding to the target sequence.
      labels: A `.NestedMap` containing the following fields: class_weights, a
        JTensor with shape [batch, seqlen] containing weights for each target
        word. class_ids, a JTensor with shape [B, T] of int32 dtype containing
        the target class labels. class_probabilities, a JTensor with shape [B,
        T, V] of float values indicating class-membership probabilities.
      input_segment_ids: A JTensor of shape [B,S]. The segment that each input
        token belongs to.
      input_segment_pos: A JTensor of shape [B, S]. The position of each input
        token within a segment.
      target_segment_ids: A JTensor of shape [B,T]. The segment that each target
        token belongs to.
      target_segment_pos: A JTensor of shape [B, T]. The position of each target
        token within a segment.

    Returns:
      Returns xent_output, where
      `xent_output` is a `.NestedMap` as defined by `SoftmaxLayer`'s return. In
      addition, per_sequence_xent is added which equal to the sum of xent loss
      for tokens in a sequence.
    """
        # Get the input embeddings.
        p = self.params
        batch, seq_length = inputs.shape
        _, target_seq_length = targets.shape

        encoder_output = self._encode(inputs, input_paddings,
                                      input_segment_ids, input_segment_pos)

        if p.decoder_embedding_tpl is not None:
            # Targets have separate embedding params.
            target_emb = self.decoder_embedding_lookup.fprop(targets)
        else:
            # Embedding parameters are shared with targets and softmax.
            target_emb = self.softmax.emb_lookup(targets)

        if p.decoder_ngrammer_tpl is not None:
            target_emb = self.decoder_ngrammer.fprop(
                input_ids=targets,
                input_embs=target_emb,
                paddings=target_paddings,
                segment_pos=target_segment_pos)

        if p.position_emb_tpl is not None:
            targets_position_emb = self.position_emb.fprop(
                seq_length=target_seq_length, position=target_segment_pos)
            target_emb += targets_position_emb

        if input_segment_ids is None:
            assert input_segment_pos is None
            # Fold the paddings with the segment mask.
            input_segment_ids = jnp.asarray(1 - input_paddings, jnp.int32)
            input_segment_pos = jnp.tile(
                jnp.arange(seq_length, dtype=jnp.int32)[None, :], [batch, 1])

        if target_segment_ids is None:
            assert target_segment_pos is None
            # Fold the paddings with the segment mask.
            target_segment_ids = jnp.asarray(1 - target_paddings, jnp.int32)
            target_segment_pos = jnp.tile(
                jnp.arange(target_seq_length, dtype=jnp.int32)[None, :],
                [batch, 1])

        # Cross attention.
        cross_segment_mask = attentions.segment_mask(target_segment_ids,
                                                     input_segment_ids,
                                                     target_emb.dtype)
        target_segment_mask = attentions.causal_segment_mask(
            target_segment_ids, target_emb.dtype)
        output = self.decoder.fprop(target_emb,
                                    target_paddings,
                                    target_segment_mask,
                                    cross_inputs=encoder_output,
                                    cross_paddings=input_paddings,
                                    cross_segment_mask=cross_segment_mask)

        # Final layer norm for decoder.
        output = self.decoder_ln.fprop(output)

        return self.compute_loss(output, labels)
Exemple #24
0
 def select_tril(x):
     mask = jnp.arange(x.shape[0])[:, None] > jnp.arange(x.shape[1])
     return jnp.where(mask, x, jnp.zeros_like(x))
Exemple #25
0
    def apply(self,
              inputs_q,
              inputs_kv,
              num_heads,
              dtype=jnp.float32,
              qkv_features=None,
              out_features=None,
              attention_axis=None,
              causal_mask=False,
              padding_mask=None,
              key_padding_mask=None,
              segmentation=None,
              key_segmentation=None,
              cache=None,
              broadcast_dropout=True,
              dropout_rng=None,
              dropout_rate=0.,
              deterministic=False,
              precision=None,
              kernel_init=nn.linear.default_kernel_init,
              bias_init=nn.initializers.zeros,
              bias=True,
              num_partitions=2):
        """Applies multi-head dot product attention on the input data.

    Projects the inputs into multi-headed query, key, and value vectors,
    applies dot-product attention and project the results to an output vector.

    This can be used for encoder-decoder attention by specifying both `inputs_q`
    and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
    setting `inputs_kv` to None.

    Args:
      inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
      inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
        or None for self-attention, inn which case key/values will be derived
        from inputs_q.
      num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
        should be divisible by the number of heads.
      dtype: the dtype of the computation (default: float32)
      qkv_features: dimension of the key, query, and value.
      out_features: dimension of the last projection
      attention_axis: axes over which the attention is applied ( 'None' means
        attention over all axes, but batch, heads, and features).
      causal_mask: boolean specifying whether to apply a causal mask on the
        attention weights. If True, the output at timestep `t` will not depend
        on inputs at timesteps strictly greater than `t`.
      padding_mask: boolean specifying query tokens that are pad token.
      key_padding_mask: boolean specifying key-value tokens that are pad token.
      segmentation: segment indices for packed inputs_q data.
      key_segmentation: segment indices for packed inputs_kv data.
      cache: an instance of `flax.nn.attention.Cache` used for efficient
        autoregressive decoding.
      broadcast_dropout: bool: use a broadcasted dropout along batch dims.
      dropout_rng: JAX PRNGKey: to be used for dropout
      dropout_rate: dropout rate
      deterministic: bool, deterministic or not (to apply dropout)
      precision: numerical precision of the computation see `jax.lax.Precision`
        for details.
      kernel_init: initializer for the kernel of the Dense layers.
      bias_init: initializer for the bias of the Dense layers.
      bias: bool: whether pointwise QKVO dense transforms use bias.
      num_partitions: number of ways to partition (i.e. how many devices to run
        across).

    Returns:
      output of shape `[bs, dim1, dim2, ..., dimN, features]`.
    """

        assert causal_mask or not cache, (
            'Caching is only support for causal attention.')

        if inputs_kv is None:
            inputs_kv = inputs_q

        if attention_axis is None:
            attention_axis = tuple(range(1, inputs_q.ndim - 1))

        features = out_features or inputs_q.shape[-1]
        qkv_features = qkv_features or inputs_q.shape[-1]

        assert qkv_features % num_heads == 0, (
            'Memory dimension must be divisible by number of heads.')
        head_dim = qkv_features // num_heads

        dense = nn.DenseGeneral.partial(axis=-1,
                                        features=(num_heads, head_dim),
                                        kernel_init=kernel_init,
                                        bias_init=bias_init,
                                        bias=bias,
                                        precision=precision)
        # project inputs_q to multi-headed q/k/v
        # dimensions are then [bs, dims..., n_heads, n_features_per_head]
        query, key, value = (dense(inputs_q, dtype=dtype, name='query'),
                             dense(inputs_kv, dtype=dtype, name='key'),
                             dense(inputs_kv, dtype=dtype, name='value'))
        if num_partitions > 1:
            partitions = P(1, 1, num_partitions, 1)
            query = with_sharding_constraint(query, partitions)
            key = with_sharding_constraint(key, partitions)
            value = with_sharding_constraint(value, partitions)

        if cache:
            assert isinstance(cache,
                              Cache), 'cache must be an instance of Cache'
            if self.is_initializing():
                cache.store(
                    np.array((key.ndim, ) + key.shape[-2:], dtype=np.int32))
            else:
                cache_entry = cache.retrieve(None)
                expected_shape = list(cache_entry.key.shape[:-2])
                for attn_dim in attention_axis:
                    expected_shape[attn_dim] = 1
                expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
                if expected_shape != inputs_q.shape:
                    raise ValueError('Invalid shape provided, '
                                     'expected shape %s instead got %s.' %
                                     (expected_shape, inputs_q.shape))

                if not isinstance(cache_entry, _CacheEntry):
                    raise ValueError('Cache is not initialized.')

                cshape = cache_entry.key.shape
                i = cache_entry.i
                one_hot_indices = jax.nn.one_hot(i, cshape[3],
                                                 dtype=key.dtype).reshape(
                                                     (1, 1, 1, cshape[3]))
                key = key.transpose((0, 2, 3, 1))
                key = cache_entry.key + key * one_hot_indices
                value = value.transpose((0, 2, 3, 1))
                value = cache_entry.value + value * one_hot_indices

                one = jnp.array(1, jnp.uint32)
                cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                                  key=key,
                                                  value=value)
                cache.store(cache_entry)

                key = key.transpose((0, 3, 1, 2))
                value = value.transpose((0, 3, 1, 2))
                cshape = (cshape[0], cshape[3], cshape[1], cshape[2])

                # TODO(levskaya): verify this is still needed in translation decoding.
                key_padding_mask = jnp.broadcast_to(
                    (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
                key_padding_mask = key_padding_mask.astype(
                    jnp.float32)[Ellipsis, None]

        # create attention masks
        mask_components = []

        if causal_mask:
            if cache and not self.is_initializing():
                bias_pre_shape = (1, ) * (key.ndim - 1)
                attn_shape = tuple(np.take(key.shape, attention_axis))
                attn_size = np.prod(attn_shape)
                ii = jnp.arange(attn_size, dtype=jnp.uint32)
                mask = ii < cache_entry.i
                mask_components.append(
                    mask.reshape(bias_pre_shape + attn_shape))
            else:
                mask_components.append(_make_causal_mask(key, attention_axis))

        if padding_mask is not None:
            if key_padding_mask is None:
                key_padding_mask = padding_mask
            padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                             padding_mask_key=key_padding_mask,
                                             query_shape=query.shape,
                                             key_shape=key.shape,
                                             attention_axis=attention_axis)
            mask_components.append(padding_mask)

        if segmentation is not None:
            if key_segmentation is None:
                key_segmentation = segmentation
            segmentation_mask = make_padding_mask(
                padding_mask_query=segmentation,
                padding_mask_key=key_segmentation,
                query_shape=query.shape,
                key_shape=key.shape,
                attention_axis=attention_axis,
                segmentation_mask=True)
            mask_components.append(segmentation_mask)

        if mask_components:
            attention_mask = mask_components[0]
            for component in mask_components[1:]:
                attention_mask = jnp.logical_and(attention_mask, component)

            # attention mask in the form of attention bias
            attention_bias = lax.select(
                attention_mask > 0,
                jnp.full(attention_mask.shape, 0.).astype(dtype),
                jnp.full(attention_mask.shape, -1e10).astype(dtype))
        else:
            attention_bias = None

        # apply attention
        x = dot_product_attention(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

        # back to the original inputs dimensions
        out = nn.DenseGeneral(x,
                              features=features,
                              axis=(-2, -1),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              dtype=dtype,
                              precision=precision,
                              name='out')
        if num_partitions > 1:
            x = with_sharding_constraint(x, None)

        return out
Exemple #26
0
def vtrace(
    v_tm1: Array,
    v_t: Array,
    r_t: Array,
    discount_t: Array,
    rho_t: Array,
    lambda_: float = 1.0,
    clip_rho_threshold: float = 1.0,
    stop_target_gradients: bool = True,
) -> Array:
    """Calculates V-Trace errors from importance weights.

  V-trace computes TD-errors from multistep trajectories by applying
  off-policy corrections based on clipped importance sampling ratios.

  See "IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor
  Learner Architectures" by Espeholt et al. (https://arxiv.org/abs/1802.01561).

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance sampling ratios.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t],
                     [float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_t])

    # Clip importance sampling ratios.
    c_t = jnp.minimum(1.0, rho_t) * lambda_
    clipped_rhos = jnp.minimum(clip_rho_threshold, rho_t)

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        target_tm1 = jax.lax.stop_gradient(target_tm1)
    return target_tm1 - v_tm1
Exemple #27
0
def multi_head_dot_product_attention(scope: Scope,
                                     inputs_q,
                                     inputs_kv,
                                     num_heads,
                                     dtype=jnp.float32,
                                     qkv_features=None,
                                     out_features=None,
                                     attention_axis=None,
                                     causal_mask=False,
                                     padding_mask=None,
                                     key_padding_mask=None,
                                     segmentation=None,
                                     key_segmentation=None,
                                     cache=False,
                                     broadcast_dropout=True,
                                     dropout_rng=None,
                                     dropout_rate=0.,
                                     deterministic=False,
                                     precision=None,
                                     kernel_init=default_kernel_init,
                                     bias_init=initializers.zeros,
                                     bias=True,
                                     attention_fn=dot_product_attention):
    """Applies multi-head dot product attention on the input data.

  Projects the inputs into multi-headed query, key, and value vectors,
  applies dot-product attention and project the results to an output vector.

  This can be used for encoder-decoder attention by specifying both `inputs_q`
  and `inputs_kv` orfor self-attention by only specifying `inputs_q` and
  setting `inputs_kv` to None.

  Args:
    inputs_q: input queries of shape `[bs, dim1, dim2, ..., dimN, features]`.
    inputs_kv: key/values of shape `[bs, dim1, dim2, ..., dimN, features]`
      or None for self-attention, inn which case key/values will be derived
      from inputs_q.
    num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1])
      should be divisible by the number of heads.
    dtype: the dtype of the computation (default: float32)
    qkv_features: dimension of the key, query, and value.
    out_features: dimension of the last projection
    attention_axis: axes over which the attention is applied ( 'None' means
      attention over all axes, but batch, heads, and features).
    causal_mask: boolean specifying whether to apply a causal mask on the
      attention weights. If True, the output at timestep `t` will not depend
      on inputs at timesteps strictly greater than `t`.
    padding_mask: boolean specifying query tokens that are pad token.
    key_padding_mask: boolean specifying key-value tokens that are pad token.
    segmentation: segment indices for packed inputs_q data.
    key_segmentation: segment indices for packed inputs_kv data.
    cache: an instance of `flax.nn.attention.Cache` used for efficient
      autoregressive decoding.
    broadcast_dropout: bool: use a broadcasted dropout along batch dims.
    dropout_rng: JAX PRNGKey: to be used for dropout
    dropout_rate: dropout rate
    deterministic: bool, deterministic or not (to apply dropout)
    precision: numerical precision of the computation see `jax.lax.Precision`
      for details.
    kernel_init: initializer for the kernel of the Dense layers.
    bias_init: initializer for the bias of the Dense layers.
    bias: bool: whether pointwise QKVO dense transforms use bias.
    attention_fn: dot_product_attention or compatible function. Accepts
    query, key, value, and returns output of shape
    `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]``

  Returns:
    output of shape `[bs, dim1, dim2, ..., dimN, features]`.
  """

    assert causal_mask or not cache, (
        'Caching is only support for causal attention.')

    if inputs_kv is None:
        inputs_kv = inputs_q

    if attention_axis is None:
        attention_axis = tuple(range(1, inputs_q.ndim - 1))

    features = out_features or inputs_q.shape[-1]
    qkv_features = qkv_features or inputs_q.shape[-1]

    assert qkv_features % num_heads == 0, (
        'Memory dimension must be divisible by number of heads.')
    head_dim = qkv_features // num_heads

    dense = functools.partial(dense_general,
                              axis=-1,
                              dtype=dtype,
                              features=(num_heads, head_dim),
                              kernel_init=kernel_init,
                              bias_init=bias_init,
                              bias=bias,
                              precision=precision)
    # project inputs_q to multi-headed q/k/v
    # dimensions are then [bs, dims..., n_heads, n_features_per_head]
    query = scope.child(dense, 'query')(inputs_q)
    key = scope.child(dense, 'key')(inputs_kv)
    value = scope.child(dense, 'value')(inputs_kv)

    if cache:
        if not scope.has_variable('cache', 'entry'):
            ndim, tail_shape = (key.ndim, key.shape[-2:])

            def init_fn(shape, dtype=jnp.float32):
                full_shape = shape + tail_shape
                if len(full_shape) != ndim:
                    raise ValueError(
                        'Shape should be a tuple with the shape of the batch'
                        'and attention dims.')
                return CacheEntry(key=jnp.zeros(full_shape, dtype),
                                  value=jnp.zeros(full_shape, dtype),
                                  i=jnp.zeros((), jnp.uint32))

            cache_entry = init_fn
        else:
            cache_entry = scope.get_variable('cache', 'entry')
            if not isinstance(cache_entry, CacheEntry):
                raise ValueError('Cache is not initialized.')

            expected_shape = list(cache_entry.key.shape[:-2])
            for attn_dim in attention_axis:
                expected_shape[attn_dim] = 1
            expected_shape = tuple(expected_shape) + inputs_q.shape[-1:]
            if expected_shape != inputs_q.shape:
                raise ValueError('Invalid shape provided, '
                                 'expected shape %s instead got %s.' %
                                 (expected_shape, inputs_q.shape))

            cshape = cache_entry.key.shape
            indices = [0] * len(cshape)
            i = cache_entry.i
            attn_size = onp.prod(onp.take(cshape, attention_axis))
            for attn_dim in attention_axis:
                attn_size //= cshape[attn_dim]
                indices[attn_dim] = i // attn_size
                i = i % attn_size

            key = lax.dynamic_update_slice(cache_entry.key, key, indices)
            value = lax.dynamic_update_slice(cache_entry.value, value, indices)
            one = jnp.array(1, jnp.uint32)
            cache_entry = cache_entry.replace(i=cache_entry.i + one,
                                              key=key,
                                              value=value)

            # TODO(levskaya): verify this is still needed in translation decoding.
            key_padding_mask = jnp.broadcast_to(
                (jnp.arange(cshape[1]) < cache_entry.i), cshape[:2])
            key_padding_mask = key_padding_mask.astype(jnp.float32)[..., None]
        scope.put_variable('cache', 'entry', cache_entry)

    # create attention masks
    mask_components = []

    if causal_mask:
        if cache and isinstance(cache_entry, CacheEntry):
            bias_pre_shape = (1, ) * (key.ndim - 1)
            attn_shape = tuple(onp.take(key.shape, attention_axis))
            attn_size = onp.prod(attn_shape)
            ii = jnp.arange(attn_size, dtype=jnp.uint32)
            mask = ii < cache_entry.i
            mask_components.append(mask.reshape(bias_pre_shape + attn_shape))
        else:
            mask_components.append(_make_causal_mask(key, attention_axis))

    if padding_mask is not None:
        if key_padding_mask is None:
            key_padding_mask = padding_mask
        padding_mask = make_padding_mask(padding_mask_query=padding_mask,
                                         padding_mask_key=key_padding_mask,
                                         query_shape=query.shape,
                                         key_shape=key.shape,
                                         attention_axis=attention_axis)
        mask_components.append(padding_mask)

    if segmentation is not None:
        if key_segmentation is None:
            key_segmentation = segmentation
        segmentation_mask = make_padding_mask(
            padding_mask_query=segmentation,
            padding_mask_key=key_segmentation,
            query_shape=query.shape,
            key_shape=key.shape,
            attention_axis=attention_axis,
            segmentation_mask=True)
        mask_components.append(segmentation_mask)

    if mask_components:
        attention_mask = mask_components[0]
        for component in mask_components[1:]:
            attention_mask = jnp.logical_and(attention_mask, component)

        # attention mask in the form of attention bias
        attention_bias = lax.select(
            attention_mask > 0,
            jnp.full(attention_mask.shape, 0.).astype(dtype),
            jnp.full(attention_mask.shape, -1e10).astype(dtype))
    else:
        attention_bias = None

    # apply attention
    x = scope.child(attention_fn)(query,
                                  key,
                                  value,
                                  dtype=dtype,
                                  axis=attention_axis,
                                  bias=attention_bias,
                                  precision=precision,
                                  dropout_rng=dropout_rng,
                                  dropout_rate=dropout_rate,
                                  broadcast_dropout=broadcast_dropout,
                                  deterministic=deterministic)

    # back to the original inputs dimensions
    out = scope.child(dense_general, name='out')(x,
                                                 features=features,
                                                 axis=(-2, -1),
                                                 kernel_init=kernel_init,
                                                 bias_init=bias_init,
                                                 bias=bias,
                                                 dtype=dtype,
                                                 precision=precision)

    return out
Exemple #28
0
def leaky_vtrace(v_tm1: Array,
                 v_t: Array,
                 r_t: Array,
                 discount_t: Array,
                 rho_t: Array,
                 alpha_: float = 1.0,
                 lambda_: float = 1.0,
                 clip_rho_threshold: float = 1.0,
                 stop_target_gradients: bool = True):
    """Calculates Leaky V-Trace errors from importance weights.

  Leaky-Vtrace is a combination of Importance sampling and V-trace, where the
  degree of mixing is controlled by a scalar `alpha` (that may be meta-learnt).

  See "Self-Tuning Deep Reinforcement Learning"
  by Zahavy et al. (https://arxiv.org/abs/2002.12928)

  Args:
    v_tm1: values at time t-1.
    v_t: values at time t.
    r_t: reward at time t.
    discount_t: discount at time t.
    rho_t: importance weights at time t.
    alpha_: mixing parameter for Importance Sampling and V-trace.
    lambda_: scalar mixing parameter lambda.
    clip_rho_threshold: clip threshold for importance weights.
    stop_target_gradients: whether or not to apply stop gradient to targets.

  Returns:
    Leaky V-Trace error.
  """
    chex.assert_rank([v_tm1, v_t, r_t, discount_t, rho_t], [1, 1, 1, 1, 1])
    chex.assert_type([v_tm1, v_t, r_t, discount_t, rho_t],
                     [float, float, float, float, float])
    chex.assert_equal_shape([v_tm1, v_t, r_t, discount_t, rho_t])

    # Mix clipped and unclipped importance sampling ratios.
    c_t = ((1 - alpha_) * rho_t + alpha_ * jnp.minimum(1.0, rho_t)) * lambda_
    clipped_rhos = ((1 - alpha_) * rho_t +
                    alpha_ * jnp.minimum(clip_rho_threshold, rho_t))

    # Compute the temporal difference errors.
    td_errors = clipped_rhos * (r_t + discount_t * v_t - v_tm1)

    # Work backwards computing the td-errors.
    err = 0.0
    errors = []
    for i in jnp.arange(v_t.shape[0] - 1, -1, -1):
        err = td_errors[i] + discount_t[i] * c_t[i] * err
        errors.insert(0, err)

    # Return errors.
    if not stop_target_gradients:
        return jnp.array(errors)
    # In TD-like algorithms, we want gradients to only flow in the predictions,
    # and not in the values used to bootstrap. In this case, add the value of the
    # initial state value to get the implied estimates of the returns, stop
    # gradient around such target and then subtract again the initial state value.
    else:
        target_tm1 = jnp.array(errors) + v_tm1
        return jax.lax.stop_gradient(target_tm1) - v_tm1
Exemple #29
0
 def log_abs_det_jacobian(self, x, y, intermediates=None):
     # NB: see derivation in LKJCholesky implementation
     n = jnp.shape(x)[-1]
     order = -jnp.arange(n - 1, -1, -1)
     return jnp.sum(order * jnp.log(jnp.diagonal(y, axis1=-2, axis2=-1)), axis=-1)
Exemple #30
0
    def _finalise_results(self, state: NestedSamplerState, collect_samples: bool, stoachastic_uncertainty: bool,
                          max_samples: int):
        collect = ['logZ',
                   'logZerr',
                   'ESS',
                   'ESS_err',
                   'H',
                   'H_err',
                   'num_likelihood_evaluations',
                   'efficiency',
                   'marginalised',
                   'marginalised_uncert',
                   'log_L_samples',
                   'n_per_sample',
                   'log_p',
                   'log_X',
                   'sampler_efficiency',
                   'num_samples'
                   ]

        if collect_samples:
            collect.append('samples')


        NestedSamplerResults = namedtuple('NestedSamplerResults', collect)
        tracked_expectations = TrackedExpectation(self.marginalised, self.marginalised_shapes,
                                                  state=state.tracked_expectations_state)
        live_update_results = tracked_expectations.update_from_live_points(state.live_points, state.log_L_live)
        if self.marginalised is not None:
            marginalised = tracked_expectations.marg_mean()
            marginalised_uncert = None  # tracked_expectations.marg_variance()
        else:
            marginalised = None
            marginalised_uncert = None

        num_live_points = state.log_L_live.shape[0]
        n_per_sample = jnp.where(jnp.arange(max_samples) < state.num_dead, num_live_points, jnp.inf)
        n_per_sample = dynamic_update_slice(n_per_sample,
                                            num_live_points - jnp.arange(num_live_points, dtype=n_per_sample.dtype),
                                            [state.num_dead])
        sampler_efficiency = dynamic_update_slice(state.sampler_efficiency,
                                                  jnp.ones(num_live_points),
                                                  [state.num_dead])
        log_w = dynamic_update_slice(state.log_w,
                                     live_update_results[3],
                                     [state.num_dead])
        log_p = log_w - logsumexp(log_w)
        log_X = dynamic_update_slice(state.log_X,
                                     live_update_results[2],
                                     [state.num_dead])
        log_L_samples = dynamic_update_slice(state.log_L_dead,
                                          live_update_results[1],
                                          [state.num_dead])
        num_samples = state.num_dead + num_live_points



        data = dict(
            logZ=tracked_expectations.evidence_mean(),
            logZerr=jnp.sqrt(tracked_expectations.evidence_variance()),
            ESS=tracked_expectations.effective_sample_size(),
            ESS_err=None,
            H=tracked_expectations.information_gain_mean(),
            H_err=jnp.sqrt(tracked_expectations.information_gain_variance()),
            num_likelihood_evaluations=state.num_likelihood_evaluations,
            efficiency=(state.num_dead + state.log_L_live.shape[0]) / state.num_likelihood_evaluations,
            marginalised=marginalised,
            marginalised_uncert=marginalised_uncert,
            n_per_sample=n_per_sample,
            log_p=log_p,
            log_X=log_X,
            log_L_samples=log_L_samples,
            num_samples=num_samples,
            sampler_efficiency=sampler_efficiency
        )

        if collect_samples:

            # log_t = jnp.where(jnp.isinf(n_per_sample), 0., jnp.log(n_per_sample) - jnp.log(n_per_sample + 1.))
            # log_X = jnp.cumsum(log_t)
            ar = jnp.argsort(state.log_L_live)
            samples = dict_multimap(lambda dead_points, live_points:
                                    dynamic_update_slice(dead_points,
                                                         live_points.astype(dead_points.dtype)[ar, ...],
                                                         [state.num_dead] + [0] * (len(dead_points.shape) - 1)),
                                    state.dead_points, state.live_points)
            # log_L_samples = dynamic_update_slice(state.log_L_dead, state.log_L_live[ar], [state.num_dead])

            # sampler_efficiency = dynamic_update_slice(state.sampler_efficiency,
            #                                           jnp.ones(num_live_points),
            #                                           [state.num_dead])

            # num_samples = state.num_dead + num_live_points
            data['samples'] = samples
            # data['log_L_samples'] = log_L_samples
            # data['n_per_sample'] = n_per_sample
            # data['log_X'] = log_X
            #
            # data['sampler_efficiency'] = sampler_efficiency
            # data['num_samples'] = num_samples

            if stoachastic_uncertainty:
                S = 200
                logZ, m, cov, ESS, H = vmap(lambda key: stochastic_result_computation(n_per_sample,
                                                                                      key, samples, log_L_samples))(
                    random.split(state.key, S))
                data['logZ'] = jnp.mean(logZ, axis=0)
                data['logZerr'] = jnp.std(logZ, axis=0)
                data['H'] = jnp.mean(H, axis=0)
                data['H_err'] = jnp.std(H, axis=0)
                data['ESS'] = jnp.mean(ESS, axis=0)
                data['ESS_err'] = jnp.std(ESS, axis=0)

            # build mean weights
            # log_L_samples = jnp.concatenate([jnp.array([-jnp.inf]), log_L_samples])
            # log_X = jnp.concatenate([jnp.array([0.]), log_X])
            # log(dX_i) = log(X[i-1] - X[i]) = log((1-t_i)*X[i-1]) = log(1-t_i) + log(X[i-1])
            # log_dX = - jnp.log(n_per_sample + 1.) + log_X[:-1]
            # log_dX = jnp.log(-jnp.diff(jnp.exp(log_X)))
            # log_avg_L = jnp.logaddexp(log_L_samples[:-1], log_L_samples[1:]) - jnp.log(2.)
            # w_i = dX_i avg_L_i
            # log_w = log_dX + log_avg_L
            # log_p = log_w - logsumexp(log_w)
            # data['log_p'] = log_p

            # if self.marginalise is not None:
            #     def single_marginalise(marginalise):
            #         return jnp.sum(vmap(lambda p, sample: p * marginalise(**sample))(jnp.exp(log_p), samples), axis=0)
            #
            #     data['marginalised'] = dict_multimap(single_marginalise, self.marginalise)

        return NestedSamplerResults(**data)