コード例 #1
0
def get_latent_distributions(distribution, mu_s, mu_p, mu_o, log_sigma_sq_s,
                             log_sigma_sq_p, log_sigma_sq_o):
    """
                    Returns tf distributions for the generative network 
    """

    if distribution == 'normal':

        # sample from mean and std of the normal distribution

        q_s = tfd.MultivariateNormalDiag(mu_s,
                                         distribution_scale(log_sigma_sq_s))
        q_p = tfd.MultivariateNormalDiag(mu_p,
                                         distribution_scale(log_sigma_sq_p))
        q_o = tfd.MultivariateNormalDiag(mu_o,
                                         distribution_scale(log_sigma_sq_o))

    elif distribution == 'vmf':

        # sample from mean and concentration of the von Mises-Fisher

        # '+1' used to prevent collapsing behaviors

        q_s = VonMisesFisher(mu_s, distribution_scale(log_sigma_sq_s) + 1)
        q_p = VonMisesFisher(mu_p, distribution_scale(log_sigma_sq_p) + 1)
        q_o = VonMisesFisher(mu_o, distribution_scale(log_sigma_sq_o) + 1)

    else:
        raise NotImplemented

    return q_s, q_p, q_o
コード例 #2
0
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(self.bs, d*d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(self.bs, d*d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) 
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample() # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample() # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T*self.bs, 2)

        # decode
        ones = torch.ones_like(self.qT[:,0:1])
        self.content = self.obs_net(ones)

        theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T*self.bs) # cos , sin 

        grid = F.affine_grid(theta, torch.Size((T*self.bs, 1, d, d)))
        self.Xrec = F.grid_sample(self.content.view(T*self.bs, 1, d, d), grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])
        return None
コード例 #3
0
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal'):
        """
        ModelVAE initializer

        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, activation, distribution

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)
コード例 #4
0
    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(
            self.bs, d * d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(
            self.bs, d * d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v)
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample()  # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample()  # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n,
                                         self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T * self.bs, 2)

        # decode
        self.Xrec = self.obs_net(self.qT).view([T, self.bs, d, d])
        return None
コード例 #5
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1]-self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T*self.bs, 3)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)

        return None
コード例 #6
0
 def _vmf_sample_z(self, location, kappa, shape, det):
     """Reparameterized sample from a vMF distribution with location and concentration kappa."""
     if location is None and kappa is None and shape is not None:
         if det:
             raise InvalidArgumentError("Cannot deterministically sample from the Uniform on a Hypersphere.")
         else:
             return HypersphericalUniform(self.z_dim - 1, device=self.device).sample(shape[:-1])
     elif location is not None and kappa is not None:
         if det:
             return location
         if self.training:
             return VonMisesFisher(location, kappa).rsample()
         else:
             return VonMisesFisher(location, kappa).sample()
     else:
         raise InvalidArgumentError("Either provide location and kappa or neither with a shape.")
コード例 #7
0
 def _vmf_log_likelihood(self, sample, location=None, kappa=None):
     """Get the log likelihood of a sample under the vMF distribution with location and kappa."""
     if location is None and kappa is None:
         return HypersphericalUniform(self.z_dim - 1, device=self.device).log_prob(sample)
     elif location is not None and kappa is not None:
         return VonMisesFisher(location, kappa).log_prob(sample)
     else:
         raise InvalidArgumentError("Provide either location and kappa or neither.")
コード例 #8
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)
        return None
コード例 #9
0
def classify(dataset, labels):
    """classification using node embeddings and logistic regression"""
    print('classification using lr :dataset:', dataset)

    #load node embeddings
    y = np.argmax(labels, axis=1)
    node_z_mean = np.load("result/{}.node.z.mean.npy".format(dataset))
    node_z_var = np.load("result/{}.node.z.var.npy".format(dataset))

    q_z = VonMisesFisher(torch.tensor(node_z_mean), torch.tensor(node_z_var))

    #train the model and get metrics
    macro_f1_avg = 0
    micro_f1_avg = 0
    acc_avg = 0

    for i in range(10):
        #sample data
        node_embedding = q_z.rsample()
        node_embedding = node_embedding.numpy()

        X_train, X_test, y_train, y_test = train_test_split(node_embedding,
                                                            y,
                                                            train_size=0.2,
                                                            test_size=1000,
                                                            random_state=2019)

        clf = LogisticRegression(solver="lbfgs",
                                 multi_class="multinomial",
                                 random_state=0).fit(X_train, y_train)

        y_pred = clf.predict(X_test)

        macro_f1 = f1_score(y_test, y_pred, average="macro")
        micro_f1 = f1_score(y_test, y_pred, average="micro")
        accuracy = accuracy_score(y_test, y_pred, normalize=True)

        macro_f1_avg += macro_f1
        micro_f1_avg += micro_f1
        acc_avg += accuracy

    return macro_f1_avg / 10, micro_f1_avg / 10, acc_avg / 10
コード例 #10
0
    def reparameterize(self, z_mean, z_var):
        if self.distribution == 'normal':
            q_z = torch.distributions.normal.Normal(z_mean, z_var)
            p_z = torch.distributions.normal.Normal(torch.zeros_like(z_mean), torch.ones_like(z_var))
        elif self.distribution == 'vmf':
            q_z = VonMisesFisher(z_mean, z_var)
            p_z = HypersphericalUniform(self.z_dim - 1)
        else:
            raise NotImplemented

        return q_z, p_z
コード例 #11
0
 def encoder(self, inputs):
     conv = self.conv_layers(inputs)
     assert conv.get_shape().as_list()[1:] == self.conv_out_shape
     self.central_mu, self.central_sigma = self.enfc_layers(conv)
     if self.vtype == "gauss":
         assert self.central_mu.get_shape().as_list()[1:] == [
             self.central_state_size
         ]
     elif self.vtype == "vmf":
         assert self.central_sigma.get_shape().as_list()[1:] == [1]
     """# epsilon
     eps = tf.random_normal(tf.shape(self.central_mu), 0, 1, dtype=tf.float32)
     # z = mu + sigma*epsilon
     enfc = tf.add(self.central_mu, tf.multiply(tf.sqrt(tf.exp(self.central_log_sigma_sq)), eps))"""
     if self.vtype == "gauss":
         self.central_distribution = tf.distributions.Normal(
             self.central_mu, self.central_sigma)
     elif self.vtype == "vmf":
         self.central_distribution = VonMisesFisher(self.central_mu,
                                                    self.central_sigma)
     self.central_states = self.central_distribution.sample()
     return self.central_states
コード例 #12
0
ファイル: training.py プロジェクト: brettbevers/miner
    def sampled_z(self, mu, sigma, batch_size):
        if self.distribution == 'normal':
            epsilon = tf.random_normal(
                tf.stack([int(batch_size), self.n_latent_units]))
            z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma))
            loss = tf.reduce_mean(
                -0.5 * self.beta *
                tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1))
        elif self.distribution == 'vmf':
            self.q_z = VonMisesFisher(mu,
                                      sigma,
                                      validate_args=True,
                                      allow_nan_stats=False)
            z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.n_latent_units,
                                             validate_args=True,
                                             allow_nan_stats=False)
            loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z))
        else:
            raise NotImplemented

        return z, loss
コード例 #13
0
    def forward(self, inputs, lengths, dist='normal', fix=True):
        inputs = pack(self.drop(inputs), lengths, batch_first=True)
        _, hn = self.rnn(inputs)
        h = torch.cat(hn, dim=2).squeeze(0)
        if dist == 'normal':
            p_z = Normal(
                torch.zeros((h.size(0), self.code_dim), device=h.device),
                (0.5 * torch.zeros(
                    (h.size(0), self.code_dim), device=h.device)).exp())
            mu, lv = self.fcmu(h), self.fclv(h)
            if self.bn:
                mu, lv = self.bnmu(mu), self.bnlv(lv)
            return hn, Normal(mu, (0.5 * lv).exp()), p_z

        elif dist == 'vmf':
            mu = self.fcmu(h)
            mu = mu / mu.norm(dim=-1, keepdim=True)
            var = F.softplus(self.fcvar(h)) + 1
            if fix:
                var = torch.ones_like(var) * 80
            return hn, VonMisesFisher(mu, var), HypersphericalUniform(
                self.code_dim - 1, device=mu.device)
        else:
            raise NotImplementedError
コード例 #14
0
 def _vmf_kl_divergence(self, location, kappa):
     """Get the estimated KL between the VMF function with a uniform hyperspherical prior."""
     return kl_divergence(
         VonMisesFisher(location, kappa),
         HypersphericalUniform(self.z_dim - 1, device=self.device))
コード例 #15
0
ファイル: model.py プロジェクト: Anonymous4sub/HCVA
    def reparameterize(self, z_mean, z_var):

        q_z = VonMisesFisher(z_mean, z_var)
        p_z = HypersphericalUniform(self.z_dim - 1)

        return q_z, p_z
コード例 #16
0
class ModelVAE(object):
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal'):
        """
        ModelVAE initializer

        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, activation, distribution

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)

    def _encoder(self, x):
        """
        Encoder network

        :param x: placeholder for input
        :return: tuple `(z_mean, z_var)` with mean and concentration around the mean
        """
        # 2 hidden layers encoder
        h0 = tf.layers.dense(x,
                             units=self.h_dim * 2,
                             activation=self.activation)
        h1 = tf.layers.dense(h0, units=self.h_dim, activation=self.activation)

        if self.distribution == 'normal':
            # compute mean and std of the normal distribution
            z_mean = tf.layers.dense(h1, units=self.z_dim, activation=None)
            z_var = tf.layers.dense(h1,
                                    units=self.z_dim,
                                    activation=tf.nn.softplus)
        elif self.distribution == 'vmf':
            # compute mean and concentration of the von Mises-Fisher
            z_mean = tf.layers.dense(
                h1,
                units=self.z_dim,
                activation=lambda x: tf.nn.l2_normalize(x, axis=-1))
            z_var = tf.layers.dense(h1, units=1, activation=tf.nn.softplus)
        else:
            raise NotImplemented

        return z_mean, z_var

    def _decoder(self, z):
        """
        Decoder network

        :param z: tensor, latent representation of input (x)
        :return: logits, `reconstruction = sigmoid(logits)`
        """
        # 2 hidden layers decoder
        h2 = tf.layers.dense(z, units=self.h_dim, activation=self.activation)
        h2 = tf.layers.dense(h2,
                             units=self.h_dim * 2,
                             activation=self.activation)
        logits = tf.layers.dense(h2, units=self.x.shape[-1], activation=None)

        return logits
