def losses(self): """Sum of KL divergences between posteriors and priors""" w_prior = tfd.Dirichlet(tf.ones([self.K])) theta_prior = tfd.Beta([0.1, 3, 9.9], [9.9, 3, 0.1]) return (tf.reduce_sum(tfd.kl_divergence(self.weight, w_prior)) + tf.reduce_sum(tfd.kl_divergence(self.ASR, theta_prior)))
def testChiChiKL(self): # We make sure a_df and b_df don't have any overlap. If this is not done, # then the check for true_kl vs kl_sample_ ends up failing because the # true_kl is zero and the sample is nonzero (though very small). a_df = np.arange(1, 6, step=2, dtype=np.float64) b_df = np.arange(2, 7, step=2, dtype=np.float64) a_df = a_df.reshape((len(a_df), 1)) b_df = b_df.reshape((1, len(b_df))) a = tfd.Chi(df=a_df) b = tfd.Chi(df=b_df) true_kl = (0.5 * special.digamma(0.5 * a_df) * (a_df - b_df) + special.gammaln(0.5 * b_df) - special.gammaln(0.5 * a_df)) kl = tfd.kl_divergence(a, b) x = a.sample(int(8e5), seed=test_util.test_seed()) kl_sample = tf.reduce_mean(a.log_prob(x) - b.log_prob(x), axis=0) kl_, kl_sample_ = self.evaluate([kl, kl_sample]) self.assertAllClose(true_kl, kl_, atol=0., rtol=1e-14) self.assertAllClose(true_kl, kl_sample_, atol=0., rtol=5e-2) zero_kl = tfd.kl_divergence(a, a) true_zero_kl_, zero_kl_ = self.evaluate( [tf.zeros_like(zero_kl), zero_kl]) self.assertAllEqual(true_zero_kl_, zero_kl_)
def kl(self, img1, img2): """Kullback-Leibler divergence between the posterior and the prior. Args: seg: A tensor of shape (b, h, w, num_classes). img: A tensor of shape (b, h, w, c). Returns: A dictionary with keys indexing the hierarchy's levels and corresponding values holding the KL-term for each level (per batch). """ self._build(img1, img2) posterior_out = self._q_sample prior_out = self._p_sample_z_q q_dists = posterior_out['distributions'] p_dists = prior_out['distributions'] kl_balanced = self._loss_kwargs['balanced'] kl = {} for level, (q, p) in enumerate(zip(q_dists, p_dists)): # Shape (b, h, w). if kl_balanced: q_freezed = tfd.MultivariateNormalDiag(loc=tf.stop_gradient(q.loc), scale_diag=tf.stop_gradient(q.scale.diag)) p_freezed = tfd.MultivariateNormalDiag(loc=tf.stop_gradient(p.loc), scale_diag=tf.stop_gradient(p.scale.diag)) kl_per_pixel = 0.2 * tfd.kl_divergence(q, p_freezed) + .8 * tfd.kl_divergence(q_freezed, p) else: kl_per_pixel = tfd.kl_divergence(q, p) # Shape (b,). kl_per_instance = tf.reduce_sum(kl_per_pixel, axis=[1, 2]) # Shape (1,). kl[level] = tf.reduce_mean(kl_per_instance) return kl
def KLsum(self): """Sum of KL divergences between posteriors and priors""" kl_mu = tf.reduce_sum(tfd.kl_divergence(self.mu, self.mu_prior)) kl_sigma = tf.reduce_sum( tfd.kl_divergence(self.sigma, self.sigma_prior)) kl_theta = tf.reduce_sum( tfd.kl_divergence(self.theta, self.theta_prior)) return kl_mu + kl_sigma + kl_theta
def KLsum(self): """Sum of KL divergences between posteriors and priors""" kl_theta = tf.reduce_sum(tfd.kl_divergence(self.theta, self.theta_prior)) kl_gamma = tf.reduce_sum(tfd.kl_divergence(self.gamma, self.gamma_prior)) kl_Y = tf.reduce_sum(self.Y.mean() * tf.math.log(self.Y.mean() / self.Y_prior.mean())) kl_Z = tf.reduce_sum(self.Z.mean() * tf.math.log(self.Z.mean() / self.Z_prior.mean())) return kl_theta + kl_Y + kl_Z # + kl_gamma
def KLsum(self): """ Sum of KL divergences between posteriors and priors The KL divergence for multinomial distribution is defined manually """ kl_mu = tf.reduce_sum(tfd.kl_divergence(self.mu, self.mu_prior)) kl_sigma = tf.reduce_sum( tfd.kl_divergence(self.sigma, self.sigma_prior)) kl_ident = tf.reduce_sum( self.ident.mean() * tf.math.log(self.ident.mean() / self.ident_prior.mean())) # axis=0 return kl_mu + kl_sigma + kl_ident
def divergence_from_states(self, lhs, rhs, mask=None): lhs = self.dist_from_state(lhs, mask) rhs = self.dist_from_state(rhs, mask) divergence = tfd.kl_divergence(lhs, rhs) if mask is not None: divergence = tools.mask(divergence, mask) return divergence
def kl(self, seg, img): """Kullback-Leibler divergence between the posterior and the prior. Args: seg: A tensor of shape (b, h, w, num_classes). img: A tensor of shape (b, h, w, c). Returns: A dictionary with keys indexing the hierarchy's levels and corresponding values holding the KL-term for each level (per batch). """ self._build(seg, img) posterior_out = self._q_sample prior_out = self._p_sample_z_q q_dists = posterior_out['distributions'] p_dists = prior_out['distributions'] kl = {} for level, (q, p) in enumerate(zip(q_dists, p_dists)): # Shape (b, h, w). kl_per_pixel = tfd.kl_divergence(q, p) # Shape (b,). kl_per_instance = tf.reduce_sum(kl_per_pixel, axis=[1, 2]) # Shape (1,). kl[level] = tf.reduce_mean(kl_per_instance) return kl
def compute_losses(self, obs): image = obs['image'] action = obs['action'] reward = obs['reward'] state = self.initialize(tf.shape(image)[0]) state['rewards'] = reward priors, posteriors = self.observe(action, image, state) features = self._get_features(posteriors) frames = self._predict_frames_tpl(features) if self._reward_from_frames: rewards = self._predict_reward_tpl(frames.mode(), reward[:, -1]) else: rewards = self._predict_reward_tpl(features, reward[:, -1]) obs_likelihood = frames.log_prob(image) reward_likelihood = rewards.log_prob(tf.to_float(reward)) divergence = tfd.kl_divergence(self._get_distribution(posteriors), self._get_distribution(priors)) divergence = tf.maximum(self._free_nats, divergence) loss = tf.reduce_mean(divergence - obs_likelihood - reward_likelihood * self._reward_loss_mul) frames_mode = tf.clip_by_value((frames.mode() + 0.5) * 255, 0, 255) return (loss, tf.reduce_mean(reward_likelihood), tf.reduce_mean(divergence), tf.cast(frames_mode, dtype=tf.uint8), rewards.mode(), obs['reward'], tf.reduce_mean(tf.math.squared_difference(frames_mode, image)))
def _train_model(self, data): with tf.GradientTape() as model_tape: embed = self._encode(data) # print(embed,data['action']) post, prior = self._dynamics.observe(embed, data['action'], data['desc']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() if self._c.cpc: # print("using cpc") pred = self._cpc_pred(embed) # print(pred,feat) cpc_loss = -1. * tf.math.reduce_mean( tools.compute_cpc_loss(pred, feat, self._c)) # caution! model_loss = cpc_loss else: model_loss = cpc_loss = 0 likes.image = tf.reduce_mean(image_pred.log_prob( data['image'])) likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) # model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss += self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss)
def call(self, inputs): mu, std = inputs var_dist = tfp.MultivariateNormalDiag(loc=mu, scale_diag=std) pri_dist = tfp.MultivariateNormalDiag(loc=K.zeros_like(mu), scale_diag=K.ones_like(std)) kl_loss = self.lamb_kl * K.mean(tfp.kl_divergence(var_dist, pri_dist)) return kl_loss
def grad_planner(state, num_actions, horizon, proposals, iterations, imagine, objective, kl_scale, step_size): dtype = prec.global_policy().compute_dtype B, P = list(state.values())[0].shape[0], proposals H, A = horizon, num_actions flat_state = {k: tf.repeat(v, P, 0) for k, v in state.items()} mean = tf.zeros((B, H, A), dtype) rawstd = 0.54 * tf.ones((B, H, A), dtype) for _ in range(iterations): proposals = tf.random.normal((B, P, H, A), dtype=dtype) with tf.GradientTape(watch_accessed_variables=False) as tape: tape.watch(mean) tape.watch(rawstd) std = tf.nn.softplus(rawstd) proposals = proposals * std[:, None] + mean[:, None] proposals = (tf.stop_gradient(tf.clip_by_value(proposals, -1, 1)) + proposals - tf.stop_gradient(proposals)) flat_proposals = tf.reshape(proposals, (B * P, H, A)) states = imagine(flat_proposals, flat_state) scores = objective(states) scores = tf.reshape(tf.reduce_sum(scores, -1), (B, P)) div = tfd.kl_divergence( tfd.Normal(mean, std), tfd.Normal(tf.zeros_like(mean), tf.ones_like(std))) elbo = tf.reduce_sum(scores) - kl_scale * div elbo /= tf.cast(tf.reduce_prod(tf.shape(scores)), dtype) grad_mean, grad_rawstd = tape.gradient(elbo, [mean, rawstd]) e, v = tf.nn.moments(grad_mean, [1, 2], keepdims=True) grad_mean /= tf.sqrt(e * e + v + 1e-4) e, v = tf.nn.moments(grad_rawstd, [1, 2], keepdims=True) grad_rawstd /= tf.sqrt(e * e + v + 1e-4) mean = tf.clip_by_value(mean + step_size * grad_mean, -1, 1) rawstd = rawstd + step_size * grad_rawstd return mean[:, 0, :]
def divergence_from_states(self, lhs, rhs, mask=None): """Compute the divergence measure between two states.""" lhs = self.dist_from_state(lhs, mask) rhs = self.dist_from_state(rhs, mask) divergence = tfd.kl_divergence(lhs, rhs) if mask is not None: divergence = tools.mask(divergence, mask) return divergence
def divergence(self, other: StateDist, mask: Optional[tf.Tensor] = None) -> tf.Tensor: """Compute the divergence measure with other state.""" divergence = tfd.kl_divergence(self.to_dist(mask), other.to_dist(mask)) if mask is not None: divergence = apply_mask(divergence, mask) return divergence
def _train(self, data, test_data, log_images, step=1, should_print=False): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.image = tf.reduce_mean(image_pred.log_prob(data['image'])) likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) with tf.GradientTape() as actor_tape: if step % 100000 == 0 and step > 0: if should_print: self.already_printed_dict[step] = True test_embed = self._encode(test_data) test_post, test_prior = self._dynamics.observe( test_embed, test_data['action']) imag_feat = self._imagine_ahead(test_post) imag_feat_sliced = imag_feat[:] decoded_images = self._decode(imag_feat_sliced) for j in range(100): for i in [5]: current_normal = decoded_images[j][i].distribution mean = current_normal.loc normalized_mean = tf.math.divide( tf.math.subtract(mean, tf.reduce_min(mean)), tf.math.subtract(tf.reduce_max(mean), tf.reduce_min(mean))) normalized_mean_int = tf.image.convert_image_dtype( normalized_mean, tf.uint8) image_file = tf.io.encode_jpeg(normalized_mean_int) file_name = "./img/steps{}traj{}img{}.jpg".format( step, i, j) tf.io.write_file(tf.constant(file_name), image_file)
def define_graph(config): network_tpl = tf.make_template('network', network, config=config) inputs = tf.placeholder(tf.float32, [None, config.num_inputs]) targets = tf.placeholder(tf.float32, [None, 1]) num_visible = tf.placeholder(tf.int32, []) batch_size = tf.shape(inputs)[0] data_dist, mean_dist = network_tpl(inputs) ood_inputs = inputs + tf.random_normal(tf.shape(inputs), 0.0, config.noise_std) ood_data_dist, ood_mean_dist = network_tpl(ood_inputs) assert len(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)) divergence = sum([ tf.reduce_sum(tensor) for tensor in tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES) ]) num_batches = tf.to_float(num_visible) / tf.to_float(batch_size) if config.center_at_target: ood_mean_prior = tfd.Normal(targets, 1.0) else: ood_mean_prior = tfd.Normal(0.0, 1.0) losses = [ config.divergence_scale * divergence / num_batches, -data_dist.log_prob(targets), config.ncp_scale * tfd.kl_divergence(ood_mean_prior, ood_mean_dist), ] if config.ood_std_prior: sg = tf.stop_gradient ood_std_dist = tfd.Normal(sg(ood_mean_dist.mean()), ood_data_dist.stddev()) ood_std_prior = tfd.Normal(sg(ood_mean_dist.mean()), config.ood_std_prior) divergence = tfd.kl_divergence(ood_std_prior, ood_std_dist) losses.append(config.ncp_scale * divergence) loss = sum(tf.reduce_sum(loss) for loss in losses) / tf.to_float(batch_size) optimizer = tf.train.AdamOptimizer(config.learning_rate) gradients, variables = zip( *optimizer.compute_gradients(loss, colocate_gradients_with_ops=True)) if config.clip_gradient: gradients, _ = tf.clip_by_global_norm(gradients, config.clip_gradient) optimize = optimizer.apply_gradients(zip(gradients, variables)) data_mean = mean_dist.mean() data_noise = data_dist.stddev() data_uncertainty = mean_dist.stddev() return tools.AttrDict(locals())
def _compute_kl_loss(self, post_probs, prior_probs): """ Compute KL divergence between two OnehotCategorical Distributions Notes: KL[ Q(z_post) || P(z_prior) ] Q(z_prior) := Q(z | h, o) P(z_prior) := P(z | h) Scratch Impl.: qlogq = post_probs * tf.math.log(post_probs) qlogp = post_probs * tf.math.log(prior_probs) kl_div = tf.reduce_sum(qlogq - qlogp, [1, 2]) Inputs: prior_probs (L, B, latent_dim, n_atoms) post_probs (L, B, latent_dim, n_atoms) """ #: Add small value to prevent inf kl post_probs += 1e-5 prior_probs += 1e-5 #: KL Balancing: See 2.2 BEHAVIOR LEARNING Algorithm 2 kl_div1 = tfd.kl_divergence( tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(post_probs)), reinterpreted_batch_ndims=1), tfd.Independent(tfd.OneHotCategorical(probs=prior_probs), reinterpreted_batch_ndims=1)) kl_div2 = tfd.kl_divergence( tfd.Independent(tfd.OneHotCategorical(probs=post_probs), reinterpreted_batch_ndims=1), tfd.Independent( tfd.OneHotCategorical(probs=tf.stop_gradient(prior_probs)), reinterpreted_batch_ndims=1)) alpha = self.config.kl_alpha kl_loss = alpha * kl_div1 + (1. - alpha) * kl_div2 #: Batch mean kl_loss = tf.reduce_mean(kl_loss) return kl_loss
def _train(self, data, log_images): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.image = tf.reduce_mean(image_pred.log_prob(data[self._c.obs_type])) likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) model_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as actor_tape: imag_feat = self._imagine_ahead(post) reward = tf.cast(self._reward(imag_feat).mode(), 'float') # cast: to address the output of bernoulli if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) value = self._value(imag_feat).mode() returns = tools.lambda_return( reward[:-1], value[:-1], pcont[:-1], bootstrap=value[-1], lambda_=self._c.disclam, axis=0) discount = tf.stop_gradient(tf.math.cumprod(tf.concat( [tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) actor_loss = -tf.reduce_mean(discount * returns) actor_loss /= float(self._strategy.num_replicas_in_sync) with tf.GradientTape() as value_tape: value_pred = self._value(imag_feat)[:-1] target = tf.stop_gradient(returns) value_loss = -tf.reduce_mean(discount * value_pred.log_prob(target)) value_loss /= float(self._strategy.num_replicas_in_sync) model_norm = self._model_opt(model_tape, model_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if tf.distribute.get_replica_context().replica_id_in_sync_group == 0: if self._c.log_scalars: self._scalar_summaries( data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm) if tf.equal(log_images, True): self._image_summaries(data, embed, image_pred) self._reward_summaries(data, reward_pred)
def KL_Beta_Binomial(Z_a, Z_b, X_a, X_b): """Calculate KL divergence between Beta distribution and Binomial likelihood See the relationship between Beta function and binomial coefficient: https://en.wikipedia.org/wiki/Beta_function#Properties """ # TODO: introduce sparse matrix _KL = tfd.kl_divergence(tfd.Beta(Z_a, Z_b), tfd.Beta(X_a + 1, X_b + 1)) _diff_binomLik_to_beta = -tf.math.log(X_a + X_b + 1) return _KL + _diff_binomLik_to_beta
def call(self, inputs): mu, std = inputs #variational distribution var_dist = tfp.MultivariateNormalDiag(loc=mu, scale_diag=std) #prior standard normal distribution pri_dist = tfp.MultivariateNormalDiag(loc=K.zeros_like(mu), scale_diag=K.ones_like(std)) ent_loss = K.mean(var_dist.entropy()) kl_loss = K.mean(tfp.kl_divergence(var_dist, pri_dist)) self.add_loss((1-self.rho)*(self.lamb_kl*kl_loss+self.lamb_ent*ent_loss)) return kl_loss
def kl_divergence(q, p, use_analytic_kl=False, q_sample=lambda q: q.sample(), reduce_axis=(), auto_remove_independent=True, name=None): """ Calculating KL(q(x)||p(x)) Parameters ---------- q : the first distribution p : the second distribution use_analytic_kl : bool (default: False) if True, use the close-form solution for q_sample : {callable, Tensor, Number} callable for extracting sample from `q(x)` (takes q distribution as input argument) reudce_axis : {None, int, tuple} reduce axis when use MCMC to estimate KL divergence, default `()` mean keep all original dimensions auto_remove_independent : bool (default: True) if `q` or `p` is `tfd.Independent` wrapper, get the original distribution for calculating the analytic KL name : {None, str} Returns ------- """ if auto_remove_independent: if isinstance(q, tfd.Independent): q = q.distribution if isinstance(p, tfd.Independent): p = p.distribution q_name = [i for i in q.name.split('/') if len(i) > 0][-1] p_name = [i for i in p.name.split('/') if len(i) > 0][-1] with tf.compat.v1.name_scope(name, "KL_q%s_p%s" % (q_name, p_name)): if bool(use_analytic_kl): return tfd.kl_divergence(q, p) else: if callable(q_sample): z = q_sample(q) elif isinstance(q_sample, Number): z = q.sample(int(q_sample)) else: z = q_sample # calculate the output, then perform reduction kl = q.log_prob(z) - p.log_prob(z) kl = tf.reduce_mean(input_tensor=kl, axis=reduce_axis) return kl
def _train(self, data, log_images): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.image = tf.reduce_mean(image_pred.log_prob(data['laser'])) likes.reward = tf.reduce_mean(reward_pred.log_prob(data['reward'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = self._c.discount * data['discount'] likes.pcont = tf.reduce_mean(pcont_pred.log_prob(pcont_target)) likes.pcont *= self._c.pcont_scale prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) div = tf.maximum(div, self._c.free_nats) model_loss = self._c.kl_scale * div - sum(likes.values()) with tf.GradientTape() as actor_tape: imag_feat = self._imagine_ahead(post) reward = self._reward(imag_feat).mode() if self._c.pcont: pcont = self._pcont(imag_feat).mean() else: pcont = self._c.discount * tf.ones_like(reward) value = self._value(imag_feat).mode() returns = tools.lambda_return(reward[:-1], value[:-1], pcont[:-1], bootstrap=value[-1], lambda_=self._c.disclam, axis=0) discount = tf.stop_gradient( tf.math.cumprod( tf.concat([tf.ones_like(pcont[:1]), pcont[:-2]], 0), 0)) actor_loss = -tf.reduce_mean(discount * returns) with tf.GradientTape() as value_tape: value_pred = self._value(imag_feat)[:-1] target = tf.stop_gradient(returns) value_loss = -tf.reduce_mean( discount * value_pred.log_prob(target)) model_norm = self._model_opt(model_tape, model_loss) actor_norm = self._actor_opt(actor_tape, actor_loss) value_norm = self._value_opt(value_tape, value_loss) if self._c.log_scalars: self._scalar_summaries(data, feat, prior_dist, post_dist, likes, div, model_loss, value_loss, actor_loss, model_norm, value_norm, actor_norm)
def get_loss(self, count_layers, target="ELBO", axis=None, **kwargs): """Loss function per gene (axis=0) or all genes Please be careful: for loss function, you should reduce_sum of each module first then add them up!!! Otherwise, it doesn't work propertly by adding modules first and then reduce_sum. """ ## target function if target == "marginLik": return -tf.reduce_sum(self.logLik_MC( count_layers, target="marginLik", **kwargs), axis=axis) else: return (tf.reduce_sum( tfd.kl_divergence(self.tauDist, self.tauPrior), axis=axis) + tf.reduce_sum(tfd.kl_divergence(self.Z, self.Z_prior), axis=axis) - self.tau_eblo_term(axis=axis) - tf.reduce_sum(self.logLik_MC( count_layers, target="ELBO", **kwargs), axis=axis))
def _model_train_step(self, data, prefix='train'): with tf.GradientTape() as model_tape: embed = self._encode(data) post, prior = self._dynamics.observe(embed, data['action']) feat = self._dynamics.get_feat(post) image_pred = self._decode(feat) reward_pred = self._reward(feat) likes = tools.AttrDict() likes.image = tf.reduce_mean( tf.boolean_mask(image_pred.log_prob(data['image']), data['mask'])) likes.reward = tf.reduce_mean( tf.boolean_mask(reward_pred.log_prob(data['reward']), data['mask'])) if self._c.pcont: pcont_pred = self._pcont(feat) pcont_target = data['terminal'] likes.pcont = tf.reduce_mean( tf.boolean_mask(pcont_pred.log_prob(pcont_target), data['mask'])) likes.pcont *= self._c.pcont_scale for key in prior.keys(): prior[key] = tf.boolean_mask(prior[key], data['mask']) post[key] = tf.boolean_mask(post[key], data['mask']) prior_dist = self._dynamics.get_dist(prior) post_dist = self._dynamics.get_dist(post) div = tf.reduce_mean(tfd.kl_divergence(post_dist, prior_dist)) model_loss = self._c.kl_scale * div - sum(likes.values()) if prefix == 'train': model_norm = self._model_opt(model_tape, model_loss) self._model_step += 1 if self._model_step % self._c.log_every == 0: self._image_summaries(data, embed, image_pred, self._model_step, prefix) model_summaries = dict() model_summaries['model_train/KL Divergence'] = tf.reduce_mean(div) model_summaries['model_train/image_recon'] = tf.reduce_mean( likes.image) model_summaries['model_train/reward_recon'] = tf.reduce_mean( likes.reward) model_summaries['model_train/model_loss'] = tf.reduce_mean( model_loss) if prefix == 'train': model_summaries['model_train/model_norm'] = tf.reduce_mean( model_norm) if self._c.pcont: model_summaries['model_train/terminal_recon'] = tf.reduce_mean( likes.pcont) self._write_summaries(model_summaries, self._model_step)
def kullback_leibler_x(self, x_param): x_mu = x_param[:, :, :self.dim_latent] x_var = tf.nn.softplus(x_param[:, :, self.dim_latent:]) qx = tfd.Normal(x_mu, x_var) px_mu = tf.zeros_like(x_mu) px_var = tf.ones_like(x_var) px = tfd.Normal(px_mu, px_var) kl_x = tfd.kl_divergence(qx, px) kl_x = tf.reduce_sum(kl_x, [2, 1]) kl_x = tf.reduce_mean(kl_x) return kl_x
def call(self, inputs): prob, y = inputs uni_prob = K.ones_like(prob)/tf.cast(K.shape(prob)[-1], tf.float32) #variational distribution var_dist = tfp.Categorical(prob) #prior uniform distribution pri_dist = tfp.Categorical(uni_prob) ent_loss = K.mean(var_dist.entropy()) kl_loss = K.mean(tfp.kl_divergence(var_dist, pri_dist)) cent_loss = K.mean(K.categorical_crossentropy(y, prob)) self.add_loss( self.rho*cent_loss + \ (1-self.rho)*(self.lamb_kl*kl_loss+self.lamb_ent*ent_loss) ) return kl_loss
def loss(self, prior, latent_dist, target): latent_sample = latent_dist.sample() reconstruction = self.decode(latent_sample) recon_loss = -tf.concat([ reconstruction.log_prob(target), ], axis=1) recon_loss = tf.clip_by_value(recon_loss, -1e5, 1e5) recon_loss = tf.reduce_sum(recon_loss) latent_loss = tfd.kl_divergence(latent_dist, self.prior) latent_loss = tf.clip_by_value(latent_loss, -1e5, 1e5) return latent_loss + recon_loss
def network(inputs, config): init_std = np.log(np.exp(config.weight_std) - 1).astype(np.float32) hidden = inputs # Define hidden layers according to config.layer_sizes for size in config.layer_sizes: hidden = tf.layers.dense(hidden, size, tf.nn.leaky_relu) #Define posterior distribution on weights as a Normal distribution with initial parameters 0 and config.weight_std kernel_posterior = tfd.Independent( tfd.Normal( tf.get_variable('kernel_mean', (hidden.shape[-1].value, 1), tf.float32, tf.random_normal_initializer(0, config.weight_std)), tf.nn.softplus( tf.get_variable('kernel_std', (hidden.shape[-1].value, 1), tf.float32, tf.constant_initializer(init_std)))), 2) #Define prior distribution on weights as Normal distribution kernel_prior = tfd.Independent( tfd.Normal( tf.zeros_like(kernel_posterior.mean()), tf.zeros_like(kernel_posterior.mean()) + tf.nn.softplus(init_std)), 2) bias_prior = None bias_posterior = tfd.Deterministic( tf.get_variable('bias_mean', (1, ), tf.float32, tf.constant_initializer(0.0))) tf.add_to_collection(tf.GraphKeys.REGULARIZATION_LOSSES, tfd.kl_divergence(kernel_posterior, kernel_prior)) #Create final bayesian layer which computes the mean mean = tfp.layers.DenseReparameterization( 1, kernel_prior_fn=lambda *args, **kwargs: kernel_prior, kernel_posterior_fn=lambda *args, **kwargs: kernel_posterior, bias_prior_fn=lambda *args, **kwargs: bias_prior, bias_posterior_fn=lambda *args, **kwargs: bias_posterior)(hidden) #Compute distribution of the mean mean_dist = tfd.Normal( tf.matmul(hidden, kernel_posterior.mean()) + bias_posterior.mean(), tf.sqrt(tf.matmul(hidden**2, kernel_posterior.variance()))) #Compute standard deviation through final non-bayesian dense layer (in parallel with mean layer) std = tf.layers.dense(hidden, 1, tf.nn.softplus) + 1e-6 data_dist = tfd.Normal(mean, std) return data_dist, mean_dist
def call(self, inputs): if self.encodertype == "user": mu, std= inputs var_dist = tfp.MultivariateNormalDiag(loc=mu, scale_diag=std) pri_dist = tfp.MultivariateNormalDiag(loc=K.zeros_like(mu), scale_diag=K.ones_like(std)) else: mu, std, uemb = inputs var_dist = tfp.MultivariateNormalDiag(loc=mu, scale_diag=std) ### with user prior uemb_mu = uemb[:,0,:] uemb_std = uemb[:,1,:] pri_dist = tfp.MultivariateNormalDiag(loc=uemb_mu, scale_diag=uemb_std) ### without user prior # pri_dist = tfp.MultivariateNormalDiag(loc=K.zeros_like(mu), scale_diag=K.ones_like(std)) kl_loss = self.lamb_kl*K.mean(tfp.kl_divergence(var_dist, pri_dist)) return kl_loss
def perform_fwd_pass(self): mean, log_var = self.nets.encoder(self.x) stddev = tf.exp(0.5 * log_var) qz_x = tfpd.Normal(loc=mean, scale=stddev) z = qz_x.sample() logits = self.nets.decoder(z) px_z = tfpd.Bernoulli(logits=logits) p_z = tfpd.Normal(loc=tf.zeros_like(z), scale=tf.ones_like(z)) kl = tf.reduce_sum(tfpd.kl_divergence(qz_x, p_z), axis=1) expected_log_likelihood = tf.reduce_sum(px_z.log_prob(self.x), axis=(1, 2, 3)) self.elbo = tf.reduce_mean(expected_log_likelihood - kl, axis=0)
def kl_divergence(q, p, use_analytic_kl=False, q_sample=lambda q: q.sample(), reduce_axis=(), name=None): """ Calculating KL(q(x)||p(x)) Parameters ---------- q : the first distribution p : the second distribution use_analytic_kl : boolean if True, use the close-form solution for q_sample : {callable, Tensor} callable for extracting sample from `q(x)` (takes q distribution as input argument) reudce_axis : {None, int, tuple} reduce axis when use MCMC to estimate KL divergence, default `()` mean keep all original dimensions """ q_name = [i for i in q.name.split('/') if len(i) > 0][-1] p_name = [i for i in p.name.split('/') if len(i) > 0][-1] with tf.compat.v1.name_scope(name, "KL_q%s_p%s" % (q_name, p_name)): if bool(use_analytic_kl): return tfd.kl_divergence(q, p) else: if callable(q_sample): z = q_sample(q) else: z = q_sample # calculate the output, then perform reduction kl = q.log_prob(z) - p.log_prob(z) kl = tf.reduce_mean(input_tensor=kl, axis=reduce_axis) return kl