示例#1
0
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import os
from src.models.modnet import MODNet

os.environ['CUDA_VISIBLE_DEVICES'] = '0'

torch_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

print('Load pre-trained MODNet...')
pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet).cuda()
modnet.load_state_dict(torch.load(pretrained_ckpt))
modnet.eval()

print('Init WebCam...')
cap = cv2.VideoCapture("./demo/input/align2.mp4")
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)

# Step1 : Background Inpainting
bg_np = cv2.imread("./demo/input/first_frame.png")
background_np = bg_np[0:672, 480:1440, :]

bg_LR_np = cv2.resize(background_np, (512, 352), cv2.INTER_AREA)
bg_LR_PIL = Image.fromarray(bg_LR_np)
示例#2
0
        exit()
    if not os.path.exists(args.ckpt_path):
        print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
        exit()

    # define hyper-parameters
    ref_size = 512

    # define image to tensor transform
    im_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)

    if torch.cuda.is_available():
        modnet = modnet.cuda()
        weights = torch.load(args.ckpt_path)
    else:
        weights = torch.load(args.ckpt_path, map_location=torch.device('cpu'))
    modnet.load_state_dict(weights)
    modnet.eval()

    # inference images
    im_names = os.listdir(args.input_path)
    for im_name in im_names:
        print('Process image: {0}'.format(im_name))
示例#3
0
文件: run.py 项目: yqGANs/MODNet
    parser = argparse.ArgumentParser()
    parser.add_argument('--video', type=str, required=True, help='input video file')
    parser.add_argument('--result-type', type=str, default='fg', choices=['fg', 'matte'], 
                        help='matte - save the alpha matte; fg - save the foreground')
    parser.add_argument('--fps', type=int, default=30, help='fps of the result video')

    print('Get CMD Arguments...')
    args = parser.parse_args()

    if not os.path.exists(args.video):
        print('Cannot find the input video: {0}'.format(args.video))
        exit()

    print('Load pre-trained MODNet...')
    pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)

    GPU = True if torch.cuda.device_count() > 0 else False
    if GPU:
        print('Use GPU...')
        modnet = modnet.cuda()
        modnet.load_state_dict(torch.load(pretrained_ckpt))
    else:
        print('Use CPU...')
        modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
    modnet.eval()

    result = os.path.splitext(args.video)[0] + '_{0}.mp4'.format(args.result_type)
    alpha_matte = True if args.result_type == 'matte' else False
    matting(args.video, result, alpha_matte, args.fps)
示例#4
0
        exit()
    if not os.path.exists(args.ckpt_path):
        print('Cannot find ckpt path: {0}'.format(args.ckpt_path))
        exit()

    # define hyper-parameters
    ref_size = 512

    # define image to tensor transform
    im_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet).cuda()
    modnet.load_state_dict(torch.load(args.ckpt_path))
    modnet.eval()

    # inference images
    im_names = os.listdir(args.input_path)
    for im_name in im_names:
        print('Process image: {0}'.format(im_name))

        # read image
        im = Image.open(os.path.join(args.input_path, im_name))

        # unify image channels to 3
        im = np.asarray(im)
        if len(im.shape) == 2:
示例#5
0
import torch
import torch.nn as nn
from src.models.modnet import MODNet
from torch.autograd import Variable

modnet = MODNet(backbone_pretrained=False)
# modnet = nn.DataParallel(modnet).cuda()
# modnet.load_state_dict(torch.load('pretrained/modnet_webcam_portrait_matting.ckpt'))
# modnet.eval()
# torch.save(modnet.module.state_dict(), 'modnet_512x672_float32.pth')

modnet.load_state_dict(torch.load('modnet_512x672_float32.pth'))
modnet.eval()
dummy_input = Variable(torch.randn(1, 3, 512, 512))
torch.onnx.export(modnet,
                  dummy_input,
                  'modnet_512x512_float32.onnx',
                  export_params=True,
                  opset_version=12)
    #GPU = True if torch.cuda.device_count() > 0 else False

    # define hyper-parameters
    ref_size = 512

    # define image to tensor transform
    im_transform = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
        ]
    )

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)#.cuda()
    modnet.load_state_dict(torch.load(args.ckpt_path, map_location=torch.device('cpu')))
    modnet.eval()

    # inference images

    frame_np = cv2.imread('{}'.format(args.input_path))
    #print(frame_np.shape)
    (width, height, channel) = frame_np.shape
    #print(frame_np)
    frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
    frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
    frame_np = frame_np[:, 120:792, :]
    frame_np = cv2.flip(frame_np, 1)
