def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Compute log N(x | mu, sigma).""" a = mu.shape[-1] * np.log(2 * np.pi) _, b = np.linalg.slogdet(sigma) y = np.linalg.solve(sigma, x - mu) y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Compute log N(x | mu, eye(diag_sigma)).""" a = mu.shape[-1] * np.log(2 * np.pi) b = np.sum(np.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = np.expand_dims(y, axis=-1) xm = np.expand_dims(x - mu, axis=-2) c = np.matmul(xm, y) c = np.squeeze(np.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def dataset_to_stream(dataset, input_name): """Takes a tf.Dataset and creates a numpy stream of ready batches.""" for example in backend.dataset_as_numpy(dataset): features = example[0] inp, out = features[input_name], example[1] mask = features['mask'] if 'mask' in features else None # All input-pipeline processing should be on CPU. with tf.device('cpu:0'): # Some accelerators don't handle uint8 well, cast to int. if isinstance(inp, np.uint8): inp = inp.astype(np.int32) if isinstance(out, np.uint8): out = out.astype(np.int32) if len(out.shape) > 1 and out.shape[-1] == 1: out = np.squeeze(out, axis=-1) yield (inp, out) if mask is None else (inp, out, mask)
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 # If we factorize the hash, find a factor dividing n_buckets nicely. rot_size, factor_list = self.n_buckets, [self.n_buckets] if self._factorize_hash: # If we are given a list of factors, verify it and use later. if isinstance(self._factorize_hash, list): rot_size, product = 0, 1 factor_list = self._factorize_hash for factor in factor_list: assert factor % 2 == 0 product *= factor rot_size += factor assert product == self.n_buckets else: # Find one factor if just set to True. # We want to represent self.n_buckets = factor * rest so that # (1) both factor and rest are even, and (2) factor + rest is minimal. # To compute this we start from factor = sqrt(n_buckets) and go down # with it until we find one that satisfies the constraints above. factor = int(math.sqrt(self.n_buckets)) while factor > 0 and not (self.n_buckets % factor == 0 and factor % 2 == 0 and (self.n_buckets // factor) % 2 == 0): factor -= 1 if factor > 2: # Factor of 2 does not warrant the effort. rot_size = factor + (self.n_buckets // factor) factor_list = [factor, self.n_buckets // factor] rotations_shape = (vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, rot_size // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = self._sample_rotation(rotations_shape, vecs, rng) # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) if self._rehash_each_round: if self._factorize_hash and len(factor_list) > 1: # We factorized self.n_buckets as the product of factor_list. # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in factor_list: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor else: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) else: assert not self._factorize_hash rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val(rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1, )) return buckets