# od = torch.randperm(frames.shape[0])
            # frames = frames[od]
            # ink_frames = ink_frames[od]

            while now < frames.shape[0] - 1:
                end = min(frames.shape[0] - 1, now + 30)
                optF = PWCnet.estimate(pwc_model, frames[now:end] / 127.5 - 1,
                                       frames[now + 1:end + 1] / 127.5 - 1)
                toptF = optF.permute(0, 3, 1, 2)

                noptF = PWCnet.estimate(pwc_model,
                                        frames[now + 1:end + 1] / 127.5 - 1,
                                        frames[now:end] / 127.5 - 1)
                tnoptF = noptF.permute(0, 3, 1, 2)

                C = util.warp(tnoptF, optF)
                tmp = C + toptF
                C = (L2dis(tmp) < 0.01 *
                     (L2dis(C) + L2dis(toptF)) + 0.5).float()

                wframe = util.warp(ink_frames[now + 1:end + 1], optF)

                temploss += (
                    torch.sum(torch.abs((wframe - ink_frames[now:end])) * C) /
                    torch.sum(C)).item()
                now = end

                del optF, toptF, noptF, tnoptF, wframe, C, tmp
                # print(now,'/',frames.shape[0])
            print(fn, temploss / frame_num)
            all_loss += temploss
