def setup(opts):
    #initialize network
    netM = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                             output_nc=4,
                             n_blocks1=7,
                             n_blocks2=3)
    netM = nn.DataParallel(netM)
    checkpoint_path = opts['checkpoint']
    netM.load_state_dict(torch.load(checkpoint_path))
    netM.cuda()
    netM.eval()
    cudnn.benchmark = True
    return netM
train_loader = torch.utils.data.DataLoader(traindata,
                                           batch_size=args.batch_size,
                                           shuffle=True,
                                           num_workers=args.batch_size,
                                           collate_fn=collate_filter_none)

print('\n[Phase 2] : Initialization')

netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                         output_nc=4,
                         n_blocks1=args.n_blocks1,
                         n_blocks2=args.n_blocks2)
netB = nn.DataParallel(netB)
netB.load_state_dict(torch.load(args.init_model))
netB.cuda()
netB.eval()
for param in netB.parameters():  # freeze netD
    param.requires_grad = False

netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                         output_nc=4,
                         n_blocks1=args.n_blocks1,
                         n_blocks2=args.n_blocks2)
netG.apply(conv_init)
netG = nn.DataParallel(netG)
netG.cuda()
torch.backends.cudnn.benchmark = True

netD = MultiscaleDiscriminator(input_nc=3,
                               num_D=1,
                               norm_layer=nn.InstanceNorm2d,
back_img20 = np.zeros(back_img10.shape)
back_img20[..., 0] = 120
back_img20[..., 1] = 255
back_img20[..., 2] = 155

#initialize network
fo = glob.glob(model_main_dir + 'netG_epoch_*')
model_name1 = fo[0]
netM = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                         output_nc=4,
                         n_blocks1=7,
                         n_blocks2=3)
netM = nn.DataParallel(netM)
netM.load_state_dict(torch.load(model_name1))
netM.cuda()
netM.eval()
cudnn.benchmark = True
reso = (512, 512)  #input reoslution to the network

#Create a list of test images
test_imgs = [
    f for f in os.listdir(data_path)
    if os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')
]
test_imgs.sort()

#output directory
result_path = args.output_dir

if not os.path.exists(result_path):
    os.makedirs(result_path)
示例#4
0
    def __init__(self, device=None, jit=True):
        self.device = device
        self.jit = jit
        self.opt = Namespace(
            **{
                'n_blocks1': 7,
                'n_blocks2': 3,
                'batch_size': 1,
                'resolution': 512,
                'name': 'Real_fixed'
            })

        scriptdir = os.path.dirname(os.path.realpath(__file__))
        csv_file = "Video_data_train_processed.csv"
        with open("Video_data_train.csv", "r") as r:
            with open(csv_file, "w") as w:
                w.write(r.read().format(scriptdir=scriptdir))
        data_config_train = {
            'reso': (self.opt.resolution, self.opt.resolution)
        }
        traindata = VideoData(csv_file=csv_file,
                              data_config=data_config_train,
                              transform=None)
        self.train_loader = torch.utils.data.DataLoader(
            traindata,
            batch_size=self.opt.batch_size,
            shuffle=True,
            num_workers=self.opt.batch_size,
            collate_fn=_collate_filter_none)

        netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        if self.device == 'cuda':
            netB.cuda()
        netB.eval()
        for param in netB.parameters():  # freeze netB
            param.requires_grad = False
        self.netB = netB

        netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        netG.apply(conv_init)
        self.netG = netG

        if self.device == 'cuda':
            self.netG.cuda()
            # TODO(asuhan): is this needed?
            torch.backends.cudnn.benchmark = True

        netD = MultiscaleDiscriminator(input_nc=3,
                                       num_D=1,
                                       norm_layer=nn.InstanceNorm2d,
                                       ndf=64)
        netD.apply(conv_init)
        netD = nn.DataParallel(netD)
        self.netD = netD
        if self.device == 'cuda':
            self.netD.cuda()

        self.l1_loss = alpha_loss()
        self.c_loss = compose_loss()
        self.g_loss = alpha_gradient_loss()
        self.GAN_loss = GANloss()

        self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
        self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

        self.log_writer = SummaryWriter(scriptdir)
        self.model_dir = scriptdir

        self._maybe_trace()
