Ejemplo n.º 1
0
def get_knot_params(theta: jnp.ndarray,
                    K: int,
                    min_width: Optional[float] = 1e-3,
                    min_height: Optional[float] = 1e-3,
                    min_derivative: Optional[float] = 1e-3,
                    bounds: Sequence[float] = ((-3.0, 3.0), (-3.0, 3.0))):
    # Get the individual parameters
    tw, th, td = jnp.split(theta, jnp.array([K, 2 * K]), axis=-1)

    # Make the parameters fit the discription of knots
    tw, th = jax.nn.softmax(tw), jax.nn.softmax(th)
    tw = min_width + (1.0 - min_width * K) * tw
    th = min_height + (1.0 - min_height * K) * th
    td = min_derivative + util.proximal_relu(td)
    # td = min_derivative + jax.nn.softplus(td)
    knot_x, knot_y = jnp.cumsum(tw, axis=-1), jnp.cumsum(th, axis=-1)

    # Pad the knots so that the first element is 0
    pad = [(0, 0)] * (len(td.shape) - 1) + [(1, 0)]
    knot_x = jnp.pad(knot_x, pad)
    knot_y = jnp.pad(knot_y, pad)

    # Scale by the bounds
    knot_x = (bounds[0][1] - bounds[0][0]) * knot_x + bounds[0][0]
    knot_y = (bounds[1][1] - bounds[1][0]) * knot_y + bounds[1][0]

    # Pad the derivatives so that the first and last elts are 1
    pad = [(0, 0)] * (len(td.shape) - 1) + [(1, 1)]
    knot_derivs = jnp.pad(td, pad, constant_values=1)

    return knot_x, knot_y, knot_derivs
Ejemplo n.º 2
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        outputs = {}
        x = inputs["x"]
        x_shape = self.get_unbatched_shapes(sample)["x"]

        param_shape = tuple([x_shape[ax] for ax in self.axes])
        b = hk.get_parameter("b",
                             shape=param_shape,
                             dtype=x.dtype,
                             init=jnp.zeros)
        log_s = hk.get_parameter("log_s",
                                 shape=param_shape,
                                 dtype=x.dtype,
                                 init=jnp.zeros)

        s = util.proximal_relu(log_s) + 1e-5

        if sample == False:
            outputs["x"] = (x - b) / s
        else:
            outputs["x"] = s * x + b

        log_det = -jnp.log(s).sum() * jnp.ones(self.batch_shape)
        outputs["log_det"] = log_det

        return outputs
Ejemplo n.º 3
0
    def transform(self, x, params=None, sample=False, rng=None, **kwargs):
        # Remember that self.get_unbatched_shapes(sample)["x"] is NOT the shape of x here!
        # The x we see here is only half of the actual x!

        # Get the parameters of the transformation
        scale_init = hk.initializers.RandomNormal(stddev=0.01)
        if params is None:
            x_shape = x.shape[len(self.batch_shape):]
            if self.kind == "affine":
                log_s = hk.get_parameter("log_s",
                                         shape=x_shape,
                                         dtype=x.dtype,
                                         init=scale_init)
            t = hk.get_parameter("t",
                                 shape=x_shape,
                                 dtype=x.dtype,
                                 init=scale_init)

        else:
            if self.kind == "affine":
                scale_scale = hk.get_parameter("scale_scale",
                                               shape=(),
                                               dtype=x.dtype,
                                               init=scale_init)
            shift_scale = hk.get_parameter("shift_scale",
                                           shape=(),
                                           dtype=x.dtype,
                                           init=scale_init)

            # Split the output and bound the scaling term
            if self.kind == "affine":
                t, log_s = jnp.split(params, 2, axis=self.axis)
                log_s = util.constrain_log_scale(log_s)
            else:
                t = params

            # Scale the parameters so that we can initialize this function to the identity
            t = shift_scale * t
            if self.kind == "affine":
                log_s = scale_scale * log_s

        if self.kind == "affine":
            if self.safe_diag:
                s = util.proximal_relu(log_s) + 1e-6
                log_s = jnp.log(s)
            else:
                s = jnp.exp(log_s)

        # Evaluate the transformation
        if sample == False:
            z = (x - t) / s if self.kind == "affine" else x - t
        else:
            z = x * s + t if self.kind == "affine" else x + t

        if self.kind == "affine":
            elementwise_log_det = jnp.broadcast_to(-log_s, x.shape)
        else:
            elementwise_log_det = jnp.zeros_like(x)

        return z, elementwise_log_det
