def train_decoder(self, layer_name): decoder = self.observed_layers[layer_name]['decoder'] n_epochs = self.observed_layers[layer_name]['n_epochs'] learning_rate = self.observed_layers[layer_name].get( 'learning_rate', 1e-3) training_loss_values = [] image_loss = nn.MSELoss() optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate) for epoch_index in range(n_epochs): #print(f'Epoch {epoch_index}') epoch_loss = 0 # reconstruct self.encoder(vgg_normalization(self.source_tensor).unsqueeze(0)) embedding = self.encoder_layers[layer_name] generated_tensor = decoder(embedding).squeeze() # re-embed self.encoder(vgg_normalization(generated_tensor).unsqueeze(0)) generated_embedding = self.encoder_layers[layer_name] loss = image_loss(self.source_tensor, generated_tensor) + \ torch.norm(embedding - generated_embedding) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss training_loss_values.append(epoch_loss) return training_loss_values
def __init__(self, source_image, observed_layers, n_bins=128): """ :param source_image: source image :type image: PIL Image object :param observed_layers: dictionary containing layer-specific information ; see bottom of file decoders.py :type observed_layers: dictionary :param n_bins: number of transportation histogram bins, defaults to 128 [TO BE IMPLEMENTED] :type n_bins: int, optional """ # source image self.source_tensor = image_preprocessing(source_image) self.normalized_source_batch = vgg_normalization( self.source_tensor).unsqueeze(0) self.source_batch = self.source_tensor.unsqueeze(0) # set encoder self.encoder = vgg19(pretrained=True).float() self.encoder.eval() for param in self.encoder.features.parameters(): param.requires_grad = False self.encoder_layers = {} self.set_encoder_hooks(observed_layers) self.n_bins = n_bins
def __init__(self, style_image, content_image, observed_layers, n_bins=128, decoder_weights_path=None): super().__init__(style_image, observed_layers, n_bins=n_bins) # input image self.content_tensor = image_preprocessing(content_image) self.normalized_content_batch = vgg_normalization( self.content_tensor).unsqueeze(0) self.content_batch = self.content_tensor.unsqueeze(0)
def transfer(self, n_passes=5, content_strength=0.5): """Style transfer :param n_passes: number of global passes, defaults to 5 :type n_passes: int, optional :param content_strength: content strength, defaults to 0.5 :type content_strength: float, optional :return: generated images layer by layer, step by step :rtype: list """ self.n_passes = n_passes pass_generated_images = [] # initialize with noise self.target_tensor = torch.randn_like(self.source_tensor) for global_pass in range(n_passes): print(f'global pass {global_pass}') for layer_name, layer_information in self.observed_layers.items(): print(f'layer {layer_name}') # forward pass on source image self.encoder(self.normalized_source_batch) source_layer = self.encoder_layers[layer_name] # forward pass on content image self.encoder(self.normalized_content_batch) content_layer = self.encoder_layers[layer_name] # forward pass on target image target_batch = vgg_normalization( self.target_tensor).unsqueeze(0) self.encoder(target_batch) target_layer = self.encoder_layers[layer_name] # transport target_layer = self.optimal_transport(layer_name, source_layer.squeeze(), target_layer.squeeze()) target_layer = target_layer.view_as(source_layer) # feature style transfer target_layer += content_strength * \ (content_layer - target_layer) # decode decoder = layer_information['decoder'] self.target_tensor = decoder(target_layer).squeeze() generated_image = np.transpose(self.target_tensor.numpy(), (1, 2, 0)).copy() pass_generated_images.append(generated_image) return pass_generated_images
def __init__(self, style_image, content_image, observed_layers, n_bins=128): """ :param content_image: content image :type content_image: PIL Image object """ super().__init__(style_image, observed_layers, n_bins=n_bins) # input image self.content_tensor = image_preprocessing(content_image) self.normalized_content_batch = vgg_normalization( self.content_tensor).unsqueeze(0) self.content_batch = self.content_tensor.unsqueeze(0)
def __init__(self, image, observed_layers, n_bins=128): # source image self.source_tensor = image_preprocessing(image) self.normalized_source_batch = vgg_normalization( self.source_tensor).unsqueeze(0) self.source_batch = self.source_tensor.unsqueeze(0) # set encoder self.encoder = vgg19(pretrained=True).float() self.encoder.eval() for param in self.encoder.features.parameters(): param.requires_grad = False self.encoder_layers = {} self.set_encoder_hooks(observed_layers) self.n_bins = n_bins
def generate(self, n_passes=5): self.n_passes = n_passes pass_generated_images = [] # initialize with noise self.target_tensor = torch.randn_like(self.source_tensor) for global_pass in range(n_passes): print(f'global pass {global_pass}') for layer_name, layer_information in self.observed_layers.items(): print(f'layer {layer_name}') # forward pass on source image self.encoder(self.normalized_source_batch) source_layer = self.encoder_layers[layer_name] # forward pass on target image target_batch = vgg_normalization( self.target_tensor).unsqueeze(0) self.encoder(target_batch) target_layer = self.encoder_layers[layer_name] # transport target_layer = self.optimal_transport(layer_name, source_layer.squeeze(), target_layer.squeeze()) target_layer = target_layer.view_as(source_layer) # decode decoder = layer_information['decoder'] self.target_tensor = decoder(target_layer).squeeze() generated_image = np.transpose(self.target_tensor.numpy(), (1, 2, 0)).copy() pass_generated_images.append(generated_image) return pass_generated_images