示例#5
0
def inference(
    output_dir,
    input_dir,
    sharpen=False,
    mask_ops="erode,3,10;dilate,5,5;blur,31,0",
    video=True,
    target_back=None,
    back=None,
    trained_model="real-fixed-cam",
    mask_suffix="_masksDL",
    outputs=["out"],
    output_suffix="",
):
    # input model
    model_main_dir = "Models/" + trained_model + "/"
    # input data path
    data_path = input_dir

    alpha_output = "out" in outputs
    matte_output = "matte" in outputs
    fg_output = "fg" in outputs
    compose_output = "compose" in outputs

    # initialize network
    fo = glob.glob(model_main_dir + "netG_epoch_*")
    model_name1 = fo[0]
    netM = ResnetConditionHR(
        input_nc=(3, 3, 1, 4), output_nc=4, n_blocks1=7, n_blocks2=3
    )
    netM = nn.DataParallel(netM)
    netM.load_state_dict(torch.load(model_name1))
    netM.cuda()
    netM.eval()
    cudnn.benchmark = True
    reso = (512, 512)  # input reoslution to the network

    # load captured background for video mode, fixed camera
    if back is not None:
        bg_im0 = cv2.imread(back)
        bg_im0 = cv2.cvtColor(bg_im0, cv2.COLOR_BGR2RGB)
        if sharpen:
            bg_im0 = sharpen_image(bg_im0)

    # Create a list of test images
    test_imgs = [
        f
        for f in os.listdir(data_path)
        if os.path.isfile(os.path.join(data_path, f)) and f.endswith("_img.png")
    ]
    test_imgs.sort()

    # output directory
    result_path = output_dir

    if not os.path.exists(result_path):
        os.makedirs(result_path)

    # mask preprocess data
    ops = []
    if mask_ops:
        ops_list = mask_ops.split(";")
        for i, op_st in enumerate(ops_list):
            op_list = op_st.split(",")
            op = op_list[0]
            ks = int(op_list[1])
            it = int(op_list[2])
            if op != "blur":
                kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (ks, ks))
            else:
                kernel = (ks, ks)
            ops.append((op, kernel, it))

    for i in tqdm(range(0, len(test_imgs))):
        filename = test_imgs[i]
        # original image
        bgr_img = cv2.imread(os.path.join(data_path, filename))
        bgr_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)

        if back is None:
            # captured background image
            bg_im0 = cv2.imread(
                os.path.join(data_path, filename.replace("_img", "_back"))
            )
            bg_im0 = cv2.cvtColor(bg_im0, cv2.COLOR_BGR2RGB)

        # segmentation mask
        rcnn = cv2.imread(
            os.path.join(data_path, filename.replace("_img", mask_suffix)), 0
        )

        if video:  # if video mode, load target background frames
            # target background path
            if compose_output:
                back_img10 = cv2.imread(
                    os.path.join(target_back, filename.replace("_img.png", ".png"))
                )
                back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB)
            # Green-screen background
            back_img20 = np.zeros(bgr_img.shape)
            back_img20[..., 0] = 120
            back_img20[..., 1] = 255
            back_img20[..., 2] = 155

            # create multiple frames with adjoining frames
            gap = 20
            multi_fr_w = np.zeros((bgr_img.shape[0], bgr_img.shape[1], 4))
            idx = [i - 2 * gap, i - gap, i + gap, i + 2 * gap]
            for t in range(0, 4):
                if idx[t] < 0:
                    idx[t] = len(test_imgs) + idx[t]
                elif idx[t] >= len(test_imgs):
                    idx[t] = idx[t] - len(test_imgs)

                file_tmp = test_imgs[idx[t]]
                bgr_img_mul = cv2.imread(os.path.join(data_path, file_tmp))
                multi_fr_w[..., t] = cv2.cvtColor(bgr_img_mul, cv2.COLOR_BGR2GRAY)
        else:
            if i is 0:
                if compose_output:
                    # target background path
                    back_img10 = cv2.imread(target_back)
                    back_img10 = cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB)
                # Green-screen background
                back_img20 = np.zeros(bgr_img.shape)
                back_img20[..., 0] = 120
                back_img20[..., 1] = 255
                back_img20[..., 2] = 155
            ## create the multi-frame
            multi_fr_w = np.zeros((bgr_img.shape[0], bgr_img.shape[1], 4))
            multi_fr_w[..., 0] = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2GRAY)
            multi_fr_w[..., 1] = multi_fr_w[..., 0]
            multi_fr_w[..., 2] = multi_fr_w[..., 0]
            multi_fr_w[..., 3] = multi_fr_w[..., 0]

        # crop tightly
        bgr_img0 = bgr_img
        try:
            bbox = get_bbox(rcnn, R=bgr_img0.shape[0], C=bgr_img0.shape[1])
        except ValueError:
            R0 = bgr_img0.shape[0]
            C0 = bgr_img0.shape[1]
            if compose_output:
                back_img10 = cv2.resize(back_img10, (C0, R0))
            back_img20 = cv2.resize(back_img20, (C0, R0)).astype(np.uint8)
            # There is no mask input, create empty images
            if alpha_output:
                cv2.imwrite(
                    result_path
                    + "/"
                    + filename.replace("_img", "_out" + output_suffix),
                    rcnn,
                )
            if fg_output:
                cv2.imwrite(
                    result_path + "/" + filename.replace("_img", "_fg" + output_suffix),
                    cv2.cvtColor(cv2.resize(rcnn, (C0, R0)), cv2.COLOR_GRAY2RGB),
                )
            if compose_output:
                cv2.imwrite(
                    result_path
                    + "/"
                    + filename.replace("_img", "_compose" + output_suffix),
                    cv2.cvtColor(back_img10, cv2.COLOR_BGR2RGB),
                )
            if matte_output:
                cv2.imwrite(
                    result_path
                    + "/"
                    + filename.replace("_img", "_matte" + output_suffix).format(i),
                    cv2.cvtColor(back_img20, cv2.COLOR_BGR2RGB),
                )
            # print("Empty: " + str(i + 1) + "/" + str(len(test_imgs)))
            continue

        crop_list = [bgr_img, bg_im0, rcnn, multi_fr_w]
        crop_list = crop_images(crop_list, reso, bbox)
        bgr_img = crop_list[0]
        bg_im = crop_list[1]
        rcnn = crop_list[2]
        multi_fr = crop_list[3]

        # sharpen original images
        if sharpen:
            bgr_img = sharpen_image(bgr_img)
            if back is None:
                bg_im = sharpen_image(bg_im)

        # process segmentation mask
        rcnn = rcnn.astype(np.float32) / 255
        rcnn[rcnn > 0.2] = 1
        K = 25

        zero_id = np.nonzero(np.sum(rcnn, axis=1) == 0)
        del_id = zero_id[0][zero_id[0] > 250]
        if len(del_id) > 0:
            del_id = [del_id[0] - 2, del_id[0] - 1, *del_id]
            rcnn = np.delete(rcnn, del_id, 0)
        rcnn = cv2.copyMakeBorder(rcnn, 0, K + len(del_id), 0, 0, cv2.BORDER_REPLICATE)

        for op in ops:
            if op[0] == "dilate":
                rcnn = cv2.dilate(rcnn, op[1], iterations=op[2])
            elif op[0] == "erode":
                rcnn = cv2.erode(rcnn, op[1], iterations=op[2])
            elif op[0] == "blur":
                rcnn = cv2.GaussianBlur(rcnn.astype(np.float32), op[1], op[2])
        rcnn = (255 * rcnn).astype(np.uint8)
        rcnn = np.delete(rcnn, range(reso[0], reso[0] + K), 0)

        # convert to torch
        img = torch.from_numpy(bgr_img.transpose((2, 0, 1))).unsqueeze(0)
        img = 2 * img.float().div(255) - 1
        bg = torch.from_numpy(bg_im.transpose((2, 0, 1))).unsqueeze(0)
        bg = 2 * bg.float().div(255) - 1
        rcnn_al = torch.from_numpy(rcnn).unsqueeze(0).unsqueeze(0)
        rcnn_al = 2 * rcnn_al.float().div(255) - 1
        multi_fr = torch.from_numpy(multi_fr.transpose((2, 0, 1))).unsqueeze(0)
        multi_fr = 2 * multi_fr.float().div(255) - 1

        with torch.no_grad():
            img, bg, rcnn_al, multi_fr = (
                Variable(img.cuda()),
                Variable(bg.cuda()),
                Variable(rcnn_al.cuda()),
                Variable(multi_fr.cuda()),
            )
            input_im = torch.cat([img, bg, rcnn_al, multi_fr], dim=1)

            alpha_pred, fg_pred_tmp = netM(img, bg, rcnn_al, multi_fr)

            al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor)

            # for regions with alpha>0.95, simply use the image as fg
            fg_pred = img * al_mask + fg_pred_tmp * (1 - al_mask)

            alpha_out = to_image(alpha_pred[0, ...])

            # refine alpha with connected component
            labels = label((alpha_out > 0.05).astype(int))
            try:
                assert labels.max() != 0
            except:
                continue
            largestCC = labels == np.argmax(np.bincount(labels.flat)[1:]) + 1
            alpha_out = alpha_out * largestCC

            alpha_out = (255 * alpha_out[..., 0]).astype(np.uint8)

            fg_out = to_image(fg_pred[0, ...])
            fg_out = fg_out * np.expand_dims(
                (alpha_out.astype(float) / 255 > 0.01).astype(float), axis=2
            )
            fg_out = (255 * fg_out).astype(np.uint8)

            # Uncrop
            R0 = bgr_img0.shape[0]
            C0 = bgr_img0.shape[1]
            alpha_out0 = uncrop(alpha_out, bbox, R0, C0)
            fg_out0 = uncrop(fg_out, bbox, R0, C0)

        # compose
        if alpha_output:
            cv2.imwrite(
                result_path + "/" + filename.replace("_img", "_out" + output_suffix),
                alpha_out0,
            )
        if fg_output:
            cv2.imwrite(
                result_path + "/" + filename.replace("_img", "_fg" + output_suffix),
                cv2.cvtColor(fg_out0, cv2.COLOR_BGR2RGB),
            )
        if compose_output:
            back_img10 = cv2.resize(back_img10, (C0, R0))
            comp_im_tr1 = composite4(fg_out0, back_img10, alpha_out0)
            cv2.imwrite(
                result_path
                + "/"
                + filename.replace("_img", "_compose" + output_suffix),
                cv2.cvtColor(comp_im_tr1, cv2.COLOR_BGR2RGB),
            )
        if matte_output:
            back_img20 = cv2.resize(back_img20, (C0, R0))
            comp_im_tr2 = composite4(fg_out0, back_img20, alpha_out0)
            cv2.imwrite(
                result_path
                + "/"
                + filename.replace("_img", "_matte" + output_suffix).format(i),
                cv2.cvtColor(comp_im_tr2, cv2.COLOR_BGR2RGB),
            )
	args.video=False
	print('Using image mode')
	#target background path
	back_img10=cv2.imread(args.target_back); back_img10=cv2.cvtColor(back_img10,cv2.COLOR_BGR2RGB);
	#Green-screen background
	back_img20=np.zeros(back_img10.shape); back_img20[...,0]=120; back_img20[...,1]=255; back_img20[...,2]=155;



