def get_stylegan_1_two_stem_autoencoder(
    image_size: int,
    latent_size: int,
    num_input_channels: int,
    n_mlp: int = 8,
    channel_multiplier: int = 2,
    init_ckpt: str = None,
    update_latent=True,
    update_noise=True,
    encoder_class: Union[WPlusNoNoiseEncoder,
                         WNoNoiseEncoder] = WPlusNoNoiseEncoder
) -> TwoStemStyleganAutoencoder:
    generator = get_stylegan1_generator(image_size,
                                        latent_size,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    latent_encoder = encoder_class(image_size,
                                   latent_size,
                                   num_input_channels,
                                   StyleGan1Generator.get_channels(),
                                   stylegan_variant=1)

    noise_encoder = NoiseEncoder(image_size,
                                 latent_size,
                                 num_input_channels,
                                 StyleGan1Generator.get_channels(),
                                 stylegan_variant=1)

    autoencoder = TwoStemStyleganAutoencoder(latent_encoder,
                                             noise_encoder,
                                             generator,
                                             update_latent=update_latent,
                                             update_noise=update_noise)
    return autoencoder
def get_stylegan_1_superresolution_autoencoder(
        image_size: int,
        latent_size: int,
        num_input_channels: int,
        n_mlp: int = 8,
        channel_multiplier: int = 2,
        init_ckpt: str = None,
        input_size: int = None,
        encoder_class: Union[WPlusEncoder, WWPlusEncoder] = WPlusEncoder,
        autoencoder_kwargs: dict = None) -> SuperResolutionStyleganAutoencoder:
    if input_size is None:
        input_size = image_size
        warnings.warn(
            "You wanted to train superresolution but you did not supply a new output size"
        )

    assert input_size <= image_size, "For training superresolution, the image size must be greater or equal than the input size"

    generator = get_stylegan1_generator(image_size,
                                        latent_size,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    encoder = encoder_class(input_size,
                            latent_size,
                            num_input_channels,
                            StyleGan1Generator.get_channels(),
                            stylegan_variant=1)

    autoencoder = SuperResolutionStyleganAutoencoder(encoder, generator,
                                                     **autoencoder_kwargs)
    return autoencoder
def get_stylegan1_wplus_style_autoencoder(
        image_size: int,
        latent_size: int,
        num_input_channels: int,
        n_mlp: int = 8,
        channel_multiplier: int = 2,
        init_ckpt: str = None) -> ContentAndStyleStyleganAutoencoder:
    generator = get_stylegan1_generator(image_size,
                                        latent_size,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    encoder = WPlusEncoder(image_size,
                           latent_size,
                           num_input_channels * 2,
                           StyleGan1Generator.get_channels(),
                           stylegan_variant=1)

    autoencoder = ContentAndStyleStyleganAutoencoder(encoder, generator)
    return autoencoder
def get_stylegan1_wplus_noise_renset_autoencoder(
        image_size: int,
        latent_size: int,
        num_input_channels: int,
        n_mlp: int = 8,
        channel_multiplier: int = 2,
        init_ckpt: str = None) -> StyleganAutoencoder:
    generator = get_stylegan1_generator(image_size,
                                        latent_size,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    encoder = WPlusResnetNoiseEncoder(image_size,
                                      latent_size,
                                      num_input_channels,
                                      StyleGan1Generator.get_channels(),
                                      stylegan_variant=1)

    autoencoder = StyleganAutoencoder(encoder, generator)
    return autoencoder
def get_stylegan1_code_autoencoder(
        image_size: int,
        latent_size: int,
        num_input_channels: int,
        n_mlp: int = 8,
        channel_multiplier: int = 2,
        init_ckpt: str = None,
        code_dim: int = 10) -> CodeStyleganAutoencoder:
    generator = get_stylegan1_generator(image_size,
                                        latent_size + code_dim,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    encoder = WCodeEncoder(code_dim,
                           image_size,
                           latent_size,
                           num_input_channels,
                           StyleGan1Generator.get_channels(),
                           stylegan_variant=1)

    autoencoder = CodeStyleganAutoencoder(encoder, generator)
    return autoencoder
def get_stylegan1_autoencoder(
    image_size: int,
    latent_size: int,
    num_input_channels: int,
    n_mlp: int = 8,
    channel_multiplier: int = 2,
    init_ckpt: str = None,
    autoencoder_class=StyleganAutoencoder,
    encoder_class: Union[WPlusEncoder, WWPlusEncoder] = WPlusEncoder
) -> StyleganAutoencoder:
    generator = get_stylegan1_generator(image_size,
                                        latent_size,
                                        n_mlp=n_mlp,
                                        init_ckpt=init_ckpt)

    encoder = encoder_class(image_size,
                            latent_size,
                            num_input_channels,
                            StyleGan1Generator.get_channels(),
                            stylegan_variant=1)

    autoencoder = autoencoder_class(encoder, generator)
    return autoencoder
Exemplo n.º 7
0
    def build_encoder(self, checkpoint) -> UNetLikeEncoder:
        if self.projector.config['stylegan_variant'] == 1:
            channel_map = StyleGan1Generator.get_channels()
        else:
            channel_map = StyleGan2Generator.get_channels()

        encoder = UNetLikeEncoder(
            self.projector.config['image_size'],
            self.projector.config['latent_size'],
            self.projector.config['input_dim'],
            channel_map
        )
        encoder.eval()

        checkpoint = torch.load(checkpoint)

        if 'autoencoder' in checkpoint:
            # we need to adapt the tensors we actually want to load
            stripped_checkpoint = {key: value for key, value in checkpoint['autoencoder'].items() if 'encoder' in key}
            checkpoint = {'.'.join(key.split('.')[2:]): value for key, value in stripped_checkpoint.items()}

        encoder.load_state_dict(checkpoint)

        return encoder.to(self.projector.device)