def _Accuracy(inputs, axis=-1, **unused_kwargs): """Returns a layer to score matches of predicted versus target categories.""" y_hat, target_category = inputs predicted_category = np.argmax(y_hat, axis=axis) # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. # shapes.assert_same_shape(predicted_category, target_category) return np.equal(predicted_category, target_category).astype(np.float32)
def sample(self, inputs, temperature=1.0): # No need for LogSoftmax with Gumbel sampling - softmax normalization is # subtracting a constant from every logit, and Gumbel sampling is taking # a max over logits plus noise, so invariant to adding a constant. if temperature == 0.0: return jnp.argmax(self._unflatten_inputs(inputs), axis=-1) return tl.gumbel_sample(self._unflatten_inputs(inputs), temperature)
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. if isinstance(self.n_buckets, int): assert self.n_buckets % 2 == 0 rot_size = self.n_buckets n_buckets = self.n_buckets else: # Factorize the hash if self.n_buckets is a list or tuple rot_size, n_buckets = 0, 1 for factor in self.n_buckets: assert factor % 2 == 0 rot_size += factor n_buckets *= factor rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2) rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng)) random_rotations = jax.random.normal(rng, rotations_shape).astype('float32') rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations) if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) else: # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in self.n_buckets: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) return buckets
def Accuracy(inputs, axis=-1, **unused_kwargs): prediction, target = inputs predicted_class = np.argmax(prediction, axis=axis) return np.equal(predicted_class, target)
def _Accuracy(inputs, axis=-1, **unused_kwargs): """Returns a layer to score matches of predicted versus target categories.""" y_hat, target_category = inputs predicted_category = np.argmax(y_hat, axis=axis) return np.equal(predicted_category, target_category).astype(np.float32)
def f(y_hat, target_category): # pylint: disable=invalid-name predicted_category = np.argmax(y_hat, axis=axis) # TODO(pkozakowski): This assertion breaks some tests. Fix and uncomment. # shapes.assert_same_shape(predicted_category, target_category) return np.equal(predicted_category, target_category).astype(np.float32)
def Accuracy(x, axis=-1, **kw): del kw prediction, target = x predicted_class = np.argmax(prediction, axis=axis) return np.equal(predicted_class, target)