def get_test_ds(): batch_shape = (batch_size, ) if n_batches is None else (n_batches, batch_size) big_batch_size = util.list_prod(batch_shape) start_idx = n_train get_end_idx = lambda start: start + big_batch_size while True: end_idx = get_end_idx(start_idx) data_batch = data[start_idx:end_idx] if n_batches is not None: data_batch = rebatch(data_batch, batch_size, n_batches) inputs = {"x": data_batch} if classification and labels is not None: y = labels[start_idx:end_idx] if n_batches is not None: y = rebatch(y, batch_size, n_batches) y_one_hot = (y[..., None] == range_classes[..., :]) * 1.0 inputs["y"] = y_one_hot yield inputs start_idx += big_batch_size if start_idx >= data.shape[0]: # Stop iterating return
def evaluate_test(self, key: PRNGKey, input_iterator, **kwargs): sum_log_px = 0.0 total_examples = 0 try: while True: key, test_key = random.split(key, 2) inputs = next(input_iterator) outputs = self.flow.scan_apply(test_key, inputs, **kwargs) # Accumulate the total sum of the log likelihoods in case the batch sizes differ sum_log_px += outputs["log_px"].sum() n_examples_in_batch = util.list_prod(self.flow.get_batch_shape(inputs)) total_examples += n_examples_in_batch except StopIteration: pass nll = -sum_log_px/total_examples self.test_losses[self.n_train_steps] = nll return nll
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] if sample == False: out_dim = util.list_prod(self.output_shape) expected_out_dim = util.list_prod(self.unbatched_input_shapes["x"]) assert out_dim == expected_out_dim, f"Dimension mismatch" z = x.reshape(self.batch_shape + self.output_shape) else: original_shape = self.batch_shape + self.unbatched_input_shapes["x"] z = x.reshape(original_shape) outputs = {"x": z, "log_det": jnp.zeros(self.batch_shape)} return outputs
def evaluate_test(self, key: PRNGKey, input_iterator, bits_per_dim: bool = False, **kwargs): sum_log_px = 0.0 total_examples = 0 n_correct = 0 try: while True: key, test_key = random.split(key, 2) inputs = next(input_iterator) # If we're using a residual flow, catch up estimating the singular values # if our estimate is bad if total_examples == 0: inputs_batched = jax.tree_map(lambda x: x[0], inputs) self.flow.apply(test_key, inputs_batched, is_training=True, force_update_params=True, **kwargs) outputs = self.flow.scan_apply(test_key, inputs, is_training=False, **kwargs) # Accumulate the total sum of the log likelihoods in case the batch sizes differ sum_log_px += outputs["log_px"].sum() y_one_hot = inputs["y"] n_correct += (outputs["prediction_one_hot"] * y_one_hot).sum(axis=-1).sum() n_examples_in_batch = util.list_prod( self.flow.get_batch_shape(inputs)) total_examples += n_examples_in_batch except StopIteration: pass nll = -sum_log_px / total_examples self.test_losses[self.n_train_steps] = nll acc = n_correct / total_examples if bits_per_dim: nll = self.flow.to_bits_per_dim(nll) return nll, (acc, nll)
def generate_grid_indices(shape, rng): total_dim = util.list_prod(shape) # Generate the indices for each pixel idx = jnp.arange(4).tile((total_dim, 1)) # Shuffle the indices. random.permutation doesn't accept an axis argument for some reason. rngs = random.split(rng, total_dim) idx = vmap(random.permutation)(rngs, idx) # Separate into the max and non-max indices max_idx = idx[..., 0].reshape(shape) non_max_idx = idx[..., 1:] non_max_idx = non_max_idx.sort(axis=-1) non_max_idx = non_max_idx.reshape(shape + (3, )) return max_idx, non_max_idx
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: outputs = {} if sample == False: outputs["x"] = inputs["x"] / self.scale else: outputs["x"] = inputs["x"] * self.scale shape = self.get_unbatched_shapes(sample)["x"] outputs["log_det"] = jnp.ones(self.batch_shape) outputs["log_det"] *= -jnp.log(self.scale) * util.list_prod(shape) return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] if sample == False: unbatched_shape = self.unbatched_input_shapes["x"] flat_dim = util.list_prod(unbatched_shape) flat_shape = self.batch_shape + (flat_dim, ) z = x.reshape(flat_shape) else: original_shape = self.batch_shape + self.unbatched_input_shapes["x"] z = x.reshape(original_shape) outputs = {"x": z, "log_det": jnp.zeros(self.batch_shape)} return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: PRNGKey, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: # Create the decoder decoder = self.default_decoder( ) if self.decoder is None else self.decoder() if sample == False: x = inputs["x"] # Get the max and non-max elements max_elts, non_max_elts, max_idx, non_max_idx = self.auto_batch( extract_max_elts)(x) # See how likely these non-max elements are. Condition on values and indices # so that the decoder has context on what to generate. cond = jnp.concatenate([max_elts[..., None], non_max_idx], axis=-1) cond = cond.reshape(cond.shape[:-2] + (-1, )) decoder_inputs = {"x": non_max_elts, "condition": cond} decoder_outputs = decoder(decoder_inputs, rng, sample=False) log_qzgx = decoder_outputs["log_pz"] + decoder_outputs.get( "log_det", jnp.array(0.0)) # We are assuming a uniform distribution for the order of the indices log_qkgx = -jnp.log(4) * max_elts.size log_contribution = log_qzgx + log_qkgx outputs = {"x": max_elts, "log_det": log_contribution} else: max_elts = inputs["x"] max_elts_shape = self.get_unbatched_shapes(sample)["x"] max_elts_size = util.list_prod(max_elts_shape) rng1, rng2 = random.split(rng, 2) # Sample the max indices from q(k|x) n_keys = util.list_prod(self.batch_shape) rngs = random.split(rng1, n_keys).reshape(self.batch_shape + (-1, )) max_idx, non_max_idx = self.auto_batch( partial(generate_grid_indices, max_elts_shape))(rngs) log_qkgx = -jnp.log(4) * max_elts_size # Sample the non-max indices H, W, C = max_idx.shape[-3:] cond = jnp.concatenate([max_elts[..., None], non_max_idx], axis=-1) cond = cond.reshape(cond.shape[:-2] + (-1, )) decoder_inputs = { "x": jnp.zeros(self.batch_shape + (H, W, 3 * C)), "condition": cond } decoder_outputs = decoder(decoder_inputs, rng2, sample=True) non_max_elts = decoder_outputs["x"] log_qzgx = decoder_outputs["log_pz"] + decoder_outputs.get( "log_det", jnp.array(0.0)) # Combine the max elements with the non-max elements x = self.auto_batch(contruct_from_max_elts)(max_elts, non_max_elts, max_idx, non_max_idx) log_contribution = log_qzgx + log_qkgx outputs = {"x": x, "log_det": log_contribution} return outputs
def output_dim(self): if hasattr(self, "_output_dim"): return self._output_dim return util.list_prod(self.output_shape)
def input_dim(self): if hasattr(self, "_input_dim"): return self._input_dim return util.list_prod(self.input_shape)
def to_bits_per_dim(self, log_likelihood): return log_likelihood/util.list_prod(self.data_shape)/jnp.log(2)
def call(self, inputs: Mapping[str, jnp.ndarray], rng: PRNGKey, sample: Optional[bool] = False, reconstruction: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] outputs = {} x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = tuple(-jnp.arange(1, 1 + len(x_shape))) x_flat = x.reshape(self.batch_shape + (-1, )) y = inputs.get("y", jnp.ones(self.batch_shape, dtype=jnp.int32) * -1) # Keep these fixed. Learning doesn't make much difference apparently. means = hk.get_state("means", shape=(self.n_classes, x_flat.shape[-1]), dtype=x.dtype, init=hk.initializers.RandomNormal()) log_diag_covs = hk.get_state("log_diag_covs", shape=(self.n_classes, x_flat.shape[-1]), dtype=x.dtype, init=jnp.zeros) @partial(jax.vmap, in_axes=(0, 0, None)) def diag_gaussian(mean, log_diag_cov, x_flat): dx = x_flat - mean log_pdf = jnp.dot(dx * jnp.exp(-log_diag_cov), dx) log_pdf += log_diag_cov.sum() log_pdf += x_flat.size * jnp.log(2 * jnp.pi) return -0.5 * log_pdf log_pdfs = self.auto_batch(partial(diag_gaussian, means, log_diag_covs))(x_flat) # # Compute the log pdfs of each mixture component # normal = dists.Normal(means, jnp.exp(log_diag_covs)) # log_pdfs = self.auto_batch(normal.log_prob)(x_flat) # log_pdfs = log_pdfs.sum(axis=-1) if sample == False: # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise def log_prob(y, log_pdfs): return jax.lax.cond( y >= 0, lambda a: log_pdfs[y] + jnp.log(self.n_classes), lambda a: logsumexp(log_pdfs) - jnp.log(self.n_classes), None) outputs["log_pz"] = self.auto_batch(log_prob)(y, log_pdfs) outputs["x"] = x else: if reconstruction: outputs = {"x": x, "log_pz": jnp.array(0.0)} else: # Sample from all of the clusters # xs = normal.sample(rng) xs = random.normal(rng, x_flat.shape) def sample(log_pdfs, y, rng): def no_label(y): y = random.randint(rng, minval=0, maxval=self.n_classes, shape=(1, ))[0] # y = dists.CategoricalLogits(jnp.zeros(self.n_classes)).sample(rng, (1,))[0] return y, logsumexp(log_pdfs) - jnp.log(self.n_classes) def with_label(y): return y, log_pdfs[y] - jnp.log(self.n_classes) # Either sample or use a specified cluster return jax.lax.cond(y < 0, no_label, with_label, y) n_keys = util.list_prod(self.batch_shape) rngs = random.split(rng, n_keys).reshape(self.batch_shape + (-1, )) y, log_pz = self.auto_batch(sample)(log_pdfs, y, rngs) # Take a specific cluster outputs = {"x": xs[y].reshape(x.shape), "log_pz": log_pz} outputs["prediction"] = jnp.argmax(log_pdfs) return outputs