Exemple #1
0
def get_density_image_old(im_shape,x,y):
    """
    im_shape: [ # pixels left-to-right, # pixels top-to-bottom ]
    """

    #
    # setup variables
    #
    # im_shape = im_shape[::-1]

    #print(im_shape)
    thresh = 0.85
    sigma = 250.
    z_thresh = 10e-3
    my_dpi = 100
    my_figsize = (im_shape[1] / my_dpi, im_shape[0] / my_dpi) 
    centers = np.c_[x,y]
    def nice_cov(*args):
        return .5 * 10e-3 * np.array([1,3])
    cmap = plt.cm.Purples
    cmap.set_bad(color="white")
    
    #
    # create dummy data (testing only)
    #

    # nsamples = 100
    # centers = np.ones( (nsamples,2) )
    # centers[:,0] *= 300 + 10 * npr.normal(size=nsamples)
    # centers[:,1] *= 1500 + 10 * npr.normal(size=nsamples)
    # centers.astype(np.int)

    #
    # create the kde
    #

    #centers += 0.01 * np.random.normal(size=centers.shape)
    centers = index_img_to_plt(centers,im_shape[0])
    kde = stats.gaussian_kde(centers.T)
    kde.set_bandwidth(nice_cov)
    gap_size = 2
    xx,yy = np.mgrid[0:im_shape[1]:gap_size,0:im_shape[0]:gap_size]
    z = kde(np.c_[xx.flat,yy.flat].T).reshape(xx.shape)
    z = np_log(z)
    z = np.ma.masked_where(z > -10e-4, z)

    #
    # create figure
    #

    fig = plt.figure(frameon=False,facecolor='white',figsize=my_figsize)
    ax = plt.Axes(fig,[0.,0.,1.,1.])
    fig.add_axes(ax)
    ax.set_axis_off()    
    ax.set_xlim([0,im_shape[1]])
    ax.set_ylim([0,im_shape[0]])


    #
    # plot and ndarray
    #

    ax.contourf(xx,yy,z,25,cmap=cmap)
    img = matplotlib_to_numpy(fig,im_shape[::-1],dpi=my_dpi)
    img = invert_np_image(img)
    # img = np.rot90(img,k=2)
    # img = np.fliplr(img)
    #print("shape := [ # rows x # cols ]",img.shape)
    #cv2.imwrite("./tmp.png",img)
    plt.close(fig)
    return img
Exemple #2
0
def driftmap_color(clusters_depths,
                   spikes_times,
                   spikes_amps,
                   spikes_depths,
                   spikes_clusters,
                   ax=None,
                   axesoff=False,
                   return_lims=False):
    '''
    Plots the driftmap of a session or a trial

    The plot shows the spike times vs spike depths.
    Each dot is a spike, whose color indicates the cluster
    and opacity indicates the spike amplitude.

    Parameters
    -------------
    clusters_depths: ndarray
        depths of all clusters
    spikes_times: ndarray
        spike times of all clusters
    spikes_amps: ndarray
        amplitude of each spike
    spikes_depths: ndarray
        depth of each spike
    spikes_clusters: ndarray
        cluster idx of each spike
    ax: matplotlib.axes.Axes object (optional)
        The axis object to plot the driftmap on
        (if `None`, a new figure and axis is created)

    Return
    ---
    ax: matplotlib.axes.Axes object
        The axis object with driftmap plotted
    x_lim: list of two elements
        range of x axis
    y_lim: list of two elements
        range of y axis
    '''

    color_bins = sns.color_palette("hls", 500)
    new_color_bins = np.vstack(
        np.transpose(np.reshape(color_bins, [5, 100, 3]), [1, 0, 2]))

    # get the sorted idx of each depth, and create colors based on the idx

    sorted_idx = np.argsort(np.argsort(clusters_depths))

    colors = np.vstack([
        np.repeat(new_color_bins[np.mod(idx, 500), :][np.newaxis, ...],
                  n_spikes,
                  axis=0)
        for (
            idx,
            n_spikes) in zip(sorted_idx,
                             np.unique(spikes_clusters, return_counts=True)[1])
    ])

    max_amp = np.percentile(spikes_amps, 90)
    min_amp = np.percentile(spikes_amps, 10)
    opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp)
    opacity[opacity > 1] = 1
    opacity[opacity < 0] = 0

    colorvec = np.zeros([len(opacity), 4], dtype='float16')
    colorvec[:, 3] = opacity.astype('float16')
    colorvec[:, 0:3] = colors.astype('float16')

    x = spikes_times.astype('float32')
    y = spikes_depths.astype('float32')

    args = dict(color=colorvec, edgecolors='none')

    if ax is None:
        fig = plt.Figure(dpi=200, frameon=False, figsize=[10, 10])
        ax = plt.Axes(fig, [0.1, 0.1, 0.9, 0.9])
        ax.set_xlabel('Time (sec)')
        ax.set_ylabel('Distance from the probe tip (um)')
        savefig = True
        args.update(s=0.1)

    ax.scatter(x, y, **args)
    x_edge = (max(x) - min(x)) * 0.05
    x_lim = [min(x) - x_edge, max(x) + x_edge]
    y_lim = [min(y) - 50, max(y) + 100]
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(y_lim[0], y_lim[1])

    if axesoff:
        ax.axis('off')

    if savefig:
        fig.add_axes(ax)
        fig.savefig('driftmap.png')

    if return_lims:
        return ax, x_lim, y_lim
    else:
        return ax
Exemple #3
0
    def get_zoom_iamges(image_path, save_dir, is_initial=False, is_last=False, use_first_zoom_factor=True):
        legend_loc = 50
        t = 1
        cmap = {1: [1.0, 0.0, 0.0, t], 2: [0.75, 0.75, 0.75, t], 3: [0.0, 1.0, 0.0, t]}
        labels = {1: "Initial", 2: "GT", 3: "Refined"}
        patches = [mpatches.Patch(color=cmap[i], label=labels[i]) for i in range(1, 4)]
        if use_first_zoom_factor:
            init_info_path = image_path.replace(
                os.path.basename(image_path)[os.path.basename(image_path).find("iter") :], "iter_00_info.txt"
            )
            _, _, zoom_factor = read_info(init_info_path)

            info_path = image_path.replace(".png", "_info.txt")
            title, legend, _ = read_info(info_path)
        else:
            info_path = image_path.replace(".png", "_info.txt")
            title, legend, zoom_factor = read_info(info_path)

        zoom_factor = zoom_factor[None, :]
        # print(zoom_factor)
        image_real = cv2.imread(image_path, cv2.IMREAD_COLOR).transpose([2, 0, 1])[None, :, :, :]
        image_rendered = image_real.copy()
        exe1 = zoom_op.simple_bind(
            ctx=ctx, zoom_factor=zoom_factor.shape, image_real=image_real.shape, image_rendered=image_rendered.shape
        )

        def simple_forward(exe1, zoom_factor, image_real, image_rendered, ctx=ctx, is_train=False):
            print("zoom factor: ", zoom_factor)
            exe1.arg_dict["zoom_factor"][:] = mx.nd.array(zoom_factor, ctx=ctx)
            exe1.arg_dict["image_real"][:] = mx.nd.array(image_real, ctx=ctx)
            exe1.arg_dict["image_rendered"][:] = mx.nd.array(image_rendered, ctx=ctx)
            exe1.forward(is_train=is_train)

        if is_initial:
            # original
            fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100)
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
            ax.set_axis_off()
            fig.add_axes(ax)
            # print(image_real[0].shape)
            ax.imshow(image_real[0].transpose((1, 2, 0))[:, :, [2, 1, 0]])
            fig.gca().text(10, 25, title, color="green", bbox=dict(facecolor="white", alpha=0.8))
            fig.gca().text(10, legend_loc, legend, color="red", bbox=dict(facecolor="white", alpha=0.8))
            plt.legend(handles=patches, loc=4, borderaxespad=0.0)
            # plt.show()
            save_d = os.path.join(save_dir, os.path.dirname(image_path).split("/")[-1])
            mkdir_if_missing(save_d)
            save_path = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_0.png"))
            plt.savefig(save_path, aspect="normal")
            plt.close()

            # ################### (1/3)
            wx, wy, tx, ty = zoom_factor[0]
            delta = (1 - wx) / 3
            zoom_factor_1 = np.zeros((1, 4))
            zoom_factor_1[0, 0] = 1 - delta
            zoom_factor_1[0, 1] = 1 - delta
            zoom_factor_1[0, 2] = tx / 3
            zoom_factor_1[0, 3] = ty / 3

            simple_forward(exe1, zoom_factor_1, image_real, image_rendered, ctx=ctx, is_train=True)
            zoom_image_real = exe1.outputs[0].asnumpy()[0].transpose((1, 2, 0)) + pixel_means
            zoom_image_real[zoom_image_real < 0] = 0
            zoom_image_real[zoom_image_real > 255] = 255
            zoom_image_real = zoom_image_real.astype("uint8")
            fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100)
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.imshow(zoom_image_real[:, :, [2, 1, 0]])
            fig.gca().text(10, 25, title, color="green", bbox=dict(facecolor="white", alpha=0.8))
            fig.gca().text(10, legend_loc, legend, color="red", bbox=dict(facecolor="white", alpha=0.8))
            plt.legend(handles=patches, loc=4, borderaxespad=0.0)
            save_d = os.path.join(save_dir, os.path.dirname(image_path).split("/")[-1])
            mkdir_if_missing(save_d)
            save_path = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_1.png"))
            plt.savefig(save_path, aspect="normal")
            # plt.show()
            plt.close()

            # #################### (2/3)
            zoom_factor_2 = np.zeros((1, 4))
            zoom_factor_2[0, 0] = 1 - 2 * delta
            zoom_factor_2[0, 1] = 1 - 2 * delta
            zoom_factor_2[0, 2] = tx / 3 * 2
            zoom_factor_2[0, 3] = ty / 3 * 2

            simple_forward(exe1, zoom_factor_2, image_real, image_rendered, ctx=ctx, is_train=True)
            zoom_image_real = exe1.outputs[0].asnumpy()[0].transpose((1, 2, 0)) + pixel_means
            zoom_image_real[zoom_image_real < 0] = 0
            zoom_image_real[zoom_image_real > 255] = 255
            zoom_image_real = zoom_image_real.astype("uint8")
            fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100)
            ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
            ax.set_axis_off()
            fig.add_axes(ax)
            ax.imshow(zoom_image_real[:, :, [2, 1, 0]])
            fig.gca().text(10, 25, title, color="green", bbox=dict(facecolor="white", alpha=0.8))
            fig.gca().text(10, legend_loc, legend, color="red", bbox=dict(facecolor="white", alpha=0.8))
            plt.legend(handles=patches, loc=4, borderaxespad=0.0)
            save_d = os.path.join(save_dir, os.path.dirname(image_path).split("/")[-1])
            mkdir_if_missing(save_d)
            save_path = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_2.png"))
            plt.savefig(save_path, aspect="normal")
            # plt.show()
            plt.close()

        # ###################### (3/3)
        simple_forward(exe1, zoom_factor, image_real, image_rendered, ctx=ctx, is_train=True)
        zoom_image_real = exe1.outputs[0].asnumpy()[0].transpose((1, 2, 0)) + pixel_means
        zoom_image_real[zoom_image_real < 0] = 0
        zoom_image_real[zoom_image_real > 255] = 255
        zoom_image_real = zoom_image_real.astype("uint8")
        fig = plt.figure(frameon=False, figsize=(8, 6), dpi=100)
        ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(zoom_image_real[:, :, [2, 1, 0]])
        fig.gca().text(10, 25, title, color="green", bbox=dict(facecolor="white", alpha=0.8))
        fig.gca().text(10, legend_loc, legend, color="red", bbox=dict(facecolor="white", alpha=0.8))
        plt.legend(handles=patches, loc=4, borderaxespad=0.0)
        save_d = os.path.join(save_dir, os.path.dirname(image_path).split("/")[-1])
        mkdir_if_missing(save_d)
        if is_initial:
            save_path = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_3.png"))
            # plt.show()
            plt.savefig(save_path, aspect="normal")
            plt.close()
        elif is_last:
            save_path_0 = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_0.png"))
            save_path_1 = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_1.png"))
            save_path_2 = os.path.join(save_d, os.path.basename(image_path).replace(".png", "_2.png"))
            plt.savefig(save_path_0, aspect="normal")
            plt.savefig(save_path_1, aspect="normal")
            plt.savefig(save_path_2, aspect="normal")
            # plt.show()
            plt.close()
        else:
            save_path = os.path.join(save_d, os.path.basename(image_path))
            plt.savefig(save_path, aspect="normal")
            # plt.show()
            plt.close()
Exemple #4
0
def _plot_pseudo_labels(batch_dict, ib):
    # pseudo labels
    pl_xy = batch_dict["pseudo_label_loc_xy"][ib]
    pl_boxes = batch_dict["pseudo_label_boxes"][ib]

    if len(pl_xy) == 0:
        return

    # groundtruth
    anns_rphi = np.array(batch_dict["dets_wp"][ib],
                         dtype=np.float32)[batch_dict["anns_valid_mask"][ib]]

    # match pseudo labels with groundtruth
    if len(anns_rphi) > 0:
        gts_x, gts_y = u.rphi_to_xy(anns_rphi[:, 0], anns_rphi[:, 1])

        x_diff = pl_xy[:, 0].reshape(-1, 1) - gts_x.reshape(1, -1)
        y_diff = pl_xy[:, 1].reshape(-1, 1) - gts_y.reshape(1, -1)
        d_diff = np.sqrt(x_diff * x_diff + y_diff * y_diff)
        match_found = d_diff < 0.3  # (pl, gt)
        match_found = match_found.max(axis=1)
    else:
        match_found = np.zeros(len(pl_xy), dtype=np.bool)

    # overlay image with laser
    im = batch_dict["im_data"][ib]["stitched_image0"]
    scan_r = batch_dict["scans"][ib][-1]
    scan_phi = batch_dict["scan_phi"][ib]
    scan_x, scan_y = u.rphi_to_xy(scan_r, scan_phi)
    scan_z = batch_dict["laser_z"][ib]
    scan_xyz_laser = np.stack((scan_x, -scan_y, scan_z),
                              axis=0)  # in JRDB laser frame
    p_xy, ib_mask = jt.transform_pts_laser_to_stitched_im(scan_xyz_laser)
    p_xy = p_xy[:, ib_mask]
    c_bgr = _distance_to_bgr_color(scan_r[ib_mask])

    # plot
    frame_id = f"{batch_dict['frame_id'][ib]:06d}"
    sequence = batch_dict["sequence"][ib]

    for count, (xy, box, is_pos) in enumerate(zip(pl_xy, pl_boxes,
                                                  match_found)):
        # image
        x0, y0, x1, y1 = box
        x0 = int(x0)
        x1 = int(x1)
        y0 = int(y0)
        y1 = int(y1)
        im_box = im[y0:y1 + 1, x0:x1 + 1]
        height = y1 - y0
        width = x1 - x0

        fig_w_inch = 0.314961 * 2.0
        fig_h_inch = 0.708661 * 2.0

        fig_im = plt.figure()
        fig_im.set_size_inches(fig_w_inch, fig_h_inch, forward=False)
        ax_im = plt.Axes(fig_im, [0.0, 0.0, 1.0, 1.0])
        ax_im.imshow(im_box)
        ax_im.set_axis_off()
        ax_im.axis(([0, width, height, 0]))
        ax_im.set_aspect((fig_h_inch / fig_w_inch) / (height / width))
        fig_im.add_axes(ax_im)

        in_box_mask = np.logical_and(
            np.logical_and(p_xy[0] >= x0, p_xy[0] <= x1),
            np.logical_and(p_xy[1] >= y0, p_xy[1] <= y1),
        )
        plt.scatter(
            p_xy[0, in_box_mask] - x0,
            p_xy[1, in_box_mask] - y0,
            s=3,
            c=c_bgr[in_box_mask],
        )

        pos_neg_dir = "true" if is_pos else "false"
        fig_file = os.path.join(
            _SAVE_DIR,
            f"samples/{sequence}/{pos_neg_dir}/{frame_id}_{count}_im.pdf")
        os.makedirs(os.path.dirname(fig_file), exist_ok=True)
        plt.savefig(fig_file, dpi=height / fig_h_inch)
        plt.close(fig_im)

        # lidar
        plot_range = 0.5
        close_mask = np.hypot(scan_x - xy[0], scan_y - xy[1]) < plot_range

        fig = plt.figure(figsize=(5, 5))
        ax = fig.add_subplot()
        ax.set_aspect("equal")
        ax.axis("off")
        # ax.set_xlim(-plot_range, plot_range)
        # ax.set_ylim(-plot_range, plot_range)
        # ax.set_xlabel("x [m]")
        # ax.set_ylabel("y [m]")
        # ax.set_aspect("equal")
        # ax.set_title(f"Frame {batch_dict['idx'][ib]}")

        # plot points in local frame (so it looks aligned with image)
        ang = np.mean(scan_phi[close_mask]) - 0.5 * np.pi
        ca, sa = np.cos(ang), np.sin(ang)
        xy_plotting = np.array([[ca, sa], [-sa, ca]]) @ np.stack(
            (scan_x[close_mask] - xy[0], scan_y[close_mask] - xy[1]), axis=0)

        ax.scatter(-xy_plotting[0],
                   xy_plotting[1],
                   s=80,
                   color=(191 / 255, 83 / 255, 79 / 255))
        ax.scatter(0,
                   0,
                   s=500,
                   color=(18 / 255, 105 / 255, 176 / 255),
                   marker="+",
                   linewidth=5)

        fig_file = os.path.join(
            _SAVE_DIR,
            f"samples/{sequence}/{pos_neg_dir}/{frame_id}_{count}_pt.pdf")
        fig.savefig(fig_file)
        plt.close(fig)
Exemple #5
0
def demo(opt):
    model.eval()
    #########################################################################################
    # eval begins here
    #########################################################################################
    data_iter_val = iter(dataloader_val)
    loss_temp = 0
    start = time.time()

    num_show = 0
    predictions = []
    count = 0
    for step in range(1000):
        data = data_iter_val.next()
        img, iseq, gts_seq, num, proposals, bboxs, box_mask, img_id = data

        # if img_id[0] != 134688:
        #     continue

        # # for i in range(proposals.size(1)): print(opt.itoc[proposals[0][i][4]], i)

        # # list1 = [6, 10]
        # list1 = [0, 1, 10, 2, 3, 4, 5, 6, 7, 8, 9]
        # proposals = proposals[:,list1]
        # num[0,1] = len(list1)
        proposals = proposals[:,:max(int(max(num[:,1])),1),:]

        input_imgs.data.resize_(img.size()).copy_(img)
        input_seqs.data.resize_(iseq.size()).copy_(iseq)
        gt_seqs.data.resize_(gts_seq.size()).copy_(gts_seq)
        input_num.data.resize_(num.size()).copy_(num)
        input_ppls.data.resize_(proposals.size()).copy_(proposals)
        gt_bboxs.data.resize_(bboxs.size()).copy_(bboxs)
        mask_bboxs.data.resize_(box_mask.size()).copy_(box_mask)
        input_imgs.data.resize_(img.size()).copy_(img)

        eval_opt = {'sample_max':1, 'beam_size': opt.beam_size, 'inference_mode' : True, 'tag_size' : opt.cbs_tag_size}
        seq, bn_seq, fg_seq, _, _, _ = model._sample(input_imgs, input_ppls, input_num, eval_opt)

        sents, det_idx, det_word = utils.decode_sequence_det(dataset_val.itow, dataset_val.itod, dataset_val.ltow, dataset_val.itoc, dataset_val.wtod, \
                                                            seq, bn_seq, fg_seq, opt.vocab_size, opt)

        if opt.dataset == 'flickr30k':
            im2show = Image.open(os.path.join(opt.image_path, '%d.jpg' % img_id[0])).convert('RGB')
        else:

            if os.path.isfile(os.path.join(opt.image_path, 'val2014/COCO_val2014_%012d.jpg' % img_id[0])):
                im2show = Image.open(os.path.join(opt.image_path, 'val2014/COCO_val2014_%012d.jpg' % img_id[0])).convert('RGB')
            else:
                im2show = Image.open(os.path.join(opt.image_path, 'train2014/COCO_train2014_%012d.jpg' % img_id[0])).convert('RGB')

        w, h = im2show.size

        rest_idx = []
        for i in range(proposals[0].shape[0]):
            if i not in det_idx:
                rest_idx.append(i)


        if len(det_idx) > 0:
            # for visulization
            proposals = proposals[0].numpy()
            proposals[:,0] = proposals[:,0] * w / float(opt.image_crop_size)
            proposals[:,2] = proposals[:,2] * w / float(opt.image_crop_size)
            proposals[:,1] = proposals[:,1] * h / float(opt.image_crop_size)
            proposals[:,3] = proposals[:,3] * h / float(opt.image_crop_size)            

            cls_dets = proposals[det_idx]
            rest_dets = proposals[rest_idx]

        # fig = plt.figure()
        # fig = plt.figure(frameon=False)
        # ax = plt.Axes(fig, [0., 0., 1., 1.])
        fig = plt.figure(frameon=False)
        # fig.set_size_inches(5,5*h/w)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        a=fig.gca()
        a.set_frame_on(False)
        a.set_xticks([]); a.set_yticks([])
        plt.axis('off')
        plt.xlim(0,w); plt.ylim(h,0)
        # fig, ax = plt.subplots(1)

        # show other box in grey.

        plt.imshow(im2show)

        if len(rest_idx) > 0:
            for i in range(len(rest_dets)):
                ax = utils.vis_detections(ax, dataset_val.itoc[int(rest_dets[i,4])], rest_dets[i,:5], i, 1)

        if len(det_idx) > 0:
            for i in range(len(cls_dets)):
                ax = utils.vis_detections(ax, dataset_val.itoc[int(cls_dets[i,4])], cls_dets[i,:5], i, 0)

        # plt.axis('off')
        # plt.axis('tight')
        # plt.tight_layout()
        fig.savefig('visu/%d.jpg' %(img_id[0]), bbox_inches='tight', pad_inches=0, dpi=150)
        print(str(img_id[0]) + ': ' + sents[0])

        entry = {'image_id': img_id[0], 'caption': sents[0]}
        predictions.append(entry)

    return predictions
