Example #1
0
def style_transfer_ori(vgg,
                       decoder,
                       fc_encoder,
                       fc_decoder,
                       content,
                       style,
                       alpha=1.0,
                       interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_feat = vgg(content)
    style_feat = vgg(style)
    style_feat_mean_std = calc_feat_mean_std(style_feat)
    print(style_feat_mean_std)
    intermediate = fc_encoder(style_feat_mean_std)
    intermediate_mean = intermediate[:, :512]
    intermediate_std = intermediate[:, 512:]
    noise = torch.randn_like(intermediate_mean)
    sampling = intermediate_mean + noise * intermediate_std  #N, 512
    style_feat_mean_std_recons = fc_decoder(sampling)  #N, 1024
    if interpolation_weights:
        _, C, H, W = content_feat.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_feat, style_feat)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_feat = content_feat[0:1]
    else:
        feat = adaptive_instance_normalization(content_feat, style_feat)

    feat = feat * alpha + content_feat * (1 - alpha)
    return decoder(feat)
Example #2
0
    def transfer_helper(self,
                        style,
                        content,
                        alpha=1.0,
                        interpolation_weights=None):
        """Helper method for style transfer.

        Keyword arguments:
        style -- a torch.autograd.variable.Variable with the style image
        content -- a torch.autograd.variable.Variable with the content image
        alpha --
        interpolation_weights --

        Returns:
        a torch.FloatTensor with the transferred image
        """

        assert (0.0 <= alpha <= 1.0)
        content_f = self.vgg(content)
        style_f = self.vgg(style)
        if interpolation_weights:
            _, C, H, W = content_f.size()
            feat = Variable(torch.FloatTensor(1, C, H, W).zero_().cuda(),
                            volatile=True)
            base_feat = adaptive_instance_normalization(content_f, style_f)
            for i, w in enumerate(interpolation_weights):
                feat = feat + w * base_feat[i:i + 1]
            content_f = content_f[0:1]
        else:
            feat = adaptive_instance_normalization(content_f, style_f)
        feat = feat * alpha + content_f * (1 - alpha)
        return self.decoder(feat)
