示例#1
0
class GanWAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
            pz_sample = self.dist_pz.sample()
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
            pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]])
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        assert pz_sample.shape.ndims == 2
        pz_logits = model._marginal_energy(pz_sample)
        qz_logits = model._marginal_energy(model.z)
        self.gp_loss = tf.reduce_mean(energ_emb(model.z)**2) * 2
        self.score_loss = -(tf.reduce_mean(tf.log(tf.nn.sigmoid(pz_logits) + 1e-7)) +\
                            tf.reduce_mean(tf.log(1 - tf.nn.sigmoid(qz_logits) + 1e-7)))
        self.score_opt_op = optimize(
            self.score_loss + args.grad_penalty * self.gp_loss, [MARGINAL_ENERGY], args)

        self.kl = -tf.reduce_mean(tf.math.log_sigmoid(qz_logits)) # non-saturating GAN loss
        self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/disc': self.score_loss,
            'loss/gp': self.gp_loss,
            'loss/kl': self.kl
        }

        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.wae_opt_op, fd)
        for j in range(self.args.train_score_dupl):
            sess.run(self.score_opt_op, fd)
示例#2
0
    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        if args.latent == 'euc':
            if args.mpf_method == 'ld':
                y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr)
            else:
                y, neg_mpf_loss = mpf_euc_spos(
                    model.z, energ_emb, args.mpf_lr, alpha=args.mpf_spos_alpha)
        elif args.latent == 'sph' and args.mpf_method == 'ld':
            y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr)
        else:
            raise NotImplemented

        self.mpf_loss = tf.reduce_mean(-neg_mpf_loss) * 1e-3 / args.mpf_lr
        self.gp_loss  = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2)
        self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss
        self.score_opt_op = optimize(self.score_loss, [MARGINAL_ENERGY], args)

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
        else:
            raise NotImplemented

        # KL = Eq(logq - logp) = Eq(-logp - energy_q)
        self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._marginal_energy(model.z))
        self.wae_loss = self.reconstruction_loss + self.kl * args.wae_lambda
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/mpf': self.mpf_loss,
            'loss/gp': self.gp_loss
        }

        self.lc = locals()
示例#3
0
    def __init__(self,
                 x_ph,
                 log_likelihood_fn,
                 dims,
                 num_samples=16,
                 method='hmc',
                 config=None):
        """
        The model implements Hamiltonian AIS.
        Developed by @bilginhalil on top of https://github.com/jiamings/ais/

        Example use case:
        logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)}
        p(x|z, theta) -> likelihood function p(z) -> prior
        Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix

        :param x_ph: Placeholder for x
        :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z
        :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :)
        :param num_samples: Number of samples to sample from in order to estimate the likelihood.

        The following are parameters for HMC.
        :param stepsize:
        :param n_steps:
        :param target_acceptance_rate:
        :param avg_acceptance_slowness:
        :param stepsize_min:
        :param stepsize_max:
        :param stepsize_dec:
        :param stepsize_inc:
        """

        self.dims = dims
        self.log_likelihood_fn = log_likelihood_fn
        self.num_samples = num_samples

        self.z_shape = [
            dims['batch_size'] * self.num_samples, dims['input_dim']
        ]

        if method != 'riem_ld':
            self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape),
                                                    scale_diag=tf.ones(
                                                        self.z_shape))
        else:
            self.prior = HypersphericalUniform(dims['input_dim'] - 1)

        self.batch_size = dims['batch_size']
        self.x = tf.tile(x_ph, [self.num_samples, 1])

        self.method = method
        self.config = config if config is not None else default_config[method]
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T*self.bs, 3)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)

        return None
示例#5
0
def make_prior(code_size, distribution, alt_prior):
    """
                        Returns the prior on embeddings for tensorflow distributions

                        (i) MultivariateNormalDiag function

                        (ii) HypersphericalUniform

                        with alternative prior on gaussian

                        (1) Alt: N(0,1/code_size)
                        (2) N(0,1)
        """

    if distribution == 'normal':
        if alt_prior:  #alternative prior 0,1/embeddings variance
            loc = tf.zeros(code_size)
            scale = tf.sqrt(tf.divide(tf.ones(code_size), code_size))

        else:
            loc = tf.zeros(code_size)
            scale = tf.ones(code_size)

        dist = tfd.MultivariateNormalDiag(loc, scale)

    elif distribution == 'vmf':

        dist = HypersphericalUniform(code_size - 1, dtype=tf.float32)

    else:
        raise NotImplemented

    return dist
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(
            self.bs, d * d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(
            self.bs, d * d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v)
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample()  # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample()  # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n,
                                         self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T * self.bs, 2)

        # decode
        self.Xrec = self.obs_net(self.qT).view([T, self.bs, d, d])
        return None
