def visualize(device, args, model, itr): model.eval() for i, tp in enumerate(args.timepoints): idx = args.data.sample_index(args.viz_batch_size, tp) p_samples = args.data.get_data()[idx] sample_fn, density_fn = get_transforms( device, args, model, args.int_tps[: i + 1] ) plt.figure(figsize=(9, 3)) visualize_transform( p_samples, args.data.base_sample(), args.data.base_density(), transform=sample_fn, inverse_transform=density_fn, samples=True, npts=100, device=device, ) fig_filename = os.path.join( args.save, "figs", "{:04d}_{:01d}.jpg".format(itr, i) ) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close()
torch.save({ 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() #if itr == 1 or itr % args.viz_freq == 0: if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() p_samples = toy_data.inf_train_gen(args.data, batch_size=20000) sample_fn, density_fn = model.inverse, model.forward plt.figure(figsize=(9, 3)) visualize_transform(p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, samples=True, npts=400, device=device) fig_filename = os.path.join(args.save, 'figs', '{:04d}.jpg'.format(itr)) print('') print(fig_filename) print('') utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) #plt.ion() plt.show() plt.pause(0.5) plt.close()
def train(args, model, growth_model): logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() growth_model.eval() for itr in range(1, args.niters + 1): optimizer.zero_grad() #growth_optimizer.zero_grad() ### Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1) loss = compute_loss(args, model, growth_model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss #if len(growth_regularization_coeffs) > 0: # growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs) # reg_loss = sum( # reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0 # ) # loss2 = loss2 + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() #loss2.backward() optimizer.step() #growth_optimizer.step() ### Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() growth_model.eval() test_loss = compute_loss(args, model, growth_model) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': model.state_dict(), 'growth_state_dict': growth_model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() for i, tp in enumerate(timepoints): p_samples = viz_sampler(tp) sample_fn, density_fn = get_transforms(model, int_tps[:i+1]) #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1]) plt.figure(figsize=(9, 3)) visualize_transform( p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, samples=True, npts=100, device=device ) fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() #visualize_transform( # p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, # inverse_transform=growth_density_fn, # samples=True, npts=800, device=device #) #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i)) #utils.makedirs(os.path.dirname(fig_filename)) #plt.savefig(fig_filename) #plt.close() model.train() """ if itr % args.viz_freq_growth == 0: with torch.no_grad(): growth_model.eval() # Visualize growth transform growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr)) utils.makedirs(os.path.dirname(growth_filename)) visualize_growth(growth_model, data, labels, npts=200, device=device) plt.savefig(growth_filename) plt.close() growth_model.train() """ end = time.time() logger.info('Training has finished.')