acc = eval_particles_acc(particles, val_gen)
    print('%s accuracy: %.4f' % (phase, acc))
    if phase == 'test':
        factor = cmd_args.n_stages * cmd_args.batch_size * cmd_args.stage_len
        f_log.write('num_obs %d, loss %.8f\n' % (num_obs, 1.0 - acc))
    return 1.0 - acc


if __name__ == '__main__':
    random.seed(cmd_args.seed)
    np.random.seed(cmd_args.seed)
    torch.manual_seed(cmd_args.seed)

    db = MnistDataset(cmd_args.data_folder)

    flow = build_model(cmd_args, x_dim=db.x_dim, ob_dim=db.ob_dim, ll_func=db.log_likelihood)
    ob_net = KernelEmbedNet(db.ob_dim, str(db.ob_dim), cmd_args.nonlinearity, trainable=True).to(DEVICE)
    
    if cmd_args.init_model_dump is not None:
        state = torch.load(cmd_args.init_model_dump)
        flow.load_state_dict(state)

    prior_dist = DiagMvn(mu=[cmd_args.prior_mu] * db.x_dim,
                         sigma=[cmd_args.prior_sigma] * db.x_dim)

    test_locs = [100, 200, 300, 400, 600, 700, 800, 1000, 1300,
                 1600, 2000, 2600, 3200, 4000, 5100, 6400, 8000, 10000, 12600, 
                 15900, 20000, 25200, 31700, 39900, 50200, 63100, 79500, 100000,
                 125900, 158500, 199600, 251200, 316300, 398200, 501200, 631000,
                 794400, 1000000, 1259000, 1584900, 1995300, 2511900, 
                 3162300, 3981100, 5011900, 6309600, 7943300]
    torch.manual_seed(cmd_args.seed)
    db = TwoGaussDataset(prior_mu=cmd_args.prior_mu,
                         prior_sigma=cmd_args.prior_sigma,
                         mu_given=[-1, 2],
                         l_sigma=1.0,
                         p=0.5,
                         partition_sizes={'train': cmd_args.train_samples})
    val_db = TwoGaussDataset(
        prior_mu=cmd_args.prior_mu,
        prior_sigma=cmd_args.prior_sigma,
        mu_given=[-1, 2],
        l_sigma=1.0,
        p=0.5,
        partition_sizes={'val': cmd_args.train_samples * cmd_args.num_vals})

    flow = build_model(cmd_args, x_dim=2, ob_dim=db.dim)

    if cmd_args.init_model_dump is not None:
        print('loading', cmd_args.init_model_dump)
        flow.load_state_dict(torch.load(cmd_args.init_model_dump))

    mvn_dist = DiagMvn(mu=[cmd_args.prior_mu] * cmd_args.gauss_dim,
                       sigma=[cmd_args.prior_sigma] * cmd_args.gauss_dim)

    if cmd_args.phase == 'train':
        optimizer = optim.Adam(flow.parameters(),
                               lr=cmd_args.learning_rate,
                               weight_decay=cmd_args.weight_decay)
        train_global_x_loop(
            cmd_args,
            lambda x: lm_train_gen(db, x),