Exemplo n.º 1
0
    def _make_root(self, rootpath):
        """
            Creates a root mesh by merging the mesh corresponding to each neuron,
            then saves it as an obj file at rootpath
        """
        raise NotImplementedError(
            "Create root method not supported yet, sorry")

        print(f"Creating root mesh for atlas {self.atlas_name}")
        temp_scene = Scene(
            atlas=Celegans,
            add_root=False,
            display_inset=False,
            atlas_kwargs=dict(data_folder=self.data_folder),
        )

        temp_scene.add_neurons(self.neurons_names)
        temp_scene.render(interactive=False)
        temp_scene.close()

        root = merge(*temp_scene.actors["neurons"]).clean().cap()
        # root = mesh2Volume(root, spacing=(0.02, 0.02, 0.02)).isosurface()

        points = Points(root.points()).smoothMLS2D(f=0.8).clean(tol=0.005)

        root = recoSurface(points, dims=100, radius=0.2)

        # Save
        write(root, rootpath)

        del temp_scene
        return root
Exemplo n.º 2
0
    def make_root_mesh(self):
        if self.structures is None:
            return

        obj_path = os.path.join(self.meshes_folder, "root.vtk")
        if os.path.isfile(obj_path):
            return

        # Get the mesh for each brain region to create root
        meshes = [
            self._get_structure_mesh(reg) for reg in self.region_acronyms
        ]
        root = merge(meshes)
        write(root, obj_path)
Exemplo n.º 3
0
    def write_neuron_to_cache(self, neuron_name, neuron, _params):
        # Write params to file
        save_yaml(self.get_cache_params_filename(neuron_name), _params)

        # Write neurons to file
        file_names = self.get_cache_filenames(neuron_name)

        if isinstance(neuron, Mesh):
            write(neuron, [f for f in file_names if f.endswith("soma.obj")][0])
        else:
            if not isinstance(neuron, dict):
                raise ValueError(
                    f"Invalid neuron argument passed while caching: {neuron}")
            for key, actor in neuron.items():
                if key == "whole_neuron":
                    fname = [f for f in file_names if f.endswith(f"{key}.obj")]
                    write(actor, fname[0])
                else:
                    # Get a single actor for each neuron component.
                    # If there's no data for the component create an empty actor
                    if not isinstance(actor, Mesh):
                        if isinstance(actor, (list, tuple)):
                            if len(actor) == 1:
                                actor = actor[0]
                            elif not actor or actor is None:
                                actor = Mesh()
                            else:
                                try:
                                    actor = merge(actor)
                                except:
                                    raise ValueError(
                                        f"{key} actor should be a mesh or a list of 1 mesh not {actor}"
                                    )

                    if actor is None:
                        actor = Mesh()

                    # Save to file
                    fname = [f for f in file_names if f.endswith(f"{key}.obj")]
                    if fname:
                        write(actor, fname[0])
                    else:
                        raise ValueError(
                            f"No filename found for {key}. Filenames {file_names}"
                        )
Exemplo n.º 4
0
            A_S = torch.from_numpy(A_S).to(device, dtype=torch.float)
            A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)

            tensor_prob_output = model(X, A_S, A_L).to(device,
                                                       dtype=torch.float)
            patch_prob_output = tensor_prob_output.cpu().numpy()

            for i_label in range(num_classes):
                predicted_labels_d[np.argmax(patch_prob_output[0, :], axis=-1)
                                   == i_label] = i_label

            # output downsampled predicted labels
            mesh2 = mesh_d.clone()
            mesh2.addCellArray(predicted_labels_d, 'Label')
            vedo.write(
                mesh2,
                os.path.join(output_path,
                             '{}_d_predicted.vtp'.format(i_sample[:-4])))

            # refinement
            print('\tRefining by pygco...')
            round_factor = 100
            patch_prob_output[patch_prob_output < 1.0e-6] = 1.0e-6

            # unaries
            unaries = -round_factor * np.log10(patch_prob_output)
            unaries = unaries.astype(np.int32)
            unaries = unaries.reshape(-1, num_classes)

            # parawise
            pairwise = (1 - np.eye(num_classes, dtype=np.int32))
Exemplo n.º 5
0
def vtk_plot(vtk_dir, output_dir):
    vtk_obj = load(vtk_dir)
    write(objct=vtk_obj, fileoutput=output_dir + "/vtk_obj.vti")
Exemplo n.º 6
0
            A_L = torch.from_numpy(A_L).to(device, dtype=torch.float)

            tensor_prob_output = model(X, A_S, A_L).to(device,
                                                       dtype=torch.float)
            patch_prob_output = tensor_prob_output.cpu().numpy()

            for i_label in range(num_classes):
                predicted_labels[np.argmax(patch_prob_output[0, :], axis=-1) ==
                                 i_label] = i_label

            # output predicted labels
            mesh2 = mesh.clone()
            mesh2.addCellArray(predicted_labels, 'Label')
            vedo.write(
                mesh2,
                os.path.join(test_path,
                             'Sample_{}_predicted.vtp'.format(i_sample)),
                binary=True)

            # convert predict result and label to one-hot maps
            tensor_predicted_labels = torch.from_numpy(predicted_labels)
            tensor_test_labels = torch.from_numpy(labels)
            tensor_predicted_labels = tensor_predicted_labels.long()
            tensor_test_labels = tensor_test_labels.long()
            one_hot_predicted_labels = nn.functional.one_hot(
                tensor_predicted_labels[:, 0], num_classes=num_classes)
            one_hot_labels = nn.functional.one_hot(tensor_test_labels[:, 0],
                                                   num_classes=num_classes)

            # calculate DSC
            i_dsc = DSC(one_hot_predicted_labels, one_hot_labels)