Пример #1
0
    def __init__(self, config):
        super(WGAN, self).__init__()
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        # get logger and config
        self.logger = get_logger("VAE_Model")
        self.config = config
        set_random_state(self.config)

        self.logger.info("WGAN_GradientPenalty init model.")
        if self.config["model_type"] == "face": self.C = 3
        if self.config["model_type"] == "sketch": self.C = 1

        self.netG = WGAN_Generator(self.C)
        self.netD = WGAN_Discriminator(input_channels=self.C)

        # WGAN values from paper
        self.b1 = 0.5
        self.b2 = 0.999

        self.learning_rate = config["learning_rate"]
        self.batch_size = self.config["batch_size"]

        # WGAN_gradient penalty uses ADAM
        self.d_optimizer = optim.Adam(self.D.parameters(),
                                      lr=self.learning_rate,
                                      betas=(self.b1, self.b2))
        self.g_optimizer = optim.Adam(self.G.parameters(),
                                      lr=self.learning_rate,
                                      betas=(self.b1, self.b2))

        self.generator_iters = self.config["num_steps"]
        self.critic_iter = 5
        self.lambda_term = 10
Пример #2
0
    def __init__(self, config, train=False):
        """Initialize the dataset to load training or validation images according to the config.yaml file. 
        
        :param DatasetMixin: This class inherits from this class to enable a good workflow through the framework edflow.  
        :param config: This config is loaded from the config.yaml file which specifies all neccesary hyperparameter for to desired operation which will be executed by the edflow framework.
        """
        # Create Logging for the Dataset
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.logger = get_logger("Dataset")
        self.config = config
        self.data_types = self.setup_data_types()

        self.data_roots = self.get_data_roots()
        # Load parameters from config
        self.set_image_transforms()
        self.set_random_state()
        self.no_encoder = self.config["model"] in [
            "model.gan.DCGAN", "model.gan.WGAN"
        ]

        # Yet a bit sloppy but ok
        self.sketch_data = self.load_sketch_data()

        self.indices = self.load_indices(train)
Пример #3
0
 def __init__(self, config):
     super(VAE_GAN, self).__init__()
     self.config = config
     if "debug_log_level" in config and config["debug_log_level"]:
         LogSingleton.set_log_level("debug")
     self.logger = get_logger("VAE_GAN")
     assert bool("sketch" in self.config["model_type"]) != bool("face" in self.config["model_type"]), "The model_type for this VAE GAN model can only be 'sketch' or 'face' but not 'sketch2face'."
     assert config["iterator"] == "iterator.vae_gan.VAE_GAN", "This model supports only the VAE_GAN iterator."
     set_random_state(self.config)
     self.sigma = self.config["variational"]["sigma"] if "variational" in self.config and "sigma" in self.config["variational"] else False
     sketch = True if "sketch" in self.config["model_type"] else False
     self.netG = VAE_config(self.config)
     self.netD = Discriminator_sketch() if sketch else Discriminator_face()
Пример #4
0
 def __init__(self, config):
     super(VAE_WGAN, self).__init__()
     if "debug_log_level" in config and config["debug_log_level"]:
         LogSingleton.set_log_level("debug")
     # get logger and config
     self.logger = get_logger("CycleWGAN")
     self.config = config
     set_random_state(self.config)
     assert bool("sketch" in self.config["model_type"]) != bool("face" in self.config["model_type"]), "The model_type for this VAE GAN model can only be 'sketch' or 'face' but not 'sketch2face'."
     assert config["iterator"] == "iterator.vae_wgan.VAE_WGAN", "This model supports only the VAE_WGAN iterator."
     self.logger.info("VAE WGAN init model.")
     self.sigma = self.config['variational']['sigma'] if "variational" in self.config and "sigma" in self.config["variational"] else False
     
     self.netG = VAE_config(self.config)
     self.netD = WGAN_Discriminator_sketch() if "sketch" in self.config["model_type"] else WGAN_Discriminator_face(input_resolution=config["data"]["transform"]["resolution"])
Пример #5
0
    def __init__(self, config, train=False):
        """Initialize the dataset to load training or validation images according to the config.yaml file. 
        
        :param DatasetMixin: This class inherits from this class to enable a good workflow through the framework edflow.  
        :param config: This config is loaded from the config.yaml file which specifies all neccesary hyperparameter for to desired operation which will be executed by the edflow framework.
        """
        # Create Logging for the Dataset
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.logger = get_logger("Dataset")
        self.config = config

        self.style_root = self.config["data"]["style_path"]
        self.content_root = self.config["data"]["content_path"]

        self.set_image_transform()
        self.set_random_state()
        self.indices = self.load_content_indices(train)
        self.art_list = self.load_art_list()
