Пример #1
0
    def log_likelihood(self, obs, mu):
        assert mu.shape[0] == 1 or obs.shape[0] == 1  # safe bradcasting

        pos_ll = DiagMvn.log_pdf(obs, mu[:, 0].view(-1, 1), self.l_sigma) + np.log(self.p)
        neg_ll = DiagMvn.log_pdf(obs, torch.sum(mu, dim=1).view(-1, 1), self.l_sigma) + np.log(1.0 - self.p)
        logits = torch.cat([pos_ll, neg_ll], dim=1)
        ll = log_sum_exp(logits)
        return ll
def test_prepare():
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)
    torch.set_grad_enabled(False)
    # load test data
    print('loading test data.')
    test_db = MvnUniModalTestset(l_sigma=cmd_args.l_sigma)

    # config
    num_particles = cmd_args.num_particles
    len_sequence = test_db.len_seq
    num_epoch = test_db.epoch

    # initial particles from prior
    mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim,
                           sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim)

    pos_mu = test_db.prior_mu
    pos_cov = test_db.prior_sigma * test_db.prior_sigma
    pos_cov = torch.diag(pos_cov.reshape(test_db.dim))

    metric = create_metric_dict(num_epoch, len_sequence)

    return test_db, metric, mvn_dist, pos_mu, pos_cov, num_epoch, len_sequence, num_particles
Пример #3
0
def get_init_dist(db, batch_size, cmd_args, mvn_dist):
    db.reset()
    data_gen = db.data_gen(batch_size=batch_size,
                           phase='train',
                           auto_reset=False,
                           shuffle=True)
    hist_obs = []
    assert cmd_args.stage_len >= 1 and cmd_args.stage_len <= cmd_args.train_samples    
    num_prior_ob = np.random.randint(cmd_args.train_samples - cmd_args.stage_len + 1)
    for i in range(num_prior_ob):
        ob = next(data_gen)
        hist_obs.append(ob)
    if len(hist_obs) == 0:
        dist = mvn_dist
        func_log_prior = lambda new_x, old_x: db.log_prior(new_x)
    else:
        pos_mu, pos_sigma = db.get_true_posterior(torch.cat(hist_obs, dim=0))
        dist = DiagMvn(pos_mu, pos_sigma)
        func_log_prior = lambda new_x, old_x: DiagMvn.log_pdf(new_x, pos_mu, pos_sigma)
    return data_gen, dist, func_log_prior
if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)

    db = MnistDataset(cmd_args.data_folder)

    flow = build_model(cmd_args, x_dim=db.x_dim, ob_dim=db.ob_dim, ll_func=db.log_likelihood)
    ob_net = KernelEmbedNet(db.ob_dim, str(db.ob_dim), cmd_args.nonlinearity, trainable=True).to(DEVICE)
    
    if cmd_args.init_model_dump is not None:
        state = torch.load(cmd_args.init_model_dump)
        flow.load_state_dict(state)

    prior_dist = DiagMvn(mu=[cmd_args.prior_mu] * db.x_dim,
                         sigma=[cmd_args.prior_sigma] * db.x_dim)

    test_locs = [100, 200, 300, 400, 600, 700, 800, 1000, 1300,
                 1600, 2000, 2600, 3200, 4000, 5100, 6400, 8000, 10000, 12600, 
                 15900, 20000, 25200, 31700, 39900, 50200, 63100, 79500, 100000,
                 125900, 158500, 199600, 251200, 316300, 398200, 501200, 631000,
                 794400, 1000000, 1259000, 1584900, 1995300, 2511900, 
                 3162300, 3981100, 5011900, 6309600, 7943300]

    if cmd_args.phase == 'train':
        if cmd_args.stage_len <= 0:
            cmd_args.stage_len = 1
        # we need to approximate the posterior of entire dataset using n_stages flows
        # so it is equivalent to have this true_bsize as the actual batch size
        true_bsize = db.num_train / cmd_args.n_stages / cmd_args.stage_len
        # coefficient in front of entropy term
