def main(): parser = argparse.ArgumentParser(description="Train a DCGAN on CIFAR10") parser.add_argument("--n_epochs", type=int, default=25, help="number of epochs of training") parser.add_argument("--batch_size", type=int, default=2 ** 5, help="size of the batches") parser.add_argument("--lr", type=float, default=0.0002, help="adam: learning rate") parser.add_argument("--b1", type=float, default=0.5, help="adam: decay of first order momentum of gradient") parser.add_argument("--b2", type=float, default=0.999, help="adam: decay of first order momentum of gradient") parser.add_argument("--latent_dim", type=int, default=100, help="dimensionality of the latent space") parser.add_argument( "--sample_interval", type=int, default=100, help="interval between image sampling. The number refers to the number of minibatch updates.", ) parser.add_argument( "--save_model_interval", type=int, default=10, help="Save the generator once every this many epochs." ) parser.add_argument("--prob_model_dir", type=str, help="interval between image sampling") parser.add_argument( "--classes", type=int, help="a list of integers (0-9) denoting the classes to consider", nargs="+" ) # -------------------------------- args = parser.parse_args() # op is a dict op = vars(args) if op["classes"] is None: # classes not specified => consider all classes op["classes"] = list(range(10)) classes = sorted(op["classes"]) cls_str = "".join(map(str, classes)) if op["prob_model_dir"] is None: # use the list of classes to name the prob_model_dir prob_model_dir_name = "cifar10_c{}-dcgan".format(cls_str) op["prob_model_dir"] = glo.prob_model_folder(prob_model_dir_name) log.l().info("Options used: ") pprint.pprint(op) dcgan = DCGAN(**op) model_fname = "cifar10_c{}-dcgan-ep{}_bs{}.pt".format(cls_str, op["n_epochs"], op["batch_size"]) model_fpath = os.path.join(dcgan.prob_model_dir, model_fname) # train log.l().info("Starting training a DCGAN on CIFAR10") dcgan.train() # save the generator g = dcgan.generator log.l().info("Saving the trained model to: {}".format(model_fpath)) g.save(model_fpath)
def train(self): """ Traing a DCGAN model with the training hyperparameters as specified in the constructor. Directly modify the state of this object to store all relevant variables. * self.generator stores the trained generator. """ # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator img_size = 32 minmax = (0.0, 1.0) # f_noise = lambda n: sample_standard_normal(n, self.latent_dim) # generator = ConvTranGenerator1(latent_dim=self.latent_dim, # f_noise=f_noise, channels=3, minmax=minmax) # generator = ReluGenerator1(latent_dim=self.latent_dim, # f_noise=f_noise, channels=3, minmax=minmax) generator = PatsornGenerator1(latent_dim=self.latent_dim, channels=3, minmax=minmax) # generator = SlowConvTransGenerator1(latent_dim=self.latent_dim, # channels=3, minmax=minmax) discriminator = Discriminator(channels=3, minmax=minmax) cuda = True if torch.cuda.is_available() else False if self.use_cuda and cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Initialize weights generator.apply(weights_init_normal) discriminator.apply(weights_init_normal) # Configure data loader os.makedirs(self.data_dir, exist_ok=True) # trdata = torchvision.datasets.CIFAR10(self.data_dir, train=True, download=True, # transform=transforms.Compose([ # transforms.ToTensor(), # # transforms.Normalize((0.1307,), (0.3081,)) # ])) print("classes to use to train: {}".format(self.classes)) trdata = cifar10_util.load_cifar10_class_subsets(self.classes, train=True, device="cpu", dtype=torch.float) print("dataset size: {}".format(len(trdata))) dataloader = torch.utils.data.DataLoader(trdata, batch_size=self.batch_size, shuffle=True, drop_last=True) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(), lr=self.lr, betas=(self.b1, self.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=self.lr, betas=(self.b1, self.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- # noise vectors for saving purpose z_save = generator.sample_noise(25).type(Tensor) for epoch in range(self.n_epochs): for i, (imgs, _) in enumerate(dataloader): # Adversarial ground truths valid = Variable(Tensor(imgs.shape[0], 1).fill_(1.0), requires_grad=False) fake = Variable(Tensor(imgs.shape[0], 1).fill_(0.0), requires_grad=False) # Configure input real_imgs = Variable(imgs.type(Tensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # Sample noise as generator input z = Variable(generator.sample_noise(imgs.shape[0]).type(Tensor)) # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs), valid) g_loss.backward() optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs), valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() optimizer_D.step() print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, self.n_epochs, i, len(dataloader), d_loss.item(), g_loss.item()) ) batches_done = epoch * len(dataloader) + i if batches_done % self.sample_interval == 0: with torch.no_grad(): gev = generator.eval() gen_save = gev(z_save) save_image( gen_save.data[:25], "%s/%06d.png" % (self.prob_model_dir, batches_done), nrow=5, normalize=False ) # keep the state of the generator self.generator = generator # Save the model once in a while if (epoch + 1) % self.save_model_interval == 0: model_fname = DCGAN.make_model_file_name(self.classes, epoch + 1, self.batch_size) model_fpath = os.path.join(self.prob_model_dir, model_fname) log.l().info("Save the generator after {} epochs. Save to: {}".format(epoch + 1, model_fpath)) generator.save(model_fpath)
def pt_gkmm( g, cond_imgs, extractor, k, Z, optimizer, sum_writer, input_weights=None, z_penalty=TPNull(), device=torch.device("cpu"), tensor_type=torch.FloatTensor, n_opt_iter=500, seed=1, texture=0, img_log_steps=10, weigh_logits=0, log_img_dir="", ): """ Conditionally generate images conditioning on the input images (cond_imgs) using kernel moment matching. * g: a generator of type torch.nn.Module (forward() takes noise vectors and tranforms them into images). Need to be differentiable for the optimization. * cond_imgs: a stack of input images to condition on. Pixel value range should be [0,1] * extractor: an instance of torch.nn.Module representing a feature extractor for image input. * k: cadgan.kernel.PTKernel representing a kernel on top of the extracted features. * Z: a stack of noise vectors to be optimized. These are fed to the generator g for the optimization. * optimizer: a Pytorch optimizer. The list of variables to optimize has to contain Z. * sum_writer: SummaryWriter for tensorboard. * input_weights: a one-dimensional Torch tensor (vector) whose length is the same as the number of conditioned images. Specifies weights of the conditioned images. 0 <= w_i <= 1 and weights sum to 1. If None, automatically set to uniform weight.s * z_penalty: a TensorPenalty to penalize Z. Set to TPNull() to set to penalty. * device: a object constructed from torch.device(..). Likely this might be torch.device('cuda') or torch.device('cpu'). Use CPU by default. * tensor_type: Default Pytorch tensor type to use e.g., torch.FloatTensor or torch.cuda.FloatTensor. Use torch.FloatTensor by default (for cpu) * n_opt_iter: number of iterations for the optimization * seed: random seed (positive integer) * img_log_steps: record generated images once every this many optimization steps. * weigh_logits: to weight the output logits of feature extactor so that we can backpropagate w.r.t certain image feature. Write output in a Tensorboard log. """ # Check generator's output range and image pixel range # We work with [0, 1] pixel_values_check(cond_imgs, (0, 1), "cond_imgs") tmp_sam = g.forward(Z) pixel_values_check(tmp_sam, (0, 1), "generator's output") # number of images to condition on n_cond = cond_imgs.shape[0] if input_weights is None: # None => set to uniform weights. input_weights = torch.ones( n_cond, device=device).type(tensor_type) / float(n_cond) # Check the rangeo of input_weights. Has to be in [0,1] if not ((input_weights >= 0.0).all() and (input_weights <= 1.0).all()): raise ValueError( '"input_weights" contains at least one weight which is outside [0,1] interval. Was {}' .format(input_weights)) # Check that the weights sum to 1 if torch.abs(input_weights.sum() - 1.0) > 1e-3: raise ValueError('"input_weights" does not sum to one. Was {}'.format( input_weights.sum())) gens_cpu = tmp_sam.to(torch.device("cpu")) arranged_init_imgs = torchvision.utils.make_grid(gens_cpu, nrow=2, normalize=True) log.l().debug('Adding initial generated images to Tensorboard') sum_writer.add_image("Init_Images", arranged_init_imgs) del tmp_sam # Setting requires_grad=True is very important. We will optimize Z. Z.requires_grad = True # number of images to generate n_sample = Z.shape[0] # Put models on gpu if needed # with torch.enable_grad(): # g = g.to(device) # Select a test image from the generated images arranged_cond_imgs = torchvision.utils.make_grid(cond_imgs, nrow=2, normalize=True) sum_writer.add_image("Cond_Images", arranged_cond_imgs) with torch.no_grad(): FX_ = extractor.forward(cond_imgs) FX = FX_ if weigh_logits: FX = weighing_logits(FX) # mean_KFX = torch.mean(k.eval(FX, FX)) kFX = k.eval(FX, FX) mean_KFX = kFX.mv(input_weights).dot(input_weights) time_per_itr = [] loss_all = [] for t in range(n_opt_iter): def closure(): Z.data.clamp_(-3.3, 3.3) optimizer.zero_grad() gens = g.forward(Z) if gens.size()[3] == 1024: # Downsample images else it takes a lot of time in optimization # TODO: WJ: To downsample, it is better to do it before calling this function. # Condiitonal generation function does not need to handle this. downsample = torch.nn.AvgPool2d(3, stride=2) gens = downsample(downsample(gens)) if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1: gens_cpu = gens.to(torch.device("cpu")) imutil.save_images( gens_cpu, os.path.join(log_img_dir, "output_images", str(t))) arranged_gens = torchvision.utils.make_grid(gens_cpu, nrow=2, normalize=True) log.l().debug( 'Logging generated images at iteration {}'.format(t + 1)) sum_writer.add_image("Generated_Images", arranged_gens, t) F_gz = extractor.forward(gens) # import pdb; pdb.set_trace() if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1: feature_size = int(np.sqrt(F_gz.shape[1])) # import pdb; pdb.set_trace() try: feat_out = F_gz.view(F_gz.shape[0], 1, feature_size, feature_size) gens_cpu = feat_out.to(torch.device("cpu")) imutil.save_images( gens_cpu, os.path.join(log_img_dir, "feature_images", str(t))) arranged_init_imgs = torchvision.utils.make_grid( gens_cpu, nrow=2, normalize=True) sum_writer.add_image("feature_images", arranged_init_imgs, t) except: if t == 0: log.l().debug( "Unable to plot features as image. Okay. Will skip plotting features." ) if weigh_logits: # WJ: This option is not really used. Should be removed. F_gz = weighing_logits(F_gz) KF_gz = k.eval(F_gz, F_gz) Z_loss = z_penalty(Z) mmd2 = torch.mean(KF_gz) - 2.0 * torch.mean( k.eval(F_gz, FX).mv(input_weights)) + mean_KFX loss = mmd2 + Z_loss # compute the gradients loss.backward(retain_graph=True) # record losses sum_writer.add_scalar("loss/total", loss.item(), t) sum_writer.add_scalar("loss/mmd2", mmd2.item(), t) sum_writer.add_scalar("loss/Z_penalty", Z_loss, t) # record some statistics sum_writer.add_scalar("Z/max_z", torch.max(Z), t) sum_writer.add_scalar("Z/min_z", torch.min(Z), t) sum_writer.add_scalar("Z/avg_z", torch.mean(Z), t) sum_writer.add_scalar("Z/std_z", torch.std(Z), t) sum_writer.add_histogram("Z/hist", Z.reshape(-1), t) loss_all.append(mmd2.item()) if t <= 20 or t % 20 == 0: log.l().info("Iter [{}], overall_loss: {}".format( t, loss.item())) return loss # start_time = datetime.datetime.now() optimizer.step(closure)
def closure(): Z.data.clamp_(-3.3, 3.3) optimizer.zero_grad() gens = g.forward(Z) if gens.size()[3] == 1024: # Downsample images else it takes a lot of time in optimization # TODO: WJ: To downsample, it is better to do it before calling this function. # Condiitonal generation function does not need to handle this. downsample = torch.nn.AvgPool2d(3, stride=2) gens = downsample(downsample(gens)) if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1: gens_cpu = gens.to(torch.device("cpu")) imutil.save_images( gens_cpu, os.path.join(log_img_dir, "output_images", str(t))) arranged_gens = torchvision.utils.make_grid(gens_cpu, nrow=2, normalize=True) log.l().debug( 'Logging generated images at iteration {}'.format(t + 1)) sum_writer.add_image("Generated_Images", arranged_gens, t) F_gz = extractor.forward(gens) # import pdb; pdb.set_trace() if t <= -1 or t % img_log_steps == 0 or t == n_opt_iter - 1: feature_size = int(np.sqrt(F_gz.shape[1])) # import pdb; pdb.set_trace() try: feat_out = F_gz.view(F_gz.shape[0], 1, feature_size, feature_size) gens_cpu = feat_out.to(torch.device("cpu")) imutil.save_images( gens_cpu, os.path.join(log_img_dir, "feature_images", str(t))) arranged_init_imgs = torchvision.utils.make_grid( gens_cpu, nrow=2, normalize=True) sum_writer.add_image("feature_images", arranged_init_imgs, t) except: if t == 0: log.l().debug( "Unable to plot features as image. Okay. Will skip plotting features." ) if weigh_logits: # WJ: This option is not really used. Should be removed. F_gz = weighing_logits(F_gz) KF_gz = k.eval(F_gz, F_gz) Z_loss = z_penalty(Z) mmd2 = torch.mean(KF_gz) - 2.0 * torch.mean( k.eval(F_gz, FX).mv(input_weights)) + mean_KFX loss = mmd2 + Z_loss # compute the gradients loss.backward(retain_graph=True) # record losses sum_writer.add_scalar("loss/total", loss.item(), t) sum_writer.add_scalar("loss/mmd2", mmd2.item(), t) sum_writer.add_scalar("loss/Z_penalty", Z_loss, t) # record some statistics sum_writer.add_scalar("Z/max_z", torch.max(Z), t) sum_writer.add_scalar("Z/min_z", torch.min(Z), t) sum_writer.add_scalar("Z/avg_z", torch.mean(Z), t) sum_writer.add_scalar("Z/std_z", torch.std(Z), t) sum_writer.add_histogram("Z/hist", Z.reshape(-1), t) loss_all.append(mmd2.item()) if t <= 20 or t % 20 == 0: log.l().info("Iter [{}], overall_loss: {}".format( t, loss.item())) return loss
def main(): # Training settings parser = argparse.ArgumentParser( description= 'PyTorch GKMM. Some paths are relative to the "(share_path)/prob_models/". See settings.ini for (share_path).' ) parser.add_argument( "--extractor_type", type=str, default="vgg", help= "The feature extractor. The saved object should be a torch.nn.Module representing a \ feature extractor. Currently support [vgg | vgg_face | alexnet_365 | resnet18_365 | resnet50_365 | hed | mnist_cnn | pixel]", required=True, ) parser.add_argument( "--extractor_layers", nargs="+", default=["4", "9", "18", "27"], help= "Number of layers to include. Only for VGG feature extractor. Default:[]", ) parser.add_argument( "--texture", type=float, default=0, help="Use texture (grammatrix) of extracted features. Default=0") parser.add_argument( "--depth_process", nargs="?", choices=["avg", "max", "no"], default="no", help="Processing module to run on the output from \ each filter in the specified layer(s).", ) parser.add_argument( "--g_path", type=str, required=True, help="Relative path \ (relative to (share_path)/prob_models) to the file that can be loaded \ to get a cadgan.gen.PTNoiseTransformer representing an image generator.", ) parser.add_argument( "--g_type", type=str, default="celebAHQ.yaml", help="Generator type based on the data it is trained for.") parser.add_argument( "--g_min", type=float, help="The minimum value of the pixel output from the generator.", required=True) parser.add_argument( "--g_max", type=float, help="The maximum value of the pixel output from the generator.", required=True) parser.add_argument( "--logdir", type=str, required=True, help="full path to the folder to contain Tensorboard log files") parser.add_argument("--device", nargs="?", choices=["cpu", "gpu"], default="cpu", help="Device to use for computation.") parser.add_argument("--n_sample", type=int, default=16, metavar="n", help="Number of images to generate") parser.add_argument("--n_opt_iter", type=int, default=500, help="Number of optimization iterations") parser.add_argument("--lr", type=float, default=0.001, metavar="LR", help="learning rate (for the optimizer)") parser.add_argument("--n_init_resample", type=float, default=1, help="number of time to resample z for the heuristic") parser.add_argument( "--seed", type=int, default=1, metavar="S", help= "Random seed. Among others, this affects the initialization of the noise vectors of the generator in the optimization.", ) parser.add_argument( "--img_log_steps", type=int, default=10, metavar="N", help= "how many optimization iterations to wait before logging generated images", ) parser.add_argument("--img_size", type=int, default=224, help="image size nxn default 256") # parser.add_argument('--data_dir', type=str, # default='mnist/', help='Relative path (relative to the data folder) \ # containing Mnist training data. Mnist data will be downloaded if \ # not existed already.') # parser.add_argument('--cond', nargs='+', type=int, dest='cond', # action='append', required=True, help='Digit label and number of images from that label to condition on. For example, "--cond 3 4" means 4 images of digit 3. --cond can be used multiple times. For instance, use --cond 1 2 --cond 3 1 to condition on 2 digits of 1, and 1 digit of 3') parser.add_argument("--cond_path", type=str, required=True, help="Path to imgs for conditioning") parser.add_argument( "--kernel", nargs="?", required=True, choices=["linear", "gauss", "imq"], help= "choice of kernel to put on top of extracted features. May need to specify also --kparams.", ) parser.add_argument( "--kparams", nargs="*", type=float, dest="kparams", default=[], help= "A list of kernel parameters (float). Semantic of parameters depends on the chosen kernel", ) parser.add_argument( "--w_input", nargs="+", default=[], help= "weight of the input, must be equal to the number of cond images and sum to 1. if none specified, equal weights will be used.", ) img_transform = target_transform() # glo.data_file('mnist/') args = parser.parse_args() print("Training options: ") args_dict = vars(args) pprint.pprint(args_dict, width=5) # --------------------------------- # Check if texture and extractor are called correctly if args.texture and not args.extractor_layers or args.texture and not args.extractor_type: parser.error( "Texture call, Extractor layers and Extractor type must be given at the same time!" ) # True to use GPU use_cuda = args.device == "gpu" and torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") tensor_type = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor torch.set_default_tensor_type(tensor_type) # load option depends on whether GPU is used device_load_options = {} if use_cuda else { "map_location": lambda storage, loc: storage } # initialize the noise vectors for the generator # Set the random seed seed = args.seed torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) n_sample = args.n_sample if args.g_type.endswith(".yaml"): # sample a stack of noise vectors latent_dim = 256 f_noise = lambda n: torch.randn(n, latent_dim).float() Z0 = f_noise(n_sample) # Loading Configs for LarsGAN yaml_folder = os.path.dirname(ganstab.configs.__file__) yaml_config_path = os.path.join(yaml_folder, args.g_type) config = load_config(yaml_config_path) # load generator nlabels = config["data"]["nlabels"] out_dir = config["training"]["out_dir"] checkpoint_dir = os.path.join(out_dir, "chkpts") generator = build_generator(config) # Put models on gpu if needed #with torch.enable_grad(): # use_cuda?????? generator = generator.to(device) # for celebA HQ generator, # if args.g_type == 'celebAHQ.yaml': # generator.add_resize(args.img_size) # Use multiple GPUs if possible generator = nn.DataParallel(generator) # Logger checkpoint_io = CheckpointIO(checkpoint_dir=checkpoint_dir) # Register modules to checkpoint checkpoint_io.register_modules(generator=generator) # Test generator if config["test"]["use_model_average"]: generator_test = copy.deepcopy(generator) checkpoint_io.register_modules(generator_test=generator_test) else: generator_test = generator # Loading Generator ydist = get_ydist(nlabels, device=device) full_g_path = glo.prob_model_folder(args.g_path) if not os.path.exists(full_g_path): #download lars pre-trained model file if not existed print( "Generator file does not exist: {}\n I will load a pretrained model for you. Please wait ..." .format(full_g_path), end='') dict_url = { 'lsun_bedroom.yaml': 'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bedroom-df4e7dd2.pt', 'lsun_bridge.yaml': 'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_bridge-82887d22.pt', 'celebAHQ.yaml': 'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/celebahq-baab46b2.pt', 'lsun_tower.yaml': 'https://s3.eu-central-1.amazonaws.com/avg-projects/gan_stability/models/lsun_tower-1af5e570.pt' } assert args.g_type in dict_url.keys( ), 'g_type of {} not support'.format(args.g_type) url = dict_url[args.g_type] r = requests.get(url) os.makedirs(os.path.dirname(full_g_path), exist_ok=True) with open(full_g_path, 'wb') as f: f.write(r.content) print('done') load_options = {} if use_cuda else { "map_location": lambda storage, loc: storage } it = checkpoint_io.load(full_g_path, **load_options) elif args.g_type == "mnist_dcgan": # TODO should probablu reorganize these latent_dim = 100 f_noise = lambda n: torch.randn(n, latent_dim).float() Z0 = f_noise(n_sample) full_g_path = glo.prob_model_folder(args.g_path) # load option depends on whether GPU is used load_options = {} if use_cuda else { "map_location": lambda storage, loc: storage } generator = mnist_dcgan.Generator() if os.path.exists(full_g_path): generator.load(full_g_path) else: print( "Generator file does not exist: {}\nLoading pretrain model...". format(full_g_path)) generator.download_pretrain( output=full_g_path) # .load(full_g_path, **load_options) generator = generator.to(device) generator_test = generator ydist = None elif args.g_type == "colormnist_dcgan": # TODO should probablu reorganize these latent_dim = 100 f_noise = lambda n: torch.randn(n, latent_dim).float() Z0 = f_noise(n_sample) full_g_path = glo.prob_model_folder(args.g_path) generator = cmnist_dcgan.Generator() if os.path.exists(full_g_path): generator.load(full_g_path) else: print( "Generator file does not exist: {}\nLoading pretrain model...". format(full_g_path)) generator.download_pretrain( output=full_g_path) # .load(full_g_path, **load_options) generator = generator.to(device) generator_test = generator ydist = None # Noise distribution is Gaussian. Unlikely that the magnitude of the # coordinate is above the bound. z_penalty = kmain.TPNull() # kmain.TPSymLogBarrier(bound=4.2, scale=1e-4) args_dict["zpen"] = z_penalty # output range of the generator (according to what the user specifies) g_range = (args.g_min, args.g_max) # Sanity check. Check that the specified g-range is plausible g_out_uncontrolled = Generator(ydist=ydist, generator=generator_test.to(device)) temp_sample = g_out_uncontrolled.forward(Z0) kmain.pixel_values_check(temp_sample, g_range, "Generator's samples") extractor_in_size = args.img_size # transform the output range of g to (0,1) g = nn.Sequential( g_out_uncontrolled, nn.AdaptiveAvgPool2d((extractor_in_size, extractor_in_size)), gen.LinearRangeTransform(from_range=g_range, to_range=(0, 1)), ) depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()} feature_size = 128 if args.texture == 1: post_process = nn.Sequential(depth_process_map[args.depth_process], GramMatrix()) else: post_process = nn.Sequential(depth_process_map[args.depth_process]) # Loading Extractor if args.extractor_type == "vgg": extractor_layers = [int(i) for i in args.extractor_layers] extractor = ext.VGG19(layers=extractor_layers, layer_postprocess=post_process) elif args.extractor_type == "vgg_face": extractor_layers = [int(i) for i in args.extractor_layers] extractor = ext.VGG19_face(layers=extractor_layers, layer_postprocess=post_process) elif args.extractor_type == "alexnet_365": extractor = ext.AlexNet_365() elif args.extractor_type == "resnet18_365": extractor = ext.ResNet18_365() elif args.extractor_type == "resnet50_365": extractor = ext.ResNet50_365(n_remove_last_layers=2, layer_postprocess=post_process) elif args.extractor_type == "hed": # extractor_in_size = 256 extractor = ext.HED(device=device, resize=feature_size) elif args.extractor_type == "hed_color": #stacking feature from HED and tiny image to get both edge and color information hed = ext.HED(device=device, resize=feature_size) tiny = ext.TinyImage(device=device, grid_size=(10, 10)) extractor = ext.StackModule(device=device, module_list=[hed, tiny], weights=[0.01, 0.99]) elif args.extractor_type == "hed_vgg": #stacking feature from HED and vgg feature to get both edge and high level vgg information feature_size = 128 hed = ext.HED(device=device, resize=feature_size) extractor_layers = [int(i) for i in args.extractor_layers] vgg = ext.VGG19(layers=extractor_layers, layer_postprocess=post_process) extractor = ext.StackModule(device=device, module_list=[hed, vgg], weights=[0.99, 0.01]) elif args.extractor_type == "hed_color_vgg": #stacking feature from HED, tiny image, and vgg feature to get edge, color, and high level vgg information feature_size = 128 hed = ext.HED(device=device, resize=feature_size) extractor_layers = [int(i) for i in args.extractor_layers] vgg = ext.VGG19(layers=extractor_layers, layer_postprocess=post_process) tiny = ext.TinyImage(device=device, grid_size=(10, 10)) extractor = ext.StackModule(device=device, module_list=[hed, vgg, tiny], weights=[0.005, 0.005, 0.99]) elif args.extractor_type == "color": extractor = ext.TinyImage(device=device, grid_size=(128, 128)) elif args.extractor_type == "color_count": # to use with Waleed color mnist only: # the purpose is to count color based on the template, currently not working as expected. prototypes = torch.tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1], [1, 1, 0], [1, 0, 1], [0.4, 0.2, 0]]) extractor = ext.SoftCountPixels(prototypes=prototypes, gwidth2=0.3, device=device, tensor_type=tensor_type) elif args.extractor_type == "mnist_cnn": depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()} if args.texture == 1: post_process = nn.Sequential(depth_process_map[args.depth_process], GramMatrix()) else: post_process = nn.Sequential(depth_process_map[args.depth_process]) extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu", layer_postprocess=post_process, layer=int(args.extractor_layers[0])) elif args.extractor_type == "mnist_cnn_digit_layer": #using the last layer of MNIST CNN (digit classification) depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()} if args.texture == 1: post_process = nn.Sequential(depth_process_map[args.depth_process], GramMatrix()) else: post_process = nn.Sequential(depth_process_map[args.depth_process]) extractor = ext.MnistCNN(device="cuda" if use_cuda else "cpu", layer_postprocess=post_process, layer=3) elif args.extractor_type == "mnist_cnn_digit_layer_color": # using the last layer of MNIST CNN (digit classification) stacking with color information from tiny image depth_process_map = {"no": ext.Identity(), "avg": ext.GlobalAvgPool()} if args.texture == 1: post_process = nn.Sequential(depth_process_map[args.depth_process], GramMatrix()) else: post_process = nn.Sequential(depth_process_map[args.depth_process]) mnistcnn = ext.MnistCNN(device="cuda" if use_cuda else "cpu", layer_postprocess=post_process, layer=3) color = ext.MaxColor(device=device) extractor = ext.StackModule(device=device, module_list=[mnistcnn, color], weights=[1, 99]) elif args.extractor_type == "pixel": #raw pixel as feature extractor = ext.Identity( flatten=True, slice_dim=0 if args.g_type == "mnist_dcgan" else None) else: raise ValueError("Unknown extractor type. Check --extractor_type") if use_cuda: extractor = extractor.cuda() assert isinstance(extractor, torch.nn.Module) print("Summary of the extractor:") try: torchsummary.summary(extractor, input_size=(3, extractor_in_size, extractor_in_size)) except: log.l().info( "Exception occured when getting a summary of the extractor") # run a forward pass throught the extractor just to test tmp_extracted = extractor(g(Z0[[0]])) n_features = torch.prod(torch.tensor(tmp_extracted.shape)) print("Number of extracted features = {}".format(n_features)) del tmp_extracted def load_multiple_images(list_imgs): for path_img in list_imgs: loaded = imutil.load_resize_image(path_img, extractor_in_size).copy() cond_img = img_transform(loaded).unsqueeze(0).type( tensor_type) # .to(device) try: cond_imgs = torch.cat((cond_imgs.clone(), cond_img)) except NameError: cond_imgs = cond_img.clone() return cond_imgs if not os.path.isdir(glo.data_file(args.cond_path)): # # read list of imgs if it's a text file if args.cond_path.endswith(".txt"): img_txt_path = glo.data_file(args.cond_path) with open(img_txt_path, "r") as f: data = f.readlines() list_imgs = [ glo.data_file(x.strip()) for x in data if len(x.strip()) != 0 ] if not list_imgs: raise ValueError( "Empty list of images to condiiton. Make sure that {} is valid" .format(img_txt_path)) cond_imgs = load_multiple_images(list_imgs) elif args.cond_path.endswith(".png") or args.cond_path.endswith( ".jpg"): path_img = glo.data_file(args.cond_path) loaded = imutil.load_resize_image(path_img, extractor_in_size).copy() cond_imgs = img_transform(loaded).unsqueeze(0).type( tensor_type) # .to(device) else: raise 'Not support input type at {} (currently support folder or text file with list of images)'.format( glo.data_file(args.cond_path)) else: # using all images in the folder list_imgs = glob.glob(glo.data_file(args.cond_path) + "*") cond_imgs = load_multiple_images(list_imgs) cond_imgs = cond_imgs.to(device).type(tensor_type) # kernel on top of the extracted features k_map = { "linear": kernel.PTKLinear, "gauss": kernel.PTKGauss, "imq": kernel.PTKIMQ } kernel_key = args.kernel kernel_params = args.kparams k_constructor = k_map[kernel_key] # construct the chosen kernel with the specified parameters k = k_constructor(*kernel_params) # texture flag texture = args.texture # run the kernel moment matching optimization n_opt_iter = args.n_opt_iter logdir = args.logdir print("LOGDIR: ", logdir) # dictionary containing key-value pairs for experimental settings. log_str_dict = dict((ke, str(va)) for (ke, va) in args_dict.items()) # logdir is just a parent folder. # Form the actual file name by concatenating the values of all # hyperparameters used. log_str_dict2 = copy.deepcopy(log_str_dict) now = datetime.datetime.now() time_str = "{:02}.{:02}.{}_{:02}{:02}{:02}".format(now.day, now.month, now.year, now.hour, now.minute, now.second) log_str_dict2["t"] = time_str util.translate_keys( log_str_dict2, { "cond_path": "co", "data_dir": "dat", "depth_process": "dp", "extractor_path": "ep", "extractor_type": "et", "extractor_layers": "el", "g_type": "gt", "kernel": "k", "kparams": "kp", "n_opt_iter": "it", "n_sample": "n", "seed": "s", "texture": "te", }, ) parameters_str = util.dict_to_string( log_str_dict2, exclude=[ "device", "img_log_steps", "logdir", "g_min", "g_max", "g_path", "t" ], entry_sep="-", kv_sep="_", ) img_log_steps = args.img_log_steps logdir_fname = util.clean_filename(parameters_str, replace="/\\[]") log_dir_path = glo.result_folder(os.path.join(logdir, logdir_fname)) # multiple restarts to refine the drawn Z. This is just a heuristic # so we start (hopefully) from a good initial point. k_img = kernel.PTKFuncCompose(k, f=extractor) # multi_restarts_refiner = kmain.ZRMMDMultipleRestarts( # g, z_sampler=f_noise, k=k_img, X=cond_imgs, # n_restarts=100, # n_sample=Z0.shape[0], # ) tmp_gen = g(Z0) assert tmp_gen.shape[-1] == extractor_in_size and tmp_gen.shape[ -2] == extractor_in_size del tmp_gen if len(args.w_input) == 0: input_weights = None else: assert cond_imgs.shape[0] == len( args.w_input ), "number of input weights must equal to number of input images" input_weights = torch.Tensor([float(x) for x in args.w_input], device=device).type(tensor_type) # A heuristic to pick good Z to start the optimization multi_restarts_refiner = kmain.ZRMMDIterGreedy( g, z_sampler=f_noise, k=k_img, X=cond_imgs, n_draws=int( args.n_init_resample ), # number of times to draw each z_i --> set to 1 since I want to test the latent optimization, n_sample=Z0.shape[0], device=device, tensor_type=tensor_type, input_weights=input_weights, ) # Summary writer for Tensorboard logging sum_writer = SummaryWriter(log_dir=log_dir_path) # write all key-value pairs in log_str_dict to the Tensorboard for ke, va in log_str_dict.items(): sum_writer.add_text(ke, va) with open(os.path.join(log_dir_path, "metadata"), "wb") as f: dill.dump(log_str_dict, f) imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images")) gens = g.forward(Z0) gens_cpu = gens.to(torch.device("cpu")) imutil.save_images(gens_cpu, os.path.join(log_dir_path, "prior_images")) del gens del gens_cpu # import pdb; pdb.set_trace() # Get a better Z Z = multi_restarts_refiner(Z0) # Try to plot (in Tensorboard) extracted features as images if possible log.l().info( 'Attemping to plot extracted features as images. Will skip if this does not work' ) try: # if args.extractor_type == 'hed': feat_out = extractor.forward(cond_imgs) # import pdb; pdb.set_trace() feature_size = int(np.sqrt(feat_out.shape[1])) feat_out = feat_out.view(feat_out.shape[0], 1, feature_size, feature_size) gens_cpu = feat_out.to(torch.device("cpu")) imutil.save_images(gens_cpu, os.path.join(log_dir_path, "input_feature")) arranged_init_imgs = torchvision.utils.make_grid(gens_cpu, nrow=2, normalize=True) sum_writer.add_image("Init_feature", arranged_init_imgs) del feat_out except Exception as err: log.l().info(err) log.l().info("unable to plot feature as image") # if args.w_intp # import pdb; pdb.set_trace() imutil.save_images(cond_imgs, os.path.join(log_dir_path, "input_images")) # optimizer optimizer = torch.optim.Adam([Z], lr=args.lr) # ,momentum=0.99,nesterov=True) # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[1000,2000,3000], gamma=0.1) # optimizer = torch.optim.LBFGS([Z]) # --> LBFGS doesn't really converge, we could try other optimizer as well # Solve the kernel moment matching problem kmain.pt_gkmm( g, cond_imgs, extractor, k, Z, optimizer, z_penalty=z_penalty, sum_writer=sum_writer, device=device, tensor_type=tensor_type, n_opt_iter=n_opt_iter, seed=seed, texture=texture, input_weights=input_weights, img_log_steps=img_log_steps, log_img_dir=log_dir_path, ) print('Finished, results location : {}'.format(log_dir_path))