コード例 #17
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net = MLP_Decoder(4, 100, 3 * 64 * 64, nonlinearity='elu')

        V_net = MLP(4, 100, 1)
        M_net = PSD(4, 300, 2)
        g_net = MatrixNet(4, 100, 4, shape=(2, 2))

        self.ode = Lag_Net(q_dim=2,
                           u_dim=2,
                           g_net=g_net,
                           M_net=M_net,
                           V_net=V_net)

        self.link1_para = torch.nn.Parameter(
            torch.tensor(0.0, dtype=self.dtype))

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path,
                                         self.hparams.T_pred,
                                         ctrl=True)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        phi1_m_logv = self.recog_net_1(batch_image[:, 0:1].reshape(
            self.bs, self.d * self.d))
        phi1_m, phi1_logv = phi1_m_logv.split([2, 1], dim=1)
        phi1_m_n = phi1_m / phi1_m.norm(dim=-1, keepdim=True)
        phi1_v = F.softplus(phi1_logv) + 1

        phi2_m_logv = self.recog_net_2(batch_image[:, 1:2].reshape(
            self.bs, self.d * self.d))
        phi2_m, phi2_logv = phi2_m_logv.split([2, 1], dim=1)
        phi2_m_n = phi2_m / phi2_m.norm(dim=-1, keepdim=True)
        phi2_v = F.softplus(phi2_logv) + 1
        return phi1_m, phi1_v, phi1_m_n, phi2_m, phi2_v, phi2_m_n

    def get_theta(self, cos, sin, x, y, bs=None):
        # x, y should have shape (bs, )
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += sin
        theta[:, 0, 2] += x
        theta[:, 1, 0] += -sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += y
        return theta

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d)
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3, 4]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_phi1, self.P_hyper_uni).mean() + \
               torch.distributions.kl.kl_divergence(self.Q_phi2, self.P_hyper_uni).mean()
        norm_penalty = (self.phi1_m_t0.norm(dim=-1).mean() - 1) ** 2 + \
                       (self.phi2_m_t0.norm(dim=-1).mean() - 1) ** 2

        loss = -lhood + kl_q + 1 / 100 * norm_penalty

        logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss}
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-3, type=float)
        parser.add_argument('--batch_size', default=1024, type=int)

        return parser
