コード例 #1
0
ファイル: train.py プロジェクト: marilynzhang/DeepPrivacy
    def transition_model(self):
        self.latest_switch += self.transition_iters
        if self.is_transitioning:
            # Stop transitioning
            self.is_transitioning = False
            self.update_transition_value()
            print(f"Stopping transition. Global step: {self.global_step}, transition_variable: {self.transition_variable}, Current imsize: {self.current_imsize}")
            self.save_checkpoint()
        elif self.current_imsize < self.max_imsize:
            # Save image before transition
            self.save_transition_checkpoint()
            self.save_transition_image(True)
            self.extend_models()
            del self.dataloader_train, self.dataloader_val
            self.dataloader_train, self.dataloader_val = load_dataset(
                self.dataset, self.batch_size, self.current_imsize,
                self.full_validation, self.pose_size,
                self.load_fraction_of_dataset)
            self.is_transitioning = True
            print(f"Start transition. Global step: {self.global_step}, transition_variable: {self.transition_variable}, Current imsize: {self.current_imsize}")

            self.init_optimizers()
            self.update_transition_value()
            print(f"New transition value: {self.transition_variable}")

            # Save image after transition
            self.save_transition_image(False)
コード例 #2
0
ファイル: worker05.py プロジェクト: garima0106/GAN
    def __init__(self, config):
        # Set Hyperparameters
        self.batch_size_schedule = config.train_config.batch_size_schedule
        self.dataset = config.dataset
        self.learning_rate = config.train_config.learning_rate
        self.running_average_generator_decay = config.models.generator.running_average_decay
        self.pose_size = config.models.pose_size
        self.discriminator_model = config.models.discriminator.structure
        self.full_validation = config.use_full_validation
        self.load_fraction_of_dataset = config.load_fraction_of_dataset

        # Image settings
        self.current_imsize = 4
        self.image_channels = 3
        self.max_imsize = config.max_imsize

        # Logging variables
        self.checkpoint_dir = config.checkpoint_dir
        self.model_name = self.checkpoint_dir.split("/")[-2]
        self.config_path = config.config_path
        self.global_step = 0
        self.logger = logger.Logger(config.summaries_dir,
                                    config.generated_data_dir)
        # Transition settings
        self.transition_variable = 1.
        self.transition_iters = config.train_config.transition_iters
        self.is_transitioning = False
        self.transition_step = 0
        self.start_channel_size = config.models.start_channel_size
        self.latest_switch = 0
        self.opt_level = config.train_config.amp_opt_level
        self.start_time = time.time()
        self.discriminator, self.generator = init_model(
            self.pose_size, config.models.start_channel_size,
            self.image_channels, self.discriminator_model)
        self.init_running_average_generator()
        self.criterion = loss.WGANLoss(self.discriminator, self.generator,
                                       self.opt_level)
        if not self.load_checkpoint():
            print("Could not load checkpoint, so extending the models")
            self.extend_models()
            self.init_optimizers()

        self.batch_size = self.batch_size_schedule[self.current_imsize]
        self.update_running_average_beta()
        self.logger.log_variable("stats/batch_size", self.batch_size)

        self.num_ims_per_log = config.logging.num_ims_per_log
        self.next_log_point = self.global_step
        self.num_ims_per_save_image = config.logging.num_ims_per_save_image
        self.next_image_save_point = self.global_step
        self.num_ims_per_checkpoint = config.logging.num_ims_per_checkpoint
        self.next_validation_checkpoint = self.global_step

        self.dataloader_train, self.dataloader_val = load_dataset(
            self.dataset, self.batch_size, self.current_imsize,
            self.full_validation, self.pose_size,
            self.load_fraction_of_dataset)
        self.static_z = to_cuda(torch.randn((8, 32, 4, 4)))
        self.num_skipped_steps = 0
