def train_eval(device, args, model, growth_model, itr, best_loss, logger, full_data): model.eval() test_loss = compute_loss(device, args, model, growth_model, logger, full_data) test_nfe = count_nfe(model) log_message = "[TEST] Iter {:04d} | Test Loss {:.6f} |" " NFE {:.0f}".format( itr, test_loss, test_nfe ) logger.info(log_message) utils.makedirs(args.save) with open(os.path.join(args.save, "train_eval.csv"), "a") as f: import csv writer = csv.writer(f) writer.writerow((itr, test_loss)) if test_loss.item() < best_loss: best_loss = test_loss.item() torch.save( { # 'args': args, "state_dict": model.state_dict(), "growth_state_dict": growth_model.state_dict(), }, os.path.join(args.save, "checkpt.pth"), )
def compare_with_DV_particle_method(args, model, dim, batch_size=None): if batch_size is None: batch_size = args.batch_size x = torch.randn([batch_size, dim], dtype=torch.float32, device=device) diff_0 = torch.zeros(1, dtype=torch.float32, device=device) x_t, diff_t = model(x, diff_0, integration_times=args.time_length) nfe = count_nfe(model) torch.save(x_t, 'output/DVP_output_gaussian_mixture.pt') return diff_t[0] / nfe
def compute_loss_wgf(args, model, dim, batch_size=None): if batch_size is None: batch_size = args.batch_size z = torch.randn(batch_size, dim, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) # mu_0 = torch.zeros(2, dtype=torch.float32, device=device) # sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device) # score_error_0 = torch.zeros(1, dtype=torch.float32, device=device) x, logp_x, score_x, wgf_reg = model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0) nfe = count_nfe(model) return wgf_reg / nfe
def score_error_wgf(args, model, batch_size=None): if batch_size is None: batch_size = args.batch_size # TODO: should have an input specifying the data dimension. Now it is fixed to 2 z = torch.randn(batch_size, 2, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) mu_0 = torch.zeros(2, dtype=torch.float32, device=device) sigma_half_0 = torch.eye(2, dtype=torch.float32, device=device) score_error_0 = torch.zeros(1, dtype=torch.float32, device=device) # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0) x, logp_x, score_x, wgf_reg, mu, sigma_half, score_error = \ model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0, mu_0=mu_0, sigma_half_0=sigma_half_0, score_error_0=score_error_0) nfe = count_nfe(model) return score_error / nfe
def compute_likelihood(args, model, batch_size=None): if batch_size is None: batch_size = args.batch_size # TODO: should have an input specifying the data dimension. Now it is fixed to 2 z = torch.randn(batch_size, 2, dtype=torch.float32, device=device) logp_z = standard_normal_logprob(z).sum(1, keepdim=True).to(z) score_z = standard_normal_score(z).to(z) wgf_reg_0 = torch.tensor(0, device=device) # x, logp_x, score_x, wgf_reg = model(z, logp_z, score_z, wgf_reg_0) x, logp_x, score_x, wgf_reg = model(z, logpz=logp_z, score=score_z, wgf_reg=wgf_reg_0) nfe = count_nfe(model) logp_true_x = gaussian_logprob(x).sum(1, keepdim=True).to(z) # logp_true_x = gaussian_mixture_logprob(x) # print(torch.mean(x, 0)) return -torch.mean(logp_true_x)
def compare_with_Gaussian(args, model, dim, batch_size=None): if batch_size is None: batch_size = args.batch_size x = torch.randn([batch_size, dim], dtype=torch.float32, device=device) mu_0 = torch.zeros(dim, dtype=torch.float32, device=device) sigma_half_0 = torch.eye(dim, dtype=torch.float32, device=device) diff_0 = torch.zeros(1, dtype=torch.float32, device=device) x_t, mu_t, sigma_half_t, diff_t = model(x, mu_0, sigma_half_0, diff_0, integration_times=args.time_length) nfe = count_nfe(model) # print(torch.mean(x_t, dim=0)) # print(torch.mean(y_t, dim=0)) # print(torch.norm(score_t[0])**2) return diff_t[0] / nfe
logger.info("Number of trainable parameters: {}".format(nWeights)) logger.info('Evaluating model on test set.') model.eval() override_divergence_fn(model, "brute_force") bInverse = True # check one batch for inverse error, for speed with torch.no_grad(): test_loss = utils.AverageMeter() test_nfe = utils.AverageMeter() for itr, x in enumerate( batch_iter(data.tst.x, batch_size=test_batch_size)): x = cvt(x) test_loss.update(compute_loss(x, model).item(), x.shape[0]) test_nfe.update(count_nfe(model)) if bInverse: # check the ivnerse error z = model(x, reverse=False) # push forward xpred = model(z, reverse=True) # inverse logger.info('inverse norm for first batch: ') logger.info(torch.norm(xpred - x).item() / x.shape[0]) bInverse = False logger.info('Progress: {:.2f}%'.format( 100. * itr / (data.tst.x.shape[0] / test_batch_size))) log_message = '[TEST] Iter {:06d} | Test Loss {:.6f} | NFE {:.0f}'.format( itr, test_loss.avg, test_nfe.avg) logger.info(log_message)
def train(epoch, train_loader, model, opt, args, logger): model.train() train_loss = np.zeros(len(train_loader)) train_bpd = np.zeros(len(train_loader)) num_data = 0 # set warmup coefficient beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta]) logger.info('beta = {:5.4f}'.format(beta)) end = time.time() for batch_idx, (data, target) in enumerate(train_loader): if args.cuda: data = data.cuda() target = target.cuda() if args.dynamic_binarization: data = torch.bernoulli(data) data = data.view(-1, *args.input_size) opt.zero_grad() if args.conditional: x_mean, z_mu, z_var, ldj, z0, zk = model(data, target) else: x_mean, z_mu, z_var, ldj, z0, zk = model(data) # if batch_idx == len(train_loader)-1: # print('-'*10 ,) # for i in range(len(x_mean)): # print(x_mean[i].data[0].item(), x_mean[i].data[1].item(), data[i].data[0].item(), data[i].data[1].item()) if 'cnf' in args.flow: f_nfe = count_nfe(model) loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta) loss.backward() if 'cnf' in args.flow: t_nfe = count_nfe(model) b_nfe = t_nfe - f_nfe train_loss[batch_idx] = loss.item() train_bpd[batch_idx] = bpd opt.step() rec = rec.item() kl = kl.item() num_data += len(data) batch_time = time.time() - end end = time.time() if batch_idx % args.log_interval == 0: if args.input_type == 'binary': perc = 100. * batch_idx / len(train_loader) log_msg = ( 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | ' 'Rec {:11.6f} | KL {:11.6f}'.format( epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl)) else: perc = 100. * batch_idx / len(train_loader) tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}' log_msg = tmp.format( epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), bpd), '\trec: {:11.3f}\tkl: {:11.6f}\tvar: {}'.format( rec, kl, torch.mean(torch.mean(z_var, dim=0))) log_msg = "".join(log_msg) if 'cnf' in args.flow: log_msg += ' | NFE Forward {} | NFE Backward {}'.format( f_nfe, b_nfe) logger.info(log_msg) if args.input_type == 'binary': logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format( epoch, train_loss.sum() / len(train_loader))) else: logger.info( '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}' .format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader))) return train_loss
loss = loss + reg_loss total_time = count_total_time(model) loss = loss + total_time * args.time_penalty loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() if args.spectral_norm: spectral_norm_power_iteration(model, args.spectral_norm_niter) time_meter.update(time.time() - start) loss_meter.update(loss.item()) steps_meter.update(count_nfe(model)) grad_meter.update(grad_norm) tt_meter.update(total_time) if itr % args.log_freq == 0: log_message = ( "Iter {:04d} | Time {:.4f}({:.4f}) | Bit/dim {:.4f}({:.4f}) | " "Steps {:.0f}({:.2f}) | Grad Norm {:.4f}({:.4f}) | Total Time {:.2f}({:.2f})" .format(itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, reg_states) logger.info(log_message)
def train(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): optimizer.zero_grad() loss = compute_loss(args, model) loss_meter.update(loss.item()) total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg)) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() test_loss = compute_loss(args, model, batch_size=args.test_batch_size) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format( itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save( { 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() xx = torch.linspace(-10, 10, 10000).view(-1, 1) true_p = data_density(xx) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') true_p = model_density(xx, model) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'figs')) plt.savefig( os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) plt.close() model.train() end = time.time() logger.info('Training has finished.')
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
def train(epoch, train_loader, model, opt, args, logger, nfef_meter=None, nfeb_meter=None): model.train() train_loss = np.zeros(len(train_loader)) train_bpd = np.zeros(len(train_loader)) num_data = 0 # set warmup coefficient beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta]) logger.info('beta = {:5.4f}'.format(beta)) end = time.time() for batch_idx, (data, _) in enumerate(train_loader): if args.cuda: data = data.cuda() if args.dynamic_binarization: data = torch.bernoulli(data) data = data.view(-1, *args.input_size) opt.zero_grad() x_mean, z_mu, z_var, ldj, z0, zk = model(data, is_eval=False, epoch=epoch) if 'cnf' in args.flow: f_nfe = count_nfe(model) loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta) loss.backward() if 'cnf' in args.flow: t_nfe = count_nfe(model) b_nfe = t_nfe - f_nfe nfef_meter.update(f_nfe) nfeb_meter.update(b_nfe) train_loss[batch_idx] = loss.item() train_bpd[batch_idx] = bpd opt.step() rec = rec.item() kl = kl.item() num_data += len(data) batch_time = time.time() - end end = time.time() if batch_idx % args.log_interval == 0: if args.input_type == 'binary': perc = 100. * batch_idx / len(train_loader) log_msg = ( 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | ' 'Rec {:11.6f} | KL {:11.6f}'.format( epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl)) else: perc = 100. * batch_idx / len(train_loader) tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}' log_msg = tmp.format( epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl) log_msg = "".join(log_msg) if 'cnf' in args.flow: log_msg += ' | NFE Forward {:.0f}({:.1f}) | NFE Backward {:.0f}({:.1f})'.format( f_nfe, nfef_meter.avg, b_nfe, nfeb_meter.avg) logger.info(log_msg) if args.input_type == 'binary': logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format( epoch, train_loss.sum() / len(train_loader))) else: logger.info( '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}' .format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader))) if 'cnf' not in args.flow: return train_loss else: return train_loss, nfef_meter, nfeb_meter
def train(args, model, growth_model): logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() growth_model.eval() for itr in range(1, args.niters + 1): optimizer.zero_grad() #growth_optimizer.zero_grad() ### Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1) loss = compute_loss(args, model, growth_model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss #if len(growth_regularization_coeffs) > 0: # growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs) # reg_loss = sum( # reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0 # ) # loss2 = loss2 + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() #loss2.backward() optimizer.step() #growth_optimizer.step() ### Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() growth_model.eval() test_loss = compute_loss(args, model, growth_model) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': model.state_dict(), 'growth_state_dict': growth_model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() for i, tp in enumerate(timepoints): p_samples = viz_sampler(tp) sample_fn, density_fn = get_transforms(model, int_tps[:i+1]) #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1]) plt.figure(figsize=(9, 3)) visualize_transform( p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, samples=True, npts=100, device=device ) fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() #visualize_transform( # p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, # inverse_transform=growth_density_fn, # samples=True, npts=800, device=device #) #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i)) #utils.makedirs(os.path.dirname(fig_filename)) #plt.savefig(fig_filename) #plt.close() model.train() """ if itr % args.viz_freq_growth == 0: with torch.no_grad(): growth_model.eval() # Visualize growth transform growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr)) utils.makedirs(os.path.dirname(growth_filename)) visualize_growth(growth_model, data, labels, npts=200, device=device) plt.savefig(growth_filename) plt.close() growth_model.train() """ end = time.time() logger.info('Training has finished.')
x = cvt(x) loss = compute_loss(x, model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) if itr % args.log_freq == 0: log_message = ( 'Iter {:06d} | Epoch {:.2f} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | '
def train(epoch, train_loader, model, opt, args, wandb): model.train() train_loss = np.zeros(len(train_loader)) train_bpd = np.zeros(len(train_loader)) num_data = 0 # set warmup coefficient beta = min([(epoch * 1.) / max([args.warmup, 1.]), args.max_beta]) # logger.info('beta = {:5.4f}'.format(beta)) end = time.time() for batch_idx, (data, _) in enumerate(train_loader): if args.cuda: data = data.cuda() if args.dynamic_binarization: data = torch.bernoulli(data) data = data.view(-1, *args.input_size) opt.zero_grad() x_mean, z_mu, z_var, ldj, z0, zk = model(data) if 'cnf' in args.flow: f_nfe = count_nfe(model) loss, rec, kl, bpd = calculate_loss(x_mean, data, z_mu, z_var, z0, zk, ldj, args, beta=beta) loss.backward() if 'cnf' in args.flow: t_nfe = count_nfe(model) b_nfe = t_nfe - f_nfe train_loss[batch_idx] = loss.item() train_bpd[batch_idx] = bpd wandb.log({ 'train_loss': loss.item(), 'train_bpd': bpd, 'nfe': t_nfe, 'nbe': b_nfe }) opt.step() rec = rec.item() kl = kl.item() num_data += len(data) batch_time = time.time() - end end = time.time() # if batch_idx % args.log_interval == 0: # if args.input_type == 'binary': # perc = 100. * batch_idx / len(train_loader) # log_msg = ( # 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | ' # 'Rec {:11.6f} | KL {:11.6f}'.format( # epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), rec, kl # ) # ) # else: # perc = 100. * batch_idx / len(train_loader) # tmp = 'Epoch {:3d} [{:5d}/{:5d} ({:2.0f}%)] | Time {:.3f} | Loss {:11.6f} | Bits/dim {:8.6f}' # log_msg = tmp.format(epoch, num_data, len(train_loader.sampler), perc, batch_time, loss.item(), # bpd), '\trec: {:11.3f}\tkl: {:11.6f}'.format(rec, kl) # log_msg = "".join(log_msg) # if 'cnf' in args.flow: # log_msg += ' | NFE Forward {} | NFE Backward {}'.format(f_nfe, b_nfe) # logger.info(log_msg) # if args.input_type == 'binary': # logger.info('====> Epoch: {:3d} Average train loss: {:.4f}'.format(epoch, train_loss.sum() / len(train_loader))) # else: # logger.info( # '====> Epoch: {:3d} Average train loss: {:.4f}, average bpd: {:.4f}'. # format(epoch, train_loss.sum() / len(train_loader), train_bpd.sum() / len(train_loader)) # ) return train_loss.sum() / len(train_loader)
for itr in range(1, args.niters + 1): optimizer.zero_grad() if args.spectral_norm: spectral_norm_power_iteration(model, 1) loss = compute_loss(args, model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format(
def train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger ): optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) full_data = ( torch.from_numpy( args.data.get_data()[args.data.get_times() != args.leaveout_timepoint] ) .type(torch.float32) .to(device) ) best_loss = float("inf") growth_model.eval() end = time.time() for itr in range(1, args.niters + 1): model.train() optimizer.zero_grad() # Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) loss = compute_loss(device, args, model, growth_model, logger, full_data) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() # Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |" " NFE Forward {:.0f}({:.1f})" " | NFE Backward {:.0f}({:.1f})".format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log( log_message, regularization_fns, reg_states ) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): train_eval( device, args, model, growth_model, itr, best_loss, logger, full_data ) if itr % args.viz_freq == 0: if args.data.get_shape()[0] > 2: logger.warning("Skipping vis as data dimension is >2") else: with torch.no_grad(): visualize(device, args, model, itr) if itr % args.save_freq == 0: utils.save_checkpoint( { # 'args': args, "state_dict": model.state_dict(), "growth_state_dict": growth_model.state_dict(), }, args.save, epoch=itr, ) end = time.time() logger.info("Training has finished.")