示例#7
0
class BGRemove():
    # define hyper-parameters
    ref_size = 512

    # define image to tensor transform
    im_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    # create MODNet and load the pre-trained ckpt
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)
    if device == 'cuda':
        modnet = modnet.cuda()

    def __init__(self, ckpt_path):
        self.parameter_load(ckpt_path)

    def parameter_load(self, ckpt_path):
        BGRemove.modnet.load_state_dict(
            torch.load(ckpt_path, map_location=BGRemove.device))
        BGRemove.modnet.eval()

    def file_load(self, filename):
        im = cv2.imread(filename)
        if len(im.shape) == 2:
            im = im[:, :, None]
        if im.shape[2] == 1:
            im = np.repeat(im, 3, axis=2)
        elif im.shape[2] == 4:
            im = im[:, :, 0:3]

        return im

    def dir_check(self, path):
        os.makedirs(path, exist_ok=True)
        if not path.endswith('/'):
            path += '/'
        return path

    def pre_process(self, im):
        self.original_im = copy.deepcopy(im)

        # convert image to PyTorch tensor
        im = BGRemove.im_transform(im)

        # add mini-batch dim
        im = im[None, :, :, :]

        # resize image for input
        im_b, im_c, im_h, im_w = im.shape
        self.height, self.width = im_h, im_w

        if max(im_h, im_w) < BGRemove.ref_size or min(
                im_h, im_w) > BGRemove.ref_size:
            if im_w >= im_h:
                im_rh = BGRemove.ref_size
                im_rw = int(im_w / im_h * BGRemove.ref_size)
            elif im_w < im_h:
                im_rw = BGRemove.ref_size
                im_rh = int(im_h / im_w * BGRemove.ref_size)
        else:
            im_rh = im_h
            im_rw = im_w

        im_rw = im_rw - im_rw % 32
        im_rh = im_rh - im_rh % 32
        im = F.interpolate(im, size=(im_rh, im_rw), mode='area')
        if BGRemove.device == 'cuda':
            im = im.cuda()
        return im

    def post_process(self,
                     mask_data,
                     background=False,
                     backgound_path='assets/background/background.jpg'):
        matte = F.interpolate(mask_data,
                              size=(self.height, self.width),
                              mode='area')
        matte = matte.repeat(1, 3, 1, 1)
        matte = matte[0].data.cpu().numpy().transpose(1, 2, 0)
        height, width, _ = matte.shape
        if background:
            back_image = self.file_load(backgound_path)
            back_image = cv2.resize(back_image, (width, height),
                                    cv2.INTER_AREA)
        else:
            back_image = np.full(self.original_im.shape, 255.0)

        matte = matte * self.original_im + (1 - matte) * back_image
        return matte

    def image(self, filename, background=False, output='output/', save=True):
        output = self.dir_check(output)

        self.im_name = filename.split('/')[-1]
        im = self.file_load(filename)
        im = self.pre_process(im)
        _, _, matte = BGRemove.modnet(im, inference=False)
        matte = self.post_process(matte, background)

        if save:
            matte = np.uint8(matte)
            return self.save(matte, output)
        else:
            h, w, _ = matte.shape
            r_h, r_w = 720, int((w / h) * 720)
            image = cv2.resize(self.original_im, (r_w, r_h), cv2.INTER_AREA)
            matte = cv2.resize(matte, (r_w, r_h), cv2.INTER_AREA)

            full_image = np.uint8(np.concatenate((image, matte), axis=1))
            self.save(full_image, output)
            exit_key = ord('q')
            while True:
                if cv2.waitKey(exit_key) & 255 == exit_key:
                    cv2.destroyAllWindows()
                    break
                cv2.imshow(
                    'MODNet - {} [Press "Q" To Exit]'.format(self.im_name),
                    full_image)

    def video(self, filename, background=False, output='output/'):
        output = self.dir_check(output)

        output_name = filename.split('/')[-1]
        extension = output_name.split('.')[-1]
        output_name = output_name.replace(extension, 'mp4')

        fourcc = cv2.VideoWriter_fourcc(*'MP4V')

        cap = cv2.VideoCapture(filename)
        flag = 1
        if (cap.isOpened() == False):
            print("Error opening video stream or file")
        exit_key = ord('q')
        while (cap.isOpened()):
            ret, frame = cap.read()
            if flag:
                height, width, _ = frame.shape
                out = cv2.VideoWriter(output + output_name, fourcc, 20.0,
                                      (2 * width, height))
                flag = 0

            if ret:
                print('Video is processing..', end='\r')

                im = self.pre_process(frame)
                _, _, matte = BGRemove.modnet(im, inference=False)
                matte = np.uint8(self.post_process(matte, background))
                full_image = np.concatenate((frame, matte), axis=1)
                full_image = np.uint8(
                    cv2.resize(full_image, (2 * width, height),
                               cv2.INTER_AREA))
                out.write(full_image)
            else:
                break
        cap.release()
        out.release()
        cv2.destroyAllWindows()

    def folder(self, foldername, background=False, output='output/'):
        output = self.dir_check(output)
        foldername = self.dir_check(foldername)

        for filename in os.listdir(foldername):
            try:
                self.im_name = filename
                im = self.file_load(foldername + filename)
                im = self.pre_process(im)
                _, _, matte = BGRemove.modnet(im, inference=False)
                matte = self.post_process(matte, background)
                status = self.save(matte, output)
                print(status)
            except:
                print(
                    'There is an error for {} file/folder'.format(foldername +
                                                                  filename))

    def webcam(self, background=False):
        cap = cv2.VideoCapture(0)
        cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
        cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)
        width, height = 455, 512

        exit_key = ord('q')
        while (True):
            _, frame_np = cap.read()
            frame_np = cv2.resize(frame_np, (width, height), cv2.INTER_AREA)
            im = self.pre_process(frame_np)
            _, _, matte = BGRemove.modnet(im, inference=False)
            processed_image = self.post_process(matte, background)

            full_image = np.concatenate((frame_np, processed_image), axis=1)
            full_image = np.uint8(
                cv2.resize(full_image, (2 * width, height), cv2.INTER_AREA))

            if cv2.waitKey(exit_key) & 255 == exit_key:
                cv2.destroyAllWindows()
                break
            cv2.imshow('MODNet - WebCam [Press "Q" To Exit]', full_image)

    def save(self, matte, output_path='output/'):
        path = os.path.join(output_path, self.im_name)
        try:
            cv2.imwrite(path, matte)
            return "Successfully saved {}".format(path)
        except:
            return "Error while saving {}".format(path)
示例#8
0
# -*- coding: utf-8 -*-

import torch
from src.dataset import HumanMattingDataset
from torch.utils.data import DataLoader

from src.models.modnet import MODNet
from src.trainer import supervised_training_iter

bs = 16         # batch size
lr = 0.01       # learn rate
epochs = 40     # total epochs

modnet = torch.nn.DataParallel(MODNet()).cuda()
optimizer = torch.optim.SGD(modnet.parameters(), lr=lr, momentum=0.9)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=int(0.25 * epochs), gamma=0.1)

train_file="train data list"

train_dataset=HumanMattingDataset(files=train_file)
dataloader=DataLoader(train_dataset,batch_size=bs,shuffle=True)
#dataloader = HumanMattingDataLoader(files=train_file)
iter_time=len(dataloader)
for epoch in range(0, epochs):
    for idx, (image, trimap, gt_matte) in enumerate(dataloader):
        image=image.to(device='cuda')
        trimap=trimap.to(device='cuda')
        gt_matte=gt_matte.to(device='cuda')
        semantic_loss, detail_loss, matte_loss = \
            supervised_training_iter(modnet, optimizer, image, trimap, gt_matte)
        lr_scheduler.step()