def main():
    # Initialize and parse command-line arguments.
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument('--verbose', action = 'store_true',
        help = 'Provide verbose output.')
    descwl.output.Reader.add_args(parser)
    parser.add_argument('--no-display', action = 'store_true',
        help = 'Do not display the image on screen.')
    parser.add_argument('-o','--output-name',type = str, default = None, metavar = 'FILE',
        help = 'Name of the output file to write.')

    select_group = parser.add_argument_group('Object selection options')
    select_group.add_argument('--galaxy', type = int, action = 'append',
        default = [ ], metavar = 'ID',
        help = 'Select the galaxy with this database ID (can be repeated).')
    select_group.add_argument('--group', type = int, action = 'append',
        default = [ ], metavar = 'ID',
        help = 'Select galaxies belonging to the group with this group ID (can be repeated).')
    select_group.add_argument('--select', type = str, action = 'append',
        default = [ ], metavar = 'CUT',
        help = 'Select objects passing the specified cut (can be repeated).')
    select_group.add_argument('--select-region', type = str,
        default = None, metavar = '[XMIN,XMAX,YMIN,YMAX]',
        help = 'Select objects within this region relative to the image center (arcsecs).')
    select_group.add_argument('--save-selected', type = str, default = None,
        help = 'Name of FITS file for saving table of selected objects')

    match_group = parser.add_argument_group('Detection catalog matching options')
    match_group.add_argument('--match-catalog', type = str,
        default = None, metavar = 'FILE',
        help = 'Name of SExtractor-compatible detection catalog to read.')
    match_group.add_argument('--match-color', type = str,
        default = 'black', metavar = 'COL',
        help = 'Matplotlib color name to use for displaying detection catalog matches.')
    match_group.add_argument('--match-info', type = str,
        default = None, metavar = 'FMT',
        help = 'String interpolation format to generate matched object annotations.')

    view_group = parser.add_argument_group('Viewing options')
    view_group.add_argument('--magnification', type = float,
        default = 1, metavar = 'MAG',
        help = 'Magnification factor to use for display.')
    view_group.add_argument('--crop', action = 'store_true',
        help = 'Crop the displayed pixels around the selected objects.')
    view_group.add_argument('--view-region', type = str,
        default = None, metavar = '[XMIN,XMAX,YMIN,YMAX]',
        help = 'Viewing region in arcsecs relative to the image center (overrides crop if set).')
    view_group.add_argument('--draw-moments', action = 'store_true',
        help = 'Draw ellipses to represent the 50%% iosophote second moments of selected objects.')
    view_group.add_argument('--info', type = str,
        default = None, metavar = 'FMT',
        help = 'String interpolation format to generate annotation labels.')
    view_group.add_argument('--no-crosshair', action = 'store_true',
        help = 'Do not draw a crosshair at the centroid of each selected object.')
    view_group.add_argument('--clip-lo-noise-fraction', type = float,
        default = 0.05, metavar = 'FRAC',
        help = 'Clip pixels with values below this fraction of the mean sky noise.')
    view_group.add_argument('--clip-hi-percentile', type = float,
        default = 90.0, metavar = 'PCT',
        help = 'Clip pixels with non-zero values above this percentile for the selected image.')
    view_group.add_argument('--hide-background', action = 'store_true',
        help = 'Do not display background pixels.')
    view_group.add_argument('--hide-selected', action = 'store_true',
        help = 'Do not overlay any selected pixels.')
    view_group.add_argument('--add-noise',type = int,default = None,metavar = 'SEED',
        help = 'Add Poisson noise using the seed provided (no noise is added unless this is set).')
    view_group.add_argument('--clip-noise',type = float,default = -1.,metavar = 'SIGMAS',
        help = 'Clip background images at this many sigmas when noise is added.')
    view_group.add_argument('--zscale-all', action='store_true',
        help = 'Set zscale using all displayed objects instead of only selected ones')

    format_group = parser.add_argument_group('Formatting options')
    format_group.add_argument('--info-size', type = str,
        default = 'large', metavar = 'SIZE',
        help = 'Matplotlib font size specification in points or relative (small,large,...)')
    format_group.add_argument('--dpi', type = float, default = 64.,
        help = 'Number of pixels per inch to use for display.')
    format_group.add_argument('--max-view-size', type = int,
        default = 2048, metavar = 'SIZE',
        help = 'Maximum allowed pixel dimensions of displayed image.')
    format_group.add_argument('--colormap', type = str,
        default = 'viridis', metavar = 'CMAP',
        help = 'Matplotlib colormap name to use for background pixel values.')
    format_group.add_argument('--highlight', type = str,
        default = 'red', metavar = 'COL',
        help = 'Matplotlib color name to use for highlighted pixel values.')
    format_group.add_argument('--crosshair-color', type = str,
        default = 'greenyellow', metavar = 'COL',
        help = 'Matplotlib color name to use for crosshairs.')
    format_group.add_argument('--ellipse-color', type = str,
        default = 'greenyellow', metavar = 'COL',
        help = 'Matplotlib color name to use for second-moment ellipses.')
    format_group.add_argument('--info-color', type = str,
        default = 'green', metavar = 'COL',
        help = 'Matplotlib color name to use for info text.')
    format_group.add_argument('--outline-color', type = str,
        default = None, metavar = 'COL',
        help = 'Matplotlib color name to use for outlining text.')

    args = parser.parse_args()

    if args.no_display and not args.output_name:
        print('No display our output requested.')
        return 0
    if args.hide_background and args.hide_selected:
        print('No pixels visible with --hide-background and --hide-selected.')
        return 0

    # Load the analysis results file we will display from.
    try:
        reader = descwl.output.Reader.from_args(defer_stamp_loading = True,args = args)
        results = reader.results
        if args.verbose:
            print(results.survey.description())
    except RuntimeError as e:
        print(str(e))
        return -1

    # Add noise, if requested.
    if args.add_noise is not None:
        results.add_noise(args.add_noise)

    # Match detected objects to simulated objects, if requested.
    if args.match_catalog:
        detected,matched,matched_indices,matched_distance = (
            results.match_sextractor(args.match_catalog))
        if args.verbose:
            print('Matched %d of %d detected objects (median sep. = %.2f arcsecs).' % (
                np.count_nonzero(matched),len(matched),np.median(matched_distance)))

    # Create region selectors.
    if args.select_region:
        try:
            assert args.select_region[0] == '[' and args.select_region[-1] == ']'
            xmin,xmax,ymin,ymax = [ float(token) for token in args.select_region[1:-1].split(',') ]
            assert xmin < xmax and ymin < ymax
        except (ValueError,AssertionError):
            print('Invalid select-region xmin,xmax,ymin,ymax = %s.' % args.select_region)
            return -1
        args.select.extend(['dx>=%f'%xmin,'dx<%f'%xmax,'dy>=%f'%ymin,'dy<%f'%ymax])

    # Perform object selection.
    if args.select:
        # Combine select clauses with logical AND.
        selection = results.select(*args.select,mode='and',format='mask')
    else:
        # Nothing is selected by default.
        selection = results.select('NONE',format='mask')
    # Add any specified groups to the selection with logical OR.
    for identifier in args.group:
        selected = results.select('grp_id==%d' % identifier,format='mask')
        if not np.any(selected):
            print('WARNING: no group found with ID %d.' % identifier)
        selection = np.logical_or(selection,selected)
    # Add any specified galaxies to the selection with logical OR.
    for identifier in args.galaxy:
        selected = results.select('db_id==%d' % identifier,format='mask')
        if not np.any(selected):
            print('WARNING: no galaxy found with ID %d.' % identifier)
        selection = np.logical_or(selection,selected)
    selected_indices = np.arange(results.num_objects)[selection]
    if args.verbose:
        print('Selected IDs:\n%s' % np.array(results.table['db_id'][selected_indices]))
        groups = np.unique(results.table[selected_indices]['grp_id'])
        print('Selected group IDs:\n%s' % np.array(groups))

    # Do we have individual objects available for selection in the output file?
    if np.any(selection) and not results.stamps:
        print('Cannot display selected objects without any stamps available.')
        return -1

    # Save table of selected objects if requested.
    if args.save_selected:
        if args.verbose:
            print('Saving selected objects to %s.' % args.save_selected)
        results.table[selected_indices].write(args.save_selected, overwrite=True)

    # Build the image of selected objects (might be None).
    selected_image = results.get_subimage(selected_indices)

    # Calculate our viewing bounds as (xmin,xmax,ymin,ymax) in floating-point pixels
    # relative to the image bottom-left corner. Also calculate view_bounds with
    # integer values that determine how to extract sub-images to display.
    scale = results.survey.pixel_scale
    if args.view_region is not None:
        try:
            assert args.view_region[0] == '[' and args.view_region[-1] == ']'
            xmin,xmax,ymin,ymax = [ float(token) for token in args.view_region[1:-1].split(',') ]
            assert xmin < xmax and ymin < ymax
        except (ValueError,AssertionError):
            print('Invalid view-window xmin,xmax,ymin,ymax = %s.' % args.view_region)
            return -1
        # Convert to pixels relative to bottom-left corner.
        xmin = xmin/scale + 0.5*results.survey.image_width
        xmax = xmax/scale + 0.5*results.survey.image_width
        ymin = ymin/scale + 0.5*results.survey.image_height
        ymax = ymax/scale + 0.5*results.survey.image_height
        # Calculate integer pixel bounds that cover the view window.
        view_bounds = galsim.BoundsI(
            int(math.floor(xmin)),int(math.ceil(xmax))-1,
            int(math.floor(ymin)),int(math.ceil(ymax))-1)
    elif args.crop and selected_image is not None:
        view_bounds = selected_image.bounds
        xmin,xmax,ymin,ymax = (
            view_bounds.xmin,view_bounds.xmax+1,view_bounds.ymin,view_bounds.ymax+1)
    else:
        view_bounds = results.survey.image.bounds
        xmin,xmax,ymin,ymax = 0,results.survey.image_width,0,results.survey.image_height
    if args.verbose:
        vxmin = (xmin - 0.5*results.survey.image_width)*scale
        vxmax = (xmax - 0.5*results.survey.image_width)*scale
        vymin = (ymin - 0.5*results.survey.image_height)*scale
        vymax = (ymax - 0.5*results.survey.image_height)*scale
        print('View window is [xmin,xmax,ymin,ymax] = [%.2f,%.2f,%.2f,%.2f] arcsecs' % (
            vxmin,vxmax,vymin,vymax))
        print('View pixels in %r' % view_bounds)

    # Initialize a matplotlib figure to display our view bounds.
    view_width = float(xmax - xmin)
    view_height = float(ymax - ymin)
    if (view_width*args.magnification > args.max_view_size or
        view_height*args.magnification > args.max_view_size):
        print('Requested view dimensions %d x %d too big. Increase --max-view-size if necessary.' % (
            view_width*args.magnification,view_height*args.magnification))
        return -1
    fig_height = args.magnification*(view_height/args.dpi)
    fig_width = args.magnification*(view_width/args.dpi)
    figure = plt.figure(figsize = (fig_width,fig_height),frameon = False,dpi = args.dpi)
    axes = plt.Axes(figure, [0., 0., 1., 1.])
    axes.axis(xmin = xmin,xmax = xmax,ymin = ymin,ymax = ymax)
    axes.set_axis_off()
    figure.add_axes(axes)

    # Get the background and highlighted images to display, sized to our view.
    background = galsim.Image(bounds = view_bounds,dtype = np.float32,scale = scale)
    highlighted = background.copy()
    if not args.hide_background:
        overlap = results.survey.image.bounds & view_bounds
        if overlap.area() > 0:
            background[overlap] = results.survey.image[overlap]
    if not args.hide_selected and selected_image is not None:
        overlap = selected_image.bounds & view_bounds
        if overlap.area() > 0:
            highlighted[overlap] = selected_image[overlap]
    if np.count_nonzero(highlighted.array) == 0:
        if args.hide_background or np.count_nonzero(background.array) == 0:
            print('There are no non-zero pixel values in the view window.')
            return -1

    # Prepare the z scaling.
    zscale_pixels = results.survey.image.array
    if selected_image and not args.zscale_all:
        if selected_image.bounds.area() < 16:
            print('WARNING: using full image for z-scaling since only %d pixel(s) selected.' % (
                selected_image.bounds.area()))
        else:
            zscale_pixels = selected_image.array
    # Clip large fluxes to a fixed percentile of the non-zero selected pixel values.
    non_zero_pixels = (zscale_pixels != 0)
    vmax = np.percentile(zscale_pixels[non_zero_pixels],q = (args.clip_hi_percentile))
    # Clip small fluxes to a fixed fraction of the mean sky noise.
    vmin = args.clip_lo_noise_fraction*np.sqrt(results.survey.mean_sky_level)
    if args.verbose:
        print('Clipping pixel values to [%.1f,%.1f] detected electrons.' % (vmin,vmax))

    # Define the z scaling function. See http://ds9.si.edu/ref/how.html#Scales
    def zscale(pixels):
        return np.sqrt(pixels)

    # Calculate the clipped and scaled pixel values to display.
    highlighted_z = zscale((np.clip(highlighted.array,vmin,vmax) - vmin)/(vmax-vmin))
    if args.add_noise:
        vmin = args.clip_noise*np.sqrt(results.survey.mean_sky_level)
        if args.verbose:
            print('Background pixels with noise clipped to [%.1f,%.1f].' % (vmin,vmax))
    background_z = zscale((np.clip(background.array,vmin,vmax) - vmin)/(vmax-vmin))

    # Convert the background image to RGB using the requested colormap.
    # Drop the alpha channel [3], which is all ones anyway.
    cmap = matplotlib.cm.get_cmap(args.colormap)
    background_rgb = cmap(background_z)[:,:,:3]

    # Overlay the highlighted image using alpha blending.
    # http://en.wikipedia.org/wiki/Alpha_compositing#Alpha_blending
    if args.highlight and args.highlight != 'none':
        alpha = highlighted_z[:,:,np.newaxis]
        color = np.array(matplotlib.colors.colorConverter.to_rgb(args.highlight))
        final_rgb = alpha*color + background_rgb*(1.-alpha)
    else:
        final_rgb = background_rgb

    # Draw the composite image.
    extent = (view_bounds.xmin,view_bounds.xmax+1,view_bounds.ymin,view_bounds.ymax+1)
    axes.imshow(final_rgb,extent = extent,aspect = 'equal',origin = 'lower',
        interpolation = 'nearest')

    # The argparse module escapes any \n or \t in string args, but we need these
    # to be unescaped in the annotation format string.
    if args.info:
        args.info = binary_type(args.info, 'utf-8').decode('unicode-escape')
    if args.match_info:
        args.match_info = binary_type(args.match_info, 'utf-8').decode('string-escape')

    num_selected = len(selected_indices)
    ellipse_centers = np.empty((num_selected,2))
    ellipse_widths = np.empty(num_selected)
    ellipse_heights = np.empty(num_selected)
    ellipse_angles = np.empty(num_selected)
    match_ellipse_centers = np.empty((num_selected,2))
    match_ellipse_widths = np.empty(num_selected)
    match_ellipse_heights = np.empty(num_selected)
    match_ellipse_angles = np.empty(num_selected)
    num_match_ellipses = 0
    for index,selected in enumerate(selected_indices):
        info = results.table[selected]
        # Do we have a detected object matched to this simulated source?
        match_info = None
        if args.match_catalog and info['match'] >= 0:
            match_info = detected[info['match']]
        # Calculate the selected object's centroid position in user display coordinates.
        x_center = (0.5*results.survey.image_width + info['dx']/scale)
        y_center = (0.5*results.survey.image_height + info['dy']/scale)
        if match_info is not None:
            x_match_center = match_info['X_IMAGE']-0.5
            y_match_center = match_info['Y_IMAGE']-0.5
        # Draw a crosshair at the centroid of selected objects.
        if not args.no_crosshair:
            axes.plot(x_center,y_center,'+',color = args.crosshair_color,
                markeredgewidth = 2,markersize = 24)
            if match_info:
                axes.plot(x_match_center,y_match_center,'x',color = args.match_color,
                    markeredgewidth = 2,markersize = 24)
        # Add annotation text if requested.
        if args.info:
            path_effects = None if args.outline_color is None else [
                matplotlib.patheffects.withStroke(linewidth = 2,
                foreground = args.outline_color)]
            try:
                annotation = args.info % info
            except IndexError:
                print('Invalid annotate-format %r' % args.info)
                return -1
            axes.annotate(annotation,xy = (x_center,y_center),xytext = (4,4),
                textcoords = 'offset points',color = args.info_color,
                fontsize = args.info_size,path_effects = path_effects)
        if match_info and args.match_info:
            path_effects = None if args.outline_color is None else [
                matplotlib.patheffects.withStroke(linewidth = 2,
                foreground = args.outline_color)]
            try:
                annotation = args.match_info % match_info
            except IndexError:
                print('Invalid match-format %r' % args.match_info)
                return -1
            axes.annotate(annotation,xy = (x_match_center,y_match_center),
                xytext = (4,4),textcoords = 'offset points',
                color = args.info_color,fontsize = args.info_size,
                path_effects = path_effects)
        # Add a second-moments ellipse if requested.
        if args.draw_moments:
            ellipse_centers[index] = (x_center,y_center)
            ellipse_widths[index] = info['a']/scale
            ellipse_heights[index] = info['b']/scale
            ellipse_angles[index] = np.degrees(info['beta'])
            if match_info:
                # This will only work if we have the necessary additional fields in the match catalog.
                try:
                    match_ellipse_centers[num_match_ellipses] = (x_match_center,y_match_center)
                    match_ellipse_widths[num_match_ellipses] = match_info['A_IMAGE']
                    match_ellipse_heights[num_match_ellipses] = match_info['B_IMAGE']
                    match_ellipse_angles[num_match_ellipses] = match_info['THETA_IMAGE']
                    num_match_ellipses += 1
                except IndexError:
                    pass

    # Draw any ellipses.
    if args.draw_moments:
        ellipses = matplotlib.collections.EllipseCollection(units = 'x',
            widths = ellipse_widths,heights = ellipse_heights,angles = ellipse_angles,
            offsets = ellipse_centers, transOffset = axes.transData)
        ellipses.set_facecolor('none')
        ellipses.set_edgecolor(args.ellipse_color)
        axes.add_collection(ellipses,autolim = True)
        if num_match_ellipses > 0:
            ellipses = matplotlib.collections.EllipseCollection(units = 'x',
                widths = match_ellipse_widths,heights = match_ellipse_heights,
                angles = match_ellipse_angles,offsets = match_ellipse_centers,
                transOffset = axes.transData)
            ellipses.set_facecolor('none')
            ellipses.set_edgecolor(args.match_color)
            #ellipses.set_linestyle('dashed')
            axes.add_collection(ellipses,autolim = True)

    if args.output_name:
        figure.savefig(args.output_name,dpi = args.dpi)

    if not args.no_display:
        plt.show()