示例#7
0
    def __init__(self, model, args):
        """
        OptimizerVAE initializer

        :param model: a model object
        :param learning_rate: float, learning rate of the optimizer
        """

        # binary cross entropy error
        assert args.observation == 'sigmoid', NotImplemented
        self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
        self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))

        if args.latent == 'euc':
            # KL divergence between normal approximate posterior and standard normal prior
            self.p_z = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=-1))
        elif args.latent == 'sph':
            # KL divergence between vMF approximate posterior and uniform hyper-spherical prior
            self.p_z = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(kl)
        else:
            raise NotImplemented

        self.ELBO = - self.reconstruction_loss - self.kl
        self.train_step = optimize(-self.ELBO, None, args)
        self.print = {'loss/recon': self.reconstruction_loss, 'loss/ELBO': self.ELBO, 'loss/KL': self.kl}
示例#8
0
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(self.bs, d*d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(self.bs, d*d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) 
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample() # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample() # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T*self.bs, 2)

        # decode
        ones = torch.ones_like(self.qT[:,0:1])
        self.content = self.obs_net(ones)

        theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T*self.bs) # cos , sin 

        grid = F.affine_grid(theta, torch.Size((T*self.bs, 1, d, d)))
        self.Xrec = F.grid_sample(self.content.view(T*self.bs, 1, d, d), grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])
        return None
示例#9
0
class MMDWAE(object):

    def __init__(self, model, args):

        self.model = model
        self.args = args

        im_axes = list(range(1, model.x.shape.ndims))
        if args.observation == 'normal':
            self.reconstruction_loss = tf.reduce_mean(
                tf.reduce_sum((model.x-model.qxz_mean)**2, axis=im_axes))
        elif args.observation == 'sigmoid':
            self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
            self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))
        else:
            raise NotImplemented

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
            pz_sample = self.dist_pz.sample()
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
            pz_sample = self.dist_pz.sample([tf.shape(model.z)[0]])
        else:
            raise NotImplemented

        def energ_emb(z):
            return model._marginal_energy(z)

        assert pz_sample.shape.ndims == 2

        self.kl = matching_loss = mmd(model.z, pz_sample)
        self.wae_loss = self.reconstruction_loss + self.kl * (args.wae_lambda*100)
        self.wae_opt_op = optimize(self.wae_loss, [ENCODER, DECODER], args)

        self.print = {
            'loss/recon': self.reconstruction_loss,
            'loss/wae': self.wae_loss,
            'loss/kl': self.kl
        }

        self.lc = locals()

    def step(self, sess, fd):
        sess.run(self.wae_opt_op, fd)
示例#10
0
 def _vmf_log_likelihood(self, sample, location=None, kappa=None):
     """Get the log likelihood of a sample under the vMF distribution with location and kappa."""
     if location is None and kappa is None:
         return HypersphericalUniform(self.z_dim - 1, device=self.device).log_prob(sample)
     elif location is not None and kappa is not None:
         return VonMisesFisher(location, kappa).log_prob(sample)
     else:
         raise InvalidArgumentError("Provide either location and kappa or neither.")
