class GanWAE(object): def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) pz_sample = self.dist_pz.sample() elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]]) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) assert pz_sample.shape.ndims == 2 pz_logits = model._marginal_energy(pz_sample) qz_logits = model._marginal_energy(model.z) self.gp_loss = tf.reduce_mean(energ_emb(model.z)**2) * 2 self.score_loss = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(pz_logits) + 1e-7)) +\ tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(qz_logits) + 1e-7))) self.score_opt_op = optimize( self.score_loss + args.grad_penalty * self.gp_loss, [MARGINAL_ENERGY], args) self.kl = -tf.reduce_mean(tf.math.log_sigmoid(qz_logits)) # non-saturating GAN loss self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/disc': self.score_loss, 'loss/gp': self.gp_loss, 'loss/kl': self.kl } self.lc = locals() def step(self, sess, fd): sess.run(self.wae_opt_op, fd) for j in range(self.args.train_score_dupl): sess.run(self.score_opt_op, fd)
class ImplicitVAE(object): def __init__(self, model, args): self.model = model self.args = args # binary cross entropy error assert args.observation == 'sigmoid', NotImplemented self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) def energ_emb(z): return model._cond_energy(z, model.x) assert args.mpf_method == 'ld', NotImplemented if args.latent == 'euc': y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr) elif args.latent == 'sph': y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr) self.mpf_loss = -tf.reduce_mean(neg_mpf_loss) self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2) self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss self.score_opt_op = optimize(self.score_loss, [COND_ENERGY], args) if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) else: raise NotImplemented # Eq log(q/p) self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._cond_energy(model.z, model.x)) self.ELBO = -self.reconstruction_loss - self.kl self.elbo_opt_op = optimize(-self.ELBO, [ENCODER, DECODER], args) self.print = { 'loss/reconloss': self.reconstruction_loss, 'loss/ELBO': self.ELBO, 'loss/approx_KL': self.kl, 'loss/mpf': self.mpf_loss, 'loss/gp': self.gp_loss, 'e/avg': tf.reduce_mean(model._cond_energy(model.z, model.x)) } self.lc = locals() def step(self, sess, fd): sess.run(self.elbo_opt_op, fd) for j in range(self.args.train_score_dupl): sess.run(self.score_opt_op, fd)
class MMDWAE(object): def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) pz_sample = self.dist_pz.sample() elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]]) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) assert pz_sample.shape.ndims == 2 self.kl = matching_loss = mmd(model.z, pz_sample) self.wae_loss = self.reconstruction_loss + self.kl * (args.wae_lambda*100) self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/kl': self.kl } self.lc = locals() def step(self, sess, fd): sess.run(self.wae_opt_op, fd)
class AIS(object): def __init__(self, x_ph, log_likelihood_fn, dims, num_samples=16, method='hmc', config=None): """ The model implements Hamiltonian AIS. Developed by @bilginhalil on top of https://github.com/jiamings/ais/ Example use case: logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)} p(x|z, theta) -> likelihood function p(z) -> prior Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix :param x_ph: Placeholder for x :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :) :param num_samples: Number of samples to sample from in order to estimate the likelihood. The following are parameters for HMC. :param stepsize: :param n_steps: :param target_acceptance_rate: :param avg_acceptance_slowness: :param stepsize_min: :param stepsize_max: :param stepsize_dec: :param stepsize_inc: """ self.dims = dims self.log_likelihood_fn = log_likelihood_fn self.num_samples = num_samples self.z_shape = [ dims['batch_size'] * self.num_samples, dims['input_dim'] ] if method != 'riem_ld': self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape), scale_diag=tf.ones( self.z_shape)) else: self.prior = HypersphericalUniform(dims['input_dim'] - 1) self.batch_size = dims['batch_size'] self.x = tf.tile(x_ph, [self.num_samples, 1]) self.method = method self.config = config if config is not None else default_config[method] def log_f_i(self, z, t): return tf.reshape(-self.energy_fn(z, t), [self.num_samples, self.batch_size]) def energy_fn(self, z, t): e = self.prior.log_prob(z) assert e.shape.ndims == 1 e += t * tf.reshape(self.log_likelihood_fn(self.x, z), [self.num_samples * self.batch_size]) assert e.shape.ndims == 1 return -e def ais(self, schedule): """ :param schedule: temperature schedule i.e. `p(z)p(x|z)^t` :return: [batch_size] """ cfg = self.config if isinstance(self.prior, tfd.MultivariateNormalDiag): z = self.prior.sample() else: z = self.prior.sample([self.num_samples * self.batch_size]) assert z.shape.ndims == 2 index_summation = (tf.constant(0), tf.zeros([self.num_samples, self.batch_size]), tf.cast(z, tf.float32), cfg.stepsize, cfg.target_acceptance_rate) items = tf.unstack( tf.convert_to_tensor([[ i, t0, t1 ] for i, (t0, t1) in enumerate(zip(schedule[:-1], schedule[1:]))])) def condition(index, summation, z, stepsize, avg_acceptance_rate): return tf.less(index, len(schedule) - 1) def body(index, w, z, stepsize, avg_acceptance_rate): item = tf.gather(items, index) t0 = tf.gather(item, 1) t1 = tf.gather(item, 2) new_u = self.log_f_i(z, t1) prev_u = self.log_f_i(z, t0) w = tf.add(w, new_u - prev_u) def run_energy(z): e = self.energy_fn(z, t1) if self.method != 'hmc': e = e[:, None] with tf.control_dependencies([e]): return e # New step: if self.method == 'hmc': accept, final_pos, final_vel = hmc_move( z, run_energy, stepsize, cfg.n_steps) new_z, new_stepsize, new_acceptance_rate = hmc_updates( z, stepsize, avg_acceptance_rate=avg_acceptance_rate, final_pos=final_pos, accept=accept, stepsize_min=cfg.stepsize_min, stepsize_max=cfg.stepsize_max, stepsize_dec=cfg.stepsize_dec, stepsize_inc=cfg.stepsize_inc, target_acceptance_rate=cfg.target_acceptance_rate, avg_acceptance_slowness=cfg.avg_acceptance_slowness) elif self.method.endswith('ld'): new_z, cur_acc_rate = ld_move(z, run_energy, stepsize, cfg.n_steps, self.method) new_stepsize, new_acceptance_rate = ld_update( stepsize, cur_acc_rate=cur_acc_rate, hist_acc_rate=avg_acceptance_rate, target_acc_rate=cfg.target_acceptance_rate, ssz_inc=cfg.stepsize_inc, ssz_dec=cfg.stepsize_dec, ssz_min=cfg.stepsize_min, ssz_max=cfg.stepsize_max, avg_acc_decay=cfg.avg_acceptance_slowness) return tf.add(index, 1), w, new_z, new_stepsize, new_acceptance_rate i, w, _, final_stepsize, final_acc_rate = tf.while_loop( condition, body, index_summation, parallel_iterations=1, swap_memory=True) # w = tf.Print(w, [final_stepsize, final_acc_rate], 'ff') return tf.squeeze(log_mean_exp(w, axis=0), axis=0)
class ImplicitWAE(object): def __init__(self, model, args): self.model = model self.args = args im_axes = list(range(1, model.x.shape.ndims)) if args.observation == 'normal': self.reconstruction_loss = tf.reduce_mean( tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes)) elif args.observation == 'sigmoid': self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits) self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1)) else: raise NotImplemented def energ_emb(z): return model._marginal_energy(z) if args.latent == 'euc': if args.mpf_method == 'ld': y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr) else: y, neg_mpf_loss = mpf_euc_spos( model.z, energ_emb, args.mpf_lr, alpha=args.mpf_spos_alpha) elif args.latent == 'sph' and args.mpf_method == 'ld': y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr) else: raise NotImplemented self.mpf_loss = tf.reduce_mean(-neg_mpf_loss) * 1e-3 / args.mpf_lr self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2) self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss self.score_opt_op = optimize(self.score_loss, [MARGINAL_ENERGY], args) if args.latent == 'euc': self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z)) self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1) elif args.latent == 'sph': self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype) self.log_pz = self.dist_pz.log_prob(model.z) else: raise NotImplemented # KL = Eq(logq - logp) = Eq(-logp - energy_q) self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._marginal_energy(model.z)) self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args) self.print = { 'loss/recon': self.reconstruction_loss, 'loss/wae': self.wae_loss, 'loss/mpf': self.mpf_loss, 'loss/gp': self.gp_loss } self.lc = locals() def step(self, sess, fd): sess.run(self.wae_opt_op, fd) for j in range(self.args.train_score_dupl): sess.run(self.score_opt_op, fd)