def visualize_times(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) xx = torch.linspace(-10, 10, 10000).view(-1, 1) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', '{:04d}.jpg'.format(i))) plt.close() trajectory_to_video(os.path.join(args.save, 'test_times', 'figs'))
def compute_loss(x, model): zero = torch.zeros(x.shape[0], 1).to(x) z, delta_logp = model(x, zero) # run model forward logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss
def compute_loss(x, model): zero = torch.zeros(x.shape[0], 1).to(x) z, change = model(x, zero) # run model forward logpx = standard_normal_logprob(z).view(z.shape[0], -1).sum( 1, keepdim=True) - change loss = -torch.mean(logpx) return loss
def compute_loss(x, model): zero = torch.zeros(x.shape[0], 1).to(x) lec = None if (args.poly_coef is None or not model.training) else torch.tensor(0.0).to(x) z, delta_logp, lec = model(x, zero, lec) logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss, lec
def compute_loss(args, model, batch_size=args.batch_size): x = toy_data.inf_train_gen(args.data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) z, change = model(x, zero) logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change loss = -torch.mean(logpx) return loss
def visualize_evolution(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] viz_times_np = viz_times[1:].detach().cpu().numpy() xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) xx_np = xx.detach().cpu().numpy() xs, ys = np.meshgrid(xx, viz_times_np) #xx,yy = np.meshgrid(args.num_particles, viz_times_np ) #all_evolutions = np.zeros((args.ntimes-1,args.num_particles)) all_evolutions = np.zeros((args.num_particles, args.ntimes - 1)) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp generated_p = generated_p.detach() #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') cur_evolution = generated_p.view(-1).exp().cpu().numpy() #all_evolutions[i]= np.array(cur_evolution) all_evolutions[:, i] = np.array(cur_evolution) #xx = np.array(xx.detach().cpu().numpy()) #yy = np.array(yy) plt.figure(dpi=1200) plt.clf() all_evolutions = all_evolutions.astype('float32') print(xs.shape) print(ys.shape) print(all_evolutions.shape) #plt.pcolormesh(ys, xs, all_evolutions) plt.pcolormesh(xs, ys, all_evolutions.transpose()) utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'evolution.jpg'.format(i))) plt.close()
def visualize_particle_flow(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) zs = [] #zs.append(xx.view(-1).cpu().numpy()) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp zs.append(z.cpu().numpy()) #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i))) #plt.close() zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles) viz_t = viz_times[1:].numpy() #print(zs) plt.figure(dpi=1200) plt.clf() #plt.plot(viz_t , zs[:,0]) with sns.color_palette("Blues_d"): plt.plot(viz_t, zs) plt.xlabel("Test Time") #plt.tight_layout() utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'particle_trajectory.jpg'.format(i))) plt.close()
def compute_bits_per_dim(x, model): zero = torch.zeros(x.shape[0], 1).to(x) z, delta_logp, reg_states = model(x, zero) # run model forward reg_states = tuple(torch.mean(rs) for rs in reg_states) logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) logpx = logpz - delta_logp logpx_per_dim = torch.sum(logpx) / x.nelement() # averaged over batches bits_per_dim = -(logpx_per_dim - np.log(nvals)) / np.log(2) return bits_per_dim, (x, z), reg_states
def calc_logpx(model, x): # load data #x = toy_data.inf_train_gen(args.data, batch_size=batch_size) #x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) # transform to z z, delta_logp = model(x, zero) # compute log q(z) logpz = standard_normal_logprob(z).sum(1, keepdim=True) logpx = logpz - delta_logp return logpx
def compute_loss(args, model, batch_size=None): if batch_size is None: batch_size = args.batch_size # load data x = toy_data.inf_train_gen(args.data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) # transform to z z, delta_logp = model(x, zero) # compute log q(z) logpz = standard_normal_logprob(z).sum(1, keepdim=True) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss
def compute_bits_per_dim(x, model): zero = torch.zeros(x.shape[0], 1).to(x) # Don't use data parallelize if batch size is small. # if x.shape[0] < 200: # model = model.module z, delta_logp = model(x, zero) # run model forward logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) logpx = logpz - delta_logp logpx_per_dim = torch.sum(logpx) / x.nelement() # averaged over batches bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2) return bits_per_dim
def my_compute_loss(dataLoaderIt, model): # load data truth, reco = next(dataLoaderIt) x = reco # I think this is the prior zero = torch.zeros(x.shape[0], 1).to(x) # transform to z z, delta_logp = model(x, zero) # compute log q(z) This really should be the prior #logpz = truth logpz = standard_normal_logprob(z).sum(1, keepdim=True) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss
def compute_loss(args, model, data, batch_size=None, end_times=None): if batch_size is None: batch_size = args.batch_size # load data x = sample_data(data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) # transform to z z, delta_logp = model(x, zero, integration_times=integration_times) # compute log q(z) logpz = standard_normal_logprob(z).sum(1, keepdim=True) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss
def compute_loss_wgf(args, model, dim, batch_size=None): if batch_size is None: batch_size = args.batch_size z = torch.randn(batch_size, dim, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) # mu_0 = torch.zeros(2, dtype=torch.float32, device=device) # sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device) # score_error_0 = torch.zeros(1, dtype=torch.float32, device=device) x, logp_x, score_x, wgf_reg = model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0) nfe = count_nfe(model) return wgf_reg / nfe
def compute_kl_divergence(args, model, batch_size=None): if batch_size is None: batch_size = args.batch_size # TODO: should have an input specifying the data dimension. Now it is fixed to 2 z = torch.randn(batch_size, 2, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0) x, logp_x, score_x, wgf_reg = model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0) # logp_true_x = gaussian_mixture_logprob(x) logp_true_x = gaussian_logprob(x).sum(1, keepdim=True).to(z) # print(torch.mean(x, 0)) return torch.mean(logp_x - logp_true_x)
def compute_bits_per_dim(x, model): zero = torch.zeros(x.shape[0], 1).to(x) lec = None if (args.poly_coef is None or not model.training) else torch.tensor(0.0).to(x) # Don't use data parallelize if batch size is small. # if x.shape[0] < 200: # model = model.module z, delta_logp, lec = model(x, zero, lec) logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum(1, keepdim=True) # logp(z) logpx = logpz - delta_logp logpx_per_dim = torch.sum(logpx) / x.nelement() # averaged over batches bits_per_dim = -(logpx_per_dim - np.log(256)) / np.log(2) lec = lec / (x[0].nelement() * np.log(2)) if lec else None return bits_per_dim, lec
def score_error_wgf(args, model, batch_size=None): if batch_size is None: batch_size = args.batch_size # TODO: should have an input specifying the data dimension. Now it is fixed to 2 z = torch.randn(batch_size, 2, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) mu_0 = torch.zeros(2, dtype=torch.float32, device=device) sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device) score_error_0 = torch.zeros(1, dtype=torch.float32, device=device) # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0) x, logp_x, score_x, wgf_reg, mu, sigma_half, score_error = \ model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0, mu_0=mu_0, sigma_half_0=sigma_half_0, score_error_0=score_error_0) nfe = count_nfe(model) return score_error / nfe
def compute_loss(args, model): """ Compute loss by integrating backwards from the last time step At each time step integrate back one time step, and concatenate that to samples of the empirical distribution at that previous timestep repeating over and over to calculate the likelihood of samples in later timepoints iteratively, making sure that the ODE is evaluated at every time step to calculate those later points. """ deltas = [] for i, (itp, tp) in enumerate(zip(int_tps[::-1], timepoints[::-1])): # tp counts down from last integration_times = torch.tensor([itp-args.time_length, itp]).type(torch.float32).to(device) print(integration_times) # load data x = train_sampler(tp) x = torch.from_numpy(x).type(torch.float32).to(device) if i > 0: x = torch.cat((z, x)) zero = torch.zeros(x.shape[0], 1).to(x) # transform to previous timepoint z, delta_logp = model(x, zero, integration_times=integration_times) deltas.append(delta_logp) # compute log q(z) logpz = standard_normal_logprob(z).sum(1, keepdim=True) logps = [logpz] losses = [] for delta_logp in deltas[::-1]: logpx = logps[-1] - delta_logp logps.append(logpx[:-args.batch_size]) losses.append(-torch.mean(logpx[-args.batch_size:])) #weights = torch.tensor([0,0,0,0,1]).to(logpx) #weights = torch.tensor([1,0,0,0,0]).to(logpx) #weights = torch.tensor([1,1,1,1,1]).to(logpx) weights = torch.tensor([3,2,1]).to(logpx) loss = torch.sum(torch.stack(losses) * weights) return loss
def compute_loss(args, model, growth_model): """ Compute loss by integrating backwards from the last time step At each time step integrate back one time step, and concatenate that to samples of the empirical distribution at that previous timestep repeating over and over to calculate the likelihood of samples in later timepoints iteratively, making sure that the ODE is evaluated at every time step to calculate those later points. The growth model is a single model of time independent cell growth / death rate defined as a variation from uniform. """ # Backward pass accumulating losses, previous state and deltas deltas = [] xs = [] zs = [] for i, (itp, tp) in enumerate(zip(int_tps[::-1], timepoints[::-1])): # tp counts down from last integration_times = torch.tensor([itp-args.time_length, itp]).type(torch.float32).to(device) # load data x = train_sampler(tp) x = torch.from_numpy(x).type(torch.float32).to(device) xs.append(x) if i > 0: x = torch.cat((z, x)) zs.append(z) zero = torch.zeros(x.shape[0], 1).to(x) # transform to previous timepoint z, delta_logp = model(x, zero, integration_times=integration_times) deltas.append(delta_logp) # compute log growth probability xs = torch.cat(xs) #growth_zs, growth_delta_logps = growth_model(xs, torch.zeros(xs.shape[0], 1).to(xs)) # Use default timestep #growth_logpzs = uniform_logprob(growth_zs).sum(1, keepdim=True) #growth_logpzs = standard_normal_logprob(growth_zs).sum(1, keepdim=True) #growth_logpxs = growth_logpzs - growth_delta_logps # compute log q(z) with forward pass logpz = standard_normal_logprob(z).sum(1, keepdim=True) logps = [logpz] # build growth rates growthrates = [torch.ones_like(logpz)] for z_state, tp in zip(zs[::-1], timepoints[::-1][1:]): full_state = torch.cat([z_state, tp * torch.ones(z_state.shape[0],1).to(z_state)], 1) growthrates.append(growth_model(full_state)) losses = [] for gr, delta_logp in zip(growthrates, deltas[::-1]): #logpx = logps[-1] - delta_logp# + gr logpx = logps[-1] - delta_logp + torch.log(gr) logps.append(logpx[:-args.batch_size]) losses.append(-torch.mean(logpx[-args.batch_size:])) #weights = torch.tensor([1,1,10]).to(logpx) #weights = torch.tensor([2,1]).to(logpx) losses = torch.stack(losses) weights = torch.ones_like(losses).to(logpx) losses = torch.mean(losses * weights) #losses = torch.mean(losses) # Add a hinge loss on the growth model so that we prefer sums over the batch # to be not too much more than 1 on average reg = 0. for gr in growthrates[1:]: reg += F.relu(torch.mean(gr[-1000:])) # Only put a loss on the last portion with real data #reg += F.relu(torch.mean(gr[-1000:]) - 1) # Only put a loss on the last portion with real data #mean_growthrate = torch.mean(torch.cat(growthrates[1:])) #reg = F.relu(mean_growthrate - 1) #print(reg.item()) #losses += 3*reg #losses += 0.001 * torch.mean(gr[-1000:] ** 2) # Direction regularization if args.vecint: similarity_loss = 0 for i, (itp, tp) in enumerate(zip(int_tps, timepoints)): itp = torch.tensor(itp).type(torch.float32).to(device) x = dir_train_sampler(tp) x = torch.from_numpy(x).type(torch.float32).to(device) y,zz = torch.split(x, 2, dim=1) y = y + torch.randn_like(y) * 0.1 # This is really hacky but I don't know a better way (alex) direction = model.chain[0].odefunc.odefunc.diffeq(itp, y) similarity_loss += 1 - torch.mean(F.cosine_similarity(direction, zz)) print(similarity_loss) losses += similarity_loss * args.vecint #loss = loss + vec_reg_loss #growth_losses = -torch.mean(growth_logpxs) #alpha = torch.tensor(args.alpha).to(growth_losses) #loss = (1 - alpha) * losses + alpha * growth_losses #loss = losses + growth_losses return losses#, growth_losses
torch.save( { 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() p_gtr = inf_train_gen(sample_loader, batch_size=2000).float().to(device) sample_fn, density_fn = get_transforms(model) prior = torch.randn_like(p_gtr).float().to(device) p_samples = sample_fn(prior) p_density = standard_normal_logprob(density_fn(p_gtr)).sum( 1, keepdim=True) buf = visualize_point_clouds(p_samples, p_gtr, name=args.data) writer.add_image('samples', buf, itr) model.train() end = time.time() logger.info('Training has finished.') # save_traj_dir = os.path.join(args.save, 'trajectory') # logger.info('Plotting trajectory to {}'.format(save_traj_dir)) # data_samples = inf_train_gen(sample_loader, batch_size=2000) # save_trajectory(model, data_samples, save_traj_dir, device=device) # trajectory_to_video(save_traj_dir)
def model_sample(model, batch_size): z = torch.randn(batch_size, 1) logqz = standard_normal_logprob(z) x, logqx = model(z, logqz, reverse=True) return x, logqx
def model_density(x, model): x = x.to(device) z, delta_logp = model(x, torch.zeros_like(x)) logpx = standard_normal_logprob(z) - delta_logp return logpx
with torch.no_grad(): if args.validate: cleanbpd = 0. dirtybpd = 0. for i, (x, y) in enumerate(test_loader): xdirty = add_noise(cvt(255 * x), nbits=args.nbits) xclean = shift(cvt(255 * x), nbits=args.nbits) # Dirty # ----- zero = torch.zeros(xdirty.shape[0], 1).to(xdirty) z, delta_logp, _ = model(xdirty, zero) # run model forward logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum( 1, keepdim=True) # logp(z) logpx = logpz - delta_logp logpx_per_dim = torch.sum( logpx) / x.nelement() # averaged over batches bits_per_dim = -(logpx_per_dim - np.log(nvals)) / np.log(2) dirtybpd = bits_per_dim.detach().cpu().item() / (i + 1) + i / ( i + 1) * dirtybpd # Clean # ----- zero = torch.zeros(xclean.shape[0], 1).to(xclean) z, delta_logp, _ = model(xclean, zero) # run model forward logpz = standard_normal_logprob(z).view(z.shape[0], -1).sum( 1, keepdim=True) # logp(z)