Ejemplo n.º 1
0
def create_and_store_vision_plus_gt_baseline(objid,
                                             k=500,
                                             include_thresh=0.5,
                                             rerun_existing=False):
    print 'Computing vision baseline against gt...'
    outdir = vision_baseline_dir(objid, k, include_thresh)
    outfile = '{}/gt_vision_only_metrics.json'.format(outdir)
    if not rerun_existing and os.path.exists(outfile):
        print "already ran " + outdir
        return
    agg_vision_mask, _ = get_pixtiles(objid, k)
    gt_mask = get_gt_mask(objid)

    vision_only_mask = compute_hybrid_mask(gt_mask,
                                           agg_vision_mask,
                                           expand_thresh=include_thresh,
                                           contract_thresh=0,
                                           vision_only=True)
    with open('{}/vision_with_gt_mask.pkl'.format(outdir), 'w') as fp:
        fp.write(pickle.dumps(vision_only_mask))
    '''
    sum_mask = vision_only_mask.astype(int) * 1 + gt_mask.astype(int) * 2

    plt.figure()
    plt.imshow(sum_mask, interpolation="none", cmap=discrete_cmap(4, 'rainbow'))  # , cmap="rainbow")
    plt.colorbar()
    plt.savefig('{}/vision_with_gt_viz.png'.format(outdir))
    plt.close()
    '''

    p, r, j, fpr, fnr = all_metrics(vision_only_mask, gt_mask)
    with open(outfile, 'w') as fp:
        fp.write(json.dumps([p, r, j, fpr, fnr]))
Ejemplo n.º 2
0
def create_and_store_hybrid_masks(sample_name,
                                  objid,
                                  clust="",
                                  base='MV',
                                  k=500,
                                  expand_thresh=0.8,
                                  contract_thresh=0.2,
                                  rerun_existing=False):
    print "creating and storing hybrid mask for {}, obj{}, clust{}, k={}".format(
        sample_name, objid, clust, k)
    outdir = hybrid_dir(sample_name, objid, k, expand_thresh, contract_thresh)
    clust_num = '-1' if clust == "" else str(clust)
    algo_name = base + '_' + clust_num
    if not rerun_existing and os.path.exists('{}/{}_hybrid_prj.json'.format(
            outdir, algo_name)):
        print "already ran " + outdir
        return
    agg_vision_mask, _ = get_pixtiles(objid, k)
    MV_mask = get_MV_mask(sample_name, objid, cluster_id=clust)
    gt_mask = get_gt_mask(objid)

    if base == 'MV':
        # MV hybrid
        base_mask = get_MV_mask(sample_name, objid, cluster_id=clust)
        # print 'base and gt mask lens:', len(np.where(base_mask == 1)[0]), len(np.where(gt_mask == 1)[0])
    else:
        print 'Only supports MV base right now'
        raise NotImplementedError

    hybrid_mask = compute_hybrid_mask(base_mask,
                                      agg_vision_mask,
                                      expand_thresh=expand_thresh,
                                      contract_thresh=contract_thresh,
                                      objid=objid,
                                      DEBUG=False)
    with open('{}/{}_hybrid_mask.pkl'.format(outdir, algo_name), 'w') as fp:
        fp.write(pickle.dumps(hybrid_mask))

    #sum_mask = hybrid_mask.astype(int) * 5 + MV_mask.astype(int) * 20 + gt_mask.astype(int) * 50
    sum_mask = hybrid_mask.astype(int) * 1 + MV_mask.astype(
        int) * 2 + gt_mask.astype(int) * 4

    if DEBUG:
        plt.figure()
        plt.imshow(sum_mask,
                   interpolation="none",
                   cmap=discrete_cmap(8, 'rainbow'))  # , cmap="rainbow")
        plt.colorbar()
        plt.savefig('{}/{}_hybrid_mask.png'.format(outdir, algo_name))
        plt.close()

    p, r, j = faster_compute_prj(hybrid_mask, gt_mask)
    with open('{}/{}_hybrid_prj.json'.format(outdir, algo_name), 'w') as fp:
        fp.write(json.dumps([p, r, j]))
