def get_ckpt_model_and_data(args): # Load checkpoint. checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) ckpt_args = checkpt['args'] state_dict = checkpt['state_dict'] # Construct model and restore checkpoint. regularization_fns, regularization_coeffs = create_regularization_fns( ckpt_args) model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) if ckpt_args.spectral_norm: add_spectral_norm(model) set_cnf_options(ckpt_args, model) model.load_state_dict(state_dict) model.to(device) print(model) print("Number of trainable parameters: {}".format( count_parameters(model))) # Load samples from dataset data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) return model, data_samples
z, change = model(x, zero) logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change loss = -torch.mean(logpx) return loss if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 2, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): optimizer.zero_grad() if args.spectral_norm: spectral_norm_power_iteration(model, 1)
def main(): # os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu") else: device = torch.cuda.current_device() # # import pdb; pdb.set_trace() cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm'] testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost'] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns).cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters # import pdb; pdb.set_trace() if args.resume is not None: # import pdb; pdb.set_trace() print('resume from checkpoint') checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 chkdir = args.save ''' elif args.resume and args.validate: chkdir = os.path.dirname(args.resume) wall_clock = 0 itr = 0 best_loss = 0.0 begin_epoch = 0 ''' else: chkdir = os.path.dirname(args.resume) filename = os.path.join(chkdir, 'test.csv') print(filename) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) # import pdb; pdb.set_trace() wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, begin_epoch + 1): # compute test loss print('Evaluating') model.eval() if args.local_rank == 0: utils.makedirs(args.save) # import pdb; pdb.set_trace() if hasattr(model, 'module'): _state = model.module.state_dict() else: _state = model.state_dict() torch.save({ "args": args, "state_dict": _state, # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt_%d.pth" % epoch)) # save real and generate with different temperatures fig_num = 64 if True: # args.save_real: for i, (x, y) in enumerate(test_loader): if i < 100: pass elif i == 100: real = x.size(0) else: break if x.shape[0] > fig_num: x = x[:fig_num, ...] # import pdb; pdb.set_trace() fig_filename = os.path.join(chkdir, "real.jpg") save_image(x.float() / 255.0, fig_filename, nrow=8) if True: # args.generate: print('\nGenerating images... ') fixed_z = cvt(torch.randn(fig_num, *data_shape)) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]: # visualize samples and density fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t) utils.makedirs(os.path.dirname(fig_filename)) generated_samples = model(t * fixed_z, reverse=True) x = unshift(generated_samples[0].view(-1, *data_shape), 8) save_image(x, fig_filename, nrow=nb)
logger.info(k) if args.resume is not None: checkpt = torch.load(args.resume) # Backwards compatibility with an older version of the code. # TODO: remove upon release. filtered_state_dict = {} for k, v in checkpt['state_dict'].items(): if 'diffeq.diffeq' not in k: filtered_state_dict[k.replace('module.', '')] = v model.load_state_dict(filtered_state_dict) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if not args.evaluate: optimizer = Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.98) loss_meter = utils.RunningAverageMeter(0.98) nfef_meter = utils.RunningAverageMeter(0.98) nfeb_meter = utils.RunningAverageMeter(0.98) tt_meter = utils.RunningAverageMeter(0.98) best_loss = float('inf') itr = 0 n_vals_without_improvement = 0
for k in model.state_dict().keys(): logger.info(k) if args.resume is not None: logger.info('Training has finished.') model = restore_model(model, args.resume).to(device) set_cnf_options(args, model) else: logger.info( 'must use --resume flag to provide the state_dict to evaluate') exit(1) logger.info(model) nWeights = count_parameters(model) logger.info("Number of trainable parameters: {}".format(nWeights)) logger.info('Evaluating model on test set.') model.eval() override_divergence_fn(model, "brute_force") bInverse = True # check one batch for inverse error, for speed with torch.no_grad(): test_loss = utils.AverageMeter() test_nfe = utils.AverageMeter() for itr, x in enumerate( batch_iter(data.tst.x, batch_size=test_batch_size)): x = cvt(x)
def train(): model = build_model_tabular(args, 1).to(device) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() for itr in range(1, args.niters + 1): optimizer.zero_grad() loss = compute_loss(args, model) loss_meter.update(loss.item()) total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() optimizer.step() nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg)) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() test_loss = compute_loss(args, model, batch_size=args.test_batch_size) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format( itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save( { 'args': args, 'state_dict': model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() xx = torch.linspace(-10, 10, 10000).view(-1, 1) true_p = data_density(xx) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='True') true_p = model_density(xx, model) plt.plot(xx.view(-1).cpu().numpy(), true_p.view(-1).exp().cpu().numpy(), label='Model') utils.makedirs(os.path.join(args.save, 'figs')) plt.savefig( os.path.join(args.save, 'figs', '{:06d}.jpg'.format(itr))) plt.close() model.train() end = time.time() logger.info('Training has finished.')
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
if __name__ == '__main__': if args.discrete: model = construct_discrete_model().to(device) model.load_state_dict(torch.load(args.checkpt)['state_dict']) else: model = build_model_tabular(args, 2).to(device) sd = torch.load(args.checkpt)['state_dict'] fixed_sd = {} for k, v in sd.items(): fixed_sd[k.replace('odefunc.odefunc', 'odefunc')] = v model.load_state_dict(fixed_sd) print(model) print("Number of trainable parameters: {}".format(count_parameters(model))) model.eval() p_samples = toy_data.inf_train_gen(args.data, batch_size=800**2) with torch.no_grad(): sample_fn, density_fn = get_transforms(model) plt.figure(figsize=(10, 10)) ax = ax = plt.gca() viz_flow.plt_samples(p_samples, ax, npts=800) plt.subplots_adjust(left=0, right=1, top=1, bottom=0) fig_filename = os.path.join(args.save, 'figs', 'true_samples.jpg') utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close()
train_loader, val_loader = get_dataset(args) # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) aug_model = build_augmented_model_tabular( args, args.aug_size + args.effective_shape, regularization_fns=regularization_fns, ) set_cnf_options(args, aug_model) logger.info(aug_model) logger.info("Number of trainable parameters: {}".format( count_parameters(aug_model))) # optimizer parameter_list = list(aug_model.parameters()) optimizer, num_params = optimizer_factory(args, parameter_list) print("Num of Parameters: %d" % num_params) # restore parameters itr = 0 if args.resume is not None: checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage) aug_model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device.
def run(args, kwargs): # ================================================================================================================== # SNAPSHOTS # ================================================================================================================== args.model_signature = str(datetime.datetime.now())[0:19].replace(' ', '_') args.model_signature = args.model_signature.replace(':', '_') if args.automatic_saving == True: path = '{}/{}/{}/{}/{}/{}/{}/{}/{}/'.format(args.solver, args.dataset, args.layer_type, args.atol, args.rtol, args.atol_start, args.rtol_start, args.warmup_steps, args.manual_seed) else: path = 'test/' args.snap_dir = os.path.join(args.out_dir, path) if not os.path.exists(args.snap_dir): os.makedirs(args.snap_dir) # logger utils.makedirs(args.snap_dir) logger = utils.get_logger(logpath=os.path.join(args.snap_dir, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) # SAVING torch.save(args, args.snap_dir + 'config.config') # ================================================================================================================== # LOAD DATA # ================================================================================================================== train_loader, val_loader, test_loader, args = load_dataset(args, **kwargs) if not args.evaluate: nfef_meter = utils.AverageMeter() nfeb_meter = utils.AverageMeter() # ============================================================================================================== # SELECT MODEL # ============================================================================================================== # flow parameters and architecture choice are passed on to model through args if args.flow == 'no_flow': model = VAE.VAE(args) elif args.flow == 'planar': model = VAE.PlanarVAE(args) elif args.flow == 'iaf': model = VAE.IAFVAE(args) elif args.flow == 'orthogonal': model = VAE.OrthogonalSylvesterVAE(args) elif args.flow == 'householder': model = VAE.HouseholderSylvesterVAE(args) elif args.flow == 'triangular': model = VAE.TriangularSylvesterVAE(args) elif args.flow == 'cnf': model = CNFVAE.CNFVAE(args) elif args.flow == 'cnf_bias': model = CNFVAE.AmortizedBiasCNFVAE(args) elif args.flow == 'cnf_hyper': model = CNFVAE.HypernetCNFVAE(args) elif args.flow == 'cnf_lyper': model = CNFVAE.LypernetCNFVAE(args) elif args.flow == 'cnf_rank': model = CNFVAE.AmortizedLowRankCNFVAE(args) else: raise ValueError('Invalid flow choice') if args.retrain_encoder: logger.info(f"Initializing decoder from {args.model_path}") dec_model = torch.load(args.model_path) dec_sd = {} for k, v in dec_model.state_dict().items(): if 'p_x' in k: dec_sd[k] = v model.load_state_dict(dec_sd, strict=False) if args.cuda: logger.info("Model on GPU") model.cuda() logger.info(model) logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if args.retrain_encoder: parameters = [] logger.info('Optimizing over:') for name, param in model.named_parameters(): if 'p_x' not in name: logger.info(name) parameters.append(param) else: parameters = model.parameters() optimizer = optim.Adamax(parameters, lr=args.learning_rate, eps=1.e-7) # ================================================================================================================== # TRAINING # ================================================================================================================== train_loss = [] val_loss = [] # for early stopping best_loss = np.inf best_bpd = np.inf e = 0 epoch = 0 train_times = [] for epoch in range(1, args.epochs + 1): atol, rtol = update_tolerances(args, epoch, decay_factors) print(atol) set_cnf_options(args, atol, rtol, model) t_start = time.time() if 'cnf' not in args.flow: tr_loss = train(epoch, train_loader, model, optimizer, args, logger) else: tr_loss, nfef_meter, nfeb_meter = train( epoch, train_loader, model, optimizer, args, logger, nfef_meter, nfeb_meter) train_loss.append(tr_loss) train_times.append(time.time() - t_start) logger.info('One training epoch took %.2f seconds' % (time.time() - t_start)) v_loss, v_bpd = evaluate(val_loader, model, args, logger, epoch=epoch) val_loss.append(v_loss) # early-stopping if v_loss < best_loss: e = 0 best_loss = v_loss if args.input_type != 'binary': best_bpd = v_bpd logger.info('->model saved<-') torch.save(model, args.snap_dir + 'model.model') # torch.save(model, snap_dir + args.flow + '_' + args.architecture + '.model') elif (args.early_stopping_epochs > 0) and (epoch >= args.warmup): e += 1 if e > args.early_stopping_epochs: break if args.input_type == 'binary': logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f})\n'.format( e, args.early_stopping_epochs, best_loss)) else: logger.info( '--> Early stopping: {}/{} (BEST: loss {:.4f}, bpd {:.4f})\n' .format(e, args.early_stopping_epochs, best_loss, best_bpd)) if math.isnan(v_loss): raise ValueError('NaN encountered!') train_loss = np.hstack(train_loss) val_loss = np.array(val_loss) plot_training_curve(train_loss, val_loss, fname=args.snap_dir + '/training_curve.pdf') # training time per epoch train_times = np.array(train_times) mean_train_time = np.mean(train_times) std_train_time = np.std(train_times, ddof=1) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) # ================================================================================================================== # EVALUATION # ================================================================================================================== logger.info(args) logger.info('Stopped after %d epochs' % epoch) logger.info('Average train time per epoch: %.2f +/- %.2f' % (mean_train_time, std_train_time)) final_model = torch.load(args.snap_dir + 'model.model') validation_loss, validation_bpd = evaluate(val_loader, final_model, args, logger) else: validation_loss = "N/A" validation_bpd = "N/A" logger.info(f"Loading model from {args.model_path}") final_model = torch.load(args.model_path) test_loss, test_bpd = evaluate(test_loader, final_model, args, logger, testing=True) logger.info( 'FINAL EVALUATION ON VALIDATION SET. ELBO (VAL): {:.4f}'.format( validation_loss))
def train(args, model, growth_model): logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) #optimizer = optim.Adam(set(model.parameters()) | set(growth_model.parameters()), optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) #growth_optimizer = optim.Adam(growth_model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93) end = time.time() best_loss = float('inf') model.train() growth_model.eval() for itr in range(1, args.niters + 1): optimizer.zero_grad() #growth_optimizer.zero_grad() ### Train if args.spectral_norm: spectral_norm_power_iteration(model, 1) #if args.spectral_norm: spectral_norm_power_iteration(growth_model, 1) loss = compute_loss(args, model, growth_model) loss_meter.update(loss.item()) if len(regularization_coeffs) > 0: # Only regularize on the last timepoint reg_states = get_regularization(model, regularization_coeffs) reg_loss = sum( reg_state * coeff for reg_state, coeff in zip(reg_states, regularization_coeffs) if coeff != 0 ) loss = loss + reg_loss #if len(growth_regularization_coeffs) > 0: # growth_reg_states = get_regularization(growth_model, growth_regularization_coeffs) # reg_loss = sum( # reg_state * coeff for reg_state, coeff in zip(growth_reg_states, growth_regularization_coeffs) if coeff != 0 # ) # loss2 = loss2 + reg_loss total_time = count_total_time(model) nfe_forward = count_nfe(model) loss.backward() #loss2.backward() optimizer.step() #growth_optimizer.step() ### Eval nfe_total = count_nfe(model) nfe_backward = nfe_total - nfe_forward nfef_meter.update(nfe_forward) nfeb_meter.update(nfe_backward) time_meter.update(time.time() - end) tt_meter.update(total_time) log_message = ( 'Iter {:04d} | Time {:.4f}({:.4f}) | Loss {:.6f}({:.6f}) | NFE Forward {:.0f}({:.1f})' ' | NFE Backward {:.0f}({:.1f}) | CNF Time {:.4f}({:.4f})'.format( itr, time_meter.val, time_meter.avg, loss_meter.val, loss_meter.avg, nfef_meter.val, nfef_meter.avg, nfeb_meter.val, nfeb_meter.avg, tt_meter.val, tt_meter.avg ) ) if len(regularization_coeffs) > 0: log_message = append_regularization_to_log(log_message, regularization_fns, reg_states) logger.info(log_message) if itr % args.val_freq == 0 or itr == args.niters: with torch.no_grad(): model.eval() growth_model.eval() test_loss = compute_loss(args, model, growth_model) test_nfe = count_nfe(model) log_message = '[TEST] Iter {:04d} | Test Loss {:.6f} | NFE {:.0f}'.format(itr, test_loss, test_nfe) logger.info(log_message) if test_loss.item() < best_loss: best_loss = test_loss.item() utils.makedirs(args.save) torch.save({ 'args': args, 'state_dict': model.state_dict(), 'growth_state_dict': growth_model.state_dict(), }, os.path.join(args.save, 'checkpt.pth')) model.train() if itr % args.viz_freq == 0: with torch.no_grad(): model.eval() for i, tp in enumerate(timepoints): p_samples = viz_sampler(tp) sample_fn, density_fn = get_transforms(model, int_tps[:i+1]) #growth_sample_fn, growth_density_fn = get_transforms(growth_model, int_tps[:i+1]) plt.figure(figsize=(9, 3)) visualize_transform( p_samples, torch.randn, standard_normal_logprob, transform=sample_fn, inverse_transform=density_fn, samples=True, npts=100, device=device ) fig_filename = os.path.join(args.save, 'figs', '{:04d}_{:01d}.jpg'.format(itr, i)) utils.makedirs(os.path.dirname(fig_filename)) plt.savefig(fig_filename) plt.close() #visualize_transform( # p_samples, torch.rand, uniform_logprob, transform=growth_sample_fn, # inverse_transform=growth_density_fn, # samples=True, npts=800, device=device #) #fig_filename = os.path.join(args.save, 'growth_figs', '{:04d}_{:01d}.jpg'.format(itr, i)) #utils.makedirs(os.path.dirname(fig_filename)) #plt.savefig(fig_filename) #plt.close() model.train() """ if itr % args.viz_freq_growth == 0: with torch.no_grad(): growth_model.eval() # Visualize growth transform growth_filename = os.path.join(args.save, 'growth', '{:04d}.jpg'.format(itr)) utils.makedirs(os.path.dirname(growth_filename)) visualize_growth(growth_model, data, labels, npts=200, device=device) plt.savefig(growth_filename) plt.close() growth_model.train() """ end = time.time() logger.info('Training has finished.')
if args.encoder == "ode_rnn": encoder = create_ode_rnn_encoder(args, device) else: raise NotImplementedError regularization_fns, regularization_coeffs = create_regularization_fns(args) aug_model = build_augmented_model_tabular( args, args.aug_size + args.effective_shape + args.latent_size, regularization_fns=regularization_fns, ) set_cnf_options(args, aug_model) logger.info(aug_model) logger.info( "Number of trainable parameters: {}".format(count_parameters(aug_model)) ) # optimizer optimizer = optim.Adam( list(aug_model.parameters()) + list(encoder.parameters()), lr=args.lr, weight_decay=args.weight_decay, ) num_params = sum(p.numel() for p in aug_model.parameters() if p.requires_grad) if args.aggressive: encoder_optimizer = optim.Adam( encoder.parameters(), lr=args.lr, weight_decay=args.weight_decay ) enc_num_params = sum(p.numel() for p in encoder.parameters() if p.requires_grad)
def main(args): # logger print(args.no_display_loss) utils.makedirs(args.save) logger = utils.get_logger( logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__), displaying=~args.no_display_loss, ) if args.layer_type == "blend": logger.info("!! Setting time_scale from None to 1.0 for Blend layers.") args.time_scale = 1.0 logger.info(args) device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" ) if args.use_cpu: device = torch.device("cpu") args.data = dataset.SCData.factory(args.dataset, args.max_dim) args.timepoints = args.data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to( device ) if args.use_growth: if args.leaveout_timepoint == -1: growth_model_path = ( "../data/externel/growth_model_v2.ckpt" ) elif args.leaveout_timepoint in [1, 2, 3]: assert args.max_dim == 5 growth_model_path = ( "../data/growth/model_%d" % args.leaveout_timepoint ) else: print("WARNING: Cannot use growth with this timepoint") growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) if args.test: state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) # if "growth_state_dict" not in state_dict: # print("error growth model note in save") # growth_model = None # else: # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device) # growth_model.load_state_dict(checkpt["growth_state_dict"]) # TODO can we load the arguments from the save? # eval_utils.generate_samples( # device, args, model, growth_model, timepoint=args.leaveout_timepoint # ) # with torch.no_grad(): # evaluate(device, args, model, growth_model) # exit() else: logger.info(model) n_param = count_parameters(model) logger.info("Number of trainable parameters: {}".format(n_param)) train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger, ) if args.data.data.shape[1] == 2: plot_output(device, args, model)