示例#1
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x = inputs["x"]
    x_shape = self.get_unbatched_shapes(sample)["x"]
    sum_axes = util.last_axes(x_shape)

    if sample == False:
      z = jax.nn.sigmoid(x)

      if self.has_scale == True:
        z -= self.scale
        z /= 1.0 - 2*self.scale

      log_det = -(jax.nn.softplus(x) + jax.nn.softplus(-x))
    else:
      if self.has_scale == True:
        x *= 1.0 - 2*self.scale
        x += self.scale

      z = jax.scipy.special.logit(x)
      log_det = -(jax.nn.softplus(z) + jax.nn.softplus(-z))

    if self.has_scale == True:
      log_det -= jnp.log(1.0 - 2*self.scale)

    log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape)

    outputs = {"x": z, "log_det": log_det}
    return outputs
def log_det_sliced_estimate(apply_fun,
                            params,
                            state,
                            x,
                            rng,
                            batch_info):
  trace_key, roulette_key = random.split(rng, 2)

  # Evaluate the flow and get the vjp function
  gx, state = apply_fun(params, state, x, rng, update_params=True)
  _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), x, has_aux=True)
  z = x + gx

  # Generate the probe vector for the trace estimate
  v = random.normal(trace_key, x.shape)

  # Get all of the terms we need for the log det and gradient estimates
  terms = unbiased_neumann_vjp_terms(vjp_fun, v, roulette_key, n_terms=7, n_exact=4)

  # Rescale the terms and sum over k (starting at k=1)
  cut_terms = terms[1:]
  log_det_coeff = -1/(1 + jnp.arange(cut_terms.shape[0]))
  log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim)
  log_det_terms = log_det_coeff*cut_terms
  summed_log_det_terms = log_det_terms.sum(axis=0)

  # Compute the log det
  x_shape, batch_shape = batch_info
  log_det = jnp.sum(summed_log_det_terms*v, axis=util.last_axes(x_shape))

  return z, log_det, v, terms, state
示例#3
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:

    outputs = {}
    x_shape = self.get_unbatched_shapes(sample)["x"]
    sum_axes = util.last_axes(x_shape)

    if sample == False:
      x = inputs["x"]
      sqrt_1px2 = jnp.sqrt(1 + x**2)
      z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha)
      outputs["x"] = z
    else:
      z = inputs["x"]
      alpha_sq = self.alpha**2
      b = (1 + self.alpha)*z + self.alpha
      x = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1)
      outputs["x"] = x
      sqrt_1px2 = jnp.sqrt(1 + x**2)

    log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha)
    log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape)

    outputs["log_det"] = log_det
    return outputs
def sliced_estimate_fwd(apply_fun, params, state, x, rng, batch_info):
  z, log_det, v, terms, state = log_det_sliced_estimate(apply_fun, params, state, x, rng, batch_info)

  # Accumulate the terms we need for the gradient
  summed_terms_for_grad = terms.sum(axis=0)

  x_shape, batch_shape = batch_info
  batch_dim = len(batch_shape)
  sum_axes = util.last_axes(x_shape)

  # Compute dlogdet(I + J(x;theta))/dtheta
  def vjvp(params, unbatched_x, unbatched_summed_terms, unbatched_v):
    # Remember that apply_fun is autobatched and can automatically pad leading dims!
    if batch_dim > 0:
      _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x[None], rng, update_params=False), unbatched_x, has_aux=True)
      w, = vjp_fun(unbatched_summed_terms[None])
    else:
      _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), unbatched_x, has_aux=True)
      w, = vjp_fun(unbatched_summed_terms)
    return jnp.sum(w*unbatched_v)

  # vmap over the batch dimensions
  vmapped_vjvp = jax.grad(vjvp, argnums=(0, 1))
  for i in range(batch_dim):
    vmapped_vjvp = jax.vmap(vmapped_vjvp, in_axes=(None, 0, 0, 0))

  # Compute the vector Jacobian vector products to get the gradient estimate terms.
  dlogdet_dtheta, dlogdet_dx = vmapped_vjvp(params, x, summed_terms_for_grad, v)

  # Store off everything for the backward pass
  ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx
  return (z, log_det, state), ctx
def estimate_fwd(apply_fun, params, state, x, rng, batch_info):
    z, log_det, terms, state = log_det_estimate(apply_fun, params, state, x,
                                                rng, batch_info)

    # Accumulate the terms we need for the gradient
    summed_terms_for_grad = terms.sum(axis=0)

    x_shape, batch_shape = batch_info
    sum_axes = util.last_axes(x_shape)

    # Compute dlogdet(I + J(x;theta))/dtheta
    def jjp(params, unbatched_x, unbatched_summed_terms):
        jac_fun = jax.jacobian(
            lambda x: apply_fun(params, state, x[None], rng)[0][0])
        J = jac_fun(unbatched_x)
        return jnp.trace(unbatched_summed_terms @ J)

    vmapped_grad_vjvp = jax.grad(jjp, argnums=(0, 1))
    for i in range(len(batch_shape)):
        vmapped_grad_vjvp = jax.vmap(vmapped_grad_vjvp, in_axes=(None, 0, 0))

    dlogdet_dtheta, dlogdet_dx = vmapped_grad_vjvp(params, x,
                                                   summed_terms_for_grad)

    ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx
    return (z, log_det, state), ctx
