def validate(epoch, model, ema=None): """ Evaluates the cross entropy between p_data and p_model. """ bpd_meter = utils.AverageMeter() ce_meter = utils.AverageMeter() if ema is not None: ema.swap() update_lipschitz(model) model.eval() correct = 0 total = 0 start = time.time() with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader)): x = x.to(device) bpd, _, _ = compute_loss(x, model) bpd_meter.update(bpd.item(), x.size(0)) val_time = time.time() - start if ema is not None: ema.swap() s = 'Epoch: [{0}]\tTime {1:.2f} | Test Nats {bpd_meter.avg:.4f}'.format( epoch, val_time, bpd_meter=bpd_meter) logger.info(s) return bpd_meter.avg
def train_handcraft(args, train_loader, valid_loader, index_loader, valid_dataset, index_dataset, save_root, writer): ext = handcraft_extractor(args) start_epoch = 0 if args.ckpt_path is not None: ext.load(args.ckpt_path) else: batch_time = u.AverageMeter() data_time = u.AverageMeter() start = time.time() pbar = tqdm.tqdm(enumerate(train_loader), desc="Extract local descriptor!") if args.train is True: for batch_i, data in pbar: data_time.update(time.time() - start) start = time.time() ext.extract_ld(data) batch_time.update(time.time() - start) start = time.time() state_msg = ('Data time: {:0.5f}; Batch time: {:0.5f};'.format( data_time.avg, batch_time.avg)) pbar.set_description(state_msg) ext.build_voca(k=args.cluster) ext.extract_vlad() filename = os.path.join(save_root, 'ckpt', 'checkpoint.pkl') ext.save(filename) if (args.valid is True) or (args.valid_sample is True): pbar = tqdm.tqdm(enumerate(valid_loader), desc="Extract query descriptor!") for batch_i, data in pbar: ext.extract_vlad_query(data) indexdb, validdb = ext.get_data() if args.metric == 0: ldm = mt.LocDegThreshMetric(args, indexdb, validdb, index_dataset, valid_dataset, 0, os.path.join(save_root, "result")) return
def validate(epoch, model, ema=None): """ Evaluates the cross entropy between p_data and p_model. """ bpd_meter = utils.AverageMeter() ce_meter = utils.AverageMeter() if ema is not None: ema.swap() update_lipschitz(model) model = parallelize(model) model.eval() correct = 0 total = 0 start = time.time() with torch.no_grad(): for i, (x, y) in enumerate(tqdm(test_loader)): x = x.to(device) bpd, logits, _, _ = compute_loss(x, model) bpd_meter.update(bpd.item(), x.size(0)) if args.task in ['classification', 'hybrid']: y = y.to(device) loss = criterion(logits, y) ce_meter.update(loss.item(), x.size(0)) _, predicted = logits.max(1) total += y.size(0) correct += predicted.eq(y).sum().item() val_time = time.time() - start if ema is not None: ema.swap() s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format( epoch, val_time, bpd_meter=bpd_meter) if args.task in ['classification', 'hybrid']: s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg, 100 * correct / total) logger.info(s) return bpd_meter.avg
def train_on_dataset(self, data_loader, models, criterions, optimizers, epoch, logs, **kwargs): """ train on dataset for one epoch """ loss_meters = [utils.AverageMeter() for i in range(len(models))] top1_meters = [utils.AverageMeter() for i in range(len(models))] for model in models: model.train() for i, (input_, target) in enumerate(data_loader): input_, target = self.to_cuda(input_, target) self.train_on_batch(input_, target, models, criterions, optimizers, logs, loss_meters, top1_meters, **kwargs) self.write_log(logs, loss_meters, top1_meters, epoch, mode="train") return logs
def validate_on_dataset(self, data_loader, models, criterions, epoch, logs, **kwargs): """ validate on dataset """ loss_meters = [utils.AverageMeter() for i in range(len(models))] top1_meters = [utils.AverageMeter() for i in range(len(models))] for model in models: model.eval() for i, (input_, target) in enumerate(data_loader): input_, target = self.to_cuda(input_, target) self.validate_on_batch(input_, target, models, criterions, logs, loss_meters, top1_meters) self.write_log(logs, loss_meters, top1_meters, epoch, mode="test") return logs
def validate(epoch, model, data_loader, ema, device): """ Evaluates the cross entropy between p_data and p_model. """ bpd_meter = utils.AverageMeter() if ema is not None: ema.swap() model.eval() start = time.time() with torch.no_grad(): for i, (x, y) in enumerate(tqdm(data_loader)): x = x.to(device) bpd = compute_loss(x, model) bpd_meter.update(bpd.item(), x.size(0)) val_time = time.time() - start if ema is not None: ema.swap() return val_time, bpd_meter.avg
def train(model, trainD, evalD, checkpt=None): global ndecs optim = torch.optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd) # sch = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.nepochs * trainD.N) if checkpt is not None: optim.load_state_dict(checkpt['optim']) ndecs = checkpt['ndecs'] batch_time = utils.RunningAverageMeter(0.98) cg_meter = utils.RunningAverageMeter(0.98) gnorm_meter = utils.RunningAverageMeter(0.98) train_est_meter = utils.RunningAverageMeter(0.98**args.train_est_freq) best_logp = -float('inf') itr = 0 if checkpt is None else checkpt['iters'] n_vals_without_improvement = 0 model.train() while True: if itr >= args.nepochs * math.ceil(trainD.N / args.batch_size): break if 0 < args.early_stopping < n_vals_without_improvement: break for x in batch_iter(trainD.x, shuffle=True): if 0 < args.early_stopping < n_vals_without_improvement: break end = time.time() optim.zero_grad() x = cvt(x) train_est = [0] if itr % args.train_est_freq == 0 else None loss = -model.logp(x, extra=train_est).mean() if train_est is not None: train_est = train_est[0].mean().detach().item() if loss != loss: raise ValueError('NaN encountered @ training logp!') loss.backward() if args.clip_grad == 0: parameters = [ p for p in model.parameters() if p.grad is not None ] grad_norm = torch.norm( torch.stack([ torch.norm(p.grad.detach(), 2.0) for p in parameters ]), 2.0) else: grad_norm = torch.nn.utils.clip_grad.clip_grad_norm_( model.parameters(), args.clip_grad) optim.step() # sch.step() gnorm_meter.update(float(grad_norm)) cg_meter.update(sum(flows.CG_ITERS_TRACER)) flows.CG_ITERS_TRACER.clear() batch_time.update(time.time() - end) if train_est is not None: train_est_meter.update(train_est) del loss gc.collect() torch.clear_autocast_cache() if itr % args.log_freq == 0: log_message = ( 'Iter {:06d} | Epoch {:.2f} | Time {batch_time.val:.3f} | ' 'GradNorm {gnorm_meter.avg:.2f} | CG iters {cg_meter.val} ({cg_meter.avg:.2f}) | ' 'Train logp {train_logp.val:.6f} ({train_logp.avg:.6f})'. format(itr, float(itr) / (trainD.N / float(args.batch_size)), batch_time=batch_time, gnorm_meter=gnorm_meter, cg_meter=cg_meter, train_logp=train_est_meter)) logger.info(log_message) # Validation loop. if itr % args.val_freq == 0: with eval_ctx(model, bruteforce=args.brute_val): val_logp = utils.AverageMeter() with tqdm(total=evalD.N) as pbar: # noinspection PyAssignmentToLoopOrWithParameter for x in batch_iter(evalD.x, batch_size=args.val_batch_size): x = cvt(x) val_logp.update( model.logp(x).mean().item(), x.size(0)) pbar.update(x.size(0)) if val_logp.avg > best_logp: best_logp = val_logp.avg utils.makedirs(args.save) torch.save( { 'args': args, 'model': model.state_dict(), 'optim': optim.state_dict(), 'iters': itr + 1, 'ndecs': ndecs, }, save_path) n_vals_without_improvement = 0 else: n_vals_without_improvement += 1 update_lr(optim, n_vals_without_improvement) log_message = ('[VAL] Iter {:06d} | Val logp {:.6f} | ' 'NoImproveEpochs {:02d}/{:02d}'.format( itr, val_logp.avg, n_vals_without_improvement, args.early_stopping)) logger.info(log_message) itr += 1 logger.info('Training has finished, yielding the best model...') best_checkpt = torch.load(save_path) model.load_state_dict(best_checkpt['model']) return model
"Resume file provided, but not found... starting from scratch: {}". format(args.resume)) logger.info(flow) logger.info("Number of trainable parameters:{}".format( count_parameters(flow))) ################################################################################ # Training # ################################################################################ if not args.evaluate: flow = train(flow, data.trn, data.val, checkpt) ################################################################################ # Testing # ################################################################################ logger.info('Evaluating model on test set.') with eval_ctx(flow, bruteforce=True): test_logp = utils.AverageMeter() with tqdm(total=data.tst.N) as pbar: for itr, x in enumerate( batch_iter(data.tst.x, batch_size=args.test_batch_size)): x = cvt(x) test_logp.update(flow.logp(x).mean().item(), x.size(0)) pbar.update(x.size(0)) log_message = '[TEST] Iter {:06d} | Test logp {:.6f}'.format( itr, test_logp.avg) logger.info(log_message)
filtered_state_dict = {} for k, v in checkpt['state_dict'].items(): if 'diffeq.diffeq' not in k: filtered_state_dict[k.replace('module.', '')] = v model.load_state_dict(filtered_state_dict) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if not args.evaluate: optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.AverageMeter() loss_meter = utils.AverageMeter() nfef_meter = utils.AverageMeter() nfeb_meter = utils.AverageMeter() tt_meter = utils.AverageMeter() best_loss = float('inf') itr = 0 n_vals_without_improvement = 0 end = time.time() model.train() while True: if args.early_stopping > 0 and n_vals_without_improvement > args.early_stopping: break for x in batch_iter(data.trn.x, shuffle=True):
def train(args, train_loader, valid_loader, index_loader, valid_dataset, index_dataset, save_root, writer): ext = extractor(args) crt = criterion(args) optim = get_optimizer(args, ext) sche = get_scheduler(args, optim) start_epoch = 0 if args.ckpt_path is not None: ckpt = torch.load(args.ckpt_path) ext = ckpt['model'] optim = ckpt['optimizer'] start_epoch = ckpt['epoch'] + 1 for epoch in range(start_epoch, args.epochs): ext.train() sche.step() batch_time = u.AverageMeter() data_time = u.AverageMeter() losses = u.AverageMeter() start = time.time() pbar = tqdm.tqdm(enumerate(train_loader), desc="Epoch : %d" % epoch) size_all = len(train_loader) interv = math.floor(size_all / args.save_interval) sub_p = 0 if args.train is True: for batch_i, data in pbar: image = data['image'].cuda() label = data['label'].cuda() data_time.update(time.time() - start) start = time.time() output = ext(image) if output.dim() == 1: output = output.unsqueeze(0) loss = crt(output, label, args.tuple, args.batch) optim.zero_grad() loss.backward() optim.step() losses.update(loss.item()) batch_time.update(time.time() - start) start = time.time() writer.add_scalars('train/loss', {'loss': losses.avg}, global_step=epoch * len(train_loader) + batch_i) state_msg = ( 'Epoch: {:4d}; Loss: {:0.5f}; Data time: {:0.5f}; Batch time: {:0.5f};' .format(epoch, losses.avg, data_time.avg, batch_time.avg)) pbar.set_description(state_msg) if ((batch_i + 1) % interv == 0) or (size_all == batch_i + 1): state = { 'epoch': epoch, 'loss': losses.avg, 'model': ext, 'optimizer': optim } filename = os.path.join( save_root, 'ckpt', 'checkpoint_subset{:03d}_epoch{:03d}.pth.tar'.format( sub_p, epoch)) torch.save(state, filename) sub_p += 1 if ((args.valid is True) or (args.valid_sample is True)) and ( (epoch + 1) % args.valid_interval == 0): if (args.db_load is not None): with open(args.db_load, "rb") as a_file: indexdb = pickle.load(a_file) else: # #index indexdb = make_inferDBandPredict(args, index_loader, ext, epoch, tp='index') #valid validdb = make_inferDBandPredict(args, valid_loader, ext, epoch, tp='valid') if args.db_load is None: if (args.extractor >= 4): indexdb['feat'], validdb['feat'] = ext.postprocessing( indexdb['feat'], validdb['feat']) if args.pca is True: pca = pp.PCAwhitening(pca_dim=args.pca_dim, pca_whitening=True) indexdb['feat'] = pca.fit_transform(indexdb['feat']) validdb['feat'] = pca.transform(validdb['feat']) if (args.db_save is not None): if os.path.isfile(args.db_save) is True: os.remove(args.db_save) a_file = open(args.db_save, "wb") pickle.dump(indexdb, a_file) a_file.close() if args.metric == 0: ldm = mt.LocDegThreshMetric(args, indexdb, validdb, index_dataset, valid_dataset, epoch, os.path.join(save_root, "result")) if args.train is False: return if args.qualitative: return for key, value in ldm.items(): writer.add_scalars('valid/top' + str(args.topk) + "_" + key, {key: value}, global_step=epoch) return
logger.info("saveLocation = {:}".format(args.save)) logger.info("-------------------------\n") begin = time.time() end = begin best_loss = float('inf') best_costs = [0.0] * 3 best_params = None log_msg = ( '{:5s} {:6s} {:7s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} ' .format('iter', ' time', 'lr', 'loss', 'L (L_2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR')) logger.info(log_msg) timeMeter = utils.AverageMeter() clampMax = 2.0 clampMin = -2.0 net.train() itr = 1 while itr < args.niters: # train for data in train_loader: images, _ = data # flatten images x0 = images.view(images.size(0), -1) x0 = cvt(x0) x0 = autoEnc.encode(x0) # encode x0 = (x0 - autoEnc.mu) / (autoEnc.std + args.eps) # normalize
def validate(epoch, model, gmm, ema=None): """ - Deploys the color normalization on test image dataset - Evaluates NMI / CV / SD # Evaluates the cross entropy between p_data and p_model. """ print("Starting Validation") model = parallelize(model) gmm = parallelize(gmm) model.to(device) gmm.to(device) bpd_meter = utils.AverageMeter() ce_meter = utils.AverageMeter() if ema is not None: ema.swap() update_lipschitz(model) model.eval() gmm.eval() mu_tmpl = 0 std_tmpl = 0 N = 0 print( f"Deploying on {len(train_loader)} batches of {args.batchsize} templates..." ) idx = 0 for x, y in tqdm(train_loader): x = x.to(device) ### TEMPLATES ### D = x[:, 0, ...].unsqueeze(1) D = rescale(D) # Scale to [0,1] interval D = D.repeat(1, args.nclusters, 1, 1) with torch.no_grad(): if isinstance(model, torch.nn.DataParallel): z_logp = model.module(D.view(-1, *input_size[1:]), 0, classify=False) else: z_logp = model(D.view(-1, *input_size[1:]), 0, classify=False) z, delta_logp = z_logp if isinstance(gmm, torch.nn.DataParallel): logpz, params = gmm.module( z.view(-1, args.nclusters, args.imagesize, args.imagesize), x.permute(0, 2, 3, 1)) else: logpz, params = gmm( z.view(-1, args.nclusters, args.imagesize, args.imagesize), x.permute(0, 2, 3, 1)) mu, std, gamma = params mu = mu.cpu().numpy() std = std.cpu().numpy() gamma = gamma.cpu().numpy() mu = mu[..., np.newaxis] std = std[..., np.newaxis] mu = np.swapaxes(mu, 0, 1) # (3,4,1) -> (4,3,1) mu = np.swapaxes(mu, 1, 2) # (4,3,1) -> (4,1,3) std = np.swapaxes(std, 0, 1) # (3,4,1) -> (4,3,1) std = np.swapaxes(std, 1, 2) # (4,3,1) -> (4,1,3) N = N + 1 mu_tmpl = (N - 1) / N * mu_tmpl + 1 / N * mu std_tmpl = (N - 1) / N * std_tmpl + 1 / N * std if idx == len(train_loader) - 1: break idx += 1 print("Estimated Mu for template(s):") print(mu_tmpl) print("Estimated Sigma for template(s):") print(std_tmpl) metrics = dict() for tc in range(1, args.nclusters + 1): metrics[f'mean_{tc}'] = [] metrics[f'median_{tc}'] = [] metrics[f'perc_95_{tc}'] = [] metrics[f'nmi_{tc}'] = [] metrics[f'sd_{tc}'] = [] metrics[f'cv_{tc}'] = [] print( f"Predicting on {len(test_loader)} batches of {args.val_batchsize} templates..." ) idx = 0 for x_test, y_test in tqdm(test_loader): x_test = x_test.to(device) ### DEPLOY ### D = x_test[:, 0, ...].unsqueeze(1) D = rescale(D) # Scale to [0,1] interval D = D.repeat(1, args.nclusters, 1, 1) with torch.no_grad(): if isinstance(model, torch.nn.DataParallel): z_logp = model.module(D.view(-1, *input_size[1:]), 0, classify=False) else: z_logp = model(D.view(-1, *input_size[1:]), 0, classify=False) z, delta_logp = z_logp if isinstance(gmm, torch.nn.DataParallel): logpz, params = gmm.module( z.view(-1, args.nclusters, args.imagesize, args.imagesize), x_test.permute(0, 2, 3, 1)) else: logpz, params = gmm( z.view(-1, args.nclusters, args.imagesize, args.imagesize), x_test.permute(0, 2, 3, 1)) mu, std, pi = params mu = mu.cpu().numpy() std = std.cpu().numpy() pi = pi.cpu().numpy() mu = mu[..., np.newaxis] std = std[..., np.newaxis] mu = np.swapaxes(mu, 0, 1) # (3,4,1) -> (4,3,1) mu = np.swapaxes(mu, 1, 2) # (4,3,1) -> (4,1,3) std = np.swapaxes(std, 0, 1) # (3,4,1) -> (4,3,1) std = np.swapaxes(std, 1, 2) # (4,3,1) -> (4,1,3) X_hsd = np.swapaxes(x_test.cpu().numpy(), 1, 2) X_hsd = np.swapaxes(X_hsd, 2, 3) X_conv = imgtf.image_dist_transform(X_hsd, mu, std, pi, mu_tmpl, std_tmpl, args) ClsLbl = np.argmax(np.asarray(pi), axis=-1) ClsLbl = ClsLbl.astype('int32') mean_rgb = np.mean(X_conv, axis=-1) pdb.set_trace() for tc in range(1, args.nclusters + 1): msk = ClsLbl == tc if not msk.any(): continue # skip metric if no class labels are found ma = mean_rgb[msk] mean = np.mean(ma) median = np.median(ma) perc = np.percentile(ma, 95) nmi = median / perc metrics[f'mean_{tc}'].append(mean) metrics[f'median_{tc}'].append(median) metrics[f'perc_95_{tc}'].append(perc) metrics[f'nmi_{tc}'].append(nmi) if idx == len(test_loader) - 1: break idx += 1 av_sd = [] av_cv = [] for tc in range(1, args.nclusters + 1): if len(metrics[f'mean_{tc}']) == 0: continue metrics[f'sd_{tc}'] = np.array(metrics[f'nmi_{tc}']).std() metrics[f'cv_{tc}'] = np.array(metrics[f'nmi_{tc}']).std() / np.array( metrics[f'nmi_{tc}']).mean() print(f'sd_{tc}:', metrics[f'sd_{tc}']) print(f'cv_{tc}:', metrics[f'cv_{tc}']) av_sd.append(metrics[f'sd_{tc}']) av_cv.append(metrics[f'cv_{tc}']) print(f"Average sd = {np.array(av_sd).mean()}") print(f"Average cv = {np.array(av_cv).mean()}") import csv file = open(f"metrics-{args.train_centers[0]}-{args.val_centers[0]}.csv", "w") writer = csv.writer(file) for key, value in metrics.items(): writer.writerow([key, value]) file.close() # correct = 0 # total = 0 # start = time.time() # with torch.no_grad(): # for i, (x, y) in enumerate(tqdm(test_loader)): # x = x.to(device) # bpd, logits, _, _ = compute_loss(x, model) # bpd_meter.update(bpd.item(), x.size(0)) # val_time = time.time() - start # if ema is not None: # ema.swap() # s = 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {bpd_meter.avg:.4f}'.format(epoch, val_time, bpd_meter=bpd_meter) # if args.task in ['classification', 'hybrid']: # s += ' | CE {:.4f} | Acc {:.2f}'.format(ce_meter.avg, 100 * correct / total) # logger.info(s) # return bpd_meter.avg return
def run(args, kwargs): # ================================================================================================================== # SNAPSHOTS # ================================================================================================================== args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') args.model_signature = args.model_signature.replace(':', '_') if args.automatic_saving == True: path = '{}/{}/{}/{}/{}/{}/{}/{}/{}/'.format(args.solver, args.dataset, args.layer_type, args.atol, args.rtol, args.atol_start, args.rtol_start, args.warmup_steps, args.manual_seed) else: path = 'test/' args.snap_dir = os.path.join(args.out_dir, path) if not os.path.exists(args.snap_dir): os.makedirs(args.snap_dir) # logger utils.makedirs(args.snap_dir) logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # SAVING torch.save(args, args.snap_dir + 'config.config') # ================================================================================================================== # LOAD DATA # ================================================================================================================== train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) if not args.evaluate: nfef_meter = utils.AverageMeter() nfeb_meter = utils.AverageMeter() # ============================================================================================================== # SELECT MODEL # ============================================================================================================== # flow parameters and architecture choice are passed on to model through args if args.flow == 'no_flow': model = VAE.VAE(args) elif args.flow == 'planar': model = VAE.PlanarVAE(args) elif args.flow == 'iaf': model = VAE.IAFVAE(args) elif args.flow == 'orthogonal': model = VAE.OrthogonalSylvesterVAE(args) elif args.flow == 'householder': model = VAE.HouseholderSylvesterVAE(args) elif args.flow == 'triangular': model = VAE.TriangularSylvesterVAE(args) elif args.flow == 'cnf': model = CNFVAE.CNFVAE(args) elif args.flow == 'cnf_bias': model = CNFVAE.AmortizedBiasCNFVAE(args) elif args.flow == 'cnf_hyper': model = CNFVAE.HypernetCNFVAE(args) elif args.flow == 'cnf_lyper': model = CNFVAE.LypernetCNFVAE(args) elif args.flow == 'cnf_rank': model = CNFVAE.AmortizedLowRankCNFVAE(args) else: raise ValueError('Invalid flow choice') if args.retrain_encoder: logger.info(f"Initializing decoder from {args.model_path}") dec_model = torch.load(args.model_path) dec_sd = {} for k, v in dec_model.state_dict().items(): if 'p_x' in k: dec_sd[k] = v model.load_state_dict(dec_sd, strict=False) if args.cuda: logger.info("Model on GPU") model.cuda() logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if args.retrain_encoder: parameters = [] logger.info('Optimizing over:') for name, param in model.named_parameters(): if 'p_x' not in name: logger.info(name) parameters.append(param) else: parameters = model.parameters() optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) # ================================================================================================================== # TRAINING # ================================================================================================================== train_loss = [] val_loss = [] # for early stopping best_loss = np.inf best_bpd = np.inf e = 0 epoch = 0 train_times = [] for epoch in range(1, args.epochs + 1): atol, rtol = update_tolerances(args, epoch, decay_factors) print(atol) set_cnf_options(args, atol, rtol, model) t_start = time.time() if 'cnf' not in args.flow: tr_loss = train(epoch, train_loader, model, optimizer, args, logger) else: tr_loss, nfef_meter, nfeb_meter = train( epoch, train_loader, model, optimizer, args, logger, nfef_meter, nfeb_meter) train_loss.append(tr_loss) train_times.append(time.time() - t_start) logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) val_loss.append(v_loss) # early-stopping if v_loss < best_loss: e = 0 best_loss = v_loss if args.input_type != 'binary': best_bpd = v_bpd logger.info('->model saved<-') torch.save(model, args.snap_dir + 'model.model') # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): e += 1 if e > args.early_stopping_epochs: break if args.input_type == 'binary': logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format( e, args.early_stopping_epochs, best_loss)) else: logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n' .format(e, args.early_stopping_epochs, best_loss, best_bpd)) if math.isnan(v_loss): raise ValueError('NaN encountered!') train_loss = np.hstack(train_loss) val_loss = np.array(val_loss) plot_training_curve(train_loss, val_loss, fname=args.snap_dir + '/training_curve.pdf') # training time per epoch train_times = np.array(train_times) mean_train_time = np.mean(train_times) std_train_time = np.std(train_times, ddof=1) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) # ================================================================================================================== # EVALUATION # ================================================================================================================== logger.info(args) logger.info('Stopped after %d epochs' % epoch) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) final_model = torch.load(args.snap_dir + 'model.model') validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) else: validation_loss = "N/A" validation_bpd = "N/A" logger.info(f"Loading model from {args.model_path}") final_model = torch.load(args.model_path) test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) logger.info( 'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format( validation_loss))
return atol, rtol if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 2, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if not args.only_viz_samples: optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.AverageMeter() loss_meter = utils.AverageMeter() nfef_meter = utils.AverageMeter() nfeb_meter = utils.AverageMeter() tt_meter = utils.AverageMeter() end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): atol, rtol = update_tolerances(args, itr, decay_factors) set_cnf_options(args, atol, rtol, model) optimizer.zero_grad() if args.spectral_norm: spectral_norm_power_iteration(model, 1)
x0=torch.randn(args.batch_size,1)-3+6*((torch.rand(args.batch_size,1))>0.5).float() x0 = cvt(x0) # x0val = toy_data.inf_train_gen(args.data, batch_size=args.val_batch_size) x0val=torch.randn(args.batch_size,1)-3+6*((torch.rand(args.batch_size,1))>0.5).float() x0val = cvt(x0val) log_msg = ( '{:5s} {:6s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} {:9s} '.format( 'iter', ' time','loss', 'L (L_2)', 'C (loss)', 'R (HJB)', 'valLoss', 'valL', 'valC', 'valR' ) ) logger.info(log_msg) time_meter = utils.AverageMeter() net.train() for itr in range(1, args.niters + 1): # train optim.zero_grad() loss, costs = compute_loss(net, x0, nt=nt) loss.backward() optim.step() time_meter.update(time.time() - end) log_message = ( '{:05d} {:6.3f} {:9.3e} {:9.3e} {:9.3e} {:9.3e} '.format( itr, time_meter.val , loss, costs[0], costs[1], costs[2] )
logger.info( 'must use --resume flag to provide the state_dict to evaluate') exit(1) logger.info(model) nWeights = count_parameters(model) logger.info("Number of trainable parameters: {}".format(nWeights)) logger.info('Evaluating model on test set.') model.eval() override_divergence_fn(model, "brute_force") bInverse = True # check one batch for inverse error, for speed with torch.no_grad(): test_loss = utils.AverageMeter() test_nfe = utils.AverageMeter() for itr, x in enumerate( batch_iter(data.tst.x, batch_size=test_batch_size)): x = cvt(x) test_loss.update(compute_loss(x, model).item(), x.shape[0]) test_nfe.update(count_nfe(model)) if bInverse: # check the ivnerse error z = model(x, reverse=False) # push forward xpred = model(z, reverse=True) # inverse logger.info('inverse norm for first batch: ') logger.info(torch.norm(xpred - x).item() / x.shape[0]) bInverse = False
if not cf.gpu: # assume debugging and run a subset nSamples = 1000 testData = testData[:nSamples, :] normSamples = normSamples[:nSamples, :] if args.long_version: ffjordFx = ffjordFx[:nSamples, :] ffjordFinvfx = ffjordFinvfx[:nSamples, :] ffjordGen = ffjordGen[:nSamples, :] net.eval() with torch.no_grad(): # meters to hold testing results testLossMeter = utils.AverageMeter() testAlphMeterL = utils.AverageMeter() testAlphMeterC = utils.AverageMeter() testAlphMeterR = utils.AverageMeter() # scale the GAS data set as it was in the training if args.data == 'gas': print(torch.min(testData),torch.max(testData)) testData = testData / 5.0 itr = 1 for x0 in batch_iter(testData, batch_size=args.batch_size): x0 = cvt(x0) nex = x0.shape[0] test_loss, test_cs = compute_loss(net, x0, nt=nt_test)