コード例 #1
0
            if len(seg_img.shape) == 3:
                seg_img = np.squeeze(seg_img[:, :, 0])

            seg_height, seg_width = seg_img.shape

            if seg_width == 2 * src_width or seg_width == 3 * src_width:
                _start_id = seg_width - src_width
                seg_img = seg_img[:, _start_id:]

            # print('seg_img.shape: ', seg_img.shape)
            # print('labels_img_orig.shape: ', labels_img_orig.shape)

            pix_acc[img_id] = eval.pixel_accuracy(seg_img, labels_img_orig)
            _acc, mean_acc[img_id] = eval.mean_accuracy(seg_img,
                                                        labels_img_orig,
                                                        return_acc=1)
            _IU, mean_IU[img_id] = eval.mean_IU(seg_img,
                                                labels_img_orig,
                                                return_iu=1)
            fw_IU[img_id], _fw = eval.frequency_weighted_IU(seg_img,
                                                            labels_img_orig,
                                                            return_freq=1)

            fw_sum += _fw
            _fw_total = np.sum(_fw)

            # print('_fw: {}'.format(_fw))
            # print('_fw_total: {}'.format(_fw_total))

            _fw_frac = np.array(_fw) / float(_fw_total)
コード例 #2
0
def run(params):
    """

    :param VisParams params:
    :return:
    """
    eval_mode = False

    if params.multi_sequence_db:
        assert params.vis_split, "vis_split must be provided for CTC"
        """some repeated code here to allow better IntelliSense"""
        if params.dataset.lower() == 'ctc':
            from new_deeplab.datasets.build_ctc_data import CTCInfo
            db_splits = CTCInfo.DBSplits().__dict__
            sequences = CTCInfo.sequences
        elif params.dataset.lower() in ('ipsc', 'ipsc_2_class',
                                        'ipsc_5_class'):
            from new_deeplab.datasets.ipsc_info import IPSCInfo
            db_splits = IPSCInfo.DBSplits().__dict__
            sequences = IPSCInfo.sequences
        elif params.dataset.lower() == 'ipsc_patches':
            from new_deeplab.datasets.ipsc_info import IPSCPatchesInfo
            db_splits = IPSCPatchesInfo.DBSplits().__dict__
            sequences = IPSCPatchesInfo.sequences
        else:
            raise AssertionError(
                'multi_sequence_db {} is not supported yet'.format(
                    params.dataset))

        seq_ids = db_splits[params.vis_split]

        src_files = []
        seg_labels_list = []
        if params.no_labels:
            src_labels_list = None
        else:
            src_labels_list = []

        total_frames = 0
        seg_total_frames = 0

        for seq_id in seq_ids:
            seq_name, n_frames = sequences[seq_id]

            images_path = os.path.join(params.images_path, seq_name)

            if params.no_labels:
                labels_path = ''
            else:
                labels_path = os.path.join(params.labels_path, seq_name)

            _src_files, _src_labels_list, _total_frames = read_data(
                images_path, params.images_ext, labels_path, params.labels_ext)

            _src_filenames = [
                os.path.splitext(os.path.basename(k))[0] for k in _src_files
            ]

            if not params.no_labels:
                _src_labels_filenames = [
                    os.path.splitext(os.path.basename(k))[0]
                    for k in _src_labels_list
                ]

                assert _src_labels_filenames == _src_filenames, "mismatch between image and label filenames"

            eval_mode = False
            if params.seg_path and params.seg_ext:
                seg_path = os.path.join(params.seg_path, seq_name)

                _, _seg_labels_list, _seg_total_frames = read_data(
                    labels_path=seg_path,
                    labels_ext=params.seg_ext,
                    labels_type='seg')

                _seg_labels__filenames = [
                    os.path.splitext(os.path.basename(k))[0]
                    for k in _seg_labels_list
                ]

                if _seg_total_frames != _total_frames:

                    if params.seg_on_subset and _seg_total_frames < _total_frames:
                        matching_ids = [
                            _src_filenames.index(k)
                            for k in _seg_labels__filenames
                        ]

                        _src_files = [_src_files[i] for i in matching_ids]
                        if not params.no_labels:
                            _src_labels_list = [
                                _src_labels_list[i] for i in matching_ids
                            ]

                        _total_frames = _seg_total_frames

                    else:
                        raise AssertionError(
                            'Mismatch between no. of frames in GT and seg labels: {} and {}'
                            .format(_total_frames, _seg_total_frames))

                    seg_labels_list += _seg_labels_list

                seg_total_frames += _seg_total_frames
                eval_mode = True

            src_files += _src_files
            if not params.no_labels:
                src_labels_list += _src_labels_list
            # else:
            #     params.stitch = params.save_stitched = 1

            total_frames += _total_frames
    else:
        src_files, src_labels_list, total_frames = read_data(
            params.images_path, params.images_ext, params.labels_path,
            params.labels_ext)

        eval_mode = False
        if params.labels_path and params.seg_path and params.seg_ext:
            _, seg_labels_list, seg_total_frames = read_data(
                labels_path=params.seg_path,
                labels_ext=params.seg_ext,
                labels_type='seg')
            if seg_total_frames != total_frames:
                raise SystemError(
                    'Mismatch between no. of frames in GT and seg labels: {} and {}'
                    .format(total_frames, seg_total_frames))
            eval_mode = True
        # else:
        #     params.stitch = params.save_stitched = 1

    if params.end_id < params.start_id:
        params.end_id = total_frames - 1

    classes, composite_classes = read_class_info(params.class_info_path)
    n_classes = len(classes)
    class_ids = list(range(n_classes))
    class_id_to_color = {i: k[1] for i, k in enumerate(classes)}

    all_classes = [k[0] for k in classes + composite_classes]

    if not params.save_path:
        if eval_mode:
            params.save_path = os.path.join(os.path.dirname(params.seg_path),
                                            'vis')
        else:
            params.save_path = os.path.join(
                os.path.dirname(params.images_path), 'vis')

    if not os.path.isdir(params.save_path):
        os.makedirs(params.save_path)

    if params.stitch and params.save_stitched:
        print('Saving visualization images to: {}'.format(params.save_path))

    log_fname = os.path.join(params.save_path,
                             'vis_log_{:s}.txt'.format(getDateTime()))
    print('Saving log to: {}'.format(log_fname))

    save_path_parent = os.path.dirname(params.save_path)
    templ_1 = os.path.basename(save_path_parent)
    templ_2 = os.path.basename(os.path.dirname(save_path_parent))

    templ = '{}_{}'.format(templ_1, templ_2)

    # if params.selective_mode:
    #     label_diff = int(255.0 / n_classes)
    # else:
    #     label_diff = int(255.0 / (n_classes - 1))

    print('templ: {}'.format(templ))
    # print('label_diff: {}'.format(label_diff))

    n_frames = params.end_id - params.start_id + 1

    pix_acc = np.zeros((n_frames, ))

    mean_acc = np.zeros((n_frames, ))
    # mean_acc_ice = np.zeros((n_frames,))
    # mean_acc_ice_1 = np.zeros((n_frames,))
    # mean_acc_ice_2 = np.zeros((n_frames,))
    #
    mean_IU = np.zeros((n_frames, ))
    # mean_IU_ice = np.zeros((n_frames,))
    # mean_IU_ice_1 = np.zeros((n_frames,))
    # mean_IU_ice_2 = np.zeros((n_frames,))

    fw_IU = np.zeros((n_frames, ))
    fw_sum = np.zeros((n_classes, ))

    print_diff = max(1, int(n_frames * 0.01))

    avg_mean_acc = {c: 0 for c in all_classes}
    avg_mean_IU = {c: 0 for c in all_classes}
    skip_mean_acc = {c: 0 for c in all_classes}
    skip_mean_IU = {c: 0 for c in all_classes}

    _pause = 1
    labels_img = None

    for img_id in range(params.start_id, params.end_id + 1):

        stitched = []

        # img_fname = '{:s}_{:d}.{:s}'.format(fname_templ, img_id + 1, img_ext)
        src_img_fname = src_files[img_id]
        img_dir = os.path.dirname(src_img_fname)
        seq_name = os.path.basename(img_dir)
        img_fname = os.path.basename(src_img_fname)

        img_fname_no_ext = os.path.splitext(img_fname)[0]

        src_img = None
        border_img = None

        if params.stitch or params.show_img:
            # src_img_fname = os.path.join(params.images_path, img_fname)
            src_img = cv2.imread(src_img_fname)
            if src_img is None:
                raise SystemError(
                    'Source image could not be read from: {}'.format(
                        src_img_fname))

            try:
                src_height, src_width, _ = src_img.shape
            except ValueError as e:
                print('src_img_fname: {}'.format(src_img_fname))
                print('src_img: {}'.format(src_img))
                print('src_img.shape: {}'.format(src_img.shape))
                print('error: {}'.format(e))
                sys.exit(1)

            if not params.blended:
                stitched.append(src_img)

            border_img = np.full_like(src_img, 255)
            border_img = border_img[:, :5, ...]

        if not params.no_labels:
            # labels_img_fname = os.path.join(params.labels_path, img_fname_no_ext + '.{}'.format(params.labels_ext))
            labels_img_fname = src_labels_list[img_id]

            labels_img_orig = cv2.imread(labels_img_fname)
            if labels_img_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))

            _, src_width = labels_img_orig.shape[:2]

            # if len(labels_img_orig.shape) == 3:
            #     labels_img_orig = np.squeeze(labels_img_orig[:, :, 0])

            if params.show_img:
                cv2.imshow('labels_img_orig', labels_img_orig)

            labels_img_orig, label_img_raw, class_to_ids = remove_fuzziness_in_mask(
                labels_img_orig,
                n_classes,
                class_id_to_color,
                fuzziness=5,
                check_equality=0)
            labels_img = np.copy(labels_img_orig)
            # if params.normalize_labels:
            #     if params.selective_mode:
            #         selective_idx = (labels_img_orig == 255)
            #         print('labels_img_orig.shape: {}'.format(labels_img_orig.shape))
            #         print('selective_idx count: {}'.format(np.count_nonzero(selective_idx)))
            #         labels_img_orig[selective_idx] = n_classes
            #         if params.show_img:
            #             cv2.imshow('labels_img_orig norm', labels_img_orig)
            #     labels_img = (labels_img_orig.astype(np.float64) * label_diff).astype(np.uint8)
            # else:
            #     labels_img = np.copy(labels_img_orig)

            # if len(labels_img.shape) != 3:
            #     labels_img = np.stack((labels_img, labels_img, labels_img), axis=2)

            if params.stitch:
                if params.blended:
                    full_mask_gs = cv2.cvtColor(labels_img, cv2.COLOR_BGR2GRAY)
                    mask_binary = full_mask_gs == 0
                    labels_img_vis = (0.5 * src_img + 0.5 * labels_img).astype(
                        np.uint8)
                    labels_img_vis[mask_binary] = src_img[mask_binary]
                else:
                    labels_img_vis = labels_img

                if params.add_border:
                    stitched.append(border_img)
                stitched.append(labels_img_vis)

            if eval_mode:
                # seg_img_fname = os.path.join(params.seg_path, img_fname_no_ext + '.{}'.format(params.seg_ext))
                seg_img_fname = seg_labels_list[img_id]

                seg_img = cv2.imread(seg_img_fname)
                if seg_img is None:
                    raise SystemError(
                        'Segmentation image could not be read from: {}'.format(
                            seg_img_fname))

                # seg_img = convert_to_raw_mask(seg_img, n_classes, seg_img_fname)

                if len(seg_img.shape) == 3:
                    seg_img = np.squeeze(seg_img[:, :, 0])

                eval_cl, _ = eval.extract_classes(seg_img)
                gt_cl, _ = eval.extract_classes(label_img_raw)

                # if seg_img.max() > n_classes - 1:
                #     seg_img = (seg_img.astype(np.float64) / label_diff).astype(np.uint8)

                seg_height, seg_width = seg_img.shape

                if seg_width == 2 * src_width or seg_width == 3 * src_width:
                    _start_id = seg_width - src_width
                    seg_img = seg_img[:, _start_id:]

                # print('seg_img.shape: ', seg_img.shape)
                # print('labels_img_orig.shape: ', labels_img_orig.shape)

                pix_acc[img_id] = eval.pixel_accuracy(seg_img, label_img_raw,
                                                      class_ids)
                _acc, mean_acc[img_id] = eval.mean_accuracy(seg_img,
                                                            label_img_raw,
                                                            class_ids,
                                                            return_acc=1)
                _IU, mean_IU[img_id] = eval.mean_IU(seg_img,
                                                    label_img_raw,
                                                    class_ids,
                                                    return_iu=1)
                fw_IU[img_id], _fw = eval.frequency_weighted_IU(seg_img,
                                                                label_img_raw,
                                                                class_ids,
                                                                return_freq=1)
                # try:
                #     fw_sum += _fw
                # except ValueError as e:
                #     print('fw_sum: {}'.format(fw_sum))
                #     print('_fw: {}'.format(_fw))
                #
                #     eval_cl, _ = eval.extract_classes(seg_img)
                #     gt_cl, _ = eval.extract_classes(label_img_raw)
                #     cl = np.union1d(eval_cl, gt_cl)
                #
                #     print('cl: {}'.format(cl))
                #     print('eval_cl: {}'.format(eval_cl))
                #     print('gt_cl: {}'.format(gt_cl))
                #
                #     raise ValueError(e)

                for _class_name, _, base_ids in composite_classes:
                    _acc_list = np.asarray(list(_acc.values()))
                    _mean_acc = np.mean(_acc_list[base_ids])
                    avg_mean_acc[_class_name] += (
                        _mean_acc - avg_mean_acc[_class_name]) / (img_id + 1)

                    _IU_list = np.asarray(list(_IU.values()))
                    _mean_IU = np.mean(_IU_list[base_ids])
                    avg_mean_IU[_class_name] += (
                        _mean_IU - avg_mean_IU[_class_name]) / (img_id + 1)

                for _class_id, _class_data in enumerate(classes):
                    _class_name = _class_data[0]
                    try:
                        _mean_acc = _acc[_class_id]
                        avg_mean_acc[_class_name] += (
                            _mean_acc - avg_mean_acc[_class_name]) / (
                                img_id - skip_mean_acc[_class_name] + 1)
                    except KeyError:
                        print('\nskip_mean_acc {}: {}'.format(
                            _class_name, img_id))
                        skip_mean_acc[_class_name] += 1
                    try:
                        _mean_IU = _IU[_class_id]
                        avg_mean_IU[_class_name] += (
                            _mean_IU - avg_mean_IU[_class_name]) / (
                                img_id - skip_mean_IU[_class_name] + 1)
                    except KeyError:
                        print('\nskip_mean_IU {}: {}'.format(
                            _class_name, img_id))
                        skip_mean_IU[_class_name] += 1

                # seg_img = (seg_img * label_diff).astype(np.uint8)

                seg_img_vis = raw_seg_to_rgb(seg_img, class_id_to_color)

                if params.stitch and params.stitch_seg:
                    if params.blended:
                        full_mask_gs = cv2.cvtColor(seg_img_vis,
                                                    cv2.COLOR_BGR2GRAY)
                        mask_binary = full_mask_gs == 0
                        seg_img_vis = (0.5 * src_img +
                                       0.5 * seg_img_vis).astype(np.uint8)
                        seg_img_vis[mask_binary] = src_img[mask_binary]

                    if params.add_border:
                        stitched.append(border_img)
                    stitched.append(seg_img_vis)

                if not params.stitch and params.show_img:
                    cv2.imshow('seg_img', seg_img_vis)
            # else:
            # _, _fw = eval.frequency_weighted_IU(label_img_raw, label_img_raw, return_freq=1)
            # try:
            #     fw_sum += _fw
            # except ValueError as e:
            #     print('fw_sum: {}'.format(fw_sum))
            #     print('_fw: {}'.format(_fw))
            #
            #     gt_cl, _ = eval.extract_classes(label_img_raw)
            #     print('gt_cl: {}'.format(gt_cl))
            #     for k in range(n_classes):
            #         if k not in gt_cl:
            #             _fw.insert(k, 0)
            #
            #     fw_sum += _fw

            # _fw_total = np.sum(_fw)

            # print('_fw: {}'.format(_fw))
            # print('_fw_total: {}'.format(_fw_total))

            # _fw_frac = np.array(_fw) / float(_fw_total)

            # print('_fw_frac: {}'.format(_fw_frac))

        if params.stitch:
            stitched = np.concatenate(stitched, axis=1)
            if params.save_stitched:
                seg_save_path = os.path.join(
                    params.save_path,
                    '{}_{}.{}'.format(seq_name, img_fname_no_ext,
                                      params.out_ext))
                cv2.imwrite(seg_save_path, stitched)

            if params.show_img:
                cv2.imshow('stitched', stitched)
        else:
            if params.show_img:
                cv2.imshow('src_img', src_img)
                if params.labels_path:
                    cv2.imshow('labels_img', labels_img)

        if params.show_img:
            k = cv2.waitKey(1 - _pause)
            if k == 27:
                sys.exit(0)
            elif k == 32:
                _pause = 1 - _pause
        img_done = img_id - params.start_id + 1
        if img_done % print_diff == 0:
            log_txt = 'Done {:5d}/{:5d} frames'.format(img_done, n_frames)
            if eval_mode:
                log_txt = '{:s} pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}'.format(
                    log_txt, pix_acc[img_id], mean_acc[img_id],
                    mean_IU[img_id], fw_IU[img_id])
                for _class in all_classes:
                    log_txt += ' acc {}: {:.5f} '.format(
                        _class, avg_mean_acc[_class])

                for _class in all_classes:
                    log_txt += ' IU {}: {:.5f} '.format(
                        _class, avg_mean_acc[_class])

            print_and_write(log_txt, log_fname)

    sys.stdout.write('\n')
    sys.stdout.flush()

    if eval_mode:
        log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format(
            np.mean(pix_acc), np.mean(mean_acc), np.mean(mean_IU),
            np.mean(fw_IU))

        for _class in all_classes:
            log_txt += ' mean_acc {}\t'.format(_class)
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' {:.5f}\t'.format(avg_mean_acc[_class])
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' mean_IU {}\t'.format(_class)
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' {:.5f}\t'.format(avg_mean_IU[_class])
        log_txt += '\n'

        print_and_write(log_txt, log_fname)

        log_txt = templ + '\n\t'
        for _class in classes:
            log_txt += '{}\t '.format(_class[0])
        log_txt += 'all_classes\t all_classes(fw)\n'

        log_txt += 'recall\t'
        for _class in classes:
            log_txt += '{:.5f}\t'.format(avg_mean_acc[_class[0]])
        log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_acc),
                                             np.mean(pix_acc))

        log_txt += 'precision\t'
        for _class in classes:
            log_txt += '{:.5f}\t'.format(avg_mean_IU[_class[0]])
        log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_IU), np.mean(fw_IU))

        print_and_write(log_txt, log_fname)

    # fw_sum_total = np.sum(fw_sum)
    # fw_sum_frac = fw_sum / float(fw_sum_total)

    # print('fw_sum_total: {}'.format(fw_sum_total))
    # print('fw_sum_frac: {}'.format(fw_sum_frac))

    print('Wrote log to: {}'.format(log_fname))
