Пример #1
0
    def forward(self, th):
        th = th + 1
        if self.iter_num == 0:
            self.image = cv2.imread(
                os.path.join(self.img_root, self.frame_name_list[th]))
            self.image = cv2.resize(self.image, (self.shape[1], self.shape[0]))
        else:
            self.image = self.result_pool[th]

        flow1 = flo.readFlow(
            os.path.join(self.flow_root,
                         '%05d.flo' % (th - 1 + self.flow_start_no)))
        flow2 = flo.readFlow(
            os.path.join(self.flow_root,
                         '%05d.rflo' % (th + self.flow_start_no)))
        flow1 = flo.flow_tf(flow1, self.image.shape)
        flow2 = flo.flow_tf(flow2, self.image.shape)

        if self.iter_num == 0:
            self.label = cv2.imread(
                os.path.join(self.mask_root,
                             '%05d.png' % (th + self.flow_start_no)),
                cv2.IMREAD_COLOR)
            self.label = cv2.resize(self.label,
                                    (self.image.shape[1], self.image.shape[0]),
                                    interpolation=cv2.INTER_NEAREST)
        else:
            self.label = self.label_pool[th]

        if len(self.label.shape) == 3:
            self.label = self.label[:, :, 0]

        if self.args.enlarge_mask and self.iter_num == 0:
            kernel = np.ones(
                (self.args.enlarge_kernel, self.args.enlarge_kernel), np.uint8)
            self.label = cv2.dilate(self.label, kernel, iterations=1)

        self.label = (self.label > 0).astype(np.uint8)
        self.image[(self.label > 0), :] = 0

        temp1 = flo.get_warp_label(flow1,
                                   flow2,
                                   self.results[th - 1][..., 0],
                                   th=self.th_warp)
        temp2 = flo.get_warp_label(flow1,
                                   flow2,
                                   self.time_stamp[th - 1],
                                   th=self.th_warp,
                                   value=-1)[..., 0]

        self.results[th][..., 0] = temp1
        self.time_stamp[th][..., 0] = temp2

        self.results[th][self.label == 0, :,
                         0] = self.image[self.label == 0, :]
        self.time_stamp[th][self.label == 0, 0] = th