Пример #5
0
 def log_likelihood(self, obs, mu):
     assert mu.shape[0] == 1 or obs.shape[0] == 1  # safe bradcasting
     mu_bb = torch.mm(mu, self.l_bb.t())
     return DiagMvn.log_pdf(obs, mu_bb, self.l_sigma)
Пример #6
0
    offline_val_list = []
    for i in range(cmd_args.num_vals):
        obs_gen = db.gen_seq_obs(cmd_args.train_samples)
        offline_val = [ob for ob in obs_gen]
        offline_val_list.append(offline_val)

    flow = build_model(cmd_args,
                       x_dim=db.dim,
                       ob_dim=db.dim,
                       ll_func=db.log_likelihood)

    if cmd_args.init_model_dump is not None:
        print('loading', cmd_args.init_model_dump)
        flow.load_state_dict(torch.load(cmd_args.init_model_dump))

    mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim,
                       sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim)

    if cmd_args.phase == 'visualize':
        vis_flow(flow, mvn_dist, db, offline_val)
    elif cmd_args.phase == 'eval_metric':
        test_db = HmmLdsTestset()
        tot_time = 0.0
        num_eval = []
        with torch.no_grad():
            for e in tqdm(range(test_db.epoch)):
                ob_seq = [
                    torch.tensor(test_db.obs[e][t]).view(1, -1).to(DEVICE)
                    for t in range(test_db.len_seq)
                ]

                particles = mvn_dist.get_samples(cmd_args.num_particles)
Пример #7
0
                         partition_sizes={'train': cmd_args.train_samples})
    val_db = TwoGaussDataset(
        prior_mu=cmd_args.prior_mu,
        prior_sigma=cmd_args.prior_sigma,
        mu_given=[-1, 2],
        l_sigma=1.0,
        p=0.5,
        partition_sizes={'val': cmd_args.train_samples * cmd_args.num_vals})

    flow = build_model(cmd_args, x_dim=2, ob_dim=db.dim)

    if cmd_args.init_model_dump is not None:
        print('loading', cmd_args.init_model_dump)
        flow.load_state_dict(torch.load(cmd_args.init_model_dump))

    mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim,
                       sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim)

    if cmd_args.phase == 'train':
        optimizer = optim.Adam(flow.parameters(),
                               lr=cmd_args.learning_rate,
                               weight_decay=cmd_args.weight_decay)
        train_global_x_loop(
            cmd_args,
            lambda x: lm_train_gen(db, x),
            prior_dist=mvn_dist,
            flow=flow,
            optimizer=optimizer,
            func_ll=db.log_likelihood,
            func_log_prior=db.log_prior,
            eval_func=lambda f, p: eval_flow(cmd_args, f, p, val_db))
    eval_flow(cmd_args, flow, mvn_dist, val_db)
Пример #8
0
 def log_posterior(self, mu, obs):
     pos_mu, pos_sigma = self.get_true_posterior(obs)
     return DiagMvn.log_pdf(mu, pos_mu, pos_sigma)
Пример #9
0
 def grad_log_likelihood(self, obs, mu):
     # will take sum of grad_log_likelihood if multiple obs
     # return tensor of the same shape as mu
     return DiagMvn.grad_mu_log_pdf(obs, mu, self.l_sigma)
Пример #10
0
 def log_likelihood(self, obs, mu):
     assert mu.shape[0] == 1 or obs.shape[0] == 1  # safe bradcasting
     return DiagMvn.log_pdf(obs, mu, self.l_sigma)
Пример #11
0
 def grad_log_prior(self, mu):
     return DiagMvn.grad_x_log_pdf(mu, self.prior_mu, self.prior_sigma)
Пример #12
0
 def log_prior(self, mu):
     return DiagMvn.log_pdf(mu, self.prior_mu, self.prior_sigma)