Example #3
0
def style_transfer(vgg,
                   decoder,
                   content,
                   style,
                   alpha=1.0,
                   conv1=None,
                   conv2=None,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    if (conv1):
        shape = content_f.shape
        mat = conv1(content_f).view(shape[0], shape[1] // 2, -1)
        another_mat = conv2(content_f).view(shape[0], shape[1], -1)
        mat = torch.matmul(mat.transpose(1, 2), mat)
        mat = nn.functional.softmax(mat, 2)
        att = torch.matmul(another_mat, mat).reshape(shape)
        # content_f = att*content_f + content_f
    style_f = vgg(style)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
Example #4
0
def style_transfer(vgg, decoder, content, style, alpha=1.0, safin_list=None,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)[-1]  
    style_f = vgg(style)[-1] 
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        if safin_list:
            skips = {}
            transformed_f = vgg.encode_transform(safin_list[0], content, style, skips)
            base_feat = safin_list[1](transformed_f, style_f)
        else:
            base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        if safin_list:
            skips = {}
            transformed_f = vgg.encode_transform(safin_list[0], content, style, skips)
            feat = safin_list[1](transformed_f, style_f)
        else:
            feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    if safin_list: return decoder(feat, skips)
    else: return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha, interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
Example #6
0
def style_transfer(vgg, decoder, content, style, alpha=1.0):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
 def _transform(self, img, param):
     alpha, style_f = param
     with torch.no_grad():
         content_f = self.vgg(img.to(self.enc_device))
     feat = adaptive_instance_normalization(content_f, style_f)
     feat = alpha * feat + (1 - alpha) * content_f
     feat = feat.to(self.dec_device)
     stylized = self.decoder(feat)
     stylized = stylized.to(self.enc_device)
     return stylized
Example #8
0
    def transfer(self,
                 content,
                 style,
                 preserve_color=False,
                 alpha=1.0,
                 interpolation_weights=None):
        """
        CONTENT is always a single image.
        STYLE can be either a single image or a list of images.
        If STYLE is a list of images, you must also pass a list of INTERPOLATION_WEIGHTS.
        """
        if interpolation_weights:
            # one content image, N style images
            style = torch.stack([self.style_tf(s) for s in style])
            content = self.content_tf(content).unsqueeze(0).expand_as(style)
            style = style.to(self.device)
            content = content.to(self.device)
        else:
            # one content image, one style image
            content = self.content_tf(content)
            style = self.style_tf(style)
            if preserve_color:
                style = coral(style, content)
            style = style.to(self.device).unsqueeze(0)
            content = content.to(self.device).unsqueeze(0)

        with torch.no_grad():
            content_f = self.vgg(content)
            style_f = self.vgg(style)

            if interpolation_weights:
                _, C, H, W = content_f.size()
                feat = torch.FloatTensor(1, C, H, W).zero_().to(self.device)
                base_feat = adaptive_instance_normalization(content_f, style_f)
                for i, w in enumerate(interpolation_weights):
                    feat = feat + w * base_feat[i:i + 1]
                content_f = content_f[0:1]
            else:
                feat = adaptive_instance_normalization(content_f, style_f)
            feat = feat * alpha + content_f * (1 - alpha)
            output = self.decoder(feat)
        return output.cpu()
def style_transfer(vgg,
                   decoder,
                   content,
                   style,
                   alpha=1.0,
                   interpolation_weights=None):
    content_f = vgg(content)
    style_f = vgg(style)
    feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
Example #10
0
def style_transfer(vgg, decoder, content, style, alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    if args.mode in ["SE64x+BD", "SE16x+BD", "E2D1"]:
      content_f = vgg.forward_aux(content, False)[-1]
      style_f = vgg.forward_aux(style, False)[-1]
    else:
      content_f = vgg(content)
      style_f = vgg(style)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_f)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
Example #11
0
def style_transfer(vgg,
                   decoder,
                   content,
                   style,
                   alpha=1.0,
                   interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    style_f = vgg(style)
    if (content_f.size() != style_f.size()):
        print(content_f.size())
        print(style_f.size())
    feat = adaptive_instance_normalization(content_f, style_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)
    def _transform(self, img, param):
        alpha, style = param
        style = style[:img.size(
        )[0]]  # undercomplete batches (last dataloader iteration) or no batching at all
        style = style.to(self.enc_device)
        with torch.no_grad():
            content_f = self.vgg(img)
            style_f = self.vgg(style)
            feat = adaptive_instance_normalization(content_f, style_f)

        feat = alpha * feat + (1 - alpha) * content_f
        if self.enc_device != self.dec_device:
            feat = feat.to(self.dec_device)
        stylized = self.decoder(feat)
        if self.enc_device != self.dec_device:
            stylized = stylized.to(self.enc_device)
        return stylized
Example #13
0
def style_transfer(vgg, decoder, abstracter, corrector, content, style, alpha=1.0, interpolation_weights=None):
    assert (0.0 <= alpha <= 1.0)
    content_f = vgg(content)
    #style_gen = vgg(style)
    style_gen = abstracter.execute(content_f) # Attension content -> style
    correct_f = corrector.execute(content_f)
    if interpolation_weights:
        _, C, H, W = content_f.size()
        feat = torch.FloatTensor(1, C, H, W).zero_().to(device)
        base_feat = adaptive_instance_normalization(content_f, style_gen)
        for i, w in enumerate(interpolation_weights):
            feat = feat + w * base_feat[i:i + 1]
        content_f = content_f[0:1]
    else:
        #feat = single_adaptive_instance_normalization(content_f, style_gen)
        feat = correct_adaptive_instance_normalization(content_f, style_gen, correct_f)
    feat = feat * alpha + content_f * (1 - alpha)
    return decoder(feat)