Exemple #1
0
    dst_dir = params['dst_dir']
    images_ext = params['images_ext']
    labels_ext = params['labels_ext']
    n_classes = params['n_classes']
    n_indices = params['n_indices']
    start_id = params['start_id']
    end_id = params['end_id']
    copy_images = params['copy_images']

    images_path = os.path.join(db_root_dir, src_dir, 'images')
    labels_path = os.path.join(db_root_dir, src_dir, 'labels')

    images_path = os.path.abspath(images_path)
    labels_path = os.path.abspath(labels_path)

    src_files, src_labels_list, total_frames = read_data(
        images_path, images_ext, labels_path, labels_ext)
    if start_id < 0:
        if end_id < 0:
            raise AssertionError(
                'end_id must be non negative for random selection')
        elif end_id >= total_frames:
            raise AssertionError(
                'end_id must be less than total_frames for random selection')
        print('Using {} random images for selection'.format(end_id + 1))
        img_ids = np.random.choice(total_frames, end_id + 1, replace=False)
    else:
        if end_id < start_id:
            end_id = total_frames - 1
        print('Using all {} images for selection'.format(end_id - start_id +
                                                         1))
        img_ids = range(start_id, end_id + 1)
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--log_dir", type=str)

    parser.add_argument("--images_path", type=str)
    parser.add_argument("--images_ext", type=str, default='png')
    parser.add_argument("--labels_path", type=str, default='')
    parser.add_argument("--labels_ext", type=str, default='png')
    parser.add_argument("--labels_col", type=str, default='green')
    parser.add_argument("--seg_paths", type=str_to_list, default=[])
    parser.add_argument("--seg_ext", type=str, default='png')
    parser.add_argument("--seg_root_dir", type=str, default='')

    parser.add_argument("--seg_labels", type=str_to_list, default=[])
    parser.add_argument(
        "--seg_cols",
        type=str_to_list,
        default=['blue', 'forest_green', 'magenta', 'cyan', 'red'])

    parser.add_argument("--out_path", type=str, default='')
    parser.add_argument("--out_ext", type=str, default='jpg')
    parser.add_argument("--out_size", type=str, default='1920x1080')
    parser.add_argument("--fps", type=float, default=30)
    parser.add_argument("--codec", type=str, default='H264')

    parser.add_argument("--save_path", type=str, default='')

    parser.add_argument("--n_classes", type=int)

    parser.add_argument("--save_stitched", type=int, default=0)
    parser.add_argument("--load_ice_conc_diff", type=int, default=0)

    parser.add_argument("--start_id", type=int, default=0)
    parser.add_argument("--end_id", type=int, default=-1)

    parser.add_argument("--show_img", type=int, default=0)
    parser.add_argument("--stitch", type=int, default=0)
    parser.add_argument("--stitch_seg", type=int, default=1)

    parser.add_argument("--plot_changed_seg_count", type=int, default=0)
    parser.add_argument("--normalize_labels", type=int, default=0)
    parser.add_argument("--selective_mode", type=int, default=0)
    parser.add_argument("--ice_type",
                        type=int,
                        default=0,
                        help='0: combined, 1: anchor, 2: frazil')
    parser.add_argument("--enable_plotting",
                        type=int,
                        default=1,
                        help='enable_plotting')

    args = parser.parse_args()

    images_path = args.images_path
    images_ext = args.images_ext
    labels_path = args.labels_path
    labels_ext = args.labels_ext
    labels_col = args.labels_col

    seg_paths = args.seg_paths
    seg_root_dir = args.seg_root_dir
    seg_ext = args.seg_ext

    out_path = args.out_path
    out_ext = args.out_ext
    out_size = args.out_size
    fps = args.fps
    codec = args.codec

    # save_path = args.save_path

    n_classes = args.n_classes

    end_id = args.end_id
    start_id = args.start_id

    show_img = args.show_img
    stitch = args.stitch
    stitch_seg = args.stitch_seg
    save_stitched = args.save_stitched

    normalize_labels = args.normalize_labels
    selective_mode = args.selective_mode

    seg_labels = args.seg_labels
    seg_cols = args.seg_cols

    ice_type = args.ice_type
    plot_changed_seg_count = args.plot_changed_seg_count

    load_ice_conc_diff = args.load_ice_conc_diff
    enable_plotting = args.enable_plotting

    ice_types = {
        0: 'Ice',
        1: 'Anchor Ice',
        2: 'Frazil Ice',
    }

    loc = (5, 120)
    size = 8
    thickness = 6
    fgr_col = (255, 255, 255)
    bgr_col = (0, 0, 0)
    font_id = 0

    video_exts = ['mp4', 'mkv', 'avi', 'mpg', 'mpeg', 'mjpg']

    labels_col_rgb = col_bgr[labels_col]
    seg_cols_rgb = [col_bgr[seg_col] for seg_col in seg_cols]

    ice_type_str = ice_types[ice_type]

    print('ice_type_str: {}'.format(ice_type_str))

    src_files, src_labels_list, total_frames = read_data(
        images_path, images_ext, labels_path, labels_ext)
    if end_id < start_id:
        end_id = total_frames - 1

    if seg_paths:
        n_seg_paths = len(seg_paths)
        n_seg_labels = len(seg_labels)

        if n_seg_paths != n_seg_labels:
            raise IOError(
                'Mismatch between n_seg_labels: {} and n_seg_paths: {}'.format(
                    n_seg_labels, n_seg_paths))
        if seg_root_dir:
            seg_paths = [
                os.path.join(seg_root_dir, name) for name in seg_paths
            ]

    if not out_path:
        if labels_path:
            out_path = labels_path + '_conc'
        elif seg_paths:
            out_path = seg_paths[0] + '_conc'

        if not os.path.isdir(out_path):
            os.makedirs(out_path)

    # print('Saving results data to {}'.format(out_path))

    # if not save_path:
    #     save_path = os.path.join(os.path.dirname(images_path), 'ice_concentration')
    # if not os.path.isdir(save_path):
    #     os.makedirs(save_path)

    # if stitch and save_stitched:
    #     print('Saving ice_concentration plots to: {}'.format(save_path))

    # log_fname = os.path.join(out_path, 'vis_log_{:s}.txt'.format(getDateTime()))
    # print('Saving log to: {}'.format(log_fname))

    if selective_mode:
        label_diff = int(255.0 / n_classes)
    else:
        label_diff = int(255.0 / (n_classes - 1))

    print('label_diff: {}'.format(label_diff))

    n_frames = end_id - start_id + 1

    print_diff = int(n_frames * 0.01)

    labels_img = None

    n_cols = len(seg_cols_rgb)

    plot_y_label = '{} concentration (%)'.format(ice_type_str)
    plot_x_label = 'distance in pixels from left edge'

    dists = {}

    for _label in seg_labels:
        dists[_label] = {
            # 'bhattacharyya': [],
            'euclidean': [],
            'mae': [],
            'mse': [],
            # 'frobenius': [],
        }

    plot_title = '{} concentration'.format(ice_type_str)

    out_size = tuple([int(x) for x in out_size.split('x')])
    write_to_video = out_ext in video_exts
    out_width, out_height = out_size

    out_seq_name = os.path.basename(out_path)

    if enable_plotting:
        if write_to_video:
            stitched_seq_path = os.path.join(
                out_path, '{}.{}'.format(out_seq_name, out_ext))
            print('Writing {}x{} output video to: {}'.format(
                out_width, out_height, stitched_seq_path))
            save_dir = os.path.dirname(stitched_seq_path)

            fourcc = cv2.VideoWriter_fourcc(*codec)
            video_out = cv2.VideoWriter(stitched_seq_path, fourcc, fps,
                                        out_size)
        else:
            stitched_seq_path = os.path.join(out_path, out_seq_name)
            print('Writing {}x{} output images of type {} to: {}'.format(
                out_width, out_height, out_ext, stitched_seq_path))
            save_dir = stitched_seq_path

        if save_dir and not os.path.isdir(save_dir):
            os.makedirs(save_dir)

    prev_seg_img = {}
    prev_conc_data_y = {}

    changed_seg_count = {}
    ice_concentration_diff = {}

    if load_ice_conc_diff:
        for seg_id in seg_labels:
            ice_concentration_diff[seg_id] = np.loadtxt(os.path.join(
                out_path, '{}_ice_concentration_diff.txt'.format(seg_id)),
                                                        dtype=np.float64)
    _pause = 0

    mae_data_y = []
    for seg_id, _ in enumerate(seg_paths):
        mae_data_y.append([])

    for img_id in range(start_id, end_id + 1):

        start_t = time.time()

        # img_fname = '{:s}_{:d}.{:s}'.format(fname_templ, img_id + 1, img_ext)
        img_fname = src_files[img_id]
        img_fname_no_ext = os.path.splitext(img_fname)[0]

        src_img_fname = os.path.join(images_path, img_fname)
        src_img = imread(src_img_fname)
        if src_img is None:
            raise SystemError('Source image could not be read from: {}'.format(
                src_img_fname))

        try:
            src_height, src_width = src_img.shape[:2]
        except ValueError as e:
            print('src_img_fname: {}'.format(src_img_fname))
            print('src_img: {}'.format(src_img))
            print('src_img.shape: {}'.format(src_img.shape))
            print('error: {}'.format(e))
            sys.exit(1)

        conc_data_x = np.asarray(range(src_width), dtype=np.float64)
        plot_data_x = conc_data_x

        plot_data_y = []
        plot_cols = []

        plot_labels = []

        stitched_img = src_img

        if labels_path:
            labels_img_fname = os.path.join(
                labels_path, img_fname_no_ext + '.{}'.format(labels_ext))
            labels_img_orig = imread(labels_img_fname)
            if labels_img_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))
            labels_height, labels_width = labels_img_orig.shape[:2]

            if labels_height != src_height or labels_width != src_width:
                raise AssertionError(
                    'Mismatch between dimensions of source: {} and label: {}'.
                    format((src_height, src_width), (seg_height, seg_width)))

            if len(labels_img_orig.shape) == 3:
                labels_img_orig = np.squeeze(labels_img_orig[:, :, 0])

            if show_img:
                cv2.imshow('labels_img_orig', labels_img_orig)

            if normalize_labels:
                labels_img = (labels_img_orig.astype(np.float64) /
                              label_diff).astype(np.uint8)
            else:
                labels_img = np.copy(labels_img_orig)

            if len(labels_img.shape) == 3:
                labels_img = labels_img[:, :, 0].squeeze()

            conc_data_y = np.zeros((labels_width, ), dtype=np.float64)

            for i in range(labels_width):
                curr_pix = np.squeeze(labels_img[:, i])
                if ice_type == 0:
                    ice_pix = curr_pix[curr_pix != 0]
                else:
                    ice_pix = curr_pix[curr_pix == ice_type]

                conc_data_y[i] = (len(ice_pix) / float(src_height)) * 100.0

            conc_data = np.zeros((labels_width, 2), dtype=np.float64)
            conc_data[:, 0] = conc_data_x
            conc_data[:, 1] = conc_data_y

            plot_data_y.append(conc_data_y)
            plot_cols.append(labels_col_rgb)

            gt_dict = {
                conc_data_x[i]: conc_data_y[i]
                for i in range(labels_width)
            }

            if not normalize_labels:
                labels_img_orig = (labels_img_orig.astype(np.float64) *
                                   label_diff).astype(np.uint8)

            if len(labels_img_orig.shape) == 2:
                labels_img_orig = np.stack(
                    (labels_img_orig, labels_img_orig, labels_img_orig),
                    axis=2)

            stitched_img = np.concatenate((stitched_img, labels_img_orig),
                                          axis=1)

            plot_labels.append('GT')

            # gt_cl, _ = eval.extract_classes(labels_img_orig)
            # print('gt_cl: {}'.format(gt_cl))

        mean_seg_counts = {}
        seg_count_data_y = []
        curr_mae_data_y = []

        mean_conc_diff = {}
        conc_diff_data_y = []
        seg_img_disp_list = []

        for seg_id, seg_path in enumerate(seg_paths):
            seg_img_fname = os.path.join(
                seg_path, img_fname_no_ext + '.{}'.format(seg_ext))
            seg_img_orig = imread(seg_img_fname)

            seg_col = seg_cols_rgb[seg_id % n_cols]

            _label = seg_labels[seg_id]

            if seg_img_orig is None:
                raise SystemError(
                    'Seg image could not be read from: {}'.format(
                        seg_img_fname))
            seg_height, seg_width = seg_img_orig.shape[:2]

            if seg_height != src_height or seg_width != src_width:
                raise AssertionError(
                    'Mismatch between dimensions of source: {} and seg: {}'.
                    format((src_height, src_width), (seg_height, seg_width)))

            if len(seg_img_orig.shape) == 3:
                seg_img_orig = np.squeeze(seg_img_orig[:, :, 0])

            if seg_img_orig.max() > n_classes - 1:
                seg_img = (seg_img_orig.astype(np.float64) /
                           label_diff).astype(np.uint8)
                seg_img_disp = seg_img_orig
            else:
                seg_img = seg_img_orig
                seg_img_disp = (seg_img_orig.astype(np.float64) *
                                label_diff).astype(np.uint8)

            if len(seg_img_disp.shape) == 2:
                seg_img_disp = np.stack(
                    (seg_img_disp, seg_img_disp, seg_img_disp), axis=2)

            ann_fmt = (font_id, loc[0], loc[1], size,
                       thickness) + fgr_col + bgr_col
            put_text_with_background(seg_img_disp,
                                     seg_labels[seg_id],
                                     fmt=ann_fmt)

            seg_img_disp_list.append(seg_img_disp)
            # eval_cl, _ = eval.extract_classes(seg_img)
            # print('eval_cl: {}'.format(eval_cl))

            if show_img:
                cv2.imshow('seg_img_orig', seg_img_orig)

            if len(seg_img.shape) == 3:
                seg_img = seg_img[:, :, 0].squeeze()

            conc_data_y = np.zeros((seg_width, ), dtype=np.float64)
            for i in range(seg_width):
                curr_pix = np.squeeze(seg_img[:, i])
                if ice_type == 0:
                    ice_pix = curr_pix[curr_pix != 0]
                else:
                    ice_pix = curr_pix[curr_pix == ice_type]
                conc_data_y[i] = (len(ice_pix) / float(src_height)) * 100.0

            plot_cols.append(seg_col)
            plot_data_y.append(conc_data_y)

            if labels_path:
                seg_dict = {
                    conc_data_x[i]: conc_data_y[i]
                    for i in range(seg_width)
                }
                # dists['bhattacharyya'].append(bhattacharyya(gt_dict, seg_dict))
                dists[_label]['euclidean'].append(euclidean(gt_dict, seg_dict))
                dists[_label]['mse'].append(mse(gt_dict, seg_dict))
                dists[_label]['mae'].append(mae(gt_dict, seg_dict))
                # dists['frobenius'].append(np.linalg.norm(conc_data_y - plot_data_y[0]))
                curr_mae_data_y.append(dists[_label]['mae'][-1])
            else:
                if img_id > 0:
                    if plot_changed_seg_count:
                        flow = cv2.calcOpticalFlowFarneback(
                            prev_seg_img[_label], seg_img, None, 0.5, 3, 15, 3,
                            5, 1.2, 0)

                        print('flow: {}'.format(flow.shape))

                        # # Obtain the flow magnitude and direction angle
                        # mag, ang = cv2.cartToPolar(flow[..., 0], flow[..., 1])
                        # hsvImg = np.zeros((2160, 3840, 3), dtype=np.uint8)
                        # hsvImg[..., 1] = 255
                        # # Update the color image
                        # hsvImg[..., 0] = 0.5 * ang * 180 / np.pi
                        # hsvImg[..., 2] = cv2.normalize(mag, None, 0, 255, cv2.NORM_MINMAX)
                        # rgbImg = cv2.cvtColor(hsvImg, cv2.COLOR_HSV2BGR)
                        # rgbImg = resizeAR(rgbImg, width=out_width, height=out_height)
                        # # Display the resulting frame
                        # cv2.imshow('dense optical flow', rgbImg)
                        # k = cv2.waitKey(0)

                        curr_x, curr_y = (prev_x + flow[..., 0]).astype(
                            np.int32), (prev_y + flow[..., 1]).astype(np.int32)

                        seg_img_flow = seg_img[curr_y, curr_x]

                        changed_seg_count[_label].append(
                            np.count_nonzero(
                                np.not_equal(seg_img, prev_seg_img[_label])))
                        seg_count_data_y.append(changed_seg_count[_label])
                        mean_seg_counts[_label] = np.mean(
                            changed_seg_count[_label])
                    else:
                        ice_concentration_diff[_label].append(
                            np.mean(
                                np.abs(conc_data_y -
                                       prev_conc_data_y[_label])))
                        conc_diff_data_y.append(ice_concentration_diff[_label])
                        mean_conc_diff[_label] = np.mean(
                            ice_concentration_diff[_label])
                else:
                    if plot_changed_seg_count:
                        prev_x, prev_y = np.meshgrid(range(seg_width),
                                                     range(seg_height),
                                                     sparse=False,
                                                     indexing='xy')
                        changed_seg_count[_label] = []
                    else:
                        ice_concentration_diff[_label] = []

            prev_seg_img[_label] = seg_img
            prev_conc_data_y[_label] = conc_data_y

        # conc_data = np.concatenate([conc_data_x, conc_data_y], axis=1)
        if labels_path:
            for i, k in enumerate(curr_mae_data_y):
                mae_data_y[i].append(k)

            n_test_images = img_id + 1
            mae_data_X = np.asarray(range(1, n_test_images + 1),
                                    dtype=np.float64)
            print('')
            # print('mae_data_X:\n {}'.format(pformat(mae_data_X)))
            # print('mae_data_y:\n {}'.format(pformat(np.array(mae_data_y).transpose())))

            if img_id == end_id:
                mae_data_y_arr = np.array(mae_data_y).transpose()
                print('mae_data_y:\n {}'.format(
                    tabulate(mae_data_y_arr,
                             headers=seg_labels,
                             tablefmt='plain')))

                pd.DataFrame(data=mae_data_y_arr,
                             columns=seg_labels).to_clipboard(excel=True)

            mae_img = getPlotImage(mae_data_X, mae_data_y, plot_cols, 'MAE',
                                   seg_labels, 'frame', 'MAE')
            cv2.imshow('mae_img', mae_img)
            conc_diff_img = resize_ar(mae_img,
                                      seg_width,
                                      src_height,
                                      bkg_col=255)
        else:
            if img_id > 0:
                n_test_images = img_id
                seg_count_data_X = np.asarray(range(1, n_test_images + 1),
                                              dtype=np.float64)

                if plot_changed_seg_count:
                    seg_count_img = getPlotImage(seg_count_data_X,
                                                 seg_count_data_y, plot_cols,
                                                 'Count', seg_labels, 'frame',
                                                 'Changed Label Count')
                    cv2.imshow('seg_count_img', seg_count_img)
                else:

                    # print('seg_count_data_X:\n {}'.format(pformat(seg_count_data_X)))
                    # print('conc_diff_data_y:\n {}'.format(pformat(conc_diff_data_y)))

                    conc_diff_img = getPlotImage(
                        seg_count_data_X, conc_diff_data_y, plot_cols,
                        'Mean concentration difference between consecutive frames'
                        .format(ice_type_str), seg_labels, 'frame',
                        'Concentration Difference (%)')
                    # cv2.imshow('conc_diff_img', conc_diff_img)
                    conc_diff_img = resize_ar(conc_diff_img,
                                              seg_width,
                                              src_height,
                                              bkg_col=255)
            else:
                conc_diff_img = np.zeros((src_height, seg_width, 3),
                                         dtype=np.uint8)

        plot_labels += seg_labels
        if enable_plotting:
            plot_img = getPlotImage(plot_data_x,
                                    plot_data_y,
                                    plot_cols,
                                    plot_title,
                                    plot_labels,
                                    plot_x_label,
                                    plot_y_label,
                                    legend=0
                                    # ylim=(0, 100)
                                    )

            plot_img = resize_ar(plot_img, seg_width, src_height, bkg_col=255)

            # plt.plot(conc_data_x, conc_data_y)
            # plt.show()

            # conc_data_fname = os.path.join(out_path, img_fname_no_ext + '.txt')
            # np.savetxt(conc_data_fname, conc_data, fmt='%.6f')
            ann_fmt = (font_id, loc[0], loc[1], size,
                       thickness) + labels_col_rgb + bgr_col

            put_text_with_background(src_img,
                                     'frame {}'.format(img_id + 1),
                                     fmt=ann_fmt)

            if n_seg_paths == 1:
                print('seg_img_disp: {}'.format(seg_img_disp.shape))
                print('plot_img: {}'.format(plot_img.shape))
                stitched_seg_img = np.concatenate((seg_img_disp, plot_img),
                                                  axis=1)

                print('stitched_seg_img: {}'.format(stitched_seg_img.shape))
                print('stitched_img: {}'.format(stitched_img.shape))
                stitched_img = np.concatenate((stitched_img, stitched_seg_img),
                                              axis=0 if labels_path else 1)
            elif n_seg_paths == 2:
                stitched_img = np.concatenate((
                    np.concatenate((src_img, conc_diff_img), axis=1),
                    np.concatenate(seg_img_disp_list, axis=1),
                ),
                                              axis=0)
            elif n_seg_paths == 3:
                stitched_img = np.concatenate((
                    np.concatenate((src_img, plot_img, conc_diff_img), axis=1),
                    np.concatenate(seg_img_disp_list, axis=1),
                ),
                                              axis=0)

            stitched_img = resize_ar(stitched_img,
                                     width=out_width,
                                     height=out_height)

            # print('dists: {}'.format(dists))

            if write_to_video:
                video_out.write(stitched_img)
            else:
                stacked_img_path = os.path.join(
                    stitched_seq_path, '{}.{}'.format(img_fname_no_ext,
                                                      out_ext))
                cv2.imwrite(stacked_img_path, stitched_img)

            cv2.imshow('stitched_img', stitched_img)
            k = cv2.waitKey(1 - _pause)
            if k == 27:
                break
            elif k == 32:
                _pause = 1 - _pause

        end_t = time.time()

        sys.stdout.write('\rDone {:d}/{:d} frames. fps: {}'.format(
            img_id + 1 - start_id, n_frames, 1.0 / (end_t - start_t)))
        sys.stdout.flush()

    print()

    if enable_plotting and write_to_video:
        video_out.release()

    if labels_path:
        median_dists = {}
        mean_dists = {}
        mae_data_y = []
        for _label in seg_labels:
            _dists = dists[_label]
            mae_data_y.append(_dists['mae'])
            mean_dists[_label] = {k: np.mean(_dists[k]) for k in _dists}
            median_dists[_label] = {k: np.median(_dists[k]) for k in _dists}

        print('mean_dists:\n{}'.format(pformat(mean_dists)))
        print('median_dists:\n{}'.format(pformat(median_dists)))

        n_test_images = len(mae_data_y[0])

        mae_data_x = np.asarray(range(1, n_test_images + 1), dtype=np.float64)
        mae_img = getPlotImage(mae_data_x, mae_data_y, plot_cols, 'MAE',
                               seg_labels, 'test image', 'Mean Absolute Error')
        # plt.show()
        cv2.imshow('MAE', mae_img)
        k = cv2.waitKey(0)
    else:
        mean_seg_counts = {}
        median_seg_counts = {}
        seg_count_data_y = []

        mean_conc_diff = {}
        median_conc_diff = {}
        conc_diff_data_y = []

        for seg_id in ice_concentration_diff:
            if plot_changed_seg_count:
                seg_count_data_y.append(changed_seg_count[seg_id])
                mean_seg_counts[seg_id] = np.mean(changed_seg_count[seg_id])
                median_seg_counts[seg_id] = np.median(
                    changed_seg_count[seg_id])
            else:
                _ice_concentration_diff = ice_concentration_diff[seg_id]
                n_test_images = len(_ice_concentration_diff)

                conc_diff_data_y.append(_ice_concentration_diff)
                mean_conc_diff[seg_id] = np.mean(_ice_concentration_diff)
                median_conc_diff[seg_id] = np.median(_ice_concentration_diff)

                np.savetxt(os.path.join(
                    out_path, '{}_ice_concentration_diff.txt'.format(seg_id)),
                           _ice_concentration_diff,
                           fmt='%8.4f',
                           delimiter='\t')

        if plot_changed_seg_count:
            print('mean_seg_counts:\n{}'.format(pformat(mean_seg_counts)))
            print('median_seg_counts:\n{}'.format(pformat(median_seg_counts)))
        else:
            print('mean_conc_diff:')
            for seg_id in mean_conc_diff:
                print('{}\t{}'.format(seg_id, mean_conc_diff[seg_id]))
            print('median_conc_diff:')
            for seg_id in mean_conc_diff:
                print('{}\t{}'.format(seg_id, median_conc_diff[seg_id]))
