예제 #1
0
    def train_decoder(self, layer_name):
        decoder = self.observed_layers[layer_name]['decoder']
        n_epochs = self.observed_layers[layer_name]['n_epochs']
        learning_rate = self.observed_layers[layer_name].get(
            'learning_rate', 1e-3)
        training_loss_values = []
        image_loss = nn.MSELoss()
        optimizer = torch.optim.Adam(decoder.parameters(), lr=learning_rate)
        for epoch_index in range(n_epochs):
            #print(f'Epoch {epoch_index}')
            epoch_loss = 0

            # reconstruct
            self.encoder(vgg_normalization(self.source_tensor).unsqueeze(0))
            embedding = self.encoder_layers[layer_name]
            generated_tensor = decoder(embedding).squeeze()

            # re-embed
            self.encoder(vgg_normalization(generated_tensor).unsqueeze(0))
            generated_embedding = self.encoder_layers[layer_name]

            loss = image_loss(self.source_tensor, generated_tensor) + \
                torch.norm(embedding - generated_embedding)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss += loss

            training_loss_values.append(epoch_loss)
        return training_loss_values
예제 #2
0
    def __init__(self, source_image, observed_layers, n_bins=128):
        """

        :param source_image: source image
        :type image: PIL Image object
        :param observed_layers: dictionary containing layer-specific information ; see bottom of file decoders.py
        :type observed_layers: dictionary
        :param n_bins: number of transportation histogram bins, defaults to 128 [TO  BE IMPLEMENTED]
        :type n_bins: int, optional
        """

        # source image
        self.source_tensor = image_preprocessing(source_image)
        self.normalized_source_batch = vgg_normalization(
            self.source_tensor).unsqueeze(0)
        self.source_batch = self.source_tensor.unsqueeze(0)

        # set encoder
        self.encoder = vgg19(pretrained=True).float()
        self.encoder.eval()
        for param in self.encoder.features.parameters():
            param.requires_grad = False
        self.encoder_layers = {}
        self.set_encoder_hooks(observed_layers)

        self.n_bins = n_bins
예제 #3
0
    def __init__(self, style_image, content_image, observed_layers, n_bins=128, decoder_weights_path=None):

        super().__init__(style_image, observed_layers, n_bins=n_bins)
        # input image
        self.content_tensor = image_preprocessing(content_image)
        self.normalized_content_batch = vgg_normalization(
            self.content_tensor).unsqueeze(0)
        self.content_batch = self.content_tensor.unsqueeze(0)
예제 #4
0
    def transfer(self, n_passes=5, content_strength=0.5):
        """Style transfer 

        :param n_passes: number of global passes, defaults to 5
        :type n_passes: int, optional
        :param content_strength: content strength, defaults to 0.5
        :type content_strength: float, optional
        :return: generated images layer by layer, step by step
        :rtype: list
        """

        self.n_passes = n_passes
        pass_generated_images = []

        # initialize with noise
        self.target_tensor = torch.randn_like(self.source_tensor)

        for global_pass in range(n_passes):
            print(f'global pass {global_pass}')
            for layer_name, layer_information in self.observed_layers.items():
                print(f'layer {layer_name}')

                # forward pass on source image
                self.encoder(self.normalized_source_batch)
                source_layer = self.encoder_layers[layer_name]

                # forward pass on content image
                self.encoder(self.normalized_content_batch)
                content_layer = self.encoder_layers[layer_name]

                # forward pass on target image
                target_batch = vgg_normalization(
                    self.target_tensor).unsqueeze(0)
                self.encoder(target_batch)
                target_layer = self.encoder_layers[layer_name]

                # transport
                target_layer = self.optimal_transport(layer_name,
                                                      source_layer.squeeze(),
                                                      target_layer.squeeze())
                target_layer = target_layer.view_as(source_layer)

                # feature style transfer
                target_layer += content_strength * \
                    (content_layer - target_layer)

                # decode
                decoder = layer_information['decoder']
                self.target_tensor = decoder(target_layer).squeeze()

                generated_image = np.transpose(self.target_tensor.numpy(),
                                               (1, 2, 0)).copy()
                pass_generated_images.append(generated_image)

        return pass_generated_images
예제 #5
0
    def __init__(self,
                 style_image,
                 content_image,
                 observed_layers,
                 n_bins=128):
        """
        :param content_image: content image
        :type content_image: PIL Image object
        """

        super().__init__(style_image, observed_layers, n_bins=n_bins)
        # input image
        self.content_tensor = image_preprocessing(content_image)
        self.normalized_content_batch = vgg_normalization(
            self.content_tensor).unsqueeze(0)
        self.content_batch = self.content_tensor.unsqueeze(0)
예제 #6
0
    def __init__(self, image, observed_layers, n_bins=128):

        # source image
        self.source_tensor = image_preprocessing(image)
        self.normalized_source_batch = vgg_normalization(
            self.source_tensor).unsqueeze(0)
        self.source_batch = self.source_tensor.unsqueeze(0)

        # set encoder
        self.encoder = vgg19(pretrained=True).float()
        self.encoder.eval()
        for param in self.encoder.features.parameters():
            param.requires_grad = False
        self.encoder_layers = {}
        self.set_encoder_hooks(observed_layers)

        self.n_bins = n_bins
예제 #7
0
    def generate(self, n_passes=5):

        self.n_passes = n_passes
        pass_generated_images = []

        # initialize with noise
        self.target_tensor = torch.randn_like(self.source_tensor)

        for global_pass in range(n_passes):
            print(f'global pass {global_pass}')
            for layer_name, layer_information in self.observed_layers.items():
                print(f'layer {layer_name}')

                # forward pass on source image
                self.encoder(self.normalized_source_batch)
                source_layer = self.encoder_layers[layer_name]

                # forward pass on target image
                target_batch = vgg_normalization(
                    self.target_tensor).unsqueeze(0)
                self.encoder(target_batch)
                target_layer = self.encoder_layers[layer_name]

                # transport
                target_layer = self.optimal_transport(layer_name,
                                                      source_layer.squeeze(),
                                                      target_layer.squeeze())
                target_layer = target_layer.view_as(source_layer)

                # decode
                decoder = layer_information['decoder']
                self.target_tensor = decoder(target_layer).squeeze()

                generated_image = np.transpose(self.target_tensor.numpy(),
                                               (1, 2, 0)).copy()
                pass_generated_images.append(generated_image)

        return pass_generated_images