Example #1
0
    def evaluate(self, image):
        """Compute the style and content loss for the image specified, used at each step of the optimization.
        """

        # Default scores start at zero, stored as tensors so it's possible to compute gradients
        # even if there are no layers enabled below.
        style_score = torch.tensor(0.0).to(self.device)
        hist_score = torch.tensor(0.0).to(self.device)
        content_score = torch.tensor(0.0).to(self.device)

        # Each layer can have custom weights for style and content loss, stored as Python iterators.
        cw = iter(self.args.content_weights)
        sw = iter(self.args.style_weights)
        hw = iter(self.args.histogram_weights)

        # Ask the model to prepare each layer one by one, then decide which losses to calculate.
        for i, f in self.model.extract(image, layers=self.all_layers):

            # The content loss is a mean squared error directly on the activation features.
            if i in self.args.content_layers:
                content_score += F.mse_loss(self.content_feat[i], f) * next(cw)

            # The style loss is mean squared error on cross-correlation statistics (aka. gram matrix).
            if i in self.args.style_layers:
                gram = histogram.square_matrix(f - 1.0)
                style_score += F.mse_loss(self.style_gram[i], gram) * next(sw)

            # Histogram loss is computed like a content loss, but only after the values have been
            # adjusted to match the target histogram.
            if i in self.args.histogram_layers:
                #print(f)
                tl = histogram.match_histograms(f,
                                                self.style_hist[i],
                                                same_range=True)
                hist_score += F.mse_loss(tl, f) * next(hw)

        # Store the image to disk at the specified intervals.
        if self.should_do(self.args.save_every):
            images.save_to_file(
                self.image.clone().detach().cpu(),
                'output/test%04i.png' % (self.scale * 1000 + self.counter))

        # Print optimization statistics at regular intervals.
        if self.should_do(self.args.print_every):
            print(
                'Iteration: {}    Style Loss: {:4f}     Content Loss: {:4f}    Histogram Loss: {:4f}'
                .format(self.counter, style_score.item(), content_score.item(),
                        hist_score.item()))

        # Total loss is passed back to the optimizer.
        return content_score + hist_score + style_score
Example #2
0
def test_match_offset(source, offset):
    h = histogram.extract_histograms(source)
    output = histogram.match_histograms(source + offset, h)
    assert pytest.approx(0.0, abs=1e-4) == torch.max(output - source)
Example #3
0
def test_match_identity(source):
    h = histogram.extract_histograms(source)
    output = histogram.match_histograms(source, h)
    assert pytest.approx(0.0, abs=1e-6) == torch.max(output - source)