def run(params):
    """

    :param StitchParams params:
    :return:
    """
    video_exts = ['mp4', 'mkv', 'avi', 'mpg', 'mpeg', 'mjpg']

    # if not params.src_path:
    #     assert params.db_root_dir, "either params.src_path or params.db_root_dir must be provided"
    #
    #     params.src_path = os.path.join(params.db_root_dir, params.seq_name, 'images')

    print('Reading source images from: {}'.format(params.src_path))

    src_files = [
        k for k in os.listdir(params.src_path)
        if k.endswith('.{:s}'.format(params.images_ext))
    ]
    total_frames = len(src_files)
    # print('file_list: {}'.format(file_list))
    if total_frames <= 0:
        print('params: {}'.format(params))
        raise SystemError('No input frames of type {} found'.format(
            params.images_ext))

    print('total_frames: {}'.format(total_frames))

    classes, composite_classes = read_class_info(params.class_info_path)
    n_classes = len(classes)
    class_ids = list(range(n_classes))

    src_files.sort(key=sort_key)

    if params.n_frames <= 0:
        params.n_frames = total_frames

    if params.end_id < params.start_id:
        params.end_id = params.n_frames - 1

    if params.patch_width <= 0:
        params.patch_width = params.patch_height

    if not params.patch_seq_path:
        if not params.patch_seq_name:
            params.patch_seq_name = '{:s}_{:d}_{:d}_{:d}_{:d}_{:d}_{:d}'.format(
                params.seq_name, params.start_id, params.end_id,
                params.patch_height, params.patch_width, params.patch_height,
                params.patch_width)
        params.patch_seq_path = os.path.join(params.db_root_dir,
                                             params.patch_seq_name,
                                             params.patch_seq_type)
        if not os.path.isdir(params.patch_seq_path):
            raise SystemError(
                'params.patch_seq_path does not exist: {}'.format(
                    params.patch_seq_path))
    else:
        params.patch_seq_name = os.path.basename(params.patch_seq_path)

    if not params.stitched_seq_path:
        stitched_seq_name = '{}_stitched_{}'.format(params.patch_seq_name,
                                                    params.method)
        if params.stacked:
            stitched_seq_name = '{}_stacked'.format(stitched_seq_name)
            params.method = 1
        stitched_seq_name = '{}_{}_{}'.format(stitched_seq_name,
                                              params.start_id, params.end_id)
        params.stitched_seq_path = os.path.join(params.db_root_dir,
                                                stitched_seq_name,
                                                params.patch_seq_type)

    gt_labels_orig = gt_labels = video_out = None
    write_to_video = params.out_ext in video_exts
    if write_to_video:
        if not params.stitched_seq_path.endswith('.{}'.format(params.out_ext)):
            params.stitched_seq_path = '{}.{}'.format(params.stitched_seq_path,
                                                      params.out_ext)
        print('Writing {}x{} output video to: {}'.format(
            params.width, params.height, params.stitched_seq_path))
        save_dir = os.path.dirname(params.stitched_seq_path)

        fourcc = cv2.VideoWriter_fourcc(*params.codec)
        video_out = cv2.VideoWriter(params.stitched_seq_path, fourcc,
                                    params.fps, (params.width, params.height))
    else:
        print('Writing output images to: {}'.format(params.stitched_seq_path))
        save_dir = params.stitched_seq_path

    if save_dir and not os.path.isdir(save_dir):
        os.makedirs(save_dir)
    log_fname = os.path.join(save_dir, 'log_{:s}.txt'.format(getDateTime()))
    print_and_write('Saving log to: {}'.format(log_fname), log_fname)
    print_and_write(
        'Reading patch images from: {}'.format(params.patch_seq_path),
        log_fname)

    n_patches = 0
    pause_after_frame = 1
    label_diff = int(255.0 / (params.n_classes - 1))

    eval_mode = False
    if params.labels_path and params.labels_ext:
        _, labels_list, labels_total_frames = read_data(
            labels_path=params.labels_path, labels_ext=params.labels_ext)
        if labels_total_frames != total_frames:
            raise SystemError(
                'Mismatch between no. of frames in GT and seg labels')
        eval_mode = True
        import densenet.evaluation.eval_segm as eval_segm

    avg_pix_acc = avg_mean_acc = avg_mean_IU = avg_fw_IU = 0
    avg_mean_acc_ice = avg_mean_acc_ice_1 = avg_mean_acc_ice_2 = 0
    avg_mean_IU_ice = avg_mean_IU_ice_1 = avg_mean_IU_ice_2 = 0

    skip_mean_acc_ice_1 = skip_mean_acc_ice_2 = 0
    skip_mean_IU_ice_1 = skip_mean_IU_ice_2 = 0

    _n_frames = params.end_id - params.start_id + 1

    for img_id in range(params.start_id, params.end_id + 1):

        # img_fname = '{:s}_{:d}.{:s}'.format(params.fname_templ, img_id + 1, params.img_ext)
        img_fname = src_files[img_id]
        img_fname_no_ext = os.path.splitext(img_fname)[0]

        src_img_fname = os.path.join(params.src_path, img_fname)
        src_img = cv2.imread(src_img_fname)

        if src_img is None:
            raise SystemError('Source image could not be read from: {}'.format(
                src_img_fname))

        n_rows, n_cols, n_channels = src_img.shape

        # np.savetxt(os.path.join(params.db_root_dir, params.seq_name, 'labels_img_{}.txt'.format(img_id + 1)),
        #            labels_img[:, :, 2], fmt='%d')

        out_id = 0
        # skip_id = 0
        min_row = 0

        enable_stitching = 1
        if params.method == -1:
            patch_src_img_fname = os.path.join(
                params.patch_seq_path,
                '{:s}.{:s}'.format(img_fname_no_ext, params.patch_ext))
            if not os.path.exists(patch_src_img_fname):
                raise SystemError('Patch image does not exist: {}'.format(
                    patch_src_img_fname))
            stitched_img = cv2.imread(patch_src_img_fname)
            enable_stitching = 0
        elif params.method == 0:
            stitched_img = None
        else:
            stitched_img = np.zeros((n_rows, n_cols, n_channels),
                                    dtype=np.uint8)

        while enable_stitching:
            max_row = min_row + params.patch_height
            if max_row > n_rows:
                diff = max_row - n_rows
                min_row -= diff
                max_row -= diff

            curr_row = None
            min_col = 0
            while True:
                max_col = min_col + params.patch_width
                if max_col > n_cols:
                    diff = max_col - n_cols
                    min_col -= diff
                    max_col -= diff

                patch_img_fname = '{:s}_{:d}'.format(img_fname_no_ext,
                                                     out_id + 1)
                patch_src_img_fname = os.path.join(
                    params.patch_seq_path,
                    '{:s}.{:s}'.format(patch_img_fname, params.patch_ext))

                if not os.path.exists(patch_src_img_fname):
                    raise SystemError('Patch image does not exist: {}'.format(
                        patch_src_img_fname))

                src_patch = cv2.imread(patch_src_img_fname)
                seg_height, seg_width, _ = src_patch.shape

                if seg_width == 2 * params.patch_width or seg_width == 3 * params.patch_width:
                    _start_id = seg_width - params.patch_width
                    src_patch = src_patch[:, _start_id:]

                if params.normalize_patches:
                    src_patch = (src_patch * label_diff).astype(np.uint8)

                # print('max(src_patch): {}'.format(np.max(src_patch)))
                # print('min(src_patch): {}'.format(np.min(src_patch)))

                out_id += 1

                if params.method == 0:
                    if curr_row is None:
                        curr_row = src_patch
                    else:
                        curr_row = np.concatenate((curr_row, src_patch),
                                                  axis=1)
                else:
                    stitched_img[min_row:max_row,
                                 min_col:max_col, :] = src_patch

                if params.show_img:
                    disp_img = src_img.copy()
                    cv2.rectangle(disp_img, (min_col, min_row),
                                  (max_col, max_row), (255, 0, 0), 2)

                    stitched_img_disp = stitched_img
                    if params.disp_resize_factor != 1:
                        disp_img = cv2.resize(disp_img, (0, 0),
                                              fx=params.disp_resize_factor,
                                              fy=params.disp_resize_factor)
                        if stitched_img_disp is not None:
                            stitched_img_disp = cv2.resize(
                                stitched_img_disp, (0, 0),
                                fx=params.disp_resize_factor,
                                fy=params.disp_resize_factor)

                    cv2.imshow('disp_img', disp_img)
                    cv2.imshow('src_patch', src_patch)

                    if stitched_img_disp is not None:
                        cv2.imshow('stacked_img', stitched_img_disp)
                    if curr_row is not None:
                        cv2.imshow('curr_row', curr_row)

                    k = cv2.waitKey(1 - pause_after_frame)
                    if k == 27:
                        sys.exit(0)
                    elif k == 32:
                        pause_after_frame = 1 - pause_after_frame

                # sys.stdout.write('\rDone {:d} patches in frame {:d}'.format(out_id, img_id + 1))
                # sys.stdout.flush()

                if max_col >= n_cols:
                    break

                min_col = max_col

            if params.method == 0:
                if stitched_img is None:
                    stitched_img = curr_row
                else:
                    stitched_img = np.concatenate((stitched_img, curr_row),
                                                  axis=0)

            if max_row >= n_rows:
                break

            min_row = max_row

        if eval_mode:
            labels_img_fname = os.path.join(
                params.labels_path,
                img_fname_no_ext + '.{}'.format(params.labels_ext))
            gt_labels_orig = imageio.imread(labels_img_fname)

            if gt_labels_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))

            if len(gt_labels_orig.shape) == 3:
                gt_labels = np.squeeze(gt_labels_orig[:, :, 0])
            else:
                gt_labels = gt_labels_orig

            seg_labels = np.squeeze(stitched_img[:, :, 0])

            pix_acc = eval_segm.pixel_accuracy(seg_labels, gt_labels,
                                               class_ids)
            _acc, mean_acc = eval_segm.mean_accuracy(seg_labels,
                                                     gt_labels,
                                                     class_ids,
                                                     return_acc=1)
            _IU, mean_IU = eval_segm.mean_IU(seg_labels,
                                             gt_labels,
                                             class_ids,
                                             return_iu=1)
            fw_IU = eval_segm.frequency_weighted_IU(seg_labels, gt_labels,
                                                    class_ids)

            avg_pix_acc += (pix_acc - avg_pix_acc) / (img_id + 1)
            avg_mean_acc += (mean_acc - avg_mean_acc) / (img_id + 1)
            avg_mean_IU += (mean_IU - avg_mean_IU) / (img_id + 1)
            avg_fw_IU += (fw_IU - avg_fw_IU) / (img_id + 1)

            # print('_acc: {}'.format(_acc))
            # print('_IU: {}'.format(_IU))

            mean_acc_ice = np.mean(list(_acc.values())[1:])
            avg_mean_acc_ice += (mean_acc_ice - avg_mean_acc_ice) / (img_id +
                                                                     1)
            try:
                mean_acc_ice_1 = _acc[1]
                avg_mean_acc_ice_1 += (mean_acc_ice_1 - avg_mean_acc_ice_1) / (
                    img_id - skip_mean_acc_ice_1 + 1)
            except KeyError:
                print('\nskip_mean_acc_ice_1: {}'.format(img_id))
                skip_mean_acc_ice_1 += 1
            try:
                mean_acc_ice_2 = _acc[2]
                avg_mean_acc_ice_2 += (mean_acc_ice_2 - avg_mean_acc_ice_2) / (
                    img_id - skip_mean_acc_ice_2 + 1)
            except KeyError:
                print('\nskip_mean_acc_ice_2: {}'.format(img_id))
                skip_mean_acc_ice_2 += 1

            mean_IU_ice = np.mean(list(_IU.values())[1:])
            avg_mean_IU_ice += (mean_IU_ice - avg_mean_IU_ice) / (img_id + 1)
            try:
                mean_IU_ice_1 = _IU[1]
                avg_mean_IU_ice_1 += (mean_IU_ice_1 - avg_mean_IU_ice_1) / (
                    img_id - skip_mean_IU_ice_1 + 1)
            except KeyError:
                print('\nskip_mean_IU_ice_1: {}'.format(img_id))
                skip_mean_IU_ice_1 += 1
            try:
                mean_IU_ice_2 = _IU[2]
                avg_mean_IU_ice_2 += (mean_IU_ice_2 - avg_mean_IU_ice_2) / (
                    img_id - skip_mean_IU_ice_2 + 1)
            except KeyError:
                print('\nskip_mean_IU_ice_2: {}'.format(img_id))
                skip_mean_IU_ice_2 += 1

            log_txt = '\nDone {:d}/{:d} frames '.format(
                img_id - params.start_id + 1, _n_frames)
            log_txt += "pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}" \
                       " avg_acc_ice: {:.5f} avg_acc_ice_1: {:.5f} avg_acc_ice_2: {:.5f}" \
                       " avg_IU_ice: {:.5f} avg_IU_ice_1: {:.5f} avg_IU_ice_2: {:.5f}".format(
                avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU,
                avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2,
                avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2,
            )
            print_and_write(log_txt, log_fname)
        else:
            sys.stdout.write('\rDone {:d}/{:d} frames'.format(
                img_id + 1 - params.start_id, _n_frames))
            sys.stdout.flush()

        if not params.normalize_patches:
            # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img))))
            # print('max(stitched_img): {:d}'.format(int(np.max(stitched_img))))
            # print('min(stitched_img): {:d}'.format(int(np.min(stitched_img))))

            # print('label_diff: {}'.format(label_diff))

            seg_img = (stitched_img * label_diff).astype(np.uint8)

            # print('max(seg_img): {:d}'.format(int(np.max(seg_img))))
            # print('min(seg_img): {:d}'.format(int(np.min(seg_img))))

            # np.savetxt('/home/abhineet/seg_img_{}.data'.format(img_id), np.squeeze(seg_img[:, :, 0]).astype(
            # np.float64), fmt='%f')
            # np.savetxt('/home/abhineet/stitched_img_{}.data'.format(img_id), np.squeeze(stitched_img[:, :,
            # 0]).astype(np.float64), fmt='%f')

            # seg_img = seg_img.astype(np.uint8)

        else:
            seg_img = stitched_img

        if params.stacked:
            # print('stitched_img.shape', stitched_img.shape)
            # print('src_img.shape', src_img.shape)
            if eval_mode:
                labels_img = (gt_labels_orig * label_diff).astype(np.uint8)
                stitched = np.concatenate((src_img, labels_img), axis=1)
            else:
                stitched = src_img
            out_img = np.concatenate((stitched, seg_img), axis=1)
        else:
            out_img = seg_img

        if write_to_video:
            out_img = resize_ar(out_img, params.width, params.height)
            video_out.write(out_img)
            # statinfo = os.stat(params.stitched_seq_path)
            # print('\nvideo_size: {}'.format(statinfo.st_size))
        else:
            if params.resize_factor != 1:
                out_img = cv2.resize(out_img, (0, 0),
                                     fx=params.resize_factor,
                                     fy=params.resize_factor)
            stacked_img_path = os.path.join(
                params.stitched_seq_path,
                '{}.{}'.format(img_fname_no_ext, params.out_ext))
            cv2.imwrite(stacked_img_path, out_img)

        n_patches += out_id

    sys.stdout.write('\n')
    sys.stdout.flush()
    sys.stdout.write('Total patches processed: {}\n'.format(n_patches))

    if params.show_img:
        cv2.destroyAllWindows()
    if write_to_video:
        video_out.release()

    if params.del_patch_seq:
        print('Removing patch folder {}'.format(params.patch_seq_path))
        shutil.rmtree(params.patch_seq_path)

    if eval_mode:
        log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_pix_acc, avg_mean_acc, avg_mean_IU, avg_fw_IU)
        log_txt += "mean_acc_ice\t mean_acc_ice_1\t mean_acc_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_mean_acc_ice, avg_mean_acc_ice_1, avg_mean_acc_ice_2)
        log_txt += "mean_IU_ice\t mean_IU_ice_1\t mean_IU_ice_2\n{:.5f}\t{:.5f}\t{:.5f}\n".format(
            avg_mean_IU_ice, avg_mean_IU_ice_1, avg_mean_IU_ice_2)
        print_and_write(log_txt, log_fname)

        print_and_write('Saved log to: {}'.format(log_fname), log_fname)
        print_and_write(
            'Read patch images from: {}'.format(params.patch_seq_path),
            log_fname)