#initialize network
fo=glob.glob(model_main_dir + 'netG_epoch_*')
model_name1=fo[0]
netM=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=7,n_blocks2=3)
netM=nn.DataParallel(netM)
netM.load_state_dict(torch.load(model_name1))
netM.cuda(); netM.eval()
cudnn.benchmark=True
reso=(512,512) #input reoslution to the network

#load captured background for video mode, fixed camera
if args.back is not None:
	bg_im0=cv2.imread(args.back)
	bg_im0=cv2.cvtColor(bg_im0,cv2.COLOR_BGR2RGB)


#Create a list of test images
test_imgs = [f for f in os.listdir(data_path) if
			   os.path.isfile(os.path.join(data_path, f)) and f.endswith('_img.png')]
test_imgs.sort()

#output directory
示例#7
0
def main():
    # CUDA

    # os.environ["CUDA_VISIBLE_DEVICES"]="4"
    # print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
    print(f'Is CUDA available: {torch.cuda.is_available()}')
    """Parses arguments."""
    parser = argparse.ArgumentParser(
        description='Training Background Matting on Adobe Dataset')
    parser.add_argument('-n',
                        '--name',
                        type=str,
                        help='Name of tensorboard and model saving folders')
    parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size')
    parser.add_argument('-res',
                        '--reso',
                        type=int,
                        help='Input image resolution')
    parser.add_argument('-init_model',
                        '--init_model',
                        type=str,
                        help='Initial model file')

    parser.add_argument('-w',
                        '--workers',
                        type=int,
                        default=None,
                        help='Number of worker to load data')
    parser.add_argument('-ep',
                        '--epochs',
                        type=int,
                        default=15,
                        help='Maximum Epoch')
    parser.add_argument(
        '-n_blocks1',
        '--n_blocks1',
        type=int,
        default=7,
        help='Number of residual blocks after Context Switching')
    parser.add_argument('-n_blocks2',
                        '--n_blocks2',
                        type=int,
                        default=3,
                        help='Number of residual blocks for Fg and alpha each')

    args = parser.parse_args()
    if args.workers is None:
        args.workers = args.batch_size

    ##Directories
    tb_dir = f'tb_summary/{args.name}'
    model_dir = f'models/{args.name}'

    if not os.path.exists(model_dir):
        os.makedirs(model_dir)

    if not os.path.exists(tb_dir):
        os.makedirs(tb_dir)

    ## Input list
    data_config_train = {
        'reso': (args.reso, args.reso)
    }  # if trimap is true, rcnn is used

    # DATA LOADING
    print('\n[Phase 1] : Data Preparation')

    # Original Data
    traindata = VideoData(
        csv_file='Video_data_train.csv',
        data_config=data_config_train,
        transform=None
    )  # Write a dataloader function that can read the database provided by .csv file

    train_loader = DataLoader(traindata,
                              batch_size=args.batch_size,
                              shuffle=True,
                              num_workers=args.workers,
                              collate_fn=collate_filter_none)

    print('\n[Phase 2] : Initialization')

    netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                             output_nc=4,
                             n_blocks1=args.n_blocks1,
                             n_blocks2=args.n_blocks2)
    netB = nn.DataParallel(netB)
    netB.load_state_dict(torch.load(args.init_model))
    netB.cuda()
    netB.eval()
    for param in netB.parameters():  # freeze netB
        param.requires_grad = False

    netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                             output_nc=4,
                             n_blocks1=args.n_blocks1,
                             n_blocks2=args.n_blocks2)
    netG.apply(conv_init)
    netG = nn.DataParallel(netG)
    netG.cuda()
    torch.backends.cudnn.benchmark = True

    netD = MultiscaleDiscriminator(input_nc=3,
                                   num_D=1,
                                   norm_layer=nn.InstanceNorm2d,
                                   ndf=64)
    netD.apply(conv_init)
    netD = nn.DataParallel(netD)
    netD.cuda()

    # Loss
    l1_loss = alpha_loss()
    c_loss = compose_loss()
    g_loss = alpha_gradient_loss()
    GAN_loss = GANloss()

    optimizerG = Adam(netG.parameters(), lr=1e-4)
    optimizerD = Adam(netD.parameters(), lr=1e-5)

    log_writer = SummaryWriter(tb_dir)

    print('Starting Training')
    step = 50

    KK = len(train_loader)

    wt = 1
    for epoch in range(0, args.epochs):

        netG.train()
        netD.train()

        lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

        t0 = get_time()

        for i, data in enumerate(train_loader):
            # Initiating
            bg = data['bg'].cuda()
            image = data['image'].cuda()
            seg = data['seg'].cuda()
            multi_fr = data['multi_fr'].cuda()
            seg_gt = data['seg-gt'].cuda()
            back_rnd = data['back-rnd'].cuda()

            mask0 = torch.ones(seg.shape).cuda()

            tr0 = get_time()

            # pseudo-supervision
            alpha_pred_sup, fg_pred_sup = netB(image, bg, seg, multi_fr)
            mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor).cuda()

            mask1 = (seg_gt > 0.95).type(torch.FloatTensor).cuda()

            ## Train Generator

            alpha_pred, fg_pred = netG(image, bg, seg, multi_fr)

            ##pseudo-supervised losses
            al_loss = l1_loss(
                alpha_pred_sup, alpha_pred,
                mask0) + 0.5 * g_loss(alpha_pred_sup, alpha_pred, mask0)
            fg_loss = l1_loss(fg_pred_sup, fg_pred, mask)

            # compose into same background
            comp_loss = c_loss(image, alpha_pred, fg_pred, bg, mask1)

            # randomly permute the background
            perm = torch.LongTensor(np.random.permutation(bg.shape[0]))
            bg_sh = bg[perm, :, :, :]

            al_mask = (alpha_pred > 0.95).type(torch.FloatTensor).cuda()

            # Choose the target background for composition
            # back_rnd: contains separate set of background videos captured
            # bg_sh: contains randomly permuted captured background from the same minibatch
            if np.random.random_sample() > 0.5:
                bg_sh = back_rnd

            image_sh = compose_image_withshift(
                alpha_pred, image * al_mask + fg_pred * (1 - al_mask), bg_sh,
                seg)

            fake_response = netD(image_sh)

            loss_ganG = GAN_loss(fake_response, label_type=True)

            lossG = loss_ganG + wt * (0.05 * comp_loss + 0.05 * al_loss +
                                      0.05 * fg_loss)

            optimizerG.zero_grad()

            lossG.backward()
            optimizerG.step()

            # Train Discriminator

            fake_response = netD(image_sh)
            real_response = netD(image)

            loss_ganD_fake = GAN_loss(fake_response, label_type=False)
            loss_ganD_real = GAN_loss(real_response, label_type=True)

            lossD = (loss_ganD_real + loss_ganD_fake) * 0.5

            # Update discriminator for every 5 generator update
            if i % 5 == 0:
                optimizerD.zero_grad()
                lossD.backward()
                optimizerD.step()

            lG += lossG.data
            lD += lossD.data
            GenL += loss_ganG.data
            DisL_r += loss_ganD_real.data
            DisL_f += loss_ganD_fake.data

            alL += al_loss.data
            fgL += fg_loss.data
            compL += comp_loss.data

            log_writer.add_scalar('Generator Loss', lossG.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss', lossD.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss: Real',
                                  loss_ganD_real.data, epoch * KK + i + 1)
            log_writer.add_scalar('Discriminator Loss: Fake',
                                  loss_ganD_fake.data, epoch * KK + i + 1)

            log_writer.add_scalar('Generator Loss: Alpha', al_loss.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Fg', fg_loss.data,
                                  epoch * KK + i + 1)
            log_writer.add_scalar('Generator Loss: Comp', comp_loss.data,
                                  epoch * KK + i + 1)

            t1 = get_time()

            elapse += t1 - t0
            elapse_run += t1 - tr0
            t0 = t1

            if i % step == (step - 1):
                print(f'[{epoch + 1}, {i + 1:5d}] '
                      f'Gen-loss: {lG / step:.4f} '
                      f'Disc-loss: {lD / step:.4f} '
                      f'Alpha-loss: {alL / step:.4f} '
                      f'Fg-loss: {fgL / step:.4f} '
                      f'Comp-loss: {compL / step:.4f} '
                      f'Time-all: {elapse / step:.4f} '
                      f'Time-fwbw: {elapse_run / step:.4f}')
                lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

                write_tb_log(image, 'image', log_writer, i)
                write_tb_log(seg, 'seg', log_writer, i)
                write_tb_log(alpha_pred_sup, 'alpha-sup', log_writer, i)
                write_tb_log(alpha_pred, 'alpha_pred', log_writer, i)
                write_tb_log(fg_pred_sup * mask, 'fg-pred-sup', log_writer, i)
                write_tb_log(fg_pred * mask, 'fg_pred', log_writer, i)

                # composition
                alpha_pred = (alpha_pred + 1) / 2
                comp = fg_pred * alpha_pred + (1 - alpha_pred) * bg
                write_tb_log(comp, 'composite-same', log_writer, i)
                write_tb_log(image_sh, 'composite-diff', log_writer, i)

                del comp

            del bg, image, seg, multi_fr, seg_gt, back_rnd
            del mask0, alpha_pred_sup, fg_pred_sup, mask, mask1
            del alpha_pred, fg_pred, al_loss, fg_loss, comp_loss
            del bg_sh, image_sh, fake_response, real_response
            del lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG

        if epoch % 2 == 0:
            ep = epoch + 1
            torch.save(netG.state_dict(), f'{model_dir}/netG_epoch_{ep}.pth')
            torch.save(optimizerG.state_dict(),
                       f'{model_dir}/optimG_epoch_{ep}.pth')
            torch.save(netD.state_dict(), f'{model_dir}/netD_epoch_{ep}.pth')
            torch.save(optimizerD.state_dict(),
                       f'{model_dir}/optimD_epoch_{ep}.pth')

            # Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision
            wt = wt / 2
