示例#1
0
 def calc_dis(i, j, A, B):
     if abs(i - j) > maxl:
         return 1e10
     optF = PWCnet.estimate(self.pwc_model, A, B)
     toptF = optF.permute(0, 3, 1, 2)
     tmp = nn.functional.interpolate(toptF,
                                     size=(self.frames.shape[2] // 4,
                                           self.frames.shape[3] // 4),
                                     mode="bilinear")
     tmp[:, 0, :, :] = tmp[:, 0, :, :] * tmp.shape[3] / toptF.shape[3]
     tmp[:, 1, :, :] = tmp[:, 1, :, :] * tmp.shape[2] / toptF.shape[2]
     FoptF = tmp.permute(0, 2, 3, 1)
     allFopt[(i, j)] = FoptF
     d = torch.mean(torch.sqrt(L2dis(optF)))
     del FoptF, tmp, optF, toptF
     return d
import math
import scipy.stats as st
import numpy as np
import torch.nn as nn
# ~~~~~~
from models.model import Hed
import models.PWCNet as PWCnet
import random
import matplotlib.pyplot as plt


def L2dis(v):
    return torch.sum(v * v, dim=1, keepdim=True)


pwc_model = PWCnet.PWCNet().cuda().eval()
model_path = './pretrained_models/network-default.pytorch'
pwc_model.load_state_dict(torch.load(model_path))
for param in pwc_model.parameters():
    param.requires_grad = False

d = 'results/CVPR19-Linear'
frame_cnt = 0
all_loss = 0

with torch.no_grad():
    for fn in os.listdir(d):
        if fn.find('ink') >= 0 and fn.find('merge') == -1 and (
                fn.find('mp4') >= 0 or fn.find('avi') >= 0):
            cap = cv2.VideoCapture(os.path.join(d, fn))
            frame_num = int(cap.get(7))
示例#3
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)

        self.long_term = [0, 1, 10, 20, 40]
        self.alpha1 = opt.alpha1
        self.alpha2 = opt.alpha2
        self.alpha = opt.alpha
        # load/define networks
        self.netG_A_encoder = networks.define_G_encoder(
            opt.input_nc, opt.output_nc, opt.ngf, opt, opt.norm,
            not opt.no_dropout, opt.init_type, self.gpu_ids, opt.saliency,
            opt.multisa)
        self.netG_A_decoder = networks.define_G_decoder(
            opt.input_nc, opt.output_nc, opt.ngf, opt, opt.norm,
            not opt.no_dropout, opt.init_type, self.gpu_ids, opt.multisa)

        self.netM = networks.define_convs(self.netG_A_encoder.channel_size() *
                                          2,
                                          1,
                                          opt.M_layers,
                                          opt.M_size,
                                          gpu_ids=self.gpu_ids)
        self.netM2 = networks.define_convs(self.netG_A_encoder.channel_size() *
                                           2,
                                           1,
                                           opt.M_layers,
                                           opt.M_size,
                                           gpu_ids=self.gpu_ids)
        # ~~~~~~
        if opt.saliency:
            cfg = get_cfg()
            point_rend.add_pointrend_config(cfg)
            cfg.merge_from_file(
                "/home/linchpin/Documents/ink_stylize/ChipGAN_release/models/detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
            )
            cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
            cfg.MODEL.WEIGHTS = "/home/linchpin/Documents/ink_stylize/ChipGAN_release/pretrained_models/model_final_3c3198.pkl"
            self.NetIS = build_model(cfg)
            checkpointer = DetectionCheckpointer(self.NetIS)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            self.NetIS.eval()
            if len(self.opt.gpu_ids) == 0:
                self.NetIS.cpu()
            for param in self.NetIS.parameters():
                param.requires_grad = False

        self.pwc_model = PWCnet.PWCNet().eval()
        if len(self.opt.gpu_ids) != 0:
            self.pwc_model.cuda()
        model_path = './pretrained_models/network-default.pytorch'
        self.pwc_model.load_state_dict(torch.load(model_path))
        for param in self.pwc_model.parameters():
            param.requires_grad = False

        # ~~~~~~
        kw = 3
        g_kernel = self.gauss_kernel(kw, 3, 1).transpose((3, 2, 1, 0))
        self.gauss_conv_kw = nn.Conv2d(1,
                                       1,
                                       kernel_size=kw,
                                       stride=1,
                                       padding=(kw - 1) // 2,
                                       bias=False)
        self.gauss_conv_kw.weight.data.copy_(torch.from_numpy(g_kernel))
        self.gauss_conv_kw.weight.requires_grad = False
        if len(self.opt.gpu_ids) != 0:
            self.gauss_conv_kw.cuda()

        which_epoch = opt.which_epoch
        self.load_network(self.netG_A_encoder, 'G_A_encoder', which_epoch)
        self.load_network(self.netG_A_decoder, 'G_A_decoder', which_epoch)
        self.load_network(self.netM, "M", which_epoch)
        self.load_network(self.netM2, "M2", which_epoch)
        # self.netG_A_decoder.eval()
        # self.netG_A_encoder.eval()
        # self.netM.eval()
        # self.netM2.eval()
        self.pwc_model.eval()

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A_encoder)
        networks.print_network(self.netG_A_decoder)
        networks.print_network(self.netM)
        print('-----------------------------------------------')
