def evaluate(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) tols = [1e-1, 1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8] errors = [] with torch.no_grad(): for tol in tols: args.rtol = tol args.atol = tol set_cnf_options(args, model) xx = torch.linspace(-15, 15, 500000).view(-1, 1).to(device) prob_xx = model_density(xx, model).double().view(-1).cpu() xx = xx.double().cpu().view(-1) dxx = torch.log(xx[1:] - xx[:-1]) num_integral = torch.logsumexp(prob_xx[:-1] + dxx, 0).exp() errors.append(float(torch.abs(num_integral - 1.))) print(errors[-1]) plt.figure(figsize=(5, 3)) plt.plot(tols, errors, linewidth=3, marker='o', markersize=7) # plt.plot([-1, 0.2], [-1, 0.2], '--', color='grey', linewidth=1) plt.xscale("log", nonposx='clip') # plt.yscale("log", nonposy='clip') plt.xlabel('Solver Tolerance', fontsize=17) plt.ylabel('$| 1 - \int p(x) |$', fontsize=17) plt.tight_layout() plt.savefig('ode_solver_error_vs_tol.pdf')
def get_ckpt_model_and_data(args): # Load checkpoint. checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) ckpt_args = checkpt['args'] state_dict = checkpt['state_dict'] # Construct model and restore checkpoint. regularization_fns, regularization_coeffs = create_regularization_fns( ckpt_args) model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) if ckpt_args.spectral_norm: add_spectral_norm(model) set_cnf_options(ckpt_args, model) model.load_state_dict(state_dict) model.to(device) print(model) print("Number of trainable parameters: {}".format( count_parameters(model))) # Load samples from dataset data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) return model, data_samples
def visualize_times(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) xx = torch.linspace(-10, 10, 10000).view(-1, 1) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', '{:04d}.jpg'.format(i))) plt.close() trajectory_to_video(os.path.join(args.save, 'test_times', 'figs'))
def visualize_evolution(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] viz_times_np = viz_times[1:].detach().cpu().numpy() xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) xx_np = xx.detach().cpu().numpy() xs, ys = np.meshgrid(xx, viz_times_np) #xx,yy = np.meshgrid(args.num_particles, viz_times_np ) #all_evolutions = np.zeros((args.ntimes-1,args.num_particles)) all_evolutions = np.zeros((args.num_particles, args.ntimes - 1)) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp generated_p = generated_p.detach() #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') cur_evolution = generated_p.view(-1).exp().cpu().numpy() #all_evolutions[i]= np.array(cur_evolution) all_evolutions[:, i] = np.array(cur_evolution) #xx = np.array(xx.detach().cpu().numpy()) #yy = np.array(yy) plt.figure(dpi=1200) plt.clf() all_evolutions = all_evolutions.astype('float32') print(xs.shape) print(ys.shape) print(all_evolutions.shape) #plt.pcolormesh(ys, xs, all_evolutions) plt.pcolormesh(xs, ys, all_evolutions.transpose()) utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'evolution.jpg'.format(i))) plt.close()
def visualize_particle_flow(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) checkpt = torch.load(os.path.join(args.save, 'checkpt.pth')) model.load_state_dict(checkpt['state_dict']) model.to(device) viz_times = torch.linspace(0., args.time_length, args.ntimes) errors = [] xx = torch.linspace(-5, 5, args.num_particles).view(-1, 1) zs = [] #zs.append(xx.view(-1).cpu().numpy()) with torch.no_grad(): for i, t in enumerate(tqdm(viz_times[1:])): model.eval() set_cnf_options(args, model) #generated_p = model_density(xx, model) generated_p = 0 for cnf in model.chain: xx = xx.to(device) z, delta_logp = cnf(xx, torch.zeros_like(xx), integration_times=torch.Tensor([0, t])) generated_p = standard_normal_logprob(z) - delta_logp zs.append(z.cpu().numpy()) #plt.plot(xx.view(-1).cpu().numpy(), generated_p.view(-1).exp().cpu().numpy(), label='Model') #plt.savefig(os.path.join(args.save,'test_times', 'figs', '{:04d}.jpg'.format(i))) #plt.close() zs = np.array(zs).reshape(args.ntimes - 1, args.num_particles) viz_t = viz_times[1:].numpy() #print(zs) plt.figure(dpi=1200) plt.clf() #plt.plot(viz_t , zs[:,0]) with sns.color_palette("Blues_d"): plt.plot(viz_t, zs) plt.xlabel("Test Time") #plt.tight_layout() utils.makedirs(os.path.join(args.save, 'test_times', 'figs')) plt.savefig( os.path.join(args.save, 'test_times', 'figs', 'particle_trajectory.jpg'.format(i))) plt.close()
def create_model(args, data_shape): hidden_dims = tuple(map(int, args.dims.split(","))) model = odenvp.ODENVP( (BATCH_SIZE, *data_shape), n_blocks=args.num_blocks, intermediate_dims=hidden_dims, nonlinearity=args.nonlinearity, alpha=args.alpha, cnf_kwargs={ "T": args.time_length, "train_T": args.train_T }, ) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) return model
def gen_model(scale=10, fraction=0.5): #build normalizing flow model from previous fit device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") args = pkl.load(open('args.pkl', 'rb')) regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 5, regularization_fns).to(device) #.cuda() if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) model.load_state_dict(torch.load('model_10000.pt')) #if torch.cuda.is_available(): # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None).cuda() #len(cond_cols)).cuda() #else: # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None) #len(cond_cols)).cuda() #num_layers = 5 #base_dist = StandardNormal(shape=(5,)) #transforms = [] #for _ in range(num_layers): # transforms.append(ReversePermutation(features=5)) # transforms.append(MaskedAffineAutoregressiveTransform(features=5, # hidden_features=4)) #transform = CompositeTransform(transforms) #model = Flow(transform, base_dist).to(device) #model.cpu() #filename = 'checkpoint11434epochs_cycle.pth' #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth' #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth' #filename = 'checkpoint_epoch{}.pth'.format(95000) #data = torch.load(filename, map_location=device) #breakpoint() #model.load_state_dict(data['model']) #if torch.cuda.is_available(): # data = torch.load(filename) # model.load_state_dict(data['model']) # model.cuda(); #else: # data = torch.load(filename, map_location=torch.device('cpu')) # model.load_state_dict(data['model']) return model
def main(): # os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu") else: device = torch.cuda.current_device() # # import pdb; pdb.set_trace() cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm'] testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost'] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns).cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters # import pdb; pdb.set_trace() if args.resume is not None: # import pdb; pdb.set_trace() print('resume from checkpoint') checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 chkdir = args.save ''' elif args.resume and args.validate: chkdir = os.path.dirname(args.resume) wall_clock = 0 itr = 0 best_loss = 0.0 begin_epoch = 0 ''' else: chkdir = os.path.dirname(args.resume) filename = os.path.join(chkdir, 'test.csv') print(filename) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) # import pdb; pdb.set_trace() wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, begin_epoch + 1): # compute test loss print('Evaluating') model.eval() if args.local_rank == 0: utils.makedirs(args.save) # import pdb; pdb.set_trace() if hasattr(model, 'module'): _state = model.module.state_dict() else: _state = model.state_dict() torch.save({ "args": args, "state_dict": _state, # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt_%d.pth" % epoch)) # save real and generate with different temperatures fig_num = 64 if True: # args.save_real: for i, (x, y) in enumerate(test_loader): if i < 100: pass elif i == 100: real = x.size(0) else: break if x.shape[0] > fig_num: x = x[:fig_num, ...] # import pdb; pdb.set_trace() fig_filename = os.path.join(chkdir, "real.jpg") save_image(x.float() / 255.0, fig_filename, nrow=8) if True: # args.generate: print('\nGenerating images... ') fixed_z = cvt(torch.randn(fig_num, *data_shape)) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]: # visualize samples and density fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t) utils.makedirs(os.path.dirname(fig_filename)) generated_samples = model(t * fixed_z, reverse=True) x = unshift(generated_samples[0].view(-1, *data_shape), 8) save_image(x, fig_filename, nrow=nb)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if args.use_cpu: device = torch.device("cpu") cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset test_loader = get_dataset(args, args.test_batch_size) # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) aug_model = build_augmented_model_tabular( args, args.aug_size + args.effective_shape, regularization_fns=regularization_fns, ) set_cnf_options(args, aug_model) logger.info(aug_model) # restore parameters itr = 0 if args.resume is not None: checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) aug_model.load_state_dict(checkpt["state_dict"]) if torch.cuda.is_available() and not args.use_cpu: aug_model = torch.nn.DataParallel(aug_model).cuda() best_loss = float("inf") aug_model.eval() with torch.no_grad():
best_loss = float('inf') itr = 0 n_vals_without_improvement = 0 end = time.time() model.train() while True: if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: break for x in batch_iter(data.trn.x, shuffle=True): if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: break atol, rtol = update_tolerances(args, itr, decay_factors) set_cnf_options(args, atol, rtol, model) print(atol) print(rtol) optimizer.zero_grad() x = cvt(x) loss = compute_loss(x, model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs)
def train(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): optimizer.zero_grad() loss = compute_loss(args, model) loss_meter.update(loss.item()) total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg)) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() test_loss = compute_loss(args, model, batch_size=args.test_batch_size) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format( itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save( { 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() xx = torch.linspace(-10, 10, 10000).view(-1, 1) true_p = data_density(xx) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') true_p = model_density(xx, model) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'figs')) plt.savefig( os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) plt.close() model.train() end = time.time() logger.info('Training has finished.')
def main(args): # logger print(args.no_display_loss) utils.makedirs(args.save) logger = utils.get_logger( logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__), displaying=~args.no_display_loss, ) if args.layer_type == "blend": logger.info("!! Setting time_scale from None to 1.0 for Blend layers.") args.time_scale = 1.0 logger.info(args) device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" ) if args.use_cpu: device = torch.device("cpu") args.data = dataset.SCData.factory(args.dataset, args.max_dim) args.timepoints = args.data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to( device ) if args.use_growth: if args.leaveout_timepoint == -1: growth_model_path = ( "../data/externel/growth_model_v2.ckpt" ) elif args.leaveout_timepoint in [1, 2, 3]: assert args.max_dim == 5 growth_model_path = ( "../data/growth/model_%d" % args.leaveout_timepoint ) else: print("WARNING: Cannot use growth with this timepoint") growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) if args.test: state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) # if "growth_state_dict" not in state_dict: # print("error growth model note in save") # growth_model = None # else: # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device) # growth_model.load_state_dict(checkpt["growth_state_dict"]) # TODO can we load the arguments from the save? # eval_utils.generate_samples( # device, args, model, growth_model, timepoint=args.leaveout_timepoint # ) # with torch.no_grad(): # evaluate(device, args, model, growth_model) # exit() else: logger.info(model) n_param = count_parameters(model) logger.info("Number of trainable parameters: {}".format(n_param)) train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger, ) if args.data.data.shape[1] == 2: plot_output(device, args, model)
def run(args, kwargs): # ================================================================================================================== # SNAPSHOTS # ================================================================================================================== args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') args.model_signature = args.model_signature.replace(':', '_') if args.automatic_saving == True: path = '{}/{}/{}/{}/{}/{}/{}/{}/{}/'.format(args.solver, args.dataset, args.layer_type, args.atol, args.rtol, args.atol_start, args.rtol_start, args.warmup_steps, args.manual_seed) else: path = 'test/' args.snap_dir = os.path.join(args.out_dir, path) if not os.path.exists(args.snap_dir): os.makedirs(args.snap_dir) # logger utils.makedirs(args.snap_dir) logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # SAVING torch.save(args, args.snap_dir + 'config.config') # ================================================================================================================== # LOAD DATA # ================================================================================================================== train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) if not args.evaluate: nfef_meter = utils.AverageMeter() nfeb_meter = utils.AverageMeter() # ============================================================================================================== # SELECT MODEL # ============================================================================================================== # flow parameters and architecture choice are passed on to model through args if args.flow == 'no_flow': model = VAE.VAE(args) elif args.flow == 'planar': model = VAE.PlanarVAE(args) elif args.flow == 'iaf': model = VAE.IAFVAE(args) elif args.flow == 'orthogonal': model = VAE.OrthogonalSylvesterVAE(args) elif args.flow == 'householder': model = VAE.HouseholderSylvesterVAE(args) elif args.flow == 'triangular': model = VAE.TriangularSylvesterVAE(args) elif args.flow == 'cnf': model = CNFVAE.CNFVAE(args) elif args.flow == 'cnf_bias': model = CNFVAE.AmortizedBiasCNFVAE(args) elif args.flow == 'cnf_hyper': model = CNFVAE.HypernetCNFVAE(args) elif args.flow == 'cnf_lyper': model = CNFVAE.LypernetCNFVAE(args) elif args.flow == 'cnf_rank': model = CNFVAE.AmortizedLowRankCNFVAE(args) else: raise ValueError('Invalid flow choice') if args.retrain_encoder: logger.info(f"Initializing decoder from {args.model_path}") dec_model = torch.load(args.model_path) dec_sd = {} for k, v in dec_model.state_dict().items(): if 'p_x' in k: dec_sd[k] = v model.load_state_dict(dec_sd, strict=False) if args.cuda: logger.info("Model on GPU") model.cuda() logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if args.retrain_encoder: parameters = [] logger.info('Optimizing over:') for name, param in model.named_parameters(): if 'p_x' not in name: logger.info(name) parameters.append(param) else: parameters = model.parameters() optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) # ================================================================================================================== # TRAINING # ================================================================================================================== train_loss = [] val_loss = [] # for early stopping best_loss = np.inf best_bpd = np.inf e = 0 epoch = 0 train_times = [] for epoch in range(1, args.epochs + 1): atol, rtol = update_tolerances(args, epoch, decay_factors) print(atol) set_cnf_options(args, atol, rtol, model) t_start = time.time() if 'cnf' not in args.flow: tr_loss = train(epoch, train_loader, model, optimizer, args, logger) else: tr_loss, nfef_meter, nfeb_meter = train( epoch, train_loader, model, optimizer, args, logger, nfef_meter, nfeb_meter) train_loss.append(tr_loss) train_times.append(time.time() - t_start) logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) val_loss.append(v_loss) # early-stopping if v_loss < best_loss: e = 0 best_loss = v_loss if args.input_type != 'binary': best_bpd = v_bpd logger.info('->model saved<-') torch.save(model, args.snap_dir + 'model.model') # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): e += 1 if e > args.early_stopping_epochs: break if args.input_type == 'binary': logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format( e, args.early_stopping_epochs, best_loss)) else: logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n' .format(e, args.early_stopping_epochs, best_loss, best_bpd)) if math.isnan(v_loss): raise ValueError('NaN encountered!') train_loss = np.hstack(train_loss) val_loss = np.array(val_loss) plot_training_curve(train_loss, val_loss, fname=args.snap_dir + '/training_curve.pdf') # training time per epoch train_times = np.array(train_times) mean_train_time = np.mean(train_times) std_train_time = np.std(train_times, ddof=1) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) # ================================================================================================================== # EVALUATION # ================================================================================================================== logger.info(args) logger.info('Stopped after %d epochs' % epoch) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) final_model = torch.load(args.snap_dir + 'model.model') validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) else: validation_loss = "N/A" validation_bpd = "N/A" logger.info(f"Loading model from {args.model_path}") final_model = torch.load(args.model_path) test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) logger.info( 'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format( validation_loss))
def main(): global best_acc if not os.path.isdir(args.out): mkdir_p(args.out) # Data print(f'==> Preparing cifar10') transform_train = transforms.Compose([ dataset.RandomPadandCrop(32), dataset.RandomFlip(), dataset.ToTensor(), ]) transform_val = transforms.Compose([ dataset.ToTensor(), ]) train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10( '/home/fengchan/stor/dataset/original-data/cifar10', args.n_labeled, transform_train=transform_train, transform_val=transform_val) labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) # Model print("==> creating WRN-28-2") def create_model(ema=False): model = models.WideResNet(num_classes=num_classes) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model data_shape = [3, 32, 32] regularization_fns, regularization_coeffs = create_regularization_fns(args) def create_cnf(): # generate cnf # cnf = create_cnf_model_1(args, data_shape, regularization_fns=None) # cnf = create_cnf_model(args, data_shape, regularization_fns=regularization_fns) cnf = create_nf_model(args, data_shape, regularization_fns=None) cnf = cnf.cuda() if use_cuda else cnf return cnf model = create_model() ema_model = create_model(ema=True) cnf = create_cnf() if args.spectral_norm: add_spectral_norm(cnf, logger) set_cnf_options(args, cnf) cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) train_criterion = SemiLoss() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) #CNF cnf_optimizer = optim.Adam(cnf.parameters(), lr=args.lr, weight_decay=args.weight_decay) ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay) start_epoch = 0 # Resume #generate prior means = generate_gaussian_means(num_classes, data_shape, seed=num_classes) title = 'noisy-cifar-10' if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) ema_model.load_state_dict(checkpoint['ema_state_dict']) cnf.load_state_dict(checkpoint['cnf_state_dict']) means = checkpoint['means'] cnf_optimizer.load_state_dict(checkpoint['cnf_optimizer']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) else: logger = Logger(os.path.join(args.out, 'log.txt'), title=title) logger.set_names([ 'Train Loss', 'Train Loss X', 'Train Loss U', 'Train loss NLL X', 'Train loss NLL U', 'Train loss mixed X', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.' ]) means = means.cuda() if use_cuda else means prior = SSLGaussMixture(means, device='cuda' if use_cuda else 'cpu') writer = SummaryWriter(args.out) step = 0 test_accs = [] # Train and val for epoch in range(start_epoch, args.epochs): print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x = train( labeled_trainloader, unlabeled_trainloader, model, cnf, prior, cnf_optimizer, optimizer, ema_optimizer, train_criterion, epoch, use_cuda) _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') step = args.train_iteration * (epoch + 1) writer.add_scalar('losses/train_loss', train_loss, step) writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step) writer.add_scalar('losses/train_loss_nll_u', train_loss_nll_u, step) writer.add_scalar('losses/train_loss_mixed_x', train_loss_mixed_x, step) writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step) writer.add_scalar('losses/valid_loss', val_loss, step) writer.add_scalar('losses/test_loss', test_loss, step) writer.add_scalar('accuracy/train_acc', train_acc, step) writer.add_scalar('accuracy/val_acc', val_acc, step) writer.add_scalar('accuracy/test_acc', test_acc, step) # append logger file logger.append([ train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x, val_loss, val_acc, test_loss, test_acc ]) # save model is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'cnf_state_dict': cnf.state_dict(), 'means': means, 'ema_state_dict': ema_model.state_dict(), 'acc': val_acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'cnf_optimizer': cnf_optimizer.state_dict(), }, is_best) test_accs.append(test_acc) logger.close() writer.close() print('Best acc:') print(best_acc) print('Mean acc:') print(np.mean(test_accs[-20:]))
def build_augmented_model_tabular(args, dims, regularization_fns=None): """ The function used for creating conditional Continuous Normlizing Flow with augmented neural ODE Parameters: args: arguments used to create conditional CNF. Check args parser for details. dims: dimension of the input. Currently only allow 1-d input. regularization_fns: regularizations applied to the ODE function Returns: a ctfp model based on augmened neural ode """ hidden_dims = tuple(map(int, args.dims.split(","))) if args.aug_hidden_dims is not None: aug_hidden_dims = tuple(map(int, args.aug_hidden_dims.split(","))) else: aug_hidden_dims = None def build_cnf(): diffeq = layers.AugODEnet( hidden_dims=hidden_dims, input_shape=(dims, ), effective_shape=args.effective_shape, strides=None, conv=False, layer_type=args.layer_type, nonlinearity=args.nonlinearity, aug_dim=args.aug_dim, aug_mapping=args.aug_mapping, aug_hidden_dims=args.aug_hidden_dims, ) odefunc = layers.AugODEfunc( diffeq=diffeq, divergence_fn=args.divergence_fn, residual=args.residual, rademacher=args.rademacher, effective_shape=args.effective_shape, ) cnf = layers.CNF( odefunc=odefunc, T=args.time_length, train_T=args.train_T, regularization_fns=regularization_fns, solver=args.solver, rtol=args.rtol, atol=args.atol, ) return cnf chain = [build_cnf() for _ in range(args.num_blocks)] if args.batch_norm: bn_layers = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) for _ in range(args.num_blocks) ] bn_chain = [ layers.MovingBatchNorm1d(dims, bn_lag=args.bn_lag, effective_shape=args.effective_shape) ] for a, b in zip(chain, bn_layers): bn_chain.append(a) bn_chain.append(b) chain = bn_chain model = layers.SequentialFlow(chain) set_cnf_options(args, model) return model
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
x = toy_data.inf_train_gen(args.data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) z, change = model(x, zero) logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change loss = -torch.mean(logpx) return loss if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 2, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train()
def main(args): device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") if args.use_cpu: device = torch.device("cpu") data = dataset.SCData.factory(args.dataset, args) args.timepoints = data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(device) if args.use_growth: growth_model_path = data.get_growth_net_path() #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt" growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) #plot_output(device, args, model, data) #exit() # get_trajectory_samples(device, model, data) args.data = data args.timepoints = args.data.get_unique_times() args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale print('integrating backwards') #end_time_data = data.data_dict[args.embedding_name] end_time_data = data.get_data()[args.data.get_times() == np.max( args.data.get_times())] #np.random.permutation(end_time_data) #rand_idx = np.random.randint(end_time_data.shape[0], size=5000) #end_time_data = end_time_data[rand_idx,:] integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device) exit() losses_list = [] #for factor in np.linspace(0.05, 0.95, 19): #for factor in np.linspace(0.91, 0.99, 9): if args.dataset == 'CHAFFER': # Do timepoint adjustment print('adjusting_timepoints') lt = args.leaveout_timepoint if lt == 1: factor = 0.6799872494335812 factor = 0.95 elif lt == 2: factor = 0.2905983814032348 factor = 0.01 else: raise RuntimeError('Unknown timepoint %d' % args.leaveout_timepoint) args.int_tps[lt] = ( 1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1] losses = eval_utils.evaluate_kantorovich_v2(device, args, model) losses_list.append(losses) print(np.array(losses_list)) np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))