Ejemplo n.º 1
0
def plot_cube(color_cube, fig=None, save=False, title=None, path='', normalize=True):
    assert isinstance(color_cube, metrics_color.ColorDensityCube)
    win = color_cube.get_win()
    res = color_cube.get_res()
    fig_was_none = False

    if not fig:
        fig_was_none = True
        fig = plt.figure(dpi=200)
        ax = fig.add_subplot(111, projection='3d')
    else:
        ax = fig.add_subplot(212, projection='3d')

    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([0.8, 1.2, 0.8, 1.1]))  # burst view along y-axis
    ax.view_init(azim=-25)

    half_win_size = win//2
    print(half_win_size)
    axis = xrange(half_win_size, 256, win)
    for x in axis:
        for y in axis:
            if normalize:
                size = color_cube.get_normalized()[int(x / win)][int(y / win)] * 5000 / res
            else:
                size = color_cube.get_cube()[int(x / win)][int(y / win)] * 5000 / res
                print(size)
            color = [np.repeat(x/256, res),
                     np.repeat(y/256, res),
                     np.array(xrange(half_win_size, 256, win)) / 256.0]
            color = np.swapaxes(color, 0, 1)
            ec = np.where(size >= 0.0, 'w', 'r')
            size = abs(size)
            ax.scatter(x, y, axis, c=color, s=size, edgecolor=ec, alpha=1)

    if fig_was_none:
        plt.tight_layout()
        fig.text(0.5, 0.975, title, ha='center')
        plt.show()
    else:
        return
    if save:
        assert title is not None
        fig.text(0.5, 0.975, title, ha='center')
        plt.savefig(path + title + '.png')
    plt.close()
Ejemplo n.º 2
0
def short_proj():
    return np.dot(Axes3D.get_proj(ax), scale)
Ejemplo n.º 3
0
def plot3D_stack_with_overlay_1axes(ax,
                                    na_3D,
                                    overlay_image,
                                    intensity_threshold,
                                    trim_overlay,
                                    overlay_type,
                                    a_shapes,
                                    view_init,
                                    shape_by_shape,
                                    number=1,
                                    label_exterior_cell=1,
                                    verbose=True):
    z, y, x = na_3D.shape

    if shape_by_shape:
        ax.get_xaxis().set_visible(False)
        ax.get_yaxis().set_visible(False)
        ax.get_xaxis().set_ticks([])
        ax.get_yaxis().set_ticks([])
        ax.set_zlim(z, 0)  # slice 27 on the bottom
        #ax.get_zaxis().set_visible(False)
    else:
        ax.set_xlabel("x-axis")
        ax.set_ylabel("y-axis")
        ax.set_zlabel("z-axis")

    ax.set_xlim(0, x)
    ax.set_ylim(0, y)
    ax.set_zlim(z, 0)  # slice 27 on the bottom

    #scaling the axis not supported in 3D for the moment.  this oneliner solves this
    #https://stackoverflow.com/questions/30223161/matplotlib-mplot3d-how-to-increase-the-size-of-an-axis-stretch-in-a-3d-plo/30315313
    Z_DIM, Y_DIM, X_DIM = na_3D.shape
    pixel_width = 0.1294751  #check fiji properties
    voxel_depth = 1.1000000
    xy_vs_z = voxel_depth / pixel_width  #1 increase in z corresponds to 5 increases in x or y in absolute distance (= back of the envelope calculation)
    Z_STRETCH = 3  # for plotting purposes it can be handy to stretch the z-dimension so you are better able to see the slices. 1=true size ; 5 = max otherwise it will not fit the page anymore
    ax.get_proj = lambda: np.dot(
        Axes3D.get_proj(ax),
        np.diag([X_DIM / Y_DIM, 1, Z_DIM / Y_DIM * xy_vs_z * Z_STRETCH, 1]))

    if view_init: ax.view_init(elev=view_init[0], azim=view_init[1])

    Z1, Y1, X1 = np.where(na_3D > intensity_threshold)
    if verbose:
        if not len(Z1):
            print('warning : no points selected with intensity threshold',
                  intensity_threshold, 'the maximum of the stack is ',
                  np.max(na_3D))
    ax.plot3D(X1, Y1, Z1, **d_plot3D_red)

    #show a levelset / overlag image over the original
    if len(overlay_image):
        if overlay_type == 'level_set':
            print(overlay_image.dtype)
            if trim_overlay:
                trim_overlay(na_3D, overlay_image)
            Z2, Y2, X2 = np.where(
                overlay_image < 0)  #interior of a level set is below zero
        else:
            Z2, Y2, X2 = np.where(overlay_image != 0)

        ax.plot3D(X2, Y2, Z2, **d_plot3D_green)

    #draw spheres inside the cell
    if len(a_shapes):
        for ix_stack, shape in enumerate(a_shapes):
            if shape.__class__.__name__ == 'Sphere':
                ax.plot3D(*shape.l_pixels[::-1],
                          **d_plot3D_sphere,
                          color=a_colors[ix_stack % len(a_colors)])
            if shape.__class__.__name__ == 'Cell':
                if shape.l_pixels:
                    ax.plot3D(*shape.l_pixels[::-1],
                              **d_plot3D_sphere,
                              color=a_colors[ix_stack % len(a_colors)])
                    if shape.label == label_exterior_cell:
                        s_title = str(shape.label) + '(exterior)'
                        ax.set_title(s_title)
                    else:
                        ax.set_title(shape.label)


#     if len(a_spheres):
#         for ix_stack, d_sphere in enumerate(a_spheres):
#             ax.plot3D(*get_3D_sphere_coordinates(d_sphere['centre'], d_sphere['radius'], na_3D.shape, verbose=False)[::-1],
#                       **d_plot3D_sphere,color=a_colors[ix_stack%len(a_colors)])

    if shape_by_shape:
        pass
    else:
        title = 'intensity_threshold={0};view_init={1}'.format(
            intensity_threshold, view_init)
        ax.set_title(title, fontsize=12)

    if view_init:
        d_save_info['f_name'] = d_save_info['f_name'] + '_elev=' + str(
            view_init[0]) + '_azim=' + str(view_init[1])

    return
Ejemplo n.º 4
0
 def short_proj():
     return np.dot(Axes3D.get_proj(ax), scale)
