Exemplo n.º 1
0
def sis_r(db, ob, particles):
    num_particles = cmd_args.num_particles
    weights = torch.ones([num_particles, 1],
                         dtype=torch.float).to(DEVICE) / num_particles
    # re-weight
    reweight = torch.zeros(weights.shape, dtype=weights.dtype).to(DEVICE)
    for ob_i in ob:
        reweight += db.log_likelihood(ob_i.reshape([1, -1]), particles)
    reweight = torch.exp(reweight)
    weights = weights * reweight

    # re-sample
    weights = weights / torch.sum(weights)
    mnrnd = torch.distributions.multinomial.Multinomial(
        num_particles, torch.reshape(torch.relu(weights), [num_particles]))
    idx = mnrnd.sample().int()
    nonzero_cluster = idx.nonzero()
    indx = 0
    new_xi = torch.zeros(particles.shape, dtype=particles.dtype)
    for iter in range(nonzero_cluster.shape[0]):
        nonzero_idx = nonzero_cluster[iter][0]
        new_xi[indx:indx + idx[nonzero_idx], :] = particles[
            nonzero_cluster[iter][0]].repeat(int(idx[nonzero_idx]), 1)
        indx += idx[nonzero_cluster[iter][0]]
    particles = new_xi

    ess = torch.sum(weights)**2 / torch.sum(weights * weights)

    if ess <= cmd_args.threshold * num_particles:
        # generate new location
        kde = KDE(particles)
        particles = kde.get_samples(cmd_args.num_particles)

    return particles
Exemplo n.º 2
0
def eval_flow(flow, mvn_dist, db, offline_val_list):
    flow.eval()
    ent = 0.0
    for idx, offline_val in enumerate(offline_val_list):
        val_gen = iter(offline_val)
        particles = mvn_dist.get_samples(cmd_args.num_particles)
        densities = mvn_dist.get_log_pdf(particles)

        pos_mu = db.prior_mu
        pos_cov = db.prior_sigma * db.prior_sigma
        pos_cov = torch.diag(pos_cov.reshape(db.dim))

        for t, ob in enumerate(val_gen):
            particles, densities = flow(particles,
                                        densities,
                                        prior_samples=particles,
                                        ob_m=ob)
            # evaluate
            pos_mu, pos_cov = db.get_true_new_posterior(ob, pos_mu, pos_cov)
            if idx == 0:
                print('step:', t)
                print('true posterior:',
                      pos_mu.cpu().data.numpy(),
                      pos_cov.cpu().data.numpy())
                print('estimated:',
                      np.mean(particles.cpu().data.numpy(), axis=0),
                      np.cov(particles.cpu().data.numpy().transpose()))

            p_particles = np.random.multivariate_normal(
                pos_mu.data.cpu().numpy().flatten(),
                pos_cov.data.cpu().numpy(),
                cmd_args.num_mc_samples).astype(np.float32)
            kde = KDE(particles)
            cur_ent = -torch.mean(
                kde.log_pdf(torch.tensor(p_particles).to(DEVICE))).item()
            if idx == 0:
                print('cross entropy:', cur_ent)
            ent += cur_ent
    ent /= len(offline_val_list)
    print('avg ent over %d seqs: %.4f' % (len(offline_val_list), ent))
    flow.train()
    return ent
