Beispiel #1
0
    def __init__(self, config):
        super(WGAN, self).__init__()
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        # get logger and config
        self.logger = get_logger("VAE_Model")
        self.config = config
        set_random_state(self.config)

        self.logger.info("WGAN_GradientPenalty init model.")
        if self.config["model_type"] == "face": self.C = 3
        if self.config["model_type"] == "sketch": self.C = 1

        self.netG = WGAN_Generator(self.C)
        self.netD = WGAN_Discriminator(input_channels=self.C)

        # WGAN values from paper
        self.b1 = 0.5
        self.b2 = 0.999

        self.learning_rate = config["learning_rate"]
        self.batch_size = self.config["batch_size"]

        # WGAN_gradient penalty uses ADAM
        self.d_optimizer = optim.Adam(self.D.parameters(),
                                      lr=self.learning_rate,
                                      betas=(self.b1, self.b2))
        self.g_optimizer = optim.Adam(self.G.parameters(),
                                      lr=self.learning_rate,
                                      betas=(self.b1, self.b2))

        self.generator_iters = self.config["num_steps"]
        self.critic_iter = 5
        self.lambda_term = 10
Beispiel #2
0
    def __init__(self, config, train=False):
        """Initialize the dataset to load training or validation images according to the config.yaml file. 
        
        :param DatasetMixin: This class inherits from this class to enable a good workflow through the framework edflow.  
        :param config: This config is loaded from the config.yaml file which specifies all neccesary hyperparameter for to desired operation which will be executed by the edflow framework.
        """
        # Create Logging for the Dataset
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.logger = get_logger("Dataset")
        self.config = config
        self.data_types = self.setup_data_types()

        self.data_roots = self.get_data_roots()
        # Load parameters from config
        self.set_image_transforms()
        self.set_random_state()
        self.no_encoder = self.config["model"] in [
            "model.gan.DCGAN", "model.gan.WGAN"
        ]

        # Yet a bit sloppy but ok
        self.sketch_data = self.load_sketch_data()

        self.indices = self.load_indices(train)
Beispiel #3
0
 def __init__(self, config):
     super(VAE_GAN, self).__init__()
     self.config = config
     if "debug_log_level" in config and config["debug_log_level"]:
         LogSingleton.set_log_level("debug")
     self.logger = get_logger("VAE_GAN")
     assert bool("sketch" in self.config["model_type"]) != bool("face" in self.config["model_type"]), "The model_type for this VAE GAN model can only be 'sketch' or 'face' but not 'sketch2face'."
     assert config["iterator"] == "iterator.vae_gan.VAE_GAN", "This model supports only the VAE_GAN iterator."
     set_random_state(self.config)
     self.sigma = self.config["variational"]["sigma"] if "variational" in self.config and "sigma" in self.config["variational"] else False
     sketch = True if "sketch" in self.config["model_type"] else False
     self.netG = VAE_config(self.config)
     self.netD = Discriminator_sketch() if sketch else Discriminator_face()
Beispiel #4
0
 def __init__(self, config):
     super(VAE_WGAN, self).__init__()
     if "debug_log_level" in config and config["debug_log_level"]:
         LogSingleton.set_log_level("debug")
     # get logger and config
     self.logger = get_logger("CycleWGAN")
     self.config = config
     set_random_state(self.config)
     assert bool("sketch" in self.config["model_type"]) != bool("face" in self.config["model_type"]), "The model_type for this VAE GAN model can only be 'sketch' or 'face' but not 'sketch2face'."
     assert config["iterator"] == "iterator.vae_wgan.VAE_WGAN", "This model supports only the VAE_WGAN iterator."
     self.logger.info("VAE WGAN init model.")
     self.sigma = self.config['variational']['sigma'] if "variational" in self.config and "sigma" in self.config["variational"] else False
     
     self.netG = VAE_config(self.config)
     self.netD = WGAN_Discriminator_sketch() if "sketch" in self.config["model_type"] else WGAN_Discriminator_face(input_resolution=config["data"]["transform"]["resolution"])
