def get_stylegan_1_two_stem_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None, update_latent=True, update_noise=True, encoder_class: Union[WPlusNoNoiseEncoder, WNoNoiseEncoder] = WPlusNoNoiseEncoder ) -> TwoStemStyleganAutoencoder: generator = get_stylegan1_generator(image_size, latent_size, n_mlp=n_mlp, init_ckpt=init_ckpt) latent_encoder = encoder_class(image_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) noise_encoder = NoiseEncoder(image_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = TwoStemStyleganAutoencoder(latent_encoder, noise_encoder, generator, update_latent=update_latent, update_noise=update_noise) return autoencoder
def get_stylegan_1_superresolution_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None, input_size: int = None, encoder_class: Union[WPlusEncoder, WWPlusEncoder] = WPlusEncoder, autoencoder_kwargs: dict = None) -> SuperResolutionStyleganAutoencoder: if input_size is None: input_size = image_size warnings.warn( "You wanted to train superresolution but you did not supply a new output size" ) assert input_size <= image_size, "For training superresolution, the image size must be greater or equal than the input size" generator = get_stylegan1_generator(image_size, latent_size, n_mlp=n_mlp, init_ckpt=init_ckpt) encoder = encoder_class(input_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = SuperResolutionStyleganAutoencoder(encoder, generator, **autoencoder_kwargs) return autoencoder
def get_stylegan1_wplus_style_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None) -> ContentAndStyleStyleganAutoencoder: generator = get_stylegan1_generator(image_size, latent_size, n_mlp=n_mlp, init_ckpt=init_ckpt) encoder = WPlusEncoder(image_size, latent_size, num_input_channels * 2, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = ContentAndStyleStyleganAutoencoder(encoder, generator) return autoencoder
def get_stylegan1_wplus_noise_renset_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None) -> StyleganAutoencoder: generator = get_stylegan1_generator(image_size, latent_size, n_mlp=n_mlp, init_ckpt=init_ckpt) encoder = WPlusResnetNoiseEncoder(image_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = StyleganAutoencoder(encoder, generator) return autoencoder
def get_stylegan1_code_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None, code_dim: int = 10) -> CodeStyleganAutoencoder: generator = get_stylegan1_generator(image_size, latent_size + code_dim, n_mlp=n_mlp, init_ckpt=init_ckpt) encoder = WCodeEncoder(code_dim, image_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = CodeStyleganAutoencoder(encoder, generator) return autoencoder
def get_stylegan1_autoencoder( image_size: int, latent_size: int, num_input_channels: int, n_mlp: int = 8, channel_multiplier: int = 2, init_ckpt: str = None, autoencoder_class=StyleganAutoencoder, encoder_class: Union[WPlusEncoder, WWPlusEncoder] = WPlusEncoder ) -> StyleganAutoencoder: generator = get_stylegan1_generator(image_size, latent_size, n_mlp=n_mlp, init_ckpt=init_ckpt) encoder = encoder_class(image_size, latent_size, num_input_channels, StyleGan1Generator.get_channels(), stylegan_variant=1) autoencoder = autoencoder_class(encoder, generator) return autoencoder
def build_encoder(self, checkpoint) -> UNetLikeEncoder: if self.projector.config['stylegan_variant'] == 1: channel_map = StyleGan1Generator.get_channels() else: channel_map = StyleGan2Generator.get_channels() encoder = UNetLikeEncoder( self.projector.config['image_size'], self.projector.config['latent_size'], self.projector.config['input_dim'], channel_map ) encoder.eval() checkpoint = torch.load(checkpoint) if 'autoencoder' in checkpoint: # we need to adapt the tensors we actually want to load stripped_checkpoint = {key: value for key, value in checkpoint['autoencoder'].items() if 'encoder' in key} checkpoint = {'.'.join(key.split('.')[2:]): value for key, value in stripped_checkpoint.items()} encoder.load_state_dict(checkpoint) return encoder.to(self.projector.device)