コード例 #18
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        self.link1_l = torch.sigmoid(self.link1_para)
        # encode
        self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode(
            X[0])
        self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode(
            X[1])
        # reparametrize
        self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0)
        self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi1_t0 = self.Q_phi1.rsample()
        while torch.isnan(self.phi1_t0).any():
            self.phi1_t0 = self.Q_phi1.rsample()
        self.phi2_t0 = self.Q_phi2.rsample()
        while torch.isnan(self.phi2_t0).any():
            self.phi2_t0 = self.Q_phi2.rsample()

        # estimate velocity
        self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0,
                                              self.phi1_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])
        self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0,
                                              self.phi2_m_n_t1,
                                              self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([
            self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2],
            self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u
        ],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 4)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.link1 = self.obs_net_1(ones)
        self.link2 = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(self.qT[:, 0],
                                    self.qT[:, 2],
                                    0,
                                    0,
                                    bs=T * self.bs)  # cos phi1, sin phi1
        x = self.link1_l * self.qT[:, 2]  # l * sin phi1
        y = self.link1_l * self.qT[:, 0]  # l * cos phi 1
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 3],
                                    x,
                                    y,
                                    bs=T * self.bs)  # cos phi2, sin phi 2

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_link1 = F.grid_sample(
            self.link1.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_link2 = F.grid_sample(
            self.link2.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_link1, transf_link2,
             torch.zeros_like(transf_link1)],
            dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
コード例 #19
0
ファイル: models.py プロジェクト: thu-ml/wmvl
class ExplicitAE(object):
    def __init__(self,
                 x,
                 h_dim,
                 z_dim,
                 activation=tf.nn.relu,
                 distribution='normal',
                 rescale_sph_latent=False):
        """
        :param x: placeholder for input
        :param h_dim: dimension of the hidden layers
        :param z_dim: dimension of the latent representation
        :param activation: callable activation function
        :param distribution: string either `normal` or `vmf`, indicates which distribution to use
        """
        self.x, self.h_dim, self.z_dim, self.activation, self.distribution = x, h_dim, z_dim, \
            activation, distribution
        self.rescale_sph_latent = rescale_sph_latent

        self.z_mean, self.z_var = self._encoder(self.x)

        if distribution == 'normal':
            self.q_z = tf.distributions.Normal(self.z_mean, self.z_var)
        elif distribution == 'vmf':
            self.q_z = VonMisesFisher(self.z_mean, self.z_var)
        else:
            raise NotImplemented

        self.z = self.q_z.sample()

        self.logits = self._decoder(self.z)

    def _encoder(self, x):
        """
        Encoder network

        :param x: placeholder for input
        :return: tuple `(z_mean, z_var)` with mean and concentration around the mean
        """

        with tf.variable_scope(ENCODER, reuse=AUTO_REUSE):
            # 2 hidden layers encoder
            h0 = tf.layers.dense(x,
                                 units=self.h_dim,
                                 activation=self.activation)
            h1 = tf.layers.dense(h0,
                                 units=self.h_dim,
                                 activation=self.activation)

            if self.distribution == 'normal':
                # compute mean and std of the normal distribution
                z_mean = tf.layers.dense(h1, units=self.z_dim, activation=None)
                z_var = tf.layers.dense(h1,
                                        units=self.z_dim,
                                        activation=tf.nn.softplus)
            elif self.distribution == 'vmf':
                # compute mean and concentration of the von Mises-Fisher
                z_mean = tf.layers.dense(
                    h1,
                    units=self.z_dim,
                    activation=lambda x: tf.nn.l2_normalize(x, axis=-1))
                # the `+ 1` prevent collapsing behaviors
                z_var = tf.layers.dense(h1, units=1,
                                        activation=tf.nn.softplus) + 1
            else:
                raise NotImplemented

            return z_mean, z_var

    def _decoder(self, z):
        """
        Decoder network

        :param z: tensor, latent representation of input (x)
        :return: logits, `reconstruction = sigmoid(logits)`
        """
        # 2 hidden layers decoder
        if self.distribution == 'vmf' and self.rescale_sph_latent:
            z = z * tf.sqrt(tf.to_float(self.z_dim))
        with tf.variable_scope(DECODER, reuse=AUTO_REUSE):
            h2 = tf.layers.dense(z,
                                 units=self.h_dim,
                                 activation=self.activation)
            h2 = tf.layers.dense(h2,
                                 units=self.h_dim,
                                 activation=self.activation)
            logits = tf.layers.dense(h2,
                                     units=self.x.shape[-1],
                                     activation=None)

        return logits
コード例 #20
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_q_net = MLP_Encoder(32 * 32, 300, 3, nonlinearity='elu')
        self.obs_net = MLP_Encoder(1, 100, 32 * 32, nonlinearity='elu')
        V_net = MLP(2, 50, 1)
        g_net = MLP(2, 50, 1)
        M_net = PSD(2, 50, 1)
        self.ode = Lag_Net(q_dim=1,
                           u_dim=1,
                           g_net=g_net,
                           M_net=M_net,
                           V_net=V_net)

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path, self.hparams.T_pred)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        q_m_logv = self.recog_q_net(batch_image)
        q_m, q_logv = q_m_logv.split([2, 1], dim=1)
        q_m_n = q_m / q_m.norm(dim=-1, keepdim=True)
        q_v = F.softplus(q_logv) + 1
        return q_m, q_v, q_m_n

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, d, d] = X.shape
        T = len(self.t_eval)
        # encode
        self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape(
            self.bs, d * d))
        self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape(
            self.bs, d * d))

        # reparametrize
        self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v)
        self.P_q = HypersphericalUniform(1, device=self.device)
        self.q0 = self.Q_q.rsample()  # bs, 2
        while torch.isnan(self.q0).any():
            self.q0 = self.Q_q.rsample()  # a bad way to avoid nan

        # estimate velocity
        self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n,
                                         self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1)
        self.qT = self.qT.view(T * self.bs, 2)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.content = self.obs_net(ones)

        theta = self.get_theta_inv(self.qT[:, 0],
                                   self.qT[:, 1],
                                   0,
                                   0,
                                   bs=T * self.bs)  # cos , sin

        grid = F.affine_grid(theta, torch.Size((T * self.bs, 1, d, d)))
        self.Xrec = F.grid_sample(self.content.view(T * self.bs, 1, d, d),
                                  grid)
        self.Xrec = self.Xrec.view([T, self.bs, d, d])
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean()
        norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1)**2

        lambda_ = self.current_epoch / 8000 if self.hparams.annealing else 1 / 100
        loss = -lhood + kl_q + lambda_ * norm_penalty

        logs = {
            'recon_loss': -lhood,
            'kl_q_loss': kl_q,
            'train_loss': loss,
            'monitor': -lhood + kl_q
        }
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-3, type=float)
        parser.add_argument('--batch_size', default=512, type=int)

        return parser
