Beispiel #1
0
def visualize_times():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)
            xx = torch.linspace(-10, 10, 10000).view(-1, 1)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            plt.plot(xx.view(-1).cpu().numpy(),
                     generated_p.view(-1).exp().cpu().numpy(),
                     label='Model')

            utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
            plt.savefig(
                os.path.join(args.save, 'test_times', 'figs',
                             '{:04d}.jpg'.format(i)))
            plt.close()
    trajectory_to_video(os.path.join(args.save, 'test_times', 'figs'))
Beispiel #2
0
def compute_loss(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)

    z, delta_logp = model(x, zero)  # run model forward

    logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss
Beispiel #3
0
def compute_loss(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)

    z, change = model(x, zero)  # run model forward

    logpx = standard_normal_logprob(z).view(z.shape[0], -1).sum(
        1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss
Beispiel #4
0
def compute_loss(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)
    lec = None if (args.poly_coef is None or not model.training) else torch.tensor(0.0).to(x)

    z, delta_logp, lec = model(x, zero, lec)

    logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss, lec
Beispiel #5
0
def compute_loss(args, model, batch_size=args.batch_size):

    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)
    z, change = model(x, zero)

    logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change
    loss = -torch.mean(logpx)
    return loss
Beispiel #6
0
def visualize_evolution():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    viz_times_np = viz_times[1:].detach().cpu().numpy()
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    xx_np = xx.detach().cpu().numpy()
    xs, ys = np.meshgrid(xx, viz_times_np)
    #xx,yy = np.meshgrid(args.num_particles, viz_times_np )
    #all_evolutions = np.zeros((args.ntimes-1,args.num_particles))
    all_evolutions = np.zeros((args.num_particles, args.ntimes - 1))
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)
            #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            generated_p = generated_p.detach()
            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')
            cur_evolution = generated_p.view(-1).exp().cpu().numpy()

            #all_evolutions[i]= np.array(cur_evolution)
            all_evolutions[:, i] = np.array(cur_evolution)
        #xx = np.array(xx.detach().cpu().numpy())
        #yy = np.array(yy)
        plt.figure(dpi=1200)
        plt.clf()
        all_evolutions = all_evolutions.astype('float32')
        print(xs.shape)
        print(ys.shape)
        print(all_evolutions.shape)
        #plt.pcolormesh(ys, xs, all_evolutions)
        plt.pcolormesh(xs, ys, all_evolutions.transpose())

        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'evolution.jpg'.format(i)))
        plt.close()
Beispiel #7
0
def visualize_particle_flow():
    model = build_model_tabular(args, 1).to(device)
    set_cnf_options(args, model)

    checkpt = torch.load(os.path.join(args.save, 'checkpt.pth'))
    model.load_state_dict(checkpt['state_dict'])
    model.to(device)

    viz_times = torch.linspace(0., args.time_length, args.ntimes)
    errors = []
    xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1)
    zs = []
    #zs.append(xx.view(-1).cpu().numpy())
    with torch.no_grad():
        for i, t in enumerate(tqdm(viz_times[1:])):
            model.eval()
            set_cnf_options(args, model)

            #generated_p = model_density(xx, model)
            generated_p = 0
            for cnf in model.chain:
                xx = xx.to(device)
                z, delta_logp = cnf(xx,
                                    torch.zeros_like(xx),
                                    integration_times=torch.Tensor([0, t]))
                generated_p = standard_normal_logprob(z) - delta_logp

            zs.append(z.cpu().numpy())

            #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model')

            #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i)))
            #plt.close()

    zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles)
    viz_t = viz_times[1:].numpy()
    #print(zs)
    plt.figure(dpi=1200)

    plt.clf()
    #plt.plot(viz_t , zs[:,0])
    with sns.color_palette("Blues_d"):
        plt.plot(viz_t, zs)
        plt.xlabel("Test Time")
        #plt.tight_layout()
        utils.makedirs(os.path.join(args.save, 'test_times', 'figs'))
        plt.savefig(
            os.path.join(args.save, 'test_times', 'figs',
                         'particle_trajectory.jpg'.format(i)))
        plt.close()
def compute_bits_per_dim(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)

    z, delta_logp, reg_states = model(x, zero)  # run model forward

    reg_states = tuple(torch.mean(rs) for rs in reg_states)

    logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp

    logpx_per_dim = torch.sum(logpx) / x.nelement()  # averaged over batches
    bits_per_dim = -(logpx_per_dim - np.log(nvals)) / np.log(2)

    return bits_per_dim, (x, z), reg_states
