Ejemplo n.º 1
0
 def save(self, directory, atlas_name):
     # temporarily remove polydata object to enable pickling
     polydata_tmp = self.nystrom_polydata
     self.nystrom_polydata = None
     fname = os.path.join(directory,atlas_name)
     pickle.dump(self,open(fname+'.p','wb'))
     # save the polydata
     io.write_polydata(polydata_tmp, fname+'.vtp')
     # replace the polydata object
     self.nystrom_polydata = polydata_tmp
Ejemplo n.º 2
0
def output_and_quality_control_cluster_atlas(atlas, output_polydata_s, subject_fiber_list, input_polydatas, number_of_subjects, outdir, cluster_numbers_s, color, embed, number_of_fibers_to_display, testing=False, verbose=False, render_images=True):

    """Save the output in our atlas format for automatic labeling of clusters.

    First save the atlas.vtp and atlas.p datasets. This is the data used to
    label a new subject.  Also write the polydata with cluster indices
    saved as cell data. This is a one-file output option for clusters.
    Finally, save some quality control metrics and save the atlas
    clusters as individual polydatas. This is used to set up a mrml
    hierarchy file and to visualize the output in Slicer. This data is
    not used to label a new subject.
    """

    # Write additional output for software testing if code has changed (requires random seed to be constant)
    if testing:
        expected_results_file = os.path.join(outdir, 'test_cluster_atlas_numbers.pkl')
        pickle.dump(cluster_numbers_s, open(expected_results_file, 'wb'))
        expected_results_file = os.path.join(outdir, 'test_cluster_atlas_colors.pkl')
        pickle.dump(color, open(expected_results_file, 'wb'))
        expected_results_file = os.path.join(outdir, 'test_cluster_atlas_embeddings.pkl')
        pickle.dump(embed, open(expected_results_file, 'wb'))
         
    # Save the output in our atlas format for automatic labeling of full brain datasets.
    # This is the data used to label a new subject
    atlas.save(outdir,'atlas')

    # Write the polydata with cluster indices saved as cell data
    fname_output = os.path.join(outdir, 'clustered_whole_brain.vtp')
    io.write_polydata(output_polydata_s, fname_output)

    # output summary file to save information about all subjects
    subjects_qc_fname = os.path.join(outdir, 'input_subjects.txt')
    subjects_qc_file = open(subjects_qc_fname, 'w')
    outstr = "Subject_idx\tSubject_ID\tfilename\n"
    subjects_qc_file.write(outstr)
    idx = 1
    for fname in input_polydatas:
        subject_id = os.path.splitext(os.path.basename(fname))[0]
        outstr =  str(idx) + '\t' + str(subject_id) + '\t' + str(fname) + '\n'
        subjects_qc_file.write(outstr)
        idx += 1
    subjects_qc_file.close()

    # output summary file to save information about all clusters
    clusters_qc_fname = os.path.join(outdir, 'cluster_quality_control.txt')
    clusters_qc_file = open(clusters_qc_fname, 'w')

    # Figure out how many subjects in each cluster (ideally, most subjects in most clusters)
    subjects_per_cluster = list()
    percent_subjects_per_cluster = list()
    fibers_per_cluster = list()
    mean_fiber_len_per_cluster = list()
    std_fiber_len_per_cluster = list()
    mean_fibers_per_subject_per_cluster = list()
    std_fibers_per_subject_per_cluster = list()

    # find out length of each fiber
    fiber_length, step_size = filter.compute_lengths(output_polydata_s)

    # loop over each cluster and compute quality control metrics
    cluster_indices = range(atlas.centroids.shape[0])
    for cidx in cluster_indices:
        cluster_mask = (cluster_numbers_s==cidx) 
        subjects_per_cluster.append(len(set(subject_fiber_list[cluster_mask])))
        fibers_per_subject = list()
        for sidx in range(number_of_subjects):
            fibers_per_subject.append(list(subject_fiber_list[cluster_mask]).count(sidx))
        mean_fibers_per_subject_per_cluster.append(numpy.mean(numpy.array(fibers_per_subject)))
        std_fibers_per_subject_per_cluster.append(numpy.std(numpy.array(fibers_per_subject)))
        mean_fiber_len_per_cluster.append(numpy.mean(fiber_length[cluster_mask]))
        std_fiber_len_per_cluster.append(numpy.std(fiber_length[cluster_mask]))

    percent_subjects_per_cluster = numpy.divide(numpy.array(subjects_per_cluster),float(number_of_subjects))

    # Save output quality control information
    print "<cluster.py> Saving cluster quality control information file."
    clusters_qc_file = open(clusters_qc_fname, 'w')
    print >> clusters_qc_file, 'cluster_idx','\t', 'number_subjects','\t', 'percent_subjects','\t', 'mean_length','\t', 'std_length','\t', 'mean_fibers_per_subject','\t', 'std_fibers_per_subject'
    for cidx in cluster_indices:
        print >> clusters_qc_file, cidx + 1,'\t', subjects_per_cluster[cidx],'\t', percent_subjects_per_cluster[cidx] * 100.0,'\t', \
            mean_fiber_len_per_cluster[cidx],'\t', std_fiber_len_per_cluster[cidx],'\t', \
            mean_fibers_per_subject_per_cluster[cidx],'\t', std_fibers_per_subject_per_cluster[cidx]

    clusters_qc_file.close()

    if HAVE_PLT:
        print "<cluster.py> Saving subjects per cluster histogram."
        fig, ax = plt.subplots()
        counts = numpy.zeros(num_of_subjects+1)
        counts[:numpy.max(subjects_per_cluster)+1] = numpy.bincount(subjects_per_cluster)
        ax.bar(range(num_of_subjects + 1), counts, width=1, align='center')
        ax.set(xlim=[-1, num_of_subjects + 1])
        plt.title('Histogram of Subjects per Cluster')
        plt.xlabel('subjects per cluster')
        plt.ylabel('number of clusters')
        plt.savefig( os.path.join(outdir, 'subjects_per_cluster_hist.pdf'))
        plt.close()
        
    # Save the entire combined atlas as individual clusters for visualization
    # and labeling/naming of structures. This will include all of the data
    # that was clustered to make the atlas.

    # Figure out file name and mean color for each cluster, and write the individual polydatas
    print "<cluster.py> Beginning to save individual clusters as polydata files. TOTAL CLUSTERS:", len(cluster_indices),
    fnames = list()
    cluster_colors = list()
    cluster_sizes = list()
    cluster_fnames = list()
    for c in cluster_indices:
        print c,
        mask = cluster_numbers_s == c
        cluster_size = numpy.sum(mask)
        cluster_sizes.append(cluster_size)
        pd_c = filter.mask(output_polydata_s, mask,verbose=verbose)
        # color by subject so we can see which one it came from
        filter.add_point_data_array(pd_c, subject_fiber_list[mask], "Subject_ID")
        # Save hemisphere information into the polydata
        farray = fibers.FiberArray()
        farray.hemispheres = True
        farray.hemisphere_percent_threshold = 0.90
        farray.convert_from_polydata(pd_c, points_per_fiber=50)
        filter.add_point_data_array(pd_c, farray.fiber_hemisphere, "Hemisphere")
        # The clusters are stored starting with 1, not 0, for user friendliness.
        fname_c = 'cluster_{0:05d}.vtp'.format(c+1)
        # save the filename for writing into the MRML file
        fnames.append(fname_c)
        # prepend the output directory
        fname_c = os.path.join(outdir, fname_c)
        #print fname_c
        io.write_polydata(pd_c, fname_c)
        cluster_fnames.append(fname_c)
        if cluster_size > 0:
            color_c = color[mask,:]
            cluster_colors.append(numpy.mean(color_c,0))
        else:
            cluster_colors.append([0,0,0])
        del pd_c
    print "\n<cluster.py> Finishes saving individual clusters as polydata files."

    # Notify user if some clusters empty
    empty_count = 0
    for sz, fname in zip(cluster_sizes,cluster_fnames):
        if sz == 0:
            print sz, ":", fname
            empty_count += 1
    if empty_count:
        print "<cluster.py> Warning. Empty clusters found:", empty_count

    cluster_sizes = numpy.array(cluster_sizes)
    print "<cluster.py> Mean number of fibers per cluster:", numpy.mean(cluster_sizes), "Range:", numpy.min(cluster_sizes), "..", numpy.max(cluster_sizes)

    # Estimate subsampling ratio to display approximately number_of_fibers_to_display total fibers in 3D Slicer
    number_fibers = len(cluster_numbers_s)
    if number_fibers < number_of_fibers_to_display:
        ratio = 1.0
    else:
        ratio = number_of_fibers_to_display / number_fibers
    print "<cluster.py> Subsampling ratio for display of", number_of_fibers_to_display, "total fibers estimated as:", ratio

    # Write the MRML file into the directory where the polydatas were already stored
    fname = os.path.join(outdir, 'clustered_tracts.mrml')
    mrml.write(fnames, numpy.around(numpy.array(cluster_colors), decimals=3), fname, ratio=ratio)

    # Also write one with 100% of fibers displayed
    fname = os.path.join(outdir, 'clustered_tracts_display_100_percent.mrml')
    mrml.write(fnames, numpy.around(numpy.array(cluster_colors), decimals=3), fname, ratio=1.0)
    
    # View the whole thing in jpg format for quality control
    if render_images:
        print '<cluster.py> Rendering and saving images of cluster atlas.'
        ren = render.render(output_polydata_s, 1000, data_mode='Cell', data_name='EmbeddingColor', verbose=verbose)
        ren.save_views(outdir, verbose=verbose)
        del ren