Exemple #1
0
    def get_test_ds():
        batch_shape = (batch_size, ) if n_batches is None else (n_batches,
                                                                batch_size)

        big_batch_size = util.list_prod(batch_shape)
        start_idx = n_train
        get_end_idx = lambda start: start + big_batch_size

        while True:
            end_idx = get_end_idx(start_idx)

            data_batch = data[start_idx:end_idx]
            if n_batches is not None:
                data_batch = rebatch(data_batch, batch_size, n_batches)
            inputs = {"x": data_batch}

            if classification and labels is not None:
                y = labels[start_idx:end_idx]
                if n_batches is not None:
                    y = rebatch(y, batch_size, n_batches)
                y_one_hot = (y[..., None] == range_classes[..., :]) * 1.0

                inputs["y"] = y_one_hot

            yield inputs

            start_idx += big_batch_size

            if start_idx >= data.shape[0]:
                # Stop iterating
                return
Exemple #2
0
  def evaluate_test(self,
                    key: PRNGKey,
                    input_iterator,
                    **kwargs):

    sum_log_px = 0.0
    total_examples = 0

    try:
      while True:
        key, test_key = random.split(key, 2)
        inputs = next(input_iterator)
        outputs = self.flow.scan_apply(test_key, inputs, **kwargs)

        # Accumulate the total sum of the log likelihoods in case the batch sizes differ
        sum_log_px += outputs["log_px"].sum()
        n_examples_in_batch = util.list_prod(self.flow.get_batch_shape(inputs))
        total_examples += n_examples_in_batch

    except StopIteration:
      pass

    nll = -sum_log_px/total_examples
    self.test_losses[self.n_train_steps] = nll
    return nll
Exemple #3
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]

        if sample == False:
            out_dim = util.list_prod(self.output_shape)
            expected_out_dim = util.list_prod(self.unbatched_input_shapes["x"])
            assert out_dim == expected_out_dim, f"Dimension mismatch"

            z = x.reshape(self.batch_shape + self.output_shape)
        else:
            original_shape = self.batch_shape + self.unbatched_input_shapes["x"]
            z = x.reshape(original_shape)

        outputs = {"x": z, "log_det": jnp.zeros(self.batch_shape)}
        return outputs
Exemple #4
0
    def evaluate_test(self,
                      key: PRNGKey,
                      input_iterator,
                      bits_per_dim: bool = False,
                      **kwargs):

        sum_log_px = 0.0
        total_examples = 0
        n_correct = 0

        try:
            while True:
                key, test_key = random.split(key, 2)
                inputs = next(input_iterator)

                # If we're using a residual flow, catch up estimating the singular values
                # if our estimate is bad
                if total_examples == 0:
                    inputs_batched = jax.tree_map(lambda x: x[0], inputs)
                    self.flow.apply(test_key,
                                    inputs_batched,
                                    is_training=True,
                                    force_update_params=True,
                                    **kwargs)

                outputs = self.flow.scan_apply(test_key,
                                               inputs,
                                               is_training=False,
                                               **kwargs)

                # Accumulate the total sum of the log likelihoods in case the batch sizes differ
                sum_log_px += outputs["log_px"].sum()
                y_one_hot = inputs["y"]
                n_correct += (outputs["prediction_one_hot"] *
                              y_one_hot).sum(axis=-1).sum()
                n_examples_in_batch = util.list_prod(
                    self.flow.get_batch_shape(inputs))
                total_examples += n_examples_in_batch

        except StopIteration:
            pass

        nll = -sum_log_px / total_examples
        self.test_losses[self.n_train_steps] = nll

        acc = n_correct / total_examples

        if bits_per_dim:
            nll = self.flow.to_bits_per_dim(nll)

        return nll, (acc, nll)