def run_mdnet(img_list, init_bbox, gt=None, savefig_dir='', display=False):

    # Init bbox
    target_bbox = np.array(init_bbox)
    result = np.zeros((len(img_list), 4))
    result_bb = np.zeros((len(img_list), 4))
    result[0] = target_bbox
    result_bb[0] = target_bbox

    # Init model
    model = MDNet(opts['model_path'])
    if opts['use_gpu']:
        model = model
    model.set_learnable_params(opts['ft_layers'])

    # Init criterion and optimizer
    criterion = BinaryLoss()
    init_optimizer = set_optimizer(model, opts['lr_init'])
    update_optimizer = set_optimizer(model, opts['lr_update'])

    tic = time.time()
    # Load first image
    image = Image.open(img_list[0]).convert('RGB')

    # Train bbox regressor
    bbreg_examples = gen_samples(
        SampleGenerator('uniform', image.size, 0.3, 1.5, 1.1), target_bbox,
        opts['n_bbreg'], opts['overlap_bbreg'], opts['scale_bbreg'])
    bbreg_feats = forward_samples(model, image, bbreg_examples)
    bbreg = BBRegressor(image.size)
    bbreg.train(bbreg_feats, bbreg_examples, target_bbox)

    # Draw pos/neg samples
    pos_examples = gen_samples(
        SampleGenerator('gaussian', image.size, 0.1, 1.2), target_bbox,
        opts['n_pos_init'], opts['overlap_pos_init'])

    neg_examples = np.concatenate([
        gen_samples(SampleGenerator('uniform', image.size, 1, 2,
                                    1.1), target_bbox, opts['n_neg_init'] // 2,
                    opts['overlap_neg_init']),
        gen_samples(SampleGenerator('whole', image.size, 0, 1.2,
                                    1.1), target_bbox, opts['n_neg_init'] // 2,
                    opts['overlap_neg_init'])
    ])
    neg_examples = np.random.permutation(neg_examples)

    # Extract pos/neg features
    pos_feats = forward_samples(model, image, pos_examples)
    neg_feats = forward_samples(model, image, neg_examples)
    feat_dim = pos_feats.size(-1)

    # Initial training
    train(model, criterion, init_optimizer, pos_feats, neg_feats,
          opts['maxiter_init'])

    # Init sample generators
    sample_generator = SampleGenerator('gaussian',
                                       image.size,
                                       opts['trans_f'],
                                       opts['scale_f'],
                                       valid=True)
    pos_generator = SampleGenerator('gaussian', image.size, 0.1, 1.2)
    neg_generator = SampleGenerator('uniform', image.size, 1.5, 1.2)

    # Init pos/neg features for update
    pos_feats_all = [pos_feats[:opts['n_pos_update']]]
    neg_feats_all = [neg_feats[:opts['n_neg_update']]]

    spf_total = time.time() - tic

    # Display
    savefig = savefig_dir != ''
    if display or savefig:
        dpi = 80.0
        figsize = (image.size[0] / dpi, image.size[1] / dpi)

        fig = plt.figure(frameon=False, figsize=figsize, dpi=dpi)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)
        im = ax.imshow(image, aspect='normal')

        if gt is not None:
            gt_rect = plt.Rectangle(tuple(gt[0, :2]),
                                    gt[0, 2],
                                    gt[0, 3],
                                    linewidth=3,
                                    edgecolor="#00ff00",
                                    zorder=1,
                                    fill=False)
            ax.add_patch(gt_rect)

        rect = plt.Rectangle(tuple(result_bb[0, :2]),
                             result_bb[0, 2],
                             result_bb[0, 3],
                             linewidth=3,
                             edgecolor="#ff0000",
                             zorder=1,
                             fill=False)
        ax.add_patch(rect)

        if display:
            plt.pause(.01)
            plt.draw()
        if savefig:
            fig.savefig(os.path.join(savefig_dir, '0000.jpg'), dpi=dpi)

    # Main loop
    for i in range(1, len(img_list)):

        tic = time.time()
        # Load image
        image = Image.open(img_list[i]).convert('RGB')

        # Estimate target bbox
        samples = gen_samples(sample_generator, target_bbox, opts['n_samples'])
        sample_scores = forward_samples(model, image, samples, out_layer='fc6')
        top_scores, top_idx = sample_scores[:, 1].topk(5)
        top_idx = top_idx.cpu().numpy()
        target_score = top_scores.mean()
        target_bbox = samples[top_idx].mean(axis=0)

        success = target_score > opts['success_thr']

        # Expand search area at failure
        if success:
            sample_generator.set_trans_f(opts['trans_f'])
        else:
            sample_generator.set_trans_f(opts['trans_f_expand'])

        # Bbox regression
        if success:
            bbreg_samples = samples[top_idx]
            bbreg_feats = forward_samples(model, image, bbreg_samples)
            bbreg_samples = bbreg.predict(bbreg_feats, bbreg_samples)
            bbreg_bbox = bbreg_samples.mean(axis=0)
        else:
            bbreg_bbox = target_bbox

        # Copy previous result at failure
        if not success:
            target_bbox = result[i - 1]
            bbreg_bbox = result_bb[i - 1]

        # Save result
        result[i] = target_bbox
        result_bb[i] = bbreg_bbox

        # Data collect
        if success:
            # Draw pos/neg samples
            pos_examples = gen_samples(pos_generator, target_bbox,
                                       opts['n_pos_update'],
                                       opts['overlap_pos_update'])
            neg_examples = gen_samples(neg_generator, target_bbox,
                                       opts['n_neg_update'],
                                       opts['overlap_neg_update'])

            # Extract pos/neg features
            pos_feats = forward_samples(model, image, pos_examples)
            neg_feats = forward_samples(model, image, neg_examples)
            pos_feats_all.append(pos_feats)
            neg_feats_all.append(neg_feats)
            if len(pos_feats_all) > opts['n_frames_long']:
                del pos_feats_all[0]
            if len(neg_feats_all) > opts['n_frames_short']:
                del neg_feats_all[0]

        # Short term update
        if not success:
            nframes = min(opts['n_frames_short'], len(pos_feats_all))
            pos_data = torch.stack(pos_feats_all[-nframes:],
                                   0).view(-1, feat_dim)
            neg_data = torch.stack(neg_feats_all, 0).view(-1, feat_dim)
            train(model, criterion, update_optimizer, pos_data, neg_data,
                  opts['maxiter_update'])

        # Long term update
        elif i % opts['long_interval'] == 0:
            pos_data = torch.stack(pos_feats_all, 0).view(-1, feat_dim)
            neg_data = torch.stack(neg_feats_all, 0).view(-1, feat_dim)
            train(model, criterion, update_optimizer, pos_data, neg_data,
                  opts['maxiter_update'])

        spf = time.time() - tic
        spf_total += spf

        # Display
        if display or savefig:
            im.set_data(image)

            if gt is not None:
                gt_rect.set_xy(gt[i, :2])
                gt_rect.set_width(gt[i, 2])
                gt_rect.set_height(gt[i, 3])

            rect.set_xy(result_bb[i, :2])
            rect.set_width(result_bb[i, 2])
            rect.set_height(result_bb[i, 3])

            if display:
                plt.pause(.01)
                plt.draw()
            if savefig:
                fig.savefig(os.path.join(savefig_dir, '%04d.jpg' % (i)),
                            dpi=dpi)

        if gt is None:
            print("Frame %d/%d, Score %.3f, Time %.3f" % \
                (i, len(img_list), target_score, spf))
        else:
            print("Frame %d/%d, Overlap %.3f, Score %.3f, Time %.3f" % \
                (i, len(img_list), overlap_ratio(gt[i],result_bb[i])[0], target_score, spf))

    fps = len(img_list) / spf_total
    return result, result_bb, fps
Exemple #8
0
def vis_one_image(im,
                  im_name,
                  output_dir,
                  boxes,
                  segms=None,
                  keypoints=None,
                  thresh=0.9,
                  kp_thresh=2,
                  dpi=200,
                  box_alpha=0.0,
                  dataset=None,
                  show_class=False,
                  ext='pdf'):
    """Visual debugging of detections."""
    if output_dir is not None:
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

    if isinstance(boxes, list):
        boxes, segms, keypoints, classes = convert_from_cls_format(
            boxes, segms, keypoints)

    if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
        return

    if segms is not None:
        masks = mask_util.decode(segms)

    color_list = colormap(rgb=True) / 255

    dataset_keypoints, _ = keypoint_utils.get_keypoints()
    kp_lines = kp_connections(dataset_keypoints)
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]

    fig = plt.figure(frameon=False)
    if output_dir is not None:
        fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)

    # preprocess the boxes
    if thresh < 0:
        # When VIS_TH less than zero, it means take the highest -thresh score boxes
        sorted_inds = np.argsort(-boxes[:, -1])
        boxes = boxes[sorted_inds[:-int(thresh)]]
        classes = [classes[_] for _ in sorted_inds[:-int(thresh)]]

    # Display in largest to smallest order to reduce occlusion
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sorted_inds = np.argsort(-areas)

    _boxes = []
    texts = []
    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue
        if len(_boxes) > 0 and (bbox == _boxes[-1][:4]).all():
            # Same box, merge prediction
            texts[-1] += '/' + get_class_string(classes[i], score, dataset)
        else:
            _boxes.append(boxes[i])
            texts.append(get_class_string(classes[i], score, dataset))
    boxes = np.stack(_boxes)

    mask_color_id = 0

    for i in range(len(boxes)):
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue

        # print(dataset.classes[classes[i]], score)
        print(texts[i])
        # show box (off by default, box_alpha=0.0)
        ax.add_patch(
            plt.Rectangle(
                (bbox[0], bbox[1]),
                bbox[2] - bbox[0],
                bbox[3] - bbox[1],
                fill=False,
                edgecolor='r',  #'##66FF66' if '@'in texts[i] else '#0099FF' ,
                linewidth=5,
                alpha=box_alpha))

        if show_class:
            ax.text(
                bbox[0],
                bbox[1] + 6,
                texts[i].split(' ')
                [0],  #get_class_string(classes[i], score, dataset),
                fontsize=3,
                family='serif',
                bbox=dict(
                    facecolor='#66FF66' if '@' in texts[i] else '#0099FF',
                    alpha=0.4,
                    pad=0,
                    edgecolor='none'),
                color='white')

        # show mask
        if segms is not None and len(segms) > i:
            img = np.ones(im.shape)
            color_mask = color_list[mask_color_id % len(color_list), 0:3]
            mask_color_id += 1

            w_ratio = .4
            for c in range(3):
                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
            for c in range(3):
                img[:, :, c] = color_mask[c]
            e = masks[:, :, i]

            _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                                cv2.CHAIN_APPROX_NONE)

            for c in contour:
                polygon = Polygon(c.reshape((-1, 2)),
                                  fill=True,
                                  facecolor=color_mask,
                                  edgecolor='w',
                                  linewidth=1.2,
                                  alpha=0.5)
                ax.add_patch(polygon)

        # show keypoints
        if keypoints is not None and len(keypoints) > i:
            kps = keypoints[i]
            plt.autoscale(False)
            for l in range(len(kp_lines)):
                i1 = kp_lines[l][0]
                i2 = kp_lines[l][1]
                if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
                    x = [kps[0, i1], kps[0, i2]]
                    y = [kps[1, i1], kps[1, i2]]
                    line = ax.plot(x, y)
                    plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
                if kps[2, i1] > kp_thresh:
                    ax.plot(kps[0, i1],
                            kps[1, i1],
                            '.',
                            color=colors[l],
                            markersize=3.0,
                            alpha=0.7)
                if kps[2, i2] > kp_thresh:
                    ax.plot(kps[0, i2],
                            kps[1, i2],
                            '.',
                            color=colors[l],
                            markersize=3.0,
                            alpha=0.7)

            # add mid shoulder / mid hip for better visualization
            mid_shoulder = (
                kps[:2, dataset_keypoints.index('right_shoulder')] +
                kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
            sc_mid_shoulder = np.minimum(
                kps[2, dataset_keypoints.index('right_shoulder')],
                kps[2, dataset_keypoints.index('left_shoulder')])
            mid_hip = (kps[:2, dataset_keypoints.index('right_hip')] +
                       kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
            sc_mid_hip = np.minimum(
                kps[2, dataset_keypoints.index('right_hip')],
                kps[2, dataset_keypoints.index('left_hip')])
            if (sc_mid_shoulder > kp_thresh
                    and kps[2, dataset_keypoints.index('nose')] > kp_thresh):
                x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
                y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
                line = ax.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines)],
                         linewidth=1.0,
                         alpha=0.7)
            if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
                x = [mid_shoulder[0], mid_hip[0]]
                y = [mid_shoulder[1], mid_hip[1]]
                line = ax.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines) + 1],
                         linewidth=1.0,
                         alpha=0.7)

        if output_dir is not None:
            output_name = os.path.basename(im_name) + '.' + ext
            fig.savefig(os.path.join(output_dir, '{}'.format(output_name)),
                        dpi=dpi)
            plt.close('all')
        else:
            plt.plot()
Exemple #9
0
def demo(sess,
         net,
         im_file,
         vis_file,
         fits_fn,
         conf_thresh=0.8,
         eval_class=True,
         extra_vis_png=False):
    """
    Detect object classes in an image using pre-computed object proposals.
    im_file:    The "fused" image file path
    vis_file:   The background image file on which detections are laid.
                Normallly, this is just the IR image file path
    fits_fn:    The FITS file path
    eval_class: True - use traditional per class-based evaluation style
                False - use per RoI-based evaluation

    """
    show_img_size = cfg.TEST.SCALES[0]
    if (not os.path.exists(im_file)):
        print('%s cannot be found' % (im_file))
        return -1
    im = cv2.imread(im_file)

    # Detect all object classes and regress object bounds
    timer = Timer()
    timer.tic()
    image_name = osp.basename(im_file)
    scores, boxes = im_detect(sess,
                              net,
                              im,
                              save_vis_dir=None,
                              img_name=os.path.splitext(image_name)[0])
    boxes *= float(show_img_size) / float(im.shape[0])
    timer.toc()
    sys.stdout.write('Done in {:.3f} secs'.format(timer.total_time))
    sys.stdout.flush()
    print(scores)

    im = cv2.imread(vis_file)

    my_dpi = 100
    fig = plt.figure()
    fig.set_size_inches(show_img_size / my_dpi, show_img_size / my_dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.set_xlim([0, show_img_size])
    ax.set_ylim([show_img_size, 0])
    #ax.set_aspect('equal')
    im = cv2.resize(im, (show_img_size, show_img_size))
    im = im[:, :, (2, 1, 0)]
    ax.imshow(im, aspect='equal')
    if ((fits_fn is not None) and (not extra_vis_png)):
        patch_contour = fuse(fits_fn,
                             im,
                             None,
                             sigma_level=4,
                             mask_ir=False,
                             get_path_patch_only=True)
        ax.add_patch(patch_contour)
    NMS_THRESH = cfg.TEST.NMS  #cfg.TEST.RPN_NMS_THRESH # 0.3

    tt_vis = 0
    bbox_img = []
    bscore_img = []
    num_sources = 0
    #if (eval_class):
    for cls_ind, cls in enumerate(CLASSES[1:]):
        cls_ind += 1  # because we skipped background
        cls_boxes = boxes[:, 4 * cls_ind:4 * (cls_ind + 1)]
        cls_scores = scores[:, cls_ind]
        dets = np.hstack(
            (cls_boxes, cls_scores[:, np.newaxis]))  #.astype(np.float32)
        keep = nms(dets, NMS_THRESH)
        dets = dets[keep, :]
        num_sources += vis_detections(im, cls, dets, ax, thresh=conf_thresh)
        #dets = np.hstack((dets, np.ones([dets.shape[0], 1]) * cls_ind))
        # if (dets.shape[0] > 0):
        #     bbox_img.append(dets)
        #     bscore_img.append(np.reshape(dets[:, -2], [-1, 1]))
    # else:
    #     for eoi_ind, eoi in enumerate(boxes):
    #         eoi_scores = scores[eoi_ind, 1:] # skip background
    #         cls_ind = np.argmax(eoi_scores) + 1 # add the background index back
    #         cls_boxes = boxes[eoi_ind, 4 * cls_ind : 4 * (cls_ind + 1)]
    #         cls_scores = scores[eoi_ind, cls_ind]
    #         dets = np.hstack((np.reshape(cls_boxes, [1, -1]),
    #                           np.reshape(cls_scores, [-1, 1])))#.astype(np.float32)
    #         dets = np.hstack((dets, np.ones([dets.shape[0], 1]) * cls_ind))
    #         bbox_img.append(dets)
    #         bscore_img.append(np.reshape(dets[:, -2], [-1, 1]))
    #
    # boxes_im = np.vstack(bbox_img)
    # scores_im = np.vstack(bscore_img)
    #
    # #if (not eval_class):
    # # a numpy float is a C double, so need to use float32
    # keep = nms(boxes_im[:, :-1].astype(np.float32), NMS_THRESH)
    # boxes_im = boxes_im[keep, :]
    # scores_im = scores_im[keep, :]
    #
    # keep_indices = range(boxes_im.shape[0])
    #num_sources = vis_detections(im, None, boxes_im[keep_indices, :], ax, thresh=conf_thresh)

    print(', found %d sources' % num_sources)
    return 0
def plot_all(plain_path, pre_HA_path, pre_PV_path, post_HA_path, post_PV_path,
             pre_HA_label_path, pre_PV_label_path, post_HA_label_path,
             post_PV_label_path, output_dir):
    reader = sitk.ImageFileReader()
    reader.SetFileName(plain_path)
    plain = reader.Execute()
    reader.SetFileName(pre_HA_path)
    pre_HA = reader.Execute()
    reader.SetFileName(pre_PV_path)
    pre_PV = reader.Execute()
    reader.SetFileName(post_HA_path)
    post_HA = reader.Execute()
    reader.SetFileName(post_PV_path)
    post_PV = reader.Execute()
    reader.SetFileName(pre_HA_label_path)
    pre_HA_label = reader.Execute()
    reader.SetFileName(pre_PV_label_path)
    pre_PV_label = reader.Execute()
    reader.SetFileName(post_HA_label_path)
    post_HA_label = reader.Execute()
    reader.SetFileName(post_PV_label_path)
    post_PV_label = reader.Execute()

    intensityWindowingFilter = sitk.IntensityWindowingImageFilter()
    intensityWindowingFilter.SetOutputMaximum(255)
    intensityWindowingFilter.SetOutputMinimum(0)
    intensityWindowingFilter.SetWindowMaximum(300)
    intensityWindowingFilter.SetWindowMinimum(0)
    plain = intensityWindowingFilter.Execute(plain)
    pre_HA = intensityWindowingFilter.Execute(pre_HA)
    pre_PV = intensityWindowingFilter.Execute(pre_PV)
    post_HA = intensityWindowingFilter.Execute(post_HA)
    post_PV = intensityWindowingFilter.Execute(post_PV)

    plain_np = sitk.GetArrayFromImage(plain)
    pre_HA_np = sitk.GetArrayFromImage(pre_HA)
    pre_PV_np = sitk.GetArrayFromImage(pre_PV)
    post_HA_np = sitk.GetArrayFromImage(post_HA)
    post_PV_np = sitk.GetArrayFromImage(post_PV)

    pre_HA_label_np = sitk.GetArrayFromImage(pre_HA_label)
    pre_PV_label_np = sitk.GetArrayFromImage(pre_PV_label)
    post_HA_label_np = sitk.GetArrayFromImage(post_HA_label)
    post_PV_label_np = sitk.GetArrayFromImage(post_PV_label)

    for z in tqdm(range(plain_np.shape[0])):
        # if not (z == 32):
        # 	continue

        plain_slice = plain_np[z, :, :]
        pre_HA_slice = pre_HA_np[z, :, :]
        pre_PV_slice = pre_PV_np[z, :, :]
        post_HA_slice = post_HA_np[z, :, :]
        post_PV_slice = post_PV_np[z, :, :]

        pre_HA_label_slice = pre_HA_label_np[z, :, :]
        pre_PV_label_slice = pre_PV_label_np[z, :, :]
        post_HA_label_slice = post_HA_label_np[z, :, :]
        post_PV_label_slice = post_PV_label_np[z, :, :]

        dpi = 100
        shape = np.shape(plain_slice)[0:2][::-1]
        size = [float(i) / dpi for i in shape]
        size[0] = size[0] * 5

        fig = plt.figure()
        fig.set_size_inches(size)

        ax = plt.Axes(fig, [0, 0, 0.2, 1])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(plain_slice, cmap="gray", origin='lower')
        CS_pre_HA = ax.contour(pre_HA_label_slice, [0, 1],
                               colors='r',
                               origin='lower',
                               linewidths=1)
        CS_pre_PV = ax.contour(pre_PV_label_slice, [0, 1],
                               colors='lawngreen',
                               origin='lower',
                               linewidths=1)
        CS_post_HA = ax.contour(post_HA_label_slice, [0, 1],
                                colors='dodgerblue',
                                origin='lower',
                                linewidths=1)
        CS_post_PV = ax.contour(post_PV_label_slice, [0, 1],
                                colors='yellow',
                                origin='lower',
                                linewidths=1)

        CS_pre_HA.collections[0].set_label("Pre HA")
        CS_pre_PV.collections[0].set_label("Pre PV")
        CS_post_HA.collections[0].set_label("Post HA")
        CS_post_PV.collections[0].set_label("Post PV")

        leg = ax.legend(loc='upper right', frameon=False)
        for text in leg.get_texts():
            plt.setp(text, color='w')

        ax = plt.Axes(fig, [0.2, 0, 0.2, 1])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(pre_HA_slice, cmap="gray", origin='lower')
        CS_pre_HA = ax.contour(pre_HA_label_slice, [0, 1],
                               colors='r',
                               origin='lower',
                               linewidths=1)

        ax = plt.Axes(fig, [0.4, 0, 0.2, 1])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(pre_PV_slice, cmap="gray", origin='lower')
        CS_pre_PV = ax.contour(pre_PV_label_slice, [0, 1],
                               colors='lawngreen',
                               origin='lower',
                               linewidths=1)

        ax = plt.Axes(fig, [0.6, 0, 0.2, 1])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(post_HA_slice, cmap="gray", origin='lower')
        CS_post_HA = ax.contour(post_HA_label_slice, [0, 1],
                                colors='dodgerblue',
                                origin='lower',
                                linewidths=1)

        ax = plt.Axes(fig, [0.8, 0, 0.2, 1])
        ax.set_axis_off()
        fig.add_axes(ax)
        ax.imshow(post_PV_slice, cmap="gray", origin='lower')
        CS_post_PV = ax.contour(post_PV_label_slice, [0, 1],
                                colors='yellow',
                                origin='lower',
                                linewidths=1)

        if z < 9:
            fig.savefig(os.path.join(output_dir, "0" + str(z + 1) + ".jpg"),
                        dpi=dpi)
        else:
            fig.savefig(os.path.join(output_dir, str(z + 1) + ".jpg"), dpi=dpi)

        plt.close(fig)
Exemple #11
0
def display_instances(image,
                      boxes,
                      masks,
                      class_ids,
                      class_names,
                      scores=None,
                      title="",
                      figsize=(16, 16),
                      ax=None,
                      dstPath=None,
                      filename=None,
                      truemask=None):
    """
    boxes: [num_instance, (y1, x1, y2, x2, class_id)] in image coordinates.
    masks: [height, width, num_instances]
    class_ids: [num_instances]
    class_names: list of class names of the dataset
    scores: (optional) confidence scores for each box
    figsize: (optional) the size of the image.
    """
    #scores_ = np.asarray(scores)

    # Number of instances
    scores = np.asarray(scores)
    maxvalue = np.amax(scores)

    indece = []
    for i, score in enumerate(scores):
        if (score == maxvalue):
            indece.append(i)

    indece = np.asarray(indece)
    indece_tmp = []
    index = -1
    tmp = 1024
    if (len(indece) == 1):
        index = indece[0]
    else:
        for i in indece:
            y1, x1, y2, x2, _ = bboxes[i]
            if (x1 < 512 and y1 < tmp):
                tmp = y1
                indece_tmp.append(i)

    tmp = 1024
    if (len(indece_tmp) == 1):
        index = indece_tmp[0]
    else:
        for i in indece_tmp:
            y1, x1, y2, x2, _ = bboxes[i]
            if (x1 < tmp):
                tmp = x1
                index = i

    print("index", index)

    #assert index < N and index >= 0
    # if not N:
    #     print("\n*** No instances to display *** \n")
    # else:
    #     assert boxes.shape[0] == masks.shape[-1] == class_ids.shape[0]
    fig = plt.figure()
    if not ax:
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
    fig.add_axes(ax)
    # Generate random colors
    colors = random_colors(2)

    # Show area outside image boundaries.
    height, width = image.shape[:2]
    ax.set_ylim(height + 10, -10)
    ax.set_xlim(-10, width + 10)
    ax.axis('off')
    ax.set_title(title)

    masked_image = image.astype(np.uint32).copy()

    #for index in range(N):
    color = colors[0]

    #if not np.any(boxes[index]):
    # Skip this instance. Has no bbox. Likely lost in image cropping.
    #continue
    y1, x1, y2, x2 = boxes[index]
    p = patches.Rectangle((x1, y1),
                          x2 - x1,
                          y2 - y1,
                          linewidth=0.5,
                          alpha=0.7,
                          linestyle="dashed",
                          edgecolor=color,
                          facecolor='none')
    ax.add_patch(p)

    # Label
    class_id = class_ids[index]
    score = scores[i] if scores is not None else None
    label = class_names[class_id]
    x = random.randint(x1, (x1 + x2) // 2)
    #caption = "{} {:.3f}".format(label, score) if score else label
    #ax.text(x1, y1 + 8, caption,
    #       color='w', size=11, backgroundcolor="none")

    # Mask
    mask = masks[:, :, index]
    masked_image = apply_mask(masked_image, mask, color)
    #truemask = np.asarray(truemask)
    #color = colors[1]
    #masked_image = apply_mask(masked_image, truemask, color)
    # Mask Polygon
    # Pad to ensure proper polygons for masks that touch image edges.
    padded_mask = np.zeros((mask.shape[0] + 2, mask.shape[1] + 2),
                           dtype=np.uint8)
    padded_mask[1:-1, 1:-1] = mask
    contours = find_contours(padded_mask, 0.5)
    for verts in contours:
        # Subtract the padding and flip (y, x) to (x, y)
        verts = np.fliplr(verts) - 1
        p = Polygon(verts, facecolor="none", edgecolor=color)
        ax.add_patch(p)

    padded_mask = np.zeros((truemask.shape[0] + 2, truemask.shape[1] + 2),
                           dtype=np.uint8)
    padded_mask[1:-1, 1:-1] = truemask
    contours = find_contours(padded_mask, 0.5)
    for verts in contours:
        # Subtract the padding and flip (y, x) to (x, y)
        verts = np.fliplr(verts) - 1
        p = Polygon(verts, facecolor="none", edgecolor=colors[1])
        ax.add_patch(p)

    ax.imshow(masked_image.astype(np.uint8))
    #ax.imshow(truemask, cmap='Blues', alpha = 0.5, interpolation = 'nearest')
    plt.show()

    fig.savefig(dstPath + "/" + filename[:-4] + "_fig.png", dpi=1024)
Exemple #12
0
def plot_chunk(chunk,
               mode='spectrogram',
               output_folder=None,
               base_path=None,
               size=227,
               nfft=None,
               file_type='png',
               labelling=False,
               **kwargs):
    """
    Plot spectrograms for a chunk of a wav-file using the described parameters.
    :param chunk: audio chunk to be plotted.
    :param mode: type of audio plot to create.
    :param nfft: number of samples for the fast fourier transformation \
        (Default: 256)
    :param size: size of the spectrogram plot in pixels. Height and width are \
        always identical (Default: 227)
    :param output_folder: if given, the plot is saved to this path in .png \
        format (Default: None)
    :param kwargs: keyword args for plotting functions
    :return: blob of the spectrogram plot
    """
    matplotlib.use('Agg')
    import matplotlib.pyplot as plt
    filename, sr, ts, audio = chunk
    write_index = ts is not None
    if not nfft:
        nfft = _next_power_of_two(int(sr * 0.025))
    log.debug(f'Using nfft={nfft} for the FFT.')
    fig = plt.figure(frameon=False, tight_layout=False)

    if labelling:
        pass
    else:
        fig.set_size_inches(1, 1)
        ax = plt.Axes(fig, [0., 0., 1., 1.])
        ax.set_axis_off()
        fig.add_axes(ax)

    with warnings.catch_warnings():
        warnings.simplefilter('ignore')
        spectrogram_axes = PLOTTING_FUNCTIONS[mode](audio, sr, nfft, **kwargs)
        if labelling:
            original_xlim = spectrogram_axes.get_xlim()
            if mode != 'chroma':
                kHz_ticks = np.apply_along_axis(lambda x: x / 1000, 0,
                                                spectrogram_axes.get_yticks())
                spectrogram_axes.set_yticklabels(kHz_ticks)
                spectrogram_axes.set_ylabel('Frequency [kHz]',
                                            fontdict=label_font)
            else:
                spectrogram_axes.set_ylabel('Pitch Classes',
                                            fontdict=label_font)
            if labelling:
                spectrogram_axes.set_xticks(spectrogram_axes.get_xticks()[::2])
            spectrogram_axes.set_xlabel('Time [s]', fontdict=label_font)
            spectrogram_axes.set_xlim(original_xlim)
        del audio
    fig.add_axes(spectrogram_axes, id='spectrogram')

    if labelling:
        plt.colorbar(format='%+2.1f dB')
        plt.tight_layout()

    if output_folder:
        relative_file_name = f'{splitext(get_relative_path(filename, base_path))[0]}_{ts:g}.{file_type}' if write_index else f'{splitext(get_relative_path(filename, base_path))[0]}.{file_type}'
        if base_path is None:
            outfile = join(output_folder, basename(relative_file_name))
        else:
            outfile = join(output_folder, relative_file_name)

        log.debug(f'Saving spectrogram plot to {outfile}.')
        makedirs(dirname(outfile), exist_ok=True)
        fig.savefig(outfile, format=file_type, dpi=size)
    buf = io.BytesIO()
    fig.savefig(buf, format='png', dpi=size)
    buf.seek(0)
    fig.clf()
    plt.close(fig)
    img_blob = buf.read()
    buf.close()
    try:
        img = imread_from_blob(img_blob, 'png')
        img = img[:, :, :-1]
        log.debug(f'Read spectrogram plot with shape {img.shape}.')
    except IOError:
        log.error('Error while reading the spectrogram blob.')
        return None
    return PlotTuple(name=get_relative_path(filename, base_path),
                     timestamp=ts,
                     plot=img)
Exemple #13
0
def vis_one_image(im,
                  im_name,
                  output_dir,
                  boxes,
                  segms=None,
                  keypoints=None,
                  thresh=0.9,
                  kp_thresh=2,
                  dpi=200,
                  box_alpha=0.8,
                  dataset=None,
                  show_class=False,
                  ext='pdf',
                  labels=None):
    """Visual debugging of detections."""
    # if not os.path.exists ( output_dir ) :
    # 	os.makedirs ( output_dir )

    print("Processing image: {}".format(im_name))

    if isinstance(boxes, list):
        boxes, segms, keypoints, classes = convert_from_cls_format(
            boxes, segms, keypoints)

    if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
        return

    # 这里的mask是一个三维数组,1,2维分别代表原图的横纵坐标,3维一共有当前预测出的instanecs个数的大小。
    if segms is not None:
        masks = mask_util.decode(segms)

    color_list = colormap(rgb=True) / 255

    # print ( np.unique ( classes ) )
    dataset_keypoints, _ = keypoint_utils.get_keypoints()
    kp_lines = kp_connections(dataset_keypoints)
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]

    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)

    # Display in largest to smallest order to reduce occlusion
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    # 默认从大到小排序,先画大的,小的会在后面把大的覆盖,这样大的和小的都能看到。
    sorted_inds = np.argsort(-areas)
    instance_cnt = defaultdict(int)

    labels_graph = np.zeros((im.shape[0], im.shape[1]))

    for item in enumerate(["instances_text", "instances", "labels"]):
        if not os.path.exists(os.path.join(output_dir, item[1])):
            os.makedirs(os.path.join(output_dir, item[1]), 0o777)

    mask_file = open("{}/instances_text/{}.txt".format(output_dir, im_name),
                     "w")

    # output labels prediction
    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        # 如果置信度大于0.5,执行
        if score > 0.5:
            single_mask = masks[:, :, i]
            instances_graph = np.zeros((im.shape[0], im.shape[1]))

            label_id = classes[i]
            instance_cnt[dataset.classes[label_id]] += 1
            instance_id = instance_cnt[dataset.classes[label_id]]
            labels_graph[single_mask == 1] = label_id
            instances_graph[single_mask == 1] = 255
            scipy.misc.imsave(
                '{}/instances/{}_{}.png'.format(output_dir, im_name,
                                                label_id * 256 + instance_id),
                instances_graph)

            # 写入格式为mask_file_point class_name score
            mask_file.write("{}/instances/{}_{}.png {} {}\n".format(
                output_dir, im_name, label_id * 256 + instance_id, classes[i],
                score))

    # # current_graph存的是比较raw的值
    # scipy.misc.imsave ( '/nfs/project/libo_i/mask-rcnn.pytorch/map_evaluation_format/raw/{}.jpg'.format ( im_name ) ,
    #                     labels_graph )
    # print ( np.unique ( labels_graph ) )
    # colored_graph里存的是经过config中的cmap赋值之后的值。
    # colored_graph = apply_color_map ( labels_graph , labels )
    # gray_graph = rgb2gray ( colored_graph )
    scipy.misc.imsave('{}/labels/{}.png'.format(output_dir, im_name),
                      labels_graph)

    mask_file.close()