示例#6
0
def sliced_estimate_fwd(apply_fun, params, state, x, rng, batch_info):
    z, log_det, v, terms = log_det_sliced_estimate(apply_fun, params, state, x,
                                                   rng, batch_info)

    # Accumulate the terms we need for the gradient
    summed_terms_for_grad = terms.sum(axis=0)

    x_shape, batch_shape = batch_info
    sum_axes = util.last_axes(x_shape)

    # Compute dlogdet(I + J(x;theta))/dtheta
    def vjvp(params, unbatched_x, unbatched_summed_terms, unbatched_v):
        _, vjp_fun, _ = jax.vjp(
            lambda x: apply_fun(params, state, x[None], rng),
            unbatched_x,
            has_aux=True)
        w, = vjp_fun(unbatched_summed_terms[None])
        return jnp.sum(w * unbatched_v)

    vmapped_vjvp = jax.grad(vjvp, argnums=(0, 1))
    for i in range(len(batch_shape)):
        vmapped_vjvp = jax.vmap(vmapped_vjvp, in_axes=(None, 0, 0, 0))

    dlogdet_dtheta, dlogdet_dx = vmapped_vjvp(params, x, summed_terms_for_grad,
                                              v)

    ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx
    return (z, log_det), ctx
示例#7
0
  def transform(self, x, params=None, sample=False, mask=None):
    # Remember that self.get_unbatched_shapes(sample)["x"] is NOT the shape of x here!
    # The x we see here is only half of the actual x!

    # Get the parameters of the transformation
    scale_init = hk.initializers.RandomNormal(stddev=0.01)
    if params is None:
      x_shape = x.shape[len(self.batch_shape):]
      if self.kind == "affine":
        log_s = hk.get_parameter("log_s", shape=x_shape, dtype=x.dtype, init=scale_init)
      t = hk.get_parameter("t", shape=x_shape, dtype=x.dtype, init=scale_init)

    else:
      if self.kind == "affine":
        scale_scale = hk.get_parameter("scale_scale", shape=(), dtype=x.dtype, init=scale_init)
      shift_scale = hk.get_parameter("shift_scale", shape=(), dtype=x.dtype, init=scale_init)

      # Split the output and bound the scaling term
      if self.kind == "affine":
        t, log_s = jnp.split(params, 2, axis=self.axis)
        log_s = util.constrain_log_scale(log_s)
      else:
        t = params

      # Scale the parameters so that we can initialize this function to the identity
      t = shift_scale*t
      if self.kind == "affine":
        log_s = scale_scale*log_s

    # Evaluate the transformation
    if sample == False:
      z = (x - t)*jnp.exp(-log_s) if self.kind == "affine" else x - t
    else:
      z = x*jnp.exp(log_s) + t if self.kind == "affine" else x + t

    # If we're doing mask coupling, need to correctly mask log_s before
    # computing the log determinant and also mask the output
    if mask is not None:
      z *= mask
      log_s *= mask

    # Compute the log determinant
    if self.kind == "affine":
      sum_axes = util.last_axes(x.shape[len(self.batch_shape):])
      log_det = -log_s.sum(axis=sum_axes)
    else:
      log_det = jnp.zeros(self.batch_shape)

    return z, log_det
示例#8
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        outputs = {}

        x_shape = self.get_unbatched_shapes(sample)["x"]
        theta = hk.get_parameter("theta",
                                 shape=x_shape +
                                 (3 * self.n_components + self.extra, ),
                                 dtype=x.dtype,
                                 init=hk.initializers.RandomNormal(0.1))

        # Split the parameters
        if self.with_affine_coupling:
            in_axes = (0, None, None, None, None, None)
            out_axes = (None, None, None, None, None)
        else:
            in_axes = (0, None, None, None)
            out_axes = (None, None, None)

        params = self.split_theta(theta)
        init_fun = self.auto_batch(self.safe_init,
                                   in_axes=in_axes,
                                   out_axes=out_axes,
                                   expected_depth=1)
        params = init_fun(x, *params)

        # Run the transform
        if sample == False:
            z, elementwise_log_det = self.auto_batch(self.mixture_forward,
                                                     in_axes=in_axes,
                                                     expected_depth=1)(x,
                                                                       *params)
        else:
            z, elementwise_log_det = self.auto_batch(self.mixture_inverse,
                                                     in_axes=in_axes,
                                                     expected_depth=1)(x,
                                                                       *params)

        sum_axes = util.last_axes(self.unbatched_input_shapes["x"])
        log_det = elementwise_log_det.sum(sum_axes)
        outputs = {"x": z, "log_det": log_det}

        return outputs
