예제 #1
0
def is_grids_list_of_grids(grids):

    if len(grids) == 0:
        return "pass"

    if isinstance(grids, list):
        if any(isinstance(i, tuple) for i in grids):
            return False
        elif any(isinstance(i, np.ndarray) for i in grids):
            if len(grids) == 1:
                return False
            else:
                return True
        elif any(isinstance(i, list) for i in grids):
            return True
        else:
            raise exc.PlottingException(
                "The grid entered into scatter_grid is a list of values, but its data-structure"
                "cannot be determined so as to make a scatter plot")
    elif isinstance(grids, np.ndarray):
        if len(grids.shape) == 2:
            return False
        else:
            raise exc.PlottingException(
                "The input grid into scatter_Grid is not 2D and therefore "
                "cannot be plotted using scatter.")
    else:
        raise exc.PlottingException(
            "The grid passed into scatter_grid is not a list or a ndarray.")
예제 #2
0
    def scatter_grid_indexes(self, grid, indexes):

        if not isinstance(grid, np.ndarray):
            raise exc.PlottingException(
                "The grid passed into scatter_grid_indexes is not a ndarray and thus its"
                "1D indexes cannot be marked and plotted.")

        if len(grid.shape) != 2:
            raise exc.PlottingException(
                "The grid passed into scatter_grid_indexes is not 2D (e.g. a flattened 1D"
                "grid) and thus its 1D indexes cannot be marked.")

        if isinstance(indexes, list):
            if not any(isinstance(i, list) for i in indexes):
                indexes = [indexes]

        color = itertools.cycle(self.colors)
        for index_list in indexes:

            if all([isinstance(index, float) for index in index_list]) or all(
                [isinstance(index, int) for index in index_list]):

                plt.scatter(
                    y=np.asarray(grid[index_list, 0]),
                    x=np.asarray(grid[index_list, 1]),
                    s=self.size,
                    color=next(color),
                    marker=self.marker,
                )

            elif all([isinstance(index, tuple)
                      for index in index_list]) or all(
                          [isinstance(index, list) for index in index_list]):

                ys = [index[0] for index in index_list]
                xs = [index[1] for index in index_list]

                plt.scatter(
                    y=np.asarray(grid.in_2d[ys, xs, 0]),
                    x=np.asarray(grid.in_2d[ys, xs, 1]),
                    s=self.size,
                    color=next(color),
                    marker=self.marker,
                )

            else:

                raise exc.PlottingException(
                    "The indexes input into the grid_scatter_index method do not conform to a "
                    "useable type")
예제 #3
0
    def norm_from_array(self, array: np.ndarray) -> object:
        """
        Returns the `Normalization` object which scales of the colormap.

        If vmin / vmax are not manually input by the user, the minimum / maximum values of the data being plotted
        are used.

        Parameters
        -----------
        array : np.ndarray
            The array of data which is to be plotted.
        """

        vmin = self.vmin_from_array(array=array)
        vmax = self.vmax_from_array(array=array)

        if self.config_dict["norm"] in "linear":
            return colors.Normalize(vmin=vmin, vmax=vmax)
        elif self.config_dict["norm"] in "log":
            if vmin == 0.0:
                vmin = 1.0e-4
            return colors.LogNorm(vmin=vmin, vmax=vmax)
        elif self.config_dict["norm"] in "symmetric_log":
            return colors.SymLogNorm(
                vmin=vmin,
                vmax=vmax,
                linthresh=self.config_dict["linthresh"],
                linscale=self.config_dict["linscale"],
            )
        else:
            raise exc.PlottingException(
                "The normalization (norm) supplied to the plotter is not a valid string (must be "
                "{linear, log, symmetric_log}")