Пример #6
0
    def __init__(self, config):
        super(DCGAN, self).__init__()
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.config = config
        self.logger = get_logger("DCGAN")
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        assert bool("sketch" in self.config["model_type"]) != bool(
            "face" in self.config["model_type"]
        ), "The model_type for this DCGAN model can only be 'sketch' or 'face' but not 'sketch2face'."
        self.sketch = True if "sketch" in self.config["model_type"] else False
        self.wasserstein = bool(
            self.config["losses"]['adversarial_loss'] == 'wasserstein')

        latent_dim = self.config['latent_dim']
        min_channels = self.config['conv']['n_channel_start']
        max_channels = self.config['conv']['n_channel_max']
        sketch_shape = [32, 1]
        face_shape = [self.config['data']['transform']['resolution'], 3]
        num_extra_conv_sketch = self.config['conv']['sketch_extra_conv']
        num_extra_conv_face = self.config['conv']['face_extra_conv']
        block_activation = nn.ReLU()
        final_activation = nn.Tanh()
        batch_norm_dec = self.config['batch_norm']
        drop_rate_dec = self.config['dropout']['dec_rate']
        drop_rate_disc = self.config['dropout']['disc_rate']
        bias_dec = self.config['bias']['dec']

        shapes = sketch_shape if "sketch" in self.config[
            "model_type"] else face_shape
        num_extra_conv = num_extra_conv_sketch if "sketch" in self.config[
            "model_type"] else num_extra_conv_face
        self.netG = VAE_Decoder(latent_dim, min_channels, max_channels,
                                *shapes, num_extra_conv, block_activation,
                                final_activation, batch_norm_dec,
                                drop_rate_dec, bias_dec)
        self.netD = Discriminator_sketch(
            droprate=drop_rate_disc, wasserstein=self.wasserstein
        ) if self.sketch else Discriminator_face(droprate=drop_rate_disc,
                                                 wasserstein=self.wasserstein)
Пример #7
0
    def __init__(self, config):
        super(Cycle_WGAN, self).__init__()
        assert config[
            "iterator"] == "iterator.cycle_wgan.Cycle_WGAN", "This model only works with the iterator: 'iterator.cycle_wgan.Cycle_WGAN"
        assert "sketch" in config["model_type"] and "face" in config[
            "model_type"], "This model only works with model_type: 'sketch2face'"
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        # get logger and config
        self.logger = get_logger("CycleWGAN")
        self.config = config
        set_random_state(self.config)
        self.logger.info("WGAN_GradientPenalty init model.")
        self.output_names = [
            'real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B'
        ]
        self.sigma = self.config['variational']['sigma']
        self.num_latent_layer = self.config['variational'][
            'num_latent_layer'] if "variational" in self.config and "num_latent_layer" in self.config[
                "variational"] else 0

        latent_dim = self.config["latent_dim"]
        min_channels = self.config['conv']['n_channel_start']
        max_channels = self.config['conv']['n_channel_max']
        sketch_shape = [32, 1]
        face_shape = [self.config['data']['transform']['resolution'], 3]
        sigma = self.config['variational']['sigma']
        num_extra_conv_sketch = self.config['conv']['sketch_extra_conv']
        num_extra_conv_face = self.config['conv']['face_extra_conv']
        block_activation = nn.ReLU()
        final_activation = nn.Tanh()
        batch_norm_enc = batch_norm_dec = self.config['batch_norm']
        drop_rate_enc = self.config['dropout'][
            'enc_rate'] if "dropout" in self.config and "enc_rate" in self.config[
                "dropout"] else 0
        drop_rate_dec = self.config['dropout'][
            'dec_rate'] if "dropout" in self.config and "dec_rate" in self.config[
                "dropout"] else 0
        bias_enc = self.config['bias'][
            'enc'] if "bias" in self.config and "enc" in self.config[
                "bias"] else True
        bias_dec = self.config['bias'][
            'dec'] if "bias" in self.config and "dec" in self.config[
                "bias"] else True
        num_latent_layer = self.config['variational'][
            'num_latent_layer'] if "variational" in self.config and "num_latent_layer" in self.config[
                "variational"] else 0
        ## cycle A ##
        self.netG_A = VAE(latent_dim=latent_dim,
                          min_channels=min_channels,
                          max_channels=max_channels,
                          in_size=sketch_shape[0],
                          in_channels=sketch_shape[1],
                          out_size=face_shape[0],
                          out_channels=face_shape[1],
                          sigma=sigma,
                          num_extra_conv_enc=num_extra_conv_sketch,
                          num_extra_conv_dec=num_extra_conv_face,
                          block_activation=block_activation,
                          final_activation=final_activation,
                          batch_norm_enc=batch_norm_enc,
                          batch_norm_dec=batch_norm_dec,
                          drop_rate_enc=drop_rate_enc,
                          drop_rate_dec=drop_rate_dec,
                          bias_enc=bias_enc,
                          bias_dec=bias_dec,
                          same_max_channels=False,
                          num_latent_layer=num_latent_layer)
        self.netD_A = WGAN_Discriminator_sketch()
        ## cycle B ##
        self.netG_B = VAE(latent_dim=latent_dim,
                          min_channels=min_channels,
                          max_channels=max_channels,
                          in_size=face_shape[0],
                          in_channels=face_shape[1],
                          out_size=sketch_shape[0],
                          out_channels=sketch_shape[1],
                          sigma=sigma,
                          num_extra_conv_enc=num_extra_conv_sketch,
                          num_extra_conv_dec=num_extra_conv_face,
                          block_activation=block_activation,
                          final_activation=final_activation,
                          batch_norm_enc=batch_norm_enc,
                          batch_norm_dec=batch_norm_dec,
                          drop_rate_enc=drop_rate_enc,
                          drop_rate_dec=drop_rate_dec,
                          bias_enc=bias_enc,
                          bias_dec=bias_dec,
                          same_max_channels=False,
                          num_latent_layer=num_latent_layer)
        self.netD_B = WGAN_Discriminator_face(input_resolution=face_shape[0])