Example #1
0
  def compute_weighted_accuracy(self, logits, targets, weights=None):
    """Compute weighted accuracy for log probs and targets.

    Args:
     logits: [batch, length, num_classes] float array.
     targets: categorical targets [batch, length] int array.
     weights: None or array of shape [batch, length]

    Returns:
      Tuple of scalar loss and batch normalizing factor.
    """
    if logits.ndim != targets.ndim + 1:
      raise ValueError(f"Incorrect shapes. Got shape {str(logits.shape)} logits"
                       f" and {str(targets.shape)} targets")
    loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
    normalizing_factor = np.prod(logits.shape[:-1])
    if weights is not None:
      loss = loss * weights
      normalizing_factor = weights.sum()

    return loss.sum(), normalizing_factor
Example #2
0
def compute_weighted_accuracy(
        logits,  # 3D ndarray of floats
        targets,  # 2D ndarray of ints
        weights=None,  # 2D ndarray of floats
):
    """Compute weighted accuracy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: Categorical targets [batch, length] int array.
   weights: None or array of shape [batch, length]

  Returns:
    Tuple of scalar loss and batch normalizing factor.
  """
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    loss = jnp.equal(jnp.argmax(logits, axis=-1), targets)
    return weight_loss(loss, targets, weights)
Example #3
0
def compute_weighted_accuracy(logits, targets, weights=None):
  """Compute weighted accuracy for log probs and targets.

  Args:
   logits: `[batch, length, num_classes]` float array.
   targets: categorical targets `[batch, length]` int array.
   weights: None or array of shape [batch, length, 1]

  Returns:
    Tuple of scalar accuracy and batch normalizing factor.
  """
  if logits.ndim != targets.ndim + 1:
    raise ValueError('Incorrect shapes. Got shape %s logits and %s targets' %
                     (str(logits.shape), str(targets.shape)))
  acc = jnp.equal(jnp.argmax(logits, axis=-1), targets)
  normalizing_factor = jnp.prod(jnp.asarray(targets.shape))
  if weights is not None:
    acc = acc * weights
    normalizing_factor = weights.sum()

  return acc.sum(), normalizing_factor
Example #4
0
    def eval_step(params, batch):
        labels = batch.pop("labels")

        logits = model(**batch, params=params, train=False)[0]

        # compute loss, ignore padded input tokens
        label_mask = jnp.where(labels > 0, 1.0, 0.0)
        loss = optax.softmax_cross_entropy(
            logits, onehot(labels, logits.shape[-1])) * label_mask

        # compute accuracy
        accuracy = jnp.equal(jnp.argmax(logits, axis=-1), labels) * label_mask

        # summarize metrics
        metrics = {
            "loss": loss.sum(),
            "accuracy": accuracy.sum(),
            "normalizer": label_mask.sum()
        }
        metrics = jax.lax.psum(metrics, axis_name="batch")

        return metrics
Example #5
0
def get_bigram_ids(ids: JTensor,
                   vocab_size: int,
                   segment_pos: Optional[JTensor] = None) -> JTensor:
    """Generate bi-gram ids from uni-gram ids.

  Args:
    ids: An int32 JTensor of shape [B, L].
    vocab_size: Vocabulary size of `ids`, must be > 0.
    segment_pos: If not None (meaning `ids` is packed, i.e. each example
      containing multiple segments), an int32 tensor of shape [B, L], containing
      the position of each id in `ids` in a segment.

  Returns:
    ngram_ids: An int64 JTensor of shape [B, L].
  """
    assert vocab_size > 0
    batch_size = ids.shape[0]
    # Cast to int64 to avoid overflow, which would affect bucket collision
    # rate and model quality.
    ids = jnp.array(ids, dtype=jnp.int64)  # [batch, time]
    pad = jnp.zeros([batch_size, 1], dtype=ids.dtype)  # [batch, 1]

    # Mechanism: for bigrams, we shift ids by one position along the time
    # dimension, and compute:
    #   bigram_id = original_id + shifted_id * vocab_size.
    ids_0 = jnp.concatenate([ids, pad], 1)  # [batch, time+1]
    ids_1 = jnp.concatenate([pad, ids], 1)  # [batch, 1+time]

    if segment_pos is not None:
        # If input is packed, mask out the parts that cross the segment
        # boundaries.
        mask = jnp.array(jnp.equal(segment_pos, 0), dtype=ids_0.dtype)
        mask = 1 - mask
        mask = jnp.concatenate([mask, pad], 1)
        ids_1 *= mask

    ngram_ids = ids_0 + ids_1 * vocab_size  # Bigram ids.
    ngram_ids = ngram_ids[:, 0:-1]
    return ngram_ids
Example #6
0
def eval_step(model, inputs, prev_metrics):
  """A single eval step."""
  input_ids = inputs['input_ids']
  input_mask = inputs['input_mask']
  segment_ids = inputs['segment_ids']
  mask_lm_positions = inputs['masked_lm_positions']

  masked_lm_ids = inputs['masked_lm_ids']
  masked_lm_weights = inputs['masked_lm_weights']
  use_bf16 = FLAGS.use_bfloat16_activation
  dtype = jnp.bfloat16 if use_bf16 else jnp.float32
  lm_outputs, _ = model([input_ids, input_mask, segment_ids, mask_lm_positions],
                        train=False, dtype=dtype)
  assert lm_outputs.dtype == jnp.float32
  _, masked_lm_example_loss, masked_lm_log_probs = get_masked_lm_output(
      lm_outputs, masked_lm_ids, label_weights=masked_lm_weights)
  masked_lm_log_probs = jnp.reshape(masked_lm_log_probs,
                                    (-1, masked_lm_log_probs.shape[-1]))
  masked_lm_predictions = jnp.argmax(masked_lm_log_probs, axis=-1)
  masked_lm_example_loss = jnp.reshape(masked_lm_example_loss, (-1))
  masked_lm_ids = jnp.reshape(masked_lm_ids, (-1))
  masked_lm_weights = jnp.reshape(masked_lm_weights, (-1))

  masked_lm_weighted_correct = jnp.multiply(
      lax.convert_element_type(
          jnp.equal(masked_lm_ids, masked_lm_predictions), jnp.float32),
      masked_lm_weights)
  masked_lm_weighted_correct = jnp.sum(masked_lm_weighted_correct)
  masked_lm_weighted_count = jnp.sum(masked_lm_weights)

  metrics = {
      'masked_lm_weighted_correct':
          jnp.reshape(masked_lm_weighted_correct, (-1)),
      'masked_lm_weighted_count':
          jnp.reshape(masked_lm_weighted_count, (-1))
  }

  return jax.tree_multimap(jnp.add, prev_metrics, metrics)
