def famosGeneration(content, noise, templatePatch, bVis=False): if opt.multiScale > 0: x = netMix(content, noise, templatePatch) else: x = netMix(content, noise) a5 = x[:, -5:] A = 4 * nn.functional.tanh(x[:, :-5]) ##smooths probs somehow A = nn.functional.softmax(1 * (A - A.detach().max()), dim=1) mixed = getTemplateMixImage(A, templatePatch) alpha = nn.functional.sigmoid(a5[:, 3:4]) beta = nn.functional.sigmoid(a5[:, 4:5]) fake = blend(nn.functional.tanh(a5[:, :3]), mixed, alpha, beta) ##call second Unet to refine further if opt.refine: a5 = netRefine( torch.cat([content, mixed, fake, a5[:, :3], tvArray(A)], 1), noise) alpha = nn.functional.sigmoid(a5[:, 3:4]) beta = nn.functional.sigmoid(a5[:, 4:5]) fake = blend(nn.functional.tanh(a5[:, :3]), mixed, alpha, beta) if bVis: return fake, torch.cat([alpha, beta, (alpha + beta) * 0.5], 1), A, mixed #alpha return fake
def forward(self, input1, input2=None, M=None): if bfirstNoise and input2 is not None: x = torch.cat([ input1, nn.functional.upsample( input2, scale_factor=2**self.nDep, mode='bilinear') ], 1) input2 = None else: x = input1 ##initial input skips = [] input1 = input1[:, 3:5] ##only coords for i in range(self.nDep): if i > 0 and self.bCopyIn: input1 = nn.functional.avg_pool2d(input1, int(2)) x = torch.cat([x, input1], 1) x = self.eblocks[i].forward(x) if i != self.nDep - 1: if self.bCopyIn: skips += [ torch.cat( [x, nn.functional.avg_pool2d(input1, int(2))], 1) ] else: skips += [x] bottle = x if input2 is not None: bottle = torch.cat((x, input2), 1) ##the det. output and the noise appended x = bottle with torch.no_grad(): MM = M.view(-1, 3, M.shape[3], M.shape[4]) ##only RGB images mFeat = [] ##without full length M for i in range(1, self.nDep): sc = 2**i mFeat.append( nn.functional.avg_pool2d(MM, int(sc)).view( M.shape[0], M.shape[1], 3, M.shape[3] // sc, M.shape[4] // sc)) mFeat = mFeat[::-1] for i in range(len(self.dblocks)): x = self.dblocks[i].forward(x) if i < self.nDep - 1 and self.bSkip: x = torch.cat((x, skips[-1 - i]), 1) if i < self.nDep - 1: blendA = 4 * nn.functional.tanh( x[:, :opt.N]) ##channels for mixing blendA = nn.functional.softmax( 1 * (blendA - blendA.detach().max()), dim=1) mixed = getTemplateMixImage(blendA, mFeat[i]) x = torch.cat((x, mixed), 1) return x