示例#1
0
    def run(self):
        """Main entry point for style transfer, operates coarse-to-fine as specified by the number of scales.
        """

        for self.scale in range(0, self.args.scales):
            # Pre-process the input images so they have the expected size.
            factor = 2 ** (self.args.scales - self.scale - 1)
            content_imgs = []
            for img in self.content_imgs:
                content_imgs.append(resize.DownscaleBuilder(factor, cuda=self.cuda).build(img))
            style_imgs = []
            for img in self.style_imgs:
               style_imgs.append(resize.DownscaleBuilder(factor, cuda=self.cuda).build(img))

            # Determine the stating point for the optimizer, was there an output of previous scale?
            if self.seed_img is None:
                # a) Load an image from disk, this needs to be the exact right size.
                if self.args.seed is not None:
                    seed_img = images.load_from_file(self.args.seed, self.device)
                    #seed_img = resize.DownscaleBuilder(factor).build(self.seed_img)
                    #print(seed_img.shape, content_img.shape)
                    assert seed_img.shape == content_imgs[0].shape

                # b) Use completely random buffer from a normal distribution.
                else:
                    seed_img = torch.empty_like(content_imgs[0]).normal_(std=0.5).clamp_(-2.0, +2.0)
            else:
                # c) There was a previous scale, so resize and add noise from normal distribution. 
                seed_img = (resize.DownscaleBuilder(factor, cuda=self.cuda).build(self.seed_img)
                           + torch.empty_like(content_imgs[0]).normal_(std=0.1)).clamp_(-2.0, +2.0)

            # Pre-compute the cross-correlation statistics for the style image layers (aka. gram matrices).
            self.style_gram = {}
            n = 0
            for img in style_imgs: 
                for i, f in self.model.extract(img, layers=self.args.style_layers):
                    self.style_gram[n, i] = histogram.square_matrix(f - 1.0).detach()
                n = n + 1
            # Pre-compute feature histograms for the style image layers specified.
            self.style_hist = {}
            n = 0
            for img in style_imgs: 
              for k, v in self.model.extract(img, layers=self.args.histogram_layers):
                self.style_hist[n, k] = histogram.extract_histograms(v, bins=5, min=torch.tensor(-1.0), max=torch.tensor(+4.0))
              n = n + 1
            # Prepare and store the content image activations for image layers too.
            self.content_feat = {}
            n = 0
            for img in content_imgs:
                for i, f in self.model.extract(img, layers=self.args.content_layers):
                   self.content_feat[n, i] = f.detach()
                n = n + 1
            # Now run the optimization using L-BFGS starting from the seed image.
            output = self.optimize(seed_img, self.iterations[self.scale]) #, lr=0.2)

            # For the next scale, we'll reuse a biliniear interpolated version of this output.
            self.seed_img = resize.UpscaleBuilder(factor, mode='bilinear').build(output).detach()

        # Save the final image at the finest scale to disk.
        basename = os.path.splitext(os.path.basename(self.args.content or self.args.style))[0]
        images.save_to_file(self.image.clone().detach().cpu(), self.args.output or ('output/%s_final.png' % basename))
示例#2
0
def test_match_offset(source, offset):
    h = histogram.extract_histograms(source)
    output = histogram.match_histograms(source + offset, h)
    assert pytest.approx(0.0, abs=1e-4) == torch.max(output - source)
示例#3
0
def test_match_identity(source):
    h = histogram.extract_histograms(source)
    output = histogram.match_histograms(source, h)
    assert pytest.approx(0.0, abs=1e-6) == torch.max(output - source)
示例#4
0
def test_extract_normalized(source):
    h = histogram.extract_histograms(source)
    assert pytest.approx(1.0, abs=1e-6) == torch.sum(h[0])
示例#5
0
def test_extract_balanced(source):
    h = histogram.extract_histograms(source)
    assert pytest.approx(0.0,
                         abs=1e-6) == torch.mean(h[0] - 1.0 / h[0].shape[2])
示例#6
0
def test_extract_deterministic(source):
    h1 = histogram.extract_histograms(source)
    h2 = histogram.extract_histograms(source)
    assert (h1[0] == h2[0]).all()