コード例 #1
0
    def hash_vectors(self, vecs, rng):
        if self.bin_by_time:
            # Instead of hashing, put chunks of consecutive items in the same bin.
            # This exists as a sanity check for the other parts of this class.
            return self.bin_vectors_by_time(vecs)

        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each batch element, head, and
        # (crucially) each round of hashing. All of these are part of dimension 0
        # of vecs. Applying multiple hashes to the same input is important because
        # it increases the probability of being in the same bin as relevant items.
        n_buckets = self.n_buckets_per_bin * self.n_bins
        assert n_buckets % 2 == 0
        rot_rng = rng
        if self._one_rng:
            rot_rng = jax.lax.tie_in(vecs, self._prng)
        random_rotation = jax.random.normal(
            rot_rng,
            (vecs.shape[0], vecs.shape[-1], n_buckets // 2)).astype('float32')

        # TODO(kitaev): making the vectors unit-length here is probably redundant.
        # vecs = self.make_unit_length(vecs)
        rng, subrng = backend.random.split(rng)
        vecs = self.drop_for_hash(vecs, subrng)
        rotated_vecs = np.matmul(vecs, random_rotation)
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        bins = np.argmax(rotated_vecs, axis=-1)
        return bins
コード例 #2
0
def accuracy(batch, model_predictions):
  """Calculate accuracy."""
  _, targets = batch
  model_predictions, targets = _make_list(model_predictions, targets)
  correct = []
  for (prediction, target) in zip(model_predictions, targets):
    predicted_class = np.argmax(prediction, axis=-1)
    correct.append(np.equal(predicted_class, target))
  return masked_mean(correct, targets)
コード例 #3
0
  def hash_vectors(self, vecs, rng):
    if self.bin_by_time:
      # Instead of hashing, put chunks of consecutive items in the same bin.
      # This exists as a sanity check for the other parts of this class.
      return self.bin_vectors_by_time(vecs)

    # See https://arxiv.org/pdf/1509.02897.pdf
    assert self.n_bins % 2 == 0
    random_rotation = jax.random.normal(
        rng, (vecs.shape[-1], self.n_bins//2)).astype('float32')

    # TODO(kitaev): making the vectors unit-length here is probably redundant.
    vecs = self.make_unit_length(vecs)
    rotated_vecs = np.matmul(vecs, random_rotation)
    rotated_vecs = self.make_unit_length(rotated_vecs)
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
    bins = np.argmax(rotated_vecs, axis=-1)
    return bins
コード例 #4
0
  def hash_vectors(self, vecs, rng):
    # See https://arxiv.org/pdf/1509.02897.pdf
    # We sample a different random rotation for each round of hashing to
    # decrease the probability of hash misses.
    assert self.n_buckets % 2 == 0
    random_rotations_shape = (
        vecs.shape[-1],
        self.n_hashes if self._rehash_each_round else 1,
        self.n_buckets // 2)

    rng = jax.lax.tie_in(vecs, rng)
    rng, subrng = backend.random.split(rng)
    random_rotations = jax.random.normal(
        rng, random_rotations_shape).astype('float32')
    # TODO(lukaszkaiser): the dropout mask will be used for all rounds of
    # hashing, so it's shared between them. Check if that's what we want.
    dropped_vecs = self.drop_for_hash(vecs, subrng)
    rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations)
    rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)

    if self._rehash_each_round:
      buckets = np.argmax(rotated_vecs, axis=-1)
      # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
      # bucket numbers from different hashing rounds don't overlap.
      offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
      offsets = np.reshape(offsets * self.n_buckets, (-1, 1))
      buckets = np.reshape(buckets + offsets, (-1,))
    else:
      # In this configuration, we map each item to the top self.n_hashes buckets
      rotated_vecs = np.squeeze(rotated_vecs, 0)
      bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1]))
      bucket_range = np.reshape(bucket_range, (1, -1))
      bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape)

      _, buckets = jax.lax.sort_key_val(
          rotated_vecs, bucket_range, dimension=-1)
      buckets = buckets[:, -self.n_hashes:]
      buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,))

    return buckets
コード例 #5
0
    def hash_vectors(self, vecs, rng):
        if self.bin_by_time:
            # Instead of hashing, put chunks of consecutive items in the same bin.
            # This exists as a sanity check for the other parts of this class.
            return self.bin_vectors_by_time(vecs)

        # See https://arxiv.org/pdf/1509.02897.pdf
        # It's not clear whether sampling a different random rotation for each head
        # and batch element matters here, but see MergedMultiHashedCausalAttention.
        assert self.n_bins % 2 == 0
        rot_rng = rng
        if self._one_rng:
            rot_rng = jax.lax.tie_in(vecs, self._prng)
        random_rotation = jax.random.normal(
            rot_rng, (vecs.shape[0], vecs.shape[-1],
                      self.n_bins // 2)).astype('float32')

        # TODO(kitaev): making the vectors unit-length here is probably redundant.
        vecs = self.make_unit_length(vecs)
        rotated_vecs = np.matmul(vecs, random_rotation)
        rotated_vecs = self.make_unit_length(rotated_vecs)
        rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1)
        bins = np.argmax(rotated_vecs, axis=-1)
        return bins
コード例 #6
0
def accuracy(batch, model_predictions):
  """Calculate accuracy."""
  _, targets = batch
  predicted_class = np.argmax(model_predictions, axis=-1)
  correct = np.equal(predicted_class, targets)
  return masked_mean(correct, targets)
コード例 #7
0
def Accuracy(x, axis=-1, **kw):
    del kw
    prediction, target = x
    predicted_class = np.argmax(prediction, axis=axis)
    return np.equal(predicted_class, target)