Example #7
0
        def scan_body(state, i):
            U, K, G, Hs, ll = state
            U = U - us[i]

            def inner_map_body(k):
                H_k = Hs[k] + hs[i]
                G_k = G - jax_nn.nn_fwd(Hs[k], g_params) + jax_nn.nn_fwd(
                    H_k, g_params)
                return self.f_fn(G_k, U, f_params)

            def map_body(k):
                return jax.lax.cond(np.less(k, K + 1), inner_map_body,
                                    lambda x: -np.inf, k)

            log_potentials = jax.lax.map(map_body, np.arange(num_data_points))
            log_Z_hat = scipy.special.logsumexp(log_potentials, keepdims=True)
            log_q = log_potentials - log_Z_hat
            ll += log_q[cs[i]]
            K = jax.lax.cond(np.equal(cs[i], K), lambda x: x + 1, lambda x: x,
                             K)
            G = G - jax_nn.nn_fwd(Hs[cs[i]], g_params) + jax_nn.nn_fwd(
                Hs[cs[i]] + hs[i], g_params)
            Hs = jax.ops.index_update(Hs, cs[i], Hs[cs[i]] + hs[i])
            return (U, K, G, Hs, ll), None
Example #8
0
def test_basic():

    y_true = jnp.array([[0.0, 1.0], [1.0, 1.0]])
    y_pred = jnp.array([[1.0, 0.0], [1.0, 1.0]])

    # Using 'auto'/'sum_over_batch_size' reduction type.
    cosine_loss = elegy.losses.CosineSimilarity(axis=1)
    assert cosine_loss(y_true, y_pred) == -0.49999997

    # Calling with 'sample_weight'.
    assert (cosine_loss(y_true, y_pred,
                        sample_weight=jnp.array([0.8, 0.2])) == -0.099999994)

    # Using 'sum' reduction type.
    cosine_loss = elegy.losses.CosineSimilarity(
        axis=1, reduction=elegy.losses.Reduction.SUM)
    assert cosine_loss(y_true, y_pred) == -0.99999994

    # Using 'none' reduction type.
    cosine_loss = elegy.losses.CosineSimilarity(
        axis=1, reduction=elegy.losses.Reduction.NONE)

    assert jnp.equal(cosine_loss(y_true, y_pred),
                     jnp.array([-0.0, -0.99999994])).all()
Example #9
0
def compute_weighted_accuracy(logits, targets, weights=None):
    """Compute weighted accuracy for log probs and targets.

  Args:
   logits: [batch, length, num_classes] float array.
   targets: categorical one-hot targets [batch, length, category] int array.
   weights: None or array of shape [batch, length]

  Returns:
    Tuple of scalar accuracy and batch normalizing factor.
  """
    targets = targets.reshape((-1))
    if logits.ndim != targets.ndim + 1:
        raise ValueError(
            'Incorrect shapes. Got shape %s logits and %s targets' %
            (str(logits.shape), str(targets.shape)))
    accuracy = jnp.equal(jnp.argmax(logits, axis=-1), targets)
    normalizing_factor = np.prod(logits.shape[:-1])
    if weights is not None:
        weights = weights.reshape((-1))
        accuracy = accuracy * weights
        normalizing_factor = jnp.sum(weights)

    return jnp.sum(accuracy), normalizing_factor