Exemple #14
0
def make_map(e, n, t, d, dat_port, dat_star, data_R, pix_m, res, cs2cs_args,
             sonpath, p, mode, nn, numstdevs, c, dx,
             use_uncorrected):  #dogrid, influence,dowrite,

    thres = 5

    trans = pyproj.Proj(init=cs2cs_args)

    merge = np.vstack((dat_port, dat_star))
    del dat_port, dat_star

    merge[np.isnan(merge)] = 0
    merge = merge[:, :len(n)]

    ## actual along-track resolution is this: dx times dy = Af
    tmp = data_R * dx * (c * 0.007 / 2
                         )  #dx = np.arcsin(c/(1000*meta['t']*meta['f']))
    res_grid = np.sqrt(np.vstack((tmp, tmp)))
    del tmp
    res_grid = res_grid[:np.shape(merge)[0], :np.shape(merge)[1]]

    #if use_uncorrected != 1:
    #   merge = merge - 10*np.log10(res_grid)

    res_grid = res_grid.astype('float32')

    merge[np.isnan(merge)] = 0
    merge[merge < 0] = 0

    merge = merge.astype('float32')

    R = np.vstack((np.flipud(data_R), data_R))
    del data_R
    R = R[:np.shape(merge)[0], :np.shape(merge)[1]]

    # get number pixels in scan line
    extent = int(np.shape(merge)[0] / 2)

    yvec = np.squeeze(
        np.linspace(np.squeeze(pix_m), extent * np.squeeze(pix_m), extent))

    X, Y, D, h, t = getXY(e, n, yvec, np.squeeze(d), t, extent)

    X = X.astype('float32')
    Y = Y.astype('float32')
    D = D.astype('float32')
    h = h.astype('float32')
    t = t.astype('float32')
    X = X.astype('float32')

    D[np.isnan(D)] = 0
    h[np.isnan(h)] = 0
    t[np.isnan(t)] = 0

    X = X[np.where(np.logical_not(np.isnan(Y)))]
    merge = merge.flatten()[np.where(np.logical_not(np.isnan(Y)))]
    res_grid = res_grid.flatten()[np.where(np.logical_not(np.isnan(Y)))]
    Y = Y[np.where(np.logical_not(np.isnan(Y)))]
    D = D[np.where(np.logical_not(np.isnan(Y)))]
    R = R.flatten()[np.where(np.logical_not(np.isnan(Y)))]
    h = h[np.where(np.logical_not(np.isnan(Y)))]
    t = t[np.where(np.logical_not(np.isnan(Y)))]

    Y = Y[np.where(np.logical_not(np.isnan(X)))]
    merge = merge.flatten()[np.where(np.logical_not(np.isnan(X)))]
    res_grid = res_grid.flatten()[np.where(np.logical_not(np.isnan(X)))]
    X = X[np.where(np.logical_not(np.isnan(X)))]
    D = D[np.where(np.logical_not(np.isnan(X)))]
    R = R.flatten()[np.where(np.logical_not(np.isnan(X)))]
    h = h[np.where(np.logical_not(np.isnan(X)))]
    t = t[np.where(np.logical_not(np.isnan(X)))]

    X = X[np.where(np.logical_not(np.isnan(merge)))]
    Y = Y[np.where(np.logical_not(np.isnan(merge)))]
    merge = merge[np.where(np.logical_not(np.isnan(merge)))]
    res_grid = res_grid.flatten()[np.where(np.logical_not(np.isnan(merge)))]
    D = D[np.where(np.logical_not(np.isnan(merge)))]
    R = R[np.where(np.logical_not(np.isnan(merge)))]
    h = h[np.where(np.logical_not(np.isnan(merge)))]
    t = t[np.where(np.logical_not(np.isnan(merge)))]

    X = X[np.where(np.logical_not(np.isinf(merge)))]
    Y = Y[np.where(np.logical_not(np.isinf(merge)))]
    merge = merge[np.where(np.logical_not(np.isinf(merge)))]
    res_grid = res_grid.flatten()[np.where(np.logical_not(np.isinf(merge)))]
    D = D[np.where(np.logical_not(np.isinf(merge)))]
    R = R[np.where(np.logical_not(np.isinf(merge)))]
    h = h[np.where(np.logical_not(np.isinf(merge)))]
    t = t[np.where(np.logical_not(np.isinf(merge)))]

    print "writing point cloud"
    #if dowrite==1:
    ## write raw bs to file
    outfile = os.path.normpath(
        os.path.join(sonpath, 'x_y_ss_raw' + str(p) + '.asc'))
    ##write.txtwrite( outfile, np.hstack((humutils.ascol(X.flatten()),humutils.ascol(Y.flatten()), humutils.ascol(merge.flatten()), humutils.ascol(D.flatten()), humutils.ascol(R.flatten()), humutils.ascol(h.flatten()), humutils.ascol(t.flatten())  )) )
    np.savetxt(outfile,
               np.hstack(
                   (humutils.ascol(X.flatten()), humutils.ascol(Y.flatten()),
                    humutils.ascol(merge.flatten()),
                    humutils.ascol(D.flatten()), humutils.ascol(R.flatten()),
                    humutils.ascol(h.flatten()), humutils.ascol(t.flatten()))),
               fmt="%8.6f %8.6f %8.6f %8.6f %8.6f %8.6f %8.6f")

    del D, R, h, t

    sigmas = 0.1  #m
    eps = 2

    #if dogrid==1:
    if 2 > 1:

        if res == 99:
            resg = np.min(res_grid[res_grid > 0]) / 2
        else:
            resg = res

        tree = KDTree(np.c_[X.flatten(), Y.flatten()])
        complete = 0
        while complete == 0:
            try:
                grid_x, grid_y, res = getmesh(np.min(X), np.max(X), np.min(Y),
                                              np.max(Y), resg)
                longrid, latgrid = trans(grid_x, grid_y, inverse=True)
                longrid = longrid.astype('float32')
                latgrid = latgrid.astype('float32')
                shape = np.shape(grid_x)

                ## create mask for where the data is not
                if pykdtree == 1:
                    dist, _ = tree.query(np.c_[grid_x.ravel(),
                                               grid_y.ravel()],
                                         k=1)
                else:
                    try:
                        dist, _ = tree.query(np.c_[grid_x.ravel(),
                                                   grid_y.ravel()],
                                             k=1,
                                             n_jobs=cpu_count())
                    except:
                        #print ".... update your scipy installation to use faster kd-tree queries"
                        dist, _ = tree.query(np.c_[grid_x.ravel(),
                                                   grid_y.ravel()],
                                             k=1)

                dist = dist.reshape(grid_x.shape)

                targ_def = pyresample.geometry.SwathDefinition(
                    lons=longrid.flatten(), lats=latgrid.flatten())
                del longrid, latgrid

                humlon, humlat = trans(X, Y, inverse=True)
                orig_def = pyresample.geometry.SwathDefinition(
                    lons=humlon.flatten(), lats=humlat.flatten())
                del humlon, humlat
                if 'orig_def' in locals():
                    complete = 1
            except:
                print "memory error: trying grid resolution of %s" % (str(
                    resg * 2))
                resg = resg * 2

        if mode == 1:

            complete = 0
            while complete == 0:
                try:
                    try:
                        dat = pyresample.kd_tree.resample_nearest(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            fill_value=None,
                            nprocs=cpu_count())
                    except:
                        dat = pyresample.kd_tree.resample_nearest(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            fill_value=None,
                            nprocs=1)

                    try:
                        r_dat = pyresample.kd_tree.resample_nearest(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            fill_value=None,
                            nprocs=cpu_count())
                    except:
                        r_dat = pyresample.kd_tree.resample_nearest(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            fill_value=None,
                            nprocs=1)

                    stdev = None
                    counts = None
                    if 'dat' in locals():
                        complete = 1
                except:
                    del grid_x, grid_y, targ_def, orig_def

                    wf = None
                    humlon, humlat = trans(X, Y, inverse=True)
                    dat, stdev, counts, resg, complete, shape = getgrid_lm(
                        humlon, humlat, merge, res * 20, min(X), max(X),
                        min(Y), max(Y), resg * 2, mode, trans, nn, wf, sigmas,
                        eps)
                    r_dat, stdev, counts, resg, complete, shape = getgrid_lm(
                        humlon, humlat, res_grid, res * 20, min(X), max(X),
                        min(Y), max(Y), resg * 2, mode, trans, nn, wf, sigmas,
                        eps)
                    del humlon, humlat

        elif mode == 2:

            # custom inverse distance
            wf = lambda r: 1 / r**2

            complete = 0
            while complete == 0:
                try:
                    try:
                        dat, stdev, counts = pyresample.kd_tree.resample_custom(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            weight_funcs=wf,
                            fill_value=None,
                            with_uncert=True,
                            nprocs=cpu_count())
                    except:
                        dat, stdev, counts = pyresample.kd_tree.resample_custom(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            weight_funcs=wf,
                            fill_value=None,
                            with_uncert=True,
                            nprocs=1)

                    try:
                        r_dat = pyresample.kd_tree.resample_custom(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            weight_funcs=wf,
                            fill_value=None,
                            with_uncert=False,
                            nprocs=cpu_count())
                    except:
                        r_dat = pyresample.kd_tree.resample_custom(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            weight_funcs=wf,
                            fill_value=None,
                            with_uncert=False,
                            nprocs=1)

                    if 'dat' in locals():
                        complete = 1
                except:
                    del grid_x, grid_y, targ_def, orig_def
                    humlon, humlat = trans(X, Y, inverse=True)
                    dat, stdev, counts, resg, complete, shape = getgrid_lm(
                        humlon, humlat, merge, res * 2, min(X), max(X), min(Y),
                        max(Y), resg * 2, mode, trans, nn, wf, sigmas, eps)
                    r_dat, stdev, counts, resg, complete, shape = getgrid_lm(
                        humlon, humlat, res_grid, res * 2, min(X), max(X),
                        min(Y), max(Y), resg * 2, mode, trans, nn, wf, sigmas,
                        eps)
                    del humlat, humlon
                    del stdev_null, counts_null

        elif mode == 3:
            wf = None

            complete = 0
            while complete == 0:
                try:
                    try:
                        dat, stdev, counts = pyresample.kd_tree.resample_gauss(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            sigmas=sigmas,
                            fill_value=None,
                            with_uncert=True,
                            nprocs=cpu_count(),
                            epsilon=eps)
                    except:
                        dat, stdev, counts = pyresample.kd_tree.resample_gauss(
                            orig_def,
                            merge.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            sigmas=sigmas,
                            fill_value=None,
                            with_uncert=True,
                            nprocs=1,
                            epsilon=eps)

                    try:
                        r_dat = pyresample.kd_tree.resample_gauss(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            sigmas=sigmas,
                            fill_value=None,
                            with_uncert=False,
                            nprocs=cpu_count(),
                            epsilon=eps)
                    except:
                        r_dat = pyresample.kd_tree.resample_gauss(
                            orig_def,
                            res_grid.flatten(),
                            targ_def,
                            radius_of_influence=res * 20,
                            neighbours=nn,
                            sigmas=sigmas,
                            fill_value=None,
                            with_uncert=False,
                            nprocs=1,
                            epsilon=eps)

                    if 'dat' in locals():
                        complete = 1
                except:
                    del grid_x, grid_y, targ_def, orig_def
                    humlon, humlat = trans(X, Y, inverse=True)
                    dat, stdev, counts, resg, complete, shape = getgrid_lm(
                        humlon, humlat, merge, res * 20, min(X), max(X),
                        min(Y), max(Y), resg * 2, mode, trans, nn, wf, sigmas,
                        eps)
                    r_dat, stdev_null, counts_null, resg, complete, shape = getgrid_lm(
                        humlon, humlat, res_grid, res * 20, min(X), max(X),
                        min(Y), max(Y), resg * 2, mode, trans, nn, wf, sigmas,
                        eps)
                    del humlat, humlon
                    del stdev_null, counts_null

        humlon, humlat = trans(X, Y, inverse=True)
        del X, Y, res_grid, merge

        dat = dat.reshape(shape)

        dat[dist > res * 10] = np.nan
        del dist

        r_dat = r_dat.reshape(shape)
        r_dat[r_dat < 1] = 1
        r_dat[r_dat > 2 * np.pi] = 1
        r_dat[np.isnan(dat)] = np.nan

        dat = dat + r_dat  #np.sqrt(np.cos(np.deg2rad(r_dat))) #dat*np.sqrt(r_dat) + dat

        del r_dat

        if mode > 1:
            stdev = stdev.reshape(shape)
            counts = counts.reshape(shape)

        mask = dat.mask.copy()

        dat[mask == 1] = np.nan
        #dat[mask==1] = 0

        if mode > 1:
            dat[(stdev > numstdevs) & (mask != 0)] = np.nan
            dat[(counts < nn) & (counts > 0)] = np.nan

    #if dogrid==1:

    dat[dat == 0] = np.nan
    dat[np.isinf(dat)] = np.nan

    dat[dat < thres] = np.nan

    datm = np.ma.masked_invalid(dat)

    glon, glat = trans(grid_x, grid_y, inverse=True)
    del grid_x, grid_y

    try:

        # =========================================================
        print "creating kmz file ..."
        ## new way to create kml file
        pixels = 1024 * 10

        fig, ax = humutils.gearth_fig(llcrnrlon=glon.min(),
                                      llcrnrlat=glat.min(),
                                      urcrnrlon=glon.max(),
                                      urcrnrlat=glat.max(),
                                      pixels=pixels)
        cs = ax.pcolormesh(glon, glat, datm, cmap='gray')
        ax.set_axis_off()
        fig.savefig(os.path.normpath(
            os.path.join(sonpath, 'map' + str(p) + '.png')),
                    transparent=True,
                    format='png')
        del fig, ax

        # =========================================================
        fig = plt.figure(figsize=(1.0, 4.0), facecolor=None, frameon=False)
        ax = fig.add_axes([0.0, 0.05, 0.2, 0.9])
        cb = fig.colorbar(cs, cax=ax)
        cb.set_label('Intensity [dB W]', rotation=-90, color='k', labelpad=20)
        fig.savefig(os.path.normpath(
            os.path.join(sonpath, 'legend' + str(p) + '.png')),
                    transparent=False,
                    format='png')
        del fig, ax, cs, cb

        # =========================================================
        humutils.make_kml(
            llcrnrlon=glon.min(),
            llcrnrlat=glat.min(),
            urcrnrlon=glon.max(),
            urcrnrlat=glat.max(),
            figs=[
                os.path.normpath(os.path.join(sonpath,
                                              'map' + str(p) + '.png'))
            ],
            colorbar=os.path.normpath(
                os.path.join(sonpath, 'legend' + str(p) + '.png')),
            kmzfile=os.path.normpath(
                os.path.join(sonpath, 'GroundOverlay' + str(p) + '.kmz')),
            name='Sidescan Intensity')

    except:
        print "error: map could not be created..."

    #y1 = np.min(glat)-0.001
    #x1 = np.min(glon)-0.001
    #y2 = np.max(glat)+0.001
    #x2 = np.max(glon)+0.001

    print "drawing and printing map ..."
    fig = plt.figure(frameon=False)
    map = Basemap(
        projection='merc',
        epsg=cs2cs_args.split(':')[1],
        resolution='i',  #h #f
        llcrnrlon=np.min(humlon) - 0.001,
        llcrnrlat=np.min(glat) - 0.001,
        urcrnrlon=np.max(humlon) + 0.001,
        urcrnrlat=np.max(glat) + 0.001)

    try:
        map.arcgisimage(server='http://server.arcgisonline.com/ArcGIS',
                        service='World_Imagery',
                        xpixels=1000,
                        ypixels=None,
                        dpi=300)
    except:
        map.arcgisimage(server='http://server.arcgisonline.com/ArcGIS',
                        service='ESRI_Imagery_World_2D',
                        xpixels=1000,
                        ypixels=None,
                        dpi=300)
    #finally:
    #   print "error: map could not be created..."

    #if dogrid==1:
    gx, gy = map.projtran(glon, glat)

    ax = plt.Axes(
        fig,
        [0., 0., 1., 1.],
    )
    ax.set_axis_off()
    fig.add_axes(ax)

    #if dogrid==1:
    if 2 > 1:
        if datm.size > 25000000:
            print "matrix size > 25,000,000 - decimating by factor of 5 for display"
            map.pcolormesh(gx[::5, ::5],
                           gy[::5, ::5],
                           datm[::5, ::5],
                           cmap='gray',
                           vmin=np.nanmin(datm),
                           vmax=np.nanmax(datm))
        else:
            map.pcolormesh(gx,
                           gy,
                           datm,
                           cmap='gray',
                           vmin=np.nanmin(datm),
                           vmax=np.nanmax(datm))
        del datm, dat
    else:
        ## draw point cloud
        x, y = map.projtran(humlon, humlat)
        map.scatter(x.flatten(),
                    y.flatten(),
                    0.5,
                    merge.flatten(),
                    cmap='gray',
                    linewidth='0')

    #map.drawmapscale(x1+0.001, y1+0.001, x1, y1, 200., units='m', barstyle='fancy', labelstyle='simple', fontcolor='k') #'#F8F8FF')
    #map.drawparallels(np.arange(y1-0.001, y2+0.001, 0.005),labels=[1,0,0,1], linewidth=0.0, rotation=30, fontsize=8)
    #map.drawmeridians(np.arange(x1, x2, 0.002),labels=[1,0,0,1], linewidth=0.0, rotation=30, fontsize=8)

    custom_save2(sonpath, 'map_imagery' + str(p))
    del fig

    del humlat, humlon
    return res  #return the new resolution