예제 #4
0
    def set(self):
        """Setup the colorbar of the figure, specifically its ticksize and the size is appears relative to the figure.

        Parameters
        -----------
        cb_ticksize : int
            The size of the tick labels on the colorbar.
        cb_fraction : float
            The fraction of the figure that the colorbar takes up, which resizes the colorbar relative to the figure.
        cb_pad : float
            Pads the color bar in the figure, which resizes the colorbar relative to the figure.
        cb_tick_values : [float]
            Manually specified values of where the colorbar tick labels appear on the colorbar.
        cb_tick_labels : [float]
            Manually specified labels of the color bar tick labels, which appear where specified by cb_tick_values.
        """

        if self.tick_values is None and self.tick_labels is None:
            cb = plt.colorbar(fraction=self.fraction, pad=self.pad)
        elif self.tick_values is not None and self.tick_labels is not None:
            cb = plt.colorbar(fraction=self.fraction,
                              pad=self.pad,
                              ticks=self.tick_values)
            cb.ax.set_yticklabels(labels=self.tick_labels)
        else:
            raise exc.PlottingException(
                "Only 1 entry of tick_values or tick_labels was input. You must either supply"
                "both the values and labels, or neither.")

        cb.ax.tick_params(labelsize=self.ticksize)
예제 #5
0
    def draw_y_vs_x(self, y, x, plot_axis_type, label=None):

        if plot_axis_type is "linear":
            plt.plot(x,
                     y,
                     c=self.colors[0],
                     lw=self.width,
                     ls=self.style,
                     label=label)
        elif plot_axis_type is "semilogy":
            plt.semilogy(x,
                         y,
                         c=self.colors[0],
                         lw=self.width,
                         ls=self.style,
                         label=label)
        elif plot_axis_type is "loglog":
            plt.loglog(x,
                       y,
                       c=self.colors[0],
                       lw=self.width,
                       ls=self.style,
                       label=label)
        elif plot_axis_type is "scatter":
            plt.scatter(x, y, c=self.colors[0], s=self.pointsize, label=label)
        else:
            raise exc.PlottingException(
                "The plot_axis_type supplied to the plotter is not a valid string (must be linear "
                "| semilogy | loglog)")
예제 #6
0
    def norm_from_array(self, array):
        """Get the normalization scale of the colormap. This will be hyper based on the input min / max normalization \
        values.

        For a 'symmetric_log' colormap, linthesh and linscale also change the colormap.

        If norm_min / norm_max are not supplied, the minimum / maximum values of the array of data_type are used.

        Parameters
        -----------
        array : data_type.array.aa.Scaled
            The 2D array of data_type which is plotted.
        norm_min : float or None
            The minimum array value the colormap map spans (all values below this value are plotted the same color).
        norm_max : float or None
            The maximum array value the colormap map spans (all values above this value are plotted the same color).
        linthresh : float
            For the 'symmetric_log' colormap normalization ,this specifies the range of values within which the colormap \
            is linear.
        linscale : float
            For the 'symmetric_log' colormap normalization, this allowws the linear range set by linthresh to be stretched \
            relative to the logarithmic range.
        """

        if self.norm_min is None:
            norm_min = array.min()
        else:
            norm_min = self.norm_min

        if self.norm_max is None:
            norm_max = array.max()
        else:
            norm_max = self.norm_max

        if self.norm in "linear":
            return colors.Normalize(vmin=norm_min, vmax=norm_max)
        elif self.norm in "log":
            if norm_min == 0.0:
                norm_min = 1.0e-4
            return colors.LogNorm(vmin=norm_min, vmax=norm_max)
        elif self.norm in "symmetric_log":
            return colors.SymLogNorm(
                linthresh=self.linthresh,
                linscale=self.linscale,
                vmin=norm_min,
                vmax=norm_max,
            )
        else:
            raise exc.PlottingException(
                "The normalization (norm) supplied to the plotter is not a valid string (must be "
                "linear | log | symmetric_log")
예제 #7
0
    def set(self):
        """ Set the figure's colorbar, optionally overriding the tick labels and values with manual inputs. """

        if self.manual_tick_values is None and self.manual_tick_labels is None:
            cb = plt.colorbar(**self.config_dict)
        elif (self.manual_tick_values is not None
              and self.manual_tick_labels is not None):
            cb = plt.colorbar(ticks=self.manual_tick_values,
                              **self.config_dict)
            cb.ax.set_yticklabels(labels=self.manual_tick_labels)
        else:
            raise exc.PlottingException(
                "Only 1 entry of tick_values or tick_labels was input. You must either supply"
                "both the values and labels, or neither.")

        return cb