Пример #2
0
def propagation(args, frame_inapint_model=None):
    # Setup dataset

    img_root = args.img_root
    mask_root = args.mask_root
    flow_root = args.flow_root
    output_root = args.output_root_propagation

    # print(img_root)
    # print(args.img_shape)
    # print(mask_root)

    # the shape list may be changed in the below, pls check it
    img_shape = args.img_shape
    th_warp = args.th_warp

    video_list = os.listdir(flow_root)
    video_list.sort()

    st_time = time.time()

    flow_no_list = [int(x[:5]) for x in os.listdir(flow_root) if '.flo' in x]
    flow_start_no = min(flow_no_list)
    print('Flow Start no', flow_start_no)
    if not os.path.exists(output_root):
        os.makedirs(output_root)

    frame_name_list = sorted(os.listdir(img_root))
    frames_num = len(frame_name_list)
    frame_inpaint_seq = np.ones(frames_num - 1)
    masked_frame_num = np.sum((frame_inpaint_seq > 0).astype(np.int))
    print(masked_frame_num, 'frames need to be inpainted.')

    image = cv2.imread(os.path.join(img_root, frame_name_list[0]))
    if img_shape[0] < 1:
        shape = image.shape
    else:
        shape = img_shape
    print('The output shape is:', shape)

    image = cv2.resize(image, (shape[1], shape[0]))
    iter_num = 0
    result_pool = [
        np.zeros(image.shape, dtype=image.dtype) for _ in range(frames_num)
    ]
    label_pool = [
        np.zeros(image.shape, dtype=image.dtype) for _ in range(frames_num)
    ]

    while masked_frame_num > 0:

        results = [
            np.zeros(image.shape + (2, ), dtype=image.dtype)
            for _ in range(frames_num)
        ]
        time_stamp = [
            -np.ones(image.shape[:2] + (2, ), dtype=int)
            for _ in range(frames_num)
        ]

        print('Iter', iter_num, 'Forward Propagation')
        # forward
        if iter_num == 0:
            image = cv2.imread(os.path.join(img_root, frame_name_list[0]))
            image = cv2.resize(image, (shape[1], shape[0]))
            if args.FIX_MASK:
                label = cv2.imread(os.path.join(mask_root),
                                   cv2.IMREAD_UNCHANGED)
            else:
                label = cv2.imread(
                    os.path.join(mask_root, '%05d.png' % (flow_start_no)),
                    cv2.IMREAD_UNCHANGED)
            print(flow_start_no)
            label = cv2.resize(label, (image.shape[1], image.shape[0]),
                               interpolation=cv2.INTER_NEAREST)
        else:
            image = result_pool[0]
            label = label_pool[0]

        if len(label.shape) == 3:
            label = label[:, :, 0]
        if args.enlarge_mask and iter_num == 0:
            kernel = np.ones((args.enlarge_kernel, args.enlarge_kernel),
                             np.uint8)
            label = cv2.dilate(label, kernel, iterations=1)

        label = (label > 0).astype(np.uint8)
        image[label > 0, :] = 0

        results[0][..., 0] = image
        time_stamp[0][label == 0, 0] = 0
        prog_bar = ProgressBar(frames_num - 1)
        for th in range(1, frames_num - 1):
            prog_bar.update()
            if iter_num == 0:
                image = cv2.imread(os.path.join(img_root, frame_name_list[th]))
                image = cv2.resize(image, (shape[1], shape[0]))
            else:
                image = result_pool[th]

            flow1 = flo.readFlow(os.path.join(flow_root, '%05d0.flo' % (th)))
            flow2 = flo.readFlow(os.path.join(flow_root, '%05d1.flo' % (th)))
            flow3 = flo.readFlow(os.path.join(flow_root, '%05d2.flo' % (th)))
            flow4 = flo.readFlow(os.path.join(flow_root, '%05d3.flo' % (th)))

            flow1 = flo.flow_tf(flow1, image.shape)
            flow2 = flo.flow_tf(flow2, image.shape)
            flow3 = flo.flow_tf(flow3, image.shape)
            flow4 = flo.flow_tf(flow4, image.shape)

            if iter_num == 0:
                if not args.FIX_MASK:
                    label = cv2.imread(
                        os.path.join(mask_root, '%05d.png' % (th)),
                        cv2.IMREAD_UNCHANGED)
                else:
                    label = cv2.imread(os.path.join(mask_root),
                                       cv2.IMREAD_UNCHANGED)
                label = cv2.resize(label, (image.shape[1], image.shape[0]),
                                   interpolation=cv2.INTER_NEAREST)
            else:
                label = label_pool[th]

            if len(label.shape) == 3:
                label = label[:, :, 0]

            if args.enlarge_mask and iter_num == 0:
                kernel = np.ones((args.enlarge_kernel, args.enlarge_kernel),
                                 np.uint8)
                label = cv2.dilate(label, kernel, iterations=1)

            label = (label > 0).astype(np.uint8)
            image[(label > 0), :] = 0

            temp1 = flo.get_warp_label(flow1,
                                       flow2,
                                       results[th - 1][..., 0],
                                       th=th_warp)
            #print(temp1)
            temp3 = flo.get_warp_label(flow3, flow4, temp1, th=th_warp)

            temp2 = flo.get_warp_label(flow1,
                                       flow2,
                                       time_stamp[th - 1],
                                       th=th_warp,
                                       value=-1)[..., 0]
            #print(time_stamp[th - 1])
            temp6 = [-np.ones(image.shape[:2] + (2, ), dtype=int)]
            #print(temp5[0].shape)
            temp6[0][..., 0] = temp2
            #print(temp5.shape)
            temp4 = flo.get_warp_label(flow3,
                                       flow4,
                                       temp6[0],
                                       th=th_warp,
                                       value=-1)[..., 0]

            results[th][..., 0] = temp3
            time_stamp[th][..., 0] = temp4

            results[th][label == 0, :, 0] = image[label == 0, :]
            time_stamp[th][label == 0, 0] = th

        sys.stdout.write('\n')
        print('Iter', iter_num, 'Backward Propagation')
        # backward
        if iter_num == 0:

            image = cv2.imread(
                os.path.join(img_root, frame_name_list[frames_num - 1]))
            image = cv2.resize(image, (shape[1], shape[0]))

            if not args.FIX_MASK:
                label = cv2.imread(
                    os.path.join(mask_root, '%05d.png' % (frames_num - 1)),
                    cv2.IMREAD_UNCHANGED)
            else:
                label = cv2.imread(os.path.join(mask_root),
                                   cv2.IMREAD_UNCHANGED)
            label = cv2.resize(label, (image.shape[1], image.shape[0]),
                               interpolation=cv2.INTER_NEAREST)
        else:
            image = result_pool[-1]
            label = label_pool[-1]

        if len(label.shape) == 3:
            label = label[:, :, 0]
        if args.enlarge_mask and iter_num == 0:
            kernel = np.ones((args.enlarge_kernel, args.enlarge_kernel),
                             np.uint8)
            label = cv2.dilate(label, kernel, iterations=1)

        label = (label > 0).astype(np.uint8)
        image[(label > 0), :] = 0

        results[frames_num - 1][..., 1] = image
        time_stamp[frames_num - 1][label == 0, 1] = frames_num - 1
        prog_bar = ProgressBar(frames_num - 1)
        for th in range(frames_num - 2, 0, -1):
            prog_bar.update()
            if iter_num == 0:
                image = cv2.imread(os.path.join(img_root, frame_name_list[th]))
                image = cv2.resize(image, (shape[1], shape[0]))
                if not args.FIX_MASK:
                    label = cv2.imread(
                        os.path.join(mask_root, '%05d.png' % (th)),
                        cv2.IMREAD_UNCHANGED)
                else:
                    label = cv2.imread(os.path.join(mask_root),
                                       cv2.IMREAD_UNCHANGED)
                label = cv2.resize(label, (image.shape[1], image.shape[0]),
                                   interpolation=cv2.INTER_NEAREST)
            else:
                image = result_pool[th]
                label = label_pool[th]

            flow1 = flo.readFlow(os.path.join(flow_root, '%05d3.rflo' % (th)))
            flow2 = flo.readFlow(os.path.join(flow_root, '%05d2.rflo' % (th)))
            flow3 = flo.readFlow(os.path.join(flow_root, '%05d1.rflo' % (th)))
            flow4 = flo.readFlow(os.path.join(flow_root, '%05d0.rflo' % (th)))

            flow1 = flo.flow_tf(flow1, image.shape)
            flow2 = flo.flow_tf(flow2, image.shape)
            flow3 = flo.flow_tf(flow3, image.shape)
            flow4 = flo.flow_tf(flow4, image.shape)

            if len(label.shape) == 3:
                label = label[:, :, 0]
            if args.enlarge_mask and iter_num == 0:
                kernel = np.ones((args.enlarge_kernel, args.enlarge_kernel),
                                 np.uint8)
                label = cv2.dilate(label, kernel, iterations=1)

            label = (label > 0).astype(np.uint8)
            image[(label > 0), :] = 0

            temp1 = flo.get_warp_label(flow1,
                                       flow2,
                                       results[th + 1][..., 1],
                                       th=th_warp)
            temp3 = flo.get_warp_label(flow3, flow4, temp1, th=th_warp)
            temp2 = flo.get_warp_label(
                flow1,
                flow2,
                time_stamp[th + 1],
                value=-1,
                th=th_warp,
            )[..., 1]

            temp6 = [-np.ones(image.shape[:2] + (2, ), dtype=int)]
            #print(temp5[0].shape)
            temp6[0][..., 1] = temp2
            #print(temp5.shape)
            temp7 = flo.get_warp_label(flow3,
                                       flow4,
                                       temp6[0],
                                       th=th_warp,
                                       value=-1)[..., 1]

            results[th][..., 1] = temp3
            time_stamp[th][..., 1] = temp7

            results[th][label == 0, :, 1] = image[label == 0, :]
            time_stamp[th][label == 0, 1] = th

        sys.stdout.write('\n')
        tmp_label_seq = np.zeros(frames_num - 1)
        print('Iter', iter_num, 'Merge Results')
        # merge
        prog_bar = ProgressBar(frames_num)
        for th in range(0, frames_num - 1):
            prog_bar.update()
            v1 = (time_stamp[th][..., 0] == -1)
            v2 = (time_stamp[th][..., 1] == -1)

            hole_v = (v1 & v2)

            result = results[th][..., 0].copy()
            result[v1, :] = results[th][v1, :, 1].copy()

            v3 = ((v1 == 0) & (v2 == 0))

            dist = time_stamp[th][..., 1] - time_stamp[th][..., 0]
            dist[dist < 1] = 1

            w2 = (th - time_stamp[th][..., 0]) / dist
            w2 = (w2 > 0.5).astype(np.float)

            result[v3, :] = (results[th][..., 1] * w2[..., np.newaxis] +
                             results[th][..., 0] *
                             (1 - w2)[..., np.newaxis])[v3, :]

            result_pool[th] = result.copy()

            tmp_mask = np.zeros_like(result)
            tmp_mask[hole_v, :] = 255
            label_pool[th] = tmp_mask.copy()
            tmp_label_seq[th] = np.sum(tmp_mask)

        sys.stdout.write('\n')
        frame_inpaint_seq[tmp_label_seq == 0] = 0
        masked_frame_num = np.sum((frame_inpaint_seq > 0).astype(np.int))
        print(masked_frame_num)
        iter_num += 1

        if masked_frame_num > 0:
            key_frame_ids = get_key_ids(frame_inpaint_seq)
            print(key_frame_ids)
            for id in key_frame_ids:
                with torch.no_grad():
                    tmp_inpaint_res = frame_inapint_model.forward(
                        result_pool[id], label_pool[id])
                label_pool[id] = label_pool[id] * 0.
                result_pool[id] = tmp_inpaint_res
        else:
            print(frames_num, 'frames have been inpainted by', iter_num,
                  'iterations.')

        tmp_label_seq = np.zeros(frames_num - 1)
        for th in range(0, frames_num - 1):
            tmp_label_seq[th] = np.sum(label_pool[th])
        frame_inpaint_seq[tmp_label_seq == 0] = 0
        masked_frame_num = np.sum((frame_inpaint_seq > 0).astype(np.int))
        print(masked_frame_num)

        print('Writing frames to:', os.path.join(output_root, 'inpaint_res'))

        if not os.path.exists(os.path.join(output_root, 'inpaint_res')):
            os.makedirs(os.path.join(output_root, 'inpaint_res'))

        for th in range(1, frames_num - 1):
            cv2.imwrite(
                os.path.join(output_root, 'inpaint_res', '%05d.png' % (th)),
                result_pool[th].astype(np.uint8))

    print('Propagation has been finished')
    pro_time = time.time() - st_time
    print(pro_time)