Exemple #15
0
def vis_one_image(im,
                  im_name,
                  boxes,
                  segms=None,
                  keypoints=None,
                  thresh=0.9,
                  kp_thresh=2,
                  dpi=200,
                  box_alpha=0.0,
                  dataset=None,
                  show_class=False,
                  ext=None):
    if isinstance(boxes, list):
        boxes, segms, keypoints, classes = convert_from_cls_format(
            boxes, segms, keypoints)
    if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
        return
    dataset_keypoints, _ = keypoint_utils.get_keypoints()
    if segms is not None and len(segms) > 0:
        masks = mask_util.decode(segms)
    color_list = colormap(rgb=True) / 255
    kp_lines = kp_connections(dataset_keypoints)
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]
    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sorted_inds = np.argsort(-areas)
    mask_color_id = 0
    C = []
    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1],
                          fill=False,
                          edgecolor='g',
                          linewidth=0.5,
                          alpha=box_alpha))
        if show_class:
            ax.text(bbox[0],
                    bbox[1] - 2,
                    get_class_string(classes[i], score, dataset),
                    fontsize=3,
                    family='serif',
                    bbox=dict(facecolor='g',
                              alpha=0.4,
                              pad=0,
                              edgecolor='none'),
                    color='white')
        if segms is not None and len(segms) > i:
            #print ("duoshaogeIIIIIIIIIIIIIIIIIIIIIIIIIII")
            #print (i)
            #print ("len(segms)*********")
            #print (len(segms))
            img = np.ones(im.shape)
            color_mask = color_list[mask_color_id % len(color_list), 0:3]
            mask_color_id += 1
            w_ratio = .4
            for c in range(3):
                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
            for c in range(3):
                img[:, :, c] = color_mask[c]
            e = masks[:, :, i]
            _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                                cv2.CHAIN_APPROX_NONE)
            for c in contour:
                polygon = Polygon(c.reshape((-1, 2)),
                                  fill=True,
                                  facecolor=color_mask,
                                  edgecolor='w',
                                  linewidth=0.2,
                                  alpha=0.5)
                ddd = list(c)
                woca = c.reshape((-1, 2))
                arr = np.array(woca)
                key = np.unique(woca)
                result = {}
                k_qu = []
                arr_end = []
                arr_end2 = []
                arr_endlast = []
                A = []
                B = []
                Bplus = []
                jjj = []
                Bwanmei = []
                Bcao = []
                Bao = []
                for k in key:
                    mask = (arr == k)
                    arr_new = arr[mask]
                    v = arr_new.size
                    result[k] = v
                    x = np.argwhere(arr == k)
                    if v > 1:
                        x = np.argwhere(arr == k)
                        x = np.array(x)
                        x0 = arr[:, 0]
                        y0 = arr[:, 1]
                        y0lie = []
                        for i in range(0, len(x0)):
                            if x0[i] != k:
                                pass
                            if x0[i] == k:
                                y0lie.append(y0[i])
                        y0lienew = []
                        arr_new = []
                        arr_new_2 = []
                        if y0lie == []:
                            pass
                        else:
                            miny0 = np.min(y0lie)
                            maxy0 = np.max(y0lie)
                            for i in range(miny0, maxy0 + 1):
                                y0lienew.append(i)
                            y0liezuizhong = []
                            if y0lienew == []:
                                pass
                            else:
                                miny0lienew = np.min(y0lienew)
                                maxy0lienew = np.max(y0lienew)
                                for i in range(miny0lienew, maxy0lienew + 1):
                                    y0liezuizhong.append(i)
                            for i in range(0, len(y0liezuizhong)):
                                arr_temp = [k, y0liezuizhong[i]]
                                arr_new.append(arr_temp)
                            arr_end.append(arr_new)
                        x0lie = []
                        for i in range(0, len(y0)):
                            if y0[i] != k:
                                pass
                            if y0[i] == k:
                                x0lie.append(x0[i])
                        x0lienew = []
                        arr_new2 = []
                        if x0lie == []:
                            pass
                        else:
                            minx0 = np.min(x0lie)
                            maxx0 = np.max(x0lie)
                            for i in range(minx0, maxx0 + 1):
                                x0lienew.append(i)
                            x0liezuizhong = []
                            if x0lienew == []:
                                pass
                            else:
                                minx0lienew = np.min(x0lienew)
                                maxx0lienew = np.max(x0lienew)
                                for i in range(minx0lienew, maxx0lienew + 1):
                                    x0liezuizhong.append(i)
                            for i in range(0, len(x0liezuizhong)):
                                arr_temp = [x0liezuizhong[i], k]
                                arr_new2.append(arr_temp)
                            arr_end2.append(arr_new2)
            arr_endlast = arr_end + arr_end2 + ddd
            A = list(chain(*arr_endlast))
            B = np.array(list(set([tuple(t) for t in A])))

            if len(B) > 10:
                Bplus = random.sample(B, 5)
            if len(B) < 10:
                jjj = arr_endlast
            Bwanmei = Bplus + jjj
            Bcaotmp = list(chain(*Bwanmei))
            Bcao = np.array(Bcaotmp)
            Bao = Bcao.reshape(-1, 2)
            C.append(Bao)

        if keypoints is not None and len(keypoints) > i:
            kps = keypoints[i]
            plt.autoscale(False)
            for l in range(len(kp_lines)):
                i1 = kp_lines[l][0]
                i2 = kp_lines[l][1]
                if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
                    x = [kps[0, i1], kps[0, i2]]
                    y = [kps[1, i1], kps[1, i2]]
                    line = plt.plot(x, y)
                    plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
                if kps[2, i1] > kp_thresh:
                    plt.plot(kps[0, i1],
                             kps[1, i1],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)
                if kps[2, i2] > kp_thresh:
                    plt.plot(kps[0, i2],
                             kps[1, i2],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)
            mid_shoulder = (
                kps[:2, dataset_keypoints.index('right_shoulder')] +
                kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
            sc_mid_shoulder = np.minimum(
                kps[2, dataset_keypoints.index('right_shoulder')],
                kps[2, dataset_keypoints.index('left_shoulder')])
            mid_hip = (kps[:2, dataset_keypoints.index('right_hip')] +
                       kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
            sc_mid_hip = np.minimum(
                kps[2, dataset_keypoints.index('right_hip')],
                kps[2, dataset_keypoints.index('left_hip')])
            if (sc_mid_shoulder > kp_thresh
                    and kps[2, dataset_keypoints.index('nose')] > kp_thresh):
                x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
                y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines)],
                         linewidth=1.0,
                         alpha=0.7)
            if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
                x = [mid_shoulder[0], mid_hip[0]]
                y = [mid_shoulder[1], mid_hip[1]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines) + 1],
                         linewidth=1.0,
                         alpha=0.7)
    plt.close('all')
    C = np.array(C)
    return C
def plot_rms(rms,
             output='rms.png',
             mode='both',
             scale=1.0,
             cmap='RdBu',
             cbar_size=0.8,
             cbar_sens=1.0,
             vmin=None,
             vmax=None,
             cbar_pad=0.01,
             pct_ticks=False,
             **kwargs):
    orig_cmap = plt.get_cmap(cmap)
    rms_cmap = colors.LinearSegmentedColormap.from_list(
        name='rms_cmap',
        N=255,
        colors=[orig_cmap(0), (1, 1, 1),
                orig_cmap(255)]
    )  #this adds a white color in the midpoint of the chosen colormap (remove if not divergent cmap)

    dpi = 100.0  #approx value. not used unless we wanted to plot this directly.
    fig = plt.figure(figsize=(rms.shape[1] / dpi, rms.shape[0] / dpi),
                     frameon=False)

    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    ax2 = plt.Axes(fig, [1.0, 0.5 * (1 - cbar_size), 0.05, cbar_size])
    fig.add_axes(ax2)

    if (vmin is None) and (vmax is None):
        _min, _max = rms.min(), rms.max()
        if (sign(_min) != sign(_max)):
            absmax = max(-_min, _max)
            vmin, vmax = -absmax, absmax
        elif (sign(_min) == +1):
            vmin, vmax = 0.0, _max
        else:
            vmin, vmax = _min, 0.0

    sns.heatmap(scale * rms,
                norm=MidpointNormalize(midpoint=0.,
                                       vmin=vmin,
                                       vmax=vmax,
                                       sensitivity=cbar_sens),
                vmin=vmin,
                vmax=vmax,
                cbar=True,
                cmap=rms_cmap,
                cbar_ax=ax2,
                annot=False,
                xticklabels=False,
                yticklabels=False,
                ax=ax,
                cbar_kws={'format': ticker.FuncFormatter(pct_fmt)}
                if pct_ticks else {})

    if (mode == 'both'):
        fig.savefig(output, dpi=dpi, bbox_inches='tight')
    elif (mode == 'error_only'):
        ax2.remove()
        fig.savefig(output, dpi=dpi)
    elif (mode == 'cbar_only'):
        ax.remove()
        fig.savefig(output, bbox_inches='tight')
Exemple #17
0
 def gen_land_bitmap(bmap, resolution_meters):
             
     #Get land polygons and bbox of polygons
     polys = []
     xmin = np.finfo(np.float64).max
     xmax = -np.finfo(np.float64).max
     ymin = xmin
     ymax = xmax
             
     logging.debug('Rasterizing Basemap, number of land polys: ' + str(len(bmap.landpolygons)))
     # If no polys: return a zero map
     if (len(bmap.landpolygons) == 0):
         raise Exception('Basemap contains no land polys to rasterize')
     
     for polygon in bmap.landpolygons:
         coords = polygon.get_coords()
         xmin = min(xmin, np.min(coords[:,0]))
         xmax = max(xmax, np.max(coords[:,0]))
         ymin = min(ymin, np.min(coords[:,1]))
         ymax = max(ymax, np.max(coords[:,1]))
         polys.append(coords)
         
     xmin = np.floor(xmin/resolution_meters)*resolution_meters
     xmax = np.ceil(xmax/resolution_meters)*resolution_meters
     ymin = np.floor(ymin/resolution_meters)*resolution_meters
     ymax = np.ceil(ymax/resolution_meters)*resolution_meters
     
     # For debugging
     logging.debug('Rasterizing Basemap, bounding box: ' + str([xmin, xmax, ymin, ymax]))
     
     # Switch backend to prevent creating an empty figure in notebook
     orig_backend = plt.get_backend()
     plt.switch_backend('agg')
     
     # Create figure to help rasterize
     fig = plt.figure(frameon=False)
     ax = plt.Axes(fig, [0., 0., 1., 1.])
     ax.set_axis_off()
     fig.add_axes(ax)   
     ax.set_xlim(xmin, xmax)
     ax.set_ylim(ymin, ymax)
     
     # Set aspect and resolution
     # Aspect gives 1 in high plot
     aspect = (xmax-xmin)/(ymax-ymin)
     resolution_dpi = (ymax-ymin) / resolution_meters
     while resolution_dpi > 10000:
         logging.debug('Too large dpi %s, reducing by factor of 2'
                       % resolution_dpi)
         resolution_meters = resolution_meters*2
         resolution_dpi = (ymax-ymin) / resolution_meters
     
     fig.set_dpi(resolution_dpi)
     fig.set_size_inches(aspect, 1)
     
     # Add polygons
     lc = PolyCollection(polys, facecolor='k', lw=0)
     ax.add_collection(lc)
     
     # Create canvas and rasterize
     canvas = FigureCanvasAgg(fig)
     try:
         canvas.draw()
         width, height = canvas.get_width_height()
         rgb_data = np.frombuffer(canvas.tostring_rgb(), dtype='uint8').reshape(height, width, 3)
         data = rgb_data[:,:,1]
         plt.close(fig) #comment this for debugging purposes and replace with plt.show()
         logging.debug('Rasterized size: ' + str([width, height]))
     except MemoryError:
         gc.collect()
         raise Exception('Basemap rasterized size too large: ' 
                         + str(aspect*resolution_dpi) + '*' + str(resolution_dpi) 
                         + ' cells')
     finally:
         # Reset backend
         plt.switch_backend(orig_backend)
     
     
     return RasterizedBasemap(xmin, xmax, ymin, ymax, resolution_meters, data)
Exemple #18
0
def create_raster_plot_combined(trials,
                                align_event,
                                sorting_var='trial_id',
                                x_lim=[-1, 1],
                                show_plot=False,
                                fig_dir=None,
                                store_type=None):

    sorting_query, mark, label = get_sort_and_marker(align_event, sorting_var)

    fig = plt.figure(dpi=150, frameon=False, figsize=[10, 5])
    ax = plt.Axes(fig, [0., 0., 1., 1.])

    if len(trials):
        if sorting_var == 'trial_id':
            spk_times, trial_ids = (trials
                                    & 'event="{}"'.format(align_event)).fetch(
                                        'trial_spike_times',
                                        'trial_id',
                                        order_by='trial_id')
            spk_trial_ids = np.hstack(
                [[trial_id] * len(spk_time)
                 for trial_id, spk_time in enumerate(spk_times)])
            ax.plot(np.hstack(spk_times),
                    spk_trial_ids,
                    'k.',
                    alpha=0.5,
                    markeredgewidth=0)
        elif sorting_var == 'contrast':
            spk_times, trial_contrasts = (
                trials & 'event="{}"'.format(align_event)).fetch(
                    'trial_spike_times',
                    'trial_signed_contrast',
                    order_by='trial_signed_contrast, trial_id')
            spk_trial_ids = np.hstack(
                [[trial_id] * len(spk_time)
                 for trial_id, spk_time in enumerate(spk_times)])
            ax.plot(np.hstack(spk_times),
                    spk_trial_ids,
                    'k.',
                    alpha=0.5,
                    markeredgewidth=0)

            # plot different contrasts as background
            contrasts, u_inds = np.unique(trial_contrasts, return_index=True)
            u_inds = list(u_inds) + [len(trial_contrasts)]

            tick_positions = np.add(u_inds[1:], u_inds[:-1]) / 2

            puor = cl.scales[str(len(contrasts))]['div']['PuOr']
            puor = np.divide(cl.to_numeric(puor), 255)

            for i, ind in enumerate(u_inds[:-1]):
                ax.fill_between([-1, 1],
                                u_inds[i],
                                u_inds[i + 1] - 1,
                                color=puor[i],
                                alpha=0.8)
            fig.add_axes(ax)
        elif sorting_var == 'feedback type':
            spk_times, trial_fb_types = (
                trials & 'event="{}"'.format(align_event)).fetch(
                    'trial_spike_times',
                    'trial_feedback_type',
                    order_by='trial_feedback_type, trial_id')
            spk_trial_ids = np.hstack(
                [[trial_id] * len(spk_time)
                 for trial_id, spk_time in enumerate(spk_times)])
            ax.plot(np.hstack(spk_times),
                    spk_trial_ids,
                    'k.',
                    alpha=0.5,
                    markeredgewidth=0)

            # plot different feedback types as background
            fb_types, u_inds = np.unique(trial_fb_types, return_index=True)
            u_inds = list(u_inds) + [len(trial_fb_types)]

            colors = sns.diverging_palette(10, 240, n=len(fb_types))

            for i, ind in enumerate(u_inds[:-1]):
                ax.fill_between([-1, 1],
                                u_inds[i],
                                u_inds[i + 1] - 1,
                                color=colors[i],
                                alpha=0.5)
            fig.add_axes(ax)
        else:
            spk_times_left, marking_points_left, \
                spk_times_right, marking_points_right, \
                spk_times_incorrect, marking_points_incorrect = \
                get_spike_times_trials(
                    trials, sorting_var, align_event, sorting_query, mark)

            id_gap = len(trials) * 0

            if len(spk_times_incorrect):
                spk_times_all_incorrect = np.hstack(spk_times_incorrect)
                id_incorrect = [
                    [i] * len(spike_time)
                    for i, spike_time in enumerate(spk_times_incorrect)
                ]
                id_incorrect = np.hstack(id_incorrect)
                ax.plot(spk_times_all_incorrect,
                        id_incorrect,
                        'r.',
                        alpha=0.5,
                        markeredgewidth=0,
                        label='incorrect trials')
                ax.plot(marking_points_incorrect,
                        range(len(spk_times_incorrect)),
                        'r',
                        label=label)
            else:
                id_incorrect = [0]

            if not len(id_incorrect):
                id_incorrect = [0]

            if len(spk_times_left):
                spk_times_all_left = np.hstack(spk_times_left)
                id_left = [[i + max(id_incorrect) + id_gap] * len(spike_time)
                           for i, spike_time in enumerate(spk_times_left)]
                id_left = np.hstack(id_left)
                ax.plot(spk_times_all_left,
                        id_left,
                        'g.',
                        alpha=0.5,
                        markeredgewidth=0,
                        label='left trials')
                ax.plot(
                    marking_points_left,
                    np.add(range(len(spk_times_left)),
                           max(id_incorrect) + id_gap), 'g')
            else:
                id_left = [max(id_incorrect)]

            if not len(id_left):
                id_left = [max(id_incorrect)]

            if len(spk_times_right):
                spk_times_all_right = np.hstack(spk_times_right)
                id_right = [[i + max(id_left) + id_gap] * len(spike_time)
                            for i, spike_time in enumerate(spk_times_right)]
                id_right = np.hstack(id_right)

                ax.plot(spk_times_all_right,
                        id_right,
                        'b.',
                        alpha=0.5,
                        markeredgewidth=0,
                        label='right trials')
                ax.plot(
                    marking_points_right,
                    np.add(range(len(spk_times_right)),
                           max(id_left) + id_gap), 'b')
            else:
                id_right = [max(id_left)]

            if not len(id_right):
                id_right = [max(id_left)]

    ax.set_axis_off()
    fig.add_axes(ax)

    # hide the axis
    ax.get_xaxis().set_visible(False)
    ax.get_yaxis().set_visible(False)

    # set the limits
    ax.set_xlim(x_lim[0], x_lim[1])
    if sorting_var in ('trial_id', 'contrast', 'feedback type'):
        if len(spk_trial_ids):
            y_lim = max(spk_trial_ids) * 1.02
        else:
            y_lim = 2
    else:
        y_lim = max(id_right) * 1.02
    ax.set_ylim(-2, y_lim)

    if not show_plot:
        plt.close(fig)

    # save the figure with `pad_inches=0` to remove
    # any padding in the image

    if fig_dir:
        store_fig_external(fig, store_type, fig_dir)
        fig.clear()
        gc.collect()
        if sorting_var == 'contrast':
            return [0, y_lim], label, contrasts, tick_positions
        else:
            return [0, y_lim], label
    else:
        encoded_string = convert_fig_to_encoded_string(fig)
        if sorting_var == 'contrast':
            return encoded_string, [0, y_lim], label, contrasts, tick_positions
        else:
            return encoded_string, [0, y_lim], label