示例#11
0
    def __init__(self, model, args):

        self.model = model
        self.args = args

        # binary cross entropy error
        assert args.observation == 'sigmoid', NotImplemented
        self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x, logits=model.logits)
        self.reconstruction_loss = tf.reduce_mean(tf.reduce_sum(self.bce, axis=-1))

        def energ_emb(z):
            return model._cond_energy(z, model.x)

        assert args.mpf_method == 'ld', NotImplemented
        if args.latent == 'euc':
            y, neg_mpf_loss = mpf_euc(model.z, energ_emb, args.mpf_lr)
        elif args.latent == 'sph':
            y, neg_mpf_loss = mpf_sph(model.z, energ_emb, args.mpf_lr)
        self.mpf_loss = -tf.reduce_mean(neg_mpf_loss)
        self.gp_loss = tf.reduce_mean(energ_emb(y)**2) + tf.reduce_mean(energ_emb(model.z)**2)
        self.score_loss = self.mpf_loss + args.grad_penalty * self.gp_loss
        self.score_opt_op = optimize(self.score_loss, [COND_ENERGY], args)

        if args.latent == 'euc':
            self.dist_pz = tf.distributions.Normal(tf.zeros_like(model.z), tf.ones_like(model.z))
            self.log_pz = tf.reduce_sum(self.dist_pz.log_prob(model.z), axis=-1)
        elif args.latent == 'sph':
            self.dist_pz = HypersphericalUniform(model.z_dim - 1, dtype=model.x.dtype)
            self.log_pz = self.dist_pz.log_prob(model.z)
        else:
            raise NotImplemented

        # Eq log(q/p)
        self.kl = -tf.reduce_mean(self.log_pz) - tf.reduce_mean(model._cond_energy(model.z, model.x))
        self.ELBO = -self.reconstruction_loss - self.kl

        self.elbo_opt_op = optimize(-self.ELBO, [ENCODER, DECODER], args)
        self.print = {
            'loss/reconloss': self.reconstruction_loss,
            'loss/ELBO': self.ELBO,
            'loss/approx_KL': self.kl,
            'loss/mpf': self.mpf_loss,
            'loss/gp': self.gp_loss,
            'e/avg': tf.reduce_mean(model._cond_energy(model.z, model.x))
        }
        self.lc = locals()
    def reparameterize(self, z_mean, z_var):
        if self.distribution == 'normal':
            q_z = torch.distributions.normal.Normal(z_mean, z_var)
            p_z = torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var))
        elif self.distribution == 'vmf':
            q_z = VonMisesFisher(z_mean, z_var)
            p_z = HypersphericalUniform(self.z_dim - 1)
        else:
            raise NotImplemented

        return q_z, p_z
示例#13
0
文件: mnist.py 项目: pimdh/svae-temp
    def __init__(self, model, learning_rate=1e-3):
        """
        OptimizerVAE initializer

        :param model: a model object
        :param learning_rate: float, learning rate of the optimizer
        """

        self.kl_weight = tf.placeholder_with_default(np.array(1.).astype(
            np.float64),
                                                     shape=())

        # binary cross entropy error
        self.bce = tf.nn.sigmoid_cross_entropy_with_logits(labels=model.x,
                                                           logits=model.logits)
        self.score = tf.reduce_sum(self.bce, axis=-1)

        print('s1', self.score)

        print(model.distribution)

        if model.distribution == 'normal':
            # KL divergence between normal approximate posterior and standard normal prior
            self.p_z = tf.distributions.Normal(tf.zeros_like(model.z),
                                               tf.ones_like(model.z))
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(tf.reduce_sum(kl, axis=-1))
        elif model.distribution == 'vmf':
            # KL divergence between vMF approximate posterior and uniform hyper-spherical prior
            self.p_z = HypersphericalUniform(model.z_dim - 1,
                                             dtype=model.x.dtype)
            kl = model.q_z.kl_divergence(self.p_z)
            self.kl = tf.reduce_mean(kl)

            self.score = -model.q_z.add_g_cor(-self.score)
        else:
            raise NotImplemented

        self.reconstruction_loss = tf.reduce_mean(self.score)

        self.ELBO = -self.reconstruction_loss - self.kl

        self.train_step = tf.train.AdamOptimizer(
            learning_rate=learning_rate).minimize(self.reconstruction_loss +
                                                  self.kl * self.kl_weight)

        self.print = {
            'recon loss': self.reconstruction_loss,
            'ELBO': self.ELBO,
            'KL': self.kl,
            'KL weight': self.kl_weight
        }
示例#14
0
 def kl_distance(self):
     if self.vtype == "gauss":
         self.prior = tf.distributions.Normal(
             tf.zeros(self.central_state_size),
             tf.ones(self.central_state_size))
         self.kl = self.central_distribution.kl_divergence(self.prior)
         loss_kl = tf.reduce_mean(tf.reduce_sum(self.kl, axis=1))
     elif self.vtype == 'vmf':
         self.prior = HypersphericalUniform(self.central_state_size - 1,
                                            dtype=tf.float32)
         self.kl = self.central_distribution.kl_divergence(self.prior)
         loss_kl = tf.reduce_mean(self.kl)
     else:
         raise NotImplemented
     return loss_kl
示例#15
0
 def _vmf_sample_z(self, location, kappa, shape, det):
     """Reparameterized sample from a vMF distribution with location and concentration kappa."""
     if location is None and kappa is None and shape is not None:
         if det:
             raise InvalidArgumentError("Cannot deterministically sample from the Uniform on a Hypersphere.")
         else:
             return HypersphericalUniform(self.z_dim - 1, device=self.device).sample(shape[:-1])
     elif location is not None and kappa is not None:
         if det:
             return location
         if self.training:
             return VonMisesFisher(location, kappa).rsample()
         else:
             return VonMisesFisher(location, kappa).sample()
     else:
         raise InvalidArgumentError("Either provide location and kappa or neither with a shape.")