Example #10
0
def top2_gating_on_logits(paddings,
                          logits,
                          experts_dim,
                          expert_capacity_dim,
                          fprop_dtype,
                          prng_key,
                          second_expert_policy='all',
                          second_expert_threshold=0.0,
                          legacy_mtf_behavior=True,
                          capacity_factor=None,
                          importance=None,
                          mask_dtype=jnp.int32):
    """Computes Top-2 gating for Mixture-of-Experts.

  This function takes gating logits, potentially sharded across tpu cores as
  inputs. We rely on sharding propagation to work universally with 1D and 2D
  sharding cases. Dispatch and combine tensors should be explicitly annotated
  with jax.with_sharding_constraint by the caller.

  We perform dispatch/combine via einsum.

  Dimensions:

    G: group dim
    S: group size dim
    E: number of experts
    C: capacity per expert
    M: model_dim (same as input_dim and output_dim as in FF layer)
    B: original batch dim
    L: original seq len dim

  Note that for local_dispatch, the original batch BLM is reshaped to GSM, each
  group `g = 0..G-1` is being dispatched independently.

  Args:
    paddings: G`S tensor.
    logits: G`SE tensor.
    experts_dim: number of experts
    expert_capacity_dim: number of examples per minibatch/group per expert. Each
      example is typically a vector of size input_dim, representing embedded
      token or an element of Transformer layer output.
    fprop_dtype: activation dtype
    prng_key: jax.random.PRNGKey used for randomness.
    second_expert_policy: 'all', 'sampling' or 'random'
      - 'all': we greedily pick the 2nd expert
      - 'sampling': we sample the 2nd expert from the softmax
      - 'random': we optionally randomize dispatch to second-best expert in
        proportional to (weight / second_expert_threshold).
    second_expert_threshold: threshold for probability normalization when
      second_expert_policy == 'random'
    legacy_mtf_behavior: bool, True if to match legacy mtf behavior exactly.
    capacity_factor: if set, increases expert_capacity_dim to at least
      (group_size * capacity_factor) / experts_dim
    importance: input importance weights for routing (G`S tensor or None)
    mask_dtype: using bfloat16 for fprop_dtype could be problematic for mask
      tensors, mask_dtype overrides dtype for such tensors

  Returns:
    A tuple (aux_loss, combine_tensor, dispatch_tensor, over_capacity ratios).

    - aux_loss: auxiliary loss, for equalizing the expert assignment ratios.
    - combine_tensor: a G`SEC tensor for combining expert outputs.
    - dispatch_tensor: a G`SEC tensor, scattering/dispatching inputs to experts.
    - over_capacity ratios: tuple that represents the ratio of tokens that
      were not dispatched due to lack of capcity for top_1 and top_2 expert
      respectively, e.g. (over_capacity_1, over_capacity_2)
  """
    assert (capacity_factor or expert_capacity_dim)
    if mask_dtype is None:
        assert fprop_dtype != jnp.bfloat16, 'Using bfloat16 for mask is an error.'
        mask_dtype = fprop_dtype

    raw_gates = jax.nn.softmax(logits, axis=-1)  # along E dim
    if raw_gates.dtype != fprop_dtype:
        raw_gates = raw_gates.astype(fprop_dtype)

    if capacity_factor is not None:
        # Determine expert capacity automatically depending on the input size
        group_size_dim = logits.shape[1]
        auto_expert_capacity = int(group_size_dim * capacity_factor /
                                   experts_dim)
        if expert_capacity_dim < auto_expert_capacity:
            expert_capacity_dim = auto_expert_capacity
            # Round up to a multiple of 4 to avoid possible padding.
            while expert_capacity_dim % 4:
                expert_capacity_dim += 1
            logging.info(
                'Setting expert_capacity_dim=%r (capacity_factor=%r '
                'group_size_dim=%r experts_dim=%r)', expert_capacity_dim,
                capacity_factor, group_size_dim, experts_dim)

    capacity = jnp.array(expert_capacity_dim, dtype=jnp.int32)

    # top-1 index: GS tensor
    index_1 = jnp.argmax(raw_gates, axis=-1)

    # GSE
    mask_1 = jax.nn.one_hot(index_1, experts_dim, dtype=mask_dtype)
    density_1_proxy = raw_gates

    if importance is not None:
        importance_is_one = jnp.equal(importance, 1.0)
        mask_1 *= jnp.expand_dims(importance_is_one.astype(mask_1.dtype), -1)
        density_1_proxy *= jnp.expand_dims(
            importance_is_one.astype(density_1_proxy.dtype), -1)
    else:
        assert len(mask_1.shape) == 3
        importance = jnp.ones_like(mask_1[:, :, 0]).astype(fprop_dtype)
        if paddings is not None:
            nonpaddings = 1.0 - paddings
            mask_1 *= jnp.expand_dims(nonpaddings.astype(mask_1.dtype), -1)
            density_1_proxy *= jnp.expand_dims(
                nonpaddings.astype(density_1_proxy.dtype), -1)
            importance = nonpaddings

    gate_1 = jnp.einsum('GSE,GSE->GS', raw_gates,
                        mask_1.astype(raw_gates.dtype))
    gates_without_top_1 = raw_gates * (1.0 - mask_1.astype(raw_gates.dtype))

    if second_expert_policy == 'sampling':
        # We directly sample the 2nd expert index from the softmax over of the 2nd
        # expert by getting rid of the 1st expert already selected above. To do so,
        # we set a very negative value to the logit corresponding to the 1st expert.
        # Then we sample from the softmax distribution using the Gumbel max trick.
        prng_key, subkey = jax.random.split(prng_key)
        noise = jax.random.uniform(subkey, logits.shape, dtype=logits.dtype)
        # Generates standard Gumbel(0, 1) noise, GSE tensor.
        noise = -jnp.log(-jnp.log(noise))
        very_negative_logits = jnp.ones_like(logits) * (-0.7) * np.finfo(
            logits.dtype).max
        # Get rid of the first expert by setting its logit to be very negative.
        updated_logits = jnp.where(mask_1 > 0.0, very_negative_logits, logits)
        # Add Gumbel noise to the updated logits.
        noised_logits = updated_logits + noise
        # Pick the index of the largest noised logits as the 2nd expert. This is
        # equivalent to sampling from the softmax over the 2nd expert.
        index_2 = jnp.argmax(noised_logits, axis=-1)
    else:
        # Greedily pick the 2nd expert.
        index_2 = jnp.argmax(gates_without_top_1, axis=-1)

    mask_2 = jax.nn.one_hot(index_2, experts_dim, dtype=mask_dtype)
    if paddings is not None:
        importance_is_nonzero = importance > 0.0
        mask_2 *= jnp.expand_dims(importance_is_nonzero.astype(mask_2.dtype),
                                  -1)
    gate_2 = jnp.einsum('GSE,GSE->GS', gates_without_top_1,
                        mask_2.astype(gates_without_top_1.dtype))

    # See notes in lingvo/core/gshard_layers.py.
    if legacy_mtf_behavior:
        # Renormalize.
        denom = gate_1 + gate_2 + 1e-9
        gate_1 /= denom
        gate_2 /= denom

    # We reshape the mask as [X*S, E], and compute cumulative sums of assignment
    # indicators for each expert index e \in 0..E-1 independently.
    # First occurrence of assignment indicator is excluded, see exclusive=True
    # flag below.
    # cumsum over S dim: mask_1 is GSE tensor.
    position_in_expert_1 = cum_sum(mask_1, exclusive=True, axis=-2)

    # GE tensor (reduce S out of GSE tensor mask_1).
    # density_1[:, e] represents assignment ration (num assigned / total) to
    # expert e as top_1 expert without taking capacity into account.
    assert importance.dtype == fprop_dtype
    if legacy_mtf_behavior:
        density_denom = 1.0
    else:
        density_denom = jnp.mean(importance, axis=1)[:, jnp.newaxis] + 1e-6
    density_1 = jnp.mean(mask_1.astype(fprop_dtype), axis=-2) / density_denom
    # density_1_proxy[:, e] represents mean of raw_gates for expert e, including
    # those of examples not assigned to e with top_k
    density_1_proxy = jnp.mean(density_1_proxy, axis=-2) / density_denom

    # Compute aux_loss
    aux_loss = jnp.mean(density_1_proxy * density_1)  # element-wise
    aux_loss *= (experts_dim * experts_dim)  # const coefficients

    # Add the over capacity ratio for expert 1
    over_capacity_1 = _create_over_capacity_ratio_summary(
        mask_1, position_in_expert_1, capacity, 'over_capacity_1')

    mask_1 *= jnp.less(position_in_expert_1,
                       expert_capacity_dim).astype(mask_1.dtype)
    position_in_expert_1 = jnp.einsum('GSE,GSE->GS', position_in_expert_1,
                                      mask_1)

    # How many examples in this sequence go to this expert?
    mask_1_count = jnp.einsum('GSE->GE', mask_1)
    # [batch, group] - mostly ones, but zeros where something didn't fit.
    mask_1_flat = jnp.sum(mask_1, axis=-1)
    assert mask_1_count.dtype == mask_dtype
    assert mask_1_flat.dtype == mask_dtype

    if second_expert_policy == 'all' or second_expert_policy == 'sampling':
        pass
    else:
        assert second_expert_policy == 'random'
        # gate_2 is between 0 and 1, reminder:
        #
        #   raw_gates = jax.nn.softmax(logits)
        #   index_1 = jnp.argmax(raw_gates, axis=-1)
        #   mask_1 = jax.nn.one_hot(index_1, experts_dim, dtpe=fprop_dtype)
        #   gate_1 = jnp.einsum(`GSE,GSE->GS', raw_gates, mask_1)
        #
        # e.g., if gate_2 exceeds second_expert_threshold, then we definitely
        # dispatch to second-best expert. Otherwise, we dispatch with probability
        # proportional to (gate_2 / threshold).
        #
        prng_key, subkey = jax.random.split(prng_key)
        sampled_2 = jnp.less(
            jax.random.uniform(subkey, gate_2.shape, dtype=gate_2.dtype),
            gate_2 / max(second_expert_threshold, 1e-9))
        gate_2 *= sampled_2.astype(gate_2.dtype)
        mask_2 *= jnp.expand_dims(sampled_2, -1).astype(mask_2.dtype)

    position_in_expert_2 = cum_sum(
        mask_2, exclusive=True, axis=-2) + jnp.expand_dims(mask_1_count, -2)
    over_capacity_2 = _create_over_capacity_ratio_summary(
        mask_2, position_in_expert_2, capacity, 'over_capacity_2')

    mask_2 *= jnp.less(position_in_expert_2,
                       expert_capacity_dim).astype(mask_2.dtype)
    position_in_expert_2 = jnp.einsum('GSE,GSE->GS', position_in_expert_2,
                                      mask_2)
    mask_2_flat = jnp.sum(mask_2, axis=-1)

    gate_1 *= mask_1_flat.astype(gate_1.dtype)
    gate_2 *= mask_2_flat.astype(gate_2.dtype)

    if not legacy_mtf_behavior:
        denom = gate_1 + gate_2
        # To avoid divide by 0.
        denom = jnp.where(denom > 0, denom, jnp.ones_like(denom))
        gate_1 /= denom
        gate_2 /= denom

    # GSC tensor
    b = jax.nn.one_hot(position_in_expert_1.astype(np.int32),
                       expert_capacity_dim,
                       dtype=fprop_dtype)
    # GSE tensor
    a = jnp.expand_dims(gate_1 * mask_1_flat.astype(fprop_dtype),
                        axis=-1) * jax.nn.one_hot(
                            index_1, experts_dim, dtype=fprop_dtype)
    # GSEC tensor
    first_part_of_combine_tensor = jnp.einsum('GSE,GSC->GSEC', a, b)

    # GSC tensor
    b = jax.nn.one_hot(position_in_expert_2.astype(np.int32),
                       expert_capacity_dim,
                       dtype=fprop_dtype)
    # GSE tensor
    a = jnp.expand_dims(gate_2 * mask_2_flat.astype(fprop_dtype),
                        axis=-1) * jax.nn.one_hot(
                            index_2, experts_dim, dtype=fprop_dtype)
    second_part_of_combine_tensor = jnp.einsum('GSE,GSC->GSEC', a, b)

    # GSEC tensor
    combine_tensor = first_part_of_combine_tensor + second_part_of_combine_tensor

    # GSEC tensor
    dispatch_tensor = combine_tensor.astype(bool).astype(fprop_dtype)

    return aux_loss, combine_tensor, dispatch_tensor, (over_capacity_1,
                                                       over_capacity_2)