Exemple #5
0
def generate_grid_indices(shape, rng):
    total_dim = util.list_prod(shape)

    # Generate the indices for each pixel
    idx = jnp.arange(4).tile((total_dim, 1))

    # Shuffle the indices.  random.permutation doesn't accept an axis argument for some reason.
    rngs = random.split(rng, total_dim)
    idx = vmap(random.permutation)(rngs, idx)

    # Separate into the max and non-max indices
    max_idx = idx[..., 0].reshape(shape)
    non_max_idx = idx[..., 1:]
    non_max_idx = non_max_idx.sort(axis=-1)
    non_max_idx = non_max_idx.reshape(shape + (3, ))

    return max_idx, non_max_idx
Exemple #6
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        outputs = {}

        if sample == False:
            outputs["x"] = inputs["x"] / self.scale
        else:
            outputs["x"] = inputs["x"] * self.scale

        shape = self.get_unbatched_shapes(sample)["x"]
        outputs["log_det"] = jnp.ones(self.batch_shape)
        outputs["log_det"] *= -jnp.log(self.scale) * util.list_prod(shape)

        return outputs
Exemple #7
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]

        if sample == False:
            unbatched_shape = self.unbatched_input_shapes["x"]
            flat_dim = util.list_prod(unbatched_shape)
            flat_shape = self.batch_shape + (flat_dim, )
            z = x.reshape(flat_shape)
        else:
            original_shape = self.batch_shape + self.unbatched_input_shapes["x"]
            z = x.reshape(original_shape)

        outputs = {"x": z, "log_det": jnp.zeros(self.batch_shape)}
        return outputs
Exemple #8
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:

        # Create the decoder
        decoder = self.default_decoder(
        ) if self.decoder is None else self.decoder()

        if sample == False:
            x = inputs["x"]

            # Get the max and non-max elements
            max_elts, non_max_elts, max_idx, non_max_idx = self.auto_batch(
                extract_max_elts)(x)

            # See how likely these non-max elements are.  Condition on values and indices
            # so that the decoder has context on what to generate.
            cond = jnp.concatenate([max_elts[..., None], non_max_idx], axis=-1)
            cond = cond.reshape(cond.shape[:-2] + (-1, ))
            decoder_inputs = {"x": non_max_elts, "condition": cond}
            decoder_outputs = decoder(decoder_inputs, rng, sample=False)
            log_qzgx = decoder_outputs["log_pz"] + decoder_outputs.get(
                "log_det", jnp.array(0.0))

            # We are assuming a uniform distribution for the order of the indices
            log_qkgx = -jnp.log(4) * max_elts.size

            log_contribution = log_qzgx + log_qkgx
            outputs = {"x": max_elts, "log_det": log_contribution}

        else:
            max_elts = inputs["x"]
            max_elts_shape = self.get_unbatched_shapes(sample)["x"]
            max_elts_size = util.list_prod(max_elts_shape)
            rng1, rng2 = random.split(rng, 2)

            # Sample the max indices from q(k|x)
            n_keys = util.list_prod(self.batch_shape)
            rngs = random.split(rng1,
                                n_keys).reshape(self.batch_shape + (-1, ))
            max_idx, non_max_idx = self.auto_batch(
                partial(generate_grid_indices, max_elts_shape))(rngs)
            log_qkgx = -jnp.log(4) * max_elts_size

            # Sample the non-max indices
            H, W, C = max_idx.shape[-3:]
            cond = jnp.concatenate([max_elts[..., None], non_max_idx], axis=-1)
            cond = cond.reshape(cond.shape[:-2] + (-1, ))
            decoder_inputs = {
                "x": jnp.zeros(self.batch_shape + (H, W, 3 * C)),
                "condition": cond
            }
            decoder_outputs = decoder(decoder_inputs, rng2, sample=True)
            non_max_elts = decoder_outputs["x"]
            log_qzgx = decoder_outputs["log_pz"] + decoder_outputs.get(
                "log_det", jnp.array(0.0))

            # Combine the max elements with the non-max elements
            x = self.auto_batch(contruct_from_max_elts)(max_elts, non_max_elts,
                                                        max_idx, non_max_idx)

            log_contribution = log_qzgx + log_qkgx
            outputs = {"x": x, "log_det": log_contribution}

        return outputs