Exemple #19
0
    def plot(self,
             figure=None,
             overlays=None,
             draw_limb=True,
             gamma=None,
             draw_grid=False,
             colorbar=True,
             basic_plot=False,
             **matplot_args):
        """Plots the map object using matplotlib

        Parameters
        ----------
        overlays : list
            List of overlays to include in the plot
        draw_limb : bool
            Whether the solar limb should be plotted.
        draw_grid : bool
            Whether solar meridians and parallels
        grid_spacing : float
            Set the spacing between meridians and parallels for the grid
        gamma : float
            Gamma value to use for the color map
        colorbar : bool
            Whether to display a colorbar next to the plot
        basic_plot : bool
            If true, the data is plotted by itself at it's natural scale; no
            title, labels, or axes are shown.
        **matplot_args : dict
            Matplotlib Any additional imshow arguments that should be used
            when plotting the image.
        """
        if overlays is None:
            overlays = []
        if draw_limb:
            overlays = overlays + [self._draw_limb]
        # TODO: need to be able to pass the grid spacing to _draw_grid from the
        # plot command.
        if draw_grid:
            overlays = overlays + [self._draw_grid]

        # Create a figure and add title and axes
        if figure is None:
            figure = plt.figure(frameon=not basic_plot)

        # Basic plot
        if basic_plot:
            axes = plt.Axes(figure, [0., 0., 1., 1.])
            axes.set_axis_off()
            figure.add_axes(axes)

        # Normal plot
        else:
            axes = figure.add_subplot(111)
            axes.set_title("%s %s" % (self.name, self.date))

            # x-axis label
            if self.coordinate_system['x'] == 'HG':
                xlabel = 'Longitude [%s]' % self.units['x']
            else:
                xlabel = 'X-position [%s]' % self.units['x']

            # y-axis label
            if self.coordinate_system['y'] == 'HG':
                ylabel = 'Latitude [%s]' % self.units['y']
            else:
                ylabel = 'Y-position [%s]' % self.units['y']

            axes.set_xlabel(xlabel)
            axes.set_ylabel(ylabel)

        # Determine extent
        extent = self.xrange + self.yrange

        # Matplotlib arguments
        params = {"cmap": self.cmap, "norm": self.norm()}
        params.update(matplot_args)

        if gamma is not None:
            params['cmap'] = copy(params['cmap'])
            params['cmap'].set_gamma(gamma)

        im = axes.imshow(self, origin='lower', extent=extent, **params)

        if colorbar and not basic_plot:
            figure.colorbar(im)

        for overlay in overlays:
            figure, axes = overlay(figure, axes)
        return figure
Exemple #20
0
def driftmap(clusters_depths,
             spikes_times,
             spikes_amps,
             spikes_depths,
             spikes_clusters,
             ax=None,
             axesoff=False,
             return_lims=False):
    '''
    Plots the driftmap of a session or a trial.

    The plot shows the spike times vs spike depths.
    Each dot is a spike, whose color indicates the cluster
    and opacity indicates the spike amplitude.

    Parameters
    -------------
    clusters_depths: ndarray
        depths of all clusters
    spikes_times: ndarray
        spike times of all clusters
    spikes_amps: ndarray
        amplitude of each spike
    spikes_depths: ndarray
        depth of each spike
    spikes_clusters: ndarray
        cluster idx of each spike
    ax: axessubplot (optional)
        The axis handle to plot the driftmap on
        (if `None`, a new figure and axis is created)

    Return
    ---
    ax: axessubplot
    x_lim: list of two elements
    y_lim: list of two elements

    '''

    # get the sorted idx of each depth, and create colors based on the idx

    sorted_idx = np.argsort(np.argsort(clusters_depths))

    colors = np.vstack([
        np.repeat(new_color_bins[np.mod(idx, 500), :][np.newaxis, ...],
                  n_spikes,
                  axis=0)
        for (
            idx,
            n_spikes) in zip(sorted_idx,
                             np.unique(spikes_clusters, return_counts=True)[1])
    ])

    max_amp = np.percentile(spikes_amps, 90)
    min_amp = np.percentile(spikes_amps, 10)
    opacity = np.divide(spikes_amps - min_amp, max_amp - min_amp)
    opacity[opacity > 1] = 1
    opacity[opacity < 0] = 0

    colorvec = np.zeros([len(opacity), 4], dtype='float16')
    colorvec[:, 3] = opacity.astype('float16')
    colorvec[:, 0:3] = colors.astype('float16')

    x = spikes_times.astype('float32')
    y = spikes_depths.astype('float32')

    if ax is None:
        fig = plt.Figure(dpi=50, frameon=False, figsize=[90, 90])
        ax = plt.Axes(fig, [0., 0., 1., 1.])

    ax.scatter(x, y, color=colorvec, edgecolors='none')
    x_edge = (np.nanmax(x) - np.nanmin(x)) * 0.05
    x_lim = [np.nanmin(x) - x_edge, np.nanmax(x) + x_edge]
    y_lim = [np.nanmin(y) - 50, np.nanmax(y) + 100]
    ax.set_xlim(x_lim[0], x_lim[1])
    ax.set_ylim(y_lim[0], y_lim[1])

    if axesoff:
        ax.axis('off')

    if return_lims:
        return ax, x_lim, y_lim
    else:
        return ax
Exemple #21
0
def draw_display(dispsize, imagefile=None):
    """Returns a matplotlib.pyplot Figure and its axes, with a size of
	dispsize, a black background colour, and optionally with an image drawn
	onto it
	
	arguments
	
	dispsize		-	tuple or list indicating the size of the display,
					e.g. (1024,768)
	
	keyword arguments
	
	imagefile		-	full path to an image file over which the heatmap
					is to be laid, or None for no image; NOTE: the image
					may be smaller than the display size, the function
					assumes that the image was presented at the centre of
					the display (default = None)
	
	returns
	fig, ax		-	matplotlib.pyplot Figure and its axes: field of zeros
					with a size of dispsize, and an image drawn onto it
					if an imagefile was passed
	"""

    # construct screen (black background)
    screen = numpy.zeros((dispsize[1], dispsize[0], 3), dtype='uint8')
    # if an image location has been passed, draw the image
    if imagefile != None:
        # check if the path to the image exists
        if not os.path.isfile(imagefile):
            raise Exception(
                "ERROR in draw_display: imagefile not found at '%s'" %
                imagefile)
        # load image
        img = image.imread(imagefile)
        # flip image over the horizontal axis
        # (do not do so on Windows, as the image appears to be loaded with
        # the correct side up there; what's up with that? :/)
        if not os.name == 'nt':
            img = numpy.flipud(img)
        # width and height of the image
        w, h = len(img[0]), len(img)
        # x and y position of the image on the display
        x = dispsize[0] / 2 - w / 2
        y = dispsize[1] / 2 - h / 2
        # draw the image on the screen
        screen[y:y + h, x:x + w, :] += img
    # dots per inch
    dpi = 100.0
    # determine the figure size in inches
    figsize = (dispsize[0] / dpi, dispsize[1] / dpi)
    # create a figure
    fig = pyplot.figure(figsize=figsize, dpi=dpi, frameon=False)
    ax = pyplot.Axes(fig, [0, 0, 1, 1])
    ax.set_axis_off()
    fig.add_axes(ax)
    # plot display
    ax.axis([0, dispsize[0], 0, dispsize[1]])
    ax.imshow(screen)  #, origin='upper')

    return fig, ax
def plot_array(imgdata, pixelsize=1., pixelunit="", scale_bar=True,
               show_fig=True, width=15, dpi=None,
               sb_settings={"location": 'lower right',
                            "color": 'k',
                            "length_fraction": 0.15,
                            "font_properties": {"size": 12}},
               imshow_kwargs={"cmap": "Greys_r"}):
    '''
    Plot a 2D numpy array as an image.
    A scale-bar can be included.
    Parameters
    ----------
    imgdata : array-like, 2D
        the image frame
    pixelsize : float, optional
        the scale size of one pixel
    pixelunit : str, optional
        the unit in which pixelsize is expressed
    scale_bar : bool, optional
        whether to add a scale bar to the image. Defaults to True.
    show_fig : bool, optional
        whether to show the figure. Defaults to True.
    width : float, optional
        width (in cm) of the plot. Default is 15 cm
    dpi : int, optional
        alternative to width. dots-per-inch can give an indication of size
        if the image is printed. Overrides width.
    sb_settings : dict, optional
        key word args passed to the scale bar function. Defaults are:
        {"location":'lower right', "color" : 'k', "length_fraction" : 0.15,
         "font_properties": {"size": 40}}
        See: <https://pypi.org/project/matplotlib-scalebar/>
    imshow_kwargs : dict, optional
        optional formating arguments passed to the pyplot.imshow function.
        Defaults are: {"cmap": "Greys_r"}
    Returns
    -------
    ax : matplotlib Axis object
    im : the image plot object
    '''
    # initialize the figure and axes objects
    if not show_fig:
        plt.ioff()
    if dpi is not None:
        fig = plt.figure(frameon=False,
                         figsize=(imgdata.shape[1]/dpi, imgdata.shape[0]/dpi))
    else:
        # change cm units into inches
        width = width*0.3937008
        height = width/imgdata.shape[1]*imgdata.shape[0]
        fig = plt.figure(frameon=False,
                         figsize=(width, height))
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    # plot the figure on the axes
    im = ax.imshow(imgdata, **imshow_kwargs)

    if scale_bar:
        # get scale bar info from metadata
        px = pixelsize
        unit = pixelunit
        # check the units and adjust sb accordingly
        scalebar = get_scalebar(px, unit, sb_settings)
        plt.gca().add_artist(scalebar)
    # if show_fig:
    #     plt.show()
    # else:
    #     plt.close()
    return ax, im
def detect_objects(model,
                   IMAGE_PATH,
                   CATEGORY_INDEX,
                   ANCHOR_POINTS,
                   MINIMUM_CONFIDENCE,
                   SAVE_DIR=None):
    '''
    Adapted from tensorflow slim.
    Performs object detection inference on given image (.jpg in IMAGE_PATH)
    with a tflite Mobilenet-V1 model which has already been allocated tensors

    Args:
    model              : (tf.lite.Interpreter) applied to MODEL_PATH
                        Should have run model.allocate_tensors() previously.
                        See initiate_tflite_model

    IMAGE_PATH         : (String) Path to image (.jpg) on which to perform inference

    CATEGORY_INDEX     : (Dictionary) Dictionary of dictionaries, key: position in labels.json
                        Each sub-dictionary has "name" field which will be displayed
                        beside the bounding box

    ANCHOR_POINTS      : array(N, 4) Shaped numpy array representing each of the
                        anchor points provided for Mobilenet-v1 SSD in anchors.json

    MINIMUM_CONFIDENCE : (Float) Minimum score permissible for box to be considered for
                        display. Mobilenet-V1 SSD tends to provide quite low probabilities.

    SAVE_DIR           : (String) (Optional) Directory to save output inference images to.
                        If None, will print the plot to CLI with Matplotlib
    '''

    image = Image.open(IMAGE_PATH)
    start = t.time()
    classes, boxes, scores = call_tflite_model(model, image)
    print("Inference time: {}".format(t.time() - start))
    image_np = load_image_into_numpy_array(image)
    image_np_reg = load_image_into_numpy_array(image, reg=True)
    image_np_expanded = np.expand_dims(image_np, axis=0)

    # Convert the quantised boxes to normalised
    ty = boxes[:, 0] / float(10)
    tx = boxes[:, 1] / float(10)
    th = boxes[:, 2] / float(5)
    tw = boxes[:, 3] / float(5)

    yACtr = ANCHOR_POINTS[:, 0]
    xACtr = ANCHOR_POINTS[:, 1]
    ha = ANCHOR_POINTS[:, 2]
    wa = ANCHOR_POINTS[:, 3]

    w = np.exp(tw) * wa
    h = np.exp(th) * ha

    yCtr = ty * ha + yACtr
    xCtr = tx * wa + xACtr

    yMin = yCtr - h / float(2)
    xMin = xCtr - w / float(2)
    yMax = yCtr + h / float(2)
    xMax = xCtr + w / float(2)

    boxes_normalised = [yMin, xMin, yMax, xMax]
    print("-" * 10)
    print("Inference Summary:")
    print("Highest Score: {}".format(np.max(scores)))
    print("Highest Scoring Box: {}".format(
        np.transpose(np.squeeze(boxes_normalised))[np.argmax(scores)]))
    print("-" * 10)
    print("Image shape: {}".format(np.squeeze(image_np).shape))
    print("Boxes shape: {}".format(
        np.transpose(np.squeeze(boxes_normalised)).shape))
    print("Classes shape: {}".format(
        np.round(np.squeeze(classes)).astype(np.int32).shape))
    print("Scores shape: {}".format(np.squeeze(scores).shape))
    fig = plt.figure()
    out_image = vis_util.visualize_boxes_and_labels_on_image_array(
        np.squeeze(image_np),
        np.transpose(np.squeeze(boxes_normalised)),
        np.round(np.squeeze(classes)).astype(np.int32),
        np.squeeze(scores),
        CATEGORY_INDEX,
        min_score_thresh=MINIMUM_CONFIDENCE,
        use_normalized_coordinates=True,
        line_thickness=1,
        ret=True)

    fig.set_size_inches(16, 9)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)

    plt.imshow(out_image / 255)
    if SAVE_DIR:
        output_dir = SAVE_DIR + '/' + 'location_1'  #image name

        mkdir_p(Path(output_dir).parent)

        plt.savefig(
            output_dir, dpi=62
        )  #save the image with the detected frame in the output directory
        plt.close(fig)
        print("Image Saved")
        print("=" * 10)

    CURRENT_PATH = os.getcwd()
    STATUS_PATH = os.path.join(CURRENT_PATH, 'model/output/status_1.txt')

    if np.max(
            scores
    ) > 0.60:  # Write to status_ID.txt file to indicate whether a fire is detected
        f = open(STATUS_PATH, "w")
        f.write("1")
        f.close()
    else:
        f = open(STATUS_PATH, "w")
        f.write("0")
        f.close
Exemple #24
0
def export_html(exp,
                sample_field=None,
                feature_field=False,
                title=None,
                xticklabel_len=50,
                cmap=None,
                clim=None,
                transform=log_n,
                output_file='out',
                html_template=None,
                **kwargs):
    '''Export an interactive html heatmap for the experiment.

    Creates a standalone html file with interactive d3.js heatmap of the experiment and interface to dbBact.

    Parameters
    ----------
    sample_field : str or None (optional)
        The field to display on the x-axis (sample):
        None (default) to not show x labels.
        str to display field values for this field
    feature_field : str or None or False(optional)
        Name of the field to display on the y-axis (features) or None not to display names
        Flase (default) to use the experiment subclass default field
    title : None or str (optional)
        None (default) to show experiment description field as title. str to set title to str.
    xticklabel_len : int (optional) or None
        The maximal length for the x label strings (will be cut to
        this length if longer). Used to prevent long labels from
        taking too much space. None indicates no cutting
    cmap : None or str (optional)
        None (default) to use mpl default color map. str to use colormap named str.
    clim : tuple of (float, float) or None (optional)
        the min and max values for the heatmap or None to use all range. It uses the min
        and max values in the ``data`` array by default.
    transform : function (optional)
        The transform function to apply to the data before plotting. default is log_n
    output_file : str (optional)
        Name of the output html file (no .html ending - it will be appended).
    html_template : str or None (optional)
        Name of the html template to use. None to use the default export_html_template.html template
    '''
    import matplotlib.pyplot as plt

    if html_template is None:
        html_template = resource_filename(__package__,
                                          'export_html_template.html')
        logger.debug('using default template file %s' % html_template)

    logger.debug('export_html heatmap')

    # get the default feature field if not specified (i.e. False)
    if feature_field is False:
        feature_field = exp.heatmap_feature_field
    numrows, numcols = exp.shape
    # step 1. transform data
    if transform is None:
        data = exp.get_data(sparse=False)
    else:
        logger.debug('transform exp with %r with param %r' %
                     (transform, kwargs))
        data = transform(exp, inplace=False, **kwargs).data

    # step 2. plot heatmap.
    # init the default colormap
    if cmap is None:
        cmap = plt.rcParams['image.cmap']
    # plot the heatmap with 1 pixel per feature/sample, no axes/lines
    fig = plt.figure(frameon=False, dpi=300)
    fig.set_size_inches(exp.shape[0] / 300, exp.shape[1] / 300)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    ax.imshow(data.transpose(), interpolation='nearest')

    if title is None:
        title = exp.description

    # add parameters to html template
    with open(html_template) as fl:
        html_page = fl.read()
    html_page = html_page.replace(
        '// yticklabels go here', 'var yticklabels = %s;' %
        _list_to_string(exp.feature_metadata[feature_field].values))
    html_page = html_page.replace(
        '// ids go here',
        'var ids = %s;' % _list_to_string(exp.feature_metadata.index.values))
    html_page = html_page.replace(
        '// samples go here', 'var samples = %s;' %
        _list_to_string(exp.sample_metadata.index.values))
    if sample_field is not None:
        html_page = html_page.replace('// field_name goes here',
                                      'var field_name = "%s";' % sample_field)
    html_page = html_page.replace('// title_text goes here',
                                  'var title_text = "%s";' % title)

    # add vertical lines between sample groups and add x tick labels
    if sample_field is not None:
        try:
            xticks = _transition_index(exp.sample_metadata[sample_field])
        except KeyError:
            raise ValueError('Sample field %r not in sample metadata.' %
                             sample_field)
        x_pos, x_val = zip(*xticks)
        x_pos = np.array([0.] + list(x_pos))

        html_page = html_page.replace(
            '// vlines go here',
            'var vlines = %s;' % _list_to_string(x_pos[1:-1]))
        xtick_pos = x_pos[:-1] + (x_pos[1:] - x_pos[:-1]) / 2
        html_page = html_page.replace(
            '// xtick_pos go here',
            'var xtick_pos = %s;' % _list_to_string(xtick_pos))

        xticklabels = [str(i) for i in x_val]
        # shorten x tick labels that are too long:
        if xticklabel_len is not None:
            mid = int(xticklabel_len / 2)
            xticklabels = [
                '%s..%s' %
                (i[:mid], i[-mid:]) if len(i) > xticklabel_len else i
                for i in xticklabels
            ]
        html_page = html_page.replace(
            '// xtick_labels go here',
            'var xtick_labels = %s;' % _list_to_string(xticklabels))

    # embed the figure png into the html page
    with BytesIO() as figfile:
        fig.savefig(figfile, format='png', dpi=300)
        figfile.seek(0)  # rewind to beginning of file
        import base64
        figdata_png = base64.b64encode(figfile.getvalue())
        figdata_png = urllib.parse.quote(figdata_png)
    html_page = html_page.replace('**image_goes_here**', figdata_png)

    if output_file[-5:] != '.html':
        output_file = output_file + '.html'

    # save the output html export
    with open(output_file, 'w') as fl:
        fl.write(html_page)
    logger.info('exported experiment to html file %s' % output_file)
    return