コード例 #21
0
ファイル: training.py プロジェクト: brettbevers/miner
class VariationalAutoEncoder(object):
    def __init__(self,
                 n_input_units,
                 n_hidden_layers,
                 n_hidden_units,
                 n_latent_units,
                 learning_rate=0.05,
                 batch_size=100,
                 min_beta=1.0,
                 max_beta=1.0,
                 distribution='normal',
                 serial_layering=None):
        self.n_input_units = n_input_units
        self.n_hidden_layers = n_hidden_layers
        self.n_hidden_units = n_hidden_units
        self.n_latent_units = n_latent_units
        self.learning_rate = learning_rate
        self.batch_size = int(batch_size)
        self.min_beta = min_beta
        self.max_beta = max_beta
        self.distribution = distribution
        if serial_layering:
            if not isinstance(serial_layering, (list, tuple)):
                raise TypeError(
                    "Argument 'serial_layering' must be a list or tuple of integers."
                )
            elif not all([isinstance(x, int) for x in serial_layering]):
                raise TypeError(
                    "Argument 'serial_layering' must be a list or tuple of integers."
                )
            elif sum(serial_layering) != self.n_hidden_layers:
                raise ValueError(
                    "Groupings in 'serial_layering' must sum to 'n_hidden_layers'."
                )
        self.serial_layering = serial_layering or [self.n_hidden_layers]
        self.layer_sequence = [
            sum(self.serial_layering[:i + 1])
            for i in range(len(self.serial_layering))
        ]

    class Encoder(object):
        def __init__(self,
                     n_hidden_layers,
                     n_hidden_units,
                     n_latent_units,
                     distribution,
                     initializers=None):
            self.n_hidden_layers = n_hidden_layers
            self.n_hidden_units = n_hidden_units
            self.n_latent_units = n_latent_units
            self.distribution = distribution
            self.initializers = initializers

        def init_hidden_layers(self):
            self.hidden_layers = []
            self.applied_hidden_layers = []

        def add_hidden_layer(self, inputs):
            if self.initializers and self.initializers.get('layers', None):
                print("initializing encoder layer...")
                kernel_initializer, bias_initializer = self.initializers[
                    'layers'].pop(0)
            else:
                kernel_initializer, bias_initializer = None, None

            self.hidden_layers.append(
                tf.layers.Dense(units=self.n_hidden_units,
                                activation=tf.nn.sigmoid,
                                kernel_initializer=kernel_initializer,
                                bias_initializer=bias_initializer))
            self.applied_hidden_layers.append(
                self.hidden_layers[-1].apply(inputs))
            return self.applied_hidden_layers[-1]

        def add_mu(self, inputs):
            if self.initializers and self.initializers.get('mu', None):
                print("initializing encoder mu...")
                kernel_initializer, bias_initializer = self.initializers['mu']
            else:
                kernel_initializer, bias_initializer = None, None

            if self.distribution == 'normal':
                self.mu = tf.layers.Dense(
                    units=self.n_latent_units,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
            elif self.distribution == 'vmf':
                self.mu = tf.layers.Dense(
                    units=self.n_latent_units + 1,
                    activation=lambda x: tf.nn.l2_normalize(x, axis=-1),
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
            else:
                raise NotImplemented

            self.applied_mu = self.mu.apply(inputs)
            return self.applied_mu

        def add_sigma(self, inputs):
            if self.initializers and self.initializers.get('sigma', None):
                print("initializing encoder sigma...")
                kernel_initializer, bias_initializer = self.initializers[
                    'sigma']
            else:
                kernel_initializer, bias_initializer = None, None

            if self.distribution == 'normal':
                self.sigma = tf.layers.Dense(
                    units=self.n_latent_units,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
                self.applied_sigma = self.sigma.apply(inputs)
            elif self.distribution == 'vmf':
                self.sigma = tf.layers.Dense(
                    units=1,
                    activation=tf.nn.softplus,
                    kernel_initializer=kernel_initializer,
                    bias_initializer=bias_initializer)
                self.applied_sigma = self.sigma.apply(inputs) + 1
            else:
                raise NotImplemented
            return self.applied_sigma

        def build(self, inputs):
            self.init_hidden_layers()

            layer = self.add_hidden_layer(inputs)

            for i in range(self.n_hidden_layers - 1):
                layer = self.add_hidden_layer(layer)

            mu = self.add_mu(layer)
            sigma = self.add_sigma(layer)

            return mu, sigma

        def eval(self, sess):
            layers = [sess.run([l.kernel, l.bias]) for l in self.hidden_layers]

            mu = sess.run([self.mu.kernel, self.mu.bias])

            sigma = sess.run([self.sigma.kernel, self.sigma.bias])

            return layers, mu, sigma

    class Decoder(object):
        def __init__(self,
                     n_hidden_layers,
                     n_hidden_units,
                     n_output_units,
                     initializers=None):
            self.n_hidden_layers = n_hidden_layers
            self.n_hidden_units = n_hidden_units
            self.n_output_units = n_output_units
            self.initializers = initializers

        def init_hidden_layers(self):
            self.hidden_layers = []
            self.applied_hidden_layers = []

        def add_hidden_layer(self, inputs):
            if self.initializers and self.initializers.get('layers', None):
                print("initializing decoder layer...")
                kernel_initializer, bias_initializer = self.initializers[
                    'layers'].pop(0)
            else:
                kernel_initializer, bias_initializer = None, None

            self.hidden_layers.append(
                tf.layers.Dense(units=self.n_hidden_units,
                                activation=tf.nn.sigmoid,
                                kernel_initializer=kernel_initializer,
                                bias_initializer=bias_initializer))
            self.applied_hidden_layers.append(
                self.hidden_layers[-1].apply(inputs))
            return self.applied_hidden_layers[-1]

        def add_output(self, inputs):
            if self.initializers and self.initializers.get('output', None):
                print("initializing decoder output...")
                kernel_initializer, bias_initializer = self.initializers[
                    'output']
            else:
                kernel_initializer, bias_initializer = None, None

            self.output = tf.layers.Dense(
                units=self.n_output_units,
                kernel_initializer=kernel_initializer,
                bias_initializer=bias_initializer)
            self.applied_output = self.output.apply(inputs)
            return self.applied_output

        def build(self, inputs):
            self.init_hidden_layers()

            layer = self.add_hidden_layer(inputs)

            for i in range(self.n_hidden_layers - 1):
                layer = self.add_hidden_layer(layer)

            output = self.add_output(layer)

            return output

        def eval(self, sess):
            layers = [sess.run([l.kernel, l.bias]) for l in self.hidden_layers]

            output = sess.run([self.output.kernel, self.output.bias])

            return layers, output

    def sampled_z(self, mu, sigma, batch_size):
        if self.distribution == 'normal':
            epsilon = tf.random_normal(
                tf.stack([int(batch_size), self.n_latent_units]))
            z = mu + tf.multiply(epsilon, tf.exp(0.5 * sigma))
            loss = tf.reduce_mean(
                -0.5 * self.beta *
                tf.reduce_sum(1.0 + sigma - tf.square(mu) - tf.exp(sigma), 1))
        elif self.distribution == 'vmf':
            self.q_z = VonMisesFisher(mu,
                                      sigma,
                                      validate_args=True,
                                      allow_nan_stats=False)
            z = self.q_z.sample()
            self.p_z = HypersphericalUniform(self.n_latent_units,
                                             validate_args=True,
                                             allow_nan_stats=False)
            loss = tf.reduce_mean(-self.q_z.kl_divergence(self.p_z))
        else:
            raise NotImplemented

        return z, loss

    def build_feature_loss(self, x, output):
        return tf.reduce_mean(
            tf.reduce_sum(tf.squared_difference(x, output), 1))

    def build_encoder_initializers(self, sess, n_hidden_layers):
        if hasattr(self, 'encoder'):
            result = {'layers': []}
            layers, mu, sigma = self.encoder.eval(sess)
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result['layers'].append((tf.constant_initializer(kernel),
                                             tf.constant_initializer(bias)))
                else:
                    result['layers'].append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))

            result['mu'] = (tf.constant_initializer(mu[0]),
                            tf.constant_initializer(mu[1]))
            result['sigma'] = (tf.constant_initializer(sigma[0]),
                               tf.constant_initializer(sigma[1]))
        else:
            result = None

        return result

    def build_decoder_initializers(self, sess, n_hidden_layers):
        if hasattr(self, 'decoder'):
            result = {'layers': []}
            layers, output = self.decoder.eval(sess)
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result['layers'].append((tf.constant_initializer(kernel),
                                             tf.constant_initializer(bias)))
                else:
                    result['layers'].append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))

            result['output'] = (tf.constant_initializer(output[0]),
                                tf.constant_initializer(output[1]))
        else:
            result = None

        return result

    def build_initializers(self, attr_name, sess, n_hidden_layers):
        if hasattr(self, attr_name):
            layers = getattr(self, attr_name).eval(sess)[0]
            result = []
            for i in range(n_hidden_layers):
                if layers:
                    kernel, bias = layers.pop(0)
                    result.append((tf.constant_initializer(kernel),
                                   tf.constant_initializer(bias)))
                else:
                    result.append(
                        (tf.constant_initializer(
                            np.diag(np.ones(self.n_latent_units))),
                         tf.constant_initializer(np.diag(np.ones(1)))))
            return result
        else:
            return None

    def initialize_tensors(self, sess, n_hidden_layers=None):
        n_hidden_layers = n_hidden_layers or self.n_hidden_layers

        self.x = tf.placeholder("float32",
                                [self.batch_size, self.n_input_units])
        self.beta = tf.placeholder("float32", [1, 1])
        self.encoder = self.Encoder(
            n_hidden_layers,
            self.n_hidden_units,
            self.n_latent_units,
            self.distribution,
            initializers=self.build_encoder_initializers(
                sess, n_hidden_layers))
        mu, sigma = self.encoder.build(self.x)
        self.mu = mu
        self.sigma = sigma

        z, latent_loss = self.sampled_z(self.mu, self.sigma, self.batch_size)
        self.z = z
        self.latent_loss = latent_loss

        self.decoder = self.Decoder(
            n_hidden_layers,
            self.n_hidden_units,
            self.n_input_units,
            initializers=self.build_decoder_initializers(
                sess, n_hidden_layers))
        self.output = self.decoder.build(self.z)

        self.feature_loss = self.build_feature_loss(self.x, self.output)
        self.loss = self.feature_loss + self.latent_loss

    def total_steps(self, data_count, epochs):
        num_batches = int(data_count / self.batch_size)
        return (num_batches * epochs) - epochs

    def generate_beta_values(self, total_steps):
        beta_delta = self.max_beta - self.min_beta
        log_beta_step = 5 / float(total_steps)
        beta_values = [
            self.min_beta + (beta_delta * (1 - math.exp(-5 +
                                                        (i * log_beta_step))))
            for i in range(total_steps)
        ]
        return beta_values

    def train_from_rdd(self, data_rdd, epochs=1):
        data_count = data_rdd.count()
        total_steps = self.total_steps(data_count, epochs)
        beta_values = self.generate_beta_values(total_steps)

        layer_sequence_step = int(total_steps / len(self.layer_sequence))
        layer_sequence = self.layer_sequence.copy()

        with tf.Session() as sess:
            batch_index = 0
            for epoch_index in range(epochs):
                iterator = data_rdd.toLocalIterator()
                while True:
                    if (not batch_index %
                            layer_sequence_step) and layer_sequence:
                        n_hidden_layers = layer_sequence.pop(0)
                        self.initialize_tensors(sess, n_hidden_layers)
                        optimizer = tf.train.AdamOptimizer(
                            self.learning_rate).minimize(self.loss)
                        sess.run(tf.global_variables_initializer())

                    batch = np.array(list(islice(iterator, self.batch_size)))
                    if batch.shape[0] == self.batch_size:
                        beta = beta_values.pop(
                            0) if len(beta_values) > 0 else self.min_beta
                        feed_dict = {
                            self.x: np.array(batch),
                            self.beta: np.array([[beta]])
                        }

                        if not batch_index % 1000:
                            print("beta: {}".format(beta))
                            print("number of hidden layers: {}".format(
                                n_hidden_layers))
                            ls, f_ls, d_ls = sess.run([
                                self.loss, self.feature_loss, self.latent_loss
                            ],
                                                      feed_dict=feed_dict)
                            print(
                                "loss={}, avg_feature_loss={}, avg_latent_loss={}"
                                .format(ls, np.mean(f_ls), np.mean(d_ls)))
                            print('running batch {} (epoch {})'.format(
                                batch_index, epoch_index))
                        sess.run(optimizer, feed_dict=feed_dict)
                        batch_index += 1
                    else:
                        print("incomplete batch: {}".format(batch.shape))
                        break

            print("evaluating model...")
            encoder_layers, eval_mu, eval_sigma = self.encoder.eval(sess)
            decoder_layers, eval_output = self.decoder.eval(sess)

        return VariationalAutoEncoderModel(encoder_layers, eval_mu, eval_sigma,
                                           decoder_layers, eval_output)

    def train(self, data, visualize=False, epochs=1):
        data_size = data.shape[0]
        batch_size = self.batch_size
        total_steps = self.total_steps(data_size, epochs)
        beta_values = self.generate_beta_values(total_steps)

        layer_sequence_step = int(total_steps / len(self.layer_sequence))
        layer_sequence = self.layer_sequence.copy()

        with tf.Session() as sess:
            for epoch_index in range(epochs):
                i = 0
                while (i * batch_size) < data_size:
                    if (not i % layer_sequence_step) and layer_sequence:
                        n_hidden_layers = layer_sequence.pop(0)
                        self.initialize_tensors(sess, n_hidden_layers)
                        optimizer = tf.train.AdamOptimizer(
                            self.learning_rate).minimize(self.loss)
                        sess.run(tf.global_variables_initializer())

                    batch = data[i * batch_size:(i + 1) * batch_size]
                    beta = beta_values.pop(
                        0) if len(beta_values) > 0 else self.min_beta
                    feed_dict = {self.x: batch, self.beta: np.array([[beta]])}
                    sess.run(optimizer, feed_dict=feed_dict)
                    if visualize and (not i % int((data_size / batch_size) / 3)
                                      or i == int(data_size / batch_size) - 1):
                        ls, d, f_ls, d_ls = sess.run([
                            self.loss, self.output, self.feature_loss,
                            self.latent_loss
                        ],
                                                     feed_dict=feed_dict)
                        plt.scatter(batch[:, 0], batch[:, 1])
                        plt.show()
                        plt.scatter(d[:, 0], d[:, 1])
                        plt.show()
                        print(i, ls, np.mean(f_ls), np.mean(d_ls))

                    i += 1

            encoder_layers, eval_mu, eval_sigma = self.encoder.eval(sess)
            decoder_layers, eval_output = self.decoder.eval(sess)

        return VariationalAutoEncoderModel(encoder_layers, eval_mu, eval_sigma,
                                           decoder_layers, eval_output)