コード例 #3
0
ファイル: train.py プロジェクト: marilynzhang/DeepPrivacy
    def __init__(self, config):
        # Set Hyperparameters
        
        #  All of the input-output channels
        self.batch_size_schedule = config.train_config.batch_size_schedule
        self.dataset = config.dataset
        self.learning_rate = config.train_config.learning_rate
        self.running_average_generator_decay = config.models.generator.running_average_decay
        # Used in ProgressiveBaseModel that both G and D are based off of
        # TODO: Figure out what it is actually doing lol
        self.pose_size = config.models.pose_size
        # Normal vs. deep
        # TODO: What's the difference
        self.discriminator_model = config.models.discriminator.structure
        # Default = False
        self.full_validation = config.use_full_validation
        # Can be used for tests?
        self.load_fraction_of_dataset = config.load_fraction_of_dataset

        # Image settings
        # Image begins at 4x4 and slowly upsamples
        self.current_imsize = 4
        self.image_channels = 3
        # The ending image dimension after series of upsamplings
        # DeepPrivacy uses 128x128 usually
        self.max_imsize = config.max_imsize

        # Logging variables
        self.checkpoint_dir = config.checkpoint_dir
        self.model_name = self.checkpoint_dir.split("/")[-2]
        self.config_path = config.config_path
        self.global_step = 0
        self.logger = logger.Logger(config.summaries_dir,
                                    config.generated_data_dir)
        # Transition settings
        self.transition_variable = 1.
        self.transition_iters = config.train_config.transition_iters
        self.is_transitioning = False
        self.transition_step = 0
        self.start_channel_size = config.models.start_channel_size
        self.latest_switch = 0
        self.opt_level = config.train_config.amp_opt_level
        self.start_time = time.time()
        self.discriminator, self.generator = init_model(self.pose_size,
                                                        config.models.start_channel_size,
                                                        self.image_channels,
                                                        self.discriminator_model)
        self.init_running_average_generator()
        self.criterion = loss.WGANLoss(self.discriminator,
                                       self.generator,
                                       self.opt_level)
        if not self.load_checkpoint():
            self.extend_models()
            self.init_optimizers()

        self.batch_size = self.batch_size_schedule[self.current_imsize]
        self.update_running_average_beta()
        self.logger.log_variable("stats/batch_size", self.batch_size)

        self.num_ims_per_log = config.logging.num_ims_per_log
        self.next_log_point = self.global_step
        self.num_ims_per_save_image = config.logging.num_ims_per_save_image
        self.next_image_save_point = self.global_step
        self.num_ims_per_checkpoint = config.logging.num_ims_per_checkpoint
        self.next_validation_checkpoint = self.global_step

        self.dataloader_train, self.dataloader_val = load_dataset(
            self.dataset, self.batch_size, self.current_imsize, self.full_validation, self.pose_size, self.load_fraction_of_dataset)
        self.static_z = to_cuda(torch.randn((8, 32, 4, 4)))
        self.num_skipped_steps = 0
コード例 #4
0
        save_path = default_path
    model_name = config.config_path.split("/")[-2]
    ckpt = load_checkpoint(os.path.join("validation_checkpoints", model_name))
    #ckpt = load_checkpoint(os.path.join(
    #                                    os.path.dirname(config.config_path),
    #                                    "checkpoints"))
    generator = init_generator(config, ckpt)
    imsize = ckpt["current_imsize"]
    pose_size = config.models.pose_size
    return generator, imsize, save_path, pose_size


generator, imsize, save_path, pose_size = read_args()

batch_size = 128
dataloader_train, dataloader_val = load_dataset("fdf", batch_size, 128, True, pose_size, True )
dataloader_val.update_next_transition_variable(1.0)
fake_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3),
                       dtype=np.uint8)
real_images = np.zeros((len(dataloader_val)*batch_size, imsize, imsize, 3),
                       dtype=np.uint8)
z = generator.generate_latent_variable(batch_size, "cuda", torch.float32).zero_()
with torch.no_grad():
    for idx, (real_data, condition, landmarks) in enumerate(tqdm.tqdm(dataloader_val)):

        fake_data = generator(condition, landmarks, z.clone())
        fake_data = torch_utils.image_to_numpy(fake_data, to_uint8=True, denormalize=True)
        real_data = torch_utils.image_to_numpy(real_data, to_uint8=True, denormalize=True)
        start_idx = idx * batch_size
        end_idx = (idx+1) * batch_size
コード例 #5
0
import matplotlib.pyplot as plt
import numpy as np
import torch
from deep_privacy.data_tools.dataloaders import load_dataset
from deep_privacy import torch_utils
from deep_privacy.data_tools import data_utils

start_imsize = 8
batch_size = 32

dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize, False, 14,
                                True)

dl = dl_val
dl.update_next_transition_variable(1.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = data_utils.denormalize_img(im)
    im = torch_utils.image_to_numpy(im, to_uint8=True)
    to_save1 = im
    break
to_save1 = np.concatenate(to_save1, axis=1)
dl_train, dl_val = load_dataset("fdf", batch_size, start_imsize * 2, False, 14,
                                True)
dl = dl_val
dl.update_next_transition_variable(0.0)
next(iter(dl))

for im, condition, landmark in dl:
    im = torch.nn.functional.avg_pool2d(im, 2)
コード例 #6
0
import deep_privacy.config_parser as config_parser
import torch
import os
import torchvision
from deep_privacy.models.unet_model import init_model
from deep_privacy.data_tools.dataloaders import load_dataset
from deep_privacy.data_tools.data_utils import denormalize_img

dl_train, _ = load_dataset("yfcc100m128",
                           batch_size=64,
                           imsize=64,
                           full_validation=False,
                           pose_size=14,
                           load_fraction=True)
config = config_parser.load_config("models/minibatch_std/config.yml")
ckpt = torch.load("models/minibatch_std/transition_checkpoints/imsize64.ckpt")
discriminator, generator = init_model(config.models.pose_size,
                                      config.models.start_channel_size,
                                      config.models.image_channels,
                                      config.models.discriminator.structure)
generator.load_state_dict(ckpt["G"])
generator.cuda()
print(generator.network.current_imsize)
dl_train.update_next_transition_variable(1.0)
ims, conditions, landmarks = next(iter(dl_train))

fakes = denormalize_img(generator(conditions, landmarks))
os.makedirs(".debug", exist_ok=True)
torchvision.utils.save_image(fakes, ".debug/test.jpg")

# Extend