nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() test_loss = compute_loss(args, model, batch_size=args.test_batch_size) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save({
def train(args, model, growth_model): logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() growth_model.eval() for itr in range(1, args.niters + 1): optimizer.zero_grad() #growth_optimizer.zero_grad() ### Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1) loss = compute_loss(args, model, growth_model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss #if len(growth_regularization_coeffs) > 0: # growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs) # reg_loss = sum( # reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0 # ) # loss2 = loss2 + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() #loss2.backward() optimizer.step() #growth_optimizer.step() ### Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() growth_model.eval() test_loss = compute_loss(args, model, growth_model) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': model.state_dict(), 'growth_state_dict': growth_model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() for i, tp in enumerate(timepoints): p_samples = viz_sampler(tp) sample_fn, density_fn = get_transforms(model, int_tps[:i+1]) #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1]) plt.figure(figsize=(9, 3)) visualize_transform( p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, samples=True, npts=100, device=device ) fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() #visualize_transform( # p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, # inverse_transform=growth_density_fn, # samples=True, npts=800, device=device #) #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i)) #utils.makedirs(os.path.dirname(fig_filename)) #plt.savefig(fig_filename) #plt.close() model.train() """ if itr % args.viz_freq_growth == 0: with torch.no_grad(): growth_model.eval() # Visualize growth transform growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr)) utils.makedirs(os.path.dirname(growth_filename)) visualize_growth(growth_model, data, labels, npts=200, device=device) plt.savefig(growth_filename) plt.close() growth_model.train() """ end = time.time() logger.info('Training has finished.')
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
def train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger ): optimizer = optim.Adam( model.parameters(), lr=args.lr, weight_decay=args.weight_decay ) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) full_data = ( torch.from_numpy( args.data.get_data()[args.data.get_times() != args.leaveout_timepoint] ) .type(torch.float32) .to(device) ) best_loss = float("inf") growth_model.eval() end = time.time() for itr in range(1, args.niters + 1): model.train() optimizer.zero_grad() # Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) loss = compute_loss(device, args, model, growth_model, logger, full_data) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() # Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( "Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) |" " NFE Forward {:.0f}({:.1f})" " | NFE Backward {:.0f}({:.1f})".format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log( log_message, regularization_fns, reg_states ) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): train_eval( device, args, model, growth_model, itr, best_loss, logger, full_data ) if itr % args.viz_freq == 0: if args.data.get_shape()[0] > 2: logger.warning("Skipping vis as data dimension is >2") else: with torch.no_grad(): visualize(device, args, model, itr) if itr % args.save_freq == 0: utils.save_checkpoint( { # 'args': args, "state_dict": model.state_dict(), "growth_state_dict": growth_model.state_dict(), }, args.save, epoch=itr, ) end = time.time() logger.info("Training has finished.")