예제 #8
0
    def set_xticks(self, array, extent, units, symmetric_around_centre=False):
        """Get the extent of the dimensions of the array in the unit_label of the figure (e.g. arc-seconds or kpc).

        This is used to set the extent of the array and thus the y / x axis limits.

        Parameters
        -----------
        array : data_type.array.aa.Scaled
            The 2D array of data_type which is plotted.
        unit_label : str
            The label for the unit_label of the y / x axis of the plots.
        unit_conversion_factor : float
            The conversion factor between arc-seconds and kiloparsecs, required to plotters the unit_label in kpc.
        xticks_manual :  [] or None
            If input, the xticks do not use the array's default xticks but instead overwrite them as these values.
        xticks_manual :  [] or None
            If input, the xticks do not use the array's default xticks but instead overwrite them as these values.
        """

        plt.tick_params(labelsize=self.xsize)

        if symmetric_around_centre:
            return

        xticks = np.linspace(extent[0], extent[1], 5)

        if self.x_manual is not None:
            xtick_labels = np.asarray([self.x_manual[0], self.x_manual[3]])
        elif not units.use_scaled:
            xtick_labels = np.linspace(0, array.shape_2d[0], 5).astype("int")
        elif units.use_scaled and units.conversion_factor is None:
            xtick_labels = np.round(np.linspace(extent[0], extent[1], 5), 2)
        elif units.use_scaled and units.conversion_factor is not None:
            xtick_labels = np.round(
                np.linspace(
                    extent[0] * units.conversion_factor,
                    extent[1] * units.conversion_factor,
                    5,
                ),
                2,
            )

        else:
            raise exc.PlottingException(
                "The y and y ticks cannot be set using the input options.")

        plt.xticks(ticks=xticks, labels=xtick_labels)
예제 #9
0
    def wrapper(*args, **kwargs):

        plotter_key = plotter_key_from_dictionary(dictionary=kwargs)
        plotter = kwargs[plotter_key]

        if not isinstance(plotter, SubPlotter):
            raise exc.PlottingException(
                "The decorator set_subplot_title was applied to a function without a SubPlotter class"
            )

        filename = plotter.output.filename_from_func(func=func)

        plotter = plotter.plotter_with_new_output(filename=filename)

        kwargs[plotter_key] = plotter

        return func(*args, **kwargs)
예제 #10
0
    def scatter_colored_grid(self, grid, color_array, cmap):

        list_of_grids = is_grids_list_of_grids(grids=grid)

        if not list_of_grids:

            plt.scatter(
                y=np.asarray(grid)[:, 0],
                x=np.asarray(grid)[:, 1],
                s=self.size,
                c=color_array,
                marker=self.marker,
                cmap=cmap,
            )

        else:

            raise exc.PlottingException(
                "Cannot plot colorred grid if input grid is a list of grids.")
예제 #11
0
    def tick_values_in_units_from(self, array: array_2d.Array2D,
                                  min_value: float, max_value: float,
                                  units: Units) -> typing.Optional[np.ndarray]:
        """
        Calculate the labels used for the yticks or xticks from input values of the minimum and maximum coordinate
        values of the y and x axis.

        The values are converted to the `Units` of the figure, via its conversion factor or using data properties.

        Parameters
        ----------
        array : array_2d.Array2D
            The array of data that is to be plotted, whose 2D shape is used to determine the tick values in units of
            pixels if this is the units specified by `units`.
        min_value : float
            the minimum value of the ticks that figure is plotted using.
        max_value : float
            the maximum value of the ticks that figure is plotted using.
        units : Units
            The units the tick values are plotted using.
        """

        if self.manual_values is not None:
            return np.asarray(self.manual_values)
        elif not units.use_scaled:
            return np.linspace(0, array.shape_native[0], 5).astype("int")
        elif (units.use_scaled
              and units.conversion_factor is None) or not units.in_kpc:
            return np.round(np.linspace(min_value, max_value, 5), 2)
        elif units.use_scaled and units.conversion_factor is not None:
            return np.round(
                np.linspace(
                    min_value * units.conversion_factor,
                    max_value * units.conversion_factor,
                    5,
                ),
                2,
            )

        else:
            raise exc.PlottingException(
                "The tick labels cannot be computed using the input options.")