Ejemplo n.º 3
0
def process_all_worker_tiles(sample, objid, algo, pixels_in_tile, cluster_id="", mode='', DEBUG=False):
    '''
    returns:
        1) num_votes[tid] = number of votes given to tid
        2) tile_area[tid] = area of tile tid
        3) tile_int_area[tid] = est_int_area for tid under given algo
            algo can be 'worker_fraction', 'ground_truth', 'GTLSA' etc.
            'real_GTLSA' if ground truth GTLSA log probabilities
    '''

    if algo not in ['ground_truth', 'worker_fraction']:
        assert mode in ['iso', '']
        read_gt = True if 'real_' in algo else False
        log_probability_in, log_probability_not_in = read_tile_log_probabilities(
            sample, objid, cluster_id, algo, mode, read_gt)
    else:
        assert mode == ''

    # pixels_in_tile = get_all_worker_tiles(sample, objid, cluster_id)
    gt_mask = get_gt_mask(objid)
    worker_mega_mask = get_mega_mask(sample, objid, cluster_id)
    nworkers = num_workers(sample, objid, cluster_id)

    num_votes = defaultdict(int)
    tile_area = defaultdict(int)
    tile_int_area = defaultdict(int)

    for tid in range(len(pixels_in_tile)):
        pixs = list(pixels_in_tile[tid])
        # num_votes[tid] = worker_mega_mask[next(iter(pixs))]
        num_votes[tid] = worker_mega_mask[pixs[0]]
        tile_area[tid] = len(pixs)
        if algo == 'ground_truth':
            for pix in pixs:
                tile_int_area[tid] += int(gt_mask[pix])
        elif algo == 'worker_fraction':
            tile_int_area[tid] = (float(num_votes[tid]) / float(nworkers)) * tile_area[tid]
        else:
            # for ground_truth, worker_fraction don't need to do anything, already computed by default
            pInT = np.exp(log_probability_in[tid])  # all pixels in same tile should have the same pInT
            pNotInT = np.exp(log_probability_not_in[tid])
            if pInT + pNotInT != 0:
                norm_pInT = pInT / (pNotInT+pInT)  # normalized pInT
            else:  # weird bug for object 18 isoGT case
                norm_pInT = 1.
            assert norm_pInT <= 1 and norm_pInT >= 0
            tile_int_area[tid] = norm_pInT * tile_area[tid]

    return num_votes, tile_area, tile_int_area
Ejemplo n.º 4
0
def greedy(sample, objid, algo='worker_fraction', cluster_id="", mode='', output="prj", rerun_existing=False, DEBUG=False):
    print "do greedy {}, obj{},clust{} [{}]".format(sample, objid,cluster_id,mode+algo)
    outdir = tile_and_mask_dir(sample, objid, cluster_id)
    outfile = '{}/{}{}_greedy_metrics.json'.format(outdir,mode, algo)
    if (not rerun_existing) and os.path.exists(outfile):
        print outfile + " already exist, read from file"
        p, r, j, fpr, fnr = json.load(open(outfile))
        return p, r, j, fpr, fnr

    start_read_tiles_time = time.time()
    pixels_in_tile = get_all_worker_tiles(sample, objid, cluster_id)
    start_process_time = time.time()
    num_votes, tile_area, tile_int_area = process_all_worker_tiles(sample, objid, algo, pixels_in_tile, cluster_id, mode, DEBUG=DEBUG)
    end_process_time = time.time()
    sorted_order_tids, tile_added, est_jacc_list = run_greedy_jaccard(tile_area, tile_int_area, DEBUG=False)
    end_greedy_time = time.time()

    if DEBUG:
        print 'get_all_worker_tiles time:', start_process_time - start_read_tiles_time
        print 'process_all_worker_tiles time:', end_process_time - start_process_time
        print 'run_greedy_jaccard time:', end_greedy_time - end_process_time

    gt_mask = get_gt_mask(objid)
    gt_est_tiles = set()
    for tid in sorted_order_tids:
        if tile_added[tid]:
            gt_est_tiles.add(tid)

    if output == "tiles":
        return gt_est_tiles
    elif output == "mask":
        gt_est_mask = tiles_to_mask(gt_est_tiles, pixels_in_tile, gt_mask)
        return gt_est_mask
    elif output == "prj":
        start_metrics_time = time.time()
        gt_est_mask = tiles_to_mask(gt_est_tiles, pixels_in_tile, gt_mask)
        # [p, r, j] = faster_compute_prj(gt_est_mask, gt_mask)
        # [fpr, fnr] = TFPNR(gt_est_mask, gt_mask)
        p, r, j, fpr, fnr = all_metrics(gt_est_mask, gt_mask)
        end_metrics_time = time.time()
        if DEBUG:
            print 'tiles_to_mask and all_metrics time:', end_metrics_time - start_metrics_time
        with open(outfile, 'w') as fp:
            fp.write(json.dumps([p, r, j, fpr, fnr]))
        if j <= 0.5:  # in the case where we are aggregating a semantic error cluster
            pkl.dump(gt_est_mask, open('{}{}_gt_est_mask_greedy.pkl'.format(outdir, algo), 'w'))
        return p, r, j, fpr, fnr
