Beispiel #1
0
def get_latent_distributions(distribution, mu_s, mu_p, mu_o, log_sigma_sq_s,
                             log_sigma_sq_p, log_sigma_sq_o):
    """
                    Returns tf distributions for the generative network 
    """

    if distribution == 'normal':

        # sample from mean and std of the normal distribution

        q_s = tfd.MultivariateNormalDiag(mu_s,
                                         distribution_scale(log_sigma_sq_s))
        q_p = tfd.MultivariateNormalDiag(mu_p,
                                         distribution_scale(log_sigma_sq_p))
        q_o = tfd.MultivariateNormalDiag(mu_o,
                                         distribution_scale(log_sigma_sq_o))

    elif distribution == 'vmf':

        # sample from mean and concentration of the von Mises-Fisher

        # '+1' used to prevent collapsing behaviors

        q_s = VonMisesFisher(mu_s, distribution_scale(log_sigma_sq_s) + 1)
        q_p = VonMisesFisher(mu_p, distribution_scale(log_sigma_sq_p) + 1)
        q_o = VonMisesFisher(mu_o, distribution_scale(log_sigma_sq_o) + 1)

    else:
        raise NotImplemented

    return q_s, q_p, q_o
Beispiel #2
0
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal'):
        """
        ModelVAE initializer

        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, activation, distribution

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)
Beispiel #3
0
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(self.bs, d*d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(self.bs, d*d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) 
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample() # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample() # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T*self.bs, 2)

        # decode
        ones = torch.ones_like(self.qT[:,0:1])
        self.content = self.obs_net(ones)

        theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T*self.bs) # cos , sin 

        grid = F.affine_grid(theta, torch.Size((T*self.bs, 1, d, d)))
        self.Xrec = F.grid_sample(self.content.view(T*self.bs, 1, d, d), grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])
        return None
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T*self.bs, 3)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)

        return None
 def _vmf_sample_z(self, location, kappa, shape, det):
     """Reparameterized sample from a vMF distribution with location and concentration kappa."""
     if location is None and kappa is None and shape is not None:
         if det:
             raise InvalidArgumentError("Cannot deterministically sample from the Uniform on a Hypersphere.")
         else:
             return HypersphericalUniform(self.z_dim - 1, device=self.device).sample(shape[:-1])
     elif location is not None and kappa is not None:
         if det:
             return location
         if self.training:
             return VonMisesFisher(location, kappa).rsample()
         else:
             return VonMisesFisher(location, kappa).sample()
     else:
         raise InvalidArgumentError("Either provide location and kappa or neither with a shape.")
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(
            self.bs, d * d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(
            self.bs, d * d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v)
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample()  # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample()  # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n,
                                         self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T * self.bs, 2)

        # decode
        self.Xrec = self.obs_net(self.qT).view([T, self.bs, d, d])
        return None
 def _vmf_log_likelihood(self, sample, location=None, kappa=None):
     """Get the log likelihood of a sample under the vMF distribution with location and kappa."""
     if location is None and kappa is None:
         return HypersphericalUniform(self.z_dim - 1, device=self.device).log_prob(sample)
     elif location is not None and kappa is not None:
         return VonMisesFisher(location, kappa).log_prob(sample)
     else:
         raise InvalidArgumentError("Provide either location and kappa or neither.")
Beispiel #8
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)
        return None
    def reparameterize(self, z_mean, z_var):
        if self.distribution == 'normal':
            q_z = torch.distributions.normal.Normal(z_mean, z_var)
            p_z = torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var))
        elif self.distribution == 'vmf':
            q_z = VonMisesFisher(z_mean, z_var)
            p_z = HypersphericalUniform(self.z_dim - 1)
        else:
            raise NotImplemented

        return q_z, p_z
Beispiel #10
0
def classify(dataset, labels):
    """classification using node embeddings and logistic regression"""
    print('classification using lr :dataset:', dataset)

    #load node embeddings
    y = np.argmax(labels, axis=1)
    node_z_mean = np.load("result/{}.node.z.mean.npy".format(dataset))
    node_z_var = np.load("result/{}.node.z.var.npy".format(dataset))

    q_z = VonMisesFisher(torch.tensor(node_z_mean), torch.tensor(node_z_var))

    #train the model and get metrics
    macro_f1_avg = 0
    micro_f1_avg = 0
    acc_avg = 0

    for i in range(10):
        #sample data
        node_embedding = q_z.rsample()
        node_embedding = node_embedding.numpy()

        X_train, X_test, y_train, y_test = train_test_split(node_embedding,
                                                            y,
                                                            train_size=0.2,
                                                            test_size=1000,
                                                            random_state=2019)

        clf = LogisticRegression(solver="lbfgs",
                                 multi_class="multinomial",
                                 random_state=0).fit(X_train, y_train)

        y_pred = clf.predict(X_test)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        micro_f1 = f1_score(y_test, y_pred, average="micro")
        accuracy = accuracy_score(y_test, y_pred, normalize=True)

        macro_f1_avg += macro_f1
        micro_f1_avg += micro_f1
        acc_avg += accuracy

    return macro_f1_avg / 10, micro_f1_avg / 10, acc_avg / 10
Beispiel #11
0
    def sampled_z(self, mu, sigma, batch_size):
        if self.distribution == 'normal':
            epsilon = tf.random_normal(
                tf.stack([int(batch_size), self.n_latent_units]))
            z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma))
            loss = tf.reduce_mean(
                -0.5 * self.beta *
                tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1))
        elif self.distribution == 'vmf':
            self.q_z = VonMisesFisher(mu,
                                      sigma,
                                      validate_args=True,
                                      allow_nan_stats=False)
            z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.n_latent_units,
                                             validate_args=True,
                                             allow_nan_stats=False)
            loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z))
        else:
            raise NotImplemented

        return z, loss
Beispiel #12
0
 def encoder(self, inputs):
     conv = self.conv_layers(inputs)
     assert conv.get_shape().as_list()[1:] == self.conv_out_shape
     self.central_mu, self.central_sigma = self.enfc_layers(conv)
     if self.vtype == "gauss":
         assert self.central_mu.get_shape().as_list()[1:] == [
             self.central_state_size
         ]
     elif self.vtype == "vmf":
         assert self.central_sigma.get_shape().as_list()[1:] == [1]
     """# epsilon
     eps = tf.random_normal(tf.shape(self.central_mu), 0, 1, dtype=tf.float32)
     # z = mu + sigma*epsilon
     enfc = tf.add(self.central_mu, tf.multiply(tf.sqrt(tf.exp(self.central_log_sigma_sq)), eps))"""
     if self.vtype == "gauss":
         self.central_distribution = tf.distributions.Normal(
             self.central_mu, self.central_sigma)
     elif self.vtype == "vmf":
         self.central_distribution = VonMisesFisher(self.central_mu,
                                                    self.central_sigma)
     self.central_states = self.central_distribution.sample()
     return self.central_states
Beispiel #13
0
    def forward(self, inputs, lengths, dist='normal', fix=True):
        inputs = pack(self.drop(inputs), lengths, batch_first=True)
        _, hn = self.rnn(inputs)
        h = torch.cat(hn, dim=2).squeeze(0)
        if dist == 'normal':
            p_z = Normal(
                torch.zeros((h.size(0), self.code_dim), device=h.device),
                (0.5 * torch.zeros(
                    (h.size(0), self.code_dim), device=h.device)).exp())
            mu, lv = self.fcmu(h), self.fclv(h)
            if self.bn:
                mu, lv = self.bnmu(mu), self.bnlv(lv)
            return hn, Normal(mu, (0.5 * lv).exp()), p_z

        elif dist == 'vmf':
            mu = self.fcmu(h)
            mu = mu / mu.norm(dim=-1, keepdim=True)
            var = F.softplus(self.fcvar(h)) + 1
            if fix:
                var = torch.ones_like(var) * 80
            return hn, VonMisesFisher(mu, var), HypersphericalUniform(
                self.code_dim - 1, device=mu.device)
        else:
            raise NotImplementedError
Beispiel #14
0
    def reparameterize(self, z_mean, z_var):

        q_z = VonMisesFisher(z_mean, z_var)
        p_z = HypersphericalUniform(self.z_dim - 1)

        return q_z, p_z
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.link1 = self.obs_net_1(ones)
        self.link2 = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(self.qT[:, 0],
                                    self.qT[:, 2],
                                    0,
                                    0,
                                    bs=T * self.bs)  # cos phi1, sin phi1
        x = self.link1_l * self.qT[:, 2]  # l * sin phi1
        y = self.link1_l * self.qT[:, 0]  # l * cos phi 1
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 3],
                                    x,
                                    y,
                                    bs=T * self.bs)  # cos phi2, sin phi 2

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_link1 = F.grid_sample(
            self.link1.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_link2 = F.grid_sample(
            self.link2.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_link1, transf_link2,
             torch.zeros_like(transf_link1)],
            dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
Beispiel #16
0
 def _vmf_kl_divergence(self, location, kappa):
     """Get the estimated KL between the VMF function with a uniform hyperspherical prior."""
     return kl_divergence(
         VonMisesFisher(location, kappa),
         HypersphericalUniform(self.z_dim - 1, device=self.device))
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
Beispiel #18
0
    def reparameterize(self, z_mean, z_kappa):

        q_z = VonMisesFisher(z_mean, z_kappa)
        p_z = HypersphericalUniform(z_mean.size(1) - 1, device=DEVICE)

        return q_z, p_z