Exemplo n.º 3
0
def onepass_smc(db, ob, particles, alpha):
    num_particles = cmd_args.num_particles

    # re-weight
    reweight = torch.zeros(alpha.shape, dtype=alpha.dtype)
    for ob_i in ob:
        reweight += db.log_likelihood(ob_i.reshape([1, -1]),
                                      particles).view(reweight.size())
    reweight = torch.exp(reweight)
    alpha = alpha * reweight

    # re-sample
    ess = torch.sum(alpha)**2 / torch.sum(alpha * alpha)

    if ess <= cmd_args.threshold * num_particles:
        alpha = alpha / torch.sum(alpha)
        mnrnd = torch.distributions.multinomial.Multinomial(
            num_particles, torch.reshape(torch.abs(alpha), [num_particles]))
        idx = mnrnd.sample().int()
        nonzero_cluster = idx.nonzero()
        indx = 0
        new_xi = torch.zeros(particles.shape, dtype=particles.dtype).to(DEVICE)
        for iter in range(nonzero_cluster.shape[0]):
            nonzero_idx = nonzero_cluster[iter][0]
            new_xi[indx:indx + idx[nonzero_idx], :] = particles[
                nonzero_cluster[iter][0]].repeat(int(idx[nonzero_idx]), 1)
            indx += idx[nonzero_cluster[iter][0]]

        # generate new location
        kde = KDE(new_xi)
        xi = kde.get_samples(cmd_args.num_particles)
        alpha = torch.ones([num_particles, 1],
                           dtype=torch.float) / num_particles
    else:
        alpha = alpha / torch.sum(alpha)
        xi = particles

    return xi, alpha
Exemplo n.º 4
0
def eval_flow(flow, mvn_dist, val_db):
    flow.eval()
    val_gen = val_db.data_gen(batch_size=1,
                              phase='val',
                              auto_reset=False,
                              shuffle=False)
    ent = 0.0
    for n_s in tqdm(range(cmd_args.num_vals)):
        hist_obs = []
        particles = mvn_dist.get_samples(cmd_args.num_particles)
        densities = mvn_dist.get_log_pdf(particles)        
        for t, ob in enumerate(val_gen):
            particles, densities = flow(particles, densities, 
                                        prior_samples=particles,
                                        ob_m=ob)
            hist_obs.append(ob)
            with torch.no_grad():
                pos_mu, pos_sigma = db.get_true_posterior(torch.cat(hist_obs, dim=0))
                q_mu = torch.mean(particles, dim=0, keepdim=True)
                q_std = torch.std(particles, dim=0, keepdim=True)
                if n_s + 1 == cmd_args.num_vals:
                    print('step:', t)
                    print('true posterior:', pos_mu.cpu().data.numpy(), pos_sigma.cpu().data.numpy())
                    print('estimated:', q_mu.cpu().data.numpy(), q_std.cpu().data.numpy())

                p_particles = torch_randn2d(cmd_args.num_mc_samples, val_db.dim) * pos_sigma + pos_mu
                kde = KDE(particles)
                cur_ent = -torch.mean(kde.log_pdf(p_particles)).item()
                if n_s + 1 == cmd_args.num_vals:
                    print('cross entropy:', cur_ent)
                ent += cur_ent
            if t + 1 == cmd_args.train_samples:
                break
    print('avg ent over %d seqs: %.4f' % (cmd_args.num_vals, ent/cmd_args.num_vals))
    flow.train()
    return ent
