コード例 #1
0
ファイル: test_spectrum.py プロジェクト: zizai/pymatgen
 def test_site_weighted_spectrum(self):
     weighted_spectrum = site_weighted_spectrum([self.site1_xanes,
                                                 self.site2_xanes])
     self.assertIsInstance(weighted_spectrum, XAS)
     self.assertTrue(len(weighted_spectrum.x), 500)
     # The site multiplicities for site1 and site2 are 4 and 2, respectively.
     self.assertAlmostEqual(weighted_spectrum.y[0], (4*self.site1_xanes.y[0] +
                            2*self.site2_xanes.y[0])/6, 2)
     self.assertEqual(min(weighted_spectrum.x),
                      max(min(self.site1_xanes.x), min(self.site2_xanes.x)))
     self.site2_xanes.absorbing_index = self.site1_xanes.absorbing_index
     self.assertRaises(ValueError, site_weighted_spectrum,
                       [self.site1_xanes, self.site2_xanes])
コード例 #2
0
ファイル: xas.py プロジェクト: utf/emmet
    def process_spectra(self, items: List[Dict]) -> Dict:

        all_spectra = [feff_task_to_spectrum(task) for task in items]

        # Dictionary of all site to spectra mapping
        sites_to_spectra = {
            index: list(group)
            for index, group in groupby(
                sorted(all_spectra, key=lambda x: x.absorbing_index),
                key=lambda x: x.absorbing_index,
            )
        }

        # perform spectra merging
        for site, spectra in sites_to_spectra.items():
            type_to_spectra = {
                index: list(group)
                for index, group in groupby(
                    sorted(
                        spectra, key=lambda x: (x.edge, x.spectrum_type, x.last_updated)
                    ),
                    key=lambda x: (x.edge, x.spectrum_type),
                )
            }
            # Make K-Total
            if ("K", "XANES") in type_to_spectra and ("K", "EXAFS") in type_to_spectra:
                xanes = type_to_spectra[("K", "XANES")][-1]
                exafs = type_to_spectra[("K", "EXAFS")][-1]
                try:
                    total_spectrum = xanes.stitch(exafs, mode="XAFS")
                    total_spectrum.absorbing_index = site
                    total_spectrum.task_ids = xanes.task_ids + exafs.task_ids
                    all_spectra.append(total_spectrum)
                except ValueError as e:
                    self.logger.warning(e)

            # Make L23
            if ("L2", "XANES") in type_to_spectra and (
                "L3",
                "XANES",
            ) in type_to_spectra:
                l2 = type_to_spectra[("L2", "XANES")][-1]
                l3 = type_to_spectra[("L3", "XANES")][-1]
                try:
                    total_spectrum = l2.stitch(l3, mode="L23")
                    total_spectrum.absorbing_index = site
                    total_spectrum.task_ids = l2.task_ids + l3.task_ids
                    all_spectra.append(total_spectrum)
                except ValueError as e:
                    self.logger.warning(e)

        self.logger.debug(f"Found {len(all_spectra)} spectra")

        # Site-weighted averaging
        spectra_to_average = [
            list(group)
            for _, group in groupby(
                sorted(
                    all_spectra,
                    key=lambda x: (x.absorbing_element, x.edge, x.spectrum_type),
                ),
                key=lambda x: lambda x: (x.absorbing_element, x.edge, x.spectrum_type),
            )
        ]
        averaged_spectra = []

        for relevant_spectra in spectra_to_average:

            if len(relevant_spectra) > 0 and not is_missing_sites(relevant_spectra):
                if len(relevant_spectra) > 1:
                    try:
                        avg_spectrum = site_weighted_spectrum(
                            relevant_spectra, num_samples=self.num_samples
                        )
                        avg_spectrum.task_ids = [
                            id
                            for spectrum in relevant_spectra
                            for id in spectrum.task_ids
                        ]
                        averaged_spectra.append(avg_spectrum)
                    except ValueError as e:
                        self.logger.error(e)
                else:
                    averaged_spectra.append(relevant_spectra[0])

        spectra_docs = [
            XASDoc.from_spectrum(spectrum).dict() for spectrum in averaged_spectra
        ]

        return spectra_docs