def get_ckpt_model_and_data(args): # Load checkpoint. checkpt = torch.load(args.checkpt, map_location=lambda storage, loc: storage) ckpt_args = checkpt['args'] state_dict = checkpt['state_dict'] # Construct model and restore checkpoint. regularization_fns, regularization_coeffs = create_regularization_fns( ckpt_args) model = build_model_tabular(ckpt_args, 2, regularization_fns).to(device) if ckpt_args.spectral_norm: add_spectral_norm(model) set_cnf_options(ckpt_args, model) model.load_state_dict(state_dict) model.to(device) print(model) print("Number of trainable parameters: {}".format( count_parameters(model))) # Load samples from dataset data_samples = toy_data.inf_train_gen(ckpt_args.data, batch_size=2000) return model, data_samples
def gen_model(scale=10, fraction=0.5): #build normalizing flow model from previous fit device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") args = pkl.load(open('args.pkl', 'rb')) regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 5, regularization_fns).to(device) #.cuda() if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) model.load_state_dict(torch.load('model_10000.pt')) #if torch.cuda.is_available(): # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None).cuda() #len(cond_cols)).cuda() #else: # model = init_flow_model( # num_inputs=5, # num_cond_inputs=None) #len(cond_cols)).cuda() #num_layers = 5 #base_dist = StandardNormal(shape=(5,)) #transforms = [] #for _ in range(num_layers): # transforms.append(ReversePermutation(features=5)) # transforms.append(MaskedAffineAutoregressiveTransform(features=5, # hidden_features=4)) #transform = CompositeTransform(transforms) #model = Flow(transform, base_dist).to(device) #model.cpu() #filename = 'checkpoint11434epochs_cycle.pth' #filename = f'gauss_scale{scale}_frac{fraction}/checkpoint200000epochs_cycle_gauss.pth' #filename = 'gauss_scale10_frac0.25/checkpoint100000epochs_cycle_gauss.pth' #filename = 'checkpoint_epoch{}.pth'.format(95000) #data = torch.load(filename, map_location=device) #breakpoint() #model.load_state_dict(data['model']) #if torch.cuda.is_available(): # data = torch.load(filename) # model.load_state_dict(data['model']) # model.cuda(); #else: # data = torch.load(filename, map_location=torch.device('cpu')) # model.load_state_dict(data['model']) return model
def compute_loss(args, model, batch_size=args.batch_size): x = toy_data.inf_train_gen(args.data, batch_size=batch_size) x = torch.from_numpy(x).type(torch.float32).to(device) zero = torch.zeros(x.shape[0], 1).to(x) z, change = model(x, zero) logpx = standard_normal_logprob(z).sum(1, keepdim=True) - change loss = -torch.mean(logpx) return loss if __name__ == '__main__': regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, 2, regularization_fns).to(device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) logger.info(model) logger.info("Number of trainable parameters: {}".format(count_parameters(model))) optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) time_meter = utils.RunningAverageMeter(0.93) loss_meter = utils.RunningAverageMeter(0.93) nfef_meter = utils.RunningAverageMeter(0.93) nfeb_meter = utils.RunningAverageMeter(0.93) tt_meter = utils.RunningAverageMeter(0.93)
def main(): # os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) device = torch.device("cuda:%d" % torch.cuda.current_device() if torch.cuda.is_available() else "cpu") else: device = torch.cuda.current_device() # # import pdb; pdb.set_trace() cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = ['itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm'] testcolumns = ['wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost'] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns).cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format(count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters # import pdb; pdb.set_trace() if args.resume is not None: # import pdb; pdb.set_trace() print('resume from checkpoint') checkpt = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 chkdir = args.save ''' elif args.resume and args.validate: chkdir = os.path.dirname(args.resume) wall_clock = 0 itr = 0 best_loss = 0.0 begin_epoch = 0 ''' else: chkdir = os.path.dirname(args.resume) filename = os.path.join(chkdir, 'test.csv') print(filename) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) # import pdb; pdb.set_trace() wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, begin_epoch + 1): # compute test loss print('Evaluating') model.eval() if args.local_rank == 0: utils.makedirs(args.save) # import pdb; pdb.set_trace() if hasattr(model, 'module'): _state = model.module.state_dict() else: _state = model.state_dict() torch.save({ "args": args, "state_dict": _state, # model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt_%d.pth" % epoch)) # save real and generate with different temperatures fig_num = 64 if True: # args.save_real: for i, (x, y) in enumerate(test_loader): if i < 100: pass elif i == 100: real = x.size(0) else: break if x.shape[0] > fig_num: x = x[:fig_num, ...] # import pdb; pdb.set_trace() fig_filename = os.path.join(chkdir, "real.jpg") save_image(x.float() / 255.0, fig_filename, nrow=8) if True: # args.generate: print('\nGenerating images... ') fixed_z = cvt(torch.randn(fig_num, *data_shape)) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) for t in [ 1.0, 0.99, 0.98, 0.97,0.96,0.95,0.93,0.92,0.90,0.85,0.8,0.75,0.7,0.65,0.6]: # visualize samples and density fig_filename = os.path.join(chkdir, "generated-T%g.jpg" % t) utils.makedirs(os.path.dirname(fig_filename)) generated_samples = model(t * fixed_z, reverse=True) x = unshift(generated_samples[0].view(-1, *data_shape), 8) save_image(x, fig_filename, nrow=nb)
def main(): #os.system('shutdown -c') # cancel previous shutdown command if write_log: utils.makedirs(args.save) logger = utils.get_logger(logpath=os.path.join(args.save, 'logs'), filepath=os.path.abspath(__file__)) logger.info(args) args_file_path = os.path.join(args.save, 'args.yaml') with open(args_file_path, 'w') as f: yaml.dump(vars(args), f, default_flow_style=False) if args.distributed: if write_log: logger.info('Distributed initializing process group') torch.cuda.set_device(args.local_rank) distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=dist_utils.env_world_size(), rank=env_rank()) assert (dist_utils.env_world_size() == distributed.get_world_size()) if write_log: logger.info("Distributed: success (%d/%d)" % (args.local_rank, distributed.get_world_size())) # get deivce # device = torch.device("cuda:%d"%torch.cuda.current_device() if torch.cuda.is_available() else "cpu") device = "cpu" cvt = lambda x: x.type(torch.float32).to(device, non_blocking=True) # load dataset train_loader, test_loader, data_shape = get_dataset(args) trainlog = os.path.join(args.save, 'training.csv') testlog = os.path.join(args.save, 'test.csv') traincolumns = [ 'itr', 'wall', 'itr_time', 'loss', 'bpd', 'fe', 'total_time', 'grad_norm' ] testcolumns = [ 'wall', 'epoch', 'eval_time', 'bpd', 'fe', 'total_time', 'transport_cost' ] # build model regularization_fns, regularization_coeffs = create_regularization_fns(args) model = create_model(args, data_shape, regularization_fns) # model = model.cuda() if args.distributed: model = dist_utils.DDP(model, device_ids=[args.local_rank], output_device=args.local_rank) traincolumns = append_regularization_keys_header(traincolumns, regularization_fns) if not args.resume and write_log: with open(trainlog, 'w') as f: csvlogger = csv.DictWriter(f, traincolumns) csvlogger.writeheader() with open(testlog, 'w') as f: csvlogger = csv.DictWriter(f, testcolumns) csvlogger.writeheader() set_cnf_options(args, model) if write_log: logger.info(model) if write_log: logger.info("Number of trainable parameters: {}".format( count_parameters(model))) if write_log: logger.info('Iters per train epoch: {}'.format(len(train_loader))) if write_log: logger.info('Iters per test: {}'.format(len(test_loader))) # optimizer if args.optimizer == 'adam': optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == 'sgd': optimizer = optim.SGD(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, momentum=0.9, nesterov=False) # restore parameters if args.resume is not None: checkpt = torch.load( args.resume, map_location=lambda storage, loc: storage.cuda(args.local_rank)) model.load_state_dict(checkpt["state_dict"]) if "optim_state_dict" in checkpt.keys(): optimizer.load_state_dict(checkpt["optim_state_dict"]) # Manually move optimizer state to device. for state in optimizer.state.values(): for k, v in state.items(): if torch.is_tensor(v): state[k] = cvt(v) # For visualization. if write_log: fixed_z = cvt(torch.randn(min(args.test_batch_size, 100), *data_shape)) if write_log: time_meter = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) loss_meter = utils.RunningAverageMeter(0.97) steps_meter = utils.RunningAverageMeter(0.97) grad_meter = utils.RunningAverageMeter(0.97) tt_meter = utils.RunningAverageMeter(0.97) if not args.resume: best_loss = float("inf") itr = 0 wall_clock = 0. begin_epoch = 1 else: chkdir = os.path.dirname(args.resume) tedf = pd.read_csv(os.path.join(chkdir, 'test.csv')) trdf = pd.read_csv(os.path.join(chkdir, 'training.csv')) wall_clock = trdf['wall'].to_numpy()[-1] itr = trdf['itr'].to_numpy()[-1] best_loss = tedf['bpd'].min() begin_epoch = int(tedf['epoch'].to_numpy()[-1] + 1) # not exactly correct if args.distributed: if write_log: logger.info('Syncing machines before training') dist_utils.sum_tensor(torch.tensor([1.0]).float().cuda()) for epoch in range(begin_epoch, args.num_epochs + 1): if not args.validate: model.train() with open(trainlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, traincolumns) for _, (x, y) in enumerate(train_loader): start = time.time() update_lr(optimizer, itr) optimizer.zero_grad() # cast data and move to device x = add_noise(cvt(x), nbits=args.nbits) #x = x.clamp_(min=0, max=1) # compute loss bpd, (x, z), reg_states = compute_bits_per_dim(x, model) if np.isnan(bpd.data.item()): raise ValueError('model returned nan during training') elif np.isinf(bpd.data.item()): raise ValueError('model returned inf during training') loss = bpd if regularization_coeffs: reg_loss = sum(reg_state * coeff for reg_state, coeff in zip( reg_states, regularization_coeffs) if coeff != 0) loss = loss + reg_loss total_time = count_total_time(model) loss.backward() nfe_opt = count_nfe(model) if write_log: steps_meter.update(nfe_opt) grad_norm = torch.nn.utils.clip_grad_norm_( model.parameters(), args.max_grad_norm) optimizer.step() itr_time = time.time() - start wall_clock += itr_time batch_size = x.size(0) metrics = torch.tensor([ 1., batch_size, loss.item(), bpd.item(), nfe_opt, grad_norm, *reg_states ]).float() rv = tuple(torch.tensor(0.) for r in reg_states) total_gpus, batch_total, r_loss, r_bpd, r_nfe, r_grad_norm, *rv = dist_utils.sum_tensor( metrics).cpu().numpy() if write_log: time_meter.update(itr_time) bpd_meter.update(r_bpd / total_gpus) loss_meter.update(r_loss / total_gpus) grad_meter.update(r_grad_norm / total_gpus) tt_meter.update(total_time) fmt = '{:.4f}' logdict = { 'itr': itr, 'wall': fmt.format(wall_clock), 'itr_time': fmt.format(itr_time), 'loss': fmt.format(r_loss / total_gpus), 'bpd': fmt.format(r_bpd / total_gpus), 'total_time': fmt.format(total_time), 'fe': r_nfe / total_gpus, 'grad_norm': fmt.format(r_grad_norm / total_gpus), } if regularization_coeffs: rv = tuple(v_ / total_gpus for v_ in rv) logdict = append_regularization_csv_dict( logdict, regularization_fns, rv) csvlogger.writerow(logdict) if itr % args.log_freq == 0: log_message = ( "Itr {:06d} | Wall {:.3e}({:.2f}) | " "Time/Itr {:.2f}({:.2f}) | BPD {:.2f}({:.2f}) | " "Loss {:.2f}({:.2f}) | " "FE {:.0f}({:.0f}) | Grad Norm {:.3e}({:.3e}) | " "TT {:.2f}({:.2f})".format( itr, wall_clock, wall_clock / (itr + 1), time_meter.val, time_meter.avg, bpd_meter.val, bpd_meter.avg, loss_meter.val, loss_meter.avg, steps_meter.val, steps_meter.avg, grad_meter.val, grad_meter.avg, tt_meter.val, tt_meter.avg)) if regularization_coeffs: log_message = append_regularization_to_log( log_message, regularization_fns, rv) logger.info(log_message) itr += 1 # compute test loss model.eval() if args.local_rank == 0: utils.makedirs(args.save) torch.save( { "args": args, "state_dict": model.module.state_dict() if torch.cuda.is_available() else model.state_dict(), "optim_state_dict": optimizer.state_dict(), "fixed_z": fixed_z.cpu() }, os.path.join(args.save, "checkpt.pth")) if epoch % args.val_freq == 0 or args.validate: with open(testlog, 'a') as f: if write_log: csvlogger = csv.DictWriter(f, testcolumns) with torch.no_grad(): start = time.time() if write_log: logger.info("validating...") lossmean = 0. meandist = 0. steps = 0 tt = 0. for i, (x, y) in enumerate(test_loader): sh = x.shape x = shift(cvt(x), nbits=args.nbits) loss, (x, z), _ = compute_bits_per_dim(x, model) dist = (x.view(x.size(0), -1) - z).pow(2).mean(dim=-1).mean() meandist = i / (i + 1) * dist + meandist / (i + 1) lossmean = i / (i + 1) * lossmean + loss / (i + 1) tt = i / (i + 1) * tt + count_total_time(model) / (i + 1) steps = i / (i + 1) * steps + count_nfe(model) / (i + 1) loss = lossmean.item() metrics = torch.tensor([1., loss, meandist, steps]).float() total_gpus, r_bpd, r_mdist, r_steps = dist_utils.sum_tensor( metrics).cpu().numpy() eval_time = time.time() - start if write_log: fmt = '{:.4f}' logdict = { 'epoch': epoch, 'eval_time': fmt.format(eval_time), 'bpd': fmt.format(r_bpd / total_gpus), 'wall': fmt.format(wall_clock), 'total_time': fmt.format(tt), 'transport_cost': fmt.format(r_mdist / total_gpus), 'fe': '{:.2f}'.format(r_steps / total_gpus) } csvlogger.writerow(logdict) logger.info( "Epoch {:04d} | Time {:.4f}, Bit/dim {:.4f}, Steps {:.4f}, TT {:.2f}, Transport Cost {:.2e}" .format(epoch, eval_time, r_bpd / total_gpus, r_steps / total_gpus, tt, r_mdist / total_gpus)) loss = r_bpd / total_gpus if loss < best_loss and args.local_rank == 0: best_loss = loss shutil.copyfile(os.path.join(args.save, "checkpt.pth"), os.path.join(args.save, "best.pth")) # visualize samples and density if write_log: with torch.no_grad(): fig_filename = os.path.join(args.save, "figs", "{:04d}.jpg".format(epoch)) utils.makedirs(os.path.dirname(fig_filename)) generated_samples, _, _ = model(fixed_z, reverse=True) generated_samples = generated_samples.view(-1, *data_shape) nb = int(np.ceil(np.sqrt(float(fixed_z.size(0))))) save_image(unshift(generated_samples, nbits=args.nbits), fig_filename, nrow=nb) if args.validate: break
def main(): global best_acc if not os.path.isdir(args.out): mkdir_p(args.out) # Data print(f'==> Preparing cifar10') transform_train = transforms.Compose([ dataset.RandomPadandCrop(32), dataset.RandomFlip(), dataset.ToTensor(), ]) transform_val = transforms.Compose([ dataset.ToTensor(), ]) train_labeled_set, train_unlabeled_set, val_set, test_set = dataset.get_cifar10( '/home/fengchan/stor/dataset/original-data/cifar10', args.n_labeled, transform_train=transform_train, transform_val=transform_val) labeled_trainloader = data.DataLoader(train_labeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) unlabeled_trainloader = data.DataLoader(train_unlabeled_set, batch_size=args.batch_size, shuffle=True, num_workers=0, drop_last=True) val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, num_workers=0) test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, num_workers=0) # Model print("==> creating WRN-28-2") def create_model(ema=False): model = models.WideResNet(num_classes=num_classes) model = model.cuda() if ema: for param in model.parameters(): param.detach_() return model data_shape = [3, 32, 32] regularization_fns, regularization_coeffs = create_regularization_fns(args) def create_cnf(): # generate cnf # cnf = create_cnf_model_1(args, data_shape, regularization_fns=None) # cnf = create_cnf_model(args, data_shape, regularization_fns=regularization_fns) cnf = create_nf_model(args, data_shape, regularization_fns=None) cnf = cnf.cuda() if use_cuda else cnf return cnf model = create_model() ema_model = create_model(ema=True) cnf = create_cnf() if args.spectral_norm: add_spectral_norm(cnf, logger) set_cnf_options(args, cnf) cudnn.benchmark = True print(' Total params: %.2fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0)) train_criterion = SemiLoss() criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(model.parameters(), lr=args.lr) #CNF cnf_optimizer = optim.Adam(cnf.parameters(), lr=args.lr, weight_decay=args.weight_decay) ema_optimizer = WeightEMA(model, ema_model, alpha=args.ema_decay) start_epoch = 0 # Resume #generate prior means = generate_gaussian_means(num_classes, data_shape, seed=num_classes) title = 'noisy-cifar-10' if args.resume: # Load checkpoint. print('==> Resuming from checkpoint..') assert os.path.isfile( args.resume), 'Error: no checkpoint directory found!' args.out = os.path.dirname(args.resume) checkpoint = torch.load(args.resume) best_acc = checkpoint['best_acc'] start_epoch = checkpoint['epoch'] model.load_state_dict(checkpoint['state_dict']) ema_model.load_state_dict(checkpoint['ema_state_dict']) cnf.load_state_dict(checkpoint['cnf_state_dict']) means = checkpoint['means'] cnf_optimizer.load_state_dict(checkpoint['cnf_optimizer']) optimizer.load_state_dict(checkpoint['optimizer']) logger = Logger(os.path.join(args.out, 'log.txt'), title=title, resume=True) else: logger = Logger(os.path.join(args.out, 'log.txt'), title=title) logger.set_names([ 'Train Loss', 'Train Loss X', 'Train Loss U', 'Train loss NLL X', 'Train loss NLL U', 'Train loss mixed X', 'Valid Loss', 'Valid Acc.', 'Test Loss', 'Test Acc.' ]) means = means.cuda() if use_cuda else means prior = SSLGaussMixture(means, device='cuda' if use_cuda else 'cpu') writer = SummaryWriter(args.out) step = 0 test_accs = [] # Train and val for epoch in range(start_epoch, args.epochs): print('\nEpoch: [%d | %d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x = train( labeled_trainloader, unlabeled_trainloader, model, cnf, prior, cnf_optimizer, optimizer, ema_optimizer, train_criterion, epoch, use_cuda) _, train_acc = validate(labeled_trainloader, ema_model, criterion, epoch, use_cuda, mode='Train Stats') val_loss, val_acc = validate(val_loader, ema_model, criterion, epoch, use_cuda, mode='Valid Stats') test_loss, test_acc = validate(test_loader, ema_model, criterion, epoch, use_cuda, mode='Test Stats ') step = args.train_iteration * (epoch + 1) writer.add_scalar('losses/train_loss', train_loss, step) writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step) writer.add_scalar('losses/train_loss_nll_u', train_loss_nll_u, step) writer.add_scalar('losses/train_loss_mixed_x', train_loss_mixed_x, step) writer.add_scalar('losses/train_loss_nll_x', train_loss_nll_x, step) writer.add_scalar('losses/valid_loss', val_loss, step) writer.add_scalar('losses/test_loss', test_loss, step) writer.add_scalar('accuracy/train_acc', train_acc, step) writer.add_scalar('accuracy/val_acc', val_acc, step) writer.add_scalar('accuracy/test_acc', test_acc, step) # append logger file logger.append([ train_loss, train_loss_x, train_loss_u, train_loss_nll_x, train_loss_nll_u, train_loss_mixed_x, val_loss, val_acc, test_loss, test_acc ]) # save model is_best = val_acc > best_acc best_acc = max(val_acc, best_acc) save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'cnf_state_dict': cnf.state_dict(), 'means': means, 'ema_state_dict': ema_model.state_dict(), 'acc': val_acc, 'best_acc': best_acc, 'optimizer': optimizer.state_dict(), 'cnf_optimizer': cnf_optimizer.state_dict(), }, is_best) test_accs.append(test_acc) logger.close() writer.close() print('Best acc:') print(best_acc) print('Mean acc:') print(np.mean(test_accs[-20:]))
def main(args): # logger print(args.no_display_loss) utils.makedirs(args.save) logger = utils.get_logger( logpath=os.path.join(args.save, "logs"), filepath=os.path.abspath(__file__), displaying=~args.no_display_loss, ) if args.layer_type == "blend": logger.info("!! Setting time_scale from None to 1.0 for Blend layers.") args.time_scale = 1.0 logger.info(args) device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu" ) if args.use_cpu: device = torch.device("cpu") args.data = dataset.SCData.factory(args.dataset, args.max_dim) args.timepoints = args.data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, args.data.get_shape()[0], regularization_fns).to( device ) if args.use_growth: if args.leaveout_timepoint == -1: growth_model_path = ( "../data/externel/growth_model_v2.ckpt" ) elif args.leaveout_timepoint in [1, 2, 3]: assert args.max_dim == 5 growth_model_path = ( "../data/growth/model_%d" % args.leaveout_timepoint ) else: print("WARNING: Cannot use growth with this timepoint") growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) if args.test: state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) # if "growth_state_dict" not in state_dict: # print("error growth model note in save") # growth_model = None # else: # checkpt = torch.load(args.save + "/checkpt.pth", map_location=device) # growth_model.load_state_dict(checkpt["growth_state_dict"]) # TODO can we load the arguments from the save? # eval_utils.generate_samples( # device, args, model, growth_model, timepoint=args.leaveout_timepoint # ) # with torch.no_grad(): # evaluate(device, args, model, growth_model) # exit() else: logger.info(model) n_param = count_parameters(model) logger.info("Number of trainable parameters: {}".format(n_param)) train( device, args, model, growth_model, regularization_coeffs, regularization_fns, logger, ) if args.data.data.shape[1] == 2: plot_output(device, args, model)
def main(args): device = torch.device( "cuda:" + str(args.gpu) if torch.cuda.is_available() else "cpu") if args.use_cpu: device = torch.device("cpu") data = dataset.SCData.factory(args.dataset, args) args.timepoints = data.get_unique_times() # Use maximum timepoint to establish integration_times # as some timepoints may be left out for validation etc. args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale regularization_fns, regularization_coeffs = create_regularization_fns(args) model = build_model_tabular(args, data.get_shape()[0], regularization_fns).to(device) if args.use_growth: growth_model_path = data.get_growth_net_path() #growth_model_path = "/home/atong/TrajectoryNet/data/externel/growth_model_v2.ckpt" growth_model = torch.load(growth_model_path, map_location=device) if args.spectral_norm: add_spectral_norm(model) set_cnf_options(args, model) state_dict = torch.load(args.save + "/checkpt.pth", map_location=device) model.load_state_dict(state_dict["state_dict"]) #plot_output(device, args, model, data) #exit() # get_trajectory_samples(device, model, data) args.data = data args.timepoints = args.data.get_unique_times() args.int_tps = (np.arange(max(args.timepoints) + 1) + 1.0) * args.time_scale print('integrating backwards') #end_time_data = data.data_dict[args.embedding_name] end_time_data = data.get_data()[args.data.get_times() == np.max( args.data.get_times())] #np.random.permutation(end_time_data) #rand_idx = np.random.randint(end_time_data.shape[0], size=5000) #end_time_data = end_time_data[rand_idx,:] integrate_backwards(end_time_data, model, args.save, ntimes=100, device=device) exit() losses_list = [] #for factor in np.linspace(0.05, 0.95, 19): #for factor in np.linspace(0.91, 0.99, 9): if args.dataset == 'CHAFFER': # Do timepoint adjustment print('adjusting_timepoints') lt = args.leaveout_timepoint if lt == 1: factor = 0.6799872494335812 factor = 0.95 elif lt == 2: factor = 0.2905983814032348 factor = 0.01 else: raise RuntimeError('Unknown timepoint %d' % args.leaveout_timepoint) args.int_tps[lt] = ( 1 - factor) * args.int_tps[lt - 1] + factor * args.int_tps[lt + 1] losses = eval_utils.evaluate_kantorovich_v2(device, args, model) losses_list.append(losses) print(np.array(losses_list)) np.save(os.path.join(args.save, 'emd_list'), np.array(losses_list))