Ejemplo n.º 4
0
    def f(self, weight_logits, means, log_scales, x):
        if self.restrict_scales:
            if self.safe_diag == False:
                log_scales = jnp.maximum(-7.0, log_scales)
            else:
                scales = util.proximal_relu(log_scales) + 1e-5
                log_scales = jnp.log(scales)

        return logistic_cdf_mixture_logit(weight_logits, means, log_scales, x)
Ejemplo n.º 5
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    outputs = {}

    dim, dtype = inputs["x"].shape[-1], inputs["x"].dtype

    L     = hk.get_parameter("L", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01))
    U     = hk.get_parameter("U", shape=(dim, dim), dtype=dtype, init=hk.initializers.RandomNormal(0.01))
    log_d = hk.get_parameter("log_d", shape=(dim,), dtype=dtype, init=jnp.zeros)
    lower_mask = jnp.ones((dim, dim), dtype=bool)
    lower_mask = jax.ops.index_update(lower_mask, jnp.triu_indices(dim), False)

    if self.safe_diag:
      d = util.proximal_relu(log_d) + 1e-5
      log_d = jnp.log(d)

    def b_init(shape, dtype):
      x = inputs["x"]
      if x.ndim == 1:
        return jnp.zeros(shape, dtype=dtype)

      # Initialize to the batch mean
      z = jnp.dot(x, (U*lower_mask.T).T) + x
      z *= jnp.exp(log_d)
      z = jnp.dot(z, (L*lower_mask).T) + z
      b = -jnp.mean(z, axis=0)
      return b

    b = hk.get_parameter("b", shape=(dim,), dtype=dtype, init=b_init)

    # Its way faster to allocate a full matrix for L and U and then mask than it
    # is to allocate only the lower/upper parts and the reshape.
    if sample == False:
      x = inputs["x"]
      z = jnp.dot(x, (U*lower_mask.T).T) + x
      z *= jnp.exp(log_d)
      z = jnp.dot(z, (L*lower_mask).T) + z
      outputs["x"] = z + b
    else:
      z = inputs["x"]

      @self.auto_batch
      def invert(z):
        x = L_solve(L, z - b)
        x = x*jnp.exp(-log_d)
        return U_solve(U, x)

      outputs["x"] = invert(z)

    outputs["log_det"] = jnp.sum(log_d, axis=-1)*jnp.ones(self.batch_shape)
    return outputs
Ejemplo n.º 6
0
    def f_and_elementwise_log_det(self, weight_logits, means, log_scales, x):
        if self.restrict_scales:
            if self.safe_diag == False:
                log_scales = jnp.maximum(-7.0, log_scales)
            else:
                scales = util.proximal_relu(log_scales) + 1e-5
                log_scales = jnp.log(scales)

        primals = weight_logits, means, log_scales, x
        tangents = jax.tree_map(jnp.zeros_like,
                                primals[:-1]) + (jnp.ones_like(x), )
        z, dzdx = jax.jvp(logistic_cdf_mixture_logit, primals, tangents)

        elementwise_log_det = jnp.log(dzdx)
        return z, elementwise_log_det
Ejemplo n.º 7
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             no_noise: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        network = self.get_generator_network()

        # network_in = inputs["x"]
        network_out = network(inputs, rng)["x"]
        # network_out = self.auto_batch(network, expected_depth=1, in_axes=(0, None))(network_in, rng)
        mu, log_diag_cov = jnp.split(network_out, 2, axis=-1)

        diag_cov = util.proximal_relu(log_diag_cov) + 1e-5
        log_diag_cov = jnp.log(diag_cov)

        x = mu
        if no_noise == False:
            x += random.normal(rng, mu.shape) * jnp.sqrt(diag_cov)

        outputs = {"x": x, "mu": mu, "log_diag_cov": log_diag_cov}

        return outputs
