def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: x = inputs["x"] x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = util.last_axes(x_shape) if sample == False: z = jax.nn.sigmoid(x) if self.has_scale == True: z -= self.scale z /= 1.0 - 2*self.scale log_det = -(jax.nn.softplus(x) + jax.nn.softplus(-x)) else: if self.has_scale == True: x *= 1.0 - 2*self.scale x += self.scale z = jax.scipy.special.logit(x) log_det = -(jax.nn.softplus(z) + jax.nn.softplus(-z)) if self.has_scale == True: log_det -= jnp.log(1.0 - 2*self.scale) log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape) outputs = {"x": z, "log_det": log_det} return outputs
def log_det_sliced_estimate(apply_fun, params, state, x, rng, batch_info): trace_key, roulette_key = random.split(rng, 2) # Evaluate the flow and get the vjp function gx, state = apply_fun(params, state, x, rng, update_params=True) _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), x, has_aux=True) z = x + gx # Generate the probe vector for the trace estimate v = random.normal(trace_key, x.shape) # Get all of the terms we need for the log det and gradient estimates terms = unbiased_neumann_vjp_terms(vjp_fun, v, roulette_key, n_terms=7, n_exact=4) # Rescale the terms and sum over k (starting at k=1) cut_terms = terms[1:] log_det_coeff = -1/(1 + jnp.arange(cut_terms.shape[0])) log_det_coeff = util.broadcast_to_first_axis(log_det_coeff, cut_terms.ndim) log_det_terms = log_det_coeff*cut_terms summed_log_det_terms = log_det_terms.sum(axis=0) # Compute the log det x_shape, batch_shape = batch_info log_det = jnp.sum(summed_log_det_terms*v, axis=util.last_axes(x_shape)) return z, log_det, v, terms, state
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: outputs = {} x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = util.last_axes(x_shape) if sample == False: x = inputs["x"] sqrt_1px2 = jnp.sqrt(1 + x**2) z = (x + self.alpha*(sqrt_1px2 - 1))/(1 + self.alpha) outputs["x"] = z else: z = inputs["x"] alpha_sq = self.alpha**2 b = (1 + self.alpha)*z + self.alpha x = (jnp.sqrt(alpha_sq*(1 + b**2 - alpha_sq)) - b)/(alpha_sq - 1) outputs["x"] = x sqrt_1px2 = jnp.sqrt(1 + x**2) log_det = jnp.log(1 + self.alpha*x/sqrt_1px2) - jnp.log(1 + self.alpha) log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape) outputs["log_det"] = log_det return outputs
def sliced_estimate_fwd(apply_fun, params, state, x, rng, batch_info): z, log_det, v, terms, state = log_det_sliced_estimate(apply_fun, params, state, x, rng, batch_info) # Accumulate the terms we need for the gradient summed_terms_for_grad = terms.sum(axis=0) x_shape, batch_shape = batch_info batch_dim = len(batch_shape) sum_axes = util.last_axes(x_shape) # Compute dlogdet(I + J(x;theta))/dtheta def vjvp(params, unbatched_x, unbatched_summed_terms, unbatched_v): # Remember that apply_fun is autobatched and can automatically pad leading dims! if batch_dim > 0: _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x[None], rng, update_params=False), unbatched_x, has_aux=True) w, = vjp_fun(unbatched_summed_terms[None]) else: _, vjp_fun, _ = jax.vjp(lambda x: apply_fun(params, state, x, rng, update_params=False), unbatched_x, has_aux=True) w, = vjp_fun(unbatched_summed_terms) return jnp.sum(w*unbatched_v) # vmap over the batch dimensions vmapped_vjvp = jax.grad(vjvp, argnums=(0, 1)) for i in range(batch_dim): vmapped_vjvp = jax.vmap(vmapped_vjvp, in_axes=(None, 0, 0, 0)) # Compute the vector Jacobian vector products to get the gradient estimate terms. dlogdet_dtheta, dlogdet_dx = vmapped_vjvp(params, x, summed_terms_for_grad, v) # Store off everything for the backward pass ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx return (z, log_det, state), ctx
def estimate_fwd(apply_fun, params, state, x, rng, batch_info): z, log_det, terms, state = log_det_estimate(apply_fun, params, state, x, rng, batch_info) # Accumulate the terms we need for the gradient summed_terms_for_grad = terms.sum(axis=0) x_shape, batch_shape = batch_info sum_axes = util.last_axes(x_shape) # Compute dlogdet(I + J(x;theta))/dtheta def jjp(params, unbatched_x, unbatched_summed_terms): jac_fun = jax.jacobian( lambda x: apply_fun(params, state, x[None], rng)[0][0]) J = jac_fun(unbatched_x) return jnp.trace(unbatched_summed_terms @ J) vmapped_grad_vjvp = jax.grad(jjp, argnums=(0, 1)) for i in range(len(batch_shape)): vmapped_grad_vjvp = jax.vmap(vmapped_grad_vjvp, in_axes=(None, 0, 0)) dlogdet_dtheta, dlogdet_dx = vmapped_grad_vjvp(params, x, summed_terms_for_grad) ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx return (z, log_det, state), ctx
def sliced_estimate_fwd(apply_fun, params, state, x, rng, batch_info): z, log_det, v, terms = log_det_sliced_estimate(apply_fun, params, state, x, rng, batch_info) # Accumulate the terms we need for the gradient summed_terms_for_grad = terms.sum(axis=0) x_shape, batch_shape = batch_info sum_axes = util.last_axes(x_shape) # Compute dlogdet(I + J(x;theta))/dtheta def vjvp(params, unbatched_x, unbatched_summed_terms, unbatched_v): _, vjp_fun, _ = jax.vjp( lambda x: apply_fun(params, state, x[None], rng), unbatched_x, has_aux=True) w, = vjp_fun(unbatched_summed_terms[None]) return jnp.sum(w * unbatched_v) vmapped_vjvp = jax.grad(vjvp, argnums=(0, 1)) for i in range(len(batch_shape)): vmapped_vjvp = jax.vmap(vmapped_vjvp, in_axes=(None, 0, 0, 0)) dlogdet_dtheta, dlogdet_dx = vmapped_vjvp(params, x, summed_terms_for_grad, v) ctx = x, params, state, rng, batch_info, dlogdet_dtheta, dlogdet_dx return (z, log_det), ctx
def transform(self, x, params=None, sample=False, mask=None): # Remember that self.get_unbatched_shapes(sample)["x"] is NOT the shape of x here! # The x we see here is only half of the actual x! # Get the parameters of the transformation scale_init = hk.initializers.RandomNormal(stddev=0.01) if params is None: x_shape = x.shape[len(self.batch_shape):] if self.kind == "affine": log_s = hk.get_parameter("log_s", shape=x_shape, dtype=x.dtype, init=scale_init) t = hk.get_parameter("t", shape=x_shape, dtype=x.dtype, init=scale_init) else: if self.kind == "affine": scale_scale = hk.get_parameter("scale_scale", shape=(), dtype=x.dtype, init=scale_init) shift_scale = hk.get_parameter("shift_scale", shape=(), dtype=x.dtype, init=scale_init) # Split the output and bound the scaling term if self.kind == "affine": t, log_s = jnp.split(params, 2, axis=self.axis) log_s = util.constrain_log_scale(log_s) else: t = params # Scale the parameters so that we can initialize this function to the identity t = shift_scale*t if self.kind == "affine": log_s = scale_scale*log_s # Evaluate the transformation if sample == False: z = (x - t)*jnp.exp(-log_s) if self.kind == "affine" else x - t else: z = x*jnp.exp(log_s) + t if self.kind == "affine" else x + t # If we're doing mask coupling, need to correctly mask log_s before # computing the log determinant and also mask the output if mask is not None: z *= mask log_s *= mask # Compute the log determinant if self.kind == "affine": sum_axes = util.last_axes(x.shape[len(self.batch_shape):]) log_det = -log_s.sum(axis=sum_axes) else: log_det = jnp.zeros(self.batch_shape) return z, log_det
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] outputs = {} x_shape = self.get_unbatched_shapes(sample)["x"] theta = hk.get_parameter("theta", shape=x_shape + (3 * self.n_components + self.extra, ), dtype=x.dtype, init=hk.initializers.RandomNormal(0.1)) # Split the parameters if self.with_affine_coupling: in_axes = (0, None, None, None, None, None) out_axes = (None, None, None, None, None) else: in_axes = (0, None, None, None) out_axes = (None, None, None) params = self.split_theta(theta) init_fun = self.auto_batch(self.safe_init, in_axes=in_axes, out_axes=out_axes, expected_depth=1) params = init_fun(x, *params) # Run the transform if sample == False: z, elementwise_log_det = self.auto_batch(self.mixture_forward, in_axes=in_axes, expected_depth=1)(x, *params) else: z, elementwise_log_det = self.auto_batch(self.mixture_inverse, in_axes=in_axes, expected_depth=1)(x, *params) sum_axes = util.last_axes(self.unbatched_input_shapes["x"]) log_det = elementwise_log_det.sum(sum_axes) outputs = {"x": z, "log_det": log_det} return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray = None, sample: Optional[bool] = False, **kwargs) -> Mapping[str, jnp.ndarray]: x = inputs["x"] x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = util.last_axes(x_shape) if sample == False: z = jnp.where(x > 0, x, self.alpha * x) else: z = jnp.where(x > 0, x, x / self.alpha) log_dx_dz = jnp.where(x > 0, 0, jnp.log(self.alpha)) log_det = log_dx_dz.sum(axis=sum_axes) * jnp.ones(self.batch_shape) outputs = {"x": z, "log_det": log_det} return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = util.last_axes(x_shape) if sample == False: x = inputs["x"] x = jnp.where(x < 0.0, 1e-5, x) dx = jnp.log1p(-jnp.exp(-x)) z = x + dx log_det = -dx.sum(axis=sum_axes)*jnp.ones(self.batch_shape) outputs = {"x": z, "log_det": log_det} else: x = jax.nn.softplus(inputs["x"]) log_det = -jnp.log1p(-jnp.exp(x)).sum(axis=sum_axes)*jnp.ones(self.batch_shape) outputs = {"x": x, "log_det": log_det} return outputs
def call(self, inputs: Mapping[str, jnp.ndarray], rng: jnp.ndarray=None, sample: Optional[bool]=False, generate_image: Optional[bool]=False, **kwargs ) -> Mapping[str, jnp.ndarray]: x = inputs["x"] outputs = {} x_shape = self.get_unbatched_shapes(sample)["x"] sum_axes = util.last_axes(x_shape) if sample == False: if self.has_scale == True: x *= (1.0 - 2*self.scale) x += self.scale z = jax.scipy.special.logit(x) log_det = (jax.nn.softplus(z) + jax.nn.softplus(-z)) else: z = jax.nn.sigmoid(x) log_det = (jax.nn.softplus(x) + jax.nn.softplus(-x)) # If we are generating images, we want to pass the normalized image # to matplotlib! if generate_image: outputs["image"] = z if self.has_scale == True: z -= self.scale z /= (1.0 - 2*self.scale) if self.has_scale == True: log_det += jnp.log(1.0 - 2*self.scale) log_det = log_det.sum(axis=sum_axes)*jnp.ones(self.batch_shape) outputs["x"] = z outputs["log_det"] = log_det return outputs
def transform(self, x, params=None, sample=False, mask=None): x_flat = x.reshape(self.batch_shape + (-1, )) param_dim = (3 * self.K - 1) if params is None: x_shape = x_flat.shape[len(self.batch_shape):] theta = hk.get_parameter("theta", shape=(x_flat.shape[-1], ) + (param_dim, ), dtype=x_flat.dtype, init=hk.initializers.RandomNormal()) in_axes = (None, 0) else: theta = params.reshape(self.batch_shape + (x_flat.shape[-1], ) + (param_dim, )) in_axes = (0, 0) if sample == False: z, ew_log_det = self.auto_batch(self.forward_spline, in_axes=in_axes)(theta, x_flat) else: z, ew_log_det = self.auto_batch(self.inverse_spline, in_axes=in_axes)(theta, x_flat) z = z.reshape(x.shape) ew_log_det = ew_log_det.reshape(x.shape) # If we're doing mask coupling, need to correctly mask log_s before # computing the log determinant and also mask the output if mask is not None: z *= mask ew_log_det *= mask sum_axes = util.last_axes(x.shape[len(self.batch_shape):]) log_det = ew_log_det.sum(axis=sum_axes) return z, log_det
def _transform(self, x, params=None, sample=False, mask=None, rng=None, **kwargs): z, ew_log_det = self.transform(x, params=params, sample=sample, rng=rng, **kwargs) assert z.shape == ew_log_det.shape # If we're doing mask coupling, need to correctly mask log_s before # computing the log determinant and also mask the output if mask is not None: z *= mask ew_log_det *= mask # Sum over each dimension sum_axes = util.last_axes(x.shape[len(self.batch_shape):]) log_det = ew_log_det.sum(sum_axes) return z, log_det