Example #11
0
def piter(f, x, trust_radius, args):

    fg = jax.value_and_grad(f)
    g = jax.grad(f)
    #h = jax.hessian(f)
    h = jax.jacfwd(g)
    #h = jax.jacrev(g)

    #compute function value and gradient
    val, grad = fg(x, *args)
    gradmag = np.linalg.norm(grad, axis=-1)
    gradcol = np.expand_dims(grad, axis=-1)

    #compute hessian and eigen-decomposition
    hess = h(x, *args)
    e, u = np.linalg.eigh(hess)

    ut = u.T

    #convert gradient to eigen-basis
    a = np.matmul(ut, gradcol)
    a = np.squeeze(a, axis=-1)

    lam = e
    e0 = lam[..., 0]

    #TODO deal with null gradient components and repeated eigenvectors
    lambar = lam
    abarsq = np.square(a)

    def phif(s):
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        pmag = np.sqrt(pmagsq)
        phipartial = np.reciprocal(pmag)
        singular = np.any(np.equal(-s, lambar), axis=-1)
        phipartial = np.where(singular, 0., phipartial)
        phi = phipartial - np.reciprocal(trust_radius)
        return phi

    def phiphiprime(s):
        phi = phif(s)
        pmagsq = np.sum(abarsq / np.square(lambar + s), axis=-1)
        phiprime = np.power(pmagsq, -1.5) * np.sum(
            abarsq / np.power(lambar + s, 3), axis=-1)
        return (phi, phiprime)

    #check if unconstrained solution is valid
    sigma0 = np.maximum(-e0, 0.)
    phisigma0 = phif(sigma0)
    usesolu = np.logical_and(e0 > 0., phisigma0 >= 0.)

    sigma = np.max(np.abs(a) / trust_radius - lam, axis=-1)
    sigma = np.maximum(sigma, 0.)
    sigma = np.where(usesolu, 0., sigma)
    phi, phiprime = phiphiprime(sigma)

    #TODO, add handling of additional cases here (singular and "hard" cases)

    #iteratively solve for sigma, enforcing unconstrained solution sigma=0 where appropriate
    unconverged = np.ones(shape=sigma.shape, dtype=np.bool_)
    j = 0
    maxiter = 200

    #This can't work with vmap+jit because of the dynamic condition, so we use the jax while_loop below
    #j = 0
    #while np.logical_and(np.any(unconverged), j<maxiter):
    #sigma = sigma - phi/phiprime
    #sigma = np.where(usesolu, 0., sigma)
    #phiout, phiprimeout = phiphiprime(sigma)
    #unconverged = np.logical_and( (phiout > phi) , (phiout < 0.) )
    #phi,phiprime = (phiout, phiprimeout)
    #j = j +1

    def cond(vals):
        sigma, phi, phiprime, unconverged, j = vals
        return np.logical_and(np.any(unconverged), j < maxiter)

    def body(vals):
        sigma, phi, phiprime, unconverged, j = vals
        sigma = sigma - phi / phiprime
        sigma = np.where(usesolu, 0., sigma)
        phiout, phiprimeout = phiphiprime(sigma)
        unconverged = np.logical_and((phiout > phi), (phiout < 0.))
        phi, phiprime = (phiout, phiprimeout)
        j = j + 1
        return (sigma, phi, phiprime, unconverged, j)

    sigma = jax.lax.while_loop(cond, body,
                               (sigma, phi, phiprime, unconverged, j))[0]

    #compute solution from eigenvalues and eigenvectors
    coeffs = -a / (lam + sigma)
    coeffscol = np.expand_dims(coeffs, axis=-1)

    p = np.matmul(u, coeffscol)
    p = np.squeeze(p, axis=-1)

    #compute predicted reduction in loss function from eigenvalues and eigenvectors
    predicted_reduction = -np.sum(a * coeffs + 0.5 * lam * np.square(coeffs),
                                  axis=-1)

    #compute actual reduction in loss
    x_new = x + p
    val_new = f(x_new, *args)
    #actual_reduction = -(val_new - val)
    actual_reduction = val - val_new

    #update trust radius and output parameters, following Nocedal and Wright 2nd ed. Algorithm 4.1
    eta = 0.15
    trust_radius_max = 1e3
    rho = actual_reduction / np.where(np.equal(actual_reduction, 0.), 1.,
                                      predicted_reduction)
    rho = np.where(np.isnan(rho), 0., rho)
    at_boundary = np.logical_not(usesolu)
    trust_radius_out = np.where(
        rho < 0.25, 0.25 * trust_radius,
        np.where(np.logical_and(rho > 0.75, at_boundary),
                 np.minimum(2. * trust_radius, trust_radius_max),
                 trust_radius))

    x_out = np.where(rho > eta, x_new, x)

    #compute estimated distance to minimum for unconstrained solution (only valid if e0>0)
    coeffs0 = -a / lam
    edm = -np.sum(a * coeffs0 + 0.5 * lam * np.square(coeffs0), axis=-1)

    return x_out, trust_radius_out, val, gradmag, edm, e0