Ejemplo n.º 2
0
    def test(self):
        tmp = list()
        import matplotlib.pyplot as plt
        if self.opt.saliency:
            sa = list()
            box = list()
            for img in self.frames:
                a, b = self.get_sa(img)
                if torch.sum(a) < 0.001:
                    return False
                sa.append(a)
                box.append(torch.stack(b))
            self.SA = torch.cat(sa, dim=0)
            self.sa = sa[0][0].data.cpu()
            self.box = box
        print("start")
        ret = list()
        f = list()
        boxed = list()
        scor = list()
        liml = 1

        # stylization serially

        # for i in range(self.frames.shape[0]):
        #     if self.opt.inpaint:
        #         F, F_horse, SA_horse, pos = self.getF(i,i+1)
        #     else:
        #         F = self.getF(i,i+1)

        #     if len(f) > 0:
        #         num = min(liml, len(f))

        #         # for j in range(4):
        #         #     tmp = util.tensor2im(self.frames[i-3+j])
        #         #     cv2.imwrite(str(j)+'_frame.png', tmp[:,:,-1::-1])

        #         # for j in range(0,3):
        #         #     tmp = torch.mean(f[i-3+j], dim=1)
        #         #     tmp = util.tensor2im(tmp[0])
        #         #     cv2.imwrite(str(j)+'_feature.png', tmp)

        #         optF = PWCnet.estimate(self.pwc_model, self.frames[i:i+1].repeat(num,1,1,1), self.frames[i-num:i])
        #         toptF = optF.permute(0, 3, 1, 2)
        #         tmp = nn.functional.interpolate(
        #             toptF, size=(F.shape[2], F.shape[3]) , mode="bilinear")
        #         tmp[:, 0, :, :] = tmp[:, 0, :, :] * tmp.shape[3] / toptF.shape[3]
        #         tmp[:, 1, :, :] = tmp[:, 1, :, :] * tmp.shape[2] / toptF.shape[2]
        #         FoptF = tmp.permute(0, 2, 3, 1)

        #         # for j in range(0,3):
        #         #     hsv = np.zeros((self.frames.shape[2], self.frames.shape[3], 3), dtype=np.uint8)
        #         #     hsv[..., 1] = 255
        #         #     flow = optF[j].cpu().numpy()
        #         #     mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
        #         #     hsv[..., 0] = ang * 180 / np.pi / 2
        #         #     hsv[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
        #         #     bgr = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
        #         #     cv2.imwrite(str(j)+'_flow.png', bgr)

        #         oF = util.warp(torch.cat(f[-num:], dim=0), FoptF)

        #         # for j in range(0,3):
        #         #     tmp = torch.mean(oF[j], dim=0)
        #         #     tmp = util.tensor2im(tmp)
        #         #     cv2.imwrite(str(j)+'_feature_warped.png', tmp)

        #         # score = torch.softmax(self.netM(oF-F), dim=0)
        #         score = torch.softmax(self.netM(torch.cat((oF,F.repeat(num, 1,1,1)), dim=1)), dim=0)

        #         # for j in range(0,3):
        #         #     tmp = util.tensor2im(score[j])
        #         #     cv2.imwrite(str(j)+'_w_hat.png', tmp)

        #         oF = torch.sum(oF * score, dim=0, keepdim=True)

        #         # tmp = torch.mean(oF, dim=1)
        #         # tmp = util.tensor2im(tmp[0])
        #         # cv2.imwrite('feature_reference_fused.png', tmp)

        #         # score = (torch.tanh(self.netM2(oF-F)) + 1) / 2
        #         score = (torch.tanh(self.netM2(torch.cat((oF,F), dim=1))) + 1) / 2

        #         # tmp = util.tensor2im(score[0])
        #         # cv2.imwrite('w.png', tmp)

        #         scor.append(util.tensor2im(torch.mean(score[0], dim=0)))
        #         # scor.append(util.tensor2im(torch.abs(self.frames[i:i+1]-util.warp(self.frames[i-1:i], optF))))
        #         F = score * oF + (1-score) * F

        #         # tmp = torch.mean(F, dim=1)
        #         # tmp = util.tensor2im(tmp[0])
        #         # cv2.imwrite('feature_fused.png', tmp)

        #         del optF, toptF, tmp, score, FoptF, oF

        #     savetof = F
        #     ink = self.netG_A_decoder(F, self.SA[i:i+1])
        #     ret.append(util.tensor2im(ink))
        #     # if len(f) >= liml:
        #     #     for j in range(4):
        #     #         cv2.imwrite(str(j)+'_ink.png', ret[i-3+j][:,:,-1::-1])
        #     #     input()
        #     if len(f) >= liml:
        #         del f[0]
        #     f.append(savetof)

        #     if self.opt.inpaint:
        #         a,b,c,d = pos[0]

        #         ink_horse = self.netG_A_decoder(F_horse[0], SA_horse[0])
        #         tmp = util.tensor2im(ink_horse)
        #         tmp = cv2.resize(tmp, (c-a, d-b))
        #         del ink_horse
        #         ink_horse = tmp
        #         y, x = ink.shape[2:]
        #         ink_horse = np.pad(ink_horse, ((b, y-d), (a, x-c), (0, 0)), "constant", constant_values=255)
        #         boxed.append(ink_horse)

        #     del ink, F
        # return np.stack(ret), np.stack(boxed) if self.opt.inpaint else None, np.stack(scor) if len(scor) > 0 else None

        # stylization with frame reordering
        print("start ink")
        inked = set()
        uninked = set(range(self.frames.shape[0]))
        allFopt = dict()
        newf = [None] * self.frames.shape[0]

        # calc distance between nearest 50 frames
        maxl = 50
        # max number of reference frames
        maxr = 7

        def calc_dis(i, j, A, B):
            if abs(i - j) > maxl:
                return 1e10
            optF = PWCnet.estimate(self.pwc_model, A, B)
            toptF = optF.permute(0, 3, 1, 2)
            tmp = nn.functional.interpolate(toptF,
                                            size=(self.frames.shape[2] // 4,
                                                  self.frames.shape[3] // 4),
                                            mode="bilinear")
            tmp[:, 0, :, :] = tmp[:, 0, :, :] * tmp.shape[3] / toptF.shape[3]
            tmp[:, 1, :, :] = tmp[:, 1, :, :] * tmp.shape[2] / toptF.shape[2]
            FoptF = tmp.permute(0, 2, 3, 1)
            allFopt[(i, j)] = FoptF
            d = torch.mean(torch.sqrt(L2dis(optF)))
            del FoptF, tmp, optF, toptF
            return d

        dis = dict()
        for i in range(self.frames.shape[0]):
            print("calc optical flow {}/{}".format(i, self.frames.shape[0]))
            for j in range(i + 1, self.frames.shape[0]):
                dis[(i, j)] = calc_dis(i, j, self.frames[i:i + 1],
                                       self.frames[j:j + 1])
                dis[(j, i)] = calc_dis(j, i, self.frames[j:j + 1],
                                       self.frames[i:i + 1])

        def evalue(x):
            val1 = 0
            for i in inked:
                val1 += dis[(x, i)]
            val2 = 0
            for i in uninked:
                if i != x:
                    val2 += dis[(x, i)]
            return -val1 / (len(inked) +
                            1) + self.alpha * val2 / (len(uninked) + 1)

        while len(uninked) > 0:
            idx = 0
            mn = 1e30
            for i in uninked:
                tmp = evalue(i)
                if tmp < mn:
                    mn = tmp
                    idx = i
            print("deal {} ...".format(idx))
            l = list()
            for i in range(self.frames.shape[0]):
                if i != idx and i in inked and abs(i - idx) <= maxl:
                    l.append((dis[(i, idx)], i))
            l.sort()
            if len(l) > maxr:
                l = l[:maxr]
            now = list(map(lambda x: newf[x[1]].cuda(), l))
            F, _, _, _ = self.getF(idx, idx + 1)
            print("referrence:", l)

            if len(now) > 0:
                now = torch.cat(now, dim=0)
                FoptF = list()
                for i in range(len(now)):
                    FoptF.append(allFopt[(idx, l[i][1])])
                FoptF = torch.cat(FoptF, dim=0)
                now = util.warp(now, FoptF)
                score = torch.softmax(self.netM(now - F), dim=0)
                oF = torch.sum(now * score, dim=0, keepdim=True)
                score = (torch.tanh(self.netM2(oF - F)) + 1) / 2
                oF = oF * score + F * (1 - score)
                del score, FoptF, now
            else:
                oF = F
            del F
            newf[idx] = oF.cpu()
            del oF
            inked.add(idx)
            uninked.remove(idx)
        del dis
        del allFopt
        del calc_dis
        del evalue
        fin = list()
        for i in range(len(newf)):
            fin.append(
                util.tensor2im(
                    self.netG_A_decoder(newf[i].cuda(), self.SA[i:i + 1])))
        del newf
        return np.stack(fin), None, None
Ejemplo n.º 3
0
    def backward_G(self, lambda_sup, lambda_newloss):
        # optimize network G

        self.input_A = self.frames[0:1]
        SA = self.SA[:1]

        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B

        # stylization
        if self.opt.inpaint:
            F, F_horse, SA_horse, pos = self.getF()
        else:
            F = self.getF()

        if self.opt.saliency:
            inked = self.netG_A_decoder(F, self.SA)
            fake_B = inked[:1]

            # Crop Loss
            if self.opt.inpaint and abs(self.opt.lambda_crop * lambda_newloss) > 0:
                ink_horse = self.netG_A_decoder(F_horse[0], SA_horse[0])
                a, b, c, d = pos[0]

                ink_horse = nn.functional.interpolate(
                    ink_horse, size=(d-b, c-a), mode="bilinear")
                self.boxed = ink_horse

                loss_crop = torch.mean(torch.pow(
                    fake_B[:, :, b:d, a:c] - ink_horse, 2)) * self.opt.lambda_crop * lambda_newloss
                self.loss_crop = loss_crop.item()
            else:
                loss_crop = self.loss_crop = 0
        else:
            inked = self.netG_A_decoder(F)
            fake_B = inked[:1]
            fake_rev = self.netG_A_decoder(F2)
            loss_crop = self.loss_crop = 0
        self.fake_B = fake_B.detach()

        # GAN loss D_B(G_B(B))
        fake_A = self.netG_B(self.input_B)
        pred_fake = self.netD_B(fake_A)
        loss_G_B = self.criterionGAN(
            pred_fake, True)
        self.loss_G_B = loss_G_B.item()
        self.fake_A = fake_A

        # Forward cycle loss
        if abs(lambda_A) > 0:
            rec_A = self.netG_B(fake_B)
            loss_cycle_A = self.criterionCycle(
                rec_A, self.input_A) * lambda_A
            self.loss_cycle_A = loss_cycle_A.item()
            self.rec_A = rec_A.data.cpu()
            del rec_A
        else:
            loss_cycle_A = 0

        # Backward cycle loss
        if self.opt.saliency:
            B_sa = list()
            for img in fake_A:
                B_sa.append(self.get_sa(img)[0])
            B_sa = torch.cat(B_sa, dim=0)
            self.B_sa = B_sa[0].data.cpu()
        else:
            self.loss_class = 0
        if self.opt.saliency:
            rec_B = self.netG_A(fake_A, B_sa)
        else:
            rec_B = self.netG_A(fake_A)
        del B_sa
        loss_cycle_B = self.criterionCycle(
            rec_B, self.input_B) * lambda_B
        self.loss_cycle_B = loss_cycle_B.item()
        self.rec_B = rec_B.data.cpu()
        del rec_B, fake_A

        # Warp loss
        if abs(self.opt.lambda_warp * lambda_newloss) > 0 and self.frames.shape[0] > 1 and torch.mean(self.C) > 0.1:
            I = util.warp(inked[1:2], self.optF[:1])
            if self.opt.saliency:
                II = self.netG_A_decoder(
                    util.warp(F[1:2], self.FoptF[:1]), self.warped_sa[:1])
            else:
                II = self.netG_A_decoder(util.warp(F[1:2], self.FoptF[:1]))
            dI = I - II
            warp_loss = torch.sum(dI * dI * self.C[:1]) / torch.sum(self.C[:1])
            warp_loss *= self.opt.lambda_warp * lambda_newloss
            self.warp_loss = warp_loss.item()
            self.fake_warp_error = dI[-1].data.cpu()
            self.warp_error = (util.warp(self.frames[1:2], self.optF[:1])-self.frames[:1]).data.cpu()
            del I, II, dI
        else:
            warp_loss = 0
            if not hasattr(self, "warp_loss"):
                self.warp_loss = 0
            if not hasattr(self, "warp_error"):
                self.warp_error = torch.zeros((3, 1, 1))
        del F

        loss = loss_cycle_A + loss_G_B + loss_cycle_B + warp_loss +\
            self.clac_GA_common_loss(
                fake_B, lambda_sup, self.SA[:1], lambda_newloss=lambda_newloss)
        loss.backward()
        if torch.isnan(loss).item():
            raise ValueError
Ejemplo n.º 4
0
    def backward_M(self, lambda_sup, lambda_newloss):
        # optimize network M
        lambda_temporal = self.opt.lambda_temporal
        self.input_A = self.frames[0:1]
        temporal_loss = 0

        F, _, _, _ = self.getF()
        FoptF = self.FoptF
        optF = self.optF
        C = self.C

        # feature fuse
        oF = util.warp(F[1:], FoptF)
        tmp = list()
        for f in F[1:]:
            tmp.append(self.netM(torch.cat((f,F[0]), dim=0).unsqueeze(0)))
        tmp = torch.cat(tmp, dim=0)
        score = torch.softmax(tmp, dim=0)
        oF = torch.sum(oF * score, dim=0, keepdim=True)
        self.display_param["score1"] = score.data.cpu()
        score = (torch.tanh(self.netM2(torch.cat((oF,F[:1]), dim=1))) + 1) / 2
        self.score = torch.mean(score[0].data.cpu(), dim=0)
        oF = score * oF + (1-score) * F[:1]

        if self.opt.saliency:
            frames = self.netG_A_decoder(F, self.SA)
        else:
            frames = self.netG_A_decoder(F)
        I = torch.cat((frames[:1], util.warp(frames[1:], optF)), dim=0)

        # temporal loss
        if self.opt.saliency:
            O = self.netG_A_decoder(oF, self.SA[:1])
        else:
            O = self.netG_A_decoder(oF)
        self.M_fake_B = O.cpu()
        self.fake_B = I[:1].cpu()
        nC = torch.zeros(O.shape[2:]).to(
            device=I.device).unsqueeze(0).unsqueeze(0)

        for i in range(1, F.shape[0]):
            mask = (C[i-1:i] > nC).float()
            nC = nC + mask
            mask = C[i-1:i] + 0.001
            dI = I[i:i+1] - O
            temporal_loss += torch.sum(mask * torch.pow(dI,2)) / torch.sum(mask) / (F.shape[0]-1)
            print("{}: {}".format(i, temporal_loss.item()))
        dI = I[0:1] - O
        temporal_loss += torch.sum((1-nC+0.01) * torch.pow(dI,2)) / torch.sum(1-nC+0.01)
        print("temp loss:", temporal_loss.item())
        temporal_loss *= lambda_temporal / self.frames.shape[0]

        self.mask = nC[0].data.cpu()
        print("check score2:", torch.mean(
            score[-1]).item(), torch.max(score[-1]).item())
        self.display_param["score2"] = score[-1].data.cpu()

        self.temporal_loss = temporal_loss.item()
        self.occlusion = nC.data[0].cpu()

        # mask loss
        if abs(self.opt.lambda_occ) > 0:
            loss_occ = torch.mean(torch.pow(
                score-nn.functional.interpolate(nC, score.shape[2:], mode="bilinear"), 2))
            loss_occ *= self.opt.lambda_occ
            self.loss_occ = loss_occ.item()
        else:
            self.loss_occ = loss_occ = 0

        # score temporal loss
        if self.frames.shape[0] > 2 and abs(self.opt.lambda_score_temp) > 0:
            oF = util.warp(F[2:3], FoptF[1:2])
            score = (torch.tanh(self.netM2(torch.cat((oF,F[:1]), dim=1))) + 1) / 2

            noptF = PWCnet.estimate(self.pwc_model, self.frames[1:2], self.frames[2:3])
            tnoptF = noptF.permute(0, 3, 1, 2)
            tmp = nn.functional.interpolate(
                tnoptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
            tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                tmp.shape[3] / tnoptF.shape[3]
            tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                tmp.shape[2] / tnoptF.shape[2]
            tmp = tmp.permute(0, 2, 3, 1)

            oF2 = util.warp(F[2:3], tmp)
            score2 = (torch.tanh(self.netM2(torch.cat((oF2,F[1:2]), dim=1))) + 1) / 2
            score2 = util.warp(score2, FoptF[:1])
            tC = nn.functional.interpolate(self.C[:1], score.shape[2:], mode="bilinear")
            ds = torch.abs(score2-score)
            mask = (ds < 0.5).float() * tC

            loss_score_temp = torch.sum(ds*mask) / (torch.sum(mask) + 1)
            loss_score_temp *= self.opt.lambda_score_temp
            self.loss_score_temp = loss_score_temp.item()
        else:
            self.loss_score_temp = loss_score_temp = 0

        total_loss = temporal_loss + loss_occ + loss_score_temp \
            + self.clac_GA_common_loss(
                O, lambda_sup, self.SA[:1], edge=True, lambda_newloss=lambda_newloss)
        total_loss.backward()
        if torch.isnan(total_loss).item():
            print("temp:", temporal_loss)
            print("occ:", loss_occ)
            print("score_temp:", loss_score_temp)
            raise ValueError
        del nC, dI, I, FoptF, optF, F, oF, score
Ejemplo n.º 5
0
    def forward(self):
        # calc instance segmentation and optical flow
        with torch.no_grad():
            kernel_size = 5
            pad_size = kernel_size//2
            p1d = (pad_size, pad_size, pad_size, pad_size)
            p_real_B = nn.functional.pad(self.input_B, p1d, "constant", 1)
            erode_real_B = -1 * \
                (nn.functional.max_pool2d(-1*p_real_B, kernel_size, 1))

            res1 = self.gauss_conv(erode_real_B[:, 0, :, :].unsqueeze(1))
            res2 = self.gauss_conv(erode_real_B[:, 1, :, :].unsqueeze(1))
            res3 = self.gauss_conv(erode_real_B[:, 2, :, :].unsqueeze(1))

            self.ink_real_B = torch.cat((res1, res2, res3), dim=1)
            if self.frames.shape[0] > 1:
                optF = PWCnet.estimate(self.pwc_model, self.frames[0:1].repeat(
                    self.frames.shape[0]-1, 1, 1, 1), self.frames[1:])
                toptF = optF.permute(0, 3, 1, 2)
                tmp = nn.functional.interpolate(
                    toptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
                tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                    tmp.shape[3] / toptF.shape[3]
                tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                    tmp.shape[2] / toptF.shape[2]
                tFoptF = tmp
                FoptF = tmp.permute(0, 2, 3, 1)

                noptF = PWCnet.estimate(self.pwc_model, self.frames[1:], self.frames[0:1].repeat(
                    self.frames.shape[0]-1, 1, 1, 1))
                tnoptF = noptF.permute(0, 3, 1, 2)
                tmp = nn.functional.interpolate(
                    tnoptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
                tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                    tmp.shape[3] / tnoptF.shape[3]
                tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                    tmp.shape[2] / tnoptF.shape[2]
                tFnoptF = tmp
                FnoptF = tmp.permute(0, 2, 3, 1)

                C = util.warp(tnoptF, optF)
                tmp = C + toptF
                C = (L2dis(tmp) < self.alpha1 *
                     (L2dis(C) + L2dis(toptF)) + self.alpha2).float()

                self.optF = optF
                self.toptF = toptF
                self.tFoptF = tFoptF
                self.FoptF = FoptF
                self.noptF = noptF
                self.tFnoptF = tFnoptF
                self.FnoptF = FnoptF
                self.C = C
                del tmp

            if self.opt.saliency:
                sa = list()
                box = list()
                for img in self.frames:
                    a, b = self.get_sa(img)
                    if torch.sum(a) < 0.001:
                        return False
                    sa.append(a)
                    box.append(torch.stack(b))
                    # exit(0)
                self.SA = torch.cat(sa, dim=0)
                self.sa = sa[0][0].data.cpu()
                self.box = box
                # print(self.sa.shape)
                if self.frames.shape[0] > 1:
                    self.warped_sa = util.warp(self.SA[1:], self.optF)
        return True