Ejemplo n.º 8
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: PRNGKey,
             sample: Optional[bool] = False,
             reconstruction: Optional[bool] = False,
             is_training: bool = True,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        outputs = {}
        x_shape = self.get_unbatched_shapes(sample)["x"]

        # Make sure that we're using 1d inputs
        x = inputs["x"].reshape(self.batch_shape + (-1, ))
        x_dim = x.shape[-1]

        # Work with one-hot labels
        y_one_hot = inputs.get("y", None)
        if y_one_hot is not None:
            assert y_one_hot.shape == self.batch_shape + (self.n_classes, )
            y_one_hot *= 1.0
        else:
            if sample == False:
                # Assign equal probability to each class
                y_one_hot = jnp.ones(self.batch_shape + (self.n_classes, ))
            else:
                # Sample class labels
                y = random.randint(rng,
                                   minval=0,
                                   maxval=self.n_classes,
                                   shape=self.batch_shape)
                y_one_hot = y[..., None] == jnp.arange(self.n_classes)[..., :]
                y_one_hot *= 1.0

        # GMM parameters.  Assume uniform mixture component weights so that things are differentiable.
        means = hk.get_parameter("means",
                                 shape=(self.n_classes, x_dim),
                                 dtype=x.dtype,
                                 init=hk.initializers.RandomNormal())
        log_diag_covs = hk.get_parameter("log_diag_covs",
                                         shape=(self.n_classes, x_dim),
                                         dtype=x.dtype,
                                         init=jnp.ones)
        diag_covs = util.proximal_relu(log_diag_covs) + 1e-3
        log_diag_covs = jnp.log(diag_covs)

        # Sample a new input
        if sample == True and reconstruction == False:
            # Sample from all of the clusters
            noise = random.normal(rng,
                                  self.batch_shape + (self.n_classes, x_dim))
            xs = means + jnp.exp(0.5 * log_diag_covs) * noise

            # Select the mixture component
            x = xs * y_one_hot[..., None]
            x = x.sum(axis=-2)

        # Evaluate the log pdf for each mixture component
        @partial(jax.vmap, in_axes=(0, 0, None))
        def diag_gaussian(mean, log_diag_cov, x):
            dx = x - mean
            log_pdf = jnp.sum(dx**2 * jnp.exp(-log_diag_cov), axis=-1)
            log_pdf += log_diag_cov.sum()
            log_pdf += x_dim * jnp.log(2 * jnp.pi)
            return -0.5 * log_pdf

        # Last axis will be across the mixture components
        log_pdfs = self.auto_batch(partial(diag_gaussian, means,
                                           log_diag_covs))(x)

        # Make a class prediction
        y_pred = jnp.argmax(log_pdfs, axis=-1)
        y_pred_one_hot = y_pred[...,
                                None] == jnp.arange(self.n_classes)[..., :]
        y_pred_one_hot *= 1.0

        # Compute p(x,y) = p(x|y)p(y) if we have a label, p(x) otherwise.
        # If we have a label, zero out all but the label index then reduce.
        # Otherwise, reduce over all of the indices.
        if is_training:

            # Apply the label masks
            if "y_is_labeled" in inputs:
                y_is_labeled = inputs["y_is_labeled"][..., None].astype(bool)
                y_one_hot = y_one_hot * y_is_labeled + jnp.ones_like(
                    y_one_hot) * (~y_is_labeled)

            log_pz = util.lse(log_pdfs, b=y_one_hot, axis=-1)
            # log_pz = logsumexp(log_pdfs, b=y_one_hot, axis=-1)
        else:
            # If we're doing classification, use the predicted label
            if "y" in inputs:
                log_pz = util.lse(log_pdfs, b=y_pred_one_hot, axis=-1)
            else:
                log_pz = logsumexp(log_pdfs, axis=-1)

        # Account for p(y)=1/N or 1/N when we take the mean
        log_pz -= jnp.log(self.n_classes)

        # p(y|x) is a categorical distribution
        log_pygx = jax.nn.log_softmax(log_pdfs)
        if is_training:
            log_pygx *= y_one_hot

            if "y_is_labeled" in inputs:
                # This time, zero out values that aren't labeled
                log_pygx *= y_is_labeled

        else:
            if "y" in inputs:
                log_pygx *= y_pred_one_hot

        log_pygx = log_pygx.sum(axis=-1)

        # Reshape the output
        x = x.reshape(self.batch_shape + x_shape)

        outputs = {"x": x, "log_pz": log_pz, "log_pygx": log_pygx}
        outputs["prediction"] = y_pred
        outputs["prediction_one_hot"] = outputs["prediction"][
            ..., None] == jnp.arange(self.n_classes)[..., :]
        return outputs