Example #12
0
 def schedule_with_first_step_zero(global_step: jnp.ndarray):
     value = momentum_schedule(global_step)
     check = jnp.equal(global_step, 0)
     return check * jnp.zeros_like(value) + (1 - check) * value
Example #13
0
def equal(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.equal(x1, x2))
Example #14
0
 def equal(self, x, y):
     return jnp.equal(x, y)
Example #15
0
def eq(a: Numeric, b: Numeric):
    return jnp.equal(a, b)
Example #16
0
def entity_linking_loss(mention_encodings: Array, entity_embeddings: Array,
                        mention_target_ids: Array,
                        mention_target_weights: Array, mode: str) -> Array:
    """Compute entity linking loss.

  Args:
    mention_encodings: [n_mentions, hidden_size] mention encodings to be used
      for computing the loss.
    entity_embeddings: [n_entities, hidden_size] entity embeddings table.
    mention_target_ids: [n_mentions] IDs of mentions.
    mention_target_weights: [n_mentions] per-mention weight for computing loss
      and metrics.
    mode: how to compute the scores -- using dot product ('dot'), dot product
      divided by the sqrt root of the hidden dim ('dot_sqrt') or cosine
      similarity ('cos').

  Returns:
    Loss, a dictionary with metrics values, per sample infomation
    (a tuple of accuracy per mention and weight per mention).
  """
    scores = jnp.einsum('qd,ed->qe', mention_encodings, entity_embeddings)
    scores = scores.astype(jnp.float32)

    mention_encodings_norm = jnp.linalg.norm(mention_encodings, axis=-1)
    entity_embeddings_norm = jnp.linalg.norm(entity_embeddings, axis=-1)

    # The cosine similarity is computed as dot product divided by norms of
    # both vectors.
    cos_scores = scores
    cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(mention_encodings_norm, 1))
    cos_scores /= (_SMALL_NUMBER + jnp.expand_dims(entity_embeddings_norm, 0))

    if mode == 'dot':
        pass
    elif mode == 'dot_sqrt':
        hidden_dim = mention_encodings.shape[1]
        scores /= jnp.sqrt(hidden_dim)
    elif mode == 'cos':
        scores = cos_scores
    else:
        raise ValueError('Unknown entity linking mode: ' + mode)

    mention_target_weights = mention_target_weights.astype(jnp.float32)

    loss, _ = metric_utils.compute_weighted_cross_entropy(
        scores,
        mention_target_ids,
        mention_target_weights,
        inputs_are_prob=False)

    acc_per_mention = jnp.equal(jnp.argmax(scores, axis=-1),
                                mention_target_ids)

    acc_per_mention = acc_per_mention * mention_target_weights

    n_mentions = mention_target_ids.shape[0]
    cos_per_mention = cos_scores[jnp.arange(n_mentions), mention_target_ids]
    cos_per_mention = cos_per_mention * mention_target_weights

    metrics = {
        'loss': loss,
        'acc': acc_per_mention.sum(),
        'cos_sim': cos_per_mention.sum(),
        'denominator': mention_target_weights.sum()
    }
    return loss, metrics, (acc_per_mention, mention_target_weights)
Example #17
0
 def _equal(a, b):
     return jnp.equal(a, b)
Example #18
0
def accuracy(batch, model_predictions):
    """Calculate accuracy."""
    _, targets = batch
    predicted_class = np.argmax(model_predictions, axis=-1)
    correct = np.equal(predicted_class, targets)
    return masked_mean(correct, targets)
Example #19
0
def accuracy(logits: chex.Array, labels: chex.Array) -> chex.Array:
  predicted_label = jnp.argmax(logits, axis=-1)
  correct = jnp.equal(predicted_label, labels).astype(jnp.float32)
  return jnp.sum(correct, axis=0) / logits.shape[0]
