def evaluate(self, image): """Compute the style and content loss for the image specified, used at each step of the optimization. """ # Default scores start at zero, stored as tensors so it's possible to compute gradients # even if there are no layers enabled below. style_score = torch.tensor(0.0).to(self.device) hist_score = torch.tensor(0.0).to(self.device) content_score = torch.tensor(0.0).to(self.device) # Each layer can have custom weights for style and content loss, stored as Python iterators. cw = iter(self.args.content_weights) sw = iter(self.args.style_weights) hw = iter(self.args.histogram_weights) # Ask the model to prepare each layer one by one, then decide which losses to calculate. for i, f in self.model.extract(image, layers=self.all_layers): # The content loss is a mean squared error directly on the activation features. if i in self.args.content_layers: content_score += F.mse_loss(self.content_feat[i], f) * next(cw) # The style loss is mean squared error on cross-correlation statistics (aka. gram matrix). if i in self.args.style_layers: gram = histogram.square_matrix(f - 1.0) style_score += F.mse_loss(self.style_gram[i], gram) * next(sw) # Histogram loss is computed like a content loss, but only after the values have been # adjusted to match the target histogram. if i in self.args.histogram_layers: #print(f) tl = histogram.match_histograms(f, self.style_hist[i], same_range=True) hist_score += F.mse_loss(tl, f) * next(hw) # Store the image to disk at the specified intervals. if self.should_do(self.args.save_every): images.save_to_file( self.image.clone().detach().cpu(), 'output/test%04i.png' % (self.scale * 1000 + self.counter)) # Print optimization statistics at regular intervals. if self.should_do(self.args.print_every): print( 'Iteration: {} Style Loss: {:4f} Content Loss: {:4f} Histogram Loss: {:4f}' .format(self.counter, style_score.item(), content_score.item(), hist_score.item())) # Total loss is passed back to the optimizer. return content_score + hist_score + style_score
def test_match_offset(source, offset): h = histogram.extract_histograms(source) output = histogram.match_histograms(source + offset, h) assert pytest.approx(0.0, abs=1e-4) == torch.max(output - source)
def test_match_identity(source): h = histogram.extract_histograms(source) output = histogram.match_histograms(source, h) assert pytest.approx(0.0, abs=1e-6) == torch.max(output - source)