示例#8
0
    def __init__(self, device=None, jit=True):
        self.device = device
        self.jit = jit
        self.opt = Namespace(
            **{
                'n_blocks1': 7,
                'n_blocks2': 3,
                'batch_size': 1,
                'resolution': 512,
                'name': 'Real_fixed'
            })

        data_config_train = {
            'reso': (self.opt.resolution, self.opt.resolution)
        }
        traindata = VideoData(csv_file='Video_data_train.csv',
                              data_config=data_config_train,
                              transform=None)
        self.train_loader = torch.utils.data.DataLoader(
            traindata,
            batch_size=self.opt.batch_size,
            shuffle=True,
            num_workers=self.opt.batch_size,
            collate_fn=_collate_filter_none)

        netB = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        if self.device == 'cuda':
            netB.cuda()
        netB.eval()
        for param in netB.parameters():  # freeze netB
            param.requires_grad = False
        self.netB = netB

        netG = ResnetConditionHR(input_nc=(3, 3, 1, 4),
                                 output_nc=4,
                                 n_blocks1=self.opt.n_blocks1,
                                 n_blocks2=self.opt.n_blocks2)
        netG.apply(conv_init)
        self.netG = netG

        if self.device == 'cuda':
            self.netG.cuda()
            # TODO(asuhan): is this needed?
            torch.backends.cudnn.benchmark = True

        netD = MultiscaleDiscriminator(input_nc=3,
                                       num_D=1,
                                       norm_layer=nn.InstanceNorm2d,
                                       ndf=64)
        netD.apply(conv_init)
        netD = nn.DataParallel(netD)
        self.netD = netD
        if self.device == 'cuda':
            self.netD.cuda()

        self.l1_loss = alpha_loss()
        self.c_loss = compose_loss()
        self.g_loss = alpha_gradient_loss()
        self.GAN_loss = GANloss()

        self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
        self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

        tb_dir = '/home/circleci/project/benchmark/models/Background-Matting/TB_Summary/' + self.opt.name
        if not os.path.exists(tb_dir):
            os.makedirs(tb_dir)
        self.log_writer = SummaryWriter(tb_dir)
        self.model_dir = '/home/circleci/project/benchmark/models/Background-Matting/Models/' + self.opt.name
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self._maybe_trace()