def classify(dataset, labels): """classification using node embeddings and logistic regression""" print('classification using lr :dataset:', dataset) #load node embeddings y = np.argmax(labels, axis=1) node_z_mean = np.load("result/{}.node.z.mean.npy".format(dataset)) node_z_var = np.load("result/{}.node.z.var.npy".format(dataset)) q_z = VonMisesFisher(torch.tensor(node_z_mean), torch.tensor(node_z_var)) #train the model and get metrics macro_f1_avg = 0 micro_f1_avg = 0 acc_avg = 0 for i in range(10): #sample data node_embedding = q_z.rsample() node_embedding = node_embedding.numpy() X_train, X_test, y_train, y_test = train_test_split(node_embedding, y, train_size=0.2, test_size=1000, random_state=2019) clf = LogisticRegression(solver="lbfgs", multi_class="multinomial", random_state=0).fit(X_train, y_train) y_pred = clf.predict(X_test) macro_f1 = f1_score(y_test, y_pred, average="macro") micro_f1 = f1_score(y_test, y_pred, average="micro") accuracy = accuracy_score(y_test, y_pred, normalize=True) macro_f1_avg += macro_f1 micro_f1_avg += micro_f1 acc_avg += accuracy return macro_f1_avg / 10, micro_f1_avg / 10, acc_avg / 10
class Model(pl.LightningModule): def __init__(self, hparams, data_path=None): super(Model, self).__init__() self.hparams = hparams self.data_path = data_path self.T_pred = self.hparams.T_pred self.loss_fn = torch.nn.MSELoss(reduction='none') self.recog_net_1 = MLP_Encoder(64 * 64, 300, 2, nonlinearity='elu') self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu') self.obs_net_1 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu') self.obs_net_2 = MLP_Decoder(1, 100, 64 * 64, nonlinearity='elu') g_net = MatrixNet(3, 100, 4, shape=(2, 2)) g_baseline_net = MLP(5, 400, 2) self.ode = Lag_Net_R1_T1(g_net=g_net, g_baseline=g_baseline_net, dyna_model='g_baseline') self.train_dataset = None self.non_ctrl_ind = 1 def train_dataloader(self): if self.hparams.homo_u: # must set trainer flag reload_dataloaders_every_epoch=True if self.train_dataset is None: self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred) if self.current_epoch < 1000: # feed zero ctrl dataset and ctrl dataset in turns if self.current_epoch % 2 == 0: u_idx = 0 else: u_idx = self.non_ctrl_ind self.non_ctrl_ind += 1 if self.non_ctrl_ind == 9: self.non_ctrl_ind = 1 else: u_idx = self.current_epoch % 9 self.train_dataset.u_idx = u_idx self.t_eval = torch.from_numpy(self.train_dataset.t_eval) return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) else: train_dataset = ImageDataset(self.data_path, self.hparams.T_pred, ctrl=True) self.t_eval = torch.from_numpy(train_dataset.t_eval) return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) def angle_vel_est(self, q0_m_n, q1_m_n, delta_t): delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1] delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2] q_dot0 = -delta_cos * q0_m_n[:, 1: 2] / delta_t + delta_sin * q0_m_n[:, 0: 1] / delta_t return q_dot0 def encode(self, batch_image): r_m_logv = self.recog_net_1(batch_image[:, 0].reshape( self.bs, self.d * self.d)) r_m, r_logv = r_m_logv.split([1, 1], dim=1) r_m = torch.tanh(r_m) r_v = torch.exp(r_logv) + 0.0001 theta = self.get_theta(1, 0, r_m[:, 0], 0) grid = F.affine_grid(theta, torch.Size((self.bs, 1, self.d, self.d))) pole_att_win = F.grid_sample(batch_image[:, 1:2], grid) phi_m_logv = self.recog_net_2( pole_att_win.reshape(self.bs, self.d * self.d)) phi_m, phi_logv = phi_m_logv.split([2, 1], dim=1) phi_m_n = phi_m / phi_m.norm(dim=-1, keepdim=True) phi_v = F.softplus(phi_logv) + 1 return r_m, r_v, phi_m, phi_v, phi_m_n def get_theta(self, cos, sin, x, y, bs=None): # x, y should have shape (bs, ) bs = self.bs if bs is None else bs theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device) theta[:, 0, 0] += cos theta[:, 0, 1] += sin theta[:, 0, 2] += x theta[:, 1, 0] += -sin theta[:, 1, 1] += cos theta[:, 1, 2] += y return theta def get_theta_inv(self, cos, sin, x, y, bs=None): bs = self.bs if bs is None else bs theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device) theta[:, 0, 0] += cos theta[:, 0, 1] += -sin theta[:, 0, 2] += -x * cos + y * sin theta[:, 1, 0] += sin theta[:, 1, 1] += cos theta[:, 1, 2] += -x * sin - y * cos return theta def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) # encode self.r0_m, self.r0_v, self.phi0_m, self.phi0_v, self.phi0_m_n = self.encode( X[0]) self.r1_m, self.r1_v, self.phi1_m, self.phi1_v, self.phi1_m_n = self.encode( X[1]) # reparametrize self.Q_r0 = Normal(self.r0_m, self.r0_v) self.P_normal = Normal(torch.zeros_like(self.r0_m), torch.ones_like(self.r0_v)) self.r0 = self.Q_r0.rsample() self.Q_phi0 = VonMisesFisher(self.phi0_m_n, self.phi0_v) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi0 = self.Q_phi0.rsample() while torch.isnan(self.phi0).any(): self.phi0 = self.Q_phi0.rsample() # estimate velocity self.r_dot0 = (self.r1_m - self.r0_m) / (self.t_eval[1] - self.t_eval[0]) self.phi_dot0 = self.angle_vel_est(self.phi0_m_n, self.phi1_m_n, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat([self.r0, self.phi0, self.r_dot0, self.phi_dot0, u], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([3, 2, 2], dim=-1) self.qT = self.qT.view(T * self.bs, 3) # decode ones = torch.ones_like(self.qT[:, 0:1]) self.cart = self.obs_net_1(ones) self.pole = self.obs_net_2(ones) theta1 = self.get_theta_inv(1, 0, self.qT[:, 0], 0, bs=T * self.bs) theta2 = self.get_theta_inv(self.qT[:, 1], self.qT[:, 2], self.qT[:, 0], 0, bs=T * self.bs) grid1 = F.affine_grid(theta1, torch.Size((T * self.bs, 1, self.d, self.d))) grid2 = F.affine_grid(theta2, torch.Size((T * self.bs, 1, self.d, self.d))) transf_cart = F.grid_sample( self.cart.view(T * self.bs, 1, self.d, self.d), grid1) transf_pole = F.grid_sample( self.pole.view(T * self.bs, 1, self.d, self.d), grid2) self.Xrec = torch.cat( [transf_cart, transf_pole, torch.zeros_like(transf_cart)], dim=1) self.Xrec = self.Xrec.view(T, self.bs, 3, self.d, self.d) return None def training_step(self, train_batch, batch_idx): X, u = train_batch self.forward(X, u) lhood = -self.loss_fn(self.Xrec, X) lhood = lhood.sum([0, 2, 3, 4]).mean() kl_q = torch.distributions.kl.kl_divergence(self.Q_r0, self.P_normal).mean() + \ torch.distributions.kl.kl_divergence(self.Q_phi0, self.P_hyper_uni).mean() norm_penalty = (self.phi0_m.norm(dim=-1).mean() - 1)**2 loss = -lhood + kl_q + 1 / 100 * norm_penalty logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss} return {'loss': loss, 'log': logs, 'progress_bar': logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), self.hparams.learning_rate) @staticmethod def add_model_specific_args(parent_parser): """ Specify the hyperparams for this LightningModule """ # MODEL specific parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--learning_rate', default=1e-4, type=float) parser.add_argument('--batch_size', default=1024, type=int) return parser
class Model(pl.LightningModule): def __init__(self, hparams, data_path=None): super(Model, self).__init__() self.hparams = hparams self.data_path = data_path self.T_pred = self.hparams.T_pred self.loss_fn = torch.nn.MSELoss(reduction='none') self.recog_q_net = MLP_Encoder(32 * 32, 300, 3, nonlinearity='elu') self.obs_net = MLP_Encoder(1, 100, 32 * 32, nonlinearity='elu') V_net = MLP(2, 50, 1) g_net = MLP(2, 50, 1) M_net = PSD(2, 50, 1) self.ode = Lag_Net(q_dim=1, u_dim=1, g_net=g_net, M_net=M_net, V_net=V_net) self.train_dataset = None self.non_ctrl_ind = 1 def train_dataloader(self): if self.hparams.homo_u: # must set trainer flag reload_dataloaders_every_epoch=True if self.train_dataset is None: self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred) if self.current_epoch < 1000: # feed zero ctrl dataset and ctrl dataset in turns if self.current_epoch % 2 == 0: u_idx = 0 else: u_idx = self.non_ctrl_ind self.non_ctrl_ind += 1 if self.non_ctrl_ind == 9: self.non_ctrl_ind = 1 else: u_idx = self.current_epoch % 9 self.train_dataset.u_idx = u_idx self.t_eval = torch.from_numpy(self.train_dataset.t_eval) return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) else: train_dataset = ImageDataset(self.data_path, self.hparams.T_pred) self.t_eval = torch.from_numpy(train_dataset.t_eval) return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) def angle_vel_est(self, q0_m_n, q1_m_n, delta_t): delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1] delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2] q_dot0 = -delta_cos * q0_m_n[:, 1: 2] / delta_t + delta_sin * q0_m_n[:, 0: 1] / delta_t return q_dot0 def encode(self, batch_image): q_m_logv = self.recog_q_net(batch_image) q_m, q_logv = q_m_logv.split([2, 1], dim=1) q_m_n = q_m / q_m.norm(dim=-1, keepdim=True) q_v = F.softplus(q_logv) + 1 return q_m, q_v, q_m_n def get_theta_inv(self, cos, sin, x, y, bs=None): bs = self.bs if bs is None else bs theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device) theta[:, 0, 0] += cos theta[:, 0, 1] += -sin theta[:, 0, 2] += -x * cos + y * sin theta[:, 1, 0] += sin theta[:, 1, 1] += cos theta[:, 1, 2] += -x * sin - y * cos return theta def forward(self, X, u): [_, self.bs, d, d] = X.shape T = len(self.t_eval) # encode self.q0_m, self.q0_v, self.q0_m_n = self.encode(X[0].reshape( self.bs, d * d)) self.q1_m, self.q1_v, self.q1_m_n = self.encode(X[1].reshape( self.bs, d * d)) # reparametrize self.Q_q = VonMisesFisher(self.q0_m_n, self.q0_v) self.P_q = HypersphericalUniform(1, device=self.device) self.q0 = self.Q_q.rsample() # bs, 2 while torch.isnan(self.q0).any(): self.q0 = self.Q_q.rsample() # a bad way to avoid nan # estimate velocity self.q_dot0 = self.angle_vel_est(self.q0_m_n, self.q1_m_n, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat((self.q0, self.q_dot0, u), dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([2, 1, 1], dim=-1) self.qT = self.qT.view(T * self.bs, 2) # decode ones = torch.ones_like(self.qT[:, 0:1]) self.content = self.obs_net(ones) theta = self.get_theta_inv(self.qT[:, 0], self.qT[:, 1], 0, 0, bs=T * self.bs) # cos , sin grid = F.affine_grid(theta, torch.Size((T * self.bs, 1, d, d))) self.Xrec = F.grid_sample(self.content.view(T * self.bs, 1, d, d), grid) self.Xrec = self.Xrec.view([T, self.bs, d, d]) return None def training_step(self, train_batch, batch_idx): X, u = train_batch self.forward(X, u) lhood = -self.loss_fn(self.Xrec, X) lhood = lhood.sum([0, 2, 3]).mean() kl_q = torch.distributions.kl.kl_divergence(self.Q_q, self.P_q).mean() norm_penalty = (self.q0_m.norm(dim=-1).mean() - 1)**2 lambda_ = self.current_epoch / 8000 if self.hparams.annealing else 1 / 100 loss = -lhood + kl_q + lambda_ * norm_penalty logs = { 'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss, 'monitor': -lhood + kl_q } return {'loss': loss, 'log': logs, 'progress_bar': logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), self.hparams.learning_rate) @staticmethod def add_model_specific_args(parent_parser): """ Specify the hyperparams for this LightningModule """ # MODEL specific parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--learning_rate', default=1e-3, type=float) parser.add_argument('--batch_size', default=512, type=int) return parser
class Model(pl.LightningModule): def __init__(self, hparams, data_path=None): super(Model, self).__init__() self.hparams = hparams self.data_path = data_path self.T_pred = self.hparams.T_pred self.loss_fn = torch.nn.MSELoss(reduction='none') self.recog_net_1 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu') self.recog_net_2 = MLP_Encoder(64 * 64, 300, 3, nonlinearity='elu') self.obs_net = MLP_Decoder(4, 100, 3 * 64 * 64, nonlinearity='elu') V_net = MLP(4, 100, 1) M_net = PSD(4, 300, 2) g_net = MatrixNet(4, 100, 4, shape=(2, 2)) self.ode = Lag_Net(q_dim=2, u_dim=2, g_net=g_net, M_net=M_net, V_net=V_net) self.link1_para = torch.nn.Parameter( torch.tensor(0.0, dtype=self.dtype)) self.train_dataset = None self.non_ctrl_ind = 1 def train_dataloader(self): if self.hparams.homo_u: # must set trainer flag reload_dataloaders_every_epoch=True if self.train_dataset is None: self.train_dataset = HomoImageDataset(self.data_path, self.hparams.T_pred) if self.current_epoch < 1000: # feed zero ctrl dataset and ctrl dataset in turns if self.current_epoch % 2 == 0: u_idx = 0 else: u_idx = self.non_ctrl_ind self.non_ctrl_ind += 1 if self.non_ctrl_ind == 9: self.non_ctrl_ind = 1 else: u_idx = self.current_epoch % 9 self.train_dataset.u_idx = u_idx self.t_eval = torch.from_numpy(self.train_dataset.t_eval) return DataLoader(self.train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) else: train_dataset = ImageDataset(self.data_path, self.hparams.T_pred, ctrl=True) self.t_eval = torch.from_numpy(train_dataset.t_eval) return DataLoader(train_dataset, batch_size=self.hparams.batch_size, shuffle=True, collate_fn=my_collate) def angle_vel_est(self, q0_m_n, q1_m_n, delta_t): delta_cos = q1_m_n[:, 0:1] - q0_m_n[:, 0:1] delta_sin = q1_m_n[:, 1:2] - q0_m_n[:, 1:2] q_dot0 = -delta_cos * q0_m_n[:, 1: 2] / delta_t + delta_sin * q0_m_n[:, 0: 1] / delta_t return q_dot0 def encode(self, batch_image): phi1_m_logv = self.recog_net_1(batch_image[:, 0:1].reshape( self.bs, self.d * self.d)) phi1_m, phi1_logv = phi1_m_logv.split([2, 1], dim=1) phi1_m_n = phi1_m / phi1_m.norm(dim=-1, keepdim=True) phi1_v = F.softplus(phi1_logv) + 1 phi2_m_logv = self.recog_net_2(batch_image[:, 1:2].reshape( self.bs, self.d * self.d)) phi2_m, phi2_logv = phi2_m_logv.split([2, 1], dim=1) phi2_m_n = phi2_m / phi2_m.norm(dim=-1, keepdim=True) phi2_v = F.softplus(phi2_logv) + 1 return phi1_m, phi1_v, phi1_m_n, phi2_m, phi2_v, phi2_m_n def get_theta(self, cos, sin, x, y, bs=None): # x, y should have shape (bs, ) bs = self.bs if bs is None else bs theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device) theta[:, 0, 0] += cos theta[:, 0, 1] += sin theta[:, 0, 2] += x theta[:, 1, 0] += -sin theta[:, 1, 1] += cos theta[:, 1, 2] += y return theta def get_theta_inv(self, cos, sin, x, y, bs=None): bs = self.bs if bs is None else bs theta = torch.zeros([bs, 2, 3], dtype=self.dtype, device=self.device) theta[:, 0, 0] += cos theta[:, 0, 1] += -sin theta[:, 0, 2] += -x * cos + y * sin theta[:, 1, 0] += sin theta[:, 1, 1] += cos theta[:, 1, 2] += -x * sin - y * cos return theta def forward(self, X, u): [_, self.bs, c, self.d, self.d] = X.shape T = len(self.t_eval) self.link1_l = torch.sigmoid(self.link1_para) # encode self.phi1_m_t0, self.phi1_v_t0, self.phi1_m_n_t0, self.phi2_m_t0, self.phi2_v_t0, self.phi2_m_n_t0 = self.encode( X[0]) self.phi1_m_t1, self.phi1_v_t1, self.phi1_m_n_t1, self.phi2_m_t1, self.phi2_v_t1, self.phi2_m_n_t1 = self.encode( X[1]) # reparametrize self.Q_phi1 = VonMisesFisher(self.phi1_m_n_t0, self.phi1_v_t0) self.Q_phi2 = VonMisesFisher(self.phi2_m_n_t0, self.phi2_v_t0) self.P_hyper_uni = HypersphericalUniform(1, device=self.device) self.phi1_t0 = self.Q_phi1.rsample() while torch.isnan(self.phi1_t0).any(): self.phi1_t0 = self.Q_phi1.rsample() self.phi2_t0 = self.Q_phi2.rsample() while torch.isnan(self.phi2_t0).any(): self.phi2_t0 = self.Q_phi2.rsample() # estimate velocity self.phi1_dot_t0 = self.angle_vel_est(self.phi1_m_n_t0, self.phi1_m_n_t1, self.t_eval[1] - self.t_eval[0]) self.phi2_dot_t0 = self.angle_vel_est(self.phi2_m_n_t0, self.phi2_m_n_t1, self.t_eval[1] - self.t_eval[0]) # predict z0_u = torch.cat([ self.phi1_t0[:, 0:1], self.phi2_t0[:, 0:1], self.phi1_t0[:, 1:2], self.phi2_t0[:, 1:2], self.phi1_dot_t0, self.phi2_dot_t0, u ], dim=1) zT_u = odeint(self.ode, z0_u, self.t_eval, method=self.hparams.solver) # T, bs, 4 self.qT, self.q_dotT, _ = zT_u.split([4, 2, 2], dim=-1) self.qT = self.qT.view(T * self.bs, 4) # decode self.Xrec = self.obs_net(self.qT).view(T, self.bs, 3, self.d, self.d) return None def training_step(self, train_batch, batch_idx): X, u = train_batch self.forward(X, u) lhood = -self.loss_fn(self.Xrec, X) lhood = lhood.sum([0, 2, 3, 4]).mean() kl_q = torch.distributions.kl.kl_divergence(self.Q_phi1, self.P_hyper_uni).mean() + \ torch.distributions.kl.kl_divergence(self.Q_phi2, self.P_hyper_uni).mean() norm_penalty = (self.phi1_m_t0.norm(dim=-1).mean() - 1) ** 2 + \ (self.phi2_m_t0.norm(dim=-1).mean() - 1) ** 2 loss = -lhood + kl_q + 1 / 100 * norm_penalty logs = {'recon_loss': -lhood, 'kl_q_loss': kl_q, 'train_loss': loss} return {'loss': loss, 'log': logs, 'progress_bar': logs} def configure_optimizers(self): return torch.optim.Adam(self.parameters(), self.hparams.learning_rate) @staticmethod def add_model_specific_args(parent_parser): """ Specify the hyperparams for this LightningModule """ # MODEL specific parser = ArgumentParser(parents=[parent_parser], add_help=False) parser.add_argument('--learning_rate', default=1e-3, type=float) parser.add_argument('--batch_size', default=1024, type=int) return parser