def run(params):
    """

    :param VisParams params:
    :return:
    """
    eval_mode = False

    if params.multi_sequence_db:
        assert params.vis_split, "vis_split must be provided for CTC"
        """some repeated code here to allow better IntelliSense"""
        if params.dataset.lower() == 'ctc':
            from new_deeplab.datasets.build_ctc_data import CTCInfo
            db_splits = CTCInfo.DBSplits().__dict__
            sequences = CTCInfo.sequences
        elif params.dataset.lower() in ('ipsc', 'ipsc_2_class',
                                        'ipsc_5_class'):
            from new_deeplab.datasets.ipsc_info import IPSCInfo
            db_splits = IPSCInfo.DBSplits().__dict__
            sequences = IPSCInfo.sequences
        elif params.dataset.lower() == 'ipsc_patches':
            from new_deeplab.datasets.ipsc_info import IPSCPatchesInfo
            db_splits = IPSCPatchesInfo.DBSplits().__dict__
            sequences = IPSCPatchesInfo.sequences
        else:
            raise AssertionError(
                'multi_sequence_db {} is not supported yet'.format(
                    params.dataset))

        seq_ids = db_splits[params.vis_split]

        src_files = []
        seg_labels_list = []
        if params.no_labels:
            src_labels_list = None
        else:
            src_labels_list = []

        total_frames = 0
        seg_total_frames = 0

        for seq_id in seq_ids:
            seq_name, n_frames = sequences[seq_id]

            images_path = os.path.join(params.images_path, seq_name)

            if params.no_labels:
                labels_path = ''
            else:
                labels_path = os.path.join(params.labels_path, seq_name)

            _src_files, _src_labels_list, _total_frames = read_data(
                images_path, params.images_ext, labels_path, params.labels_ext)

            _src_filenames = [
                os.path.splitext(os.path.basename(k))[0] for k in _src_files
            ]

            if not params.no_labels:
                _src_labels_filenames = [
                    os.path.splitext(os.path.basename(k))[0]
                    for k in _src_labels_list
                ]

                assert _src_labels_filenames == _src_filenames, "mismatch between image and label filenames"

            eval_mode = False
            if params.seg_path and params.seg_ext:
                seg_path = os.path.join(params.seg_path, seq_name)

                _, _seg_labels_list, _seg_total_frames = read_data(
                    labels_path=seg_path,
                    labels_ext=params.seg_ext,
                    labels_type='seg')

                _seg_labels__filenames = [
                    os.path.splitext(os.path.basename(k))[0]
                    for k in _seg_labels_list
                ]

                if _seg_total_frames != _total_frames:

                    if params.seg_on_subset and _seg_total_frames < _total_frames:
                        matching_ids = [
                            _src_filenames.index(k)
                            for k in _seg_labels__filenames
                        ]

                        _src_files = [_src_files[i] for i in matching_ids]
                        if not params.no_labels:
                            _src_labels_list = [
                                _src_labels_list[i] for i in matching_ids
                            ]

                        _total_frames = _seg_total_frames

                    else:
                        raise AssertionError(
                            'Mismatch between no. of frames in GT and seg labels: {} and {}'
                            .format(_total_frames, _seg_total_frames))

                    seg_labels_list += _seg_labels_list

                seg_total_frames += _seg_total_frames
                eval_mode = True

            src_files += _src_files
            if not params.no_labels:
                src_labels_list += _src_labels_list
            # else:
            #     params.stitch = params.save_stitched = 1

            total_frames += _total_frames
    else:
        src_files, src_labels_list, total_frames = read_data(
            params.images_path, params.images_ext, params.labels_path,
            params.labels_ext)

        eval_mode = False
        if params.labels_path and params.seg_path and params.seg_ext:
            _, seg_labels_list, seg_total_frames = read_data(
                labels_path=params.seg_path,
                labels_ext=params.seg_ext,
                labels_type='seg')
            if seg_total_frames != total_frames:
                raise SystemError(
                    'Mismatch between no. of frames in GT and seg labels: {} and {}'
                    .format(total_frames, seg_total_frames))
            eval_mode = True
        # else:
        #     params.stitch = params.save_stitched = 1

    if params.end_id < params.start_id:
        params.end_id = total_frames - 1

    classes, composite_classes = read_class_info(params.class_info_path)
    n_classes = len(classes)
    class_ids = list(range(n_classes))
    class_id_to_color = {i: k[1] for i, k in enumerate(classes)}

    all_classes = [k[0] for k in classes + composite_classes]

    if not params.save_path:
        if eval_mode:
            params.save_path = os.path.join(os.path.dirname(params.seg_path),
                                            'vis')
        else:
            params.save_path = os.path.join(
                os.path.dirname(params.images_path), 'vis')

    if not os.path.isdir(params.save_path):
        os.makedirs(params.save_path)

    if params.stitch and params.save_stitched:
        print('Saving visualization images to: {}'.format(params.save_path))

    log_fname = os.path.join(params.save_path,
                             'vis_log_{:s}.txt'.format(getDateTime()))
    print('Saving log to: {}'.format(log_fname))

    save_path_parent = os.path.dirname(params.save_path)
    templ_1 = os.path.basename(save_path_parent)
    templ_2 = os.path.basename(os.path.dirname(save_path_parent))

    templ = '{}_{}'.format(templ_1, templ_2)

    # if params.selective_mode:
    #     label_diff = int(255.0 / n_classes)
    # else:
    #     label_diff = int(255.0 / (n_classes - 1))

    print('templ: {}'.format(templ))
    # print('label_diff: {}'.format(label_diff))

    n_frames = params.end_id - params.start_id + 1

    pix_acc = np.zeros((n_frames, ))

    mean_acc = np.zeros((n_frames, ))
    # mean_acc_ice = np.zeros((n_frames,))
    # mean_acc_ice_1 = np.zeros((n_frames,))
    # mean_acc_ice_2 = np.zeros((n_frames,))
    #
    mean_IU = np.zeros((n_frames, ))
    # mean_IU_ice = np.zeros((n_frames,))
    # mean_IU_ice_1 = np.zeros((n_frames,))
    # mean_IU_ice_2 = np.zeros((n_frames,))

    fw_IU = np.zeros((n_frames, ))
    fw_sum = np.zeros((n_classes, ))

    print_diff = max(1, int(n_frames * 0.01))

    avg_mean_acc = {c: 0 for c in all_classes}
    avg_mean_IU = {c: 0 for c in all_classes}
    skip_mean_acc = {c: 0 for c in all_classes}
    skip_mean_IU = {c: 0 for c in all_classes}

    _pause = 1
    labels_img = None

    for img_id in range(params.start_id, params.end_id + 1):

        stitched = []

        # img_fname = '{:s}_{:d}.{:s}'.format(fname_templ, img_id + 1, img_ext)
        src_img_fname = src_files[img_id]
        img_dir = os.path.dirname(src_img_fname)
        seq_name = os.path.basename(img_dir)
        img_fname = os.path.basename(src_img_fname)

        img_fname_no_ext = os.path.splitext(img_fname)[0]

        src_img = None
        border_img = None

        if params.stitch or params.show_img:
            # src_img_fname = os.path.join(params.images_path, img_fname)
            src_img = cv2.imread(src_img_fname)
            if src_img is None:
                raise SystemError(
                    'Source image could not be read from: {}'.format(
                        src_img_fname))

            try:
                src_height, src_width, _ = src_img.shape
            except ValueError as e:
                print('src_img_fname: {}'.format(src_img_fname))
                print('src_img: {}'.format(src_img))
                print('src_img.shape: {}'.format(src_img.shape))
                print('error: {}'.format(e))
                sys.exit(1)

            if not params.blended:
                stitched.append(src_img)

            border_img = np.full_like(src_img, 255)
            border_img = border_img[:, :5, ...]

        if not params.no_labels:
            # labels_img_fname = os.path.join(params.labels_path, img_fname_no_ext + '.{}'.format(params.labels_ext))
            labels_img_fname = src_labels_list[img_id]

            labels_img_orig = cv2.imread(labels_img_fname)
            if labels_img_orig is None:
                raise SystemError(
                    'Labels image could not be read from: {}'.format(
                        labels_img_fname))

            _, src_width = labels_img_orig.shape[:2]

            # if len(labels_img_orig.shape) == 3:
            #     labels_img_orig = np.squeeze(labels_img_orig[:, :, 0])

            if params.show_img:
                cv2.imshow('labels_img_orig', labels_img_orig)

            labels_img_orig, label_img_raw, class_to_ids = remove_fuzziness_in_mask(
                labels_img_orig,
                n_classes,
                class_id_to_color,
                fuzziness=5,
                check_equality=0)
            labels_img = np.copy(labels_img_orig)
            # if params.normalize_labels:
            #     if params.selective_mode:
            #         selective_idx = (labels_img_orig == 255)
            #         print('labels_img_orig.shape: {}'.format(labels_img_orig.shape))
            #         print('selective_idx count: {}'.format(np.count_nonzero(selective_idx)))
            #         labels_img_orig[selective_idx] = n_classes
            #         if params.show_img:
            #             cv2.imshow('labels_img_orig norm', labels_img_orig)
            #     labels_img = (labels_img_orig.astype(np.float64) * label_diff).astype(np.uint8)
            # else:
            #     labels_img = np.copy(labels_img_orig)

            # if len(labels_img.shape) != 3:
            #     labels_img = np.stack((labels_img, labels_img, labels_img), axis=2)

            if params.stitch:
                if params.blended:
                    full_mask_gs = cv2.cvtColor(labels_img, cv2.COLOR_BGR2GRAY)
                    mask_binary = full_mask_gs == 0
                    labels_img_vis = (0.5 * src_img + 0.5 * labels_img).astype(
                        np.uint8)
                    labels_img_vis[mask_binary] = src_img[mask_binary]
                else:
                    labels_img_vis = labels_img

                if params.add_border:
                    stitched.append(border_img)
                stitched.append(labels_img_vis)

            if eval_mode:
                # seg_img_fname = os.path.join(params.seg_path, img_fname_no_ext + '.{}'.format(params.seg_ext))
                seg_img_fname = seg_labels_list[img_id]

                seg_img = cv2.imread(seg_img_fname)
                if seg_img is None:
                    raise SystemError(
                        'Segmentation image could not be read from: {}'.format(
                            seg_img_fname))

                # seg_img = convert_to_raw_mask(seg_img, n_classes, seg_img_fname)

                if len(seg_img.shape) == 3:
                    seg_img = np.squeeze(seg_img[:, :, 0])

                eval_cl, _ = eval.extract_classes(seg_img)
                gt_cl, _ = eval.extract_classes(label_img_raw)

                # if seg_img.max() > n_classes - 1:
                #     seg_img = (seg_img.astype(np.float64) / label_diff).astype(np.uint8)

                seg_height, seg_width = seg_img.shape

                if seg_width == 2 * src_width or seg_width == 3 * src_width:
                    _start_id = seg_width - src_width
                    seg_img = seg_img[:, _start_id:]

                # print('seg_img.shape: ', seg_img.shape)
                # print('labels_img_orig.shape: ', labels_img_orig.shape)

                pix_acc[img_id] = eval.pixel_accuracy(seg_img, label_img_raw,
                                                      class_ids)
                _acc, mean_acc[img_id] = eval.mean_accuracy(seg_img,
                                                            label_img_raw,
                                                            class_ids,
                                                            return_acc=1)
                _IU, mean_IU[img_id] = eval.mean_IU(seg_img,
                                                    label_img_raw,
                                                    class_ids,
                                                    return_iu=1)
                fw_IU[img_id], _fw = eval.frequency_weighted_IU(seg_img,
                                                                label_img_raw,
                                                                class_ids,
                                                                return_freq=1)
                # try:
                #     fw_sum += _fw
                # except ValueError as e:
                #     print('fw_sum: {}'.format(fw_sum))
                #     print('_fw: {}'.format(_fw))
                #
                #     eval_cl, _ = eval.extract_classes(seg_img)
                #     gt_cl, _ = eval.extract_classes(label_img_raw)
                #     cl = np.union1d(eval_cl, gt_cl)
                #
                #     print('cl: {}'.format(cl))
                #     print('eval_cl: {}'.format(eval_cl))
                #     print('gt_cl: {}'.format(gt_cl))
                #
                #     raise ValueError(e)

                for _class_name, _, base_ids in composite_classes:
                    _acc_list = np.asarray(list(_acc.values()))
                    _mean_acc = np.mean(_acc_list[base_ids])
                    avg_mean_acc[_class_name] += (
                        _mean_acc - avg_mean_acc[_class_name]) / (img_id + 1)

                    _IU_list = np.asarray(list(_IU.values()))
                    _mean_IU = np.mean(_IU_list[base_ids])
                    avg_mean_IU[_class_name] += (
                        _mean_IU - avg_mean_IU[_class_name]) / (img_id + 1)

                for _class_id, _class_data in enumerate(classes):
                    _class_name = _class_data[0]
                    try:
                        _mean_acc = _acc[_class_id]
                        avg_mean_acc[_class_name] += (
                            _mean_acc - avg_mean_acc[_class_name]) / (
                                img_id - skip_mean_acc[_class_name] + 1)
                    except KeyError:
                        print('\nskip_mean_acc {}: {}'.format(
                            _class_name, img_id))
                        skip_mean_acc[_class_name] += 1
                    try:
                        _mean_IU = _IU[_class_id]
                        avg_mean_IU[_class_name] += (
                            _mean_IU - avg_mean_IU[_class_name]) / (
                                img_id - skip_mean_IU[_class_name] + 1)
                    except KeyError:
                        print('\nskip_mean_IU {}: {}'.format(
                            _class_name, img_id))
                        skip_mean_IU[_class_name] += 1

                # seg_img = (seg_img * label_diff).astype(np.uint8)

                seg_img_vis = raw_seg_to_rgb(seg_img, class_id_to_color)

                if params.stitch and params.stitch_seg:
                    if params.blended:
                        full_mask_gs = cv2.cvtColor(seg_img_vis,
                                                    cv2.COLOR_BGR2GRAY)
                        mask_binary = full_mask_gs == 0
                        seg_img_vis = (0.5 * src_img +
                                       0.5 * seg_img_vis).astype(np.uint8)
                        seg_img_vis[mask_binary] = src_img[mask_binary]

                    if params.add_border:
                        stitched.append(border_img)
                    stitched.append(seg_img_vis)

                if not params.stitch and params.show_img:
                    cv2.imshow('seg_img', seg_img_vis)
            # else:
            # _, _fw = eval.frequency_weighted_IU(label_img_raw, label_img_raw, return_freq=1)
            # try:
            #     fw_sum += _fw
            # except ValueError as e:
            #     print('fw_sum: {}'.format(fw_sum))
            #     print('_fw: {}'.format(_fw))
            #
            #     gt_cl, _ = eval.extract_classes(label_img_raw)
            #     print('gt_cl: {}'.format(gt_cl))
            #     for k in range(n_classes):
            #         if k not in gt_cl:
            #             _fw.insert(k, 0)
            #
            #     fw_sum += _fw

            # _fw_total = np.sum(_fw)

            # print('_fw: {}'.format(_fw))
            # print('_fw_total: {}'.format(_fw_total))

            # _fw_frac = np.array(_fw) / float(_fw_total)

            # print('_fw_frac: {}'.format(_fw_frac))

        if params.stitch:
            stitched = np.concatenate(stitched, axis=1)
            if params.save_stitched:
                seg_save_path = os.path.join(
                    params.save_path,
                    '{}_{}.{}'.format(seq_name, img_fname_no_ext,
                                      params.out_ext))
                cv2.imwrite(seg_save_path, stitched)

            if params.show_img:
                cv2.imshow('stitched', stitched)
        else:
            if params.show_img:
                cv2.imshow('src_img', src_img)
                if params.labels_path:
                    cv2.imshow('labels_img', labels_img)

        if params.show_img:
            k = cv2.waitKey(1 - _pause)
            if k == 27:
                sys.exit(0)
            elif k == 32:
                _pause = 1 - _pause
        img_done = img_id - params.start_id + 1
        if img_done % print_diff == 0:
            log_txt = 'Done {:5d}/{:5d} frames'.format(img_done, n_frames)
            if eval_mode:
                log_txt = '{:s} pix_acc: {:.5f} mean_acc: {:.5f} mean_IU: {:.5f} fw_IU: {:.5f}'.format(
                    log_txt, pix_acc[img_id], mean_acc[img_id],
                    mean_IU[img_id], fw_IU[img_id])
                for _class in all_classes:
                    log_txt += ' acc {}: {:.5f} '.format(
                        _class, avg_mean_acc[_class])

                for _class in all_classes:
                    log_txt += ' IU {}: {:.5f} '.format(
                        _class, avg_mean_acc[_class])

            print_and_write(log_txt, log_fname)

    sys.stdout.write('\n')
    sys.stdout.flush()

    if eval_mode:
        log_txt = "pix_acc\t mean_acc\t mean_IU\t fw_IU\n{:.5f}\t{:.5f}\t{:.5f}\t{:.5f}\n".format(
            np.mean(pix_acc), np.mean(mean_acc), np.mean(mean_IU),
            np.mean(fw_IU))

        for _class in all_classes:
            log_txt += ' mean_acc {}\t'.format(_class)
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' {:.5f}\t'.format(avg_mean_acc[_class])
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' mean_IU {}\t'.format(_class)
        log_txt += '\n'

        for _class in all_classes:
            log_txt += ' {:.5f}\t'.format(avg_mean_IU[_class])
        log_txt += '\n'

        print_and_write(log_txt, log_fname)

        log_txt = templ + '\n\t'
        for _class in classes:
            log_txt += '{}\t '.format(_class[0])
        log_txt += 'all_classes\t all_classes(fw)\n'

        log_txt += 'recall\t'
        for _class in classes:
            log_txt += '{:.5f}\t'.format(avg_mean_acc[_class[0]])
        log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_acc),
                                             np.mean(pix_acc))

        log_txt += 'precision\t'
        for _class in classes:
            log_txt += '{:.5f}\t'.format(avg_mean_IU[_class[0]])
        log_txt += '{:.5f}\t{:.5f}\n'.format(np.mean(mean_IU), np.mean(fw_IU))

        print_and_write(log_txt, log_fname)

    # fw_sum_total = np.sum(fw_sum)
    # fw_sum_frac = fw_sum / float(fw_sum_total)

    # print('fw_sum_total: {}'.format(fw_sum_total))
    # print('fw_sum_frac: {}'.format(fw_sum_frac))

    print('Wrote log to: {}'.format(log_fname))