コード例 #22
0
class VAE(CAE):
    def __init__(
            self,
            vtype,
            output_low_bound,
            output_up_bound,
            # relu bounds
            nonlinear_low_bound,
            nonlinear_up_bound,
            # conv layers
            conv_filter_sizes=[3, 3],  #[[3,3], [3,3], [3,3], [3,3], [3,3]], 
            conv_strides=[1, 1],  #[[1,1], [1,1], [1,1], [1,1], [1,1]],
            conv_padding="SAME",  #["SAME", "SAME", "SAME", "SAME", "SAME"],
            conv_channel_sizes=[128, 128, 128, 64, 64, 64,
                                3],  # [128, 128, 128, 128, 1]
            conv_leaky_ratio=[0.4, 0.4, 0.4, 0.2, 0.2, 0.2, 0.1],
            # deconv layers
            decv_filter_sizes=[3, 3],  #[[3,3], [3,3], [3,3], [3,3], [3,3]], 
            decv_strides=[1, 1],  #[[1,1], [1,1], [1,1], [1,1], [1,1]],
            decv_padding="SAME",  #["SAME", "SAME", "SAME", "SAME", "SAME"],
            decv_channel_sizes=[3, 64, 64, 64, 128, 128,
                                128],  # [1, 128, 128, 128, 128]
            decv_leaky_ratio=[0.1, 0.2, 0.2, 0.2, 0.4, 0.4, 0.01],
            # encoder fc layers
            enfc_state_sizes=[4096],
            enfc_leaky_ratio=[0.2, 0.2],
            enfc_drop_rate=[0, 0.75],
            # bottleneck
            central_state_size=2048,
            # decoder fc layers
            defc_state_sizes=[4096],
            defc_leaky_ratio=[0.2, 0.2],
            defc_drop_rate=[0.75, 0],
            # img channel
            img_channel=None,
            # switch
            use_norm=None):
        self.vtype = vtype
        super().__init__(
            output_low_bound, output_up_bound, nonlinear_low_bound,
            nonlinear_up_bound, conv_filter_sizes, conv_strides, conv_padding,
            conv_channel_sizes, conv_leaky_ratio, decv_filter_sizes,
            decv_strides, decv_padding, decv_channel_sizes, decv_leaky_ratio,
            enfc_state_sizes, enfc_leaky_ratio, enfc_drop_rate,
            central_state_size, defc_state_sizes, defc_leaky_ratio,
            defc_drop_rate, img_channel, use_norm)

    @lazy_method
    def enfc_weights_biases(self):
        in_size = self.conv_out_shape[0] * self.conv_out_shape[
            1] * self.conv_out_shape[2]
        state_sizes = self.enfc_state_sizes + [self.central_state_size]
        return self._fc_weights_biases("W_enfc",
                                       "b_enfc",
                                       in_size,
                                       state_sizes,
                                       sampling=True)

    def _fc_weights_biases(self,
                           W_name,
                           b_name,
                           in_size,
                           state_sizes,
                           sampling=False):
        num_layer = len(state_sizes)
        _weights = {}
        _biases = {}

        def _func(in_size, out_size, idx, postfix=""):
            W_key = "{}{}{}".format(W_name, idx, postfix)
            W_shape = [in_size, out_size]
            _weights[W_key] = ne.weight_variable(W_shape, name=W_key)

            b_key = "{}{}{}".format(b_name, idx, postfix)
            b_shape = [out_size]
            _biases[b_key] = ne.bias_variable(b_shape, name=b_key)

            in_size = out_size

            # tensorboard
            tf.summary.histogram("Weight_" + W_key, _weights[W_key])
            tf.summary.histogram("Bias_" + b_key, _biases[b_key])

            return in_size

        for idx in range(num_layer - 1):
            in_size = _func(in_size, state_sizes[idx], idx)
        # Last layer
        if sampling:
            if self.vtype == "gauss":
                for postfix in ["_mu", "_sigma"]:
                    _func(in_size, state_sizes[num_layer - 1], num_layer - 1,
                          postfix)
            elif self.vtype == "vmf":
                _func(in_size, state_sizes[num_layer - 1], num_layer - 1,
                      "_mu")
                _func(in_size, 1, num_layer - 1, "_sigma")
            else:
                raise NotImplemented
        else:
            _func(in_size, state_sizes[num_layer - 1], num_layer - 1)
        #import pdb; pdb.set_trace()

        return _weights, _biases, num_layer

    @lazy_method
    def enfc_layers(self, inputs, W_name="W_enfc", b_name="b_enfc"):
        net = tf.reshape(inputs, [
            -1, self.conv_out_shape[0] * self.conv_out_shape[1] *
            self.conv_out_shape[2]
        ])

        def _func(net, layer_id, postfix="", act_func="leaky"):
            weight_name = "{}{}{}".format(W_name, layer_id, postfix)
            bias_name = "{}{}{}".format(b_name, layer_id, postfix)
            curr_weight = self.enfc_weights[weight_name]
            curr_bias = self.enfc_biases[bias_name]
            net = ne.fully_conn(net, weights=curr_weight, biases=curr_bias)
            # batch normalization
            if self.use_norm == "BATCH":
                net = ne.batch_norm(net, self.is_training, axis=1)
            elif self.use_norm == "LAYER":
                net = ne.layer_norm(net, self.is_training)
            #net = ne.leaky_brelu(net, self.enfc_leaky_ratio[layer_id], self.enfc_low_bound[layer_id], self.enfc_up_bound[layer_id]) # Nonlinear act
            if act_func == "leaky":
                net = ne.leaky_relu(net, self.enfc_leaky_ratio[layer_id])
            elif act_func == "soft":
                net = tf.nn.softplus(net)
            #net = ne.drop_out(net, self.enfc_drop_rate[layer_id], self.is_training)
            return net

        for layer_id in range(self.num_enfc - 1):
            net = _func(net, layer_id)
        # Last layer
        if self.vtype == "gauss":
            # compute mean and log of var of the normal distribution
            """net_mu = tf.minimum(tf.maximum(-5.0, _func(net, self.num_enfc-1, "_mu")), 5.0)
            ## Set low and up bounds for log_sigma_sq
            '''net_log_sigma_sq = tf.minimum(tf.maximum(-10.0, _func(net, self.num_enfc-1, "_sigma")), 5.0)
            net_sigma = tf.sqrt(tf.exp(net_log_sigma_sq))'''
            net_sigma = tf.maximum(_func(net, self.num_enfc-1, "_sigma", "soft"), 5.0)"""
            net_mu = _func(net, self.num_enfc - 1, "_mu")
            net_log_sigma_sq = tf.minimum(
                tf.maximum(-10.0, _func(net, self.num_enfc - 1, "_sigma")),
                5.0)
            net_sigma = tf.sqrt(tf.exp(net_log_sigma_sq))
        elif self.vtype == "vmf":
            # compute mean and log of var of the von Mises-Fisher
            #net_mu = tf.minimum(tf.maximum(0.0, _func(net, self.num_enfc-1, "_mu", None)), 0.0)
            net_mu = _func(net, self.num_enfc - 1, "_mu", None)
            net_mu = tf.nn.l2_normalize(net_mu, axis=-1)
            #net_mu = tf.nn.l2_normalize(_func(net, self.num_enfc-1, "_mu"), axis=1)
            ## Set low and up bounds for log_sigma_sq
            #net_log_sigma_sq = tf.minimum(tf.maximum(0.0, _func(net, self.num_enfc-1, "_log_sigma_sq")), 10.0)
            net_sigma = _func(net, self.num_enfc - 1, "_sigma", "soft") + 200.0
        else:
            raise NotImplemented

        net_mu = tf.identity(net_mu, name="output_mu")
        net_sigma = tf.identity(net_sigma, name="output_sigma")
        return net_mu, net_sigma

    @lazy_method
    def encoder(self, inputs):
        conv = self.conv_layers(inputs)
        assert conv.get_shape().as_list()[1:] == self.conv_out_shape
        self.central_mu, self.central_sigma = self.enfc_layers(conv)
        if self.vtype == "gauss":
            assert self.central_mu.get_shape().as_list()[1:] == [
                self.central_state_size
            ]
        elif self.vtype == "vmf":
            assert self.central_sigma.get_shape().as_list()[1:] == [1]
        """# epsilon
        eps = tf.random_normal(tf.shape(self.central_mu), 0, 1, dtype=tf.float32)
        # z = mu + sigma*epsilon
        enfc = tf.add(self.central_mu, tf.multiply(tf.sqrt(tf.exp(self.central_log_sigma_sq)), eps))"""
        if self.vtype == "gauss":
            self.central_distribution = tf.distributions.Normal(
                self.central_mu, self.central_sigma)
        elif self.vtype == "vmf":
            self.central_distribution = VonMisesFisher(self.central_mu,
                                                       self.central_sigma)
        self.central_states = self.central_distribution.sample()
        return self.central_states

    @lazy_method
    def kl_distance(self):
        if self.vtype == "gauss":
            self.prior = tf.distributions.Normal(
                tf.zeros(self.central_state_size),
                tf.ones(self.central_state_size))
            self.kl = self.central_distribution.kl_divergence(self.prior)
            loss_kl = tf.reduce_mean(tf.reduce_sum(self.kl, axis=1))
        elif self.vtype == 'vmf':
            self.prior = HypersphericalUniform(self.central_state_size - 1,
                                               dtype=tf.float32)
            self.kl = self.central_distribution.kl_divergence(self.prior)
            loss_kl = tf.reduce_mean(self.kl)
        else:
            raise NotImplemented
        return loss_kl

    @lazy_method
    def gauss_kl_distance(self):
        loss = -0.5 * tf.reduce_sum(
            1 + self.central_log_sigma_sq - tf.square(self.central_mu) -
            tf.exp(self.central_log_sigma_sq), 1)

        return loss

    def tf_load(self, sess, path, name='deep_vcae.ckpt', spec=""):
        #saver = tf.train.Saver(dict(self.conv_filters, **self.conv_biases, **self.decv_filters, **self.decv_biases))
        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='autoencoder'))
        saver.restore(sess, path + '/' + name + spec)

    def tf_save(self, sess, path, name='deep_vcae.ckpt', spec=""):
        #saver = tf.train.Saver(dict(self.conv_filters, **self.conv_biases, **self.decv_filters, **self.decv_biases))
        saver = tf.train.Saver(var_list=tf.get_collection(
            tf.GraphKeys.GLOBAL_VARIABLES, scope='autoencoder'))
        saver.save(sess, path + '/' + name + spec)
