예제 #1
0
def generate_label_views(kzip_path,
                         ssd_version,
                         gt_type,
                         n_voting=40,
                         nb_views=2,
                         ws=(256, 128),
                         comp_window=8e3,
                         out_path=None,
                         verbose=False):
    """

    Parameters
    ----------
    kzip_path : str
    gt_type :  str
    ssd_version : str
    n_voting : int
        Number of collected nodes during BFS for majority vote (label smoothing)
    nb_views : int
    ws: Tuple[int]
    comp_window : float
    initial_run : bool
        if True, will copy SSV from default SSD to SSD with version=gt_type
    out_path : str
        If given, export mesh colored accoring to GT labels
    verbose : bool
        Print additional information

    Returns
    -------
    Tuple[np.array]
        raw, label and index views
    """
    assert gt_type in ["axgt",
                       "spgt"], "Currently only spine and axon GT is supported"
    n_labels = 5 if gt_type == "axgt" else 4
    palette = generate_palette(n_labels)
    sso_id = int(re.findall("/(\d+).", kzip_path)[0])
    sso = SuperSegmentationObject(sso_id, version=ssd_version)
    if initial_run:  # use default SSD version
        orig_sso = SuperSegmentationObject(sso_id)
        orig_sso.copy2dir(dest_dir=sso.ssv_dir, safe=False)
    if not sso.attr_dict_exists:
        msg = 'Attribute dict of original SSV was not copied successfully ' \
              'to target SSD.'
        raise ValueError(msg)
    sso.load_attr_dict()
    indices, vertices, normals = sso.mesh

    # # Load mesh
    vertices = vertices.reshape((-1, 3))

    # load skeleton
    skel = load_skeleton(kzip_path)
    if len(skel) == 1:
        skel = list(skel.values())[0]
    else:
        skel = skel["skeleton"]
    skel_nodes = list(skel.getNodes())

    node_coords = np.array(
        [n.getCoordinate() * sso.scaling for n in skel_nodes])
    node_labels = np.array(
        [str2intconverter(n.getComment(), gt_type) for n in skel_nodes],
        dtype=np.int)
    node_coords = node_coords[(node_labels != -1)]
    node_labels = node_labels[(node_labels != -1)]

    # create KD tree from skeleton node coordinates
    tree = KDTree(node_coords)
    # transfer labels from skeleton to mesh
    dist, ind = tree.query(vertices, k=1)
    vertex_labels = node_labels[ind]  # retrieving labels of vertices
    if n_voting > 0:
        vertex_labels = bfs_smoothing(vertices,
                                      vertex_labels,
                                      n_voting=n_voting)
    color_array = palette[vertex_labels].astype(np.float32) / 255.

    if out_path is not None:
        if gt_type == 'spgt':  #
            colors = [[0.6, 0.6, 0.6, 1], [0.9, 0.2, 0.2, 1],
                      [0.1, 0.1, 0.1, 1], [0.05, 0.6, 0.6, 1],
                      [0.9, 0.9, 0.9, 1]]
        else:  # dendrite, axon, soma, bouton, terminal, background
            colors = [[0.6, 0.6, 0.6, 1], [0.9, 0.2, 0.2, 1],
                      [0.1, 0.1, 0.1, 1], [0.05, 0.6, 0.6, 1],
                      [0.6, 0.05, 0.05, 1], [0.9, 0.9, 0.9, 1]]
        colors = (np.array(colors) * 255).astype(np.uint8)
        color_array_mesh = colors[
            vertex_labels][:,
                           0]  # TODO: check why only first element, maybe colors introduces an additional axis
        write_mesh2kzip("{}/sso_{}_gtlabels.k.zip".format(out_path, sso.id),
                        sso.mesh[0],
                        sso.mesh[1],
                        sso.mesh[2],
                        color_array_mesh,
                        ply_fname="gtlabels.ply")

    # Initializing mesh object with ground truth coloring
    mo = MeshObject("neuron", indices, vertices, color=color_array)

    # use downsampled locations for view locations, only if they are close to a
    # labeled skeleton node
    locs = generate_rendering_locs(vertices, comp_window /
                                   6)  # 6 rendering locations per comp.
    # window
    dist, ind = tree.query(locs)
    locs = locs[dist[:, 0] < 2000]  #[::3][:5]  # TODO add as parameter

    # # # To get view locations
    # dest_folder = os.path.expanduser("~") + \
    #               "/spiness_skels/{}/view_imgs_{}/".format(sso_id, n_voting)
    # if not os.path.isdir(dest_folder):
    #     os.makedirs(dest_folder)
    # loc_text = ''
    # for i, c in enumerate(locs):
    #     loc_text += str(i) + "\t" + str((c / np.array([10, 10, 20])).astype(np.int)) +'\n' #rescalling to the voxel grid
    # with open("{}/viewcoords.txt".format(dest_folder), "w") as f:
    #     f.write(loc_text)
    # # # DEBUG PART END
    label_views, rot_mat = _render_mesh_coords(locs,
                                               mo,
                                               depth_map=False,
                                               return_rot_matrices=True,
                                               ws=ws,
                                               smooth_shade=False,
                                               nb_views=nb_views,
                                               comp_window=comp_window,
                                               verbose=verbose)
    label_views = remap_rgb_labelviews(label_views[..., :3], palette)[:, None]
    # TODO: the 3 neglects the alpha channel, i.e. remapping labels bigger than 256**3 becomes
    #  invalid
    index_views = render_sso_coords_index_views(sso,
                                                locs,
                                                rot_mat=rot_mat,
                                                verbose=verbose,
                                                nb_views=nb_views,
                                                ws=ws,
                                                comp_window=comp_window)
    raw_views = render_sso_coords(sso,
                                  locs,
                                  nb_views=nb_views,
                                  ws=ws,
                                  comp_window=comp_window,
                                  verbose=verbose,
                                  rot_mat=rot_mat)
    return raw_views, label_views, index_views
예제 #2
0
            args.append(pkl.load(f))
        except EOFError:
            break

scaling = global_params.config['scaling']
# TODO: This coulb be cunked by loading `mesh_bb` and glia prob. prediction cache arrays
#  (might have to be create via `dataset_analysis`)
for cc in args:
    svixs = list(cc.nodes())
    cc_ix = np.min(svixs)
    sso = SuperSegmentationObject(cc_ix,
                                  version="gliaremoval",
                                  nb_cpus=1,
                                  working_dir=global_params.config.working_dir,
                                  create=True,
                                  scaling=scaling,
                                  sv_ids=svixs)
    so_cc = nx.Graph()
    for e in cc.edges():
        so_cc.add_edge(sso.get_seg_obj("sv", e[0]),
                       sso.get_seg_obj("sv", e[1]))
    sso._rag = so_cc
    sd = sos_dict_fact(svixs)
    sos = init_sos(sd)
    sso._objects["sv"] = sos
    sso.load_attr_dict()
    sso.gliasplit(verbose=False, recompute=False)

with open(path_out_file, "wb") as f:
    pkl.dump("0", f)