def calc_logpx(model, x):
    # load data
    #x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    #x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    return logpx
Beispiel #10
0
def compute_loss(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # load data
    x = toy_data.inf_train_gen(args.data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss
Beispiel #11
0
def compute_bits_per_dim(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)

    # Don't use data parallelize if batch size is small.
    # if x.shape[0] < 200:
    #     model = model.module

    z, delta_logp = model(x, zero)  # run model forward

    logpz = standard_normal_logprob(z).view(z.shape[0],
                                            -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp

    logpx_per_dim = torch.sum(logpx) / x.nelement()  # averaged over batches
    bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2)

    return bits_per_dim
Beispiel #12
0
def my_compute_loss(dataLoaderIt, model):
    # load data
    truth, reco = next(dataLoaderIt)
    x = reco
    # I think this is the prior
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero)

    # compute log q(z) This really should be the prior
    #logpz = truth
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss
Beispiel #13
0
def compute_loss(args, model, data, batch_size=None, end_times=None):
    if batch_size is None: batch_size = args.batch_size

    # load data
    x = sample_data(data, batch_size=batch_size)
    x = torch.from_numpy(x).type(torch.float32).to(device)
    zero = torch.zeros(x.shape[0], 1).to(x)

    # transform to z
    z, delta_logp = model(x, zero, integration_times=integration_times)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logpx = logpz - delta_logp
    loss = -torch.mean(logpx)
    return loss
Beispiel #14
0
def compute_loss_wgf(args, model, dim, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    z = torch.randn(batch_size, dim, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    # mu_0 = torch.zeros(2, dtype=torch.float32, device=device)
    # sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device)
    # score_error_0 = torch.zeros(1, dtype=torch.float32, device=device)
    x, logp_x, score_x, wgf_reg = model(z,
                                        logpz=logp_z,
                                        score=score_z,
                                        wgf_reg=wgf_reg_0)

    nfe = count_nfe(model)

    return wgf_reg / nfe
Beispiel #15
0
def compute_kl_divergence(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # TODO: should have an input specifying the data dimension. Now it is fixed to 2
    z = torch.randn(batch_size, 2, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0)
    x, logp_x, score_x, wgf_reg = model(z,
                                        logpz=logp_z,
                                        score=score_z,
                                        wgf_reg=wgf_reg_0)

    # logp_true_x = gaussian_mixture_logprob(x)
    logp_true_x = gaussian_logprob(x).sum(1, keepdim=True).to(z)
    # print(torch.mean(x, 0))
    return torch.mean(logp_x - logp_true_x)
Beispiel #16
0
def compute_bits_per_dim(x, model):
    zero = torch.zeros(x.shape[0], 1).to(x)
    lec = None if (args.poly_coef is None or not model.training) else torch.tensor(0.0).to(x)

    # Don't use data parallelize if batch size is small.
    # if x.shape[0] < 200:
    #     model = model.module

    z, delta_logp, lec = model(x, zero, lec)

    logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True)  # logp(z)
    logpx = logpz - delta_logp

    logpx_per_dim = torch.sum(logpx) / x.nelement()  # averaged over batches
    bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2)
    lec = lec / (x[0].nelement() * np.log(2)) if lec else None

    return bits_per_dim, lec
Beispiel #17
0
def score_error_wgf(args, model, batch_size=None):
    if batch_size is None: batch_size = args.batch_size

    # TODO: should have an input specifying the data dimension. Now it is fixed to 2
    z = torch.randn(batch_size, 2, dtype=torch.float32, device=device)
    logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z)
    score_z = standard_normal_score(z).to(z)
    wgf_reg_0 = torch.tensor(0, device=device)
    mu_0 = torch.zeros(2, dtype=torch.float32, device=device)
    sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device)
    score_error_0 = torch.zeros(1, dtype=torch.float32, device=device)
    # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0)
    x, logp_x, score_x, wgf_reg, mu, sigma_half, score_error = \
        model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0, mu_0=mu_0, sigma_half_0=sigma_half_0,
              score_error_0=score_error_0)

    nfe = count_nfe(model)

    return score_error / nfe