예제 #12
0
    def plot_y_vs_x(
        self,
        y: typing.Union[np.ndarray, array_1d.Array1D],
        x: typing.Union[np.ndarray, array_1d.Array1D],
        label: str = None,
        plot_axis_type=None,
    ):
        """
        Plots 1D y-data against 1D x-data using the matplotlib method `plt.plot`, `plt.semilogy`, `plt.loglog`,
        or `plt.scatter`.

        Parameters
        ----------
        y : np.ndarray or array_1d.Array1D
            The ydata that is plotted.
        x : np.ndarray or lines.Line
            The xdata that is plotted.
        plot_axis_type : str
            The method used to make the plot that defines the scale of the axes {"linear", "semilogy", "loglog",
            "scatter"}.
        label : str
            Optionally include a label on the plot for a `Legend` to display.
        """

        if plot_axis_type == "linear" or plot_axis_type == "symlog":
            plt.plot(x, y, label=label, **self.config_dict)
        elif plot_axis_type == "semilogy":
            plt.semilogy(x, y, label=label, **self.config_dict)
        elif plot_axis_type == "loglog":
            plt.loglog(x, y, label=label, **self.config_dict)
        elif plot_axis_type == "scatter":
            plt.scatter(x, y, label=label, **self.config_dict)
        else:
            raise exc.PlottingException(
                "The plot_axis_type supplied to the plotter is not a valid string (must be linear "
                "{semilogy, loglog})")
예제 #13
0
    def scatter_grid_indexes(self, grid: typing.Union[np.ndarray,
                                                      grid_2d.Grid2D],
                             indexes: np.ndarray):
        """
        Plot specific points of an input grid of (y,x) coordinates, which are specified according to the 1D or 2D
        indexes of the `Grid2D`.

        This method allows us to color in points on grids that map between one another.

        Parameters
        ----------
        grid : Grid2D
            The grid of (y,x) coordinates that is plotted.
        indexes : np.ndarray
            The 1D indexes of the grid that are colored in when plotted.
        """
        if not isinstance(grid, np.ndarray):
            raise exc.PlottingException(
                "The grid passed into scatter_grid_indexes is not a ndarray and thus its"
                "1D indexes cannot be marked and plotted.")

        if len(grid.shape) != 2:
            raise exc.PlottingException(
                "The grid passed into scatter_grid_indexes is not 2D (e.g. a flattened 1D"
                "grid) and thus its 1D indexes cannot be marked.")

        if isinstance(indexes, list):
            if not any(isinstance(i, list) for i in indexes):
                indexes = [indexes]

        color = itertools.cycle(self.config_dict["c"])
        config_dict = self.config_dict
        config_dict.pop("c")

        for index_list in indexes:

            if all([isinstance(index, float) for index in index_list]) or all(
                [isinstance(index, int) for index in index_list]):

                plt.scatter(
                    y=grid[index_list, 0],
                    x=grid[index_list, 1],
                    color=next(color),
                    **config_dict,
                )

            elif all([isinstance(index, tuple)
                      for index in index_list]) or all(
                          [isinstance(index, list) for index in index_list]):

                ys, xs = map(list, zip(*index_list))

                plt.scatter(
                    y=grid.native[ys, xs, 0],
                    x=grid.native[ys, xs, 1],
                    color=next(color),
                    **config_dict,
                )

            else:

                raise exc.PlottingException(
                    "The indexes input into the grid_scatter_index method do not conform to a "
                    "useable type")