def log_likelihood(self, obs, mu): assert mu.shape[0] == 1 or obs.shape[0] == 1 # safe bradcasting pos_ll = DiagMvn.log_pdf(obs, mu[:, 0].view(-1, 1), self.l_sigma) + np.log(self.p) neg_ll = DiagMvn.log_pdf(obs, torch.sum(mu, dim=1).view(-1, 1), self.l_sigma) + np.log(1.0 - self.p) logits = torch.cat([pos_ll, neg_ll], dim=1) ll = log_sum_exp(logits) return ll
def test_prepare(): random.seed(cmd_args.seed) np.random.seed(cmd_args.seed) torch.manual_seed(cmd_args.seed) torch.set_grad_enabled(False) # load test data print('loading test data.') test_db = MvnUniModalTestset(l_sigma=cmd_args.l_sigma) # config num_particles = cmd_args.num_particles len_sequence = test_db.len_seq num_epoch = test_db.epoch # initial particles from prior mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim, sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim) pos_mu = test_db.prior_mu pos_cov = test_db.prior_sigma * test_db.prior_sigma pos_cov = torch.diag(pos_cov.reshape(test_db.dim)) metric = create_metric_dict(num_epoch, len_sequence) return test_db, metric, mvn_dist, pos_mu, pos_cov, num_epoch, len_sequence, num_particles
def get_init_dist(db, batch_size, cmd_args, mvn_dist): db.reset() data_gen = db.data_gen(batch_size=batch_size, phase='train', auto_reset=False, shuffle=True) hist_obs = [] assert cmd_args.stage_len >= 1 and cmd_args.stage_len <= cmd_args.train_samples num_prior_ob = np.random.randint(cmd_args.train_samples - cmd_args.stage_len + 1) for i in range(num_prior_ob): ob = next(data_gen) hist_obs.append(ob) if len(hist_obs) == 0: dist = mvn_dist func_log_prior = lambda new_x, old_x: db.log_prior(new_x) else: pos_mu, pos_sigma = db.get_true_posterior(torch.cat(hist_obs, dim=0)) dist = DiagMvn(pos_mu, pos_sigma) func_log_prior = lambda new_x, old_x: DiagMvn.log_pdf(new_x, pos_mu, pos_sigma) return data_gen, dist, func_log_prior
if __name__ == '__main__': random.seed(cmd_args.seed) np.random.seed(cmd_args.seed) torch.manual_seed(cmd_args.seed) db = MnistDataset(cmd_args.data_folder) flow = build_model(cmd_args, x_dim=db.x_dim, ob_dim=db.ob_dim, ll_func=db.log_likelihood) ob_net = KernelEmbedNet(db.ob_dim, str(db.ob_dim), cmd_args.nonlinearity, trainable=True).to(DEVICE) if cmd_args.init_model_dump is not None: state = torch.load(cmd_args.init_model_dump) flow.load_state_dict(state) prior_dist = DiagMvn(mu=[cmd_args.prior_mu] * db.x_dim, sigma=[cmd_args.prior_sigma] * db.x_dim) test_locs = [100, 200, 300, 400, 600, 700, 800, 1000, 1300, 1600, 2000, 2600, 3200, 4000, 5100, 6400, 8000, 10000, 12600, 15900, 20000, 25200, 31700, 39900, 50200, 63100, 79500, 100000, 125900, 158500, 199600, 251200, 316300, 398200, 501200, 631000, 794400, 1000000, 1259000, 1584900, 1995300, 2511900, 3162300, 3981100, 5011900, 6309600, 7943300] if cmd_args.phase == 'train': if cmd_args.stage_len <= 0: cmd_args.stage_len = 1 # we need to approximate the posterior of entire dataset using n_stages flows # so it is equivalent to have this true_bsize as the actual batch size true_bsize = db.num_train / cmd_args.n_stages / cmd_args.stage_len # coefficient in front of entropy term
def log_likelihood(self, obs, mu): assert mu.shape[0] == 1 or obs.shape[0] == 1 # safe bradcasting mu_bb = torch.mm(mu, self.l_bb.t()) return DiagMvn.log_pdf(obs, mu_bb, self.l_sigma)
offline_val_list = [] for i in range(cmd_args.num_vals): obs_gen = db.gen_seq_obs(cmd_args.train_samples) offline_val = [ob for ob in obs_gen] offline_val_list.append(offline_val) flow = build_model(cmd_args, x_dim=db.dim, ob_dim=db.dim, ll_func=db.log_likelihood) if cmd_args.init_model_dump is not None: print('loading', cmd_args.init_model_dump) flow.load_state_dict(torch.load(cmd_args.init_model_dump)) mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim, sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim) if cmd_args.phase == 'visualize': vis_flow(flow, mvn_dist, db, offline_val) elif cmd_args.phase == 'eval_metric': test_db = HmmLdsTestset() tot_time = 0.0 num_eval = [] with torch.no_grad(): for e in tqdm(range(test_db.epoch)): ob_seq = [ torch.tensor(test_db.obs[e][t]).view(1, -1).to(DEVICE) for t in range(test_db.len_seq) ] particles = mvn_dist.get_samples(cmd_args.num_particles)
partition_sizes={'train': cmd_args.train_samples}) val_db = TwoGaussDataset( prior_mu=cmd_args.prior_mu, prior_sigma=cmd_args.prior_sigma, mu_given=[-1, 2], l_sigma=1.0, p=0.5, partition_sizes={'val': cmd_args.train_samples * cmd_args.num_vals}) flow = build_model(cmd_args, x_dim=2, ob_dim=db.dim) if cmd_args.init_model_dump is not None: print('loading', cmd_args.init_model_dump) flow.load_state_dict(torch.load(cmd_args.init_model_dump)) mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim, sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim) if cmd_args.phase == 'train': optimizer = optim.Adam(flow.parameters(), lr=cmd_args.learning_rate, weight_decay=cmd_args.weight_decay) train_global_x_loop( cmd_args, lambda x: lm_train_gen(db, x), prior_dist=mvn_dist, flow=flow, optimizer=optimizer, func_ll=db.log_likelihood, func_log_prior=db.log_prior, eval_func=lambda f, p: eval_flow(cmd_args, f, p, val_db)) eval_flow(cmd_args, flow, mvn_dist, val_db)
def log_posterior(self, mu, obs): pos_mu, pos_sigma = self.get_true_posterior(obs) return DiagMvn.log_pdf(mu, pos_mu, pos_sigma)
def grad_log_likelihood(self, obs, mu): # will take sum of grad_log_likelihood if multiple obs # return tensor of the same shape as mu return DiagMvn.grad_mu_log_pdf(obs, mu, self.l_sigma)
def log_likelihood(self, obs, mu): assert mu.shape[0] == 1 or obs.shape[0] == 1 # safe bradcasting return DiagMvn.log_pdf(obs, mu, self.l_sigma)
def grad_log_prior(self, mu): return DiagMvn.grad_x_log_pdf(mu, self.prior_mu, self.prior_sigma)
def log_prior(self, mu): return DiagMvn.log_pdf(mu, self.prior_mu, self.prior_sigma)