Ejemplo n.º 5
0
def testing_viz_test_gt_vision_overlay():
    # try out different result masks and overlay against vision and gt
    # test potential for improvement
    for batch in ['5workers_rand0']:
        for k in [500]:
            for objid in range(1, 2):
                print 'Visualizing stuff for obj ', objid
                outdir = VISION_DIR + 'visualizing_and_testing_stuff/obj{}/{}'.format(
                    objid, k)
                if not os.path.isdir(outdir):
                    os.makedirs(outdir)
                # all_worker_tiles = pickle.load(open('{}/obj{}/tiles.pkl'.format(batch_path, objid)))
                vision_mask, vision_tiles = get_pixtiles(objid, k)
                gt_mask = get_gt_mask(objid)
                worker_mega_mask = get_mega_mask(batch, objid)
                mv_mask = get_MV_mask(batch, objid)

                # worker_tiles_mask, num_votes, tile_area, tile_int_area = process_all_worker_tiles(
                #     all_worker_tiles, worker_mega_mask, gt_mask)

                print 'Computing simple hybrid mask...'
                print 'MV p, r, j: ', faster_compute_prj(mv_mask, gt_mask)
                all_voted_mask = np.zeros_like(worker_mega_mask)
                num_workers = int(batch.split('workers')[0])
                all_voted_mask[np.where(worker_mega_mask == num_workers)] = 1
                all_mv_but_low_confidence = mv_mask - all_voted_mask
                hybrid_mask = all_voted_mask + compute_hybrid_mask(
                    all_mv_but_low_confidence,
                    vision_mask,
                    expand_thresh=0.8,
                    contract_thresh=0.2,
                    objid=objid,
                    DEBUG=False)
                hybrid_mask[np.where(hybrid_mask > 1)] = 1
                # show_mask(hybrid_mask, figname='{}/hybrid_mask.png'.format(outdir))
                print 'Hybrid p, r, j: ', faster_compute_prj(
                    hybrid_mask, gt_mask)

                visualize_test_gt_vision_overlay(batch, objid, k, hybrid_mask,
                                                 outdir)
                print '----------------------------------------------------------------------------------'
