Ejemplo n.º 1
0
    def evaluate(self, image):
        """Compute the style and content loss for the image specified, used at each step of the optimization.
        """

        # Default scores start at zero, stored as tensors so it's possible to compute gradients
        # even if there are no layers enabled below.
        style_score = torch.tensor(0.0).to(self.device)
        hist_score = torch.tensor(0.0).to(self.device)
        content_score = torch.tensor(0.0).to(self.device)

        # Each layer can have custom weights for style and content loss, stored as Python iterators.
        cw = iter(self.args.content_weights)
        sw = iter(self.args.style_weights)
        hw = iter(self.args.histogram_weights)

        # Ask the model to prepare each layer one by one, then decide which losses to calculate.
        for i, f in self.model.extract(image, layers=self.all_layers):

            # The content loss is a mean squared error directly on the activation features.
            if i in self.args.content_layers:
                content_score += F.mse_loss(self.content_feat[i], f) * next(cw)

            # The style loss is mean squared error on cross-correlation statistics (aka. gram matrix).
            if i in self.args.style_layers:
                gram = histogram.square_matrix(f - 1.0)
                style_score += F.mse_loss(self.style_gram[i], gram) * next(sw)

            # Histogram loss is computed like a content loss, but only after the values have been
            # adjusted to match the target histogram.
            if i in self.args.histogram_layers:
                #print(f)
                tl = histogram.match_histograms(f,
                                                self.style_hist[i],
                                                same_range=True)
                hist_score += F.mse_loss(tl, f) * next(hw)

        # Store the image to disk at the specified intervals.
        if self.should_do(self.args.save_every):
            images.save_to_file(
                self.image.clone().detach().cpu(),
                'output/test%04i.png' % (self.scale * 1000 + self.counter))

        # Print optimization statistics at regular intervals.
        if self.should_do(self.args.print_every):
            print(
                'Iteration: {}    Style Loss: {:4f}     Content Loss: {:4f}    Histogram Loss: {:4f}'
                .format(self.counter, style_score.item(), content_score.item(),
                        hist_score.item()))

        # Total loss is passed back to the optimizer.
        return content_score + hist_score + style_score
Ejemplo n.º 2
0
    def run(self):
        """Main entry point for style transfer, operates coarse-to-fine as specified by the number of scales.
        """

        for self.scale in range(0, self.args.scales):
            # Pre-process the input images so they have the expected size.
            factor = 2 ** (self.args.scales - self.scale - 1)
            content_imgs = []
            for img in self.content_imgs:
                content_imgs.append(resize.DownscaleBuilder(factor, cuda=self.cuda).build(img))
            style_imgs = []
            for img in self.style_imgs:
               style_imgs.append(resize.DownscaleBuilder(factor, cuda=self.cuda).build(img))

            # Determine the stating point for the optimizer, was there an output of previous scale?
            if self.seed_img is None:
                # a) Load an image from disk, this needs to be the exact right size.
                if self.args.seed is not None:
                    seed_img = images.load_from_file(self.args.seed, self.device)
                    #seed_img = resize.DownscaleBuilder(factor).build(self.seed_img)
                    #print(seed_img.shape, content_img.shape)
                    assert seed_img.shape == content_imgs[0].shape

                # b) Use completely random buffer from a normal distribution.
                else:
                    seed_img = torch.empty_like(content_imgs[0]).normal_(std=0.5).clamp_(-2.0, +2.0)
            else:
                # c) There was a previous scale, so resize and add noise from normal distribution. 
                seed_img = (resize.DownscaleBuilder(factor, cuda=self.cuda).build(self.seed_img)
                           + torch.empty_like(content_imgs[0]).normal_(std=0.1)).clamp_(-2.0, +2.0)

            # Pre-compute the cross-correlation statistics for the style image layers (aka. gram matrices).
            self.style_gram = {}
            n = 0
            for img in style_imgs: 
                for i, f in self.model.extract(img, layers=self.args.style_layers):
                    self.style_gram[n, i] = histogram.square_matrix(f - 1.0).detach()
                n = n + 1
            # Pre-compute feature histograms for the style image layers specified.
            self.style_hist = {}
            n = 0
            for img in style_imgs: 
              for k, v in self.model.extract(img, layers=self.args.histogram_layers):
                self.style_hist[n, k] = histogram.extract_histograms(v, bins=5, min=torch.tensor(-1.0), max=torch.tensor(+4.0))
              n = n + 1
            # Prepare and store the content image activations for image layers too.
            self.content_feat = {}
            n = 0
            for img in content_imgs:
                for i, f in self.model.extract(img, layers=self.args.content_layers):
                   self.content_feat[n, i] = f.detach()
                n = n + 1
            # Now run the optimization using L-BFGS starting from the seed image.
            output = self.optimize(seed_img, self.iterations[self.scale]) #, lr=0.2)

            # For the next scale, we'll reuse a biliniear interpolated version of this output.
            self.seed_img = resize.UpscaleBuilder(factor, mode='bilinear').build(output).detach()

        # Save the final image at the finest scale to disk.
        basename = os.path.splitext(os.path.basename(self.args.content or self.args.style))[0]
        images.save_to_file(self.image.clone().detach().cpu(), self.args.output or ('output/%s_final.png' % basename))