Example #20
0
        def loss_fn(
            model_config: ml_collections.FrozenConfigDict,
            model_params: Dict[str, Any],
            model_vars: Dict[str, Any],
            batch: Dict[str, Any],
            deterministic: bool,
            dropout_rng: Optional[Dict[str, Array]] = None,
        ) -> Tuple[float, MetricGroups, Dict[str, Any]]:
            """Model-specific loss function. See BaseTask."""

            batch_size = batch['text_ids'].shape[0]
            mention_target_ids = batch['mention_target_ids']
            mention_target_ids *= batch['mention_target_weights']

            variable_dict = {'params': model_params}
            variable_dict.update(model_vars)
            loss_helpers, logging_helpers = cls.build_model(
                model_config).apply(variable_dict,
                                    batch,
                                    deterministic=deterministic,
                                    rngs=dropout_rng)

            mlm_logits = loss_helpers['mlm_logits']
            mlm_target_is_mention = batch['mlm_target_is_mention']
            mlm_target_is_not_mention = 1 - batch['mlm_target_is_mention']
            mention_target_is_masked = batch['mention_target_is_masked']
            mention_target_is_not_masked = 1 - mention_target_is_masked
            mlm_loss, mlm_denom = metric_utils.compute_weighted_cross_entropy(
                mlm_logits, batch['mlm_target_ids'],
                batch['mlm_target_weights'])
            correct_mask = jnp.equal(
                jnp.argmax(mlm_logits, axis=-1),
                batch['mlm_target_ids']) * batch['mlm_target_weights']
            mlm_acc = correct_mask.sum()
            mlm_mention_acc = (correct_mask * mlm_target_is_mention).sum()
            mlm_mention_denom = (batch['mlm_target_weights'] *
                                 mlm_target_is_mention).sum()
            mlm_non_mention_acc = (correct_mask *
                                   mlm_target_is_not_mention).sum()
            mlm_non_mention_denom = (batch['mlm_target_weights'] *
                                     mlm_target_is_not_mention).sum()
            loss = mlm_weight * mlm_loss / mlm_denom

            metrics = {
                'mlm': {
                    'loss': mlm_loss,
                    'acc': mlm_acc,
                    'denominator': mlm_denom,
                },
                'mlm_mention': {
                    'acc': mlm_mention_acc,
                    'denominator': mlm_mention_denom,
                },
                'mlm_non_mention': {
                    'acc': mlm_non_mention_acc,
                    'denominator': mlm_non_mention_denom,
                },
            }

            def process_el_im_loss(loss, weight, prefix=''):
                memory_attention_weights = loss_helpers[
                    prefix + 'memory_attention_weights']
                memory_entity_ids = loss_helpers[prefix + 'top_entity_ids']

                target_mentions_memory_attention_weights = jut.matmul_slice(
                    memory_attention_weights, batch['mention_target_indices'])

                intermediate_entity_ids = jut.matmul_slice(
                    memory_entity_ids, batch['mention_target_indices'])

                el_loss_intermediate, same_entity_avg_prob, el_im_denom = metric_utils.compute_loss_and_prob_from_probs_with_duplicates(
                    target_mentions_memory_attention_weights,
                    intermediate_entity_ids, mention_target_ids,
                    batch['mention_target_weights'])

                if weight > 0:
                    loss += weight * el_loss_intermediate / el_im_denom
                metrics[prefix + 'el_intermediate'] = {
                    'loss': el_loss_intermediate,
                    'same_entity_avg_prob': same_entity_avg_prob,
                    'denominator': el_im_denom,
                }
                return loss

            loss = process_el_im_loss(loss, el_im_weight)
            if 'second_memory_attention_weights' in loss_helpers:
                loss = process_el_im_loss(loss, el_second_im_weight, 'second_')

            if coref_res_weight > 0:
                (coref_res_loss, coref_res_metrics
                 ) = mention_losses.coreference_resolution_loss(
                     loss_helpers['target_mention_encodings'],
                     batch['mention_target_batch_positions'],
                     mention_target_ids, batch_size, coref_res_mode,
                     mention_target_is_masked)
                coref_res_denom = coref_res_metrics['coref_resolution'][
                    'denominator']
                loss += coref_res_weight * coref_res_loss / coref_res_denom
                metrics.update(coref_res_metrics)

            if mtb_im_weight > 0:
                (mtb_im_loss, mtb_im_metrics) = mention_losses.mtb_loss(
                    loss_helpers['intermediate_target_mention_encodings'],
                    batch['mention_target_batch_positions'],
                    mention_target_ids, batch_size, mtb_score_mode,
                    mention_target_is_masked, 'im_')
                mtb_im_denom = mtb_im_metrics['im_mtb']['denominator']
                loss += mtb_im_weight * mtb_im_loss / mtb_im_denom
                metrics.update(mtb_im_metrics)

            if mtb_final_weight > 0:
                (mtb_final_loss, mtb_final_metrics) = mention_losses.mtb_loss(
                    loss_helpers['target_mention_encodings'],
                    batch['mention_target_batch_positions'],
                    mention_target_ids, batch_size, mtb_score_mode,
                    mention_target_is_masked, 'final_')
                mtb_final_denom = mtb_final_metrics['final_mtb']['denominator']
                loss += mtb_final_weight * mtb_final_loss / mtb_final_denom
                metrics.update(mtb_final_metrics)

            if same_passage_weight > 0:
                same_passage_mask = loss_helpers[
                    'memory_attention_disallowed_mask']
                (
                    same_passage_loss, same_passage_metrics, _
                ) = metric_utils.compute_cross_entropy_loss_with_positives_and_negatives_masks(
                    loss_helpers['memory_attention_scores_with_disallowed'],
                    same_passage_mask, jnp.logical_not(same_passage_mask),
                    batch['mention_mask'])
                same_passage_denom = same_passage_metrics['denominator']
                loss += same_passage_weight * same_passage_loss / same_passage_denom
                metrics['same_passage'] = same_passage_metrics

            if same_entity_set_retrieval_weight > 0:
                if config.get('same_entity_set_target_threshold') is None:
                    raise ValueError(
                        '`same_entitites_retrieval_threshold` must be specified '
                        'if `same_entity_set_retrieval_weight` is provided')

                (same_entity_set_retrieval_loss,
                 same_entity_set_retrieval_avg_prob,
                 same_entity_set_retrieval_denom
                 ) = mention_losses.same_entity_set_retrieval_loss(
                     mention_target_batch_positions=batch[
                         'mention_target_batch_positions'],
                     mention_target_ids=mention_target_ids,
                     mention_target_weights=batch['mention_target_weights'],
                     mention_batch_positions=batch['mention_batch_positions'],
                     mention_mask=batch['mention_mask'],
                     memory_text_entities=loss_helpers[
                         'memory_top_text_entities'],
                     memory_attention_weights=loss_helpers[
                         'memory_attention_weights'],
                     memory_mask=1 -
                     loss_helpers['memory_attention_disallowed_mask'],
                     batch_size=batch_size,
                     same_entity_set_target_threshold=config.
                     same_entity_set_target_threshold)

                loss += (same_entity_set_retrieval_weight *
                         same_entity_set_retrieval_loss /
                         same_entity_set_retrieval_denom)

                metrics['same_entity_set_retrieval'] = {
                    'loss': same_entity_set_retrieval_loss,
                    'avg_prob': same_entity_set_retrieval_avg_prob,
                    'denominator': same_entity_set_retrieval_denom,
                }

            if el_final_weight > 0:
                final_attention_weights = loss_helpers[
                    'final_memory_attention_weights']
                final_memory_entity_ids = loss_helpers['final_top_entity_ids']

                (
                    el_loss_final, same_entity_avg_prob_final, el_loss_denom
                ) = metric_utils.compute_loss_and_prob_from_probs_with_duplicates(
                    final_attention_weights, final_memory_entity_ids,
                    mention_target_ids, batch['mention_target_weights'])

                (
                    _, same_entity_avg_prob_final_masked, el_loss_denom_masked
                ) = metric_utils.compute_loss_and_prob_from_probs_with_duplicates(
                    final_attention_weights, final_memory_entity_ids,
                    mention_target_ids,
                    batch['mention_target_weights'] * mention_target_is_masked)

                (
                    _, same_entity_avg_prob_final_not_masked,
                    el_loss_denom_not_masked
                ) = metric_utils.compute_loss_and_prob_from_probs_with_duplicates(
                    final_attention_weights, final_memory_entity_ids,
                    mention_target_ids, batch['mention_target_weights'] *
                    mention_target_is_not_masked)

                metrics['el_final'] = {
                    'loss': el_loss_final,
                    'same_entity_avg_prob': same_entity_avg_prob_final,
                    'denominator': el_loss_denom,
                }
                metrics['el_final_masked'] = {
                    'same_entity_avg_prob': same_entity_avg_prob_final_masked,
                    'denominator': el_loss_denom_masked,
                }
                metrics['el_final_not_masked'] = {
                    'same_entity_avg_prob':
                    same_entity_avg_prob_final_not_masked,
                    'denominator': el_loss_denom_not_masked,
                }
                loss += el_final_weight * el_loss_final / (
                    el_loss_denom + default_values.SMALL_NUMBER)

            metrics['agg'] = {
                'loss': loss,
                'denominator': 1.0,
            }

            if 'n_disallowed' in logging_helpers:
                metrics['disallowed'] = {
                    'per_mention': logging_helpers['n_disallowed'],
                    'denominator': batch['mention_mask'].sum(),
                }

            if 'second_n_disallowed' in logging_helpers:
                metrics['second_n_disallowed'] = {
                    'per_mention': logging_helpers['second_n_disallowed'],
                    'denominator': batch['mention_mask'].sum(),
                }

            auxiliary_output = {
                'top_entity_ids': loss_helpers['top_entity_ids'],
                'top_memory_ids': loss_helpers['top_memory_ids'],
            }

            if 'second_top_entity_ids' in loss_helpers:
                auxiliary_output['second_top_entity_ids'] = loss_helpers[
                    'second_top_entity_ids']
                auxiliary_output['second_top_memory_ids'] = loss_helpers[
                    'second_top_memory_ids']

            return loss, metrics, auxiliary_output
