def test_init(self): # energies and weights must be equal lengths with pytest.raises(ValueError): GridDOSData(np.linspace(0, 10, 11), np.zeros(10)) # energies must be evenly spaced with pytest.raises(ValueError): GridDOSData(np.linspace(0, 10, 11)**2, np.zeros(11))
def test_addition(self, dense_dos, another_dense_dos): sum_dos = dense_dos + another_dense_dos assert np.allclose(sum_dos.get_energies(), dense_dos.get_energies()) assert np.allclose(sum_dos.get_weights(), dense_dos.get_weights() * 3) assert sum_dos.info == {'symbol': 'C'} with pytest.raises(ValueError): dense_dos + GridDOSData(dense_dos.get_energies() + 1., dense_dos.get_weights()) with pytest.raises(ValueError): dense_dos + GridDOSData(dense_dos.get_energies()[1:], dense_dos.get_weights()[1:])
def test_init_errors(self, griddos): with pytest.raises(TypeError): GridDOSCollection([RawDOSData([1.], [1.])]) with pytest.raises(ValueError): energies = np.linspace(1, 10, 7) + 1 GridDOSCollection( [griddos, GridDOSData(energies, np.sin(energies))]) with pytest.raises(ValueError): energies = np.linspace(1, 10, 6) GridDOSCollection( [griddos, GridDOSData(energies, np.sin(energies))]) with pytest.raises(ValueError): GridDOSCollection([], energies=None) with pytest.raises(ValueError): GridDOSCollection([griddos], energies=np.linspace(1, 10, 6))
def from_data(cls, energies: Sequence[float], weights: Sequence[Sequence[float]], info: Sequence[Info] = None) -> 'GridDOSCollection': """Create a GridDOSCollection from data with a common set of energies This convenience method may also be more efficient as it limits redundant copying/checking of the data. Args: energies: common set of energy values for input data weights: array of DOS weights with rows corresponding to different datasets info: sequence of info dicts corresponding to weights rows. Returns: Collection of DOS data (in RawDOSData format) """ weights_array = np.asarray(weights, dtype=float) if len(weights_array.shape) != 2: raise IndexError("Weights must be a 2-D array or nested sequence") if weights_array.shape[0] < 1: raise IndexError("Weights cannot be empty") if weights_array.shape[1] != len(energies): raise IndexError("Length of weights rows must equal size of x") info = cls._check_weights_and_info(weights, info) dos_collection = cls([GridDOSData(energies, weights_array[0])]) dos_collection._weights = weights_array dos_collection._info = list(info) return dos_collection
def plot(self, npts: int = 0, xmin: float = None, xmax: float = None, width: float = None, smearing: str = 'Gauss', ax: 'matplotlib.axes.Axes' = None, show: bool = False, filename: str = None, mplargs: dict = None) -> 'matplotlib.axes.Axes': """Simple plot of collected DOS data, resampled onto a grid If the special key 'label' is present in self.info, this will be set as the label for the plotted line (unless overruled in mplargs). The label is only seen if a legend is added to the plot (i.e. by calling `ax.legend()`). Args: npts: Number of points in resampled x-axis. If set to zero (default), no resampling is performed and the stored data is plotted directly. xmin, xmax: output data range; this limits the resampling range as well as the plotting output width: Width of broadening kernel, passed to self.sample() smearing: selection of broadening kernel, passed to self.sample() ax: existing Matplotlib axes object. If not provided, a new figure with one set of axes will be created using Pyplot show: show the figure on-screen filename: if a path is given, save the figure to this file mplargs: additional arguments to pass to matplotlib plot command (e.g. {'linewidth': 2} for a thicker line). Returns: Plotting axes. If "ax" was set, this is the same object. """ # Apply defaults if necessary npts, width = GridDOSData._interpret_smearing_args(npts, width) if npts: assert isinstance(width, float) dos = self.sample_grid(npts, xmin=xmin, xmax=xmax, width=width, smearing=smearing) else: dos = self energies, all_y = dos._energies, dos._weights all_labels = [DOSData.label_from_info(data.info) for data in self] with SimplePlottingAxes(ax=ax, show=show, filename=filename) as ax: self._plot_broadened(ax, energies, all_y, all_labels, mplargs) return ax
def dense_dos(self): x = np.linspace(0., 10., 11) y = np.sin(x / 10) return GridDOSData(x, y, info={ 'symbol': 'C', 'orbital': '2s', 'day': 'Tue' })
def another_dense_dos(self): x = np.linspace(0., 10., 11) y = np.sin(x / 10) * 2 return GridDOSData(x, y, info={ 'symbol': 'C', 'orbital': '2p', 'month': 'Feb' })
def __getitem__(self, item): # noqa F811 if isinstance(item, int): return GridDOSData(self._energies, self._weights[item, :], info=self._info[item]) elif isinstance(item, slice): return type(self)([self[i] for i in range(len(self))[item]]) else: raise TypeError("index in DOSCollection must be an integer or " "slice")
def test_smearing_args_interpreter(self, inputs, expected): assert GridDOSData._interpret_smearing_args(**inputs) == expected
def denser_dos(self): x = np.linspace(0., 10., 21) y = np.sin(x / 10) return GridDOSData(x, y)
def another_griddos(self): energies = np.linspace(1, 10, 7) weights = np.cos(energies) return GridDOSData(energies, weights, info={'my_key': 'other_value'})
def griddos(self): energies = np.linspace(1, 10, 7) weights = np.sin(energies) return GridDOSData(energies, weights, info={'my_key': 'my_value'})