示例#4
0
    def initialize(self, opt):
        BaseModel.initialize(self, opt)
        self.display_param = dict()

        self.long_term = [0, 1, 10, 20, 40]
        self.alpha1 = opt.alpha1
        self.alpha2 = opt.alpha2
        self.alpha = opt.alpha
        # load/define networks
        self.netG_A_encoder = networks.define_G_encoder(opt.input_nc, opt.output_nc,
                                                        opt.ngf, opt, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt.saliency, opt.multisa)
        self.netG_A_decoder = networks.define_G_decoder(opt.input_nc, opt.output_nc,
                                                        opt.ngf, opt, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids, opt.multisa)

        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc,
                                        opt.ngf, opt, opt.norm, not opt.no_dropout, opt.init_type, self.gpu_ids)

        chs = self.netG_A_encoder.channel_size()
        self.netM = networks.define_convs(chs * 2, 1, opt.M_layers, opt.M_size, gpu_ids=self.gpu_ids)
        self.netM2 = networks.define_convs(chs * 2, 1, opt.M_layers, opt.M_size, gpu_ids=self.gpu_ids)

        if self.isTrain:
            use_sigmoid = opt.no_lsgan
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf,
                                            opt.which_model_netD,
                                            opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

            self.netD_ink = networks.define_D(opt.output_nc, opt.ndf,
                                              opt.which_model_netD,
                                              opt.n_layers_D, opt.norm, use_sigmoid, opt.init_type, self.gpu_ids)

            g_kernel = self.gauss_kernel(21, 3, 1).transpose((3, 2, 1, 0))
            self.gauss_conv = nn.Conv2d(
                1, 1, kernel_size=21, stride=1, padding=1, bias=False)
            self.gauss_conv.weight.data.copy_(torch.from_numpy(g_kernel))
            self.gauss_conv.weight.requires_grad = False
            self.gauss_conv.cuda()

            # Gaussain blur
            kw = 7
            g_kernel = self.gauss_kernel(kw, 3, 1).transpose((3, 2, 1, 0))
            self.gauss_conv_kw = nn.Conv2d(
                1, 1, kernel_size=kw, stride=1, padding=0, bias=False)
            self.gauss_conv_kw.weight.data.copy_(torch.from_numpy(g_kernel))
            self.gauss_conv_kw.weight.requires_grad = False
            self.gauss_conv_kw.cuda()
            self.gauss_conv_kw_pad = nn.ReflectionPad2d((kw-1)//2)

            L = np.array([1, 1]).reshape(2, 1)
            H = np.array([-1, 1]).reshape(2, 1)
            haar_kernel = np.stack(
                (L@(H.T), H@(L.T), H@(H.T))).reshape(3, 1, 2, 2) / 2
            self.haar_kernel = nn.Conv2d(
                1, 3, 2, stride=2, padding=0, bias=False)
            self.haar_kernel.weight.data.copy_(torch.from_numpy(haar_kernel))
            self.haar_kernel.weight.requires_grad = False
            self.haar_kernel.cuda()

            # Hed Model
            self.hed_model = Hed()
            self.hed_model.cuda()
            save_path = './pretrained_models/35.pth'
            self.hed_model.load_state_dict(torch.load(save_path))
            for param in self.hed_model.parameters():
                param.requires_grad = False

        if opt.saliency:
            # detectron2
            cfg = get_cfg()
            point_rend.add_pointrend_config(cfg)
            cfg.merge_from_file(
                "/home/linchpin/Documents/ink_stylize/ChipGAN_release/models/detectron2_repo/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml")
            cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5
            cfg.MODEL.WEIGHTS = "pretrained_models/model_final_3c3198.pkl"
            self.NetIS = build_model(cfg)
            checkpointer = DetectionCheckpointer(self.NetIS)
            checkpointer.load(cfg.MODEL.WEIGHTS)
            self.NetIS.eval()
            for param in self.NetIS.parameters():
                param.requires_grad = False

        # pwcnet
        self.pwc_model = PWCnet.PWCNet().cuda().eval()
        model_path = './pretrained_models/network-default.pytorch'
        self.pwc_model.load_state_dict(torch.load(model_path))
        for param in self.pwc_model.parameters():
            param.requires_grad = False

        # ~~~~~~

        if not self.isTrain or opt.continue_train:
            which_epoch = opt.which_epoch
            self.load_network(self.netG_A_encoder, 'G_A_encoder', which_epoch)
            self.load_network(self.netG_A_decoder, 'G_A_decoder', which_epoch)
            self.load_network(self.netG_B, 'G_B', which_epoch)
            self.load_network(self.netM, "M", which_epoch)
            self.load_network(self.netM2, "M2", which_epoch)
        if opt.continue_train:
            self.load_network(self.netD_A, 'D_A', which_epoch)
            self.load_network(self.netD_B, 'D_B', which_epoch)
            self.load_network(self.netD_ink, 'D_ink', which_epoch)

        if self.isTrain:
            self.old_lr = opt.lr
            self.fake_A_pool = ImagePool(opt.pool_size)
            self.fake_B_pool = ImagePool(opt.pool_size)
            self.ink_fake_B_pool = ImagePool(opt.pool_size)
            # define loss functions
            self.criterionGAN = networks.GANLoss(
                use_lsgan=not opt.no_lsgan, tensor=self.Tensor)
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            self.TV_LOSS = networks.TVLoss()
            # initialize optimizers
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A_encoder.parameters(), self.netG_A_decoder.parameters(), self.netG_B.parameters()),
                                                lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_M = torch.optim.Adam(itertools.chain(self.netM.parameters(
            ), self.netM2.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_A = torch.optim.Adam(
                self.netD_A.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_B = torch.optim.Adam(
                self.netD_B.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D_ink = torch.optim.Adam(
                self.netD_ink.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
            if opt.continue_train:
                self.load_optim(self.optimizer_M, "M", which_epoch)
                self.load_optim(self.optimizer_G, "G", which_epoch)
                self.load_optim(self.optimizer_D_A, "D_A", which_epoch)
                self.load_optim(self.optimizer_D_B, "D_B", which_epoch)
                self.load_optim(self.optimizer_D_ink, "D_ink", which_epoch)
            self.optimizers = []
            self.schedulers = []
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_M)
            self.optimizers.append(self.optimizer_D_A)
            self.optimizers.append(self.optimizer_D_B)
            self.optimizers.append(self.optimizer_D_ink)
            for optimizer in self.optimizers:
                self.schedulers.append(networks.get_scheduler(optimizer, opt))

        print('---------- Networks initialized -------------')
        networks.print_network(self.netG_A_encoder)
        networks.print_network(self.netG_A_decoder)
        networks.print_network(self.netG_B)
        networks.print_network(self.netM)
        networks.print_network(self.netM2)
        if self.isTrain:
            networks.print_network(self.netD_A)
            networks.print_network(self.netD_B)
        print('-----------------------------------------------')
示例#5
0
    def backward_M(self, lambda_sup, lambda_newloss):
        # optimize network M
        lambda_temporal = self.opt.lambda_temporal
        self.input_A = self.frames[0:1]
        temporal_loss = 0

        F, _, _, _ = self.getF()
        FoptF = self.FoptF
        optF = self.optF
        C = self.C

        # feature fuse
        oF = util.warp(F[1:], FoptF)
        tmp = list()
        for f in F[1:]:
            tmp.append(self.netM(torch.cat((f,F[0]), dim=0).unsqueeze(0)))
        tmp = torch.cat(tmp, dim=0)
        score = torch.softmax(tmp, dim=0)
        oF = torch.sum(oF * score, dim=0, keepdim=True)
        self.display_param["score1"] = score.data.cpu()
        score = (torch.tanh(self.netM2(torch.cat((oF,F[:1]), dim=1))) + 1) / 2
        self.score = torch.mean(score[0].data.cpu(), dim=0)
        oF = score * oF + (1-score) * F[:1]

        if self.opt.saliency:
            frames = self.netG_A_decoder(F, self.SA)
        else:
            frames = self.netG_A_decoder(F)
        I = torch.cat((frames[:1], util.warp(frames[1:], optF)), dim=0)

        # temporal loss
        if self.opt.saliency:
            O = self.netG_A_decoder(oF, self.SA[:1])
        else:
            O = self.netG_A_decoder(oF)
        self.M_fake_B = O.cpu()
        self.fake_B = I[:1].cpu()
        nC = torch.zeros(O.shape[2:]).to(
            device=I.device).unsqueeze(0).unsqueeze(0)

        for i in range(1, F.shape[0]):
            mask = (C[i-1:i] > nC).float()
            nC = nC + mask
            mask = C[i-1:i] + 0.001
            dI = I[i:i+1] - O
            temporal_loss += torch.sum(mask * torch.pow(dI,2)) / torch.sum(mask) / (F.shape[0]-1)
            print("{}: {}".format(i, temporal_loss.item()))
        dI = I[0:1] - O
        temporal_loss += torch.sum((1-nC+0.01) * torch.pow(dI,2)) / torch.sum(1-nC+0.01)
        print("temp loss:", temporal_loss.item())
        temporal_loss *= lambda_temporal / self.frames.shape[0]

        self.mask = nC[0].data.cpu()
        print("check score2:", torch.mean(
            score[-1]).item(), torch.max(score[-1]).item())
        self.display_param["score2"] = score[-1].data.cpu()

        self.temporal_loss = temporal_loss.item()
        self.occlusion = nC.data[0].cpu()

        # mask loss
        if abs(self.opt.lambda_occ) > 0:
            loss_occ = torch.mean(torch.pow(
                score-nn.functional.interpolate(nC, score.shape[2:], mode="bilinear"), 2))
            loss_occ *= self.opt.lambda_occ
            self.loss_occ = loss_occ.item()
        else:
            self.loss_occ = loss_occ = 0

        # score temporal loss
        if self.frames.shape[0] > 2 and abs(self.opt.lambda_score_temp) > 0:
            oF = util.warp(F[2:3], FoptF[1:2])
            score = (torch.tanh(self.netM2(torch.cat((oF,F[:1]), dim=1))) + 1) / 2

            noptF = PWCnet.estimate(self.pwc_model, self.frames[1:2], self.frames[2:3])
            tnoptF = noptF.permute(0, 3, 1, 2)
            tmp = nn.functional.interpolate(
                tnoptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
            tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                tmp.shape[3] / tnoptF.shape[3]
            tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                tmp.shape[2] / tnoptF.shape[2]
            tmp = tmp.permute(0, 2, 3, 1)

            oF2 = util.warp(F[2:3], tmp)
            score2 = (torch.tanh(self.netM2(torch.cat((oF2,F[1:2]), dim=1))) + 1) / 2
            score2 = util.warp(score2, FoptF[:1])
            tC = nn.functional.interpolate(self.C[:1], score.shape[2:], mode="bilinear")
            ds = torch.abs(score2-score)
            mask = (ds < 0.5).float() * tC

            loss_score_temp = torch.sum(ds*mask) / (torch.sum(mask) + 1)
            loss_score_temp *= self.opt.lambda_score_temp
            self.loss_score_temp = loss_score_temp.item()
        else:
            self.loss_score_temp = loss_score_temp = 0

        total_loss = temporal_loss + loss_occ + loss_score_temp \
            + self.clac_GA_common_loss(
                O, lambda_sup, self.SA[:1], edge=True, lambda_newloss=lambda_newloss)
        total_loss.backward()
        if torch.isnan(total_loss).item():
            print("temp:", temporal_loss)
            print("occ:", loss_occ)
            print("score_temp:", loss_score_temp)
            raise ValueError
        del nC, dI, I, FoptF, optF, F, oF, score
示例#6
0
    def forward(self):
        # calc instance segmentation and optical flow
        with torch.no_grad():
            kernel_size = 5
            pad_size = kernel_size//2
            p1d = (pad_size, pad_size, pad_size, pad_size)
            p_real_B = nn.functional.pad(self.input_B, p1d, "constant", 1)
            erode_real_B = -1 * \
                (nn.functional.max_pool2d(-1*p_real_B, kernel_size, 1))

            res1 = self.gauss_conv(erode_real_B[:, 0, :, :].unsqueeze(1))
            res2 = self.gauss_conv(erode_real_B[:, 1, :, :].unsqueeze(1))
            res3 = self.gauss_conv(erode_real_B[:, 2, :, :].unsqueeze(1))

            self.ink_real_B = torch.cat((res1, res2, res3), dim=1)
            if self.frames.shape[0] > 1:
                optF = PWCnet.estimate(self.pwc_model, self.frames[0:1].repeat(
                    self.frames.shape[0]-1, 1, 1, 1), self.frames[1:])
                toptF = optF.permute(0, 3, 1, 2)
                tmp = nn.functional.interpolate(
                    toptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
                tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                    tmp.shape[3] / toptF.shape[3]
                tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                    tmp.shape[2] / toptF.shape[2]
                tFoptF = tmp
                FoptF = tmp.permute(0, 2, 3, 1)

                noptF = PWCnet.estimate(self.pwc_model, self.frames[1:], self.frames[0:1].repeat(
                    self.frames.shape[0]-1, 1, 1, 1))
                tnoptF = noptF.permute(0, 3, 1, 2)
                tmp = nn.functional.interpolate(
                    tnoptF, size=(self.frames.shape[2] // 4, self.frames.shape[3] // 4), mode="bilinear")
                tmp[:, 0, :, :] = tmp[:, 0, :, :] * \
                    tmp.shape[3] / tnoptF.shape[3]
                tmp[:, 1, :, :] = tmp[:, 1, :, :] * \
                    tmp.shape[2] / tnoptF.shape[2]
                tFnoptF = tmp
                FnoptF = tmp.permute(0, 2, 3, 1)

                C = util.warp(tnoptF, optF)
                tmp = C + toptF
                C = (L2dis(tmp) < self.alpha1 *
                     (L2dis(C) + L2dis(toptF)) + self.alpha2).float()

                self.optF = optF
                self.toptF = toptF
                self.tFoptF = tFoptF
                self.FoptF = FoptF
                self.noptF = noptF
                self.tFnoptF = tFnoptF
                self.FnoptF = FnoptF
                self.C = C
                del tmp

            if self.opt.saliency:
                sa = list()
                box = list()
                for img in self.frames:
                    a, b = self.get_sa(img)
                    if torch.sum(a) < 0.001:
                        return False
                    sa.append(a)
                    box.append(torch.stack(b))
                    # exit(0)
                self.SA = torch.cat(sa, dim=0)
                self.sa = sa[0][0].data.cpu()
                self.box = box
                # print(self.sa.shape)
                if self.frames.shape[0] > 1:
                    self.warped_sa = util.warp(self.SA[1:], self.optF)
        return True