Exemplo n.º 5
0
def vis_flow(flow, mvn_dist, val_db):
    flow.eval()
    w = 100
    x = np.linspace(-3, 3, w)
    y = np.linspace(-3, 3, w)
    xx, yy = np.meshgrid(x, y)
    mus = np.stack([xx.flatten(), yy.flatten()]).transpose()
    mus = torch.Tensor(mus.astype(np.float32)).to(DEVICE)

    val_set = val_db.data_gen(batch_size=cmd_args.batch_size,
                              phase='val',
                              auto_reset=False,
                              shuffle=False)

    ob_list = []
    for _ in range(cmd_args.train_samples):
        ob_list.append(next(val_set))
    lm_val_gen = lambda: iter(ob_list)

    particles = mvn_dist.get_samples(cmd_args.num_particles)
    densities = mvn_dist.get_log_pdf(particles)
    kde = KDE(particles)
    log_scores = kde.log_pdf(mus)
    est_scores = torch.softmax(log_scores.view(-1),
                               -1).view(w, w).data.cpu().numpy()
    flow_heats = [est_scores]
    val_gen = lm_val_gen()
    for t, ob in enumerate(val_gen):
        # evaluate
        particles, densities = flow(particles,
                                    densities,
                                    prior_samples=particles,
                                    ob_m=ob)
        kde = KDE(particles)
        log_scores = kde.log_pdf(mus)
        est_scores = torch.softmax(log_scores.view(-1),
                                   -1).view(w, w).data.cpu().numpy()
        flow_heats.append(est_scores)

    log_scores = val_db.log_prior(mus)
    scores = torch.softmax(log_scores.view(-1), -1).view(w,
                                                         w).data.cpu().numpy()
    true_heats = [scores] + get_normalized_heatmaps(mvn_dist, lm_val_gen,
                                                    val_db, mus)

    out_dir = os.path.join(cmd_args.save_dir, 'video-%d' % cmd_args.seed)
    if not os.path.isdir(out_dir):
        os.makedirs(out_dir)
    np.save(os.path.join(out_dir, 'flow_heats.npy'), flow_heats)
    np.save(os.path.join(out_dir, 'true_heats.npy'), true_heats)

    images = list(zip(flow_heats, true_heats))
    save_prefix = os.path.join(out_dir, 'heat-step')
    plot_image_seqs(images, save_prefix)
    create_video(save_prefix, output_name=os.path.join(out_dir, 'traj.mp4'))
Exemplo n.º 6
0
def vis_flow(flow, mvn_dist, db, offline_val):
    flow.eval()
    w = 100
    x = np.linspace(-3, 3, w)
    y = np.linspace(-3, 3, w)
    xx, yy = np.meshgrid(x, y)
    mus = np.stack([xx.flatten(), yy.flatten()]).transpose()
    mus = torch.Tensor(mus.astype(np.float32)).to(DEVICE)

    log_scores = mvn_dist.get_log_pdf(mus)
    scores = torch.softmax(log_scores.view(-1), -1).view(w,
                                                         w).data.cpu().numpy()

    val_gen = iter(offline_val)

    pos_mu = db.prior_mu
    pos_cov = db.prior_sigma * db.prior_sigma
    pos_cov = torch.diag(pos_cov.reshape(db.dim))
    true_heats = [scores]

    particles = mvn_dist.get_samples(cmd_args.num_particles)
    densities = mvn_dist.get_log_pdf(particles)
    kde = KDE(particles)
    log_scores = kde.log_pdf(mus)
    est_scores = torch.softmax(log_scores.view(-1),
                               -1).view(w, w).data.cpu().numpy()
    flow_heats = [est_scores]

    for t, ob in enumerate(val_gen):
        # evaluate
        pos_mu, pos_cov = db.get_true_new_posterior(ob, pos_mu, pos_cov)
        particles, densities = flow(particles,
                                    densities,
                                    prior_samples=particles,
                                    ob_m=ob)
        kde = KDE(particles)
        log_scores = kde.log_pdf(mus)
        est_scores = torch.softmax(log_scores.view(-1),
                                   -1).view(w, w).data.cpu().numpy()
        flow_heats.append(est_scores)
        dist = torch.distributions.MultivariateNormal(pos_mu, pos_cov)
        log_scores = dist.log_prob(mus)
        exact_scores = torch.softmax(log_scores.view(-1),
                                     -1).view(w, w).data.cpu().numpy()
        true_heats.append(exact_scores)
    images = list(zip(flow_heats, true_heats))
    save_prefix = os.path.join(cmd_args.save_dir, 'heat-step')
    plot_image_seqs(images, save_prefix)
    create_video(save_prefix,
                 output_name=os.path.join(cmd_args.save_dir, 'traj.mp4'))
    flow.train()
