) test_loader = torch.utils.data.DataLoader( datasets.CIFAR10(args.dataroot, train=False, transform=transform_test), batch_size=args.val_batchsize, shuffle=False, num_workers=args.nworkers, ) elif args.data == 'mnist': im_dim = 1 init_layer = layers.LogitTransform(1e-6) n_classes = 10 train_loader = torch.utils.data.DataLoader( datasets.MNIST( args.dataroot, train=True, transform=transforms.Compose([ transforms.Resize(args.imagesize), transforms.ToTensor(), add_noise, ]) ), batch_size=args.batchsize, shuffle=True, num_workers=args.nworkers, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST( args.dataroot, train=False, transform=transforms.Compose([ transforms.Resize(args.imagesize), transforms.ToTensor(), add_noise, ]) ),
def main(rank, world_size, args): setup(rank, world_size, args.port) # setup logger if rank == 0: utils.makedirs(args.save) logger = utils.get_logger(os.path.join(args.save, "logs")) def mprint(msg): if rank == 0: logger.info(msg) mprint(args) device = torch.device( f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') if device.type == 'cuda': mprint('Found {} CUDA devices.'.format(torch.cuda.device_count())) for i in range(torch.cuda.device_count()): props = torch.cuda.get_device_properties(i) mprint('{} \t Memory: {:.2f}GB'.format( props.name, props.total_memory / (1024**3))) else: mprint('WARNING: Using device {}'.format(device)) np.random.seed(args.seed + rank) torch.manual_seed(args.seed + rank) if device.type == 'cuda': torch.cuda.manual_seed(args.seed + rank) mprint('Loading dataset {}'.format(args.data)) # Dataset and hyperparameters if args.data == 'cifar10': im_dim = 3 transform_train = transforms.Compose([ transforms.Resize(args.imagesize), transforms.RandomHorizontalFlip(), transforms.ToTensor(), add_noise if args.add_noise else identity, ]) transform_test = transforms.Compose([ transforms.Resize(args.imagesize), transforms.ToTensor(), add_noise if args.add_noise else identity, ]) init_layer = flows.LogitTransform(0.05) train_set = vdsets.SVHN(args.dataroot, download=True, split="train", transform=transform_train) sampler = torch.utils.data.distributed.DistributedSampler(train_set) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batchsize, sampler=sampler, ) test_loader = torch.utils.data.DataLoader( vdsets.SVHN(args.dataroot, download=True, split="test", transform=transform_test), batch_size=args.val_batchsize, shuffle=False, ) elif args.data == 'mnist': im_dim = 1 init_layer = flows.LogitTransform(1e-6) train_set = datasets.MNIST( args.dataroot, train=True, transform=transforms.Compose([ transforms.Resize(args.imagesize), transforms.ToTensor(), add_noise if args.add_noise else identity, ])) sampler = torch.utils.data.distributed.DistributedSampler(train_set) train_loader = torch.utils.data.DataLoader( train_set, batch_size=args.batchsize, sampler=sampler, ) test_loader = torch.utils.data.DataLoader( datasets.MNIST(args.dataroot, train=False, transform=transforms.Compose([ transforms.Resize(args.imagesize), transforms.ToTensor(), add_noise if args.add_noise else identity, ])), batch_size=args.val_batchsize, shuffle=False, ) else: raise Exception(f'dataset not one of mnist / cifar10, got {args.data}') mprint('Dataset loaded.') mprint('Creating model.') input_size = (args.batchsize, im_dim, args.imagesize, args.imagesize) model = MultiscaleFlow( input_size, block_fn=partial(cpflow_block_fn, block_type=args.block_type, dimh=args.dimh, num_hidden_layers=args.num_hidden_layers, icnn_version=args.icnn, num_pooling=args.num_pooling), n_blocks=list(map(int, args.nblocks.split('-'))), factor_out=args.factor_out, init_layer=init_layer, actnorm=args.actnorm, fc_end=args.fc_end, glow=args.glow, ) model.to(device) model = DDP(model, device_ids=[rank], find_unused_parameters=True) ema = utils.ExponentialMovingAverage(model) mprint(model) mprint('EMA: {}'.format(ema)) optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.99), weight_decay=args.wd) # Saving and resuming best_test_bpd = math.inf begin_epoch = 0 most_recent_path = os.path.join(args.save, 'models', 'most_recent.pth') checkpt_exists = os.path.exists(most_recent_path) if checkpt_exists: mprint(f"Resuming from {most_recent_path}") # deal with data-dependent initialization like actnorm. with torch.no_grad(): x = torch.rand(8, *input_size[1:]).to(device) model(x) checkpt = torch.load(most_recent_path) begin_epoch = checkpt["epoch"] + 1 model.module.load_state_dict(checkpt["state_dict"]) ema.set(checkpt['ema']) optimizer.load_state_dict(checkpt["opt_state_dict"]) elif args.resume: mprint(f"Resuming from {args.resume}") # deal with data-dependent initialization like actnorm. with torch.no_grad(): x = torch.rand(8, *input_size[1:]).to(device) model(x) checkpt = torch.load(args.resume) begin_epoch = checkpt["epoch"] + 1 model.module.load_state_dict(checkpt["state_dict"]) ema.set(checkpt['ema']) optimizer.load_state_dict(checkpt["opt_state_dict"]) mprint(optimizer) batch_time = utils.RunningAverageMeter(0.97) bpd_meter = utils.RunningAverageMeter(0.97) gnorm_meter = utils.RunningAverageMeter(0.97) cg_meter = utils.RunningAverageMeter(0.97) hnorm_meter = utils.RunningAverageMeter(0.97) update_lr(optimizer, 0, args) # for visualization fixed_x = next(iter(train_loader))[0][:8].to(device) fixed_z = torch.randn(8, im_dim * args.imagesize * args.imagesize).to(fixed_x) if rank == 0: utils.makedirs(os.path.join(args.save, 'figs')) # visualize(model, fixed_x, fixed_z, os.path.join(args.save, 'figs', 'init.png')) for epoch in range(begin_epoch, args.nepochs): sampler.set_epoch(epoch) flows.CG_ITERS_TRACER.clear() flows.HESS_NORM_TRACER.clear() mprint('Current LR {}'.format(optimizer.param_groups[0]['lr'])) train(epoch, train_loader, model, optimizer, bpd_meter, gnorm_meter, cg_meter, hnorm_meter, batch_time, ema, device, mprint, world_size, args) val_time, test_bpd = validate(epoch, model, test_loader, ema, device) mprint( 'Epoch: [{0}]\tTime {1:.2f} | Test bits/dim {test_bpd:.4f}'.format( epoch, val_time, test_bpd=test_bpd)) if rank == 0: utils.makedirs(os.path.join(args.save, 'figs')) visualize(model, fixed_x, fixed_z, os.path.join(args.save, 'figs', f'{epoch}.png')) utils.makedirs(os.path.join(args.save, "models")) if test_bpd < best_test_bpd: best_test_bpd = test_bpd torch.save( { 'epoch': epoch, 'state_dict': model.module.state_dict(), 'opt_state_dict': optimizer.state_dict(), 'args': args, 'ema': ema, 'test_bpd': test_bpd, }, os.path.join(args.save, 'models', 'best_model.pth')) if rank == 0: torch.save( { 'epoch': epoch, 'state_dict': model.module.state_dict(), 'opt_state_dict': optimizer.state_dict(), 'args': args, 'ema': ema, 'test_bpd': test_bpd, }, os.path.join(args.save, 'models', 'most_recent.pth')) cleanup()