示例#1
0
文件: run.py 项目: yqGANs/MODNet
    parser.add_argument('--video', type=str, required=True, help='input video file')
    parser.add_argument('--result-type', type=str, default='fg', choices=['fg', 'matte'], 
                        help='matte - save the alpha matte; fg - save the foreground')
    parser.add_argument('--fps', type=int, default=30, help='fps of the result video')

    print('Get CMD Arguments...')
    args = parser.parse_args()

    if not os.path.exists(args.video):
        print('Cannot find the input video: {0}'.format(args.video))
        exit()

    print('Load pre-trained MODNet...')
    pretrained_ckpt = './pretrained/modnet_webcam_portrait_matting.ckpt'
    modnet = MODNet(backbone_pretrained=False)
    modnet = nn.DataParallel(modnet)

    GPU = True if torch.cuda.device_count() > 0 else False
    if GPU:
        print('Use GPU...')
        modnet = modnet.cuda()
        modnet.load_state_dict(torch.load(pretrained_ckpt))
    else:
        print('Use CPU...')
        modnet.load_state_dict(torch.load(pretrained_ckpt, map_location=torch.device('cpu')))
    modnet.eval()

    result = os.path.splitext(args.video)[0] + '_{0}.mp4'.format(args.result_type)
    alpha_matte = True if args.result_type == 'matte' else False
    matting(args.video, result, alpha_matte, args.fps)
示例#2
0
                    default=False,
                    help="use cpu inferece")
args = parser.parse_args()

torch_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

print('Load pre-trained MODNet...')
pretrained_ckpt = args.ckpt_path
modnet = MODNet(backbone_pretrained=False)
modnet = nn.DataParallel(modnet)
modnet.load_state_dict(torch.load(pretrained_ckpt, map_location='cpu'))
if not args.cpu:
    modnet.cuda()
modnet.eval()

print('Init WebCam...')
cap = cv2.VideoCapture(0)
cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1280)
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 720)

print('Start matting...')
while (True):
    _, frame_np = cap.read()
    frame_np = cv2.cvtColor(frame_np, cv2.COLOR_BGR2RGB)
    frame_np = cv2.resize(frame_np, (910, 512), cv2.INTER_AREA)
    frame_np = frame_np[:, 120:792, :]
    frame_np = cv2.flip(frame_np, 1)