コード例 #23
0
    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None
コード例 #24
0
class Model(pl.LightningModule):
    def __init__(self, hparams, data_path=None):
        super(Model, self).__init__()
        self.hparams = hparams
        self.data_path = data_path
        self.T_pred = self.hparams.T_pred
        self.loss_fn = torch.nn.MSELoss(reduction='none')

        self.recog_net_1 = MLP_Encoder(64 * 64, 300, 2, nonlinearity='elu')
        self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu')
        self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')
        self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu')

        g_net = MatrixNet(3, 100, 4, shape=(2, 2))
        g_baseline_net = MLP(5, 400, 2)

        self.ode = Lag_Net_R1_T1(g_net=g_net,
                                 g_baseline=g_baseline_net,
                                 dyna_model='g_baseline')

        self.train_dataset = None
        self.non_ctrl_ind = 1

    def train_dataloader(self):
        if self.hparams.homo_u:
            # must set trainer flag reload_dataloaders_every_epoch=True
            if self.train_dataset is None:
                self.train_dataset = HomoImageDataset(self.data_path,
                                                      self.hparams.T_pred)
            if self.current_epoch < 1000:
                # feed zero ctrl dataset and ctrl dataset in turns
                if self.current_epoch % 2 == 0:
                    u_idx = 0
                else:
                    u_idx = self.non_ctrl_ind
                    self.non_ctrl_ind += 1
                    if self.non_ctrl_ind == 9:
                        self.non_ctrl_ind = 1
            else:
                u_idx = self.current_epoch % 9
            self.train_dataset.u_idx = u_idx
            self.t_eval = torch.from_numpy(self.train_dataset.t_eval)
            return DataLoader(self.train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)
        else:
            train_dataset = ImageDataset(self.data_path,
                                         self.hparams.T_pred,
                                         ctrl=True)
            self.t_eval = torch.from_numpy(train_dataset.t_eval)
            return DataLoader(train_dataset,
                              batch_size=self.hparams.batch_size,
                              shuffle=True,
                              collate_fn=my_collate)

    def angle_vel_est(self, q0_m_n, q1_m_n, delta_t):
        delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1]
        delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2]
        q_dot0 = -delta_cos * q0_m_n[:, 1:
                                     2] / delta_t + delta_sin * q0_m_n[:, 0:
                                                                       1] / delta_t
        return q_dot0

    def encode(self, batch_image):
        r_m_logv = self.recog_net_1(batch_image[:, 0].reshape(
            self.bs, self.d * self.d))
        r_m, r_logv = r_m_logv.split([1, 1], dim=1)
        r_m = torch.tanh(r_m)
        r_v = torch.exp(r_logv) + 0.0001
        theta = self.get_theta(1, 0, r_m[:, 0], 0)
        grid = F.affine_grid(theta, torch.Size((self.bs, 1, self.d, self.d)))
        pole_att_win = F.grid_sample(batch_image[:, 1:2], grid)
        phi_m_logv = self.recog_net_2(
            pole_att_win.reshape(self.bs, self.d * self.d))
        phi_m, phi_logv = phi_m_logv.split([2, 1], dim=1)
        phi_m_n = phi_m / phi_m.norm(dim=-1, keepdim=True)
        phi_v = F.softplus(phi_logv) + 1
        return r_m, r_v, phi_m, phi_v, phi_m_n

    def get_theta(self, cos, sin, x, y, bs=None):
        # x, y should have shape (bs, )
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += sin
        theta[:, 0, 2] += x
        theta[:, 1, 0] += -sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += y
        return theta

    def get_theta_inv(self, cos, sin, x, y, bs=None):
        bs = self.bs if bs is None else bs
        theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device)
        theta[:, 0, 0] += cos
        theta[:, 0, 1] += -sin
        theta[:, 0, 2] += -x * cos + y * sin
        theta[:, 1, 0] += sin
        theta[:, 1, 1] += cos
        theta[:, 1, 2] += -x * sin - y * cos
        return theta

    def forward(self, X, u):
        [_, self.bs, c, self.d, self.d] = X.shape
        T = len(self.t_eval)
        # encode
        self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode(
            X[0])
        self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode(
            X[1])

        # reparametrize
        self.Q_r0 = Normal(self.r0_m, self.r0_v)
        self.P_normal = Normal(torch.zeros_like(self.r0_m),
                               torch.ones_like(self.r0_v))
        self.r0 = self.Q_r0.rsample()

        self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v)
        self.P_hyper_uni = HypersphericalUniform(1, device=self.device)
        self.phi0 = self.Q_phi0.rsample()
        while torch.isnan(self.phi0).any():
            self.phi0 = self.Q_phi0.rsample()

        # estimate velocity
        self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] -
                                                 self.t_eval[0])
        self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n,
                                           self.t_eval[1] - self.t_eval[0])

        # predict
        z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u],
                         dim=1)
        zT_u = odeint(self.ode, z0_u, self.t_eval,
                      method=self.hparams.solver)  # T, bs, 4
        self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1)
        self.qT = self.qT.view(T * self.bs, 3)

        # decode
        ones = torch.ones_like(self.qT[:, 0:1])
        self.cart = self.obs_net_1(ones)
        self.pole = self.obs_net_2(ones)

        theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs)
        theta2 = self.get_theta_inv(self.qT[:, 1],
                                    self.qT[:, 2],
                                    self.qT[:, 0],
                                    0,
                                    bs=T * self.bs)

        grid1 = F.affine_grid(theta1,
                              torch.Size((T * self.bs, 1, self.d, self.d)))
        grid2 = F.affine_grid(theta2,
                              torch.Size((T * self.bs, 1, self.d, self.d)))

        transf_cart = F.grid_sample(
            self.cart.view(T * self.bs, 1, self.d, self.d), grid1)
        transf_pole = F.grid_sample(
            self.pole.view(T * self.bs, 1, self.d, self.d), grid2)
        self.Xrec = torch.cat(
            [transf_cart, transf_pole,
             torch.zeros_like(transf_cart)], dim=1)
        self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d)
        return None

    def training_step(self, train_batch, batch_idx):
        X, u = train_batch
        self.forward(X, u)

        lhood = -self.loss_fn(self.Xrec, X)
        lhood = lhood.sum([0, 2, 3, 4]).mean()
        kl_q = torch.distributions.kl.kl_divergence(self.Q_r0, self.P_normal).mean() + \
                torch.distributions.kl.kl_divergence(self.Q_phi0, self.P_hyper_uni).mean()
        norm_penalty = (self.phi0_m.norm(dim=-1).mean() - 1)**2

        loss = -lhood + kl_q + 1 / 100 * norm_penalty

        logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss}
        return {'loss': loss, 'log': logs, 'progress_bar': logs}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), self.hparams.learning_rate)

    @staticmethod
    def add_model_specific_args(parent_parser):
        """
        Specify the hyperparams for this LightningModule
        """
        # MODEL specific
        parser = ArgumentParser(parents=[parent_parser], add_help=False)
        parser.add_argument('--learning_rate', default=1e-4, type=float)
        parser.add_argument('--batch_size', default=1024, type=int)

        return parser
コード例 #25
0
    def reparameterize(self, z_mean, z_kappa):

        q_z = VonMisesFisher(z_mean, z_kappa)
        p_z = HypersphericalUniform(z_mean.size(1) - 1, device=DEVICE)

        return q_z, p_z