示例#9
0
    def call(self,
             inputs: Mapping[str, jnp.ndarray],
             rng: jnp.ndarray = None,
             sample: Optional[bool] = False,
             **kwargs) -> Mapping[str, jnp.ndarray]:
        x = inputs["x"]
        x_shape = self.get_unbatched_shapes(sample)["x"]
        sum_axes = util.last_axes(x_shape)

        if sample == False:
            z = jnp.where(x > 0, x, self.alpha * x)
        else:
            z = jnp.where(x > 0, x, x / self.alpha)

        log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha))
        log_det = log_dx_dz.sum(axis=sum_axes) * jnp.ones(self.batch_shape)

        outputs = {"x": z, "log_det": log_det}
        return outputs
示例#10
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x_shape = self.get_unbatched_shapes(sample)["x"]
    sum_axes = util.last_axes(x_shape)

    if sample == False:
      x = inputs["x"]
      x = jnp.where(x < 0.0, 1e-5, x)
      dx = jnp.log1p(-jnp.exp(-x))
      z = x + dx
      log_det = -dx.sum(axis=sum_axes)*jnp.ones(self.batch_shape)
      outputs = {"x": z, "log_det": log_det}
    else:
      x = jax.nn.softplus(inputs["x"])
      log_det = -jnp.log1p(-jnp.exp(x)).sum(axis=sum_axes)*jnp.ones(self.batch_shape)
      outputs = {"x": x, "log_det": log_det}

    return outputs
示例#11
0
  def call(self,
           inputs: Mapping[str, jnp.ndarray],
           rng: jnp.ndarray=None,
           sample: Optional[bool]=False,
           generate_image: Optional[bool]=False,
           **kwargs
  ) -> Mapping[str, jnp.ndarray]:
    x = inputs["x"]
    outputs = {}
    x_shape = self.get_unbatched_shapes(sample)["x"]
    sum_axes = util.last_axes(x_shape)

    if sample == False:
      if self.has_scale == True:
        x *= (1.0 - 2*self.scale)
        x += self.scale
      z = jax.scipy.special.logit(x)
      log_det = (jax.nn.softplus(z) + jax.nn.softplus(-z))
    else:
      z = jax.nn.sigmoid(x)
      log_det = (jax.nn.softplus(x) + jax.nn.softplus(-x))

      # If we are generating images, we want to pass the normalized image
      # to matplotlib!
      if generate_image:
        outputs["image"] = z

      if self.has_scale == True:
        z -= self.scale
        z /= (1.0 - 2*self.scale)

    if self.has_scale == True:
      log_det += jnp.log(1.0 - 2*self.scale)

    log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape)

    outputs["x"] = z
    outputs["log_det"] = log_det
    return outputs
示例#12
0
文件: spline.py 项目: jxzhangjhu/NuX
    def transform(self, x, params=None, sample=False, mask=None):
        x_flat = x.reshape(self.batch_shape + (-1, ))
        param_dim = (3 * self.K - 1)
        if params is None:
            x_shape = x_flat.shape[len(self.batch_shape):]
            theta = hk.get_parameter("theta",
                                     shape=(x_flat.shape[-1], ) +
                                     (param_dim, ),
                                     dtype=x_flat.dtype,
                                     init=hk.initializers.RandomNormal())
            in_axes = (None, 0)
        else:
            theta = params.reshape(self.batch_shape + (x_flat.shape[-1], ) +
                                   (param_dim, ))
            in_axes = (0, 0)

        if sample == False:
            z, ew_log_det = self.auto_batch(self.forward_spline,
                                            in_axes=in_axes)(theta, x_flat)
        else:
            z, ew_log_det = self.auto_batch(self.inverse_spline,
                                            in_axes=in_axes)(theta, x_flat)

        z = z.reshape(x.shape)
        ew_log_det = ew_log_det.reshape(x.shape)

        # If we're doing mask coupling, need to correctly mask log_s before
        # computing the log determinant and also mask the output
        if mask is not None:
            z *= mask
            ew_log_det *= mask

        sum_axes = util.last_axes(x.shape[len(self.batch_shape):])
        log_det = ew_log_det.sum(axis=sum_axes)

        return z, log_det
    def _transform(self,
                   x,
                   params=None,
                   sample=False,
                   mask=None,
                   rng=None,
                   **kwargs):
        z, ew_log_det = self.transform(x,
                                       params=params,
                                       sample=sample,
                                       rng=rng,
                                       **kwargs)
        assert z.shape == ew_log_det.shape

        # If we're doing mask coupling, need to correctly mask log_s before
        # computing the log determinant and also mask the output
        if mask is not None:
            z *= mask
            ew_log_det *= mask

        # Sum over each dimension
        sum_axes = util.last_axes(x.shape[len(self.batch_shape):])
        log_det = ew_log_det.sum(sum_axes)
        return z, log_det