Example #21
0
File: test.py Project: bkompa/pygau
import utils

x1 = np.array([[1., 2.]])
x2 = np.array([[1., 2.]])
x3 = np.array([[1., 3.]])

X = np.vstack((x1, x2, x3))

# test the squared distance function
assert(utils.squared_distance(x1, x2)[0][0] == 0.)
assert(utils.squared_distance(x1, x3)[0][0] == 1.)

# test the squared exponential function
assert(utils.squared_exponential(x1, x2)[0][0] == 1.)
assert(utils.squared_exponential(x1, x3)[0][0] == np.exp(-.5))

# test the kernel function for the same matrix
k_squared_distance = np.array([[0., 0., 1.],
                               [0., 0., 1.],
                               [1., 1., 0.]])
assert(np.equal(utils.kernel_matrix(X, X, utils.squared_distance, None), k_squared_distance).all())

k_se_distance = np.array([[1., 1., np.exp(-.5)],
                          [1., 1., np.exp(-.5)],
                          [np.exp(-.5), np.exp(-.5), 1.]])
assert(np.equal(utils.kernel_matrix(X, X, utils.squared_exponential, None), k_se_distance).all())

# test the kernel function for a matrix of different dimensions
k_001 = np.array([[0., 0., 1.]])
assert(np.equal(utils.kernel_matrix(X, x1, utils.squared_distance, None), k_001.T).all())
assert(np.equal(utils.kernel_matrix(x1, X, utils.squared_distance, None), k_001).all())
Example #22
0
        def loss_fn(
            model_config: ml_collections.FrozenConfigDict,
            model_params: Dict[str, Any],
            model_vars: Dict[str, Any],
            batch: Dict[str, Any],
            deterministic: bool,
            dropout_rng: Optional[Dict[str, Array]] = None,
        ) -> Tuple[float, MetricGroups, Dict[str, Any]]:
            """Model-specific loss function. See BaseTask."""

            variable_dict = {'params': model_params}
            variable_dict.update(model_vars)
            loss_helpers, _ = cls.build_model(model_config).apply(
                variable_dict,
                batch,
                deterministic=deterministic,
                rngs=dropout_rng)

            mlm_logits = loss_helpers['mlm_logits']
            mlm_target_is_mention = batch['mlm_target_is_mention']
            mlm_target_is_not_mention = 1 - batch['mlm_target_is_mention']
            mention_target_is_masked = batch['mention_target_is_masked']

            mlm_loss, mlm_denom = metric_utils.compute_weighted_cross_entropy(
                mlm_logits, batch['mlm_target_ids'],
                batch['mlm_target_weights'])
            correct_mask = jnp.equal(
                jnp.argmax(mlm_logits, axis=-1),
                batch['mlm_target_ids']) * batch['mlm_target_weights']
            mlm_acc = correct_mask.sum()
            mlm_mention_acc = (correct_mask * mlm_target_is_mention).sum()
            mlm_mention_denom = (batch['mlm_target_weights'] *
                                 mlm_target_is_mention).sum()
            mlm_non_mention_acc = (correct_mask *
                                   mlm_target_is_not_mention).sum()
            mlm_non_mention_denom = (batch['mlm_target_weights'] *
                                     mlm_target_is_not_mention).sum()

            loss = mlm_weight * mlm_loss / mlm_denom

            metrics = {
                'mlm': {
                    'loss': mlm_loss,
                    'acc': mlm_acc,
                    'denominator': mlm_denom,
                },
                'mlm_mention': {
                    'acc': mlm_mention_acc,
                    'denominator': mlm_mention_denom,
                },
                'mlm_non_mention': {
                    'acc': mlm_non_mention_acc,
                    'denominator': mlm_non_mention_denom,
                },
            }

            if coref_res_weight > 0:
                batch_size = batch['text_ids'].shape[0]
                mention_target_ids = batch['mention_target_ids']
                mention_target_ids = mention_target_ids * batch[
                    'mention_target_weights']

                (coref_res_loss, coref_res_metrics
                 ) = mention_losses.coreference_resolution_loss(
                     loss_helpers['target_mention_encodings'],
                     batch['mention_target_batch_positions'],
                     mention_target_ids, batch_size, coref_res_mode,
                     mention_target_is_masked)
                coref_res_denom = coref_res_metrics['coref_resolution'][
                    'denominator']
                loss += coref_res_weight * coref_res_loss / coref_res_denom
                metrics.update(coref_res_metrics)

            metrics['agg'] = {
                'loss': loss,
                'denominator': 1.0,
            }

            return loss, metrics, {}
Example #23
0
import numpy as onp
from functools import partial

#@jax.jit
def f(x):
    for _ in range(10):
        y = jax.device_put(x)
        x = jax.device_get(y)
    return x

#@jax.jit
def g(x):
    return x - 3.


xs = [onp.random.randn(i) for i in range(10)]
ed = jnp.ediff1d(xs[1], 0., -10.)

print("now concurrency")

with concurrent.futures.ProcessPoolExecutor() as executor:
    futures = [executor.submit(partial(f, x)) for x in xs]
    ys = [f.result() for f in futures]
for x, y in zip(xs, ys):
    if not all(jnp.equal(x,y)):
        print("These are not the same")
        print(x)
        print(y)
        print("----------------------")
print("done")
Example #24
0
        utils.numpy_dtype(dtype)(np.count_nonzero(input, axis))))

cumprod = utils.copy_docstring(tf.math.cumprod, _cumprod)

cumsum = utils.copy_docstring(tf.math.cumsum, _cumsum)

digamma = utils.copy_docstring(tf.math.digamma,
                               lambda x, name=None: scipy_special.digamma(x))

divide = utils.copy_docstring(tf.math.divide,
                              lambda x, y, name=None: np.divide(x, y))

divide_no_nan = utils.copy_docstring(
    tf.math.divide_no_nan,
    lambda x, y, name=None: np.where(  # pylint: disable=g-long-lambda
        onp.broadcast_to(np.equal(y, 0.),
                         np.array(x).shape), np.zeros_like(np.divide(x, y)),
        np.divide(x, y)))

equal = utils.copy_docstring(tf.math.equal,
                             lambda x, y, name=None: np.equal(x, y))

erf = utils.copy_docstring(tf.math.erf,
                           lambda x, name=None: scipy_special.erf(x))

erfc = utils.copy_docstring(tf.math.erfc,
                            lambda x, name=None: scipy_special.erfc(x))

exp = utils.copy_docstring(tf.math.exp, lambda x, name=None: np.exp(x))

expm1 = utils.copy_docstring(tf.math.expm1, lambda x, name=None: np.expm1(x))
Example #25
0
 def evaluate_batch(self, images, labels):
     logits = self.model(images, training=False)
     num_correct = jn.count_nonzero(
         jn.equal(jn.argmax(logits, axis=1), labels))
     return num_correct