示例#16
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)
        return None
示例#17
0
    def sampled_z(self, mu, sigma, batch_size):
        if self.distribution == 'normal':
            epsilon = tf.random_normal(
                tf.stack([int(batch_size), self.n_latent_units]))
            z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma))
            loss = tf.reduce_mean(
                -0.5 * self.beta *
                tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1))
        elif self.distribution == 'vmf':
            self.q_z = VonMisesFisher(mu,
                                      sigma,
                                      validate_args=True,
                                      allow_nan_stats=False)
            z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.n_latent_units,
                                             validate_args=True,
                                             allow_nan_stats=False)
            loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z))
        else:
            raise NotImplemented

        return z, loss
示例#18
0
    def forward(self, inputs, lengths, dist='normal', fix=True):
        inputs = pack(self.drop(inputs), lengths, batch_first=True)
        _, hn = self.rnn(inputs)
        h = torch.cat(hn, dim=2).squeeze(0)
        if dist == 'normal':
            p_z = Normal(
                torch.zeros((h.size(0), self.code_dim), device=h.device),
                (0.5 * torch.zeros(
                    (h.size(0), self.code_dim), device=h.device)).exp())
            mu, lv = self.fcmu(h), self.fclv(h)
            if self.bn:
                mu, lv = self.bnmu(mu), self.bnlv(lv)
            return hn, Normal(mu, (0.5 * lv).exp()), p_z

        elif dist == 'vmf':
            mu = self.fcmu(h)
            mu = mu / mu.norm(dim=-1, keepdim=True)
            var = F.softplus(self.fcvar(h)) + 1
            if fix:
                var = torch.ones_like(var) * 80
            return hn, VonMisesFisher(mu, var), HypersphericalUniform(
                self.code_dim - 1, device=mu.device)
        else:
            raise NotImplementedError
示例#19
0
    def reparameterize(self, z_mean, z_var):

        q_z = VonMisesFisher(z_mean, z_var)
        p_z = HypersphericalUniform(self.z_dim - 1)

        return q_z, p_z