Beispiel #18
0
def compute_loss(args, model):
    """
    Compute loss by integrating backwards from the last time step
    At each time step integrate back one time step, and concatenate that
    to samples of the empirical distribution at that previous timestep
    repeating over and over to calculate the likelihood of samples in
    later timepoints iteratively, making sure that the ODE is evaluated
    at every time step to calculate those later points.
    """
    deltas = []
    for i, (itp, tp) in enumerate(zip(int_tps[::-1], timepoints[::-1])): # tp counts down from last
        integration_times = torch.tensor([itp-args.time_length, itp]).type(torch.float32).to(device)
        print(integration_times)

        # load data
        x = train_sampler(tp)
        x = torch.from_numpy(x).type(torch.float32).to(device)
        if i > 0:
            x = torch.cat((z, x))
        zero = torch.zeros(x.shape[0], 1).to(x)

        # transform to previous timepoint
        z, delta_logp = model(x, zero, integration_times=integration_times)
        deltas.append(delta_logp)

    # compute log q(z)
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)

    logps = [logpz]
    losses = []
    for delta_logp in deltas[::-1]:
        logpx = logps[-1] - delta_logp
        logps.append(logpx[:-args.batch_size])
        losses.append(-torch.mean(logpx[-args.batch_size:]))
    #weights = torch.tensor([0,0,0,0,1]).to(logpx)
    #weights = torch.tensor([1,0,0,0,0]).to(logpx)
    #weights = torch.tensor([1,1,1,1,1]).to(logpx)
    weights = torch.tensor([3,2,1]).to(logpx)
    loss = torch.sum(torch.stack(losses) * weights)
    return loss
Beispiel #19
0
def compute_loss(args, model, growth_model):
    """
    Compute loss by integrating backwards from the last time step
    At each time step integrate back one time step, and concatenate that
    to samples of the empirical distribution at that previous timestep
    repeating over and over to calculate the likelihood of samples in
    later timepoints iteratively, making sure that the ODE is evaluated
    at every time step to calculate those later points.

    The growth model is a single model of time independent cell growth / 
    death rate defined as a variation from uniform.
    """

    # Backward pass accumulating losses, previous state and deltas
    deltas = []
    xs = []
    zs = []
    for i, (itp, tp) in enumerate(zip(int_tps[::-1], timepoints[::-1])): # tp counts down from last
        integration_times = torch.tensor([itp-args.time_length, itp]).type(torch.float32).to(device)

        # load data
        x = train_sampler(tp)
        x = torch.from_numpy(x).type(torch.float32).to(device)
        xs.append(x)
        if i > 0:
            x = torch.cat((z, x))
            zs.append(z)
        zero = torch.zeros(x.shape[0], 1).to(x)

        # transform to previous timepoint
        z, delta_logp = model(x, zero, integration_times=integration_times)
        deltas.append(delta_logp)

    # compute log growth probability
    xs = torch.cat(xs)
    #growth_zs, growth_delta_logps = growth_model(xs, torch.zeros(xs.shape[0], 1).to(xs)) # Use default timestep
    #growth_logpzs = uniform_logprob(growth_zs).sum(1, keepdim=True)
    #growth_logpzs = standard_normal_logprob(growth_zs).sum(1, keepdim=True)
    #growth_logpxs = growth_logpzs - growth_delta_logps

    # compute log q(z) with forward pass
    logpz = standard_normal_logprob(z).sum(1, keepdim=True)
    logps = [logpz]
    
    # build growth rates
    growthrates = [torch.ones_like(logpz)]
    for z_state, tp in zip(zs[::-1], timepoints[::-1][1:]):

        full_state = torch.cat([z_state, tp * torch.ones(z_state.shape[0],1).to(z_state)], 1)
        growthrates.append(growth_model(full_state))
    losses = []
    for gr, delta_logp in zip(growthrates, deltas[::-1]):
        #logpx = logps[-1] - delta_logp# + gr
        logpx = logps[-1] - delta_logp + torch.log(gr)
        logps.append(logpx[:-args.batch_size])
        losses.append(-torch.mean(logpx[-args.batch_size:]))
    #weights = torch.tensor([1,1,10]).to(logpx)
    #weights = torch.tensor([2,1]).to(logpx)
    losses = torch.stack(losses)
    weights = torch.ones_like(losses).to(logpx)
    losses = torch.mean(losses * weights)
    #losses = torch.mean(losses)

    # Add a hinge loss on the growth model so that we prefer sums over the batch
    # to be not too much more than 1 on average
    reg = 0.
    for gr in growthrates[1:]:
        reg += F.relu(torch.mean(gr[-1000:])) # Only put a loss on the last portion with real data
        #reg += F.relu(torch.mean(gr[-1000:]) - 1) # Only put a loss on the last portion with real data
    #mean_growthrate = torch.mean(torch.cat(growthrates[1:]))
    #reg = F.relu(mean_growthrate - 1)
    #print(reg.item())
    #losses += 3*reg
    #losses += 0.001 * torch.mean(gr[-1000:] ** 2)

    # Direction regularization
    if args.vecint:
        similarity_loss = 0
        for i, (itp, tp) in enumerate(zip(int_tps, timepoints)):
            itp = torch.tensor(itp).type(torch.float32).to(device)
            x = dir_train_sampler(tp)
            x = torch.from_numpy(x).type(torch.float32).to(device)
            y,zz = torch.split(x, 2, dim=1)
            y = y + torch.randn_like(y) * 0.1
            # This is really hacky but I don't know a better way (alex)
            direction = model.chain[0].odefunc.odefunc.diffeq(itp, y)
            similarity_loss += 1 - torch.mean(F.cosine_similarity(direction, zz))
        print(similarity_loss)
        losses += similarity_loss * args.vecint

    #loss = loss + vec_reg_loss


    #growth_losses = -torch.mean(growth_logpxs)
    #alpha = torch.tensor(args.alpha).to(growth_losses)
    #loss = (1 - alpha) * losses + alpha * growth_losses
    #loss = losses + growth_losses
    return losses#, growth_losses