Example #26
0
def collapse_and_remove_blanks(labels: jnp.ndarray,
                               seq_length: jnp.ndarray,
                               blank_id: int = 0):
  """Merge repeated labels into single labels and remove the designated blank symbol.

  Args:
    labels: Array of shape (batch, seq_length)
    seq_length: Arrray of shape (batch), sequence length of each batch element.
    blank_id: Optional id of the blank symbol

  Returns:
    tuple of tf.SparseTensor of shape (batch, seq_length) with repeated labels
    collapsed, eg: [[A, A, B, B, A],
                    [A, B, C, D, E]] => [[A, B, A],
                                         [A, B, C, D, E]]
    and int tensor of shape [batch] with new sequence lengths.
  """
  b, t = labels.shape
  # Zap out blank
  blank_mask = 1 - jnp.equal(labels, blank_id)
  labels = (labels * blank_mask).astype(labels.dtype)

  # Mask labels that don't equal previous label.
  label_mask = jnp.concatenate([
      jnp.ones_like(labels[:, :1], dtype=jnp.int32),
      jnp.not_equal(labels[:, 1:], labels[:, :-1])
  ],
                               axis=1)

  # Filter labels that aren't in the original sequence.
  maxlen = labels.shape[1]
  seq_mask = sequence_mask(seq_length, maxlen=maxlen)
  label_mask = label_mask * seq_mask

  # remove repetitions from the labels
  ulabels = label_mask * labels

  # Count masks for new sequence lengths.
  label_mask = jnp.not_equal(ulabels, 0).astype(labels.dtype)
  new_seq_len = jnp.sum(label_mask, axis=1)

  # Mask indexes based on sequence length mask.
  new_maxlen = maxlen
  idx_mask = sequence_mask(new_seq_len, maxlen=new_maxlen)

  # Flatten everything and mask out labels to keep and sparse indices.
  flat_labels = jnp.reshape(ulabels, [-1])
  flat_idx_mask = jnp.reshape(idx_mask, [-1])

  indices = jnp.nonzero(flat_idx_mask, size=b * t)[0]
  values = jnp.nonzero(flat_labels, size=b * t)[0]
  updates = jnp.take_along_axis(flat_labels, values, axis=-1)

  # Scatter to flat shape.
  flat = jnp.zeros(flat_idx_mask.shape).astype(labels.dtype)
  flat = flat.at[indices].set(updates)
  # 0'th position in the flat array gets clobbered by later padded updates,
  # so reset it here to its original value
  flat = flat.at[0].set(updates[0])

  # Reshape back to square batch.
  batch_size = labels.shape[0]
  new_shape = [batch_size, new_maxlen]
  return (jnp.reshape(flat, new_shape).astype(labels.dtype),
          new_seq_len.astype(seq_length.dtype))
Example #27
0
 def hvp(self, x, v, *args):
     if self.x is None or not np.equal(x, self.x).all():
         self.vhessres = self._vhess(x.reshape((self.nblocks, -1)), *args)
         self.x = x
     return self._vmatmul(self.vhessres, v.reshape(
         (self.nblocks, -1))).reshape((-1, ))
Example #28
0
def multiply_no_nan(x, y):
    dtype = jnp.result_type(x, y)
    return jnp.where(jnp.equal(x, 0.0), jnp.zeros((), dtype=dtype), jnp.multiply(x, y))
    def update(self, params, x, y, loss=None):
        """
        Description: Updates parameters based on correct value, loss and learning rate.
        Args:
            params (list/numpy.ndarray): Parameters of method pred method
            x (float): input to method
            y (float): true label
            loss (function): loss function. defaults to input value.
        Returns:
            Updated parameters in same shape as input
        """
        assert self.initialized
        assert type(
            params
        ) == dict, "optimizers can only take params in dictionary format"

        grad = self.gradient(params, x, y,
                             loss=loss)  # defined in optimizers core class

        if self.theta is None:
            self.theta = {
                k: -dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.theta = {
                k: v - dw
                for (k, v), dw in zip(self.theta.items(), grad.values())
            }

        if self.eta is None:
            self.eta = {
                k: dw * dw
                for (k, w), dw in zip(params.items(), grad.values())
            }
        else:
            self.eta = {
                k: v + dw * dw
                for (k, v), dw in zip(self.eta.items(), grad.values())
            }

        if self.theta_max is None:
            self.theta_max = {
                k: np.absolute(v)
                for (k, v) in self.theta.items()
            }
        else:
            self.theta_max = {
                k: np.where(np.greater(np.absolute(v), v_max), np.absolute(v),
                            v_max)
                for (k, v), v_max in zip(self.theta.items(),
                                         self.theta_max.values())
            }

        new_params = {
            k: np.where(np.equal(0.0, np.maximum(theta_max, eta)), theta,
                        theta / np.sqrt(np.maximum(theta_max, eta)))
            for (k, w), theta, theta_max, eta in zip(params.items(
            ), self.theta.values(), self.theta_max.values(), self.eta.values())
        }

        x_new = np.roll(x, 1)
        x_new = jax.ops.index_update(x_new, 0, y)
        y_t = self.pred(params=new_params, x=x_new)

        #        print('y before {0}'.format(y_t))
        x_plus_bias_new = np.vstack((np.ones((1, 1)), x_new))
        new_mapped_params = {
            k: self.norm_project(
                np.where(np.equal(0.0, np.maximum(theta_max, eta)), 0.0,
                         1.0 / np.sqrt(np.maximum(theta_max, eta))),
                x_plus_bias_new, y_t, p)
            for (k, p), theta_max, eta in zip(
                new_params.items(), self.theta_max.values(), self.eta.values())
        }

        #        y_t = self.pred(params=new_mapped_params, x=x_new)
        #        print('y after {0}'.format(y_t))
        return new_mapped_params
Example #30
0
value_and_grad = jax.value_and_grad(loss)
# opt_def = Adam(learning_rate=hyperparams.oracle_lr)
opt_def = Momentum(learning_rate=1e-3, weight_decay=1e-4, nesterov=True)

opt = opt_def.create(target=params)


def train_op(opt, x, y):
    v, g = value_and_grad(opt.target, x, y)
    return v, opt.apply_gradient(g)


train_op = jax.jit(train_op)
for step in range(400000):
    key, subkey = random.split(key)
    index = random.randint(subkey,
                           shape=(hyperparams.oracle_batch_size, ),
                           minval=0,
                           maxval=x_train.shape[0])
    v, opt = train_op(opt, x_train[index], y_train[index])
    if step % 500 == 0:
        print("test sgd result")
        f_x_test = model.apply(opt.target, x_test)
        test_loss = v_ce(f_x_test, y_test)
        pred = jnp.argmax(f_x_test, axis=1)
        corrct = jnp.true_divide(
            jnp.sum(jnp.equal(pred, jnp.reshape(y_test, pred.shape))),
            y_test.shape[0])
        print("step %5d, test accuracy % .4f" % (step, corrct))