Exemple #9
0
 def output_dim(self):
     if hasattr(self, "_output_dim"):
         return self._output_dim
     return util.list_prod(self.output_shape)
Exemple #10
0
 def input_dim(self):
     if hasattr(self, "_input_dim"):
         return self._input_dim
     return util.list_prod(self.input_shape)
Exemple #11
0
 def to_bits_per_dim(self, log_likelihood):
   return log_likelihood/util.list_prod(self.data_shape)/jnp.log(2)
Exemple #12
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             reconstruction: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        outputs = {}
        x_shape = self.get_unbatched_shapes(sample)["x"]
        sum_axes = tuple(-jnp.arange(1, 1 + len(x_shape)))
        x_flat = x.reshape(self.batch_shape + (-1, ))
        y = inputs.get("y", jnp.ones(self.batch_shape, dtype=jnp.int32) * -1)

        # Keep these fixed.  Learning doesn't make much difference apparently.
        means = hk.get_state("means",
                             shape=(self.n_classes, x_flat.shape[-1]),
                             dtype=x.dtype,
                             init=hk.initializers.RandomNormal())
        log_diag_covs = hk.get_state("log_diag_covs",
                                     shape=(self.n_classes, x_flat.shape[-1]),
                                     dtype=x.dtype,
                                     init=jnp.zeros)

        @partial(jax.vmap, in_axes=(0, 0, None))
        def diag_gaussian(mean, log_diag_cov, x_flat):
            dx = x_flat - mean
            log_pdf = jnp.dot(dx * jnp.exp(-log_diag_cov), dx)
            log_pdf += log_diag_cov.sum()
            log_pdf += x_flat.size * jnp.log(2 * jnp.pi)
            return -0.5 * log_pdf

        log_pdfs = self.auto_batch(partial(diag_gaussian, means,
                                           log_diag_covs))(x_flat)

        # # Compute the log pdfs of each mixture component
        # normal = dists.Normal(means, jnp.exp(log_diag_covs))
        # log_pdfs = self.auto_batch(normal.log_prob)(x_flat)
        # log_pdfs = log_pdfs.sum(axis=-1)

        if sample == False:
            # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise
            def log_prob(y, log_pdfs):
                return jax.lax.cond(
                    y >= 0, lambda a: log_pdfs[y] + jnp.log(self.n_classes),
                    lambda a: logsumexp(log_pdfs) - jnp.log(self.n_classes),
                    None)

            outputs["log_pz"] = self.auto_batch(log_prob)(y, log_pdfs)
            outputs["x"] = x

        else:
            if reconstruction:
                outputs = {"x": x, "log_pz": jnp.array(0.0)}
            else:
                # Sample from all of the clusters
                # xs = normal.sample(rng)
                xs = random.normal(rng, x_flat.shape)

                def sample(log_pdfs, y, rng):
                    def no_label(y):
                        y = random.randint(rng,
                                           minval=0,
                                           maxval=self.n_classes,
                                           shape=(1, ))[0]
                        # y = dists.CategoricalLogits(jnp.zeros(self.n_classes)).sample(rng, (1,))[0]
                        return y, logsumexp(log_pdfs) - jnp.log(self.n_classes)

                    def with_label(y):
                        return y, log_pdfs[y] - jnp.log(self.n_classes)

                    # Either sample or use a specified cluster
                    return jax.lax.cond(y < 0, no_label, with_label, y)

                n_keys = util.list_prod(self.batch_shape)
                rngs = random.split(rng,
                                    n_keys).reshape(self.batch_shape + (-1, ))
                y, log_pz = self.auto_batch(sample)(log_pdfs, y, rngs)

                # Take a specific cluster
                outputs = {"x": xs[y].reshape(x.shape), "log_pz": log_pz}

        outputs["prediction"] = jnp.argmax(log_pdfs)

        return outputs