コード例 #3
0
def run(params):
    """

    :param StitchParams params:
    :return:
    """
    video_exts = ['mp4', 'mkv', 'avi', 'mpg', 'mpeg', 'mjpg']

    # if not params.src_path:
    #     assert params.db_root_dir, "either params.src_path or params.db_root_dir must be provided"
    #
    #     params.src_path = os.path.join(params.db_root_dir, params.seq_name, 'images')

    print('Reading source images from: {}'.format(params.src_path))

    src_files = [
        k for k in os.listdir(params.src_path)
        if k.endswith('.{:s}'.format(params.images_ext))
    ]
    total_frames = len(src_files)
    # print('file_list: {}'.format(file_list))
    if total_frames <= 0:
        print('params: {}'.format(params))
        raise SystemError('No input frames of type {} found'.format(
            params.images_ext))

    print('total_frames: {}'.format(total_frames))

    classes, composite_classes = read_class_info(params.class_info_path)
    n_classes = len(classes)
    class_ids = list(range(n_classes))

    src_files.sort(key=sort_key)

    if params.n_frames <= 0:
        params.n_frames = total_frames

    if params.end_id < params.start_id:
        params.end_id = params.n_frames - 1

    if params.patch_width <= 0:
        params.patch_width = params.patch_height

    if not params.patch_seq_path:
        if not params.patch_seq_name:
            params.patch_seq_name = '{:s}_{:d}_{:d}_{:d}_{:d}_{:d}_{:d}'.format(
                params.seq_name, params.start_id, params.end_id,
                params.patch_height, params.patch_width, params.patch_height,
                params.patch_width)
        params.patch_seq_path = os.path.join(params.db_root_dir,
                                             params.patch_seq_name,
                                             params.patch_seq_type)
        if not os.path.isdir(params.patch_seq_path):
            raise SystemError(
                'params.patch_seq_path does not exist: {}'.format(
                    params.patch_seq_path))
    else:
        params.patch_seq_name = os.path.basename(params.patch_seq_path)

    if not params.stitched_seq_path:
        stitched_seq_name = '{}_stitched_{}'.format(params.patch_seq_name,
                                                    params.method)
        if params.stacked:
            stitched_seq_name = '{}_stacked'.format(stitched_seq_name)
            params.method = 1
        stitched_seq_name = '{}_{}_{}'.format(stitched_seq_name,
                                              params.start_id, params.end_id)
        params.stitched_seq_path = os.path.join(params.db_root_dir,
                                                stitched_seq_name,
                                                params.patch_seq_type)

    gt_labels_orig = gt_labels = video_out = None
    write_to_video = params.out_ext in video_exts
    if write_to_video:
        if not params.stitched_seq_path.endswith('.{}'.format(params.out_ext)):
            params.stitched_seq_path = '{}.{}'.format(params.stitched_seq_path,
                                                      params.out_ext)
        print('Writing {}x{} output video to: {}'.format(
            params.width, params.height, params.stitched_seq_path))
        save_dir = os.path.dirname(params.stitched_seq_path)

        fourcc = cv2.VideoWriter_fourcc(*params.codec)
        video_out = cv2.VideoWriter(params.stitched_seq_path, fourcc,
                                    params.fps, (params.width, params.height))
    else:
        print('Writing output images to: {}'.format(params.stitched_seq_path))
        save_dir = params.stitched_seq_path

    if save_dir and not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    log_fname = os.path.join(save_dir, 'log_{:s}.txt'.format(getDateTime()))
    print_and_write('Saving log to: {}'.format(log_fname), log_fname)
    print_and_write(
        'Reading patch images from: {}'.format(params.patch_seq_path),
        log_fname)

    n_patches = 0
    pause_after_frame = 1
    label_diff = int(255.0 / (params.n_classes - 1))

    eval_mode = False
    if params.labels_path and params.labels_ext:
        _, labels_list, labels_total_frames = read_data(
            labels_path=params.labels_path, labels_ext=params.labels_ext)
        if labels_total_frames != total_frames:
            raise SystemError(
                'Mismatch between no. of frames in GT and seg labels')
        eval_mode = True
        import densenet.evaluation.eval_segm as eval_segm

    avg_pix_acc = avg_mean_acc = avg_mean_IU = avg_fw_IU = 0
    avg_mean_acc_ice = avg_mean_acc_ice_1 = avg_mean_acc_ice_2 = 0
    avg_mean_IU_ice = avg_mean_IU_ice_1 = avg_mean_IU_ice_2 = 0

    skip_mean_acc_ice_1 = skip_mean_acc_ice_2 = 0
    skip_mean_IU_ice_1 = skip_mean_IU_ice_2 = 0

    _n_frames = params.end_id - params.start_id + 1

    for img_id in range(params.start_id, params.end_id + 1):

        # img_fname = '{:s}_{:d}.{:s}'.format(params.fname_templ, img_id + 1, params.img_ext)
        img_fname = src_files[img_id]
        img_fname_no_ext = os.path.splitext(img_fname)[0]

        src_img_fname = os.path.join(params.src_path, img_fname)
        src_img = cv2.imread(src_img_fname)

        if src_img is None:
            raise SystemError('Source image could not be read from: {}'.format(
                src_img_fname))

        n_rows, n_cols, n_channels = src_img.shape

        # np.savetxt(os.path.join(params.db_root_dir, params.seq_name, 'labels_img_{}.txt'.format(img_id + 1)),
        #            labels_img[:, :, 2], fmt='%d')

        out_id = 0
        # skip_id = 0
        min_row = 0

        enable_stitching = 1
        if params.method == -1:
            patch_src_img_fname = os.path.join(
                params.patch_seq_path,
                '{:s}.{:s}'.format(img_fname_no_ext, params.patch_ext))
            if not os.path.exists(patch_src_img_fname):
                raise SystemError('Patch image does not exist: {}'.format(
                    patch_src_img_fname))
            stitched_img = cv2.imread(patch_src_img_fname)
            enable_stitching = 0
        elif params.method == 0:
            stitched_img = None
        else:
            stitched_img = np.zeros((n_rows, n_cols, n_channels),
                                    dtype=np.uint8)

        while enable_stitching:
            max_row = min_row + params.patch_height
            if max_row > n_rows:
                diff = max_row - n_rows
                min_row -= diff
                max_row -= diff

            curr_row = None
            min_col = 0
            while True:
                max_col = min_col + params.patch_width
                if max_col > n_cols:
                    diff = max_col - n_cols
                    min_col -= diff
                    max_col -= diff

                patch_img_fname = '{:s}_{:d}'.format(img_fname_no_ext,
                                                     out_id + 1)
                patch_src_img_fname = os.path.join(
                    params.patch_seq_path,
                    '{:s}.{:s}'.format(patch_img_fname, params.patch_ext))

                if not os.path.exists(patch_src_img_fname):
                    raise SystemError('Patch image does not exist: {}'.format(
                        patch_src_img_fname))

                src_patch = cv2.imread(patch_src_img_fname)
                seg_height, seg_width, _ = src_patch.shape

                if seg_width == 2 * params.patch_width or seg_width == 3 * params.patch_width:
                    _start_id = seg_width - params.patch_width
                    src_patch = src_patch[:, _start_id:]

                if params.normalize_patches:
                    src_patch = (src_patch * label_diff).astype(np.uint8)

                # print('max(src_patch): {}'.format(np.max(src_patch)))
                # print('min(src_patch): {}'.format(np.min(src_patch)))

                out_id += 1

                if params.method == 0:
                    if curr_row is None:
                        curr_row = src_patch
                    else:
                        curr_row = np.concatenate((curr_row, src_patch),
                                                  axis=1)
                else:
                    stitched_img[min_row:max_row,
                                 min_col:max_col, :] = src_patch

                if params.show_img:
                    disp_img = src_img.copy()
                    cv2.rectangle(disp_img, (min_col, min_row),
                                  (max_col, max_row), (255, 0, 0), 2)

                    stitched_img_disp = stitched_img
                    if params.disp_resize_factor != 1:
                        disp_img = cv2.resize(disp_img, (0, 0),
                                              fx=params.disp_resize_factor,
                                              fy=params.disp_resize_factor)
                        if stitched_img_disp is not None:
                            stitched_img_disp = cv2.resize(
                                stitched_img_disp, (0, 0),
                                fx=params.disp_resize_factor,
                                fy=params.disp_resize_factor)

                    cv2.imshow('disp_img', disp_img)
                    cv2.imshow('src_patch', src_patch)

                    if stitched_img_disp is not None:
                        cv2.imshow('stacked_img', stitched_img_disp)
                    if curr_row is not None:
                        cv2.imshow('curr_row', curr_row)

                    k = cv2.waitKey(1 - pause_after_frame)
                    if k == 27:
                        sys.exit(0)
                    elif k == 32:
                        pause_after_frame = 1 - pause_after_frame

                # sys.stdout.write('\rDone {:d} patches in frame {:d}'.format(out_id, img_id + 1))
                # sys.stdout.flush()

                if max_col >= n_cols:
                    break

                min_col = max_col

            if params.method == 0:
                if stitched_img is None:
                    stitched_img = curr_row
                else:
                    stitched_img = np.concatenate((stitched_img, curr_row),
                                                  axis=0)

            if max_row >= n_rows:
                break

            min_row = max_row

        if eval_mode:
            labels_img_fname = os.path.join(
                params.labels_path,
                img_fname_no_ext + '.{}'.format(params.labels_ext))
            gt_labels_orig = imageio.imread(labels_img_fname)

            if gt_labels_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))

            if len(gt_labels_orig.shape) == 3:
                gt_labels = np.squeeze(gt_labels_orig[:, :, 0])
            else:
                gt_labels = gt_labels_orig

            seg_labels = np.squeeze(stitched_img[:, :, 0])

            pix_acc = eval_segm.pixel_accuracy(seg_labels, gt_labels,
                                               class_ids)
            _acc, mean_acc = eval_segm.mean_accuracy(seg_labels,
                                                     gt_labels,
                                                     class_ids,
                                                     return_acc=1)
            _IU, mean_IU = eval_segm.mean_IU(seg_labels,
                                             gt_labels,
                                             class_ids,
                                             return_iu=1)
            fw_IU = eval_segm.frequency_weighted_IU(seg_labels, gt_labels,
                                                    class_ids)

            avg_pix_acc += (pix_acc - avg_pix_acc) / (img_id + 1)
            avg_mean_acc += (mean_acc - avg_mean_acc) / (img_id + 1)
            avg_mean_IU += (mean_IU - avg_mean_IU) / (img_id + 1)
            avg_fw_IU += (fw_IU - avg_fw_IU) / (img_id + 1)

            # print('_acc: {}'.format(_acc))
            # print('_IU: {}'.format(_IU))

            mean_acc_ice = np.mean(list(_acc.values())[1:])
            avg_mean_acc_ice += (mean_acc_ice - avg_mean_acc_ice) / (img_id +
                                                                     1)
            try:
                mean_acc_ice_1 = _acc[1]
                avg_mean_acc_ice_1 += (mean_acc_ice_1 - avg_mean_acc_ice_1) / (
                    img_id - skip_mean_acc_ice_1 + 1)
            except KeyError:
                print('\nskip_mean_acc_ice_1: {}'.format(img_id))
                skip_mean_acc_ice_1 += 1
            try:
                mean_acc_ice_2 = _acc[2]
                avg_mean_acc_ice_2 += (mean_acc_ice_2 - avg_mean_acc_ice_2) / (
                    img_id - skip_mean_acc_ice_2 + 1)
            except KeyError:
                print('\nskip_mean_acc_ice_2: {}'.format(img_id))
                skip_mean_acc_ice_2 += 1

            mean_IU_ice = np.mean(list(_IU.values())[1:])
            avg_mean_IU_ice += (mean_IU_ice - avg_mean_IU_ice) / (img_id + 1)
            try:
                mean_IU_ice_1 = _IU[1]
                avg_mean_IU_ice_1 += (mean_IU_ice_1 - avg_mean_IU_ice_1) / (
                    img_id - skip_mean_IU_ice_1 + 1)
            except KeyError:
                print('\nskip_mean_IU_ice_1: {}'.format(img_id))
                skip_mean_IU_ice_1 += 1
            try:
                mean_IU_ice_2 = _IU[2]
                avg_mean_IU_ice_2 += (mean_IU_ice_2 - avg_mean_IU_ice_2) / (
                    img_id - skip_mean_IU_ice_2 + 1)
            except KeyError:
                print('\nskip_mean_IU_ice_2: {}'.format(img_id))
                skip_mean_IU_ice_2 += 1

            log_txt = '\nDone {:d}/{:d} frames '.format(
                img_id - params.start_id + 1, _n_frames)
            log_txt += "pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}" \
                       " avg_acc_ice: {:.5f} avg_acc_ice_1: {:.5f} avg_acc_ice_2: {:.5f}" \
                       " avg_IU_ice: {:.5f} avg_IU_ice_1: {:.5f} avg_IU_ice_2: {:.5f}".format(
                avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU,
                avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2,
                avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2,
            )
            print_and_write(log_txt, log_fname)
        else:
            sys.stdout.write('\rDone {:d}/{:d} frames'.format(
                img_id + 1 - params.start_id, _n_frames))
            sys.stdout.flush()

        if not params.normalize_patches:
            # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img))))
            # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img))))
            # print('min(stitched_img): {:d}'.format(int(np.min(stitched_img))))

            # print('label_diff: {}'.format(label_diff))

            seg_img = (stitched_img * label_diff).astype(np.uint8)

            # print('max(seg_img): {:d}'.format(int(np.max(seg_img))))
            # print('min(seg_img): {:d}'.format(int(np.min(seg_img))))

            # np.savetxt('/home/abhineet/seg_img_{}.data'.format(img_id), np.squeeze(seg_img[:, :, 0]).astype(
            # np.float64), fmt='%f')
            # np.savetxt('/home/abhineet/stitched_img_{}.data'.format(img_id), np.squeeze(stitched_img[:, :,
            # 0]).astype(np.float64), fmt='%f')

            # seg_img = seg_img.astype(np.uint8)

        else:
            seg_img = stitched_img

        if params.stacked:
            # print('stitched_img.shape', stitched_img.shape)
            # print('src_img.shape', src_img.shape)
            if eval_mode:
                labels_img = (gt_labels_orig * label_diff).astype(np.uint8)
                stitched = np.concatenate((src_img, labels_img), axis=1)
            else:
                stitched = src_img
            out_img = np.concatenate((stitched, seg_img), axis=1)
        else:
            out_img = seg_img

        if write_to_video:
            out_img = resize_ar(out_img, params.width, params.height)
            video_out.write(out_img)
            # statinfo = os.stat(params.stitched_seq_path)
            # print('\nvideo_size: {}'.format(statinfo.st_size))
        else:
            if params.resize_factor != 1:
                out_img = cv2.resize(out_img, (0, 0),
                                     fx=params.resize_factor,
                                     fy=params.resize_factor)
            stacked_img_path = os.path.join(
                params.stitched_seq_path,
                '{}.{}'.format(img_fname_no_ext, params.out_ext))
            cv2.imwrite(stacked_img_path, out_img)

        n_patches += out_id

    sys.stdout.write('\n')
    sys.stdout.flush()
    sys.stdout.write('Total patches processed: {}\n'.format(n_patches))

    if params.show_img:
        cv2.destroyAllWindows()
    if write_to_video:
        video_out.release()

    if params.del_patch_seq:
        print('Removing patch folder {}'.format(params.patch_seq_path))
        shutil.rmtree(params.patch_seq_path)

    if eval_mode:
        log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU)
        log_txt += "mean_acc_ice\t mean_acc_ice_1\t mean_acc_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2)
        log_txt += "mean_IU_ice\t mean_IU_ice_1\t mean_IU_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2)
        print_and_write(log_txt, log_fname)

        print_and_write('Saved log to: {}'.format(log_fname), log_fname)
        print_and_write(
            'Read patch images from: {}'.format(params.patch_seq_path),
            log_fname)
コード例 #4
0
            if gt_labels_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))

            if len(gt_labels_orig.shape) == 3:
                gt_labels = np.squeeze(gt_labels_orig[:, :, 0])
            else:
                gt_labels = gt_labels_orig

            seg_labels = np.squeeze(stitched_img[:, :, 0])

            pix_acc = eval_segm.pixel_accuracy(seg_labels, gt_labels)
            _acc, mean_acc = eval_segm.mean_accuracy(seg_labels,
                                                     gt_labels,
                                                     return_acc=1)
            _IU, mean_IU = eval_segm.mean_IU(seg_labels,
                                             gt_labels,
                                             return_iu=1)
            fw_IU = eval_segm.frequency_weighted_IU(seg_labels, gt_labels)

            avg_pix_acc += (pix_acc - avg_pix_acc) / (img_id + 1)
            avg_mean_acc += (mean_acc - avg_mean_acc) / (img_id + 1)
            avg_mean_IU += (mean_IU - avg_mean_IU) / (img_id + 1)
            avg_fw_IU += (fw_IU - avg_fw_IU) / (img_id + 1)

            # print('_acc: {}'.format(_acc))
            # print('_IU: {}'.format(_IU))

            mean_acc_ice = np.mean(list(_acc.values())[1:])