示例#1
0
def _Accuracy(inputs, axis=-1, **unused_kwargs):
    """Returns a layer to score matches of predicted versus target categories."""
    y_hat, target_category = inputs
    predicted_category = np.argmax(y_hat, axis=axis)
    # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
    # shapes.assert_same_shape(predicted_category, target_category)
    return np.equal(predicted_category, target_category).astype(np.float32)
示例#2
0
 def sample(self, inputs, temperature=1.0):
     # No need for LogSoftmax with Gumbel sampling - softmax normalization is
     # subtracting a constant from every logit, and Gumbel sampling is taking
     # a max over logits plus noise, so invariant to adding a constant.
     if temperature == 0.0:
         return jnp.argmax(self._unflatten_inputs(inputs), axis=-1)
     return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature)
示例#3
0
    def hash_vectors(self, vecs, rng):
        # See https://arxiv.org/pdf/1509.02897.pdf
        # We sample a different random rotation for each round of hashing to
        # decrease the probability of hash misses.
        if isinstance(self.n_buckets, int):
            assert self.n_buckets % 2 == 0
            rot_size = self.n_buckets
            n_buckets = self.n_buckets
        else:
            # Factorize the hash if self.n_buckets is a list or tuple
            rot_size, n_buckets = 0, 1
            for factor in self.n_buckets:
                assert factor % 2 == 0
                rot_size += factor
                n_buckets *= factor

        rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2)

        rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng))
        random_rotations = jax.random.normal(rng,
                                             rotations_shape).astype('float32')
        rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations)

        if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1:
            rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs],
                                          axis=-1)
            buckets = np.argmax(rotated_vecs, axis=-1)
        else:
            # Get the buckets for them and combine.
            buckets, cur_sum, cur_product = None, 0, 1
            for factor in self.n_buckets:
                rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)]
                cur_sum += factor // 2
                rv = np.concatenate([rv, -rv], axis=-1)
                if buckets is None:
                    buckets = np.argmax(rv, axis=-1)
                else:
                    buckets += cur_product * np.argmax(rv, axis=-1)
                cur_product *= factor

        # buckets is now (self.n_hashes, seqlen). Next we add offsets so that
        # bucket numbers from different hashing rounds don't overlap.
        offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes))
        offsets = np.reshape(offsets * n_buckets, (-1, 1))
        buckets = np.reshape(buckets + offsets, (-1, ))

        return buckets
示例#4
0
def Accuracy(inputs, axis=-1, **unused_kwargs):
    prediction, target = inputs
    predicted_class = np.argmax(prediction, axis=axis)
    return np.equal(predicted_class, target)
示例#5
0
文件: metrics.py 项目: zsunpku/trax
def _Accuracy(inputs, axis=-1, **unused_kwargs):
  """Returns a layer to score matches of predicted versus target categories."""
  y_hat, target_category = inputs
  predicted_category = np.argmax(y_hat, axis=axis)
  return np.equal(predicted_category, target_category).astype(np.float32)
示例#6
0
文件: metrics.py 项目: zhaoqiuye/trax
 def f(y_hat, target_category):  # pylint: disable=invalid-name
     predicted_category = np.argmax(y_hat, axis=axis)
     # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment.
     # shapes.assert_same_shape(predicted_category, target_category)
     return np.equal(predicted_category, target_category).astype(np.float32)
示例#7
0
文件: metrics.py 项目: galloperx/trax
def Accuracy(x, axis=-1, **kw):
    del kw
    prediction, target = x
    predicted_class = np.argmax(prediction, axis=axis)
    return np.equal(predicted_class, target)