def __init__(self, args): super(CNFVAE, self).__init__(args) # CNF model self.cnf = build_model_tabular(args, args.z_size) if args.cuda: self.cuda()
def visualize_evolution(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] viz_times_np = viz_times[1:].detach().cpu().numpy() xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) xx_np = xx.detach().cpu().numpy() xs, ys = np.meshgrid(xx, viz_times_np) #xx,yy = np.meshgrid(args.num_particles, viz_times_np ) #all_evolutions = np.zeros((args.ntimes-1,args.num_particles)) all_evolutions = np.zeros((args.num_particles, args.ntimes - 1)) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp generated_p = generated_p.detach() #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') cur_evolution = generated_p.view(-1).exp().cpu().numpy() #all_evolutions[i]= np.array(cur_evolution) all_evolutions[:, i] = np.array(cur_evolution) #xx = np.array(xx.detach().cpu().numpy()) #yy = np.array(yy) plt.figure(dpi=1200) plt.clf() all_evolutions = all_evolutions.astype('float32') print(xs.shape) print(ys.shape) print(all_evolutions.shape) #plt.pcolormesh(ys, xs, all_evolutions) plt.pcolormesh(xs, ys, all_evolutions.transpose()) utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'evolution.jpg'.format(i))) plt.close()
def visualize_particle_flow(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) zs = [] #zs.append(xx.view(-1).cpu().numpy()) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp zs.append(z.cpu().numpy()) #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i))) #plt.close() zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles) viz_t = viz_times[1:].numpy() #print(zs) plt.figure(dpi=1200) plt.clf() #plt.plot(viz_t , zs[:,0]) with sns.color_palette("Blues_d"): plt.plot(viz_t, zs) plt.xlabel("Test Time") #plt.tight_layout() utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'particle_trajectory.jpg'.format(i))) plt.close()
def gen_model(scale=10, fraction=0.5): #build normalizing flow model from previous fit device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") args = pkl.load(open('args.pkl', 'rb')) regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 5, regularization_fns).to(device) #.cuda() if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) model.load_state_dict(torch.load('model_10000.pt')) #if torch.cuda.is_available(): # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None).cuda() #len(cond_cols)).cuda() #else: # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None) #len(cond_cols)).cuda() #num_layers = 5 #base_dist = StandardNormal(shape=(5,)) #transforms = [] #for _ in range(num_layers): # transforms.append(ReversePermutation(features=5)) # transforms.append(MaskedAffineAutoregressiveTransform(features=5, # hidden_features=4)) #transform = CompositeTransform(transforms) #model = Flow(transform, base_dist).to(device) #model.cpu() #filename = 'checkpoint11434epochs_cycle.pth' #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth' #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth' #filename = 'checkpoint_epoch{}.pth'.format(95000) #data = torch.load(filename, map_location=device) #breakpoint() #model.load_state_dict(data['model']) #if torch.cuda.is_available(): # data = torch.load(filename) # model.load_state_dict(data['model']) # model.cuda(); #else: # data = torch.load(filename, map_location=torch.device('cpu')) # model.load_state_dict(data['model']) return model
def get_ckpt_model_and_data(args): # Load checkpoint. checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) ckpt_args = checkpt['args'] state_dict = checkpt['state_dict'] # Construct model and restore checkpoint. regularization_fns, regularization_coeffs = create_regularization_fns(ckpt_args) model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) if ckpt_args.spectral_norm: add_spectral_norm(model) set_cnf_options(ckpt_args, model) model.load_state_dict(state_dict) model.to(device) print(model) print("Number of trainable parameters: {}".format(count_parameters(model))) # Load samples from dataset data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) return model, data_samples
def compute_loss(args, model, batch_size=args.batch_size): x = toy_data.inf_train_gen(args.data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) z, change = model(x, zero) logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change loss = -torch.mean(logpx) return loss if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 2, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time()
if __name__ == '__main__': device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) data = load_data(args.data) data.trn.x = torch.from_numpy(data.trn.x) data.val.x = torch.from_numpy(data.val.x) data.tst.x = torch.from_numpy(data.tst.x) args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden) regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, data.n_dims, regularization_fns).to(device) set_cnf_options(args, model) for k in model.state_dict().keys(): logger.info(k) if args.resume is not None: checkpt = torch.load(args.resume) # Backwards compatibility with an older version of the code. # TODO: remove upon release. filtered_state_dict = {} for k, v in checkpt['state_dict'].items(): if 'diffeq.diffeq' not in k: filtered_state_dict[k.replace('module.', '')] = v model.load_state_dict(filtered_state_dict)
if __name__ == '__main__': # only a single block of diffeq is supported now assert args.num_blocks == 1 centers = DEFAULT_CENTERS.to(device) dim = centers.shape[1] convection = lambda x: -gaussian_mixture_score(x, centers) writer = SummaryWriter('out/wgf/gaussian') regularization_fns, regularization_coeffs = create_regularization_fns(args) regularization_fns = None model = build_model_tabular(args=args, dims=dim, convection=convection, regularization_fns=regularization_fns, exp_decay=args.exp_decay).to(device) model_validate_dvp = build_model_compare_DVP( args=args, convection=convection, mollifier=IsometricGaussianMollifier(args.mollifier_sigma_square), diffeq=model.chain[0].odefunc.diffeq).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model)))
cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # logger.info('Using {} GPUs.'.format(torch.cuda.device_count())) data = load_data(args.data) data.trn.x = torch.from_numpy(data.trn.x) data.val.x = torch.from_numpy(data.val.x) data.tst.x = torch.from_numpy(data.tst.x) args.dims = '-'.join([str(args.hdim_factor * data.n_dims)] * args.nhidden) regularization_fns, regularization_coeffs = create_regularization_fns(args) model, cnfs = build_model_tabular( args, data.n_dims, regularization_fns, return_intermediate_points=args.return_inter_points) model = model.to(device) set_cnf_options(args, model) # for k in model.state_dict().keys(): # logger.info(k) if args.resume is not None: checkpt = torch.load(args.resume) # Backwards compatibility with an older version of the code. # TODO: remove upon release. filtered_state_dict = {} for k, v in checkpt['state_dict'].items():
def train(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): optimizer.zero_grad() loss = compute_loss(args, model) loss_meter.update(loss.item()) total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg)) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() test_loss = compute_loss(args, model, batch_size=args.test_batch_size) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format( itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save( { 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() xx = torch.linspace(-10, 10, 10000).view(-1, 1) true_p = data_density(xx) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') true_p = model_density(xx, model) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'figs')) plt.savefig( os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) plt.close() model.train() end = time.time() logger.info('Training has finished.')
def density_fn(x, logpx=None): if logpx is not None: return model(x, logpx, reverse=False) else: return model(x, reverse=False) return sample_fn, density_fn if __name__ == '__main__': if args.discrete: model = construct_discrete_model().to(device) model.load_state_dict(torch.load(args.checkpt)['state_dict']) else: model = build_model_tabular(args, 2).to(device) sd = torch.load(args.checkpt)['state_dict'] fixed_sd = {} for k, v in sd.items(): fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v model.load_state_dict(fixed_sd) print(model) print("Number of trainable parameters: {}".format(count_parameters(model))) model.eval() p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2) with torch.no_grad(): sample_fn, density_fn = get_transforms(model)
# transform to z z, delta_logp = model(x, zero) # compute log q(z) logpz = standard_normal_logprob(z).sum(1, keepdim=True) logpx = logpz - delta_logp loss = -torch.mean(logpx) return loss if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) input_dim = 2 if args.data != 'HDline' else 16 model = build_model_tabular(args, input_dim, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93)
def main(args): # logger print(args.no_display_loss) utils.makedirs(args.save) logger = utils.get_logger( logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__), displaying=~args.no_display_loss, ) if args.layer_type == "blend": logger.info("!! Setting time_scale from None to 1.0 for Blend layers.") args.time_scale = 1.0 logger.info(args) device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" ) if args.use_cpu: device = torch.device("cpu") args.data = dataset.SCData.factory(args.dataset, args.max_dim) args.timepoints = args.data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to( device ) if args.use_growth: if args.leaveout_timepoint == -1: growth_model_path = ( "../data/externel/growth_model_v2.ckpt" ) elif args.leaveout_timepoint in [1, 2, 3]: assert args.max_dim == 5 growth_model_path = ( "../data/growth/model_%d" % args.leaveout_timepoint ) else: print("WARNING: Cannot use growth with this timepoint") growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) if args.test: state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) # if "growth_state_dict" not in state_dict: # print("error growth model note in save") # growth_model = None # else: # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device) # growth_model.load_state_dict(checkpt["growth_state_dict"]) # TODO can we load the arguments from the save? # eval_utils.generate_samples( # device, args, model, growth_model, timepoint=args.leaveout_timepoint # ) # with torch.no_grad(): # evaluate(device, args, model, growth_model) # exit() else: logger.info(model) n_param = count_parameters(model) logger.info("Number of trainable parameters: {}".format(n_param)) train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger, ) if args.data.data.shape[1] == 2: plot_output(device, args, model)
def main(args): device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") if args.use_cpu: device = torch.device("cpu") data = dataset.SCData.factory(args.dataset, args) args.timepoints = data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(device) if args.use_growth: growth_model_path = data.get_growth_net_path() #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt" growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) #plot_output(device, args, model, data) #exit() # get_trajectory_samples(device, model, data) args.data = data args.timepoints = args.data.get_unique_times() args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale print('integrating backwards') #end_time_data = data.data_dict[args.embedding_name] end_time_data = data.get_data()[args.data.get_times() == np.max( args.data.get_times())] #np.random.permutation(end_time_data) #rand_idx = np.random.randint(end_time_data.shape[0], size=5000) #end_time_data = end_time_data[rand_idx,:] integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device) exit() losses_list = [] #for factor in np.linspace(0.05, 0.95, 19): #for factor in np.linspace(0.91, 0.99, 9): if args.dataset == 'CHAFFER': # Do timepoint adjustment print('adjusting_timepoints') lt = args.leaveout_timepoint if lt == 1: factor = 0.6799872494335812 factor = 0.95 elif lt == 2: factor = 0.2905983814032348 factor = 0.01 else: raise RuntimeError('Unknown timepoint %d' % args.leaveout_timepoint) args.int_tps[lt] = ( 1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1] losses = eval_utils.evaluate_kantorovich_v2(device, args, model) losses_list.append(losses) print(np.array(losses_list)) np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))