class GanWAE(object): def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) pz_sample = self.dist_pz.sample() elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]]) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) assert pz_sample.shape.ndims == 2 pz_logits = model._marginal_energy(pz_sample) qz_logits = model._marginal_energy(model.z) self.gp_loss = tf.reduce_mean(energ_emb(model.z)**2) * 2 self.score_loss = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(pz_logits) + 1e-7)) +\ tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(qz_logits) + 1e-7))) self.score_opt_op = optimize( self.score_loss + args.grad_penalty * self.gp_loss, [MARGINAL_ENERGY], args) self.kl = -tf.reduce_mean(tf.math.log_sigmoid(qz_logits)) # non-saturating GAN loss self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/disc': self.score_loss, 'loss/gp': self.gp_loss, 'loss/kl': self.kl } self.lc = locals() def step(self, sess, fd): sess.run(self.wae_opt_op, fd) for j in range(self.args.train_score_dupl): sess.run(self.score_opt_op, fd)
def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) if args.latent == 'euc': if args.mpf_method == 'ld': y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr) else: y, neg_mpf_loss = mpf_euc_spos( model.z, energ_emb, args.mpf_lr, alpha=args.mpf_spos_alpha) elif args.latent == 'sph' and args.mpf_method == 'ld': y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr) else: raise NotImplemented self.mpf_loss = tf.reduce_mean(-neg_mpf_loss) * 1e-3 / args.mpf_lr self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2) self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss self.score_opt_op = optimize(self.score_loss, [MARGINAL_ENERGY], args) if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) else: raise NotImplemented # KL = Eq(logq - logp) = Eq(-logp - energy_q) self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._marginal_energy(model.z)) self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/mpf': self.mpf_loss, 'loss/gp': self.gp_loss } self.lc = locals()
def __init__(self, x_ph, log_likelihood_fn, dims, num_samples=16, method='hmc', config=None): """ The model implements Hamiltonian AIS. Developed by @bilginhalil on top of https://github.com/jiamings/ais/ Example use case: logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)} p(x|z, theta) -> likelihood function p(z) -> prior Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix :param x_ph: Placeholder for x :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :) :param num_samples: Number of samples to sample from in order to estimate the likelihood. The following are parameters for HMC. :param stepsize: :param n_steps: :param target_acceptance_rate: :param avg_acceptance_slowness: :param stepsize_min: :param stepsize_max: :param stepsize_dec: :param stepsize_inc: """ self.dims = dims self.log_likelihood_fn = log_likelihood_fn self.num_samples = num_samples self.z_shape = [ dims['batch_size'] * self.num_samples, dims['input_dim'] ] if method != 'riem_ld': self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape), scale_diag=tf.ones( self.z_shape)) else: self.prior = HypersphericalUniform(dims['input_dim'] - 1) self.batch_size = dims['batch_size'] self.x = tf.tile(x_ph, [self.num_samples, 1]) self.method = method self.config = config if config is not None else default_config[method]
def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) # encode self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(X[0]) self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(X[1]) # reparametrize self.Q_r0 = Normal(self.r0_m, self.r0_v) self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v)) self.r0 = self.Q_r0.rsample() self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi0 = self.Q_phi0.rsample() while torch.isnan(self.phi0).any(): self.phi0 = self.Q_phi0.rsample() # estimate velocity self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0]) self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1]-self.t_eval[0]) # predict z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1) self.qT = self.qT.view(T*self.bs, 3) # decode self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d) return None
def make_prior(code_size, distribution, alt_prior): """ Returns the prior on embeddings for tensorflow distributions (i) MultivariateNormalDiag function (ii) HypersphericalUniform with alternative prior on gaussian (1) Alt: N(0,1/code_size) (2) N(0,1) """ if distribution == 'normal': if alt_prior: #alternative prior 0,1/embeddings variance loc = tf.zeros(code_size) scale = tf.sqrt(tf.divide(tf.ones(code_size), code_size)) else: loc = tf.zeros(code_size) scale = tf.ones(code_size) dist = tfd.MultivariateNormalDiag(loc, scale) elif distribution == 'vmf': dist = HypersphericalUniform(code_size - 1, dtype=tf.float32) else: raise NotImplemented return dist
def forward(self, X, u): [_, self.bs, d, d] = X.shape T = len(self.t_eval) # encode self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape( self.bs, d * d)) self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape( self.bs, d * d)) # reparametrize self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) self.P_q = HypersphericalUniform(1, device=self.device) self.q0 = self.Q_q.rsample() # bs, 2 while torch.isnan(self.q0).any(): self.q0 = self.Q_q.rsample() # a bad way to avoid nan # estimate velocity self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1) self.qT = self.qT.view(T * self.bs, 2) # decode self.Xrec = self.obs_net(self.qT).view([T, self.bs, d, d]) return None
def __init__(self, model, args): """ OptimizerVAE initializer :param model: a model object :param learning_rate: float, learning rate of the optimizer """ # binary cross entropy error assert args.observation == 'sigmoid', NotImplemented self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) if args.latent == 'euc': # KL divergence between normal approximate posterior and standard normal prior self.p_z = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) kl = model.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=-1)) elif args.latent == 'sph': # KL divergence between vMF approximate posterior and uniform hyper-spherical prior self.p_z = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) kl = model.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(kl) else: raise NotImplemented self.ELBO = - self.reconstruction_loss - self.kl self.train_step = optimize(-self.ELBO, None, args) self.print = {'loss/recon': self.reconstruction_loss, 'loss/ELBO': self.ELBO, 'loss/KL': self.kl}
def forward(self, X, u): [_, self.bs, d, d] = X.shape T = len(self.t_eval) # encode self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(self.bs, d*d)) self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(self.bs, d*d)) # reparametrize self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) self.P_q = HypersphericalUniform(1, device=self.device) self.q0 = self.Q_q.rsample() # bs, 2 while torch.isnan(self.q0).any(): self.q0 = self.Q_q.rsample() # a bad way to avoid nan # estimate velocity self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1]-self.t_eval[0]) # predict z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1) self.qT = self.qT.view(T*self.bs, 2) # decode ones = torch.ones_like(self.qT[:,0:1]) self.content = self.obs_net(ones) theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T*self.bs) # cos , sin grid = F.affine_grid(theta, torch.Size((T*self.bs, 1, d, d))) self.Xrec = F.grid_sample(self.content.view(T*self.bs, 1, d, d), grid) self.Xrec = self.Xrec.view([T, self.bs, d, d]) return None
class MMDWAE(object): def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) pz_sample = self.dist_pz.sample() elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]]) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) assert pz_sample.shape.ndims == 2 self.kl = matching_loss = mmd(model.z, pz_sample) self.wae_loss = self.reconstruction_loss + self.kl * (args.wae_lambda*100) self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/kl': self.kl } self.lc = locals() def step(self, sess, fd): sess.run(self.wae_opt_op, fd)
def _vmf_log_likelihood(self, sample, location=None, kappa=None): """Get the log likelihood of a sample under the vMF distribution with location and kappa.""" if location is None and kappa is None: return HypersphericalUniform(self.z_dim - 1, device=self.device).log_prob(sample) elif location is not None and kappa is not None: return VonMisesFisher(location, kappa).log_prob(sample) else: raise InvalidArgumentError("Provide either location and kappa or neither.")
def __init__(self, model, args): self.model = model self.args = args # binary cross entropy error assert args.observation == 'sigmoid', NotImplemented self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) def energ_emb(z): return model._cond_energy(z, model.x) assert args.mpf_method == 'ld', NotImplemented if args.latent == 'euc': y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr) elif args.latent == 'sph': y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr) self.mpf_loss = -tf.reduce_mean(neg_mpf_loss) self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2) self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss self.score_opt_op = optimize(self.score_loss, [COND_ENERGY], args) if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) else: raise NotImplemented # Eq log(q/p) self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._cond_energy(model.z, model.x)) self.ELBO = -self.reconstruction_loss - self.kl self.elbo_opt_op = optimize(-self.ELBO, [ENCODER, DECODER], args) self.print = { 'loss/reconloss': self.reconstruction_loss, 'loss/ELBO': self.ELBO, 'loss/approx_KL': self.kl, 'loss/mpf': self.mpf_loss, 'loss/gp': self.gp_loss, 'e/avg': tf.reduce_mean(model._cond_energy(model.z, model.x)) } self.lc = locals()
def reparameterize(self, z_mean, z_var): if self.distribution == 'normal': q_z = torch.distributions.normal.Normal(z_mean, z_var) p_z = torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var)) elif self.distribution == 'vmf': q_z = VonMisesFisher(z_mean, z_var) p_z = HypersphericalUniform(self.z_dim - 1) else: raise NotImplemented return q_z, p_z
def __init__(self, model, learning_rate=1e-3): """ OptimizerVAE initializer :param model: a model object :param learning_rate: float, learning rate of the optimizer """ self.kl_weight = tf.placeholder_with_default(np.array(1.).astype( np.float64), shape=()) # binary cross entropy error self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.score = tf.reduce_sum(self.bce, axis=-1) print('s1', self.score) print(model.distribution) if model.distribution == 'normal': # KL divergence between normal approximate posterior and standard normal prior self.p_z = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) kl = model.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=-1)) elif model.distribution == 'vmf': # KL divergence between vMF approximate posterior and uniform hyper-spherical prior self.p_z = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) kl = model.q_z.kl_divergence(self.p_z) self.kl = tf.reduce_mean(kl) self.score = -model.q_z.add_g_cor(-self.score) else: raise NotImplemented self.reconstruction_loss = tf.reduce_mean(self.score) self.ELBO = -self.reconstruction_loss - self.kl self.train_step = tf.train.AdamOptimizer( learning_rate=learning_rate).minimize(self.reconstruction_loss + self.kl * self.kl_weight) self.print = { 'recon loss': self.reconstruction_loss, 'ELBO': self.ELBO, 'KL': self.kl, 'KL weight': self.kl_weight }
def kl_distance(self): if self.vtype == "gauss": self.prior = tf.distributions.Normal( tf.zeros(self.central_state_size), tf.ones(self.central_state_size)) self.kl = self.central_distribution.kl_divergence(self.prior) loss_kl = tf.reduce_mean(tf.reduce_sum(self.kl, axis=1)) elif self.vtype == 'vmf': self.prior = HypersphericalUniform(self.central_state_size - 1, dtype=tf.float32) self.kl = self.central_distribution.kl_divergence(self.prior) loss_kl = tf.reduce_mean(self.kl) else: raise NotImplemented return loss_kl
def _vmf_sample_z(self, location, kappa, shape, det): """Reparameterized sample from a vMF distribution with location and concentration kappa.""" if location is None and kappa is None and shape is not None: if det: raise InvalidArgumentError("Cannot deterministically sample from the Uniform on a Hypersphere.") else: return HypersphericalUniform(self.z_dim - 1, device=self.device).sample(shape[:-1]) elif location is not None and kappa is not None: if det: return location if self.training: return VonMisesFisher(location, kappa).rsample() else: return VonMisesFisher(location, kappa).sample() else: raise InvalidArgumentError("Either provide location and kappa or neither with a shape.")
def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) self.link1_l = torch.sigmoid(self.link1_para) # encode self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode( X[0]) self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode( X[1]) # reparametrize self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0) self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi1_t0 = self.Q_phi1.rsample() while torch.isnan(self.phi1_t0).any(): self.phi1_t0 = self.Q_phi1.rsample() self.phi2_t0 = self.Q_phi2.rsample() while torch.isnan(self.phi2_t0).any(): self.phi2_t0 = self.Q_phi2.rsample() # estimate velocity self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0, self.phi1_m_n_t1, self.t_eval[1] - self.t_eval[0]) self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0, self.phi2_m_n_t1, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat([ self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2], self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u ], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1) self.qT = self.qT.view(T * self.bs, 4) # decode self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d) return None
def sampled_z(self, mu, sigma, batch_size): if self.distribution == 'normal': epsilon = tf.random_normal( tf.stack([int(batch_size), self.n_latent_units])) z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma)) loss = tf.reduce_mean( -0.5 * self.beta * tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1)) elif self.distribution == 'vmf': self.q_z = VonMisesFisher(mu, sigma, validate_args=True, allow_nan_stats=False) z = self.q_z.sample() self.p_z = HypersphericalUniform(self.n_latent_units, validate_args=True, allow_nan_stats=False) loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z)) else: raise NotImplemented return z, loss
def forward(self, inputs, lengths, dist='normal', fix=True): inputs = pack(self.drop(inputs), lengths, batch_first=True) _, hn = self.rnn(inputs) h = torch.cat(hn, dim=2).squeeze(0) if dist == 'normal': p_z = Normal( torch.zeros((h.size(0), self.code_dim), device=h.device), (0.5 * torch.zeros( (h.size(0), self.code_dim), device=h.device)).exp()) mu, lv = self.fcmu(h), self.fclv(h) if self.bn: mu, lv = self.bnmu(mu), self.bnlv(lv) return hn, Normal(mu, (0.5 * lv).exp()), p_z elif dist == 'vmf': mu = self.fcmu(h) mu = mu / mu.norm(dim=-1, keepdim=True) var = F.softplus(self.fcvar(h)) + 1 if fix: var = torch.ones_like(var) * 80 return hn, VonMisesFisher(mu, var), HypersphericalUniform( self.code_dim - 1, device=mu.device) else: raise NotImplementedError
def reparameterize(self, z_mean, z_var): q_z = VonMisesFisher(z_mean, z_var) p_z = HypersphericalUniform(self.z_dim - 1) return q_z, p_z
class AIS(object): def __init__(self, x_ph, log_likelihood_fn, dims, num_samples=16, method='hmc', config=None): """ The model implements Hamiltonian AIS. Developed by @bilginhalil on top of https://github.com/jiamings/ais/ Example use case: logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)} p(x|z, theta) -> likelihood function p(z) -> prior Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix :param x_ph: Placeholder for x :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :) :param num_samples: Number of samples to sample from in order to estimate the likelihood. The following are parameters for HMC. :param stepsize: :param n_steps: :param target_acceptance_rate: :param avg_acceptance_slowness: :param stepsize_min: :param stepsize_max: :param stepsize_dec: :param stepsize_inc: """ self.dims = dims self.log_likelihood_fn = log_likelihood_fn self.num_samples = num_samples self.z_shape = [ dims['batch_size'] * self.num_samples, dims['input_dim'] ] if method != 'riem_ld': self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape), scale_diag=tf.ones( self.z_shape)) else: self.prior = HypersphericalUniform(dims['input_dim'] - 1) self.batch_size = dims['batch_size'] self.x = tf.tile(x_ph, [self.num_samples, 1]) self.method = method self.config = config if config is not None else default_config[method] def log_f_i(self, z, t): return tf.reshape(-self.energy_fn(z, t), [self.num_samples, self.batch_size]) def energy_fn(self, z, t): e = self.prior.log_prob(z) assert e.shape.ndims == 1 e += t * tf.reshape(self.log_likelihood_fn(self.x, z), [self.num_samples * self.batch_size]) assert e.shape.ndims == 1 return -e def ais(self, schedule): """ :param schedule: temperature schedule i.e. `p(z)p(x|z)^t` :return: [batch_size] """ cfg = self.config if isinstance(self.prior, tfd.MultivariateNormalDiag): z = self.prior.sample() else: z = self.prior.sample([self.num_samples * self.batch_size]) assert z.shape.ndims == 2 index_summation = (tf.constant(0), tf.zeros([self.num_samples, self.batch_size]), tf.cast(z, tf.float32), cfg.stepsize, cfg.target_acceptance_rate) items = tf.unstack( tf.convert_to_tensor([[ i, t0, t1 ] for i, (t0, t1) in enumerate(zip(schedule[:-1], schedule[1:]))])) def condition(index, summation, z, stepsize, avg_acceptance_rate): return tf.less(index, len(schedule) - 1) def body(index, w, z, stepsize, avg_acceptance_rate): item = tf.gather(items, index) t0 = tf.gather(item, 1) t1 = tf.gather(item, 2) new_u = self.log_f_i(z, t1) prev_u = self.log_f_i(z, t0) w = tf.add(w, new_u - prev_u) def run_energy(z): e = self.energy_fn(z, t1) if self.method != 'hmc': e = e[:, None] with tf.control_dependencies([e]): return e # New step: if self.method == 'hmc': accept, final_pos, final_vel = hmc_move( z, run_energy, stepsize, cfg.n_steps) new_z, new_stepsize, new_acceptance_rate = hmc_updates( z, stepsize, avg_acceptance_rate=avg_acceptance_rate, final_pos=final_pos, accept=accept, stepsize_min=cfg.stepsize_min, stepsize_max=cfg.stepsize_max, stepsize_dec=cfg.stepsize_dec, stepsize_inc=cfg.stepsize_inc, target_acceptance_rate=cfg.target_acceptance_rate, avg_acceptance_slowness=cfg.avg_acceptance_slowness) elif self.method.endswith('ld'): new_z, cur_acc_rate = ld_move(z, run_energy, stepsize, cfg.n_steps, self.method) new_stepsize, new_acceptance_rate = ld_update( stepsize, cur_acc_rate=cur_acc_rate, hist_acc_rate=avg_acceptance_rate, target_acc_rate=cfg.target_acceptance_rate, ssz_inc=cfg.stepsize_inc, ssz_dec=cfg.stepsize_dec, ssz_min=cfg.stepsize_min, ssz_max=cfg.stepsize_max, avg_acc_decay=cfg.avg_acceptance_slowness) return tf.add(index, 1), w, new_z, new_stepsize, new_acceptance_rate i, w, _, final_stepsize, final_acc_rate = tf.while_loop( condition, body, index_summation, parallel_iterations=1, swap_memory=True) # w = tf.Print(w, [final_stepsize, final_acc_rate], 'ff') return tf.squeeze(log_mean_exp(w, axis=0), axis=0)
def _vmf_kl_divergence(self, location, kappa): """Get the estimated KL between the VMF function with a uniform hyperspherical prior.""" return kl_divergence( VonMisesFisher(location, kappa), HypersphericalUniform(self.z_dim - 1, device=self.device))
def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) self.link1_l = torch.sigmoid(self.link1_para) # encode self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode( X[0]) self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode( X[1]) # reparametrize self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0) self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi1_t0 = self.Q_phi1.rsample() while torch.isnan(self.phi1_t0).any(): self.phi1_t0 = self.Q_phi1.rsample() self.phi2_t0 = self.Q_phi2.rsample() while torch.isnan(self.phi2_t0).any(): self.phi2_t0 = self.Q_phi2.rsample() # estimate velocity self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0, self.phi1_m_n_t1, self.t_eval[1] - self.t_eval[0]) self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0, self.phi2_m_n_t1, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat([ self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2], self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u ], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1) self.qT = self.qT.view(T * self.bs, 4) # decode ones = torch.ones_like(self.qT[:, 0:1]) self.link1 = self.obs_net_1(ones) self.link2 = self.obs_net_2(ones) theta1 = self.get_theta_inv(self.qT[:, 0], self.qT[:, 2], 0, 0, bs=T * self.bs) # cos phi1, sin phi1 x = self.link1_l * self.qT[:, 2] # l * sin phi1 y = self.link1_l * self.qT[:, 0] # l * cos phi 1 theta2 = self.get_theta_inv(self.qT[:, 1], self.qT[:, 3], x, y, bs=T * self.bs) # cos phi2, sin phi 2 grid1 = F.affine_grid(theta1, torch.Size((T * self.bs, 1, self.d, self.d))) grid2 = F.affine_grid(theta2, torch.Size((T * self.bs, 1, self.d, self.d))) transf_link1 = F.grid_sample( self.link1.view(T * self.bs, 1, self.d, self.d), grid1) transf_link2 = F.grid_sample( self.link2.view(T * self.bs, 1, self.d, self.d), grid2) self.Xrec = torch.cat( [transf_link1, transf_link2, torch.zeros_like(transf_link1)], dim=1) self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d) return None
def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) # encode self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode( X[0]) self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode( X[1]) # reparametrize self.Q_r0 = Normal(self.r0_m, self.r0_v) self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v)) self.r0 = self.Q_r0.rsample() self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi0 = self.Q_phi0.rsample() while torch.isnan(self.phi0).any(): self.phi0 = self.Q_phi0.rsample() # estimate velocity self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0]) self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1) self.qT = self.qT.view(T * self.bs, 3) # decode ones = torch.ones_like(self.qT[:, 0:1]) self.cart = self.obs_net_1(ones) self.pole = self.obs_net_2(ones) theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs) theta2 = self.get_theta_inv(self.qT[:, 1], self.qT[:, 2], self.qT[:, 0], 0, bs=T * self.bs) grid1 = F.affine_grid(theta1, torch.Size((T * self.bs, 1, self.d, self.d))) grid2 = F.affine_grid(theta2, torch.Size((T * self.bs, 1, self.d, self.d))) transf_cart = F.grid_sample( self.cart.view(T * self.bs, 1, self.d, self.d), grid1) transf_pole = F.grid_sample( self.pole.view(T * self.bs, 1, self.d, self.d), grid2) self.Xrec = torch.cat( [transf_cart, transf_pole, torch.zeros_like(transf_cart)], dim=1) self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d) return None
def reparameterize(self, z_mean, z_kappa): q_z = VonMisesFisher(z_mean, z_kappa) p_z = HypersphericalUniform(z_mean.size(1) - 1, device=DEVICE) return q_z, p_z