Beispiel #5
0
    def __init__(self, config, train=False):
        """Initialize the dataset to load training or validation images according to the config.yaml file. 
        
        :param DatasetMixin: This class inherits from this class to enable a good workflow through the framework edflow.  
        :param config: This config is loaded from the config.yaml file which specifies all neccesary hyperparameter for to desired operation which will be executed by the edflow framework.
        """
        # Create Logging for the Dataset
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.logger = get_logger("Dataset")
        self.config = config

        self.style_root = self.config["data"]["style_path"]
        self.content_root = self.config["data"]["content_path"]

        self.set_image_transform()
        self.set_random_state()
        self.indices = self.load_content_indices(train)
        self.art_list = self.load_art_list()
Beispiel #6
0
    def __init__(self, config):
        super(DCGAN, self).__init__()
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        self.config = config
        self.logger = get_logger("DCGAN")
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        assert bool("sketch" in self.config["model_type"]) != bool(
            "face" in self.config["model_type"]
        ), "The model_type for this DCGAN model can only be 'sketch' or 'face' but not 'sketch2face'."
        self.sketch = True if "sketch" in self.config["model_type"] else False
        self.wasserstein = bool(
            self.config["losses"]['adversarial_loss'] == 'wasserstein')

        latent_dim = self.config['latent_dim']
        min_channels = self.config['conv']['n_channel_start']
        max_channels = self.config['conv']['n_channel_max']
        sketch_shape = [32, 1]
        face_shape = [self.config['data']['transform']['resolution'], 3]
        num_extra_conv_sketch = self.config['conv']['sketch_extra_conv']
        num_extra_conv_face = self.config['conv']['face_extra_conv']
        block_activation = nn.ReLU()
        final_activation = nn.Tanh()
        batch_norm_dec = self.config['batch_norm']
        drop_rate_dec = self.config['dropout']['dec_rate']
        drop_rate_disc = self.config['dropout']['disc_rate']
        bias_dec = self.config['bias']['dec']

        shapes = sketch_shape if "sketch" in self.config[
            "model_type"] else face_shape
        num_extra_conv = num_extra_conv_sketch if "sketch" in self.config[
            "model_type"] else num_extra_conv_face
        self.netG = VAE_Decoder(latent_dim, min_channels, max_channels,
                                *shapes, num_extra_conv, block_activation,
                                final_activation, batch_norm_dec,
                                drop_rate_dec, bias_dec)
        self.netD = Discriminator_sketch(
            droprate=drop_rate_disc, wasserstein=self.wasserstein
        ) if self.sketch else Discriminator_face(droprate=drop_rate_disc,
                                                 wasserstein=self.wasserstein)
