コード例 #1
0
    def __call__(self, listmode: Listmode) -> Image:
        if self.mode == 'numba':
            sub_length = int(len(listmode) // self.n_sub)
            x = self.emap.update(
                data=np.ones(self.emap.shape, dtype=np.float32))
            for _ in tqdm(range(self.n_iter)):
                index = np.random.randint(0,
                                          high=len(listmode),
                                          size=sub_length)
                _listmode = Project(self.mode)(x, listmode.lors[index, :])
                _bp = BackProject(self.mode)(listmode[index] / _listmode, x)
                x *= _bp / self.emap
            return x
        elif self.mode == 'tf-eager':
            raise NotImplementedError
            x = self.emap.update(
                data=np.ones(self.emap.shape, dtype=np.float32))
            x_tf = x.update(data=tf.constant(x.data))
            emap_tf = self.emap.update(data=tf.constant(self.emap.data))
            lors_tf = listmode.lors.update(
                data=tf.constant(listmode.lors.data))
            listmode_tf = listmode.update(data=tf.constant(listmode.data),
                                          lors=lors_tf)
            _listmode_tf = Project(self.mode)(x_tf, lors_tf)
            _bp = BackProject(self.mode)(listmode_tf / _listmode_tf, emap_tf)
            _mlem_iter = tf.assign(x_tf.data, (x_tf * _bp / emap_tf).data)

            with tf.compat.v1.Session('') as sess:
                sess.run(tf.compat.v1.global_variables_initializer())
                for _ in tqdm(range(self.n_iter)):
                    sess.run(_mlem_iter)
                return x_tf.update(data=sess.run(x_tf.data))
        elif self.mode == 'tf':
            raise NotImplementedError
コード例 #2
0
ファイル: mlem_deform.py プロジェクト: twj2417/srf
    def __call__(self, listmodes: List(Listmode), dvfs: List(Dvf)) -> Image:
        declare_eager_execution()
        x = self.emap.update(data=np.ones(self.emap.shape, dtype=np.float32))
        x_tf = x.update(data=tf.Variable(x.data))

        emap_tf = self.emap.update(data=tf.constant(self.emap.data))
        for _ in tqdm(range(self.n_iter)):
            for listmode, dvf in zip(listmodes, dvfs):
                dvf_x_tf = tf.constant(dvf.dvf_x.data)
                dvf_y_tf = tf.constant(dvf.dvf_y.data)
                dvf_z_tf = tf.constant(dvf.dvf_z.data)
                lors_tf = listmode.lors.update(
                    data=tf.constant(listmode.lors.data))
                listmode_tf = listmode.update(data=tf.constant(listmode.data),
                                              lors=lors_tf)
                _x_tf = x_tf.update(data=self._deform_invert_tf(
                    x_tf.data, dvf_x_tf, dvf_y_tf, dvf_z_tf))
                _listmode_tf = Project('tf')(_x_tf, lors_tf)
                _listmode_tf = _listmode_tf + 1e-8
                _bp = BackProject('tf')(listmode_tf / _listmode_tf, emap_tf)
                emap_tf2 = emap_tf.update(data=self._deform_tf(
                    emap_tf.data, dvf_x_tf, dvf_y_tf, dvf_z_tf)) + 1e-8
                _bp2 = _bp.update(data=self._deform_tf(_bp.data, dvf_x_tf,
                                                       dvf_y_tf, dvf_z_tf))
                x_tf = x_tf * _bp2 / emap_tf2

        return x_tf.update(data=x_tf.data.numpy())
コード例 #3
0
def listmode_from_gate_out(path, scanner, nb_sub):
    full_data = None
    for i in tqdm(range(nb_sub)):
        path_ = path.replace('?', str(i))
        data = np.load(path_)[:, :7]
        if full_data is None:
            full_data = np.array(data)
        else:
            full_data = np.vstack((full_data, data))
    if isinstance(scanner, nef.PetEcatScanner):
        lors = Lors(full_data)
        listmode = Listmode(np.ones((lors.length, ), dtype=np.float32), lors)
        row = nef.EcatCrystalPosToIndex(scanner)(lors.data[:, :3])
        col = nef.EcatCrystalPosToIndex(scanner)(lors.data[:, 3:])
        lors_data1 = nef.EcatIndexToCrystalPos(scanner)(row)
        lors_data2 = nef.EcatIndexToCrystalPos(scanner)(col)
        lors = np.hstack((lors_data1, lors_data2))
        return Listmode(np.ones((lors.shape[0], ), dtype=np.float32), lors)
    elif isinstance(scanner, nef.PetCylindricalScanner):
        lors = Lors(full_data)
        listmode = Listmode(np.ones((lors.length, ), dtype=np.float32), lors)
        ind1 = nef.CylindricalCrystalPosToIndex(scanner)(lors.data[:, :3])
        ind2 = nef.CylindricalCrystalPosToIndex(scanner)(lors.data[:, 3:])
        lors_data1 = nef.CylindricalIndexToCrystalPos(scanner)(ind1)
        lors_data2 = nef.CylindricalIndexToCrystalPos(scanner)(ind2)
        lors_data = np.hstack((lors_data1, lors_data2))
        return listmode.update(lors=nef.Lors(lors_data))
    else:
        raise NotImplementedError
コード例 #4
0
    def __call__(self) -> Image:
        declare_eager_execution()
        if self.emap is None:
            self.generate_emap()

        if self.atten_corr is not None:
            listmode_ = self.atten_corr(self.listmode)
        else:
            listmode_ = self.listmode

        x_tf = Image(data=tf.Variable(
            np.ones(self.image_config.shape, dtype=np.float32)),
                     center=self.image_config.center,
                     size=self.image_config.size)
        emap_data_n0_zero = copy(self.emap.data)
        emap_data_n0_zero[emap_data_n0_zero == 0.0] = 1e8
        emap_tf = self.emap.update(data=tf.constant(emap_data_n0_zero))
        lors_tf = self.listmode.lors.update(
            data=tf.constant(self.listmode.lors.data))
        listmode_tf = self.listmode.update(data=tf.constant(listmode_.data),
                                           lors=lors_tf)

        for _ in tqdm(range(self.n_iter)):
            _listmode_tf = Project('tf-eager')(x_tf, lors_tf)
            listmode_div = tf.div_no_nan(listmode_tf.data, _listmode_tf.data)
            _bp = BackProject('tf-eager')(
                listmode_tf.update(data=listmode_div), emap_tf)
            x_tf = x_tf * _bp / emap_tf

        x = x_tf.update(data=x_tf.data.numpy())

        # if self.scatter_corr is not None:
        #     if self.atten_corr is not None:
        #         listmode_ = self.scatter_corr(x, self.atten_corr.u_map, self.scanner, self.listmode)

        #         x_tf = Image(data = tf.Variable(np.ones(self.image_config.shape, dtype = np.float32)),
        #                  center = self.image_config.center,
        #                  size = self.image_config.size)
        #         emap_data_n0_zero = copy(self.emap.data)
        #         emap_data_n0_zero[emap_data_n0_zero == 0.0] = 1e8
        #         emap_tf = self.emap.update(data = tf.constant(emap_data_n0_zero))
        #         lors_tf = self.listmode.lors.update(data = tf.constant(self.listmode.lors.data))
        #         listmode_tf = self.listmode.update(data = tf.constant(listmode_.data), lors = lors_tf)

        #         for _ in tqdm(range(self.n_iter)):
        #             _listmode_tf = Project('tf-eager')(x_tf, lors_tf)
        #             listmode_div = tf.div_no_nan(listmode_tf.data, _listmode_tf.data)
        #             _bp = BackProject('tf-eager')(listmode_tf.update(data = listmode_div), emap_tf)
        #             x_tf = x_tf * _bp / emap_tf

        #         x = x_tf.update(data = x_tf.data.numpy())

        if self.psf_corr is not None:
            image_ = self.psf_corr(x)
        else:
            image_ = x
        return image_
コード例 #5
0
def listmode_tof_from_gate_out(path, scanner, nb_sub):
    full_data = None
    for i in tqdm(range(nb_sub)):
        path_ = path.replace('?', str(i))
        data = np.load(path_)[:, :7]
        if full_data is None:
            full_data = np.array(data)
        else:
            full_data = np.vstack((full_data, data))
    lors = Lors(full_data)
    listmode = Listmode(np.ones((lors.length, ), dtype=np.float32), lors)
    row, col = ListmodeToId()(listmode, scanner)
    lors = IdToListmode()(row, col, scanner)
    print(lors.data.shape)
    print(full_data.shape)
    lors_tof = lors.update(
        data=np.append(lors.data, full_data[:, -1].reshape(-1, 1), axis=1))
    return Listmode(np.ones((lors.shape[0], ), dtype=np.float32), lors_tof)
コード例 #6
0
    def __call__(self, image: Image):
        if self.kernel_xy is None:
            raise ValueError('Please do make kernel first')
        from srfnef.utils import declare_eager_execution
        declare_eager_execution()

        x = np.ones((image.shape[0] * image.shape[1], image.shape[2]),
                    dtype=np.float32)
        x_tf = tf.Variable(x)

        d = image.data.reshape((-1, image.shape[2]))
        d_tf = tf.constant(d)
        kernel_xy_tf = tf.sparse.SparseTensor(indices=list(
            zip(self.kernel_xy.row, self.kernel_xy.col)),
                                              values=self.kernel_xy.data,
                                              dense_shape=self.kernel_xy.shape)
        kernel_z_tf = tf.sparse.SparseTensor(indices=list(
            zip(self.kernel_z.row, self.kernel_z.col)),
                                             values=self.kernel_z.data,
                                             dense_shape=self.kernel_z.shape)
        if self.kernel_xy.nnz * image.shape[2] > 2**31:
            raise ValueError(
                'Cannot use GPU when output.shape[1] * nnz(a) > 2^31 [Op:SparseTensorDenseMatMul]'
            )
        for _ in tqdm(range(self.n_iter)):
            c_tf = tf.transpose(
                tf.sparse.sparse_dense_matmul(kernel_z_tf,
                                              x_tf,
                                              adjoint_b=True))
            c_tf = tf.sparse.sparse_dense_matmul(kernel_xy_tf, c_tf) + 1e-16
            c_tf = d_tf / c_tf
            c_tf = tf.sparse.sparse_dense_matmul(kernel_xy_tf,
                                                 c_tf,
                                                 adjoint_a=True)
            c_tf = tf.transpose(
                tf.sparse.sparse_dense_matmul(kernel_z_tf,
                                              c_tf,
                                              adjoint_a=True,
                                              adjoint_b=True))
            x_tf = x_tf * c_tf
        image = image.update(data=x_tf.numpy().reshape(image.shape))
        return image
コード例 #7
0
ファイル: mlem.py プロジェクト: twj2417/srf
    def __call__(self, listmode: Listmode) -> Image:
        declare_eager_execution()
        x_tf = Image(data=tf.Variable(
            np.ones(self.emap.shape, dtype=np.float32)),
                     center=self.emap.center,
                     size=self.emap.size)
        emap_data_n0_zero = copy(self.emap.data)
        emap_data_n0_zero[emap_data_n0_zero == 0.0] = 1e8
        emap_tf = self.emap.update(data=tf.constant(emap_data_n0_zero))
        lors_tf = listmode.lors.update(data=tf.constant(listmode.lors.data))
        listmode_tf = listmode.update(data=tf.constant(listmode.data),
                                      lors=lors_tf)

        for _ in tqdm(range(self.n_iter)):
            _listmode_tf = Project('tf-eager')(x_tf, lors_tf)
            _listmode_tf = _listmode_tf + np.mean(listmode.data) * 1e-8
            _bp = BackProject('tf-eager')(listmode_tf / _listmode_tf, emap_tf)
            x_tf = x_tf * _bp / emap_tf

        return x_tf.update(data=x_tf.data.numpy())
コード例 #8
0
def listmode_from_gate_out_multi_bed(path, scanner, nb_sub):
    full_data = {}
    for i in tqdm(range(nb_sub)):
        path_ = path.replace('?', str(i))
        data = np.load(path_)[:, :7]
        bed_id = np.load(path_)[:, -1]
        for i_bed in set(bed_id):
            if i_bed not in full_data:
                full_data[i_bed] = np.array(data[bed_id == i_bed, :])
            else:
                full_data[i_bed] = np.vstack(
                    (full_data[i_bed], data[bed_id == i_bed, :]))

    listmode_out = {}
    for key, values in full_data.items():
        lors = Lors(values)
        listmode = Listmode(np.ones((lors.length, ), dtype=np.float32), lors)
        row, col = ListmodeToId()(listmode, scanner)
        lors = IdToListmode()(row, col, scanner)
        listmode_out[key] = Listmode(
            np.ones((lors.shape[0], ), dtype=np.float32), lors)
    return listmode_out
コード例 #9
0
ファイル: emap_generator.py プロジェクト: twj2417/srf
    def __call__(self, image: Image):
        from srfnef import EcatIndexToCrystalPos
        if self.mode == 'full':
            declare_eager_execution()
            ind2pos = EcatIndexToCrystalPos(self.scanner)
            ind = np.arange(self.scanner.nb_crystals)
            pos1 = pos2 = ind2pos(ind)
            pos1_ = np.kron(pos1, [1] * pos2.size)
            pos2_ = np.kron(pos2, [[1]] * pos1.size).reshape(-1, 3)
            lors_data = np.hstack((pos1_, pos2_))
            listmode = LorsToListmode()(nef.Lors(lors_data))
            return Emap(**BackProject(
                mode='tf-eager')(listmode, image).asdict())
        elif self.mode == 'block':
            declare_eager_execution()
            single_block_scanner = self.scanner.update(nb_blocks_per_ring=1)
            ind2pos = EcatIndexToCrystalPos(single_block_scanner)
            ind = np.arange(self.scanner.nb_crystals_per_block *
                            self.scanner.nb_rings)
            pos1 = pos2 = ind2pos(ind)
            pos1_x = np.kron(pos1[:, 0], [1] * ind.size)
            pos1_y = np.kron(pos1[:, 1], [1] * ind.size)
            pos1_z = np.kron(pos1[:, 2], [1] * ind.size)
            pos1_ = np.vstack((pos1_x, pos1_y, pos1_z)).transpose()

            emap_data = np.zeros(image.shape, np.float32)
            emap_tf = Emap(data=tf.Variable(emap_data),
                           center=image.center,
                           size=image.size)
            for d in tqdm(range(self.scanner.nb_blocks_per_ring)):
                angle = d * self.scanner.angle_per_block
                print(angle)
                pos2_x = np.kron(pos2[:, 0], [[1]] * ind.size).ravel()
                pos2_y = np.kron(pos2[:, 1], [[1]] * ind.size).ravel()
                pos2_z = np.kron(pos2[:, 2], [[1]] * ind.size).ravel()
                pos2_ = np.vstack(
                    (pos2_x * np.cos(angle) - pos2_y * np.sin(angle),
                     pos2_x * np.sin(angle) + pos2_y * np.cos(angle),
                     pos2_z)).transpose()
                lors_data = np.hstack((pos1_, pos2_)).astype(np.float32)
                listmode = LorsToListmode()(nef.Lors(lors_data))
                listmode_tf = listmode.update(data=tf.Variable(listmode.data),
                                              lors=nef.Lors(
                                                  tf.Variable(lors_data)))
                _emap = BackProject(mode='tf')(listmode_tf, emap_tf)
                for i in range(self.scanner.nb_blocks_per_ring):
                    _emap_rotate_data = self._rotate_tf(
                        _emap.data, i * self.scanner.angle_per_block)
                    tf.compat.v1.assign_add(emap_tf.data, _emap_rotate_data)
            emap_data = emap_tf.data.numpy()
            return emap_tf.update(data=emap_data,
                                  center=image.center,
                                  size=image.size)

        elif self.mode == 'block-full':
            declare_eager_execution()
            single_block_scanner = self.scanner.update(nb_blocks_per_ring=1)
            ind2pos = EcatIndexToCrystalPos(single_block_scanner)
            ind = np.arange(self.scanner.nb_crystals_per_block *
                            self.scanner.nb_rings)
            pos1 = pos2 = ind2pos(ind)

            emap_data = np.zeros(image.shape, np.float32)
            emap_tf = Emap(data=tf.Variable(emap_data),
                           center=image.center,
                           size=image.size)
            for i in tqdm(range(self.scanner.nb_blocks_per_ring)):
                angle1 = i * self.scanner.angle_per_block
                pos1_x = np.kron(pos1[:, 0], [1] * ind.size)
                pos1_y = np.kron(pos1[:, 1], [1] * ind.size)
                pos1_z = np.kron(pos1[:, 2], [1] * ind.size)
                pos1_ = np.vstack(
                    (pos1_x * np.cos(angle1) - pos1_y * np.sin(angle1),
                     pos1_x * np.sin(angle1) + pos1_y * np.cos(angle1),
                     pos1_z)).transpose()
                for j in range(self.scanner.nb_blocks_per_ring):
                    angle2 = j * self.scanner.angle_per_block
                    pos2_x = np.kron(pos2[:, 0], [[1]] * ind.size).ravel()
                    pos2_y = np.kron(pos2[:, 1], [[1]] * ind.size).ravel()
                    pos2_z = np.kron(pos2[:, 2], [[1]] * ind.size).ravel()
                    pos2_ = np.vstack(
                        (pos2_x * np.cos(angle2) - pos2_y * np.sin(angle2),
                         pos2_x * np.sin(angle2) + pos2_y * np.cos(angle2),
                         pos2_z)).transpose()

                    lors_data = np.hstack((pos1_, pos2_)).astype(np.float32)
                    listmode = LorsToListmode()(nef.Lors(lors_data))
                    listmode_tf = listmode.update(
                        data=tf.Variable(listmode.data),
                        lors=nef.Lors(tf.Variable(lors_data)))
                    _emap = BackProject(mode='tf')(listmode_tf, emap_tf)
                    tf.compat.v1.assign_add(emap_tf.data, _emap.data)
            emap_data = emap_tf.data.numpy()
            return emap_tf.update(data=emap_data,
                                  center=image.center,
                                  size=image.size)
        elif self.mode == 'rsector':
            return self.update(mode='block')(image)
        elif self.mode == 'rsector-full':
            return self.update(mode='block-full')(image)
        else:
            raise NotImplementedError