Ejemplo n.º 6
0
def binarySearchDeriveBestThresh(sample_name,
                                 objid,
                                 cluster_id,
                                 log_probability_in,
                                 log_probability_not_in,
                                 tiles_to_search_against,
                                 exclude_isovote=False,
                                 rerun_existing=False,
                                 DEBUG=False):
    # binary search for p == r point
    # p and r computed against tiles_to_search_against
    # if arbitrary reference (for eg., gt) that does not respect tile boundaries, need to change to mask
    thresh_min = -200.0
    thresh_max = 200.0
    thresh = (thresh_min + thresh_max) / 2.0
    p, r = 0, -1
    iterations = 0
    epsilon = 0.125
    outdir = tile_em_output_dir(sample_name, objid, cluster_id)

    if DEBUG:
        # used to compute actual prj
        track = []  # for plotting
        gt_mask = get_gt_mask(objid)
        tiles = get_all_worker_tiles(sample_name, objid, cluster_id)

    while (iterations <= 100
           or p == -1):  # continue iterations below max iterations or if p=-1
        if (p == r) or (thresh_min + epsilon >= thresh_max):
            # stop if p==r or if epsilon (range in x) gets below a certain threshold
            break
        [p, r, j], gt_est_tiles = estimate_gt_compute_PRJ_against_MV(
            sample_name,
            objid,
            cluster_id,
            log_probability_in,
            log_probability_not_in,
            tiles_to_search_against,
            thresh,
            exclude_isovote=exclude_isovote)
        if p > r or p <= 0:
            thresh_max = thresh
        else:
            thresh_min = thresh

        if False:
            print "----Trying threshold:", thresh, "-----"
            print p, r, j, thresh_max, thresh_min
            gt_est_mask = tiles_to_mask(gt_est_tiles, tiles, gt_mask)
            gt_p, gt_r, gt_j = faster_compute_prj(gt_est_mask, gt_mask)
            print "actual prj against GT", gt_p, gt_r, gt_j
            track.append([thresh, p, r, j, gt_p, gt_r, gt_j])
            # plt.figure()
            # plt.title("Iter #"+str(iterations))
            # plt.imshow(gt_est_mask)
            # plt.colorbar()

        thresh = (thresh_min + thresh_max) / 2.0
        iterations += 1

    if False:
        # TODO: currently overwritten by different algos and iterations
        track = np.array(track)
        plt.figure()
        plt.title('prj_crossover')
        idx = np.argsort(track[:, 0])
        ths = track[:, 0][idx]
        ps = track[:, 1][idx]
        rs = track[:, 2][idx]
        js = track[:, 3][idx]
        act_ps = track[:, 4][idx]
        act_rs = track[:, 5][idx]
        act_js = track[:, 6][idx]
        plt.plot(ths, ps, 'o', color='blue', label="p")
        plt.plot(ths, rs, 'x', color='orange', label="r")
        plt.plot(ths, js, '-', color='green', label="j")
        plt.plot(ths, act_ps, '--', color='blue', label="act_p")
        plt.plot(ths, act_rs, '--', color='orange', label="act_r")
        plt.plot(ths, act_js, '--', color='green', label="act_j")
        plt.legend(loc='lower right')
        plt.plot(track[-1][0], track[-1][1], '^', color='red')
        # print track
        # print track[-1][0], track[-1][1]
        plt.savefig('{}prj_crossover.png'.format(outdir))
        plt.close()

    return p, r, j, thresh, gt_est_tiles
