Exemplo n.º 1
0
class GanWAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
            pz_sample = self.dist_pz.sample()
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
            pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]])
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        assert pz_sample.shape.ndims == 2
        pz_logits = model._marginal_energy(pz_sample)
        qz_logits = model._marginal_energy(model.z)
        self.gp_loss = tf.reduce_mean(energ_emb(model.z)**2) * 2
        self.score_loss = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(pz_logits) + 1e-7)) +\
                            tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(qz_logits) + 1e-7)))
        self.score_opt_op = optimize(
            self.score_loss + args.grad_penalty * self.gp_loss, [MARGINAL_ENERGY], args)

        self.kl = -tf.reduce_mean(tf.math.log_sigmoid(qz_logits)) # non-saturating GAN loss
        self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/disc': self.score_loss,
            'loss/gp': self.gp_loss,
            'loss/kl': self.kl
        }

        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.wae_opt_op, fd)
        for j in range(self.args.train_score_dupl):
            sess.run(self.score_opt_op, fd)
Exemplo n.º 2
0
class ImplicitVAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        # binary cross entropy error
        assert args.observation == 'sigmoid', NotImplemented
        self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
        self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))

        def energ_emb(z):
            return model._cond_energy(z, model.x)

        assert args.mpf_method == 'ld', NotImplemented
        if args.latent == 'euc':
            y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr)
        elif args.latent == 'sph':
            y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr)
        self.mpf_loss = -tf.reduce_mean(neg_mpf_loss)
        self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2)
        self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss
        self.score_opt_op = optimize(self.score_loss, [COND_ENERGY], args)

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
        else:
            raise NotImplemented

        # Eq log(q/p)
        self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._cond_energy(model.z, model.x))
        self.ELBO = -self.reconstruction_loss - self.kl

        self.elbo_opt_op = optimize(-self.ELBO, [ENCODER, DECODER], args)
        self.print = {
            'loss/reconloss': self.reconstruction_loss,
            'loss/ELBO': self.ELBO,
            'loss/approx_KL': self.kl,
            'loss/mpf': self.mpf_loss,
            'loss/gp': self.gp_loss,
            'e/avg': tf.reduce_mean(model._cond_energy(model.z, model.x))
        }
        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.elbo_opt_op, fd)
        for j in range(self.args.train_score_dupl):
            sess.run(self.score_opt_op, fd)
Exemplo n.º 3
0
class MMDWAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
            pz_sample = self.dist_pz.sample()
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
            pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]])
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        assert pz_sample.shape.ndims == 2

        self.kl = matching_loss = mmd(model.z, pz_sample)
        self.wae_loss = self.reconstruction_loss + self.kl * (args.wae_lambda*100)
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/kl': self.kl
        }

        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.wae_opt_op, fd)
