コード例 #1
0
ファイル: mosaicFAMOS.py プロジェクト: zhuwenxing/famos
def famosGeneration(content, noise, templatePatch, bVis=False):
    if opt.multiScale > 0:
        x = netMix(content, noise, templatePatch)
    else:
        x = netMix(content, noise)
    a5 = x[:, -5:]
    A = 4 * nn.functional.tanh(x[:, :-5])  ##smooths probs somehow
    A = nn.functional.softmax(1 * (A - A.detach().max()), dim=1)
    mixed = getTemplateMixImage(A, templatePatch)
    alpha = nn.functional.sigmoid(a5[:, 3:4])
    beta = nn.functional.sigmoid(a5[:, 4:5])
    fake = blend(nn.functional.tanh(a5[:, :3]), mixed, alpha, beta)

    ##call second Unet to refine further
    if opt.refine:
        a5 = netRefine(
            torch.cat([content, mixed, fake, a5[:, :3],
                       tvArray(A)], 1), noise)
        alpha = nn.functional.sigmoid(a5[:, 3:4])
        beta = nn.functional.sigmoid(a5[:, 4:5])
        fake = blend(nn.functional.tanh(a5[:, :3]), mixed, alpha, beta)

    if bVis:
        return fake, torch.cat([alpha, beta, (alpha + beta) * 0.5],
                               1), A, mixed  #alpha
    return fake
コード例 #2
0
    def forward(self, input1, input2=None, M=None):
        if bfirstNoise and input2 is not None:
            x = torch.cat([
                input1,
                nn.functional.upsample(
                    input2, scale_factor=2**self.nDep, mode='bilinear')
            ], 1)
            input2 = None
        else:
            x = input1  ##initial input

        skips = []
        input1 = input1[:, 3:5]  ##only coords
        for i in range(self.nDep):
            if i > 0 and self.bCopyIn:
                input1 = nn.functional.avg_pool2d(input1, int(2))
                x = torch.cat([x, input1], 1)
            x = self.eblocks[i].forward(x)
            if i != self.nDep - 1:
                if self.bCopyIn:
                    skips += [
                        torch.cat(
                            [x, nn.functional.avg_pool2d(input1, int(2))], 1)
                    ]
                else:
                    skips += [x]
        bottle = x
        if input2 is not None:
            bottle = torch.cat((x, input2),
                               1)  ##the det. output and the noise appended
        x = bottle

        with torch.no_grad():
            MM = M.view(-1, 3, M.shape[3], M.shape[4])  ##only RGB images
            mFeat = []  ##without full length M
            for i in range(1, self.nDep):
                sc = 2**i
                mFeat.append(
                    nn.functional.avg_pool2d(MM, int(sc)).view(
                        M.shape[0], M.shape[1], 3, M.shape[3] // sc,
                        M.shape[4] // sc))
            mFeat = mFeat[::-1]

        for i in range(len(self.dblocks)):
            x = self.dblocks[i].forward(x)
            if i < self.nDep - 1 and self.bSkip:
                x = torch.cat((x, skips[-1 - i]), 1)
            if i < self.nDep - 1:
                blendA = 4 * nn.functional.tanh(
                    x[:, :opt.N])  ##channels for mixing
                blendA = nn.functional.softmax(
                    1 * (blendA - blendA.detach().max()), dim=1)
                mixed = getTemplateMixImage(blendA, mFeat[i])
                x = torch.cat((x, mixed), 1)
        return x