Exemple #25
0
def vis_one_image(im,
                  im_name,
                  output_dir,
                  boxes,
                  segms=None,
                  keypoints=None,
                  thresh=0.9,
                  kp_thresh=2,
                  dpi=200,
                  box_alpha=0.0,
                  dataset=None,
                  show_class=False,
                  ext='pdf',
                  out_when_no_box=False):
    """Visual debugging of detections."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if isinstance(boxes, list):
        boxes, segms, keypoints, classes = convert_from_cls_format(
            boxes, segms, keypoints)

    if (boxes is None or boxes.shape[0] == 0
            or max(boxes[:, 4]) < thresh) and not out_when_no_box:
        return

    dataset_keypoints, _ = keypoint_utils.get_keypoints()

    if segms is not None and len(segms) > 0:
        masks = mask_util.decode(segms)

    color_list = colormap(rgb=True) / 255

    kp_lines = kp_connections(dataset_keypoints)
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]

    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)

    if boxes is None:
        sorted_inds = []  # avoid crash when 'boxes' is None
    else:
        # Display in largest to smallest order to reduce occlusion
        areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
        sorted_inds = np.argsort(-areas)

    mask_color_id = 0
    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue

        # show box (off by default)
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1],
                          fill=False,
                          edgecolor='g',
                          linewidth=0.5,
                          alpha=box_alpha))

        if show_class:
            ax.text(bbox[0],
                    bbox[1] - 2,
                    get_class_string(classes[i], score, dataset),
                    fontsize=3,
                    family='serif',
                    bbox=dict(facecolor='g',
                              alpha=0.4,
                              pad=0,
                              edgecolor='none'),
                    color='white')

        # show mask
        if segms is not None and len(segms) > i:
            img = np.ones(im.shape)
            color_mask = color_list[mask_color_id % len(color_list), 0:3]
            mask_color_id += 1

            w_ratio = .4
            for c in range(3):
                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
            for c in range(3):
                img[:, :, c] = color_mask[c]
            e = masks[:, :, i]

            contour = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                       cv2.CHAIN_APPROX_NONE)[-2]

            for c in contour:
                polygon = Polygon(c.reshape((-1, 2)),
                                  fill=True,
                                  facecolor=color_mask,
                                  edgecolor='w',
                                  linewidth=1.2,
                                  alpha=0.5)
                ax.add_patch(polygon)

        # show keypoints
        if keypoints is not None and len(keypoints) > i:
            kps = keypoints[i]
            plt.autoscale(False)
            for l in range(len(kp_lines)):
                i1 = kp_lines[l][0]
                i2 = kp_lines[l][1]
                if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
                    x = [kps[0, i1], kps[0, i2]]
                    y = [kps[1, i1], kps[1, i2]]
                    line = plt.plot(x, y)
                    plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
                if kps[2, i1] > kp_thresh:
                    plt.plot(kps[0, i1],
                             kps[1, i1],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)

                if kps[2, i2] > kp_thresh:
                    plt.plot(kps[0, i2],
                             kps[1, i2],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)

            # add mid shoulder / mid hip for better visualization
            mid_shoulder = (
                kps[:2, dataset_keypoints.index('right_shoulder')] +
                kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
            sc_mid_shoulder = np.minimum(
                kps[2, dataset_keypoints.index('right_shoulder')],
                kps[2, dataset_keypoints.index('left_shoulder')])
            mid_hip = (kps[:2, dataset_keypoints.index('right_hip')] +
                       kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
            sc_mid_hip = np.minimum(
                kps[2, dataset_keypoints.index('right_hip')],
                kps[2, dataset_keypoints.index('left_hip')])
            if (sc_mid_shoulder > kp_thresh
                    and kps[2, dataset_keypoints.index('nose')] > kp_thresh):
                x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
                y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines)],
                         linewidth=1.0,
                         alpha=0.7)
            if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
                x = [mid_shoulder[0], mid_hip[0]]
                y = [mid_shoulder[1], mid_hip[1]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines) + 1],
                         linewidth=1.0,
                         alpha=0.7)

    output_name = os.path.basename(im_name) + '.' + ext
    fig.savefig(os.path.join(output_dir, '{}'.format(output_name)), dpi=dpi)
    plt.close('all')
def main(argv):
    ntrans = 1
    save_to_mat = 'off'
    flip_profile = 'no'
    which_gps = 'all'
    flip_updown = 'yes'
    incidence_file = 'incidence_file'
    display_InSAR = 'on'
    display_Average = 'on'
    display_Standard_deviation = 'on'

    try:
        opts, args = getopt.getopt(
            argv, "f:s:e:n:d:g:l:h:r:L:F:p:u:G:S:i:I:A:U:E:D:W:x:X:")

    except getopt.GetoptError:
        usage()
        sys.exit(1)

    for opt, arg in opts:
        if opt == '-f':
            velocityFile = arg
        elif opt == '-s':
            pnt1 = arg.split(',')
            y0 = int(pnt1[0])
            x0 = int(pnt1[1])
        elif opt == '-e':
            pnt2 = arg.split(',')
            y1 = int(pnt2[0])
            x1 = int(pnt2[1])
        elif opt == '-n':
            ntrans = int(arg)
        elif opt == '-d':
            dp = float(arg)
        elif opt == '-g':
            gpsFile = arg
        elif opt == '-r':
            refStation = arg
        elif opt == '-i':
            incidence_file = arg
        elif opt == '-L':
            stationsList = arg.split(',')
        elif opt == '-F':
            Fault_coord_file = arg
        elif opt == '-p':
            flip_profile = arg
        elif opt == '-u':
            flip_updown = arg
            print(flip_updown)
        elif opt == '-G':
            which_gps = arg
        elif opt == '-S':
            gps_source = arg
        elif opt == '-l':
            lbound = float(arg)
        elif opt == '-I':
            display_InSAR = arg
        elif opt == '-A':
            display_Average = arg
        elif opt == '-U':
            display_Standard_deviation = arg
        elif opt == '-E':
            save_to_mat = arg
        elif opt == '-h':
            hbound = float(arg)
        elif opt == '-D':
            Dp = float(arg)
        elif opt == '-W':
            profile_Length = float(arg)
        elif opt == '-x':
            x_lbound = float(arg)
        elif opt == '-X':
            x_hbound = float(arg)

    try:
        h5file = h5py.File(velocityFile, 'r')
    except:
        usage()
        sys.exit(1)

    k = list(h5file.keys())
    dset = h5file[k[0]].get(k[0])
    z = dset[0:dset.shape[0], 0:dset.shape[1]]
    dx = float(h5file[k[0]].attrs['X_STEP']) * 6375000.0 * np.pi / 180.0
    dy = float(h5file[k[0]].attrs['Y_STEP']) * 6375000.0 * np.pi / 180.0

    #############################################################################

    try:
        lat, lon, lat_step, lon_step, lat_all, lon_all = get_lat_lon(h5file)
    except:
        print('radar coordinate')

    Fault_lon, Fault_lat = read_fault_coords(Fault_coord_file, Dp)

    # Fault_lon=[66.40968453947265,66.36000186563085,66.31103920134248]
    # Fault_lat=[30.59405079532564,30.51565960186412,30.43928430936202]

    Num_profiles = len(Fault_lon) - 1
    print('*********************************************')
    print('*********************************************')
    print('Number of profiles to be generated: ' + str(Num_profiles))
    print('*********************************************')
    print('*********************************************')

    for Np in range(Num_profiles):
        FaultCoords = [
            Fault_lat[Np], Fault_lon[Np], Fault_lat[Np + 1], Fault_lon[Np + 1]
        ]
        print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
        print('%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%')
        print('')
        print('Profile ' + str(Np) + ' [of total ' + str(Num_profiles) + ']')
        print('')

        try:
            #  Lat0 = dms2d(FaultCoords[0]); Lon0 = dms2d(FaultCoords[1])
            #  Lat1 = dms2d(FaultCoords[2]); Lon1 = dms2d(FaultCoords[3])

            Lat0 = FaultCoords[0]
            Lon0 = FaultCoords[1]
            Lat1 = FaultCoords[2]
            Lon1 = FaultCoords[3]
            Length, Width = np.shape(z)
            Yf0, Xf0 = find_row_column(Lon0, Lat0, lon, lat, lon_step,
                                       lat_step)
            Yf1, Xf1 = find_row_column(Lon1, Lat1, lon, lat, lon_step,
                                       lat_step)

            print('*********************************************')
            print(' Fault Coordinates:')
            print('   --------------------------  ')
            print('    Lat          Lon')
            print(str(Lat0) + ' , ' + str(Lon0))
            print(str(Lat1) + ' , ' + str(Lon1))
            print('   --------------------------  ')
            print('    row          column')
            print(str(Yf0) + ' , ' + str(Xf0))
            print(str(Yf1) + ' , ' + str(Xf1))
            print('*********************************************')
            #mf=float(Yf1-Yf0)/float((Xf1-Xf0))  # slope of the fault line
            #cf=float(Yf0-mf*Xf0)   # intercept of the fault line
            #df0=dist_point_from_line(mf,cf,x0,y0,1,1) #distance of the profile start point from the Fault line
            #df1=dist_point_from_line(mf,cf,x1,y1,1,1) #distance of the profile end point from the Fault line

            #mp=-1./mf  # slope of profile which is perpendicualr to the fault line
            # correcting the end point of the profile to be on a line perpendicular to the Fault
            #x1=int((df0+df1)/np.sqrt(1+mp**2)+x0)
            #y1=int(mp*(x1-x0)+y0)

        except:
            print('*********************************************')
            print('No information about the Fault coordinates!')
            print('*********************************************')

        #############################################################################
        y0, x0, y1, x1 = get_start_end_point(Xf0, Yf0, Xf1, Yf1,
                                             profile_Length, dx, dy)

        try:
            x0
            y0
            x1
            y1
        except:
            fig = plt.figure()
            ax = fig.add_subplot(111)
            ax.imshow(z)
            try:
                ax.plot([Xf0, Xf1], [Yf0, Yf1], 'k-')
            except:
                print('Fault line is not specified')

            xc = []
            yc = []
            print('please click on start and end point of the desired profile')

            def onclick(event):
                if event.button == 1:
                    print('click')
                    xc.append(int(event.xdata))
                    yc.append(int(event.ydata))

            cid = fig.canvas.mpl_connect('button_press_event', onclick)
            plt.show()
            x0 = xc[0]
            x1 = xc[1]
            y0 = yc[0]
            y1 = yc[1]

        ##############################################################################
        print('******************************************************')
        print('First profile coordinates:')
        print('Start point:  y = ' + str(y0) + ',x = ' + str(x0))
        print('End point:   y = ' + str(y1) + '  , x = ' + str(x1))
        print('')
        print(str(y0) + ',' + str(x0))
        print(str(y1) + ',' + str(x1))
        print('******************************************************')
        length = int(np.hypot(x1 - x0, y1 - y0))
        x, y = np.linspace(x0, x1, length), np.linspace(y0, y1, length)
        zi = z[y.astype(np.int), x.astype(np.int)]
        try:
            lat_transect = lat_all[y.astype(np.int), x.astype(np.int)]
            lon_transect = lon_all[y.astype(np.int), x.astype(np.int)]
        except:
            lat_transect = 'Nan'
            lon_transect = 'Nan'

        # zi=get_transect(z,x0,y0,x1,y1)

        try:
            dx = float(
                h5file[k[0]].attrs['X_STEP']) * 6375000.0 * np.pi / 180.0
            dy = float(
                h5file[k[0]].attrs['Y_STEP']) * 6375000.0 * np.pi / 180.0
            DX = (x - x0) * dx
            DY = (y - y0) * dy
            D = np.hypot(DX, DY)
            print('geo coordinate:')
            print('profile length = ' + str(D[-1] / 1000.0) + ' km')
            #df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)
        except:
            dx = float(h5file[k[0]].attrs['RANGE_PIXEL_SIZE'])
            dy = float(h5file[k[0]].attrs['AZIMUTH_PIXEL_SIZE'])
            DX = (x - x0) * dx
            DY = (y - y0) * dy
            D = np.hypot(DX, DY)
            print('radar coordinate:')
            print('profile length = ' + str(D[-1] / 1000.0) + ' km')
            #df0_km=dist_point_from_line(mf,cf,x0,y0,dx,dy)

        try:
            mf, cf = line(Xf0, Yf0, Xf1, Yf1)
            df0_km = dist_point_from_line(mf, cf, x0, y0, dx, dy)
        except:
            print('Fault line is not specified')

        transect = np.zeros([len(D), ntrans])
        transect[:, 0] = zi
        XX0 = []
        XX1 = []
        YY0 = []
        YY1 = []
        XX0.append(x0)
        XX1.append(x1)
        YY0.append(y0)
        YY1.append(y1)

        if ntrans > 1:

            m = float(y1 - y0) / float((x1 - x0))
            c = float(y0 - m * x0)
            m1 = -1.0 / m
            try:
                dp
            except:
                dp = 1.0
            if lat_transect == 'Nan':
                for i in range(1, ntrans):

                    X0 = i * dp / np.sqrt(1 + m1**2) + x0
                    Y0 = m1 * (X0 - x0) + y0
                    X1 = i * dp / np.sqrt(1 + m1**2) + x1
                    Y1 = m1 * (X1 - x1) + y1
                    zi = get_transect(z, X0, Y0, X1, Y1)
                    transect[:, i] = zi
                    XX0.append(X0)
                    XX1.append(X1)
                    YY0.append(Y0)
                    YY1.append(Y1)
            else:
                transect_lat = np.zeros([len(D), ntrans])
                transect_lat[:, 0] = lat_transect
                transect_lon = np.zeros([len(D), ntrans])
                transect_lon[:, 0] = lon_transect

                for i in range(1, ntrans):

                    X0 = i * dp / np.sqrt(1 + m1**2) + x0
                    Y0 = m1 * (X0 - x0) + y0
                    X1 = i * dp / np.sqrt(1 + m1**2) + x1
                    Y1 = m1 * (X1 - x1) + y1
                    zi = get_transect(z, X0, Y0, X1, Y1)
                    lat_transect = get_transect(lat_all, X0, Y0, X1, Y1)
                    lon_transect = get_transect(lon_all, X0, Y0, X1, Y1)
                    transect[:, i] = zi
                    transect_lat[:, i] = lat_transect
                    transect_lon[:, i] = lon_transect
                    XX0.append(X0)
                    XX1.append(X1)
                    YY0.append(Y0)
                    YY1.append(Y1)

        #############################################
        try:
            m_prof_edge, c_prof_edge = line(XX0[0], YY0[0], XX0[-1], YY0[-1])
        except:
            print('Plotting one profile')

        ###############################################################################
        if flip_profile == 'yes':
            transect = np.flipud(transect)
            try:
                df0_km = np.max(D) - df0_km
            except:
                print('')

        print('******************************************************')
        try:
            gpsFile
        except:
            gpsFile = 'Nogps'
        print('GPS velocity file:')
        print(gpsFile)
        print('*******************************************************')
        if os.path.isfile(gpsFile):
            insarData = z
            del z
            fileName, fileExtension = os.path.splitext(gpsFile)
            #print fileExtension
            #if fileExtension =='.cmm4':
            #    print 'reading cmm4 velocities'
            #    Stations, gpsData = redGPSfile_cmm4(gpsFile)
            #    idxRef=Stations.index(refStation)
            #    Lon,Lat,Ve,Vn,Se,Sn,Corr,Hrate,H12=gpsData[idxRef,:]
            #    Lon=Lon-360.0
            #    Lat,Lon,Ve,Se,Vn,Sn,Corr,NumEpochs,timeSpan,AvgEpochTimes = gpsData[idxRef,:]
            #    Vu=0
            #else:
            #    Stations, gpsData = redGPSfile(gpsFile)
            #    idxRef=Stations.index(refStation)
            #    Lat,Lon,Vn,Ve,Sn,Se,Corr,Vu,Su = gpsData[idxRef,:]

            Stations, Lat, Lon, Ve, Se, Vn, Sn = readGPSfile(
                gpsFile, gps_source)
            idxRef = Stations.index(refStation)
            Length, Width = np.shape(insarData)
            lat, lon, lat_step, lon_step = get_lat_lon(h5file)
            IDYref, IDXref = find_row_column(Lon[idxRef], Lat[idxRef], lon,
                                             lat, lon_step, lat_step)
            if (not np.isnan(IDYref)) and (not np.isnan(IDXref)):
                print('referencing InSAR data to the GPS station at : ' +
                      str(IDYref) + ' , ' + str(IDXref))
                if not np.isnan(insarData[IDYref][IDXref]):
                    transect = transect - insarData[IDYref][IDXref]
                    insarData = insarData - insarData[IDYref][IDXref]

                else:

                    print(""" 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
      
      WARNING: nan value for InSAR data at the refernce pixel!
               reference station should be a pixel with valid value in InSAR data.
                               
               please select another GPS station as the reference station.

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%                       
                   """)
                    sys.exit(1)
            else:
                print('WARNING:')
                print(
                    'Reference GPS station is out of the area covered by InSAR data'
                )
                print(
                    'please select another GPS station as the reference station.'
                )
                sys.exit(1)

            try:
                stationsList
            except:
                stationsList = Stations

            # theta=23.0*np.pi/180.0
            if os.path.isfile(incidence_file):
                print('Using exact look angle for each pixel')
                h5file_theta = h5py.File(incidence_file, 'r')
                dset = h5file_theta['mask'].get('mask')
                theta = dset[0:dset.shape[0], 0:dset.shape[1]]
                theta = theta * np.pi / 180.0
            else:
                print('Using average look angle')
                theta = np.ones(np.shape(insarData)) * 23.0 * np.pi / 180.0

            heading = 193.0 * np.pi / 180.0

            #unitVec=[-np.sin(theta)*np.sin(heading),-np.cos(heading)*np.sin(theta),-np.cos(theta)]
            # -np.cos(theta)]
            unitVec = [
                np.cos(heading) * np.sin(theta),
                -np.sin(theta) * np.sin(heading), 0
            ]

            #  [0.0806152480932643, 0.34918300221540616, -0.93358042649720174]
            # print unitVec
            # unitVec=[0.3,-0.09,0.9]
            # unitVec=[-0.3,0.09,-0.9]
            # unitVec=[-0.3,0.09,0]

            # print '*******************************************'
            # print 'unit vector to project GPS to InSAR LOS:'
            # print unitVec
            # print '*******************************************'
            # gpsLOS_ref=unitVec[0]*Ve[idxRef]+unitVec[1]*Vn[idxRef]#+unitVec[2]*Vu[idxRef]

            gpsLOS_ref = gps_to_LOS(Ve[idxRef], Vn[idxRef],
                                    theta[IDYref, IDXref], heading)
            print('%%%%%%^^^^^^^%%%%%%%%')
            print(gpsLOS_ref / 1000.0)
            # insarData=insarData -gpsLOS_ref/1000.0
            # transect = transect -gpsLOS_ref/1000.0

            GPS = []
            GPS_station = []
            GPSx = []
            GPSy = []
            GPS_lat = []
            GPS_lon = []
            for st in stationsList:
                try:
                    idx = Stations.index(st)

                    #gpsLOS = unitVec[0]*Ve[idx]+unitVec[1]*Vn[idx]#+unitVec[2]*Vu[idx]

                    #gpsLOS = gps_to_LOS(Ve[idx],Vn[idx],theta[idx],heading)
                    #gpsLOS=gpsLOS-gpsLOS_ref

                    IDY, IDX = find_row_column(Lon[idx], Lat[idx], lon, lat,
                                               lon_step, lat_step)
                    print(theta[IDY, IDX])
                    gpsLOS = gps_to_LOS(Ve[idx], Vn[idx], theta[IDY, IDX],
                                        heading)
                    #gpsLOS = gpsLOS-gpsLOS_ref

                    if which_gps == 'all':
                        if theta[IDY, IDX] != 0.0:
                            GPS.append(gpsLOS - gpsLOS_ref)
                            GPS_station.append(st)
                            GPSx.append(IDX)
                            GPSy.append(IDY)
                            GPS_lat.append(Lat[idx])
                            GPS_lon.append(Lon[idx])
                    elif not np.isnan(insarData[IDY][IDX]):
                        if theta[IDY, IDX] != 0.0:
                            GPS.append(gpsLOS - gpsLOS_ref)
                            GPS_station.append(st)
                            GPSx.append(IDX)
                            GPSy.append(IDY)
                            GPS_lat.append(Lat[idx])
                            GPS_lon.append(Lon[idx])
                except:
                    NoInSAR = 'yes'

            DistGPS = []
            GPS_in_bound = []
            GPS_in_bound_st = []
            GPSxx = []
            GPSyy = []
            for i in range(len(GPS_station)):
                gx = GPSx[i]
                gy = GPSy[i]

                if which_gps in ['all', 'insar']:
                    check_result = 'True'
                else:
                    check_result = check_st_in_box(gx, gy, x0, y0, x1, y1, X0,
                                                   Y0, X1, Y1)

                if check_result == 'True':
                    check_result2 = check_st_in_box2(gx, gy, x0, y0, x1, y1,
                                                     X0, Y0, X1, Y1)
                    GPS_in_bound_st.append(GPS_station[i])
                    GPS_in_bound.append(GPS[i])
                    GPSxx.append(GPSx[i])
                    GPSyy.append(GPSy[i])
                    # gy=y0+1
                    # gx=x0+1
                    # gxp,gyp=get_intersect(m,c,gx,gy)
                    # Dx=dx*(gx-gxp);Dy=dy*(gy-gyp)
                    # print gxp
                    # print gyp
                    # distance of GPS station from the first profile line
                    dg = dist_point_from_line(m, c, gx, gy, 1, 1)
                    # DistGPS.append(np.hypot(Dx,Dy))
                    # X0=dg/np.sqrt(1+m1**2)+x0
                    # Y0=m1*(X0-x0)+y0
                    # DistGPS.append(np.hypot(dx*(gx-X0), dy*(gy-Y0)))

                    DistGPS.append(
                        dist_point_from_line(m_prof_edge, c_prof_edge, GPSx[i],
                                             GPSy[i], dx, dy))

            print('****************************************************')
            print('GPS stations in the profile area:')
            print(GPS_in_bound_st)
            print('****************************************************')
            GPS_in_bound = np.array(GPS_in_bound)
            DistGPS = np.array(DistGPS)
            #axes[1].plot(DistGPS/1000.0, -1*GPS_in_bound/1000, 'bo')

        if gpsFile == 'Nogps':

            insarData = z
            GPSxx = []
            GPSyy = []
            GPSx = []
            GPSy = []
            GPS = []
            XX0[0] = x0
            XX1[0] = x1
            YY0[0] = y0
            YY1[0] = y1

        print('****************')
        print('flip up-down')
        print(flip_updown)

        if flip_updown == 'yes' and gpsFile != 'Nogps':
            print('Flipping up-down')
            transect = -1 * transect
            GPS_in_bound = -1 * GPS_in_bound
        elif flip_updown == 'yes':
            print('Flipping up-down')
            transect = -1 * transect

        if flip_profile == 'yes' and gpsFile != 'Nogps':

            GPS = np.flipud(GPS)
            GPS_in_bound = np.flipud(GPS_in_bound)
            DistGPS = np.flipud(max(D) - DistGPS)

        fig, axes = plt.subplots(nrows=2)
        axes[0].imshow(insarData)
        for i in range(ntrans):
            axes[0].plot([XX0[i], XX1[i]], [YY0[i], YY1[i]], 'r-')

        axes[0].plot(GPSx, GPSy, 'b^')
        axes[0].plot(GPSxx, GPSyy, 'k^')
        if gpsFile != 'Nogps':
            axes[0].plot(IDXref, IDYref, 'r^')
        axes[0].axis('image')
        axes[1].plot(D / 1000.0, transect, 'ko', ms=1)

        avgInSAR = np.array(nanmean(transect, axis=1))
        stdInSAR = np.array(nanstd(transect, axis=1))

        # std=np.std(transect,1)
        # axes[1].plot(D/1000.0, avgInSAR, 'r-')
        try:
            axes[1].plot(DistGPS / 1000.0,
                         -1 * GPS_in_bound / 1000,
                         'b^',
                         ms=10)
        except:
            print('')
        # pl.fill_between(x, y-error, y+error,alpha=0.6, facecolor='0.20')
        # print transect
        #############################################################################

        fig2, axes2 = plt.subplots(nrows=1)
        axes2.imshow(insarData)
        # for i in range(ntrans):
        axes2.plot([XX0[0], XX1[0]], [YY0[0], YY1[0]], 'k-')
        axes2.plot([XX0[-1], XX1[-1]], [YY0[-1], YY1[-1]], 'k-')
        axes2.plot([XX0[0], XX0[-1]], [YY0[0], YY0[-1]], 'k-')
        axes2.plot([XX1[0], XX1[-1]], [YY1[0], YY1[-1]], 'k-')

        try:
            axes2.plot([Xf0, Xf1], [Yf0, Yf1], 'k-')
        except:
            FaultLine = 'None'

        axes2.plot(GPSx, GPSy, 'b^')
        axes2.plot(GPSxx, GPSyy, 'k^')
        if gpsFile != 'Nogps':
            axes2.plot(IDXref, IDYref, 'r^')
        axes2.axis('image')

        figName = 'transect_area_' + str(Np) + '.png'
        print('writing ' + figName)
        plt.savefig(figName)

        #############################################################################
        fig = plt.figure()
        fig.set_size_inches(10, 4)
        ax = plt.Axes(
            fig,
            [0., 0., 1., 1.],
        )
        ax = fig.add_subplot(111)
        if display_InSAR in ['on', 'On', 'ON']:
            ax.plot(D / 1000.0,
                    transect * 1000,
                    'o',
                    ms=1,
                    mfc='Black',
                    linewidth='0')

        ############################################################################
        # save the profile data:
        if save_to_mat in ['ON', 'on', 'On']:
            import scipy.io as sio
            matFile = 'transect' + str(Np) + '.mat'
            dataset = {}
            dataset['datavec'] = transect
            try:
                dataset['lat'] = transect_lat
                dataset['lon'] = transect_lon
            except:
                dataset['lat'] = 'Nan'
                dataset['lon'] = 'Nan'
            dataset['Unit'] = 'm'
            dataset['Distance_along_profile'] = D
            print('*****************************************')
            print('')
            print('writing transect to >>> ' + matFile)
            sio.savemat(matFile, {'dataset': dataset})
            print('')
            print('*****************************************')

        #############################################################################
        if display_Standard_deviation in ['on', 'On', 'ON']:

            for i in np.arange(0.0, 1.01, 0.01):
                # ,color='#DCDCDC')#'LightGrey')
                ax.plot(D / 1000.0, (avgInSAR - i * stdInSAR) * 1000,
                        '-',
                        color='#DCDCDC',
                        alpha=0.5)
            for i in np.arange(0.0, 1.01, 0.01):
                ax.plot(D / 1000.0, (avgInSAR + i * stdInSAR) * 1000,
                        '-',
                        color='#DCDCDC',
                        alpha=0.5)  # 'LightGrey')
        #############################################################################
        if display_Average in ['on', 'On', 'ON']:
            ax.plot(D / 1000.0, avgInSAR * 1000, 'r-')
        ###########
        try:
            ax.plot(DistGPS / 1000.0,
                    -1 * GPS_in_bound,
                    '^',
                    ms=10,
                    mfc='Cyan')
        except:
            print('')
        ax.set_ylabel('LOS velocity [mm/yr]', fontsize=26)
        ax.set_xlabel('Distance along profile [km]', fontsize=26)

        ###################################################################
        # lower and higher bounds for diplaying the profile

        try:
            lbound
            hbound
        except:
            lbound = np.nanmin(transect) * 1000
            hbound = np.nanmax(transect) * 1000

        ###################################################################
        # To plot the Fault location on the profile
        ax.plot([df0_km / 1000.0, df0_km / 1000.0], [lbound, hbound],
                '--',
                color='black',
                linewidth='2')

        ###################################################################

        try:
            ax.set_ylim(lbound, hbound)
        except:
            ylim = 'no'

        try:
            ax.set_xlim(x_lbound, x_hbound)
        except:
            xlim = 'no'

        ##########
        # Temporary To plot DEM
        # majorLocator = MultipleLocator(5)
        # ax.yaxis.set_major_locator(majorLocator)
        # minorLocator   = MultipleLocator(1)
        # ax.yaxis.set_minor_locator(minorLocator)

        # plt.tick_params(which='major', length=15,width=2)
        # plt.tick_params(which='minor', length=6,width=2)

        # try:
        #    for tick in ax.xaxis.get_major_ticks():
        #             tick.label.set_fontsize(26)
        #    for tick in ax.yaxis.get_major_ticks():
        #             tick.label.set_fontsize(26)
        #
        #    plt.tick_params(which='major', length=15,width=2)
        #    plt.tick_params(which='minor', length=6,width=2)
        # except:
        #    print 'couldn not fix the ticks! '

        figName = 'transect_' + str(Np) + '.png'
        print('writing ' + figName)
        plt.savefig(figName)
        print('')
        print('________________________________')
def demo(image_name, image_no, image_index, net):

    conf_thresh = 0.3
    min_boxes = 15
    max_boxes = 15
    indexes = []
    cfg.TEST.NMS = 0.6

    im = cv2.imread(
        os.path.join(
            "/media/sadaf/e4da0f25-29be-4c9e-a432-3193ff5f5baf/Code/AWA_data/Animals_with_Attributes2/clean_images",
            image_name))

    scores, boxes, attr_scores, rel_scores = im_detect(net, im)

    # Keep the original boxes, don't worry about the regression bbox outputs
    rois = net.blobs['rois'].data.copy()
    # unscale back to raw image space
    blobs, im_scales = _get_blobs(im, None)

    cls_boxes = rois[:, 1:5] / im_scales[0]
    print(len(cls_boxes))
    cls_prob = net.blobs['cls_prob'].data
    attr_prob = net.blobs['attr_prob'].data
    pool5 = net.blobs['pool5_flat'].data

    # Keep only the best detections
    max_conf = np.zeros((rois.shape[0]))
    for cls_ind in range(1, cls_prob.shape[1]):
        cls_scores = scores[:, cls_ind]
        dets = np.hstack(
            (cls_boxes, cls_scores[:, np.newaxis])).astype(np.float32)
        keep = np.array(nms(dets, cfg.TEST.NMS))

        max_conf[keep] = np.where(cls_scores[keep] > max_conf[keep],
                                  cls_scores[keep], max_conf[keep])

    keep_boxes = np.where(max_conf >= conf_thresh)[0]

    if len(keep_boxes) < min_boxes:
        keep_boxes = np.argsort(max_conf)[::-1][:min_boxes]
    elif len(keep_boxes) > max_boxes:
        keep_boxes = np.argsort(max_conf)[::-1][:max_boxes]
        ############################
    att_unique = np.unique(att_names[image_index * scale:(image_index * scale +
                                                          scale)])
    print(att_unique)
    att_unique_adv = np.unique(
        att_names_adv[image_index * scale:(image_index * scale + scale)])
    cls_unique = np.unique(att_cls[image_index * scale:(image_index * scale +
                                                        scale)])
    print(cls_unique)
    cls_unique_adv = np.unique(
        att_cls_adv[image_index * scale:(image_index * scale + scale)])
    im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
    sizes = np.shape(im)
    height = float(sizes[0])
    width = float(sizes[1])
    fig = plt.figure()
    fig.set_size_inches(width / height, 1, forward=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    plt.imshow(im)

    boxes = cls_boxes[keep_boxes]
    #print (boxes)
    #print (keep_boxes)
    objects = np.argmax(cls_prob[keep_boxes][:, 1:], axis=1)
    attr_thresh = 0.1
    attr = np.argmax(attr_prob[keep_boxes][:, 1:], axis=1)
    attr_conf = np.max(attr_prob[keep_boxes][:, 1:], axis=1)
    count_box = 0
    colors = [
        "blue", "green", "red", "cyan", "magenta", "yellow", "black", "white",
        "darkblue", "orchid", "springgreen", "lime", "deepskyblue",
        "mediumvioletred", "maroon", "orangered"
    ]

    for i in range(len(keep_boxes)):
        bbox = boxes[i]
        if bbox[0] == 0:
            bbox[0] = 1
        if bbox[1] == 0:
            bbox[1] = 1
        #cls = classes[objects[i]+1]
        if attr_conf[i] > attr_thresh:
            #for k in range (len(att_unique)):
            #   for l in range (len(cls_unique)):
            #if attributes[attr[i]+1]==att_unique[k]:
            #   if classes[objects[i]+1] == cls_unique[l]:
            #if attributes[attr[i]+1] not in att_unique_adv:
            #if classes[objects[i]+1] not in cls_unique_adv:

            if attributes[attr[i] + 1] in att_unique:

                #cls = attributes[attr[i]+1] + " " + classes[objects[i]+1]
                cls = attributes[attr[i] + 1]
                count_box = count_box + 1
                plt.gca().add_patch(
                    plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1],
                                  fill=False,
                                  edgecolor=colors[i],
                                  linewidth=0.3,
                                  alpha=0.5))
                plt.gca().text(bbox[0],
                               bbox[1] + 30,
                               '%s' % (cls),
                               bbox=dict(facecolor='blue',
                                         alpha=0,
                                         linewidth=0.2),
                               fontsize=2,
                               color=colors[i])
            # if classes[objects[i]+1] in att_unique:
            #
            #         cls1 =classes[objects[i]+1]
            #
            #         plt.gca().add_patch(plt.Rectangle((bbox[0], bbox[1]),bbox[2] - bbox[0],bbox[3] - bbox[1], fill=False,edgecolor='red', linewidth=0.3, alpha=0.5))
            #         plt.gca().text(bbox[2]-30, bbox[3],'%s' % (cls1),bbox=dict(facecolor='blue', alpha=0,linewidth=0.2),fontsize=1.5, color='red')

    plt.savefig(
        '/media/sadaf/e4da0f25-29be-4c9e-a432-3193ff5f5baf/Code/AWA_data/Animals_with_Attributes2/clean_images_1/clean_bb{}.jpg'
        .format(image_no),
        dpi=1500)
    #plt.savefig('/media/sadaf/e4da0f25-29be-4c9e-a432-3193ff5f5baf/Code/Pytorch_Code/transfer_learn/pytorch-adversarial_box/plots_AT_NoAT/adv_bb_AT/adv_bb_AT{}_25.jpg'.format(image_no), dpi = 1500)
    plt.close()
def vis_one_image(im,
                  im_name,
                  output_dir,
                  boxes,
                  segms=None,
                  keypoints=None,
                  body_uv=None,
                  thresh=0.9,
                  kp_thresh=2,
                  dpi=200,
                  box_alpha=0.0,
                  dataset=None,
                  show_class=False,
                  ext='jpg',
                  frame_no=None):
    """Visual debugging of detections."""
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    if isinstance(boxes, list):
        boxes, segms, keypoints, classes = convert_from_cls_format(
            boxes, segms, keypoints)

    if boxes is None or boxes.shape[0] == 0 or max(boxes[:, 4]) < thresh:
        print("Box:None, Shape:0 or MaxBox below thresh")
        return

    dataset_keypoints, _ = keypoint_utils.get_keypoints()

    if segms is not None and len(segms) > 0:
        masks = mask_util.decode(segms)

    optime = time.time()
    color_list = colormap(rgb=True) / 255

    kp_lines = kp_connections(dataset_keypoints)
    cmap = plt.get_cmap('rainbow')
    colors = [cmap(i) for i in np.linspace(0, 1, len(kp_lines) + 2)]

    fig = plt.figure(frameon=False)
    fig.set_size_inches(im.shape[1] / dpi, im.shape[0] / dpi)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.axis('off')
    fig.add_axes(ax)
    ax.imshow(im)

    # Display in largest to smallest order to reduce occlusion
    areas = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
    sorted_inds = np.argsort(-areas)

    mask_color_id = 0
    for i in sorted_inds:
        bbox = boxes[i, :4]
        score = boxes[i, -1]
        if score < thresh:
            continue

        # show box (off by default)
        ax.add_patch(
            plt.Rectangle((bbox[0], bbox[1]),
                          bbox[2] - bbox[0],
                          bbox[3] - bbox[1],
                          fill=False,
                          edgecolor='g',
                          linewidth=0.5,
                          alpha=box_alpha))

        if show_class:
            ax.text(bbox[0],
                    bbox[1] - 2,
                    get_class_string(classes[i], score, dataset),
                    fontsize=10,
                    family='serif',
                    bbox=dict(facecolor='g',
                              alpha=0.4,
                              pad=0,
                              edgecolor='none'),
                    color='white')

        # show mask
        if segms is not None and len(segms) > i:
            img = np.ones(im.shape)
            color_mask = color_list[mask_color_id % len(color_list), 0:3]
            mask_color_id += 1

            w_ratio = .4
            for c in range(3):
                color_mask[c] = color_mask[c] * (1 - w_ratio) + w_ratio
            for c in range(3):
                img[:, :, c] = color_mask[c]
            e = masks[:, :, i]

            _, contour, hier = cv2.findContours(e.copy(), cv2.RETR_CCOMP,
                                                cv2.CHAIN_APPROX_NONE)

            for c in contour:
                polygon = Polygon(c.reshape((-1, 2)),
                                  fill=True,
                                  facecolor=color_mask,
                                  edgecolor='w',
                                  linewidth=1.2,
                                  alpha=0.5)
                ax.add_patch(polygon)
        # show keypoints
        if keypoints is not None and len(keypoints) > i:
            kps = keypoints[i]
            plt.autoscale(False)
            for l in range(len(kp_lines)):
                i1 = kp_lines[l][0]
                i2 = kp_lines[l][1]
                if kps[2, i1] > kp_thresh and kps[2, i2] > kp_thresh:
                    x = [kps[0, i1], kps[0, i2]]
                    y = [kps[1, i1], kps[1, i2]]
                    line = plt.plot(x, y)
                    plt.setp(line, color=colors[l], linewidth=1.0, alpha=0.7)
                if kps[2, i1] > kp_thresh:
                    plt.plot(kps[0, i1],
                             kps[1, i1],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)

                if kps[2, i2] > kp_thresh:
                    plt.plot(kps[0, i2],
                             kps[1, i2],
                             '.',
                             color=colors[l],
                             markersize=3.0,
                             alpha=0.7)

            # add mid shoulder / mid hip for better visualization
            mid_shoulder = (
                kps[:2, dataset_keypoints.index('right_shoulder')] +
                kps[:2, dataset_keypoints.index('left_shoulder')]) / 2.0
            sc_mid_shoulder = np.minimum(
                kps[2, dataset_keypoints.index('right_shoulder')],
                kps[2, dataset_keypoints.index('left_shoulder')])
            mid_hip = (kps[:2, dataset_keypoints.index('right_hip')] +
                       kps[:2, dataset_keypoints.index('left_hip')]) / 2.0
            sc_mid_hip = np.minimum(
                kps[2, dataset_keypoints.index('right_hip')],
                kps[2, dataset_keypoints.index('left_hip')])
            if (sc_mid_shoulder > kp_thresh
                    and kps[2, dataset_keypoints.index('nose')] > kp_thresh):
                x = [mid_shoulder[0], kps[0, dataset_keypoints.index('nose')]]
                y = [mid_shoulder[1], kps[1, dataset_keypoints.index('nose')]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines)],
                         linewidth=1.0,
                         alpha=0.7)
            if sc_mid_shoulder > kp_thresh and sc_mid_hip > kp_thresh:
                x = [mid_shoulder[0], mid_hip[0]]
                y = [mid_shoulder[1], mid_hip[1]]
                line = plt.plot(x, y)
                plt.setp(line,
                         color=colors[len(kp_lines) + 1],
                         linewidth=1.0,
                         alpha=0.7)
        #   DensePose Visualization Starts!!
        ##  Get full IUV image out
    IUV_fields = body_uv[1]
    #
    All_Coords = np.zeros(im.shape)
    All_inds = np.zeros([im.shape[0], im.shape[1]])
    K = 26
    ##
    inds = np.argsort(boxes[:, 4])
    ##
    for i, ind in enumerate(inds):
        entry = boxes[ind, :]
        if entry[4] > 0.65:
            entry = entry[0:4].astype(int)
            ####
            output = IUV_fields[ind]
            ####
            All_Coords_Old = All_Coords[entry[1]:entry[1] + output.shape[1],
                                        entry[0]:entry[0] + output.shape[2], :]
            All_Coords_Old[All_Coords_Old == 0] = output.transpose(
                [1, 2, 0])[All_Coords_Old == 0]
            All_Coords[entry[1]:entry[1] + output.shape[1],
                       entry[0]:entry[0] + output.shape[2], :] = All_Coords_Old
            ###
            CurrentMask = (output[0, :, :] > 0).astype(np.float32)
            All_inds_old = All_inds[entry[1]:entry[1] + output.shape[1],
                                    entry[0]:entry[0] + output.shape[2]]
            All_inds_old[All_inds_old ==
                         0] = CurrentMask[All_inds_old == 0] * i
            All_inds[entry[1]:entry[1] + output.shape[1],
                     entry[0]:entry[0] + output.shape[2]] = All_inds_old
    #
    All_Coords[:, :, 1:3] = 255. * All_Coords[:, :, 1:3]
    All_Coords[All_Coords > 255] = 255.
    All_Coords = All_Coords.astype(np.uint8)
    All_inds = All_inds.astype(np.uint8)
    # pdb.set_trace()
    #draw frame and contours to canvas
    plt.contour(All_Coords[:, :, 1] / 256., 10, linewidths=1)
    plt.contour(All_Coords[:, :, 2] / 256., 10, linewidths=1)
    plt.contour(All_inds, linewidths=3)
    plt.axis('off')
    fig.canvas.draw()
    # convert canvas to image
    img = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
    img = img.reshape(fig.canvas.get_width_height()[::-1] + (3, ))
    # img is rgb, convert to opencv's default bgr
    img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
    # display image with opencv
    cv2.imshow("plot", img)
    cv2.waitKey(1)
    print('\t-Visualized in {: .3f}s'.format(time.time() - optime))
    output_name = os.path.basename(im_name) + '.' + ext
    filetime = time.time()
    out_dir = os.path.join(output_dir, 'vid')
    if not os.path.exists(out_dir):
        os.mkdir(out_dir)
    out_file = 'file%02d.png' % frame_no
    fig.savefig(os.path.join(out_dir, out_file), dpi=dpi)
    print('\t-Output file wrote in {: .3f}s'.format(time.time() - filetime))
    plt.close('all')
    return True
Exemple #29
0
print('\nPlotting results')

#fig = plt.figure(figsize=(10, 1.2),frameon= False)
#gs = gridspec.GridSpec(1, 10, wspace=0.05, hspace=0.05)

label = np.argmax(y_tmp, axis=1)
proba = np.max(y_tmp, axis=1)
for i in range(10):

    probab = "-".join(str(proba[i]).split("."))
    Shape = X_tmp[i].shape
    fig = plt.figure(figsize=(Shape[1] / 100, Shape[0] / 100),
                     dpi=100,
                     frameon=False)
    ax = plt.Axes(fig, [0., 0., 1., 1.])
    ax.set_axis_off()
    fig.add_axes(ax)
    im = plt.imshow(X_tmp[i])
    fig.savefig("img/{}_{}".format(label[i], probab),
                bbbox_inches='tight',
                pad_inches=0,
                dpi=100)
    plt.show()
#    ax = fig.add_subplot(gs[0, i])
#
#    plt.show(x_tmp[i])
#    ax.imshow(X_tmp[i], cmap='gray', interpolation='none')
#    ax.set_xticks([])
#    ax.set_yticks([])
#    ax.set_xlabel('{0} ({1:.2f})'.format(label[i], proba[i]),
Exemple #30
0
    def __init__(self, sequence):
        super().__init__()
        self.sequence = sequence
        self.frames = load_frames('sequences/' + self.sequence)
        self.num_frames, self.height, self.width = self.frames.shape[:3]
        # # init model
        self.model = model(self.frames)

        # set window
        self.setWindowTitle('Demo: Interaction-and-Propagation Network')
        self.setGeometry(100, 100, self.width, self.height + 100)

        # buttons
        self.prev_button = QPushButton('Prev')
        self.prev_button.clicked.connect(self.on_prev)
        self.next_button = QPushButton('Next')
        self.next_button.clicked.connect(self.on_next)
        self.play_button = QPushButton('Play')
        self.play_button.clicked.connect(self.on_play)
        self.run_button = QPushButton('Propagate!')
        self.run_button.clicked.connect(self.on_run)

        # LCD
        self.lcd = QTextEdit()
        self.lcd.setReadOnly(True)
        self.lcd.setMaximumHeight(28)
        self.lcd.setMaximumWidth(100)
        self.lcd.setText('{: 3d} / {: 3d}'.format(0, self.num_frames - 1))

        # slide
        self.slider = QSlider(Qt.Horizontal)
        self.slider.setMinimum(0)
        self.slider.setMaximum(self.num_frames - 1)
        self.slider.setValue(0)
        self.slider.setTickPosition(QSlider.TicksBelow)
        self.slider.setTickInterval(1)
        self.slider.valueChanged.connect(self.slide)

        # combobox
        self.combo = QComboBox(self)
        self.combo.addItem("fade")
        self.combo.addItem("davis")
        self.combo.addItem("checker")
        self.combo.addItem("color")
        self.combo.currentTextChanged.connect(self.set_viz_mode)

        # canvas
        self.fig = plt.Figure()
        self.ax = plt.Axes(self.fig, [0., 0., 1., 1.])
        self.ax.set_axis_off()
        self.fig.add_axes(self.ax)

        self.canvas = FigureCanvas(self.fig)

        self.cidpress = self.fig.canvas.mpl_connect('button_press_event',
                                                    self.on_press)
        self.cidrelease = self.fig.canvas.mpl_connect('button_release_event',
                                                      self.on_release)
        self.cidmotion = self.fig.canvas.mpl_connect('motion_notify_event',
                                                     self.on_motion)

        # navigator
        navi = QHBoxLayout()
        navi.addWidget(self.lcd)
        navi.addWidget(self.prev_button)
        navi.addWidget(self.play_button)
        navi.addWidget(self.next_button)
        navi.addStretch(1)
        navi.addWidget(QLabel('Overlay Mode'))
        navi.addWidget(self.combo)
        navi.addStretch(1)
        navi.addWidget(self.run_button)

        layout = QVBoxLayout()
        layout.addWidget(self.canvas)
        layout.addWidget(self.slider)
        layout.addLayout(navi)
        layout.setStretchFactor(navi, 1)
        layout.setStretchFactor(self.canvas, 0)
        self.setLayout(layout)

        # timer
        self.timer = QTimer()
        self.timer.setSingleShot(False)
        self.timer.timeout.connect(self.on_time)

        # initialize visualize
        self.viz_mode = 'fade'
        self.current_mask = np.zeros(
            (self.num_frames, self.height, self.width), dtype=np.uint8)
        self.cursur = 0
        self.on_showing = None
        self.show_current()

        # initialize action
        self.reset_scribbles()
        self.pressed = False
        self.on_drawing = None
        self.drawn_strokes = []

        self.show()