Beispiel #7
0
def _train(config, root, checkpoint=None, retrain=False):
    """Run training. Loads model, iterator and dataset according to config."""
    from edflow.iterators.batches import make_batches

    LogSingleton().set_default("train")
    logger = get_logger("train")
    logger.info("Starting Training.")

    implementations = get_implementations_from_config(
        config, ["model", "iterator", "dataset"]
    )

    # fork early to avoid taking all the crap into forked processes
    logger.info("Instantiating dataset.")
    dataset = implementations["dataset"](config=config)
    dataset.expand = True
    logger.info("Number of training samples: {}".format(len(dataset)))
    n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
    n_prefetch = config.get("n_prefetch", 1)
    with make_batches(
        dataset,
        batch_size=config["batch_size"],
        shuffle=True,
        n_processes=n_processes,
        n_prefetch=n_prefetch,
        error_on_timeout=config.get("error_on_timeout", False),
    ) as batches:
        # get them going
        logger.info("Warm up batches.")
        next(batches)
        batches.reset()
        logger.info("Reset batches.")

        if "num_steps" in config:
            # set number of epochs to perform at least num_steps steps
            steps_per_epoch = len(dataset) / config["batch_size"]
            num_epochs = config["num_steps"] / steps_per_epoch
            config["num_epochs"] = math.ceil(num_epochs)
        else:
            steps_per_epoch = len(dataset) / config["batch_size"]
            num_steps = config["num_epochs"] * steps_per_epoch
            config["num_steps"] = math.ceil(num_steps)

        logger.info("Instantiating model.")
        Model = implementations["model"](config)
        if not "hook_freq" in config:
            config["hook_freq"] = 1
        compat_kwargs = dict(
            hook_freq=config["hook_freq"], num_epochs=config["num_epochs"]
        )
        logger.info("Instantiating iterator.")
        Trainer = implementations["iterator"](
            config, root, Model, dataset=dataset, **compat_kwargs
        )

        logger.info("Initializing model.")
        if checkpoint is not None:
            Trainer.initialize(checkpoint_path=checkpoint)
        else:
            Trainer.initialize()

        if retrain:
            Trainer.reset_global_step()

        # save current config
        logger.info("Starting Training with config:\n{}".format(yaml.dump(config)))
        cpath = _save_config(config, prefix="train")
        logger.info("Saved config at {}".format(cpath))

        logger.info("Iterating.")
        Trainer.iterate(batches)
Beispiel #8
0
def _test(config, root, checkpoint=None, nogpu=False, bar_position=0):
    """Run tests. Loads model, iterator and dataset from config."""
    from edflow.iterators.batches import make_batches

    LogSingleton().set_default("latest_eval")
    logger = get_logger("test")
    logger.info("Starting Evaluation.")

    if "test_batch_size" in config:
        config["batch_size"] = config["test_batch_size"]
    if "test_mode" not in config:
        config["test_mode"] = True

    implementations = get_implementations_from_config(
        config, ["model", "iterator", "dataset"]
    )

    dataset = implementations["dataset"](config=config)
    dataset.expand = True
    logger.info("Number of testing samples: {}".format(len(dataset)))
    n_processes = config.get("n_data_processes", min(16, config["batch_size"]))
    n_prefetch = config.get("n_prefetch", 1)
    batches = make_batches(
        dataset,
        batch_size=config["batch_size"],
        shuffle=False,
        n_processes=n_processes,
        n_prefetch=n_prefetch,
        error_on_timeout=config.get("error_on_timeout", False),
    )
    # get going
    next(batches)
    batches.reset()

    logger.info("Initializing model.")
    Model = implementations["model"](config)

    config["hook_freq"] = 1
    config["num_epochs"] = 1
    config["nogpu"] = nogpu
    compat_kwargs = dict(
        hook_freq=config["hook_freq"],
        bar_position=bar_position,
        nogpu=config["nogpu"],
        num_epochs=config["num_epochs"],
    )
    Evaluator = implementations["iterator"](
        config, root, Model, dataset=dataset, **compat_kwargs
    )

    logger.info("Initializing model.")
    if checkpoint is not None:
        Evaluator.initialize(checkpoint_path=checkpoint)
    else:
        Evaluator.initialize()

    # save current config
    logger.info("Starting Evaluation with config:\n{}".format(yaml.dump(config)))
    prefix = "eval"
    if bar_position > 0:
        prefix = prefix + str(bar_position)
    cpath = _save_config(config, prefix=prefix)
    logger.info("Saved config at {}".format(cpath))

    logger.info("Iterating")
    while True:
        Evaluator.iterate(batches)
        if not config.get("eval_forever", False):
            break