Ejemplo n.º 7
0
def do_EM_for(sample_name,
              objid,
              cluster_id="",
              algo='GTLSA',
              rerun_existing=False,
              exclude_isovote=False,
              dump_output_at_every_iter=False,
              compute_PR_every_iter=False,
              PLOT=False,
              DEBUG=False):
    if exclude_isovote:
        mode = 'iso'
    else:
        mode = ''

    start = time.time()
    outdir = tile_em_output_dir(sample_name, objid, cluster_id)

    if DEBUG: print "Doing {} mode = {}".format(algo, mode)
    if not rerun_existing:
        if os.path.isfile('{}{}{}_EM_prj_best_thresh.json'.format(
                outdir, mode, algo)):
            print "Already ran {}, skipped".format(algo)
            return

    tarea = get_tile_to_area_map(sample_name, objid, cluster_id)
    tworkers = get_tile_to_workers_map(sample_name, objid, cluster_id)
    wtiles = get_worker_to_tiles_map(sample_name, objid, cluster_id)
    Nworkers = num_workers(sample_name, objid, cluster_id)

    # only used to convert tiles to pixels when prj needs to be computed
    tiles = get_all_worker_tiles(sample_name, objid, cluster_id)
    gt_mask = get_gt_mask(objid)

    # initialize MV tiles
    MV_tiles = get_MV_tiles(sample_name, objid, cluster_id)
    gt_est_tiles = MV_tiles.copy()
    # In the first step we use 50% MV for initializing T*, A thres is therefore the median area pixel based on votes and noVotes
    prev_gt_est = gt_est_tiles.copy()
    jaccard_against_prev_gt_est = 0
    it = 0
    max_iter = 6
    while (jaccard_against_prev_gt_est < 0.999 or it <= 1):
        if (it >= max_iter):
            break
        if DEBUG:
            print "iteration:", it

        it += 1

        if algo == 'basic':
            q = dict()
        elif algo == 'GT':
            qp = dict()
            qn = dict()
        elif algo == 'GTLSA':
            qp1 = dict()
            qn1 = dict()
            qp2 = dict()
            qn2 = dict()
            area_thresh_gt, area_thresh_ngt = compute_area_thresh(
                gt_est_tiles, tarea)

        if DEBUG:
            t0 = time.time()

        for wid in wtiles.keys():
            if algo == 'basic':
                q[wid] = basic_worker_prob_correct(
                    gt_est_tiles,
                    wtiles[wid],
                    tarea,
                    tworkers,
                    Nworkers,
                    exclude_isovote=exclude_isovote)
            elif algo == 'GT':
                qp[wid], qn[wid] = GT_worker_prob_correct(
                    gt_est_tiles,
                    wtiles[wid],
                    tarea,
                    tworkers,
                    Nworkers,
                    exclude_isovote=False)
            elif algo == 'GTLSA':
                qp1[wid], qn1[wid], qp2[wid], qn2[
                    wid] = GTLSA_worker_prob_correct(
                        gt_est_tiles,
                        wtiles[wid],
                        tarea,
                        tworkers,
                        Nworkers,
                        area_thresh_gt,
                        area_thresh_ngt,
                        exclude_isovote=exclude_isovote)
        if DEBUG:
            t1 = time.time()
            print "Time for worker prob calculation:", t1 - t0
            # print 'qp1:', qp1
            # print 'qp2:', qp2
            # print 'qn1:', qn1
            # print 'qn2:', qn2

        # Compute pInMask and pNotInMask
        if algo == 'basic':
            log_probability_in, log_probability_not_in = basic_log_probabilities(
                wtiles, q, tarea)
        elif algo == 'GT':
            log_probability_in, log_probability_not_in = GT_log_probabilities(
                wtiles, qp, qn, tarea)
        elif algo == 'GTLSA':
            log_probability_in, log_probability_not_in = GTLSA_log_probabilities(
                wtiles, qp1, qn1, qp2, qn2, tarea, area_thresh_gt,
                area_thresh_ngt)

        if DEBUG:
            t2 = time.time()
            print "Time for mask log prob calculation:", t2 - t1

        # gt_est_mask = estimate_gt_from(log_probability_in_mask, log_probability_not_in_mask,thresh=thresh)
        p, r, j, thresh, gt_est_tiles = binarySearchDeriveBestThresh(
            sample_name,
            objid,
            cluster_id,
            log_probability_in,
            log_probability_not_in,
            MV_tiles,
            exclude_isovote=exclude_isovote,
            rerun_existing=rerun_existing,
            DEBUG=DEBUG)

        # Compute PR mask based on the EM estimate mask from every iteration
        if compute_PR_every_iter:
            gt_est_mask = tiles_to_mask(gt_est_tiles, tiles, gt_mask)
            [p, r, j] = faster_compute_prj(gt_est_mask, gt_mask)
            with open(
                    '{}{}{}_EM_prj_iter{}_thresh{}.json'.format(
                        outdir, mode, algo, it, thresh), 'w') as fp:
                fp.write(json.dumps([p, r, j]))

            if DEBUG:
                print qp1, qn1, qp2, qn2
                print "-->" + str([p, r, j])

        # compute jaccard between previous and current gt estimation mask
        [p_against_prev, r_against_prev, jaccard_against_prev_gt_est
         ] = prj_tile_against_tile(gt_est_tiles, prev_gt_est, tarea)
        if DEBUG:
            print "jaccard_against_prev_gt_est:", jaccard_against_prev_gt_est
        prev_gt_est = gt_est_tiles.copy()

    gt_est_mask = tiles_to_mask(gt_est_tiles, tiles, gt_mask)
    [p, r, j] = faster_compute_prj(gt_est_mask, gt_mask)
    if DEBUG:
        print 'Final prj:', [p, r, j]

    with open('{}{}{}_EM_prj_best_thresh.json'.format(outdir, mode, algo),
              'w') as fp:
        fp.write(json.dumps([p, r, j]))

    pickle.dump(
        gt_est_tiles,
        open('{}{}{}_gt_est_tiles_best_thresh.pkl'.format(outdir, mode, algo),
             'w'))
    pickle.dump(
        log_probability_in,
        open('{}{}{}_p_in_tiles_best_thresh.pkl'.format(outdir, mode, algo),
             'w'))
    pickle.dump(
        log_probability_not_in,
        open(
            '{}{}{}_p_not_in_tiles_best_thresh.pkl'.format(outdir, mode, algo),
            'w'))

    if algo == 'basic':
        pickle.dump(
            q, open('{}{}{}_q_best_thresh.pkl'.format(outdir, mode, algo),
                    'w'))
    elif algo == 'GT':
        pickle.dump(
            qp,
            open('{}{}{}_qp_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
        pickle.dump(
            qn,
            open('{}{}{}_qn_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
    elif algo == 'GTLSA':
        pickle.dump(
            qp1,
            open('{}{}{}_qp1_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
        pickle.dump(
            qn1,
            open('{}{}{}_qn1_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
        pickle.dump(
            qp2,
            open('{}{}{}_qp2_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
        pickle.dump(
            qn2,
            open('{}{}{}_qn2_best_thresh.pkl'.format(outdir, mode, algo), 'w'))
    if PLOT:
        plt.figure()
        plt.imshow(gt_est_mask, interpolation="none")  # ,cmap="rainbow")
        plt.colorbar()
        plt.savefig('{}{}{}_EM_mask_thresh{}.png'.format(
            outdir, mode, algo, thresh))

    end = time.time()
    if DEBUG:
        print "Time:{}".format(end - start)

    return end - start
Ejemplo n.º 8
0
def compute_hybrid_mask(base_mask,
                        agg_vision_mask,
                        expand_thresh=0.8,
                        contract_thresh=0,
                        objid=None,
                        vision_only=False,
                        DEBUG=False):
    # objid only for debugging purposes
    intersection_area = defaultdict(
        float)  # key = v_tile id taken as value of agg_vision_mask[i][j]
    vtile_area = defaultdict(
        float)  # key = v_tile id taken as value of agg_vision_mask[i][j]
    base_mask_area = 0.0
    if DEBUG:
        print 'Num unique vision tiles: ', len(np.unique(agg_vision_mask))
    for i in range(len(agg_vision_mask)):
        for j in range(len(agg_vision_mask[0])):
            vtile_id = agg_vision_mask[i][j]
            if vtile_id == 0:
                continue
            vtile_area[vtile_id] += 1.0
            if base_mask[i][j]:
                base_mask_area += 1.0
                intersection_area[vtile_id] += 1.0
    if vision_only:
        # only include or delete full vision tiles
        # only using base mask to decide which to include / delete
        final_mask = np.zeros_like(base_mask)
    else:
        final_mask = np.copy(base_mask)
    # for vtile_id in np.unique(agg_vision_mask):
    #     if vtile_id == 0:
    #         continue

    for vtile_id in vtile_area:
        # for each vtile, decide o either fill out leftovers of it, delete it's intersection with base_mask
        # or leave unchanged
        if intersection_area[vtile_id] == 0:
            continue

        frac_vtile_covered = float(intersection_area[vtile_id]) / float(
            vtile_area[vtile_id])

        if DEBUG:
            print '-----------------'
            print 'Intersection area: ', intersection_area[vtile_id]
            print 'vtile area: ', vtile_area[vtile_id]
            print 'Base mask area: ', base_mask_area
            print 'Frac vtile covered: ', frac_vtile_covered

        if frac_vtile_covered > expand_thresh:
            # expand mask to include entire vision tile
            final_mask[agg_vision_mask == vtile_id] = True
            if DEBUG:
                print 'Expanding'
        elif frac_vtile_covered < contract_thresh:
            # delete mask to exclude entire vision tile
            final_mask[agg_vision_mask == vtile_id] = False
            if DEBUG:
                print 'Deleting'
        elif DEBUG:
            print 'Passing'
        if DEBUG:
            v_mask = np.copy(agg_vision_mask)
            v_mask[v_mask != vtile_id] = 0
            v_mask[v_mask == vtile_id] = 20
            plot_base_mask = np.copy(base_mask).astype(int) * 50
            plot_gt_mask = np.copy(get_gt_mask(objid)).astype(int) * 100
            plot_sum_mask = v_mask + plot_base_mask + plot_gt_mask
            plt.figure()
            plt.imshow(plot_sum_mask, interpolation="none")  # , cmap="hot")
            plt.colorbar()
            plt.show()
            plt.close()
    return final_mask