def get_covariance_function(): gp_dtype = gpf.config.default_float() # Matern 32 m32_cov = Matern32(variance=1, lengthscales=100.) m32_cov.variance.prior = Normal(gp_dtype(1.), gp_dtype(0.1)) m32_cov.lengthscales.prior = Normal(gp_dtype(100.), gp_dtype(50.)) # Periodic base kernel periodic_base_cov = SquaredExponential(variance=5., lengthscales=1.) set_trainable(periodic_base_cov.variance, False) periodic_base_cov.lengthscales.prior = Normal(gp_dtype(5.), gp_dtype(1.)) # Periodic periodic_cov = Periodic(periodic_base_cov, period=1., order=FLAGS.qp_order) set_trainable(periodic_cov.period, False) # Periodic damping periodic_damping_cov = Matern32(variance=1e-1, lengthscales=50) periodic_damping_cov.variance.prior = Normal(gp_dtype(1e-1), gp_dtype(1e-3)) periodic_damping_cov.lengthscales.prior = Normal(gp_dtype(50), gp_dtype(10.)) # Final covariance co2_cov = periodic_cov * periodic_damping_cov + m32_cov return co2_cov
def call(self, state): # Get mean and standard deviation from the policy network a1 = self.dense1_layer(state) a2 = self.dense2_layer(a1) mu = self.mean_layer(a2) # Standard deviation is bounded by a constraint of being non-negative # therefore we produce log stdev as output which can be [-inf, inf] log_sigma = self.stdev_layer(a2) sigma = tf.exp(log_sigma) # Use re-parameterization trick to deterministically sample action from # the policy network. First, sample from a Normal distribution of # sample size as the action and multiply it with stdev dist = Normal(mu, sigma) action_ = dist.sample() # Apply the tanh squashing to keep the gaussian bounded in (-1,1) action = tf.tanh(action_) # Calculate the log probability log_pi_ = dist.log_prob(action_) log_pi = log_pi_ - tf.reduce_sum( tf.math.log(1 - action**2 + eps), axis=1, keepdims=True) return action, log_pi
def total_correlation(z_samples: Tensor, qZ_X: Distribution) -> Tensor: r"""Estimate of total correlation using Gaussian distribution on a batch. We need to compute the expectation over a batch of: `E_j [log(q(z(x_j))) - log(prod_l q(z(x_j)_l))]` We ignore the constants as they do not matter for the minimization. The constant should be equal to `(num_latents - 1) * log(batch_size * dataset_size)` If `alpha = gamma = 1`, Eq(4) can be written as `ELBO + (1 - beta) * TC`. (i.e. `(1. - beta) * total_correlation(z_sampled, qZ_X)`) Parameters ---------- z_samples : Tensor shape `[batch_size, num_latents]` - tensor with sampled representation. qZ_X : Distribution the posterior distribution, shape `[batch_size, num_latents]` Note ---- This involve calculating pair-wise distance, memory complexity up to `O(n*n*d)`. Returns ------- Total correlation estimated on a batch. References ---------- Chen, R.T.Q., Li, X., Grosse, R., Duvenaud, D., 2019. Isolating Sources of Disentanglement in Variational Autoencoders. arXiv:1802.04942 [cs, stat]. Github code https://github.com/google-research/disentanglement_lib """ gaus = Normal(loc=tf.expand_dims(qZ_X.mean(), 0), scale=tf.expand_dims(qZ_X.stddev(), 0)) # Compute log(q(z(x_j)|x_i)) for every sample in the batch, which is a # tensor of size [batch_size, batch_size, num_latents]. In the following # comments, [batch_size, batch_size, num_latents] are indexed by [j, i, l]. log_qz_prob = gaus.log_prob(tf.expand_dims(z_samples, 1)) # Compute log prod_l p(z(x_j)_l) = sum_l(log(sum_i(q(z(z_j)_l|x_i))) # + constant) for each sample in the batch, which is a vector of size # [batch_size,]. log_qz_product = tf.reduce_sum(tf.reduce_logsumexp(log_qz_prob, axis=1, keepdims=False), axis=1, keepdims=False) # Compute log(q(z(x_j))) as log(sum_i(q(z(x_j)|x_i))) + constant = # log(sum_i(prod_l q(z(x_j)_l|x_i))) + constant. log_qz = tf.reduce_logsumexp(tf.reduce_sum(log_qz_prob, axis=2, keepdims=False), axis=1, keepdims=False) return tf.reduce_mean(log_qz - log_qz_product)
def get_action_distribution(actor, state, network, recurrent_state=None): if network == "gru": mu, sigma, recurrent_state = actor(state, initial_state=recurrent_state) return Normal(loc=mu, scale=sigma), recurrent_state if network == "mlp": mu, sigma, _ = actor(state) return Normal(loc=mu, scale=sigma), None
def __init__(self, normal_mean, normal_std, epsilon=1e-6): """ :param normal_mean: Mean of the normal distribution :param normal_std: Std of the normal distribution :param epsilon: Numerical stability epsilon when computing log-prob. """ self.normal_mean = normal_mean self.normal_std = normal_std self.normal = Normal(normal_mean, normal_std) self.epsilon = epsilon
def elbo_components(self, inputs, training=None, mask=None): X_u, y_u, X_l, y_l = _prepare_elbo(self, inputs, training=training, mask=mask) y_l = tf.clip_by_value(y_l, 1e-8, 1. - 1e-8) px_z_u, (qz_x_u, qzc_x_u, qy_zx_u) = self(X_u, training=training) px_z_l, (qz_x_l, qzc_x_l, qy_zx_l) = self(X_l, training=training) z_exc = tf.concat( [tf.convert_to_tensor(qz_x_u), tf.convert_to_tensor(qz_x_l)], axis=0) z_c = tf.concat( [tf.convert_to_tensor(qzc_x_u), tf.convert_to_tensor(qzc_x_l)], axis=0) # Convert y to one-hot vector and Sample y for those without labels y_sup = y_l y_uns = tf.convert_to_tensor(qy_zx_u) y = tf.concat((y_uns, y_sup), axis=0) # log q(y|z_c) h = tf.concat([qy_zx_u.logits, qy_zx_l.logits], axis=0) log_q_y_zc = tf.reduce_sum(h * y, axis=1) # log p(x|z) log_p_x_z = tf.concat([px_z_u.log_prob(X_u), px_z_l.log_prob(X_l)], axis=0) # log p(z_c|y) pzc_y = self.regressor(y) log_p_zc_y = pzc_y.log_prob(z_c) # log p(z_\c) dist = Normal(tf.cast(0., self.dtype), 1.) log_p_zexc = tf.reduce_sum(dist.log_prob(z_exc), axis=-1) # log p(z|y) log_p_z_y = log_p_zc_y + log_p_zexc # log q(y|x) (Draw 128 points from q(z_c|x). Supervised samples only) h = qzc_x_l.sample(self.n_resamples) h = tf.reshape(h, (-1, h.shape[-1])) qy_x = self.classify(h, training=training) qy_x_logits = tf.reshape(qy_x.logits, (self.n_resamples, -1, h.shape[-1])) h = tf.reduce_logsumexp(h, axis=0) - tf.math.log(128.) log_q_y_x = tf.reduce_sum(h * y_l, axis=1) # log q(z|x) log_qz_x = tf.concat([qz_x_u.log_prob(qz_x_u), qz_x_l.log_prob(qz_x_l)], axis=0) log_qzc_x = tf.concat( [qzc_x_u.log_prob(qzc_x_u), qzc_x_l.log_prob(qzc_x_l)], axis=0) log_q_z_x = log_qz_x + log_qzc_x # Calculate the lower bound n_uns = ps.shape(X_u)[0] h = log_p_x_z + log_p_z_y - log_q_y_zc - log_q_z_x coef_sup = tf.math.exp(log_q_y_zc[n_uns:] - log_q_y_x) coef_uns = tf.ones((n_uns,), dtype=self.dtype) coef = tf.concat((coef_uns, coef_sup), axis=0) zeros = tf.zeros((n_uns,), dtype=self.dtype) lb = coef * h + tf.concat((zeros, log_q_y_x), axis=0) return {'elbo': lb}, {}
def test_pad_mixture_dimensions_mixture(self): gm = Mixture(cat=Categorical(probs=[[0.3, 0.7]]), components=[ Normal(loc=[-1.0], scale=[1.0]), Normal(loc=[1.0], scale=[0.5]) ]) x = tf.constant([[1.0, 2.0], [3.0, 4.0]]) x_pad = distribution_util.pad_mixture_dimensions( x, gm, gm.cat, tensorshape_util.rank(gm.event_shape)) x_out, x_pad_out = self.evaluate([x, x_pad]) self.assertAllEqual(x_pad_out.shape, [2, 2]) self.assertAllEqual(x_out.reshape([-1]), x_pad_out.reshape([-1]))
def get_covariance_function(): gp_dtype = gpf.config.default_float() matern_variance = 5500. matern_lengthscales = 5. m32_cov = Matern32(variance=matern_variance, lengthscales=matern_lengthscales) m32_cov.variance.prior = Normal(gp_dtype(matern_variance), gp_dtype(matern_variance)) m32_cov.lengthscales.prior = Normal(gp_dtype(matern_lengthscales), gp_dtype(matern_lengthscales)) return m32_cov
def set_priors(gp_model: GPModel): to_dtype = gpf.utilities.to_default_float if FLAGS.model == ModelEnum.GP.value: gp_model.likelihood.variance.prior = Normal(to_dtype(0.1), to_dtype(1.)) gp_model.likelihood.variance.prior_on = PriorOn.UNCONSTRAINED else: gp_model.noise_variance.prior = Normal(to_dtype(0.1), to_dtype(1.)) gp_model.noise_variance.prior_on = PriorOn.UNCONSTRAINED gp_model.kernel.variance.prior = Normal(to_dtype(1.), to_dtype(3.)) gp_model.kernel.variance.prior_on = PriorOn.UNCONSTRAINED gp_model.kernel.lengthscales.prior = Normal(to_dtype(1.), to_dtype(3.)) gp_model.kernel.lengthscales.prior_on = PriorOn.UNCONSTRAINED return gp_model
class NormalPyramid(Distribution): # means is indexed by *, y, x, channel # base_sigma is a scalar, and is the std used for the 'raw' pixels; subsequent levels use smaller stds def __init__(self, means, base_sigma, levels=None, validate_args=False, allow_nan_stats=True, name='NormalPyramid'): with ops.name_scope(name, values=[means, base_sigma]) as ns: self._means = array_ops.identity(means, name='means') self._base_sigma = array_ops.identity(base_sigma, name='base_sigma') self._base_dist = Normal(loc=self._means, scale=self._base_sigma) self._standard_normal = Normal(loc=0., scale=1.) self._levels = levels super(NormalPyramid, self).__init__(dtype=tf.float32, parameters={ 'means': means, 'base_sigma': base_sigma }, reparameterization_type=FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns) def _log_prob(self, x): # The resulting density here will be indexed by *, i.e. we sum over x, y, channel, and pyramid-levels z = (x - self._means) / self._base_sigma z_shape = list(map(int, z.get_shape())) z_pyramid = gaussian_pyramid(z, self._levels) return sum( tf.reduce_mean(self._standard_normal.log_prob(z_level), axis=[-3, -2, -1]) # ** check the rescaling here! for level_index, z_level in enumerate(z_pyramid)) / len(z_pyramid) def _sample_n(self, n, seed=None): return self._base_dist._sample_n(n, seed) def _mean(self): return self._means def _mode(self): return self._means
def make_gaussian_out(p: tf.Tensor, event_shape: Sequence[int]) -> Independent: loc, scale = tf.split(p, 2, -1) loc = tf.reshape(loc, (-1,) + tuple(event_shape)) scale = tf.reshape(scale, (-1,) + tuple(event_shape)) scale = tf.nn.softplus(scale) return Independent(Normal(loc=loc, scale=scale), len(event_shape))
def variational_posterior(shape, name, prior, istraining): """ this function create a variational posterior q(w/theta) over a given "weight:w" of the network theta is parameterized by mean+standard*noise we apply the reparameterization trick from kingma et al, 2014 with correct loss function (free energy) we learn mean and standard to estimate of theta, thus can estimate posterior p(w/D) by computing KL loss for each variational posterior q(w/theta) with prior(w) :param name: is the name of the tensor/variable to create variational posterior q(w/Q) for true posterior (p(w/D)) :param shape: is the shape of the weigth variable :param training: whether in training or inference mode :return: samples (i.e. weights), mean of weigths, std in-case of the training there is noise assoicated with the weights """ # theta=mu+sigma i.e. theta = mu+sigma i.e. mu+log(1+exp(rho)), log(1+exp(rho)) is the computed by using tf.math.softplus(rho) to avoid negative sigma #need to check for init mu = tf.get_variable("{}_mean".format(name), shape=shape, dtype=tf.float32) rho = tf.get_variable("{}_rho".format(name), shape=shape, dtype=tf.float32) sigma = tf.math.softplus(rho) #if training we add noise to variation parameters theta if (istraining): epsilon = Normal(0, 1.0).sample(shape) sample = mu + sigma * epsilon else: sample = mu + sigma return sample, mu, sigma
def _parse_distribution(input_shape: Tuple[int, int, int], distribution: Literal['qlogistic', 'mixqlogistic', 'bernoulli', 'gaussian'], n_components=10, n_channels=3) -> Tuple[int, DistributionLambda, Layer]: from odin.bay.layers import DistributionDense n_channels = input_shape[-1] last_layer = Activation('linear') # === 1. Quantized logistic if distribution == 'qlogistic': n_params = 2 observation = DistributionLambda( lambda params: QuantizedLogistic( *[ # loc p if i == 0 else # Ensure scales are positive and do not collapse to near-zero tf.nn.softplus(p) + tf.cast(tf.exp(-7.), tf.float32) for i, p in enumerate(tf.split(params, 2, -1)) ], low=0, high=255, inputs_domain='sigmoid', reinterpreted_batch_ndims=3), convert_to_tensor_fn=Distribution.sample, name='image') # === 2. Mixture Quantized logistic elif distribution == 'mixqlogistic': n_params = MixtureQuantizedLogistic.params_size( n_components=n_components, n_channels=n_channels) // n_channels observation = DistributionLambda( lambda params: MixtureQuantizedLogistic(params, n_components=n_components, n_channels=n_channels, inputs_domain='sigmoid', high=255, low=0), convert_to_tensor_fn=Distribution.mean, name='image') # === 3. Bernoulli elif distribution == 'bernoulli': n_params = 1 observation = DistributionDense( event_shape=input_shape, posterior=lambda p: Independent(Bernoulli(logits=p), len(input_shape)), projection=False, name="image") # === 4. Gaussian elif distribution == 'gaussian': n_params = 2 observation = DistributionDense( event_shape=input_shape, posterior=lambda p: Independent(Normal(*tf.split(p, 2, -1)), len(input_shape)), projection=False, name="image") else: raise ValueError(f'No support for distribution {distribution}') return n_params, observation, last_layer
def __init__(self, units, **kwargs): super().__init__(units, posterior=NormalLayer, posterior_kwargs=dict(scale_activation='softplus1'), prior=Independent( Normal(loc=tf.zeros(shape=units), scale=tf.ones(shape=units)), 1), **kwargs)
def set_gp_priors(gp_model: GPModel): gp_dtype = gpf.config.default_float() variance_prior = Normal(gp_dtype(FLAGS.noise_variance), gp_dtype(FLAGS.noise_variance)) if FLAGS.model == ModelEnum.GP.value: gp_model.likelihood.variance.prior = variance_prior else: gp_model.noise_variance.prior = variance_prior
def __init__(self, args: Arguments, free_bits=None, beta=1, **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdim = args.zdim prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])) networks['latents'] = DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name=networks['latents'].name) super().__init__(free_bits=free_bits, beta=beta, **networks, **kwargs)
def model_fullcov(args: Arguments): nets = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdims = int(np.prod(nets['latents'].event_shape)) nets['latents'] = RVconf( event_shape=zdims, projection=True, posterior='mvntril', prior=Independent(Normal(tf.zeros([zdims]), tf.ones([zdims])), 1), name='latents').create_posterior() return VariationalAutoencoder(**nets, name='FullCov')
def rsample(self, return_pretanh_value=False): """ Sampling in the reparameterization case. """ z = (self.normal_mean + self.normal_std * Normal(tf.zeros_like(self.normal_mean), tf.ones_like(self.normal_std)).sample()) if return_pretanh_value: return tf.math.tanh(z), z else: return tf.math.tanh(z)
def get_action_distribution(self, state, recurrent_state=None, update=False): # Get the normal distribution over the action space, determined by mu and sigma if self.network == "mlp": mu, sigma, _ = self.actor(state.squeeze()) return Normal(loc=mu, scale=sigma), None if self.network == "gru": if update: # create mask to reset the recurrent state for finished environments mask = np.ones( ((self.num_agents * self.num_steps), self.obs_space_size)) mask[self.memory.terminals, :] = 0 state = tf.concat((state, mask), axis=1) state = tf.reshape(state, (self.num_agents, self.num_steps, (self.obs_space_size * 2))) # forward mask together with input to the actor mu, sigma, recurrent_state = self.actor( state, initial_state=recurrent_state) return Normal(loc=mu, scale=sigma), recurrent_state
def new(dists): q_e, q_d = dists mu_e = q_e.mean() mu_d = q_d.mean() prec_e = 1 / q_e.variance() prec_d = 1 / q_d.variance() mu = (mu_e * prec_e + mu_d * prec_d) / (prec_e + prec_d) scale = tf.math.sqrt(1 / (prec_e + prec_d)) dist = Normal(loc=mu, scale=scale) if isinstance(q_e, Independent): ndim = q_e.reinterpreted_batch_ndims dist = Independent(dist, reinterpreted_batch_ndims=ndim) return dist
def __call__(self): print("Calling posterior variable" + self.name) # if training we add noise to variation parameters theta if (self.isTraining): epsilon = Normal(0, 1.0).sample(self.shape) self.samples = self.mu + self.sigma * epsilon self.KL=self.compute_KL_univariate_prior(self.prior,self.samples) print("Training variational posterior" + self.name + "KL loss" + str(self.Kl)) #debug only else: self.samples = self.mu + self.sigma;
def __init__(self, args: Arguments, **kwargs): networks = get_networks(args.ds, zdim=args.zdim, is_hierarchical=False, is_semi_supervised=False) zdim = args.zdim prior = Normal(loc=tf.zeros([zdim]), scale=tf.ones([zdim])) latents = [ DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name='latents1'), DistributionDense(units=zdim * 2, posterior=make_normal, prior=prior, name='latents2') ] networks['latents'] = latents super().__init__(**networks, **kwargs)
def set_prior(self, loc=0., log_scale=np.log(np.expm1(1)), mixture_logits=None): r""" Set the prior for mixture density network loc : Scalar or Tensor with shape `[n_components, event_size]` log_scale : Scalar or Tensor with shape `[n_components, event_size]` for 'none' and 'diag' component, and `[n_components, event_size*(event_size +1)//2]` for 'full' component. mixture_logits : Scalar or Tensor with shape `[n_components]` """ event_size = self.event_size if self.covariance == 'diag': scale_shape = [self.n_components, event_size] fn = lambda l, s: MultivariateNormalDiag( loc=l, scale_diag=tf.nn.softplus(s)) elif self.covariance == 'none': scale_shape = [self.n_components, event_size] fn = lambda l, s: Independent( Normal(loc=l, scale=tf.math.softplus(s)), 1) elif self.covariance == 'full': scale_shape = [ self.n_components, event_size * (event_size + 1) // 2 ] fn = lambda l, s: MultivariateNormalTriL( loc=l, scale_tril=FillScaleTriL(diag_shift=1e-5)(tf.math.softplus(s))) # if isinstance(log_scale, Number) or tf.rank(log_scale) == 0: loc = tf.fill([self.n_components, self.event_size], loc) # if isinstance(log_scale, Number) or tf.rank(log_scale) == 0: log_scale = tf.fill(scale_shape, log_scale) # if mixture_logits is None: p = 1. / self.n_components mixture_logits = np.log(p / (1. - p)) if isinstance(mixture_logits, Number) or tf.rank(mixture_logits) == 0: mixture_logits = tf.fill([self.n_components], mixture_logits) # loc = tf.cast(loc, self.dtype) log_scale = tf.cast(log_scale, self.dtype) mixture_logits = tf.cast(mixture_logits, self.dtype) self._prior = MixtureSameFamily( components_distribution=fn(loc, log_scale), mixture_distribution=Categorical(logits=mixture_logits), name="prior") return self
def __init__(self, means, base_sigma, levels=None, validate_args=False, allow_nan_stats=True, name='NormalPyramid'): with ops.name_scope(name, values=[means, base_sigma]) as ns: self._means = array_ops.identity(means, name='means') self._base_sigma = array_ops.identity(base_sigma, name='base_sigma') self._base_dist = Normal(loc=self._means, scale=self._base_sigma) self._standard_normal = Normal(loc=0., scale=1.) self._levels = levels super(NormalPyramid, self).__init__(dtype=tf.float32, parameters={ 'means': means, 'base_sigma': base_sigma }, reparameterization_type=FULLY_REPARAMETERIZED, validate_args=validate_args, allow_nan_stats=allow_nan_stats, name=ns)
def __init__(self, loc=0., scale=1., min_value=None, max_value=None, validate_args=False, allow_nan_stats=True, name="qNormal"): super(qNormal, self).__init__(distribution=Normal( loc=loc, scale=scale, validate_args=validate_args, allow_nan_stats=allow_nan_stats), low=min_value, high=max_value, name=name)
def __init__(self, units: int, prior_loc: float = 0., prior_scale: float = 1., projection: bool = True, name: str = "Latents", **kwargs): super().__init__( event_shape=(int(units), ), posterior=NormalLayer, posterior_kwargs=dict(scale_activation='softplus'), prior=Independent(Normal(loc=tf.fill((units, ), prior_loc), scale=tf.fill((units, ), prior_scale)), reinterpreted_batch_ndims=1), projection=projection, name=name, **kwargs, )
def get_KL_multivariate_prior(self, multivariateprior, theta, sample): """ :param prior: assuming univatier prior of Normal(m,s); i.e. Normal(m1,s1) and Normal(m2,s2) :param posterior: (theta: mean,std) to create posterior q(w/theta) i.e. Normal(mean,std) :param sample: :return: """ sample = tf.reshape(sample, [-1]) #flatten vector (mean, std) = theta posterior = Normal(mean, std) (std1, std2) = multivariateprior prior1 = Normal(0, std1) prior2 = Normal(0, std2) q_theta = tf.reduce_sum(posterior.log_prob(sample)) p1 = tf.reduce_sum(prior1.log_prob(sample)) p2 = tf.reduce_sum( prior2.log_prob(sample)) #this is wrong need to work this out KL = tf.subtract(q_theta, tf.reduce_logsumexp([p1, p2])) return KL
def encode(self, inputs, library=None, training=None, mask=None, sample_shape=(), **kwargs): qZ_X = super().encode(inputs=inputs, library=library, training=training, mask=mask, sample_shape=sample_shape) if library is not None: mean, var = tf.split(tf.nest.flatten(library)[0], 2, axis=1) pL = Independent(Normal(loc=mean, scale=tf.math.sqrt(var)), 1) else: pL = None qZ_X[-1].KL_divergence.prior = pL return qZ_X
def llk_pixels(model: VariationalModel, valid_ds: tf.data.Dataset): llk = [] for x, y in valid_ds.take(5): px, _ = _call(model, x, y, decode=True) px = as_tuple(px)[0] if hasattr(px, 'distribution'): px = px.distribution if isinstance(px, Bernoulli): px = Bernoulli(logits=px.logits) elif isinstance(px, Normal): px = Normal(loc=px.loc, scale=px.scale) elif isinstance(px, QuantizedLogistic): px = QuantizedLogistic(loc=px.loc, scale=px.scale, low=px.low, high=px.high, inputs_domain=px.inputs_domain, reinterpreted_batch_ndims=None) else: return # nothing to do llk.append(px.log_prob(x)) # average over all channels llk_image = tf.reduce_mean(tf.reduce_mean(tf.concat(llk, 0), axis=0), axis=-1) llk = tf.reshape(llk_image, -1) tf.summary.histogram('valid/llk_pixels', llk, step=model.step) # show the image heatmap of llk pixels fig = plt.figure(figsize=(3, 3)) ax = plt.gca() im = ax.pcolormesh(llk_image.numpy(), cmap='Spectral', vmin=np.min(llk), vmax=np.max(llk)) ax.axis('off') ax.margins(0.) # color bar ticks = np.linspace(np.min(llk), np.max(llk), 5) cbar = plt.colorbar(im, ax=ax, fraction=0.04, pad=0.02, ticks=ticks) cbar.ax.set_yticklabels([f'{i:.2f}' for i in ticks]) cbar.ax.tick_params(labelsize=6) plt.tight_layout() tf.summary.image('llk_heatmap', vs.plot_to_image(fig, dpi=100))
def build(self, input_shape=None): super().build(input_shape) if self._disable: return decoder_shape = self.layer.compute_output_shape(input_shape) layer, layer_t = _NDIMS_CONV[self.input_ndim] # === 1. create projection layer assert self.encoder is not None, \ 'ParallelLatents require encoder to be specified' # posterior projection (assume encoder shape and decoder shape the same) self._conv_posterior = layer(**self._network_kw, name='ConvPosterior') self._conv_posterior.build(decoder_shape) # === 2. distribution params_shape = self._conv_posterior.compute_output_shape(decoder_shape) self._dist_posterior = DistributionLambda( make_distribution_fn=partial(_create_dist, event_ndims=len(params_shape) - 1, dtype=self.dtype), name=f'{self.name}_posterior') self._dist_posterior.build(params_shape) # dynamically infer the shape latents_shape = tf.convert_to_tensor( self._dist_posterior(keras.layers.Input(params_shape[1:]))).shape self._latents_shape = latents_shape[1:] # create the prior N(0,I) self._prior = Independent(Normal(loc=tf.zeros(self.latents_shape, dtype=self.dtype), scale=tf.ones(self.latents_shape, dtype=self.dtype)), reinterpreted_batch_ndims=len( self.latents_shape), name=f'{self.name}_prior') # === 3. final output affine self._conv_out = _upsample_by_conv( layer, layer_t, input_shape=latents_shape, output_shape=decoder_shape, kernel_size=self._conv_posterior.kernel_size, padding=self._conv_posterior.padding, strides=self._conv_posterior.strides)