Beispiel #9
0
    def __init__(self, config):
        super(Cycle_WGAN, self).__init__()
        assert config[
            "iterator"] == "iterator.cycle_wgan.Cycle_WGAN", "This model only works with the iterator: 'iterator.cycle_wgan.Cycle_WGAN"
        assert "sketch" in config["model_type"] and "face" in config[
            "model_type"], "This model only works with model_type: 'sketch2face'"
        if "debug_log_level" in config and config["debug_log_level"]:
            LogSingleton.set_log_level("debug")
        # get logger and config
        self.logger = get_logger("CycleWGAN")
        self.config = config
        set_random_state(self.config)
        self.logger.info("WGAN_GradientPenalty init model.")
        self.output_names = [
            'real_A', 'fake_B', 'rec_A', 'real_B', 'fake_A', 'rec_B'
        ]
        self.sigma = self.config['variational']['sigma']
        self.num_latent_layer = self.config['variational'][
            'num_latent_layer'] if "variational" in self.config and "num_latent_layer" in self.config[
                "variational"] else 0

        latent_dim = self.config["latent_dim"]
        min_channels = self.config['conv']['n_channel_start']
        max_channels = self.config['conv']['n_channel_max']
        sketch_shape = [32, 1]
        face_shape = [self.config['data']['transform']['resolution'], 3]
        sigma = self.config['variational']['sigma']
        num_extra_conv_sketch = self.config['conv']['sketch_extra_conv']
        num_extra_conv_face = self.config['conv']['face_extra_conv']
        block_activation = nn.ReLU()
        final_activation = nn.Tanh()
        batch_norm_enc = batch_norm_dec = self.config['batch_norm']
        drop_rate_enc = self.config['dropout'][
            'enc_rate'] if "dropout" in self.config and "enc_rate" in self.config[
                "dropout"] else 0
        drop_rate_dec = self.config['dropout'][
            'dec_rate'] if "dropout" in self.config and "dec_rate" in self.config[
                "dropout"] else 0
        bias_enc = self.config['bias'][
            'enc'] if "bias" in self.config and "enc" in self.config[
                "bias"] else True
        bias_dec = self.config['bias'][
            'dec'] if "bias" in self.config and "dec" in self.config[
                "bias"] else True
        num_latent_layer = self.config['variational'][
            'num_latent_layer'] if "variational" in self.config and "num_latent_layer" in self.config[
                "variational"] else 0
        ## cycle A ##
        self.netG_A = VAE(latent_dim=latent_dim,
                          min_channels=min_channels,
                          max_channels=max_channels,
                          in_size=sketch_shape[0],
                          in_channels=sketch_shape[1],
                          out_size=face_shape[0],
                          out_channels=face_shape[1],
                          sigma=sigma,
                          num_extra_conv_enc=num_extra_conv_sketch,
                          num_extra_conv_dec=num_extra_conv_face,
                          block_activation=block_activation,
                          final_activation=final_activation,
                          batch_norm_enc=batch_norm_enc,
                          batch_norm_dec=batch_norm_dec,
                          drop_rate_enc=drop_rate_enc,
                          drop_rate_dec=drop_rate_dec,
                          bias_enc=bias_enc,
                          bias_dec=bias_dec,
                          same_max_channels=False,
                          num_latent_layer=num_latent_layer)
        self.netD_A = WGAN_Discriminator_sketch()
        ## cycle B ##
        self.netG_B = VAE(latent_dim=latent_dim,
                          min_channels=min_channels,
                          max_channels=max_channels,
                          in_size=face_shape[0],
                          in_channels=face_shape[1],
                          out_size=sketch_shape[0],
                          out_channels=sketch_shape[1],
                          sigma=sigma,
                          num_extra_conv_enc=num_extra_conv_sketch,
                          num_extra_conv_dec=num_extra_conv_face,
                          block_activation=block_activation,
                          final_activation=final_activation,
                          batch_norm_enc=batch_norm_enc,
                          batch_norm_dec=batch_norm_dec,
                          drop_rate_enc=drop_rate_enc,
                          drop_rate_dec=drop_rate_dec,
                          bias_enc=bias_enc,
                          bias_dec=bias_dec,
                          same_max_channels=False,
                          num_latent_layer=num_latent_layer)
        self.netD_B = WGAN_Discriminator_face(input_resolution=face_shape[0])