def on_train_begin(self, num_epochs, logs={}):
     self.progress = 0
     self.increment = rescale(1, self.num_epochs)
     self.i = 1
     self.x = []
     self.losses = []
     self.val_losses = []
     self.logs = []
     plt.figure()
     plt.xlim(1, self.num_epochs)
     plt.ylim(0, 1)
     plt.plot(0, 0, label="loss")
     plt.plot(0, 0, label="val_loss")
     plt.title("Training and Validation loss [Epoch {}]".format(0))
     plt.xlabel("Epoch")
     plt.ylabel("Loss")
     plt.savefig('temp/Epoch-{}.png'.format(0))
     pixmap = QPixmap('temp/Epoch-{}.png'.format(0))
     scaled_pixmap = pixmap.scaled(self.label.size(),
                                   QtCore.Qt.KeepAspectRatio)
     self.label.setPixmap(scaled_pixmap)
     plt.close()
Ejemplo n.º 2
0
    def start_reconstruction(self):
        os.makedirs(self.reconstructed_path, exist_ok=True)
        dcm_filenames = sorted(list(os.listdir(self.sinogram_path)))
        file, header = nrrd.read(os.path.join(self.sinogram_path,
                                              dcm_filenames[-1]),
                                 index_order='C')
        last_size = len(file)
        length = (len(dcm_filenames) - 1) * 128 + last_size

        progress = 0
        increment = rescale(1, length)
        it = 0
        for dcm_file in tqdm(dcm_filenames):
            if dcm_file.endswith('.nrrd'):
                if self._case == 0:
                    recon_filename = os.path.join(
                        self.reconstructed_path,
                        'noisy_astra_{:03d}.nrrd'.format(it + 101))
                    nrrd.write(recon_filename,
                               self.reconstruct_astra(
                                   os.path.join(self.sinogram_path, dcm_file),
                                   progress, increment),
                               index_order='C')
                if self._case == 1:
                    recon_filename = os.path.join(
                        self.reconstructed_path,
                        'noisy_odl_{:03d}.nrrd'.format(it + 101))
                    nrrd.write(recon_filename,
                               self.reconstruct_odl(
                                   os.path.join(self.sinogram_path, dcm_file),
                                   progress, increment),
                               index_order='C')
                it += 1
                progress += 100 / len(dcm_filenames)

        self.progressbar.emit(100)
    def start_sinogram_generation(self):

        file_list = list_files(self.files_path)

        os.makedirs(self.sinogram_path, exist_ok=True)
        os.makedirs(self.ground_truth_path, exist_ok=True)

        lidc_idri_gen_len = len(file_list)

        vol_geom = astra.create_vol_geom(512, 512)
        angles = np.linspace(0, np.pi, 1000, False)
        proj_geom = astra.create_proj_geom('parallel', 1.0, 727, angles)
        projector_id = astra.create_projector('cuda', proj_geom, vol_geom)

        reco_space = odl.uniform_discr(min_pt=self.MIN_PT, max_pt=self.MAX_PT,
                                       shape=self.RECO_IM_SHAPE, dtype=np.float32)
        space = odl.uniform_discr(min_pt=self.MIN_PT, max_pt=self.MAX_PT, shape=self.IM_SHAPE,
                                  dtype=np.float32)

        reco_geometry = odl.tomo.parallel_beam_geometry(
            reco_space, num_angles=self.NUM_ANGLES)
        geometry = odl.tomo.parallel_beam_geometry(
            space, num_angles=self.NUM_ANGLES, det_shape=reco_geometry.detector.shape)

        # IMPL = 'astra_cpu'
        # reco_ray_trafo = odl.tomo.RayTransform(reco_space, reco_geometry)
        ray_trafo = odl.tomo.RayTransform(space, geometry)


        rs = np.random.RandomState(3)

        n_files = ceil(lidc_idri_gen_len / self.NUM_SAMPLES_PER_FILE)

        slices = self.lidc_idri_gen(file_list)
        it1 = 0
        it2 = self.NUM_SAMPLES_PER_FILE
        progress = 0
        increment = rescale(1, n_files)
        for filenumber in tqdm(range(n_files)):
            obs_filename = os.path.join(
                self.sinogram_path, 'sinogram_{:03d}.nrrd'.format(filenumber))
            ground_truth_filename = os.path.join(
                self.ground_truth_path, 'ground_truth_{:03d}.nrrd'.format(filenumber))

            observation_dataset = []
            ground_truth_dataset = []

            for data in tqdm(slices[it1:it2, ...]):
                data = np.flipud(data)
                # ground_truth_dataset.append(data)
                data = self.ff(ray_trafo, data)
                data /=  self.PHOTONS_PER_PIXEL
                np.maximum(0.1 / self.PHOTONS_PER_PIXEL, data, out=data)
                np.log(data, out=data)
                data /= (-self.MU_MAX)
                observation_dataset.append(data)
                if self.FLAG:
                    print('terminating process')
                    return

            it1 += self.NUM_SAMPLES_PER_FILE
            it2 += self.NUM_SAMPLES_PER_FILE

            if it2 > len(slices):
                it2 = len(slices)
            # ground_truth_dataset = np.array(ground_truth_dataset)
            # ground_truth_dataset = np.rot90(ground_truth_dataset, -1, (1, 2))
            nrrd.write(obs_filename, np.array(observation_dataset), index_order='C')