Beispiel #20
0
                    torch.save(
                        {
                            'args': args,
                            'state_dict': model.state_dict(),
                        }, os.path.join(args.save, 'checkpt.pth'))
                model.train()

        if itr % args.viz_freq == 0:
            with torch.no_grad():
                model.eval()
                p_gtr = inf_train_gen(sample_loader,
                                      batch_size=2000).float().to(device)
                sample_fn, density_fn = get_transforms(model)
                prior = torch.randn_like(p_gtr).float().to(device)
                p_samples = sample_fn(prior)
                p_density = standard_normal_logprob(density_fn(p_gtr)).sum(
                    1, keepdim=True)

                buf = visualize_point_clouds(p_samples, p_gtr, name=args.data)
                writer.add_image('samples', buf, itr)
                model.train()

        end = time.time()

    logger.info('Training has finished.')

    # save_traj_dir = os.path.join(args.save, 'trajectory')
    # logger.info('Plotting trajectory to {}'.format(save_traj_dir))
    # data_samples = inf_train_gen(sample_loader, batch_size=2000)
    # save_trajectory(model, data_samples, save_traj_dir, device=device)
    # trajectory_to_video(save_traj_dir)
Beispiel #21
0
def model_sample(model, batch_size):
    z = torch.randn(batch_size, 1)
    logqz = standard_normal_logprob(z)
    x, logqx = model(z, logqz, reverse=True)
    return x, logqx
Beispiel #22
0
def model_density(x, model):
    x = x.to(device)
    z, delta_logp = model(x, torch.zeros_like(x))
    logpx = standard_normal_logprob(z) - delta_logp
    return logpx
Beispiel #23
0
    with torch.no_grad():

        if args.validate:
            cleanbpd = 0.
            dirtybpd = 0.
            for i, (x, y) in enumerate(test_loader):

                xdirty = add_noise(cvt(255 * x), nbits=args.nbits)
                xclean = shift(cvt(255 * x), nbits=args.nbits)

                # Dirty
                # -----
                zero = torch.zeros(xdirty.shape[0], 1).to(xdirty)
                z, delta_logp, _ = model(xdirty, zero)  # run model forward

                logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(
                    1, keepdim=True)  # logp(z)
                logpx = logpz - delta_logp

                logpx_per_dim = torch.sum(
                    logpx) / x.nelement()  # averaged over batches
                bits_per_dim = -(logpx_per_dim - np.log(nvals)) / np.log(2)
                dirtybpd = bits_per_dim.detach().cpu().item() / (i + 1) + i / (
                    i + 1) * dirtybpd

                # Clean
                # -----
                zero = torch.zeros(xclean.shape[0], 1).to(xclean)
                z, delta_logp, _ = model(xclean, zero)  # run model forward

                logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(
                    1, keepdim=True)  # logp(z)