示例#20
0
class AIS(object):
    def __init__(self,
                 x_ph,
                 log_likelihood_fn,
                 dims,
                 num_samples=16,
                 method='hmc',
                 config=None):
        """
        The model implements Hamiltonian AIS.
        Developed by @bilginhalil on top of https://github.com/jiamings/ais/

        Example use case:
        logp(x|z) = |integrate over z|{logp(x|z,theta) + logp(z)}
        p(x|z, theta) -> likelihood function p(z) -> prior
        Prior is assumed to be a normal distribution with mean 0 and identity covariance matrix

        :param x_ph: Placeholder for x
        :param log_likelihood_fn: Outputs the logp(x|z, theta), it should take two parameters: x and z
        :param e.g. {'output_dim': 28*28, 'input_dim': FLAGS.d, 'batch_size': 1} :)
        :param num_samples: Number of samples to sample from in order to estimate the likelihood.

        The following are parameters for HMC.
        :param stepsize:
        :param n_steps:
        :param target_acceptance_rate:
        :param avg_acceptance_slowness:
        :param stepsize_min:
        :param stepsize_max:
        :param stepsize_dec:
        :param stepsize_inc:
        """

        self.dims = dims
        self.log_likelihood_fn = log_likelihood_fn
        self.num_samples = num_samples

        self.z_shape = [
            dims['batch_size'] * self.num_samples, dims['input_dim']
        ]

        if method != 'riem_ld':
            self.prior = tfd.MultivariateNormalDiag(loc=tf.zeros(self.z_shape),
                                                    scale_diag=tf.ones(
                                                        self.z_shape))
        else:
            self.prior = HypersphericalUniform(dims['input_dim'] - 1)

        self.batch_size = dims['batch_size']
        self.x = tf.tile(x_ph, [self.num_samples, 1])

        self.method = method
        self.config = config if config is not None else default_config[method]

    def log_f_i(self, z, t):
        return tf.reshape(-self.energy_fn(z, t),
                          [self.num_samples, self.batch_size])

    def energy_fn(self, z, t):
        e = self.prior.log_prob(z)
        assert e.shape.ndims == 1
        e += t * tf.reshape(self.log_likelihood_fn(self.x, z),
                            [self.num_samples * self.batch_size])
        assert e.shape.ndims == 1
        return -e

    def ais(self, schedule):
        """
            :param schedule: temperature schedule i.e. `p(z)p(x|z)^t`
            :return: [batch_size]
        """
        cfg = self.config
        if isinstance(self.prior, tfd.MultivariateNormalDiag):
            z = self.prior.sample()
        else:
            z = self.prior.sample([self.num_samples * self.batch_size])
        assert z.shape.ndims == 2

        index_summation = (tf.constant(0),
                           tf.zeros([self.num_samples,
                                     self.batch_size]), tf.cast(z, tf.float32),
                           cfg.stepsize, cfg.target_acceptance_rate)

        items = tf.unstack(
            tf.convert_to_tensor([[
                i, t0, t1
            ] for i, (t0, t1) in enumerate(zip(schedule[:-1], schedule[1:]))]))

        def condition(index, summation, z, stepsize, avg_acceptance_rate):
            return tf.less(index, len(schedule) - 1)

        def body(index, w, z, stepsize, avg_acceptance_rate):
            item = tf.gather(items, index)
            t0 = tf.gather(item, 1)
            t1 = tf.gather(item, 2)

            new_u = self.log_f_i(z, t1)
            prev_u = self.log_f_i(z, t0)
            w = tf.add(w, new_u - prev_u)

            def run_energy(z):
                e = self.energy_fn(z, t1)
                if self.method != 'hmc':
                    e = e[:, None]
                with tf.control_dependencies([e]):
                    return e

            # New step:
            if self.method == 'hmc':
                accept, final_pos, final_vel = hmc_move(
                    z, run_energy, stepsize, cfg.n_steps)
                new_z, new_stepsize, new_acceptance_rate = hmc_updates(
                    z,
                    stepsize,
                    avg_acceptance_rate=avg_acceptance_rate,
                    final_pos=final_pos,
                    accept=accept,
                    stepsize_min=cfg.stepsize_min,
                    stepsize_max=cfg.stepsize_max,
                    stepsize_dec=cfg.stepsize_dec,
                    stepsize_inc=cfg.stepsize_inc,
                    target_acceptance_rate=cfg.target_acceptance_rate,
                    avg_acceptance_slowness=cfg.avg_acceptance_slowness)

            elif self.method.endswith('ld'):
                new_z, cur_acc_rate = ld_move(z, run_energy, stepsize,
                                              cfg.n_steps, self.method)
                new_stepsize, new_acceptance_rate = ld_update(
                    stepsize,
                    cur_acc_rate=cur_acc_rate,
                    hist_acc_rate=avg_acceptance_rate,
                    target_acc_rate=cfg.target_acceptance_rate,
                    ssz_inc=cfg.stepsize_inc,
                    ssz_dec=cfg.stepsize_dec,
                    ssz_min=cfg.stepsize_min,
                    ssz_max=cfg.stepsize_max,
                    avg_acc_decay=cfg.avg_acceptance_slowness)

            return tf.add(index,
                          1), w, new_z, new_stepsize, new_acceptance_rate

        i, w, _, final_stepsize, final_acc_rate = tf.while_loop(
            condition,
            body,
            index_summation,
            parallel_iterations=1,
            swap_memory=True)
        # w = tf.Print(w, [final_stepsize, final_acc_rate], 'ff')
        return tf.squeeze(log_mean_exp(w, axis=0), axis=0)
示例#21
0
 def _vmf_kl_divergence(self, location, kappa):
     """Get the estimated KL between the VMF function with a uniform hyperspherical prior."""
     return kl_divergence(
         VonMisesFisher(location, kappa),
         HypersphericalUniform(self.z_dim - 1, device=self.device))
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.link1 = self.obs_net_1(ones)
        self.link2 = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(self.qT[:, 0],
                                    self.qT[:, 2],
                                    0,
                                    0,
                                    bs=T * self.bs)  # cos phi1, sin phi1
        x = self.link1_l * self.qT[:, 2]  # l * sin phi1
        y = self.link1_l * self.qT[:, 0]  # l * cos phi 1
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 3],
                                    x,
                                    y,
                                    bs=T * self.bs)  # cos phi2, sin phi 2

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_link1 = F.grid_sample(
            self.link1.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_link2 = F.grid_sample(
            self.link2.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_link1, transf_link2,
             torch.zeros_like(transf_link1)],
            dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
示例#24
0
    def reparameterize(self, z_mean, z_kappa):

        q_z = VonMisesFisher(z_mean, z_kappa)
        p_z = HypersphericalUniform(z_mean.size(1) - 1, device=DEVICE)

        return q_z, p_z