Exemplo n.º 4
0
class AIS(object):
    def __init__(self,
                 x_ph,
                 log_likelihood_fn,
                 dims,
                 num_samples=16,
                 method='hmc',
                 config=None):
        """
        The model implements Hamiltonian AIS.
        Developed by @bilginhalil on top of https://github.com/jiamings/ais/

        Example use case:
        logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)}
        p(x|z, theta) -> likelihood function p(z) -> prior
        Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix

        :param x_ph: Placeholder for x
        :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z
        :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :)
        :param num_samples: Number of samples to sample from in order to estimate the likelihood.

        The following are parameters for HMC.
        :param stepsize:
        :param n_steps:
        :param target_acceptance_rate:
        :param avg_acceptance_slowness:
        :param stepsize_min:
        :param stepsize_max:
        :param stepsize_dec:
        :param stepsize_inc:
        """

        self.dims = dims
        self.log_likelihood_fn = log_likelihood_fn
        self.num_samples = num_samples

        self.z_shape = [
            dims['batch_size'] * self.num_samples, dims['input_dim']
        ]

        if method != 'riem_ld':
            self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape),
                                                    scale_diag=tf.ones(
                                                        self.z_shape))
        else:
            self.prior = HypersphericalUniform(dims['input_dim'] - 1)

        self.batch_size = dims['batch_size']
        self.x = tf.tile(x_ph, [self.num_samples, 1])

        self.method = method
        self.config = config if config is not None else default_config[method]

    def log_f_i(self, z, t):
        return tf.reshape(-self.energy_fn(z, t),
                          [self.num_samples, self.batch_size])

    def energy_fn(self, z, t):
        e = self.prior.log_prob(z)
        assert e.shape.ndims == 1
        e += t * tf.reshape(self.log_likelihood_fn(self.x, z),
                            [self.num_samples * self.batch_size])
        assert e.shape.ndims == 1
        return -e

    def ais(self, schedule):
        """
            :param schedule: temperature schedule i.e. `p(z)p(x|z)^t`
            :return: [batch_size]
        """
        cfg = self.config
        if isinstance(self.prior, tfd.MultivariateNormalDiag):
            z = self.prior.sample()
        else:
            z = self.prior.sample([self.num_samples * self.batch_size])
        assert z.shape.ndims == 2

        index_summation = (tf.constant(0),
                           tf.zeros([self.num_samples,
                                     self.batch_size]), tf.cast(z, tf.float32),
                           cfg.stepsize, cfg.target_acceptance_rate)

        items = tf.unstack(
            tf.convert_to_tensor([[
                i, t0, t1
            ] for i, (t0, t1) in enumerate(zip(schedule[:-1], schedule[1:]))]))

        def condition(index, summation, z, stepsize, avg_acceptance_rate):
            return tf.less(index, len(schedule) - 1)

        def body(index, w, z, stepsize, avg_acceptance_rate):
            item = tf.gather(items, index)
            t0 = tf.gather(item, 1)
            t1 = tf.gather(item, 2)

            new_u = self.log_f_i(z, t1)
            prev_u = self.log_f_i(z, t0)
            w = tf.add(w, new_u - prev_u)

            def run_energy(z):
                e = self.energy_fn(z, t1)
                if self.method != 'hmc':
                    e = e[:, None]
                with tf.control_dependencies([e]):
                    return e

            # New step:
            if self.method == 'hmc':
                accept, final_pos, final_vel = hmc_move(
                    z, run_energy, stepsize, cfg.n_steps)
                new_z, new_stepsize, new_acceptance_rate = hmc_updates(
                    z,
                    stepsize,
                    avg_acceptance_rate=avg_acceptance_rate,
                    final_pos=final_pos,
                    accept=accept,
                    stepsize_min=cfg.stepsize_min,
                    stepsize_max=cfg.stepsize_max,
                    stepsize_dec=cfg.stepsize_dec,
                    stepsize_inc=cfg.stepsize_inc,
                    target_acceptance_rate=cfg.target_acceptance_rate,
                    avg_acceptance_slowness=cfg.avg_acceptance_slowness)

            elif self.method.endswith('ld'):
                new_z, cur_acc_rate = ld_move(z, run_energy, stepsize,
                                              cfg.n_steps, self.method)
                new_stepsize, new_acceptance_rate = ld_update(
                    stepsize,
                    cur_acc_rate=cur_acc_rate,
                    hist_acc_rate=avg_acceptance_rate,
                    target_acc_rate=cfg.target_acceptance_rate,
                    ssz_inc=cfg.stepsize_inc,
                    ssz_dec=cfg.stepsize_dec,
                    ssz_min=cfg.stepsize_min,
                    ssz_max=cfg.stepsize_max,
                    avg_acc_decay=cfg.avg_acceptance_slowness)

            return tf.add(index,
                          1), w, new_z, new_stepsize, new_acceptance_rate

        i, w, _, final_stepsize, final_acc_rate = tf.while_loop(
            condition,
            body,
            index_summation,
            parallel_iterations=1,
            swap_memory=True)
        # w = tf.Print(w, [final_stepsize, final_acc_rate], 'ff')
        return tf.squeeze(log_mean_exp(w, axis=0), axis=0)
Exemplo n.º 5
0
class ImplicitWAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        if args.latent == 'euc':
            if args.mpf_method == 'ld':
                y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr)
            else:
                y, neg_mpf_loss = mpf_euc_spos(
                    model.z, energ_emb, args.mpf_lr, alpha=args.mpf_spos_alpha)
        elif args.latent == 'sph' and args.mpf_method == 'ld':
            y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr)
        else:
            raise NotImplemented

        self.mpf_loss = tf.reduce_mean(-neg_mpf_loss) * 1e-3 / args.mpf_lr
        self.gp_loss  = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2)
        self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss
        self.score_opt_op = optimize(self.score_loss, [MARGINAL_ENERGY], args)

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
        else:
            raise NotImplemented

        # KL = Eq(logq - logp) = Eq(-logp - energy_q)
        self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._marginal_energy(model.z))
        self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/mpf': self.mpf_loss,
            'loss/gp': self.gp_loss
        }

        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.wae_opt_op, fd)
        for j in range(self.args.train_score_dupl):
            sess.run(self.score_opt_op, fd)