def transition_model(self): self.latest_switch += self.transition_iters if self.is_transitioning: # Stop transitioning self.is_transitioning = False self.update_transition_value() print(f"Stopping transition. Global step: {self.global_step}, transition_variable: {self.transition_variable}, Current imsize: {self.current_imsize}") self.save_checkpoint() elif self.current_imsize < self.max_imsize: # Save image before transition self.save_transition_checkpoint() self.save_transition_image(True) self.extend_models() del self.dataloader_train, self.dataloader_val self.dataloader_train, self.dataloader_val = load_dataset( self.dataset, self.batch_size, self.current_imsize, self.full_validation, self.pose_size, self.load_fraction_of_dataset) self.is_transitioning = True print(f"Start transition. Global step: {self.global_step}, transition_variable: {self.transition_variable}, Current imsize: {self.current_imsize}") self.init_optimizers() self.update_transition_value() print(f"New transition value: {self.transition_variable}") # Save image after transition self.save_transition_image(False)
def __init__(self, config): # Set Hyperparameters self.batch_size_schedule = config.train_config.batch_size_schedule self.dataset = config.dataset self.learning_rate = config.train_config.learning_rate self.running_average_generator_decay = config.models.generator.running_average_decay self.pose_size = config.models.pose_size self.discriminator_model = config.models.discriminator.structure self.full_validation = config.use_full_validation self.load_fraction_of_dataset = config.load_fraction_of_dataset # Image settings self.current_imsize = 4 self.image_channels = 3 self.max_imsize = config.max_imsize # Logging variables self.checkpoint_dir = config.checkpoint_dir self.model_name = self.checkpoint_dir.split("/")[-2] self.config_path = config.config_path self.global_step = 0 self.logger = logger.Logger(config.summaries_dir, config.generated_data_dir) # Transition settings self.transition_variable = 1. self.transition_iters = config.train_config.transition_iters self.is_transitioning = False self.transition_step = 0 self.start_channel_size = config.models.start_channel_size self.latest_switch = 0 self.opt_level = config.train_config.amp_opt_level self.start_time = time.time() self.discriminator, self.generator = init_model( self.pose_size, config.models.start_channel_size, self.image_channels, self.discriminator_model) self.init_running_average_generator() self.criterion = loss.WGANLoss(self.discriminator, self.generator, self.opt_level) if not self.load_checkpoint(): print("Could not load checkpoint, so extending the models") self.extend_models() self.init_optimizers() self.batch_size = self.batch_size_schedule[self.current_imsize] self.update_running_average_beta() self.logger.log_variable("stats/batch_size", self.batch_size) self.num_ims_per_log = config.logging.num_ims_per_log self.next_log_point = self.global_step self.num_ims_per_save_image = config.logging.num_ims_per_save_image self.next_image_save_point = self.global_step self.num_ims_per_checkpoint = config.logging.num_ims_per_checkpoint self.next_validation_checkpoint = self.global_step self.dataloader_train, self.dataloader_val = load_dataset( self.dataset, self.batch_size, self.current_imsize, self.full_validation, self.pose_size, self.load_fraction_of_dataset) self.static_z = to_cuda(torch.randn((8, 32, 4, 4))) self.num_skipped_steps = 0
def __init__(self, config): # Set Hyperparameters # All of the input-output channels self.batch_size_schedule = config.train_config.batch_size_schedule self.dataset = config.dataset self.learning_rate = config.train_config.learning_rate self.running_average_generator_decay = config.models.generator.running_average_decay # Used in ProgressiveBaseModel that both G and D are based off of # TODO: Figure out what it is actually doing lol self.pose_size = config.models.pose_size # Normal vs. deep # TODO: What's the difference self.discriminator_model = config.models.discriminator.structure # Default = False self.full_validation = config.use_full_validation # Can be used for tests? self.load_fraction_of_dataset = config.load_fraction_of_dataset # Image settings # Image begins at 4x4 and slowly upsamples self.current_imsize = 4 self.image_channels = 3 # The ending image dimension after series of upsamplings # DeepPrivacy uses 128x128 usually self.max_imsize = config.max_imsize # Logging variables self.checkpoint_dir = config.checkpoint_dir self.model_name = self.checkpoint_dir.split("/")[-2] self.config_path = config.config_path self.global_step = 0 self.logger = logger.Logger(config.summaries_dir, config.generated_data_dir) # Transition settings self.transition_variable = 1. self.transition_iters = config.train_config.transition_iters self.is_transitioning = False self.transition_step = 0 self.start_channel_size = config.models.start_channel_size self.latest_switch = 0 self.opt_level = config.train_config.amp_opt_level self.start_time = time.time() self.discriminator, self.generator = init_model(self.pose_size, config.models.start_channel_size, self.image_channels, self.discriminator_model) self.init_running_average_generator() self.criterion = loss.WGANLoss(self.discriminator, self.generator, self.opt_level) if not self.load_checkpoint(): self.extend_models() self.init_optimizers() self.batch_size = self.batch_size_schedule[self.current_imsize] self.update_running_average_beta() self.logger.log_variable("stats/batch_size", self.batch_size) self.num_ims_per_log = config.logging.num_ims_per_log self.next_log_point = self.global_step self.num_ims_per_save_image = config.logging.num_ims_per_save_image self.next_image_save_point = self.global_step self.num_ims_per_checkpoint = config.logging.num_ims_per_checkpoint self.next_validation_checkpoint = self.global_step self.dataloader_train, self.dataloader_val = load_dataset( self.dataset, self.batch_size, self.current_imsize, self.full_validation, self.pose_size, self.load_fraction_of_dataset) self.static_z = to_cuda(torch.randn((8, 32, 4, 4))) self.num_skipped_steps = 0
save_path = default_path model_name = config.config_path.split("/")[-2] ckpt = load_checkpoint(os.path.join("validation_checkpoints", model_name)) #ckpt = load_checkpoint(os.path.join( # os.path.dirname(config.config_path), # "checkpoints")) generator = init_generator(config, ckpt) imsize = ckpt["current_imsize"] pose_size = config.models.pose_size return generator, imsize, save_path, pose_size generator, imsize, save_path, pose_size = read_args() batch_size = 128 dataloader_train, dataloader_val = load_dataset("fdf", batch_size, 128, True, pose_size, True ) dataloader_val.update_next_transition_variable(1.0) fake_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3), dtype=np.uint8) real_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3), dtype=np.uint8) z = generator.generate_latent_variable(batch_size, "cuda", torch.float32).zero_() with torch.no_grad(): for idx, (real_data, condition, landmarks) in enumerate(tqdm.tqdm(dataloader_val)): fake_data = generator(condition, landmarks, z.clone()) fake_data = torch_utils.image_to_numpy(fake_data, to_uint8=True, denormalize=True) real_data = torch_utils.image_to_numpy(real_data, to_uint8=True, denormalize=True) start_idx = idx * batch_size end_idx = (idx+1) * batch_size
import matplotlib.pyplot as plt import numpy as np import torch from deep_privacy.data_tools.dataloaders import load_dataset from deep_privacy import torch_utils from deep_privacy.data_tools import data_utils start_imsize = 8 batch_size = 32 dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize, False, 14, True) dl = dl_val dl.update_next_transition_variable(1.0) next(iter(dl)) for im, condition, landmark in dl: im = data_utils.denormalize_img(im) im = torch_utils.image_to_numpy(im, to_uint8=True) to_save1 = im break to_save1 = np.concatenate(to_save1, axis=1) dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize * 2, False, 14, True) dl = dl_val dl.update_next_transition_variable(0.0) next(iter(dl)) for im, condition, landmark in dl: im = torch.nn.functional.avg_pool2d(im, 2)
import deep_privacy.config_parser as config_parser import torch import os import torchvision from deep_privacy.models.unet_model import init_model from deep_privacy.data_tools.dataloaders import load_dataset from deep_privacy.data_tools.data_utils import denormalize_img dl_train, _ = load_dataset("yfcc100m128", batch_size=64, imsize=64, full_validation=False, pose_size=14, load_fraction=True) config = config_parser.load_config("models/minibatch_std/config.yml") ckpt = torch.load("models/minibatch_std/transition_checkpoints/imsize64.ckpt") discriminator, generator = init_model(config.models.pose_size, config.models.start_channel_size, config.models.image_channels, config.models.discriminator.structure) generator.load_state_dict(ckpt["G"]) generator.cuda() print(generator.network.current_imsize) dl_train.update_next_transition_variable(1.0) ims, conditions, landmarks = next(iter(dl_train)) fakes = denormalize_img(generator(conditions, landmarks)) os.makedirs(".debug", exist_ok=True) torchvision.utils.save_image(fakes, ".debug/test.jpg") # Extend