Ejemplo n.º 5
0
def demo_visualize(args, demo_loader, disp_net, ego_pose_net, obj_pose_net):
    global device
    torch.set_printoptions(sci_mode=False)
    np.set_printoptions(suppress=True)
    # np.set_printoptions(formatter={'all':lambda x: str(x)})

    # switch to eval mode
    disp_net.eval().to(device)
    ego_pose_net.eval().to(device)
    obj_pose_net.eval().to(device)

    ego_global_mat = np.identity(4)
    ego_global_mats = [ego_global_mat]
    
    objOs, objHs, objXs, objYs, objZs = [], [], [], [], []
    objIDs = []
    colors = ['yellow', 'lightskyblue', 'lime', 'magenta', 'orange', 'coral', 'gold', 'cyan']

    vidx = 0

    for i, (ref_img, tgt_img, ref_seg, tgt_seg, intrinsics, intrinsics_inv) in enumerate(demo_loader):

        ref_img = ref_img.to(device)
        tgt_img = tgt_img.to(device)
        ref_seg = ref_seg.to(device)
        tgt_seg = tgt_seg.to(device)
        intrinsics = intrinsics.to(device)
        intrinsics_inv = intrinsics_inv.to(device)

        # input instance masking
        ref_bg_mask = 1 - (ref_seg[:,1:].sum(dim=1, keepdim=True)>0).float()
        tgt_bg_mask = 1 - (tgt_seg[:,1:].sum(dim=1, keepdim=True)>0).float()
        ref_bg_img = ref_img * ref_bg_mask * tgt_bg_mask
        tgt_bg_img = tgt_img * ref_bg_mask * tgt_bg_mask
        num_inst = int(ref_seg[:,0,0,0])
        num_insts = [[num_inst], [num_inst]]

        # tracking info
        if len(objIDs) == 0: 
            objIDs.append( np.arange(num_inst).tolist() )
        else:
            # -> ref_seg의 인스턴스들이 tgt_seg_prev의 몇 번째 채널 인스턴스에 매칭되는가?
            p2c_iou, p2c_idx = inst_iou(ref_seg.cpu(), tgt_seg_prev.cpu(), torch.ones(1,1,ref_seg.size(2),ref_seg.size(3)).type_as(ref_seg).cpu())
            p2c_iou = p2c_iou[1:]
            p2c_idx = p2c_idx[1:] - 1
            newColorID = list(set(np.arange(len(colors))) - set(objIDs[-1]))
            newID = []

            for ii, iou in enumerate(p2c_iou):
                if iou > 0.5:
                    newID.append( objIDs[-1][int(p2c_idx[ii])] )
                elif iou != iou:
                    break;
                else:
                    newID.append( newColorID[0] )
                    newColorID = newColorID[1:]
            objIDs.append( newID )

        
        tgt_seg_prev = tgt_seg.clone()
        tgt_seg_prev[0,0] = 0
        objIDs_flatten = list(itertools.chain.from_iterable(objIDs))
        # pdb.set_trace()
        '''
            # plt.close('all')
            ea1 = 4; ea2 = 5; ii = 1;
            fig = plt.figure(99, figsize=(20, 10))
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(tgt_seg_prev[0,0].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(tgt_seg_prev[0,1].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(tgt_seg_prev[0,2].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(tgt_seg_prev[0,3].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(tgt_seg_prev[0,4].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(ref_seg[0,0].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(ref_seg[0,1].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(ref_seg[0,2].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(ref_seg[0,3].cpu()); plt.colorbar(); 
            fig.add_subplot(ea1,ea2,ii); ii += 1;
            plt.imshow(ref_seg[0,4].cpu()); plt.colorbar(); 
            plt.tight_layout(); plt.ion(); plt.show()

        '''

        # compute depth & camera motion
        ref_depth = 1 / disp_net(ref_img)
        tgt_depth = 1 / disp_net(tgt_img)
        ego_pose = ego_pose_net(tgt_bg_img, ref_bg_img)
        ego_pose_inv = ego_pose_net(ref_bg_img, tgt_bg_img)
        # ego_pose = ego_pose_net(tgt_img, ref_img)
        # ego_pose_inv = ego_pose_net(ref_img, tgt_img)

        ego_mat = pose_vec2mat(ego_pose).squeeze(0).cpu().detach().numpy()
        ego_mat = np.concatenate([ego_mat, np.array([0, 0, 0, 1]).reshape(1,4)], axis=0)
        ego_global_mat = ego_global_mat @ ego_mat
        ego_global_mats.append(ego_global_mat)

        ### Batch-wise computing ###    rtt_Is, rtt_Ms, prj_Ds, cmp_Ds  --> (19.11.18) change from fw to iw for bg regions
        ### Outputs:  warped-masked-bg-img,  valid-bg-mask,  valid-bg-proj-depth,  valid-bg-comp-depth ###
        ### NumScales(1) >> NumRefs(2) >> I-M-D-D(4) >> 2B(fwd/bwd)xCxHxW ###
        IMDDs = compute_batch_bg_warping(tgt_img, [ref_img, ref_img], [tgt_bg_mask, tgt_bg_mask], [ref_bg_mask, ref_bg_mask], 
                                         tgt_depth, [ref_depth, ref_depth], [ego_pose, ego_pose], [ego_pose_inv, ego_pose_inv], intrinsics)


        ref_obj_img = ref_img.repeat(num_inst,1,1,1) * ref_seg[0,1:1+num_inst].unsqueeze(1)
        tgt_obj_img = tgt_img.repeat(num_inst,1,1,1) * tgt_seg[0,1:1+num_inst].unsqueeze(1)
        ref_obj_mask = ref_seg[0,1:1+num_inst].unsqueeze(1)
        tgt_obj_mask = tgt_seg[0,1:1+num_inst].unsqueeze(1)
        ref_obj_depth = ref_depth.repeat(num_inst,1,1,1) * ref_seg[0,1:1+num_inst].unsqueeze(1)
        tgt_obj_depth = tgt_depth.repeat(num_inst,1,1,1) * tgt_seg[0,1:1+num_inst].unsqueeze(1)

        _, _, _, _,    r2t_obj_imgs, r2t_obj_masks, _, r2t_obj_sc_depths = \
            compute_reverse_warp_ego([ref_depth, ref_depth], [ref_obj_img, ref_obj_img], [ref_obj_mask, ref_obj_mask], [ego_pose_inv, ego_pose_inv], intrinsics, num_insts)
        _, _, _, _,    t2r_obj_imgs, t2r_obj_masks, _, t2r_obj_sc_depths = \
            compute_reverse_warp_ego([tgt_depth, tgt_depth], [tgt_obj_img, tgt_obj_img], [tgt_obj_mask, tgt_obj_mask], [ego_pose, ego_pose], intrinsics, num_insts)

        obj_pose = obj_pose_net(tgt_obj_img, r2t_obj_imgs[0])
        obj_pose_inv = obj_pose_net(ref_obj_img, t2r_obj_imgs[0])
        obj_pose = torch.cat([obj_pose, torch.zeros_like(obj_pose)], dim=1)
        obj_pose_inv = torch.cat([obj_pose_inv, torch.zeros_like(obj_pose_inv)], dim=1)

        obj_mat = pose_vec2mat(obj_pose).cpu().detach().numpy()
        obj_mat = np.concatenate([obj_mat, np.array([0, 0, 0, 1]).reshape(1,1,4).repeat(obj_pose.size(0),axis=0)], axis=1)
        obj_global_mat = ego_global_mat.reshape(1,4,4).repeat(obj_pose.size(0),axis=0) @ obj_mat

        obj_IMDDs, obj_ovls = compute_batch_obj_warping(tgt_img, [ref_img, ref_img], [tgt_obj_mask, tgt_obj_mask], [ref_obj_mask, ref_obj_mask], tgt_depth, [ref_depth, ref_depth], 
                                                        [ego_pose, ego_pose], [ego_pose_inv, ego_pose_inv], [obj_pose, obj_pose], [obj_pose_inv, obj_pose_inv], intrinsics, num_insts)


        tr_fwd, tr_bwd = compute_obj_translation(r2t_obj_sc_depths, t2r_obj_sc_depths, [tgt_obj_depth, tgt_obj_depth], [ref_obj_depth, ref_obj_depth], num_insts, intrinsics)
        
        rtt_obj_imgs, rtt_obj_masks, rtt_obj_depths, rtt_obj_sc_depths = compute_reverse_warp_obj(r2t_obj_sc_depths, r2t_obj_imgs, r2t_obj_masks, [-obj_pose, -obj_pose], intrinsics.repeat(num_inst,1,1), num_insts)
        # pdb.set_trace()
        '''
            sq = 0; bb = 0; 
            plt.close('all')
            plt.figure(1); plt.imshow(r2t_obj_sc_depths[sq][bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(2); plt.imshow(r2t_sc_depths[sq][0,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(3); plt.imshow(rtt_obj_sc_depths[sq][bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(4); plt.imshow(rtt_obj_sc_depth_2[bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(5); plt.imshow(rev_d2f[bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(6); plt.imshow(d2f[bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(7); plt.imshow(norm[bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(8); plt.imshow(r2t_obj_masks[sq][bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(9); plt.imshow(rtt_obj_masks[sq][bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()
            plt.figure(10); plt.imshow(rtt_obj_imgs[sq][bb,0].detach().cpu()); plt.colorbar(); plt.ion(); plt.show()

        '''

        _, _, r2t_ego_projected_depth, r2t_ego_computed_depth = inverse_warp2(ref_img, tgt_depth, ego_pose, intrinsics, ref_depth)


        ### KITTI ###
        if 'kitti' in args.data:
            xlim_1 = 0.25; ylim_1 = 0.1;  zlim_1 = 1.2;
            xlim_2 = 0.25; ylim_2 = 0.1;  zlim_2 = 1.2;
            obj_vo_scale = 3.0
            ego_vo_scale = 0.015

        ### CS ###
        if 'cityscapes' in args.data:
            xlim_1 = 0.1;  ylim_1 = 0.1;  zlim_1 = 0.4;
            xlim_2 = 0.12;  ylim_2 = 0.06;  zlim_2 = 0.60;
            obj_vo_scale = 3.0
            ego_vo_scale = 0.005


        ego_init_o = np.array([0,0,0,1]).reshape(4,1)
        ego_init_x = np.array([ego_vo_scale*1,0,0,1]).reshape(4,1)
        ego_init_y = np.array([0,ego_vo_scale*1,0,1]).reshape(4,1)
        ego_init_z = np.array([0,0,ego_vo_scale*1,1]).reshape(4,1)
        egoOs = np.array([mat @ ego_init_o for mat in ego_global_mats])[:,:3,0]
        egoXs = np.array([mat @ ego_init_x for mat in ego_global_mats])[:,:3,0]
        egoYs = np.array([mat @ ego_init_y for mat in ego_global_mats])[:,:3,0]
        egoZs = np.array([mat @ ego_init_z for mat in ego_global_mats])[:,:3,0]

        bbox_y = dict(boxstyle='round', facecolor='yellow', alpha=0.5)
        bbox_c = dict(boxstyle='round', facecolor='coral', alpha=0.5)
        bbox_m = dict(boxstyle='round', facecolor='magenta', alpha=0.5)
        bbox_l = dict(boxstyle='round', facecolor='lime', alpha=0.5)
        bbox_w = dict(boxstyle='round', facecolor='white', alpha=0.5)
        bbox_b = dict(boxstyle='round', facecolor='deepskyblue', alpha=0.5)
        # pdb.set_trace()

        sq = 0; bb = 0;
        r2t_objs_coords = pixel2cam(r2t_obj_sc_depths[0][:,0], intrinsics.inverse().repeat(num_inst,1,1))
        rtt_objs_coords = pixel2cam(rtt_obj_sc_depths[0][:,0], intrinsics.inverse().repeat(num_inst,1,1))
        tgt_objs_coords = pixel2cam(tgt_obj_depth[:,0], intrinsics.inverse().repeat(num_inst,1,1))
        r2t_obj_3d_locs = []
        rtt_obj_3d_locs = []
        tgt_obj_3d_locs = []
        for r2t_obj_coords in r2t_objs_coords: r2t_obj_3d_locs.append(torch.cat([coords[coords!=0].mean().unsqueeze(0) for coords in r2t_obj_coords]))
        for rtt_obj_coords in rtt_objs_coords: rtt_obj_3d_locs.append(torch.cat([coords[coords!=0].mean().unsqueeze(0) for coords in rtt_obj_coords]))
        for tgt_obj_coords in tgt_objs_coords: tgt_obj_3d_locs.append(torch.cat([coords[coords!=0].mean().unsqueeze(0) for coords in tgt_obj_coords]))
        for obj_loc in tgt_obj_3d_locs: objOs.append( (ego_global_mat @ np.concatenate([obj_loc.detach().cpu().numpy(), np.array([1])]).reshape(4,1))[:3].squeeze() );
        objHs_pred, objHs_comp = [], []
        for ii in range(len(obj_pose_inv)): objHs_pred.append( (ego_global_mat @ np.concatenate([tgt_obj_3d_locs[ii].detach().cpu().numpy(), np.array([1])]).reshape(4,1))[:3].squeeze() + obj_vo_scale*obj_pose_inv[ii].detach().cpu().numpy()[:3] )
        for ii in range(len(obj_pose_inv)): objHs_comp.append( (ego_global_mat @ np.concatenate([tgt_obj_3d_locs[ii].detach().cpu().numpy(), np.array([1])]).reshape(4,1))[:3].squeeze() - obj_vo_scale*tr_fwd[0][ii].detach().cpu().numpy() )
        for pred, comp in zip(objHs_pred, objHs_comp): objHs.append( (pred + comp) / 2 )
        r2t_obj_3d_loc = torch.stack(r2t_obj_3d_locs).unsqueeze(-1).unsqueeze(-1)
        r2t_obj_homo, _ = cam2homo(r2t_obj_3d_loc, intrinsics.repeat(num_inst,1,1), torch.zeros([1,3,1]).cuda())
        r2t_obj_tail = r2t_obj_homo.reshape(num_inst,2).detach().cpu().numpy()
        r2t_obj_trans = -obj_pose[:,:3]
        r2t_obj_trans_gt = -tr_fwd[0]
        r2t_obj_3d_loc_tr = r2t_obj_3d_loc.reshape(num_inst,3) + r2t_obj_trans
        r2t_obj_3d_loc_tr_gt = r2t_obj_3d_loc.reshape(num_inst,3) + r2t_obj_trans_gt
        r2t_obj_homo_tr, _ = cam2homo(r2t_obj_3d_loc_tr.unsqueeze(-1).unsqueeze(-1), intrinsics.repeat(num_inst,1,1), torch.zeros([1,3,1]).cuda())
        r2t_obj_homo_tr_gt, _ = cam2homo(r2t_obj_3d_loc_tr_gt.unsqueeze(-1).unsqueeze(-1), intrinsics.repeat(num_inst,1,1), torch.zeros([1,3,1]).cuda())
        r2t_obj_head = r2t_obj_homo_tr.reshape(num_inst,2).detach().cpu().numpy()
        r2t_obj_head_gt = r2t_obj_homo_tr_gt.reshape(num_inst,2).detach().cpu().numpy()
        arr_scale = 1.5
        tgt = (tgt_img[bb%args.batch_size]*0.5+0.5).detach().cpu().numpy().transpose(1,2,0)
        tgt_inst = 1 - tgt_bg_mask[bb].repeat(3,1,1).detach().cpu().numpy().transpose(1,2,0)
        tgt_masked = (tgt + 0.2 * tgt_inst).clip(max=1.0)
        ref = (ref_img[bb%args.batch_size]*0.5+0.5).detach().cpu().numpy().transpose(1,2,0)
        ref_inst = 1 - ref_bg_mask[bb].repeat(3,1,1).detach().cpu().numpy().transpose(1,2,0)
        ref_masked = (ref + 0.2 * ref_inst).clip(max=1.0)
        d_tgt = 1/tgt_depth.detach().cpu()[bb%args.batch_size,0]
        d_ref = 1/ref_depth.detach().cpu()[bb%args.batch_size,0]
        r2t_obj = (r2t_obj_imgs[0].sum(dim=0) * 0.5 + 0.5).detach().cpu().numpy().transpose(1,2,0) if num_inst != 0 else np.zeros([256,832,3])
        tgt_obj = (tgt_obj_img.sum(dim=0) * 0.5 + 0.5).detach().cpu().numpy().transpose(1,2,0) if num_inst != 0 else np.zeros([256,832,3])
        i_w_bg = (IMDDs[sq][0] * 0.5 + 0.5)[bb].detach().cpu().numpy().transpose(1,2,0)
        i_w_obj = (obj_IMDDs[sq][0] * 0.5 + 0.5)[bb].detach().cpu().numpy().transpose(1,2,0)
        i_w = ((IMDDs[sq][0] + obj_IMDDs[sq][0]) * 0.5 + 0.5)[bb].detach().cpu().numpy().transpose(1,2,0)
        m_w = obj_IMDDs[sq][1][0].repeat(3,1,1).detach().cpu().numpy().transpose(1,2,0)
        i_w_masked = i_w + 0.2 * m_w
        d_diff = ( ((IMDDs[sq][3] + obj_IMDDs[sq][3]) - (IMDDs[sq][2] + obj_IMDDs[sq][2])).abs() / ((IMDDs[sq][3] + obj_IMDDs[sq][3]) + (IMDDs[sq][2] + obj_IMDDs[sq][2])).abs().clamp(min=1e-3) ).clamp(0,1)[bb,0].detach().cpu()
        d_diff_ego = ( (r2t_ego_projected_depth - r2t_ego_computed_depth).abs() / (r2t_ego_projected_depth + r2t_ego_computed_depth).abs().clamp(min=1e-3) ).clamp(0,1)[bb,0].detach().cpu() * (IMDDs[sq][1] + obj_IMDDs[sq][1])[bb,0].detach().cpu()
        occ = 1.5 * d_diff.unsqueeze(-1).repeat(1,1,3).numpy()
        occ[:,:,2] = 0
        occ[occ<0.1] = 0
        i_w_occ = (i_w_masked + occ).clip(max=1.0)
        tgt_diff = np.abs(i_w-tgt).mean(axis=2)
        th = 5; samp = 20;
        r2t_obj_coords = r2t_objs_coords.sum(dim=0, keepdim=True)
        rtt_obj_coords = rtt_objs_coords.sum(dim=0, keepdim=True)
        tgt_obj_coords = tgt_objs_coords.sum(dim=0, keepdim=True)
        r2t_filt = np.abs(stats.zscore( r2t_obj_coords[bb,2].view(-1)[r2t_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy() )) < th
        rtt_filt = np.abs(stats.zscore( rtt_obj_coords[bb,2].view(-1)[rtt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy() )) < th
        tgt_filt = np.abs(stats.zscore( tgt_obj_coords[bb,2].view(-1)[tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy() )) < th
        npts_r2t = int(r2t_filt.sum())
        npts_rtt = int(rtt_filt.sum())
        npts_tgt = int(tgt_filt.sum())
        X_r2t = r2t_obj_coords[bb,0].view(-1)[r2t_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[r2t_filt][range(0,npts_r2t,samp)]
        Y_r2t = r2t_obj_coords[bb,1].view(-1)[r2t_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[r2t_filt][range(0,npts_r2t,samp)]
        Z_r2t = r2t_obj_coords[bb,2].view(-1)[r2t_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[r2t_filt][range(0,npts_r2t,samp)]
        C_r2t = r2t_obj_imgs[0].sum(dim=0).view(3,-1)[:,r2t_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[:,r2t_filt][:,range(0,npts_r2t,samp)] * 0.5 + 0.5
        C_r2t[0] = 1; C_r2t[1] = 0; C_r2t[2] = 0;
        X_rtt = rtt_obj_coords[bb,0].view(-1)[rtt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[rtt_filt][range(0,npts_rtt,samp)]
        Y_rtt = rtt_obj_coords[bb,1].view(-1)[rtt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[rtt_filt][range(0,npts_rtt,samp)]
        Z_rtt = rtt_obj_coords[bb,2].view(-1)[rtt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[rtt_filt][range(0,npts_rtt,samp)]
        C_rtt = rtt_obj_imgs[0].sum(dim=0).view(3,-1)[:,rtt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[:,rtt_filt][:,range(0,npts_rtt,samp)] * 0.5 + 0.5
        C_rtt[0] = 1; C_rtt[1] = 1; C_rtt[2] = 0;
        X_tgt = tgt_obj_coords[bb,0].view(-1)[tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[tgt_filt][range(0,npts_tgt,samp)]
        Y_tgt = tgt_obj_coords[bb,1].view(-1)[tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[tgt_filt][range(0,npts_tgt,samp)]
        Z_tgt = tgt_obj_coords[bb,2].view(-1)[tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[tgt_filt][range(0,npts_tgt,samp)]
        C_tgt = tgt_obj_img.sum(dim=0).view(3,-1)[:,tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[:,tgt_filt][:,range(0,npts_tgt,samp)] * 0.5 + 0.5
        C_tgt[0] = 0; C_tgt[1] = 0; C_tgt[2] = 1;
        XYZ_global_tgt = np.expand_dims(ego_global_mat, axis=0).repeat(X_tgt.shape[0],axis=0) @ np.expand_dims(np.stack([X_tgt, Y_tgt, Z_tgt, np.ones([X_tgt.shape[0]])]).transpose(1,0), axis=-1)
        C_global_tgt = (tgt_obj_img.sum(dim=0).view(3,-1)[:,tgt_obj_coords[bb].mean(dim=0).view(-1)!=0].detach().cpu().numpy()[:,tgt_filt][:,range(0,npts_tgt,samp)] * 0.5 + 0.5).clip(min=0.0, max=1.0)

        plt.close('all')
        fig = plt.figure(1, figsize=(1920/100, 1080/100), dpi=100)   # figsize=(23, 13)
        gs = GridSpec(nrows=5, ncols=6)
        text_xy = [7, -16]
        text_fd = {'family': 'sans', 'size': 13, 'color': 'black', 'style': 'italic'}
        fig.add_subplot(gs[0, 0:2])
        plt.imshow(ref_masked, vmax=1); plt.text(text_xy[0], text_xy[1], "$I_{t}$", fontdict=text_fd);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        if not args.save_fig: plt.grid(linestyle=':', linewidth=0.4);
        plt.text(55, -29, "Scene: {}, Iter: {}".format(args.data, i), fontsize=6.5);
        plt.text(55, -9, "Model: {}".format(args.pretrained_disp), fontsize=6.5);
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[0, 2:4])
        plt.imshow(d_ref, cmap='turbo', vmax=14); plt.text(text_xy[0], text_xy[1], "$D_{t}$", fontdict=text_fd);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[1, 0:2])
        plt.imshow(tgt_masked, vmax=1); plt.text(text_xy[0], text_xy[1], "$I_{t+1}$", fontdict=text_fd);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[1, 2:4])
        plt.imshow(d_tgt, cmap='turbo', vmax=14); plt.text(text_xy[0], text_xy[1], "$D_{t+1}$", fontdict=text_fd);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[2, 0:2])
        plt.imshow(r2t_obj, vmax=1); plt.text(text_xy[0], text_xy[1], "Ego-warped objects with motion", fontdict=text_fd, size=10);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4)
        plt.text(130, 250, "*ego speed {:0.4f},  6-DoF {}".format(float(ego_pose[0,:3].pow(2).sum().sqrt()), ego_pose[0].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_b, ha='left', va='bottom');
        if num_inst > 0: plt.text(7, 7,  "Obj-1: {:0.4f} {}".format(float(obj_pose[0,:3].pow(2).sum().sqrt()), obj_pose[0][:3].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_m, ha='left', va='top');
        if num_inst > 0 and not args.save_fig: plt.text(330, 7, "#1: {:0.4f} {}".format(float(tr_fwd[0].pow(2).sum(dim=1).sqrt()[0]), tr_fwd[0][0].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_c, ha='left', va='top');
        if num_inst > 1: plt.text(7, 31, "Obj-2: {:0.4f} {}".format(float(obj_pose[1,:3].pow(2).sum().sqrt()), obj_pose[1][:3].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_m, ha='left', va='top');
        if num_inst > 1 and not args.save_fig: plt.text(330, 31, "#2: {:0.4f} {}".format(float(tr_fwd[0].pow(2).sum(dim=1).sqrt()[1]), tr_fwd[0][1].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_c, ha='left', va='top');
        if num_inst > 2: plt.text(7, 55, "Obj-3: {:0.4f} {}".format(float(obj_pose[2,:3].pow(2).sum().sqrt()), obj_pose[2][:3].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_m, ha='left', va='top');
        if num_inst > 2 and not args.save_fig: plt.text(330, 55, "#3: {:0.4f} {}".format(float(tr_fwd[0].pow(2).sum(dim=1).sqrt()[2]), tr_fwd[0][2].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_c, ha='left', va='top');
        if num_inst > 3: plt.text(7, 79, "Obj-4: {:0.4f} {}".format(float(obj_pose[3,:3].pow(2).sum().sqrt()), obj_pose[3][:3].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_m, ha='left', va='top');
        if num_inst > 3 and not args.save_fig: plt.text(330, 79, "#4: {:0.4f} {}".format(float(tr_fwd[0].pow(2).sum(dim=1).sqrt()[3]), tr_fwd[0][3].detach().cpu().numpy().round(4)), fontsize=7, bbox=bbox_c, ha='left', va='top');
        if num_inst > 0 and not args.save_fig: plt.arrow(r2t_obj_tail[0,0], r2t_obj_tail[0,1], arr_scale*(-r2t_obj_tail[0,0]+r2t_obj_head_gt[0,0]), arr_scale*(-r2t_obj_tail[0,1]+r2t_obj_head_gt[0,1]), width=2, head_width=9, head_length=9, color='red', alpha=1); 
        if num_inst > 0: plt.arrow(r2t_obj_tail[0,0], r2t_obj_tail[0,1], arr_scale*(-r2t_obj_tail[0,0]+r2t_obj_head[0,0]), arr_scale*(-r2t_obj_tail[0,1]+r2t_obj_head[0,1]), width=3, head_width=10, head_length=9, color='magenta', alpha=1); 
        if num_inst > 0: plt.text(r2t_obj_tail[0,0]-30, r2t_obj_tail[0,1]+25, "1: {:0.4f}".format(float(obj_pose[0,:3].pow(2).sum().sqrt())), fontsize=7, bbox=bbox_l);
        if num_inst > 1 and not args.save_fig: plt.arrow(r2t_obj_tail[1,0], r2t_obj_tail[1,1], arr_scale*(-r2t_obj_tail[1,0]+r2t_obj_head_gt[1,0]), arr_scale*(-r2t_obj_tail[1,1]+r2t_obj_head_gt[1,1]), width=2, head_width=9, head_length=9, color='red', alpha=1); 
        if num_inst > 1: plt.arrow(r2t_obj_tail[1,0], r2t_obj_tail[1,1], arr_scale*(-r2t_obj_tail[1,0]+r2t_obj_head[1,0]), arr_scale*(-r2t_obj_tail[1,1]+r2t_obj_head[1,1]), width=3, head_width=10, head_length=9, color='magenta', alpha=1); 
        if num_inst > 1: plt.text(r2t_obj_tail[1,0]-30, r2t_obj_tail[1,1]+25, "2: {:0.4f}".format(float(obj_pose[1,:3].pow(2).sum().sqrt())), fontsize=7, bbox=bbox_l);
        if num_inst > 2 and not args.save_fig: plt.arrow(r2t_obj_tail[2,0], r2t_obj_tail[2,1], arr_scale*(-r2t_obj_tail[2,0]+r2t_obj_head_gt[2,0]), arr_scale*(-r2t_obj_tail[2,1]+r2t_obj_head_gt[2,1]), width=2, head_width=9, head_length=9, color='red', alpha=1); 
        if num_inst > 2: plt.arrow(r2t_obj_tail[2,0], r2t_obj_tail[2,1], arr_scale*(-r2t_obj_tail[2,0]+r2t_obj_head[2,0]), arr_scale*(-r2t_obj_tail[2,1]+r2t_obj_head[2,1]), width=3, head_width=10, head_length=9, color='magenta', alpha=1); 
        if num_inst > 2: plt.text(r2t_obj_tail[2,0]-30, r2t_obj_tail[2,1]+25, "3: {:0.4f}".format(float(obj_pose[2,:3].pow(2).sum().sqrt())), fontsize=7, bbox=bbox_l);
        if num_inst > 3 and not args.save_fig: plt.arrow(r2t_obj_tail[3,0], r2t_obj_tail[3,1], arr_scale*(-r2t_obj_tail[3,0]+r2t_obj_head_gt[3,0]), arr_scale*(-r2t_obj_tail[3,1]+r2t_obj_head_gt[3,1]), width=2, head_width=9, head_length=9, color='red', alpha=1); 
        if num_inst > 3: plt.arrow(r2t_obj_tail[3,0], r2t_obj_tail[3,1], arr_scale*(-r2t_obj_tail[3,0]+r2t_obj_head[3,0]), arr_scale*(-r2t_obj_tail[3,1]+r2t_obj_head[3,1]), width=3, head_width=10, head_length=9, color='magenta', alpha=1); 
        if num_inst > 3: plt.text(r2t_obj_tail[3,0]-30, r2t_obj_tail[3,1]+25, "4: {:0.4f}".format(float(obj_pose[3,:3].pow(2).sum().sqrt())), fontsize=7, bbox=bbox_l);
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[3, 0:2])
        plt.imshow(i_w_occ, vmax=1); plt.text(text_xy[0], text_xy[1], "Final synthesis (yellow: dis/occlusion)", fontdict=text_fd, size=10);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);
        fig.add_subplot(gs[4, 0:2])
        plt.imshow(tgt_diff, cmap='bone', vmax=0.5); plt.text(text_xy[0], text_xy[1], "$I_{diff}$", fontdict=text_fd);
        plt.xticks([]) and plt.yticks([]) if args.save_fig else plt.grid(linestyle=':', linewidth=0.4) 
        plt.xlim(0, 832-1); plt.ylim(256-1, 0);

        ### 3d plot 1: cam-coord ###
        ax1 = fig.add_subplot(gs[3:5, 2:4], projection='3d')
        ax1_axfont = {'family': 'sans', 'size': 12, 'weight': 'heavy', 'style': 'italic', 'color': 'gray'}
        ax1_titlefont = {'family': 'sans', 'size': 12, 'color': 'black', 'ha': 'center', 'va': 'bottom', 'linespacing': 2}
        ax1_annotfont = {'family': 'sans', 'size': 8, 'color': 'black', 'ha': 'center', 'va': 'center'}
        ax1.scatter(X_r2t, Y_r2t, Z_r2t, c=C_r2t.transpose(1,0), s=1, alpha=0.4)
        ax1.scatter(X_rtt, Y_rtt, Z_rtt, c=C_rtt.transpose(1,0), s=1, alpha=0.6)
        ax1.scatter(X_tgt, Y_tgt, Z_tgt, c=C_tgt.transpose(1,0), s=1, alpha=0.4)
        ax1.set_xlabel('X', fontdict=ax1_axfont); ax1.set_zlabel('Z', fontdict=ax1_axfont);
        ax1.axes.yaxis.set_ticklabels([])
        ax1.set_xlim(-xlim_1, xlim_1)
        ax1.set_ylim(-ylim_1, ylim_1)
        ax1.set_zlim(0,       zlim_1)
        ax1.text(0, 0, zlim_1*1.20, "[Top-view] Objects in {$t+1$} frame on camera coordinate\n(red: ego-warped $t$→$t+1$, yellow: final-warped $t$→$t+1$, blue: $t+1$)", fontdict=ax1_titlefont)
        if num_inst > 0: ax1.text(-xlim_1/2, 0, zlim_1*1.10, "1: XYZ {}".format(r2t_obj_3d_locs[0].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_c);
        if num_inst > 0: ax1.text(-xlim_1/2, 0, zlim_1*1.05, "1: XYZ {}".format(rtt_obj_3d_locs[0].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_y);
        if num_inst > 0: ax1.text(-xlim_1/2, 0, zlim_1*1.00, "1: XYZ {}".format(tgt_obj_3d_locs[0].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_b);
        if num_inst > 1: ax1.text(+xlim_1/2, 0, zlim_1*1.10, "2: XYZ {}".format(r2t_obj_3d_locs[1].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_c);
        if num_inst > 1: ax1.text(+xlim_1/2, 0, zlim_1*1.05, "2: XYZ {}".format(rtt_obj_3d_locs[1].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_y);
        if num_inst > 1: ax1.text(+xlim_1/2, 0, zlim_1*1.00, "2: XYZ {}".format(tgt_obj_3d_locs[1].detach().cpu().numpy().round(4)), fontdict=ax1_annotfont, bbox=bbox_b);
        ax1.get_proj = lambda: np.dot(Axes3D.get_proj(ax1), np.diag([2, 1, 2, 1]))
        ax1.view_init(elev=0, azim=-90)

        ### 3d plot 2: world-coord ###
        ax2 = fig.add_subplot(gs[1:5, 4:6], projection='3d')
        ax2_axfont = {'family': 'sans', 'size': 14, 'weight': 'heavy', 'style': 'italic', 'color': 'gray'}
        ax2_titlefont = {'family': 'sans', 'size': 12, 'color': 'black', 'ha': 'center', 'va': 'center'}
        ax2.scatter(XYZ_global_tgt[:,0,0], XYZ_global_tgt[:,1,0], XYZ_global_tgt[:,2,0], c=C_global_tgt.transpose(1,0), s=6, zorder=vidx+1)
        for ii in range(len(egoOs)-1): dR.drawVector(ax2, egoOs[ii], egoOs[ii+1], mutation_scale=1, alpha=0.5, arrowstyle='-', lineStyle=':', lineWidth=1, lineColor='k', zorder=vidx+2)
        for ii in range(len(egoOs)):
            if ii >= len(egoOs) - 1:
                dR.drawPointWithAxis(ax2, egoOs[ii], egoXs[ii]-egoOs[ii], egoYs[ii]-egoOs[ii], egoZs[ii]-egoOs[ii], mutation_scale=3, alpha=1.0, arrowstyle='-', lineWidth=2.0, vectorLength=2, zorder=vidx+4)
                ax2.text(egoOs[ii,0]-xlim_2/4, egoOs[ii,1], egoOs[ii,2], "{:0.4f}".format( np.linalg.norm(egoOs[ii]-egoOs[ii-1]).round(4) ), fontsize=8, bbox=bbox_w, zorder=vidx+5);
            else:
                dR.drawPointWithAxis(ax2, egoOs[ii], egoXs[ii]-egoOs[ii], egoYs[ii]-egoOs[ii], egoZs[ii]-egoOs[ii], mutation_scale=1, alpha=0.4, arrowstyle='-', lineWidth=1.5, zorder=vidx+3)
        for ii in range(len(objOs)):
            if ii >= len(objOs) - len(obj_pose):
                if np.linalg.norm(objOs[ii]-objHs[ii], 2) < 0.1 and np.linalg.norm(objOs[ii]-objHs[ii], 2) > 0.002:
                    dR.drawVector(ax2, objOs[ii], objHs[ii], mutation_scale=20, arrowstyle='fancy', pointEnable=False, lineWidth=0.5, faceColor=colors[objIDs_flatten[ii]], edgeColor='k', zorder=vidx+20);
                    ax2.text(objHs[ii][0]-xlim_2/4, objHs[ii][1], objHs[ii][2], "{:0.4f}".format( (np.linalg.norm(objOs[ii]-objHs[ii])/obj_vo_scale).round(4) ), fontsize=8, bbox=bbox_w, zorder=vidx+30);
            else:
                if np.linalg.norm(objOs[ii]-objHs[ii], 2) < 0.1 and np.linalg.norm(objOs[ii]-objHs[ii], 2) > 0.002:
                    dR.drawVector(ax2, objOs[ii], objHs[ii], mutation_scale=20, alpha=0.3, arrowstyle='fancy', pointEnable=False, lineWidth=0.5, faceColor=colors[objIDs_flatten[ii]], edgeColor='k', zorder=vidx+10);
        ax2.text(-xlim_2*0.9, 0, zlim_2*0.02, "Speed", fontsize=9, style='italic', bbox=bbox_w, zorder=vidx+25);
        ax2.set_xlabel('X', fontdict=ax2_axfont); ax2.set_zlabel('Z', fontdict=ax2_axfont);
        ax2.axes.yaxis.set_ticklabels([])
        ax2.set_xlim(-xlim_2, xlim_2)
        ax2.set_ylim(-ylim_2, ylim_2)
        ax2.set_zlim(0,       zlim_2)
        ax2.get_proj = lambda: np.dot(Axes3D.get_proj(ax2), np.diag([1.2, 0.6, 2.4, 1]))
        ax2.xaxis._axinfo['juggled'] = (2,0,1)
        ax2.yaxis._axinfo['juggled'] = (2,1,0)
        ax2.zaxis._axinfo['juggled'] = (0,2,1)
        elv = 2; azm = 1;
        if 'cityscapes' in args.data:
            ax2.view_init(elev=-0.01-elv*vidx, azim=-90+azm*vidx)
            # ax2.view_init(elev=-0.01-elv*vidx, azim=-90.01-azm*vidx)
        else:
            if vidx <= 10: ax2.view_init(elev=-0.01-elv*vidx, azim=-90+azm*vidx)                                                # elev: -0~-40, azim: -90~-70
            if 10 < vidx and vidx <= 20: ax2.view_init(elev= -0.01-elv*10 + elv*(vidx-10), azim= -90+azm*10 - azm*(vidx-10))    # elev: -40~-0, azim: -70~-90
            if 20 < vidx and vidx <= 30: ax2.view_init(elev= -0.01 - elv*(vidx-20), azim= -90 - azm*(vidx-20))                  # elev: -0~-40, azim: -90~-110
            if 30 < vidx and vidx <= 40: ax2.view_init(elev= -0.01-elv*10 + elv*(vidx-30), azim= -90-azm*10 + azm*(vidx-30))    # elev: -40~-0, azim: -110~-90
            if 40 < vidx and vidx <= 50: ax2.view_init(elev= -0.01 - elv*(vidx-40), azim= -90 + azm*(vidx-40))                  # elev: -0~-40, azim: -90~-70
            if 50 < vidx and vidx <= 60: ax2.view_init(elev= -0.01-elv*10 + elv*(vidx-50), azim= -90+azm*10 - azm*(vidx-50))    # elev: -40~-0, azim: -70~-90
        ax2.dist = 10 + 0.1*vidx
        ax2_title = fig.add_subplot(gs[4, 4:6])
        ax2_title.axis('off')
        ax2_title.text(0.5, 0.5, '[Top-view] Unified visual odometry on world coordinate', fontdict=ax2_titlefont)
        plt.tight_layout();


        if args.save_fig:
            print('>> Saving image #{:02d}'.format(vidx))
            # plt.savefig('{:}/{:}_{:04d}.png'.format(args.save_path, Path(args.data).basename(), i), dpi=100)
            plt.savefig('{:}/{:04d}.png'.format(args.save_path, vidx), dpi=100)
            plt.close('all')
        else:
            plt.ion(); plt.show();
            print('>> Type \'c\' to continue')
            pdb.set_trace()
            plt.close('all')

        vidx += 1

    return 0
Ejemplo n.º 6
0
def MFP3D_bmax_time_all_comp (fig, MFP_stacked_thres, MFP_thres_poly, MFP_poly_bmax, depths, xx, yy, xmin, ymin, xmax, ymax, mofettes, comp_colors):

    grid = fig.add_gridspec(nrows = 5, ncols = 7)

    ####################### Bmax vs time Plots
    
    # create dict containing the row index for each polygon
    # A: third column (because 3D plot will be in first&second column), B: fourth column etc.
    # Depth 100m: first row (i.e. row "0"), depth 200m: second row etc.
    # Example: 300B will be in third row (i.e. "2"), fourth column (i.e. "3")
    plot_axes = {}
    for row_idx, depth in enumerate(np.arange (100,600,100)): # 5 depths in total (100,200,300,400,500m)
        for column_idx, signature_idx in enumerate(np.arange (0,4,1)): # 4 signature letters in total (A,B,C,D)
            letter = chr(ord('A') + signature_idx)
            plot_axes[str(depth) + letter] = {}
            plot_axes[str(depth) + letter]["row"] = row_idx
            plot_axes[str(depth) + letter]["col"] = column_idx+2 # i add "2" to column idx because first&second column is for the 3d plot
    
    global_max = 79
    global_min = 69
        
        ###### Plot Bmax & Lok./Min.
    for poly_sig, poly_bmax in MFP_poly_bmax.items():
        
        row_idx = plot_axes[poly_sig]["row"]
        col_idx = plot_axes[poly_sig]["col"]
        ax = plt.subplot (grid[row_idx, col_idx])
        # ax2 = ax.twinx()
        
        # Get bmax data for current polygon
        x = MFP_poly_bmax[poly_sig][:,0]
        y = MFP_poly_bmax[poly_sig][:,1]
        # Sort the points by bmax, so that the lowest bmax are plotted last
        idx = y.argsort()
        x_, y_ = x[idx], y[idx]
        # Plot Bmax
        colormap = plt.cm.gist_stern_r
        normalize = matplotlib.colors.Normalize(vmin = y.min(), vmax = y.max())
        ax.scatter(x_, y_, c = y_, s = 80, cmap = colormap, norm = normalize, edgecolor = '')
        
        ###### Layout
        # general tick settings
        ax.tick_params(axis = 'both', which = 'major', labelsize = 16, length = 5, width = 1)
        
        ax.set_xlim(left = 0, right = 540)
        ax.set_ylim(top = global_max, bottom = global_min)
        
        ax.set_xticks(np.arange(0, 540 + 60, 60))
        ax.set_yticks(np.arange(global_min, global_max + 5, 5))
        
        ax.set_axisbelow(True)
        ax.xaxis.grid(True, lw = 2, alpha = 0.5, zorder = -5)
        
        ticklabels = ax.get_xticklabels()
        
        ticklabels = ax.get_yticklabels()
        ticklabels[0].set_va("bottom")
        ticklabels[-1].set_va("top")

        ax.text(0, 1.02, poly_sig, fontsize = 19, weight = "bold", transform = ax.transAxes)
        
        # axis labels
        if poly_sig == "500B":
            ax.set_ylabel("Bmax [dB]", fontsize = 18, color = "black", weight = "bold")
            
        if poly_sig == "500B":
            ax.set_xlabel ("Uhrzeit", fontsize = 20, weight = "bold")
        
        # ticklabels
        labelbottom = ["100C", "400A", "500B", "500C", "500D"]
        if poly_sig in labelbottom:
            ax.tick_params(labelbottom = True)
            a = ax.get_xticks().tolist()
            for tickidx, tick in enumerate(a):
                a[tickidx] = ""
            
            a[0] = "22"
            a[1] = "23"
            a[2] = "0"
            a[3] = "1"
            a[4] = "2"
            a[5] = "3"
            a[6] = "4"
            a[7] = "5"
            a[8] = "6"
            a[9] = "7"
            
            ax.set_xticklabels(a)
        else:
            ax.tick_params(labelbottom = False)
        
        labelleft = ["100A", "200A", "300A", "400A", "500B"]
        if poly_sig in labelleft:
            ax.tick_params(labelleft = True)
        else:
            ax.tick_params(labelleft = False)
            
    ####################### 3D Plots
    ax_3d = plt.subplot (grid[0:, 0:2], projection='3d')
    
    # stretch z-axis (by 0.9)
    x_scale, y_scale, z_scale = 0.6, 0.6, 0.9
    from mpl_toolkits.mplot3d.axes3d import Axes3D
    ax_3d.get_proj = lambda: np.dot(Axes3D.get_proj(ax_3d), np.diag([x_scale, y_scale, z_scale, 1]))

    # plot surface for every depth
    for component in MFP_stacked_thres:
        global_max = max([array.max() for array in MFP_stacked_thres[component].values()])
        for depth, array in MFP_stacked_thres[component].items():
            # add white rectangle as base for current depth
            rect = patches.Rectangle((0, 0), 500, 500, facecolor = (1, 1, 1, 0.5)) # color white (1, 1, 1) with alpha 0.3
            ax_3d.add_patch(rect)
            art3d.pathpatch_2d_to_3d(rect, z = depth, zdir = "z")
            array_ = np.where(array == 0, "nan", array)
            PlotMFPLoczCountSubplot_3D (fig, ax_3d, array_, global_max, xx, yy, component, depth, modus = [], unique_areas = [])
            # plot border of areas with color according to component
            array = np.where(array != 0, 1, array)
            ax_3d.contour(xx, yy, array, offset = depth, zdir = 'z', colors = comp_colors[component])
                    
    for component in MFP_thres_poly:
        for depth in MFP_thres_poly[component]:
            for poly_signature, poly in MFP_thres_poly[component][depth].items():
                poly_txt = ax_3d.text(poly.centroid.x + 30, poly.centroid.y, depth, poly_signature[-1:], color = 'white', fontsize = 32, weight = "bold", transform = ax_3d.transData, zorder = 1000000)
                poly_txt.set_path_effects([PathEffects.Stroke(linewidth = 2, foreground = "k"), PathEffects.Normal()])
    
    # plot mofettes
    ax_3d.text(mofettes[0,0],mofettes[0,1], 0, u'\u2605', color = 'red', fontsize = 40, transform = ax_3d.transData, zorder = 2000)
    ax_3d.text(mofettes[1,0],mofettes[1,1], 0, u'\u2605', color = 'red', fontsize = 40, transform = ax_3d.transData, zorder = 2000)
    ax_3d.text(mofettes[2,0],mofettes[2,1], 0, u'\u2605', color = 'red', fontsize = 40, transform = ax_3d.transData, zorder = 2000)
    
    # draw north arrow              
    arrow3d(ax_3d, length = 150, width = 2, head = 0.3, headwidth = 3, offset = [-200, 200, 500], theta_x = 270, theta_z = 0, color = "black")
    ax_3d.text(-200, 300, 530, "N", size = 24, zorder = 1, color='k')
    
    # set axis limits
    ax_3d.set_zlim(min(depths),max(depths))
    ax_3d.set_xlim(0,500)
    ax_3d.set_ylim(0,500)
                
    # define distance titles to plots
    rcParams['axes.titlepad'] = -10
    
    # invert z- & x-axis
    ax_3d.invert_zaxis()
    
    # set x-, y- & z-axislabel
    ax_3d.set_xlabel('x [m]', labelpad = 0, fontsize = 20, verticalalignment = 'center')
    ax_3d.set_ylabel('y [m]', labelpad = 1, fontsize = 20, verticalalignment = 'center')
    ax_3d.set_zlabel('Tiefe [m]', labelpad = 28, fontsize = 24, rotation = 90)
    ax_3d.zaxis.set_rotate_label(False) # disable automatic axislabel rotation
        
    # define ticklabels
    xy_labels = np.arange(xmin, xmax+1, 100)
    xy_labels = xy_labels.astype(int)
    z_labels = depths
    # set ticks and ticklabels
    # x
    ax_3d.set_xticks(xy_labels) # create ticks
    ax_3d.set_xticklabels(xy_labels, verticalalignment = 'baseline', horizontalalignment = 'center') # label ticks
    ax_3d.tick_params(axis = 'x', which = 'major', pad = 2, labelsize = 20) # adjust ticklabel positions
    # y
    ax_3d.set_yticks(xy_labels)
    ax_3d.set_yticklabels(xy_labels, verticalalignment = 'bottom', horizontalalignment = 'right')
    ax_3d.tick_params(axis = 'y', which = 'major', pad = 2, labelsize = 20)
    # z
    ax_3d.set_zticks(z_labels)
    ax_3d.set_zticklabels(z_labels, verticalalignment = 'center')
    ax_3d.tick_params(axis = 'z', which = 'major', pad = 12, labelsize = 20)
    
    # rotate view
    ax_3d.view_init(elev = 10, azim = 225)
    
    line_z = Line2D([0,1],[0,1],  linewidth = 8, linestyle = '-', color = comp_colors["Z"])
    line_ns = Line2D([0,1],[0,1], linewidth = 8, linestyle = '-', color = comp_colors["N"])
    line_ew = Line2D([0,1],[0,1], linewidth = 8, linestyle = '-', color = comp_colors["E"])
    plt.figlegend((line_z, line_ns, line_ew), ('Z', 'N', 'E'), loc = "center left", 
                  fontsize = 24, borderaxespad = 0.1, bbox_to_anchor = (0.23, 0.12))
    
    grid.update(wspace = 0.15)
    
    ax_3d.dist = 6
# sort states by index
state_ci.sort(key = lambda x: x[1], reverse=True)
ci_thresholds = [0, 48.5, 52.7, 100] # thresholds determine where to start new plot

for x in range(3):
    # parameters for 3D plot
    fig = plt.figure(figsize=(12, 8))
    ax0 = plt.figure().gca()
    ax = fig.gca(projection='3d')
    ax.elev = 35
    ax.azim = 270
    ax.dist = 14
    x_scale = 1
    y_scale = 2.1
    z_scale = 1
    ax.get_proj = lambda: np.dot(Axes3D.get_proj(ax), np.diag([x_scale, y_scale, z_scale, 1]))

    j = 0 
    state_list = []
    # generate list of states for each plot based on index thresholds
    for i in range(len(state_code)):
        state = state_ci[i][0]
        if state_ci[i][1] < ci_thresholds[x+1] and state_ci[i][1] >= ci_thresholds[x]:
            state_list.append(state[3:])
    # plot states on above list
    for i in range(len(state_code)):
        state = state_ci[i][0]
        if state_ci[i][1] < ci_thresholds[x+1] and state_ci[i][1] >= ci_thresholds[x]:
            plot_state_by_index(globals()["{0}_epi_df".format(state)], globals()["{0}_demo_df".format(state)], 
                                state, ax0, ax, j, state_list, zlim=200)
            j += 1
Ejemplo n.º 8
0
def scale3d(ax, x_scale=1, y_scale=1, z_scale=1):
    scale = np.diag([x_scale, y_scale, z_scale, 1.0])
    scale = scale * (1.0 / scale.max())
    scale[3, 3] = 1.0
    short_proj = np.dot(Axes3D.get_proj(ax), scale)
    ax.get_proj = short_proj