def style_transfer_ori(vgg, decoder, fc_encoder, fc_decoder, content, style, alpha=1.0, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_feat = vgg(content) style_feat = vgg(style) style_feat_mean_std = calc_feat_mean_std(style_feat) print(style_feat_mean_std) intermediate = fc_encoder(style_feat_mean_std) intermediate_mean = intermediate[:, :512] intermediate_std = intermediate[:, 512:] noise = torch.randn_like(intermediate_mean) sampling = intermediate_mean + noise * intermediate_std #N, 512 style_feat_mean_std_recons = fc_decoder(sampling) #N, 1024 if interpolation_weights: _, C, H, W = content_feat.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_feat, style_feat) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_feat = content_feat[0:1] else: feat = adaptive_instance_normalization(content_feat, style_feat) feat = feat * alpha + content_feat * (1 - alpha) return decoder(feat)
def transfer_helper(self, style, content, alpha=1.0, interpolation_weights=None): """Helper method for style transfer. Keyword arguments: style -- a torch.autograd.variable.Variable with the style image content -- a torch.autograd.variable.Variable with the content image alpha -- interpolation_weights -- Returns: a torch.FloatTensor with the transferred image """ assert (0.0 <= alpha <= 1.0) content_f = self.vgg(content) style_f = self.vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = Variable(torch.FloatTensor(1, C, H, W).zero_().cuda(), volatile=True) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return self.decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha=1.0, conv1=None, conv2=None, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_f = vgg(content) if (conv1): shape = content_f.shape mat = conv1(content_f).view(shape[0], shape[1] // 2, -1) another_mat = conv2(content_f).view(shape[0], shape[1], -1) mat = torch.matmul(mat.transpose(1, 2), mat) mat = nn.functional.softmax(mat, 2) att = torch.matmul(another_mat, mat).reshape(shape) # content_f = att*content_f + content_f style_f = vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha=1.0, safin_list=None, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_f = vgg(content)[-1] style_f = vgg(style)[-1] if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) if safin_list: skips = {} transformed_f = vgg.encode_transform(safin_list[0], content, style, skips) base_feat = safin_list[1](transformed_f, style_f) else: base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: if safin_list: skips = {} transformed_f = vgg.encode_transform(safin_list[0], content, style, skips) feat = safin_list[1](transformed_f, style_f) else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) if safin_list: return decoder(feat, skips) else: return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_f = vgg(content) style_f = vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha=1.0): assert (0.0 <= alpha <= 1.0) content_f = vgg(content) style_f = vgg(style) feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def _transform(self, img, param): alpha, style_f = param with torch.no_grad(): content_f = self.vgg(img.to(self.enc_device)) feat = adaptive_instance_normalization(content_f, style_f) feat = alpha * feat + (1 - alpha) * content_f feat = feat.to(self.dec_device) stylized = self.decoder(feat) stylized = stylized.to(self.enc_device) return stylized
def transfer(self, content, style, preserve_color=False, alpha=1.0, interpolation_weights=None): """ CONTENT is always a single image. STYLE can be either a single image or a list of images. If STYLE is a list of images, you must also pass a list of INTERPOLATION_WEIGHTS. """ if interpolation_weights: # one content image, N style images style = torch.stack([self.style_tf(s) for s in style]) content = self.content_tf(content).unsqueeze(0).expand_as(style) style = style.to(self.device) content = content.to(self.device) else: # one content image, one style image content = self.content_tf(content) style = self.style_tf(style) if preserve_color: style = coral(style, content) style = style.to(self.device).unsqueeze(0) content = content.to(self.device).unsqueeze(0) with torch.no_grad(): content_f = self.vgg(content) style_f = self.vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(self.device) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) output = self.decoder(feat) return output.cpu()
def style_transfer(vgg, decoder, content, style, alpha=1.0, interpolation_weights=None): content_f = vgg(content) style_f = vgg(style) feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha=1.0, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) if args.mode in ["SE64x+BD", "SE16x+BD", "E2D1"]: content_f = vgg.forward_aux(content, False)[-1] style_f = vgg.forward_aux(style, False)[-1] else: content_f = vgg(content) style_f = vgg(style) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_f, style_f) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def style_transfer(vgg, decoder, content, style, alpha=1.0, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_f = vgg(content) style_f = vgg(style) if (content_f.size() != style_f.size()): print(content_f.size()) print(style_f.size()) feat = adaptive_instance_normalization(content_f, style_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)
def _transform(self, img, param): alpha, style = param style = style[:img.size( )[0]] # undercomplete batches (last dataloader iteration) or no batching at all style = style.to(self.enc_device) with torch.no_grad(): content_f = self.vgg(img) style_f = self.vgg(style) feat = adaptive_instance_normalization(content_f, style_f) feat = alpha * feat + (1 - alpha) * content_f if self.enc_device != self.dec_device: feat = feat.to(self.dec_device) stylized = self.decoder(feat) if self.enc_device != self.dec_device: stylized = stylized.to(self.enc_device) return stylized
def style_transfer(vgg, decoder, abstracter, corrector, content, style, alpha=1.0, interpolation_weights=None): assert (0.0 <= alpha <= 1.0) content_f = vgg(content) #style_gen = vgg(style) style_gen = abstracter.execute(content_f) # Attension content -> style correct_f = corrector.execute(content_f) if interpolation_weights: _, C, H, W = content_f.size() feat = torch.FloatTensor(1, C, H, W).zero_().to(device) base_feat = adaptive_instance_normalization(content_f, style_gen) for i, w in enumerate(interpolation_weights): feat = feat + w * base_feat[i:i + 1] content_f = content_f[0:1] else: #feat = single_adaptive_instance_normalization(content_f, style_gen) feat = correct_adaptive_instance_normalization(content_f, style_gen, correct_f) feat = feat * alpha + content_f * (1 - alpha) return decoder(feat)