Exemplo n.º 7
0
def supervised_train_loop(cmd_args,
                          db,
                          prior_dist,
                          flow,
                          ob_net,
                          coeff,
                          eval_func,
                          test_locs=[]):
    print('coeff:', coeff)
    optimizer = optim.Adam(chain(flow.parameters(), ob_net.parameters()),
                           lr=cmd_args.learning_rate,
                           weight_decay=cmd_args.weight_decay)

    best_val_loss = None
    if cmd_args.stage_dist_metric == 'ce':
        fn_log_prior = lambda x, y: KDE(y, coeff=cmd_args.kernel_bw).log_pdf(x)
    else:
        fn_log_prior = lambda x, y: -MMD(x, y, bandwidth=cmd_args.kernel_bw)
    scheduler = ReduceLROnPlateau(optimizer,
                                  'min',
                                  factor=0.1,
                                  patience=2,
                                  min_lr=1e-6,
                                  verbose=True)
    num_obs = 0
    for epoch in range(cmd_args.num_epochs):
        train_gen = db.data_gen(batch_size=cmd_args.batch_size,
                                phase='train',
                                auto_reset=True,
                                shuffle=True)
        pbar = tqdm(range(cmd_args.n_stages))
        particles = prior_dist.get_samples(cmd_args.num_particles)
        densities = prior_dist.get_log_pdf(particles)
        for it in pbar:
            loss = 0.0
            feats_all = []
            labels_all = []
            particles = particles.detach()
            densities = densities.detach()
            prior_particles = particles
            optimizer.zero_grad()
            acc = 0.0
            for l in range(cmd_args.stage_len):
                feats, labels = next(train_gen)
                num_obs += feats.shape[0]
                ob = ob_net((feats, labels))
                if l + 1 == cmd_args.stage_len:
                    pred = torch.sigmoid(torch.matmul(feats, particles.t()))
                    pred = torch.mean(pred, dim=-1, keepdim=True)
                    acc = (labels < 0.5) == (pred < 0.5)
                    acc = torch.mean(acc.float())
                new_particles, new_densities = flow(particles,
                                                    densities,
                                                    prior_samples=particles,
                                                    ob_m=ob)
                feats_all.append(feats)
                labels_all.append(labels)
                feats = torch.cat(feats_all, dim=0)
                labels = torch.cat(labels_all, dim=0)

                ll = torch.mean(
                    torch.sum(db.log_likelihood((feats, labels),
                                                new_particles),
                              dim=0))
                if it == 0:
                    log_prior = prior_dist.get_log_pdf(new_particles) * coeff
                else:
                    log_prior = fn_log_prior(new_particles, prior_particles)
                loss += coeff * torch.mean(new_densities) - ll - torch.mean(
                    log_prior)
                particles = new_particles
                densities = new_densities
            loss.backward()
            optimizer.step()
            if len(test_locs) and num_obs >= test_locs[0]:
                eval_func(num_obs, particles, db, 'test')
                test_locs = test_locs[1:]
            pbar.set_description(
                'epoch %.2f, loss: %.4f, last_acc: %.4f' %
                (epoch + float(it + 1) / cmd_args.n_stages, loss.item(), acc))
        if (epoch + 1) * cmd_args.n_stages % cmd_args.iters_per_eval == 0:
            loss = 0.0
            if cmd_args.num_vals == 1:
                loss += eval_func(num_obs, particles, db, 'val')
            else:
                for i in range(cmd_args.num_vals):
                    print('evaluating val-%d' % i)
                    loss += eval_func(num_obs, particles, db, 'val-%d' % i)
            loss /= cmd_args.num_vals
            scheduler.step(loss)
            eval_func(num_obs, particles, db, 'test')
            if best_val_loss is None or loss < best_val_loss:
                best_val_loss = loss
                print('saving model with best valid error')
                torch.save(
                    flow.state_dict(),
                    os.path.join(cmd_args.save_dir, 'best_val_model.dump'))
Exemplo n.º 8
0
def _kde_prior_func(particles):
    kde = KDE(particles)
    cur_func_prior = lambda x, y: kde.log_pdf(x)
    return cur_func_prior