def main(): args = parser.parse_args() if args.deterministic: print("MODEL NOT FULLY DETERMINISTIC") torch.manual_seed(1234) torch.cuda.manual_seed(1234) np.random.seed(1234) random.seed(1234) torch.backends.cudnn.deterministic=True dens_est_chain = [ lambda x: (255. * x) + torch.zeros_like(x).uniform_(0., 1.), lambda x: x / 256., lambda x: x - 0.5 ] if args.dataset == 'mnist': assert args.densityEstimation, "Currently mnist is only supported for density estimation" mnist_transforms = [transforms.Pad(2, 0), transforms.ToTensor(), lambda x: x.repeat((3, 1, 1))] transform_train_mnist = transforms.Compose(mnist_transforms + dens_est_chain) transform_test_mnist = transforms.Compose(mnist_transforms + dens_est_chain) trainset = torchvision.datasets.MNIST( root='./data', train=True, download=True, transform=transform_train_mnist) testset = torchvision.datasets.MNIST( root='./data', train=False, download=False, transform=transform_test_mnist) args.nClasses = 10 in_shape = (3, 32, 32) else: if args.dataset == 'svhn': train_chain = [transforms.Pad(4, padding_mode="symmetric"), transforms.RandomCrop(32), transforms.ToTensor()] else: train_chain = [transforms.Pad(4, padding_mode="symmetric"), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor()] test_chain = [transforms.ToTensor()] if args.densityEstimation: transform_train = transforms.Compose(train_chain + dens_est_chain) transform_test = transforms.Compose(test_chain + dens_est_chain) else: clf_chain = [transforms.Normalize(mean[args.dataset], std[args.dataset])] transform_train = transforms.Compose(train_chain + clf_chain) transform_test = transforms.Compose(test_chain + clf_chain) if args.dataset == 'cifar10': trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) args.nClasses = 10 elif args.dataset == 'cifar100': trainset = torchvision.datasets.CIFAR100( root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100( root='./data', train=False, download=True, transform=transform_test) args.nClasses = 100 elif args.dataset == 'svhn': trainset = torchvision.datasets.SVHN( root='./data', split='train', download=True, transform=transform_train) testset = torchvision.datasets.SVHN( root='./data', split='test', download=True, transform=transform_test) args.nClasses = 10 in_shape = (3, 32, 32) # setup logging with visdom viz = visdom.Visdom(port=args.vis_port, server="http://" + args.vis_server) assert viz.check_connection(), "Could not make visdom" if args.deterministic: trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch, shuffle=True, num_workers=2, worker_init_fn=np.random.seed(1234)) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch, shuffle=False, num_workers=2, worker_init_fn=np.random.seed(1234)) else: trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch, shuffle=False, num_workers=2) def get_model(args): if args.multiScale: model = multiscale_iResNet(in_shape, args.nBlocks, args.nStrides, args.nChannels, args.init_ds == 2, args.inj_pad, args.coeff, args.densityEstimation, args.nClasses, args.numTraceSamples, args.numSeriesTerms, args.powerIterSpectralNorm, actnorm=(not args.noActnorm), learn_prior=(not args.fixedPrior), nonlin=args.nonlin) else: model = iResNet(nBlocks=args.nBlocks, nStrides=args.nStrides, nChannels=args.nChannels, nClasses=args.nClasses, init_ds=args.init_ds, inj_pad=args.inj_pad, in_shape=in_shape, coeff=args.coeff, numTraceSamples=args.numTraceSamples, numSeriesTerms=args.numSeriesTerms, n_power_iter = args.powerIterSpectralNorm, density_estimation=args.densityEstimation, actnorm=(not args.noActnorm), learn_prior=(not args.fixedPrior), nonlin=args.nonlin) return model model = get_model(args) # init actnrom parameters init_batch = get_init_batch(trainloader, args.init_batch) print("initializing actnorm parameters...") with torch.no_grad(): model(init_batch, ignore_logdet=True) print("initialized") use_cuda = torch.cuda.is_available() if use_cuda: model.cuda() model = torch.nn.DataParallel(model, range(torch.cuda.device_count())) cudnn.benchmark = True in_shapes = model.module.get_in_shapes() else: in_shapes = model.get_in_shapes() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) #start_epoch = checkpoint['epoch'] #best_objective = checkpoint['objective'] #print('objective: '+str(best_objective)) model = checkpoint['model'] optimizer = checkpoint['optimizer'] #if use_cuda: # model.module.set_num_terms(args.numSeriesTerms) #else: # model.set_num_terms(args.numSeriesTerms) #print("=> loaded checkpoint '{}' (epoch {})" # .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) try_make_dir(args.save_dir) if args.analysisTraceEst: anaylse_trace_estimation(model, testset, use_cuda, args.extension) return if args.norm: test_spec_norm(model, in_shapes, args.extension) return if args.interpolate: interpolate(model, testloader, testset, start_epoch, use_cuda, best_objective, args.dataset) return if args.evaluate: test_log = open(os.path.join(args.save_dir, "test_log.txt"), 'w') if use_cuda: model.module.set_num_terms(args.numSeriesTerms) else: model.set_num_terms(args.numSeriesTerms) model = torch.nn.DataParallel(model.module) test(best_objective, args, model, start_epoch, testloader, viz, use_cuda, test_log) return print('| Train Epochs: ' + str(args.epochs)) print('| Initial Learning Rate: ' + str(args.lr)) elapsed_time = 0 test_objective = -np.inf if args.optimizer == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "adamax": optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=args.nesterov) with open(os.path.join(args.save_dir, 'params.txt'), 'w') as f: f.write(json.dumps(args.__dict__)) train_log = open(os.path.join(args.save_dir, "train_log.txt"), 'w') for epoch in range(1, 1+args.epochs): start_time = time.time() train(args, model, optimizer, epoch, trainloader, trainset, viz, use_cuda, train_log) epoch_time = time.time() - start_time elapsed_time += epoch_time print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time))) if (epoch % 2) == 0: state = {'model': model if use_cuda else model, 'optimizer': optimizer} torch.save(state, os.path.join(args.save_dir, 'IResModelGM_%d.t7' % (epoch))) #state = {'model': model if use_cuda else model, 'optimizer': optimizer} #torch.save(state, os.path.join(args.save_dir, 'IResNetModelGM_%d.t7' % (epoch))) print('Testing model') test_log = open(os.path.join(args.save_dir, "test_log.txt"), 'w') test_objective = test(test_objective, args, model, epoch, testloader, viz, use_cuda, test_log) print('* Test results : objective = %.2f%%' % (test_objective)) with open(os.path.join(args.save_dir, 'final.txt'), 'w') as f: f.write(str(test_objective))
def main(): args = parser.parse_args() transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean[args.dataset], std[args.dataset]), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean[args.dataset], std[args.dataset]), ]) if(args.dataset == 'cifar10'): trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform_test) nClasses = 10 in_shape = [3, 32, 32] elif(args.dataset == 'cifar100'): trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_test) nClasses = 100 in_shape = [3, 32, 32] trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=100, shuffle=False, num_workers=2) def get_model(args): if (args.model == 'i-revnet'): model = iRevNet(nBlocks=args.nBlocks, nStrides=args.nStrides, nChannels=args.nChannels, nClasses=nClasses, init_ds=args.init_ds, dropout_rate=0.1, affineBN=True, in_shape=in_shape, mult=args.bottleneck_mult) fname = 'i-revnet-'+str(sum(args.nBlocks)+1) elif (args.model == 'revnet'): raise NotImplementedError else: print('Choose i-revnet or revnet') sys.exit(0) return model, fname model, fname = get_model(args) use_cuda = True if use_cuda: model.cuda() model = torch.nn.DataParallel(model, device_ids=(0,)) # range(torch.cuda.device_count())) cudnn.benchmark = True # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_acc = checkpoint['acc'] model = checkpoint['model'] print("=> loaded checkpoint '{}' (epoch {})" .format(args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) if args.evaluate: test(model, testloader, testset, start_epoch, use_cuda, best_acc, args.dataset, fname) return print('| Train Epochs: ' + str(args.epochs)) print('| Initial Learning Rate: ' + str(args.lr)) elapsed_time = 0 best_acc = 0. for epoch in range(1, 1+args.epochs): start_time = time.time() train(model, trainloader, trainset, epoch, args.epochs, args.batch, args.lr, use_cuda, in_shape) best_acc = test(model, testloader, testset, epoch, use_cuda, best_acc, args.dataset, fname) epoch_time = time.time() - start_time elapsed_time += epoch_time print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time))) print('Testing model') print('* Test results : Acc@1 = %.2f%%' % (best_acc))
def main(): args = parser.parse_args() if args.deterministic: print("MODEL NOT FULLY DETERMINISTIC") torch.manual_seed(1234) torch.cuda.manual_seed(1234) np.random.seed(1234) random.seed(1234) torch.backends.cudnn.deterministic = True dens_est_chain = [ lambda x: (255. * x) + torch.zeros_like(x).uniform_(0., 1.), lambda x: x / 256., lambda x: x - 0.5 ] inverse_den_est_chain = [lambda x: x + 0.5] inverse_den_est = transforms.Compose(inverse_den_est_chain) test_chain = [transforms.ToTensor()] if args.dataset == 'cifar10': train_chain = [ transforms.Pad(4, padding_mode="symmetric"), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor() ] transform_train = transforms.Compose(train_chain + dens_est_chain) transform_test = transforms.Compose(test_chain + dens_est_chain) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) args.nClasses = 10 if args.single_label: trainset = CifarSingleDataset( '/home/billy/Downloads/CIFAR-10-images/train/airplane', transform=transform_train) testset = CifarSingleDataset( '/home/billy/Downloads/CIFAR-10-images/test/airplane', transform=transform_test) elif args.dataset == 'cifar100': train_chain = [ transforms.Pad(4, padding_mode="symmetric"), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor() ] transform_train = transforms.Compose(train_chain + dens_est_chain) transform_test = transforms.Compose(test_chain + dens_est_chain) trainset = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test) args.nClasses = 100 elif args.dataset == 'svhn': train_chain = [ transforms.Pad(4, padding_mode="symmetric"), transforms.RandomCrop(32), transforms.ToTensor() ] transform_train = transforms.Compose(train_chain + dens_est_chain) transform_test = transforms.Compose(test_chain + dens_est_chain) trainset = torchvision.datasets.SVHN(root='./data', split='train', download=True, transform=transform_train) testset = torchvision.datasets.SVHN(root='./data', split='test', download=True, transform=transform_test) args.nClasses = 10 else: # mnist mnist_transforms = [ transforms.Pad(2, 0), transforms.ToTensor(), lambda x: x.repeat((3, 1, 1)) ] transform_train_mnist = transforms.Compose(mnist_transforms + dens_est_chain) transform_test_mnist = transforms.Compose(mnist_transforms + dens_est_chain) trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform_train_mnist) testset = torchvision.datasets.MNIST(root='./data', train=False, download=False, transform=transform_test_mnist) args.nClasses = 10 in_shape = (3, 32, 32) # setup logging with visdom # viz = visdom.Visdom(port=args.vis_port, server="http://" + args.vis_server) # assert viz.check_connection(), "Could not make visdom" viz = None if args.deterministic: trainloader = torch.utils.data.DataLoader( trainset, batch_size=args.batch, shuffle=True, num_workers=2, worker_init_fn=np.random.seed(1234)) testloader = torch.utils.data.DataLoader( testset, batch_size=args.batch, shuffle=False, num_workers=2, worker_init_fn=np.random.seed(1234)) else: trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.batch, shuffle=True, num_workers=2) testloader = torch.utils.data.DataLoader(testset, batch_size=args.batch, shuffle=False, num_workers=2) def get_model(args): if args.multiScale: model = multiscale_iResNet(in_shape, args.nBlocks, args.nStrides, args.nChannels, args.doAttention, args.init_ds == 2, args.coeff, args.nClasses, args.numTraceSamples, args.numSeriesTerms, args.powerIterSpectralNorm, actnorm=(not args.noActnorm), nonlin=args.nonlin, use_label=args.use_label) else: # model = iResNet(nBlocks=args.nBlocks, nStrides=args.nStrides, # nChannels=args.nChannels, nClasses=args.nClasses, # init_ds=args.init_ds, # inj_pad=args.inj_pad, # in_shape=in_shape, # coeff=args.coeff, # numTraceSamples=args.numTraceSamples, # numSeriesTerms=args.numSeriesTerms, # n_power_iter = args.powerIterSpectralNorm, # density_estimation=args.densityEstimation, # actnorm=(not args.noActnorm), # learn_prior=(not args.fixedPrior), # nonlin=args.nonlin) print("Only multiscale model supported.") exit() return model model = get_model(args) # init actnrom parameters init_batch, init_target = get_init_batch(trainloader, args.init_batch) print("initializing actnorm parameters...") with torch.no_grad(): model(init_batch, init_target, ignore_logdet=True) print("initialized") use_cuda = torch.cuda.is_available() if use_cuda: model.cuda() model = torch.nn.DataParallel(model, range(torch.cuda.device_count())) cudnn.benchmark = True in_shapes = model.module.get_in_shapes() else: in_shapes = model.get_in_shapes() # optionally resume from a checkpoint if args.resume: if os.path.isfile(args.resume): print("=> loading checkpoint '{}'".format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] best_objective = checkpoint['objective'] print('objective: ' + str(best_objective)) model = checkpoint['model'] if use_cuda: model.module.set_num_terms(args.numSeriesTerms) else: model.set_num_terms(args.numSeriesTerms) print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) else: print("=> no checkpoint found at '{}'".format(args.resume)) try_make_dir(args.save_dir) if args.analysisTraceEst: anaylse_trace_estimation(model, testset, use_cuda, args.extension) return if args.norm: test_spec_norm(model, in_shapes, args.extension) return if args.interpolate: interpolate(model, testloader, testset, start_epoch, use_cuda, best_objective, args.dataset) return if args.evaluate: test_log = open(os.path.join(args.save_dir, "test_log.txt"), 'w') if use_cuda: model.module.set_num_terms(args.numSeriesTerms) else: model.set_num_terms(args.numSeriesTerms) model = torch.nn.DataParallel(model.module) test(best_objective, args, model, start_epoch, testloader, viz, use_cuda, test_log, inverse_den_est) return print('| Train Epochs: ' + str(args.epochs)) print('| Initial Learning Rate: ' + str(args.lr)) elapsed_time = 0 test_objective = -np.inf if args.optimizer == "adam": optimizer = optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) elif args.optimizer == "adamax": optimizer = optim.Adamax(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) else: optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=args.weight_decay, nesterov=args.nesterov) with open(os.path.join(args.save_dir, 'params.txt'), 'w') as f: f.write(json.dumps(args.__dict__)) train_log = open(os.path.join(args.save_dir, "train_log.txt"), 'w') #### Copy all project code dst_dir = os.path.join(args.save_dir, 'code') try_make_dir(dst_dir) marco_src_path = './' onlyfiles = [ f for f in listdir(marco_src_path) if isfile(join(marco_src_path, f)) ] pythonfiles = [f for f in onlyfiles if f.endswith('.py')] for f in pythonfiles: copyfile(f, os.path.join(dst_dir, f)) models_src_path = 'models/' dst_dir = os.path.join(dst_dir, 'models') try_make_dir(dst_dir) onlyfiles = [ f for f in listdir(models_src_path) if isfile(join(models_src_path, f)) ] pythonfiles = [f for f in onlyfiles if f.endswith('.py')] for f in pythonfiles: copyfile(os.path.join(models_src_path, f), os.path.join(dst_dir, f)) for epoch in range(1, 1 + args.epochs): start_time = time.time() train(args, model, optimizer, epoch, trainloader, trainset, viz, use_cuda, train_log) epoch_time = time.time() - start_time elapsed_time += epoch_time print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time))) try: epoch except NameError: epoch = 0 print('Testing model') test_log = open(os.path.join(args.save_dir, "test_log.txt"), 'w') test_objective = test(test_objective, args, model, epoch, testloader, viz, use_cuda, test_log, inverse_den_est, args.gen) print('* Test results : objective = %.2f%%' % (test_objective)) with open(os.path.join(args.save_dir, 'final.txt'), 'w') as f: f.write(str(test_objective))