Exemple #1
0
def plot_color_color(colorbox,
                     objects,
                     label_x,
                     label_y,
                     output,
                     xlim=None,
                     ylim=None,
                     offset=None,
                     legend='upper left'):
    """
    Parameters
    ----------
    colorbox : species.core.box.ColorMagBox
        Box with the colors and magnitudes.
    objects : tuple(tuple(str, str, str, str), )
        Tuple with individual objects. The objects require a tuple with their database tag, the
        two filter IDs for the color, and the filter ID for the absolute magnitude.
    label_x : str
        Label for the x-axis.
    label_y : str
        Label for the y-axis.
    output : str
        Output filename.
    xlim : tuple(float, float)
        Limits for the x-axis.
    ylim : tuple(float, float)
        Limits for the y-axis.
    offset : tuple(float, float)
        Offset of the x- and y-axis label.
    legend : str
        Legend position.

    Returns
    -------
    None
    """

    marker = itertools.cycle(('o', 's', '<', '>', 'p', 'v', '^', '*',
                              'd', 'x', '+', '1', '2', '3', '4'))

    sys.stdout.write('Plotting color-color diagram: '+output+'... ')
    sys.stdout.flush()

    plt.figure(1, figsize=(4, 4.3))
    gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 4.])
    gridsp.update(wspace=0., hspace=0., left=0, right=1, bottom=0, top=1)

    ax1 = plt.subplot(gridsp[2, 0])
    ax2 = plt.subplot(gridsp[0, 0])

    ax1.grid(True, linestyle=':', linewidth=0.7, color='silver', dashes=(1, 4), zorder=0)

    ax1.tick_params(axis='both', which='major', colors='black', labelcolor='black',
                    direction='in', width=0.8, length=5, labelsize=12, top=True,
                    bottom=True, left=True, right=True)

    ax1.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
                    direction='in', width=0.8, length=3, labelsize=12, top=True,
                    bottom=True, left=True, right=True)

    ax1.set_xlabel(label_x, fontsize=14)
    ax1.set_ylabel(label_y, fontsize=14)

    ax1.invert_yaxis()

    if offset:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.08)
        ax1.get_yaxis().set_label_coords(-0.12, 0.5)

    if xlim:
        ax1.set_xlim(xlim[0], xlim[1])

    if ylim:
        ax1.set_ylim(ylim[0], ylim[1])

    cmap = plt.cm.viridis
    bounds = np.arange(0, 8, 1)
    norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

    sptype = colorbox.sptype
    color1 = colorbox.color1
    color2 = colorbox.color2

    indices = np.where(sptype != 'None')[0]

    sptype = sptype[indices]
    color1 = color1[indices]
    color2 = color2[indices]

    spt_disc = plot_util.sptype_discrete(sptype, color1.shape)

    _, unique = np.unique(color1, return_index=True)

    sptype = sptype[unique]
    color1 = color1[unique]
    color2 = color2[unique]
    spt_disc = spt_disc[unique]

    scat = ax1.scatter(color1, color2, c=spt_disc, cmap=cmap, norm=norm,
                       zorder=5, s=40, alpha=0.6, edgecolor='none')

    cb = Colorbar(ax=ax2, mappable=scat, orientation='horizontal',
                  ticklocation='top', format='%.2f')

    cb.ax.tick_params(width=0.8, length=5, labelsize=10, direction='in', color='black')
    cb.set_ticks(np.arange(0.5, 7., 1.))
    cb.set_ticklabels(['M0-M4', 'M5-M9', 'L0-L4', 'L5-L9', 'T0-T4', 'T6-T8', 'Y1-Y2'])

    if objects is not None:
        for item in objects:
            objdata = read_object.ReadObject(item[0])

            mag1 = objdata.get_photometry(item[1][0])[0]
            mag2 = objdata.get_photometry(item[1][1])[0]
            mag3 = objdata.get_photometry(item[2][0])[0]
            mag4 = objdata.get_photometry(item[2][1])[0]

            err1 = objdata.get_photometry(item[1][0])[1]
            err2 = objdata.get_photometry(item[1][1])[1]
            err3 = objdata.get_photometry(item[2][0])[1]
            err4 = objdata.get_photometry(item[2][1])[1]

            color1 = mag1 - mag2
            color2 = mag3 - mag4

            error1 = math.sqrt(err1**2+err2**2)
            error2 = math.sqrt(err3**2+err4**2)

            ax1.errorbar(color1, color2, xerr=error1, yerr=error2,
                         marker=next(marker), ms=6, color='black', label=objdata.object_name,
                         markerfacecolor='white', markeredgecolor='black', zorder=10)

    handles, labels = ax1.get_legend_handles_labels()

    if handles:
        handles = [h[0] for h in handles]
        ax1.legend(handles, labels, loc=legend, prop={'size': 9}, frameon=False, numpoints=1)

    plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
    plt.close()

    sys.stdout.write('[DONE]\n')
    sys.stdout.flush()
Exemple #2
0
def plot_color_color(boxes: list,
                     objects: Optional[Union[
                         List[Tuple[str, Tuple[str, str], Tuple[str, str]]],
                         List[Tuple[str, Tuple[str, str], Tuple[str, str],
                                    Optional[dict], Optional[dict]]]]] = None,
                     mass_labels: Optional[Union[List[float],
                                                 List[Tuple[float,
                                                            str]]]] = None,
                     teff_labels: Optional[Union[List[float],
                                                 List[Tuple[float,
                                                            str]]]] = None,
                     companion_labels: bool = False,
                     reddening: Optional[List[Tuple[Tuple[str,
                                                          str], Tuple[str,
                                                                      str],
                                                    Tuple[str,
                                                          float], str, float,
                                                    Tuple[float,
                                                          float]]]] = None,
                     field_range: Optional[Tuple[str, str]] = None,
                     label_x: str = 'Color',
                     label_y: str = 'Color',
                     xlim: Optional[Tuple[float, float]] = None,
                     ylim: Optional[Tuple[float, float]] = None,
                     offset: Optional[Tuple[float, float]] = None,
                     legend: Optional[Union[str, dict,
                                            Tuple[float,
                                                  float]]] = 'upper left',
                     figsize: Optional[Tuple[float, float]] = (4., 4.3),
                     output: str = 'color-color.pdf') -> None:
    """
    Function for creating a color-color diagram.

    Parameters
    ----------
    boxes : list(species.core.box.ColorColorBox, species.core.box.IsochroneBox, )
        Boxes with the color-color and isochrone data from photometric libraries, spectral
        libraries, and/or atmospheric models. The synthetic data have to be created with
        :func:`~species.read.read_isochrone.ReadIsochrone.get_color_color`. These boxes
        contain synthetic colors for a given age and a range of masses.
    objects : tuple(tuple(str, tuple(str, str), tuple(str, str)), ),
              tuple(tuple(str, tuple(str, str), tuple(str, str), dict, dict), ), None
        Tuple with individual objects. The objects require a tuple with their database tag, the two
        filter names for the first color, and the two filter names for the second color.
        Optionally, a dictionary with keyword arguments can be provided for the object's marker and
        label, respectively. For example, ``{'marker': 'o', 'ms': 10}`` for the marker and
        ``{'ha': 'left', 'va': 'bottom', 'xytext': (5, 5)})`` for the label. Not used if set to
        None.
    mass_labels : list(float, ), list(tuple(float, str), ), None
        Plot labels with masses next to the isochrone data of `models`. The list with masses has
        to be provided in Jupiter mass. Alternatively, a list of tuples can be provided with
        the planet mass and position of the label ('left' or 'right), for example
        ``[(10., 'left'), (20., 'right')]``. No labels are shown if set to None.
    teff_labels : list(float, ), list(tuple(float, str), ), None
        Plot labels with temperatures (K) next to the synthetic Planck photometry. Alternatively,
        a list of tuples can be provided with the planet mass and position of the label ('left' or
        'right), for example ``[(1000., 'left'), (1200., 'right')]``. No labels are shown if set
        to None.
    companion_labels : bool
        Plot labels with the names of the directly imaged companions.
    reddening : list(tuple(tuple(str, str), tuple(str, str), tuple(str, float), str, float, tuple(float, float)), None
        Include reddening arrows by providing a list with tuples. Each tuple contains the filter
        names for the color, the filter name for the magnitude, the particle radius (um), and the
        start position (color, mag) of the arrow in the plot, so (filter_color_1, filter_color_2,
        filter_mag, composition, radius, (x_pos, y_pos)). The composition can be either 'Fe' or
        'MgSiO3' (both with crystalline structure). The parameter is not used if set to ``None``.
    field_range : tuple(str, str), None
        Range of the discrete colorbar for the field dwarfs. The tuple should contain the lower
        and upper value ('early M', 'late M', 'early L', 'late L', 'early T', 'late T', 'early Y).
        The full range is used if set to None.
    label_x : str
        Label for the x-axis.
    label_y : str
        Label for the y-axis.
    xlim : tuple(float, float)
        Limits for the x-axis.
    ylim : tuple(float, float)
        Limits for the y-axis.
    offset : tuple(float, float), None
        Offset of the x- and y-axis label.
    legend : str, tuple(float, float), dict, None
        Legend position or keyword arguments. No legend is shown if set to ``None``.
    figsize : tuple(float, float)
        Figure size.
    output : str
        Output filename.

    Returns
    -------
    NoneType
        None
    """

    mpl.rcParams['font.serif'] = ['Bitstream Vera Serif']
    mpl.rcParams['font.family'] = 'serif'

    plt.rc('axes', edgecolor='black', linewidth=2.2)

    model_color = ('#234398', '#f6a432', 'black')
    model_linestyle = ('-', '--', ':', '-.')

    isochrones = []
    planck = []
    models = []
    empirical = []

    for item in boxes:
        if isinstance(item, box.IsochroneBox):
            isochrones.append(item)

        elif isinstance(item, box.ColorColorBox):
            if item.object_type == 'model':
                models.append(item)

            elif item.library == 'planck':
                planck.append(item)

            else:
                empirical.append(item)

        else:
            raise ValueError(
                f'Found a {type(item)} while only ColorColorBox and IsochroneBox '
                f'objects can be provided to \'boxes\'.')

    plt.figure(1, figsize=figsize)
    gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 4.])
    gridsp.update(wspace=0., hspace=0., left=0, right=1, bottom=0, top=1)

    ax1 = plt.subplot(gridsp[2, 0])
    ax2 = plt.subplot(gridsp[0, 0])

    ax1.tick_params(axis='both',
                    which='major',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=1,
                    length=5,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True)

    ax1.tick_params(axis='both',
                    which='minor',
                    colors='black',
                    labelcolor='black',
                    direction='in',
                    width=1,
                    length=3,
                    labelsize=12,
                    top=True,
                    bottom=True,
                    left=True,
                    right=True)

    ax1.xaxis.set_major_locator(MultipleLocator(0.5))
    ax1.yaxis.set_major_locator(MultipleLocator(0.5))

    ax1.xaxis.set_minor_locator(MultipleLocator(0.1))
    ax1.yaxis.set_minor_locator(MultipleLocator(0.1))

    ax1.set_xlabel(label_x, fontsize=14)
    ax1.set_ylabel(label_y, fontsize=14)

    ax1.invert_yaxis()

    if offset:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.08)
        ax1.get_yaxis().set_label_coords(-0.12, 0.5)

    if xlim:
        ax1.set_xlim(xlim[0], xlim[1])

    if ylim:
        ax1.set_ylim(ylim[0], ylim[1])

    if models is not None:
        count = 0

        model_dict = {}

        for j, item in enumerate(models):
            if item.library not in model_dict:
                model_dict[item.library] = [count, 0]
                count += 1

            else:
                model_dict[item.library] = [
                    model_dict[item.library][0],
                    model_dict[item.library][1] + 1
                ]

            model_count = model_dict[item.library]

            if model_count[1] == 0:
                label = plot_util.model_name(item.library)

                if item.library == 'zhu2015':
                    ax1.plot(item.color1,
                             item.color2,
                             marker='x',
                             ms=5,
                             linestyle=model_linestyle[model_count[1]],
                             linewidth=0.6,
                             color='gray',
                             label=label,
                             zorder=0)

                    xlim = ax1.get_xlim()
                    ylim = ax1.get_ylim()

                    for i, teff_item in enumerate(item.sptype):
                        teff_label = rf'{teff_item:.0e} $M_\mathregular{{Jup}}^{2}$ yr$^{{-1}}$'

                        if item.color2[i] < ylim[1]:
                            ax1.annotate(teff_label,
                                         (item.color1[i], item.color2[i]),
                                         color='gray',
                                         fontsize=8,
                                         ha='left',
                                         va='center',
                                         xytext=(item.color1[i] + 0.1,
                                                 item.color2[i] - 0.05),
                                         zorder=3)

                else:
                    ax1.plot(item.color1,
                             item.color2,
                             linestyle=model_linestyle[model_count[1]],
                             linewidth=1.,
                             color=model_color[model_count[0]],
                             label=label,
                             zorder=0)

                    if mass_labels is not None:
                        interp_color1 = interp1d(item.sptype, item.color1)
                        interp_color2 = interp1d(item.sptype, item.color2)

                        for i, mass_item in enumerate(mass_labels):
                            if isinstance(mass_item, tuple):
                                mass_val = mass_item[0]
                                mass_pos = mass_item[1]

                            else:
                                mass_val = mass_item
                                mass_pos = 'right'

                            # if j == 0 or (j > 0 and mass_val < 20.):
                            if j == 0:
                                pos_color1 = interp_color1(mass_val)
                                pos_color2 = interp_color2(mass_val)

                                if mass_pos == 'left':
                                    mass_ha = 'right'
                                    mass_xytext = (pos_color1 - 0.05,
                                                   pos_color2)

                                else:
                                    mass_ha = 'left'
                                    mass_xytext = (pos_color1 + 0.05,
                                                   pos_color2)

                                mass_label = str(
                                    int(mass_val)) + r' M$_\mathregular{J}$'

                                xlim = ax1.get_xlim()
                                ylim = ax1.get_ylim()

                                if xlim[0]+0.2 < pos_color1 < xlim[1]-0.2 and \
                                        ylim[0]+0.2 < pos_color2 < ylim[1]-0.2:

                                    ax1.scatter(pos_color1,
                                                pos_color2,
                                                c=model_color[model_count[0]],
                                                s=15,
                                                edgecolor='none',
                                                zorder=0)

                                    ax1.annotate(
                                        mass_label, (pos_color1, pos_color2),
                                        color=model_color[model_count[0]],
                                        fontsize=9,
                                        xytext=mass_xytext,
                                        ha=mass_ha,
                                        va='center',
                                        zorder=3)

            else:
                ax1.plot(item.color1,
                         item.color2,
                         linestyle=model_linestyle[model_count[1]],
                         linewidth=0.6,
                         color=model_color[model_count[0]],
                         label=label,
                         zorder=0)

    if planck is not None:
        planck_count = 0

        for j, item in enumerate(planck):

            if planck_count == 0:
                label = plot_util.model_name(item.library)

                ax1.plot(item.color1,
                         item.color2,
                         ls='--',
                         linewidth=0.8,
                         color='black',
                         label=label,
                         zorder=0)

                if teff_labels is not None:
                    interp_color1 = interp1d(item.sptype, item.color1)
                    interp_color2 = interp1d(item.sptype, item.color2)

                    for i, teff_item in enumerate(teff_labels):
                        if isinstance(teff_item, tuple):
                            teff_val = teff_item[0]
                            teff_pos = teff_item[1]

                        else:
                            teff_val = teff_item
                            teff_pos = 'right'

                        if j == 0 or (j > 0 and teff_val < 20.):
                            pos_color1 = interp_color1(teff_val)
                            pos_color2 = interp_color2(teff_val)

                            if teff_pos == 'left':
                                teff_ha = 'right'
                                teff_xytext = (pos_color1 - 0.05, pos_color2)

                            else:
                                teff_ha = 'left'
                                teff_xytext = (pos_color1 + 0.05, pos_color2)

                            teff_label = f'{int(teff_val)} K'

                            xlim = ax1.get_xlim()
                            ylim = ax1.get_ylim()

                            if xlim[0]+0.2 < pos_color1 < xlim[1]-0.2 and \
                                    ylim[0]+0.2 < pos_color2 < ylim[1]-0.2:

                                ax1.scatter(pos_color1,
                                            pos_color2,
                                            c='black',
                                            s=15,
                                            edgecolor='none',
                                            zorder=0)

                                ax1.annotate(teff_label,
                                             (pos_color1, pos_color2),
                                             color='black',
                                             fontsize=9,
                                             xytext=teff_xytext,
                                             zorder=3,
                                             ha=teff_ha,
                                             va='center')

            else:
                ax1.plot(item.color1,
                         item.color2,
                         ls='--',
                         lw=0.5,
                         color='black',
                         zorder=0)

            planck_count += 1

    if empirical:
        cmap = plt.cm.viridis

        bounds, ticks, ticklabels = plot_util.field_bounds_ticks(field_range)
        norm = mpl.colors.BoundaryNorm(bounds, cmap.N)

        for item in empirical:
            sptype = item.sptype
            names = item.names
            color1 = item.color1
            color2 = item.color2

            if isinstance(sptype, list):
                sptype = np.array(sptype)

            if item.object_type in ['field', None]:
                indices = np.where(sptype != 'None')[0]

                sptype = sptype[indices]
                color1 = color1[indices]
                color2 = color2[indices]

                spt_disc = plot_util.sptype_substellar(sptype, color1.shape)
                _, unique = np.unique(color1, return_index=True)

                sptype = sptype[unique]
                color1 = color1[unique]
                color2 = color2[unique]
                spt_disc = spt_disc[unique]

                scat = ax1.scatter(color1,
                                   color2,
                                   c=spt_disc,
                                   cmap=cmap,
                                   norm=norm,
                                   s=50,
                                   alpha=0.7,
                                   edgecolor='none',
                                   zorder=2)

                cb = Colorbar(ax=ax2,
                              mappable=scat,
                              orientation='horizontal',
                              ticklocation='top',
                              format='%.2f')

                cb.ax.tick_params(width=1,
                                  length=5,
                                  labelsize=10,
                                  direction='in',
                                  color='black')

                cb.set_ticks(ticks)
                cb.set_ticklabels(ticklabels)

            elif item.object_type == 'young':
                if objects is not None:
                    object_names = []

                    for obj_item in objects:
                        object_names.append(obj_item[0])

                    indices = plot_util.remove_color_duplicates(
                        object_names, names)

                    color1 = color1[indices]
                    color2 = color2[indices]

                ax1.plot(color1,
                         color2,
                         marker='s',
                         ms=4,
                         linestyle='none',
                         alpha=0.7,
                         color='gray',
                         markeredgecolor='black',
                         label='Young/low-gravity',
                         zorder=2)

    if isochrones:
        for item in isochrones:
            ax1.plot(item.colors[0],
                     item.colors[1],
                     linestyle='-',
                     linewidth=1.,
                     color='black')

    if reddening is not None:
        for item in reddening:
            ext_1, ext_2 = dust_util.calc_reddening(item[0],
                                                    item[2],
                                                    composition=item[3],
                                                    structure='crystalline',
                                                    radius_g=item[4])

            ext_3, ext_4 = dust_util.calc_reddening(item[1],
                                                    item[2],
                                                    composition=item[3],
                                                    structure='crystalline',
                                                    radius_g=item[4])

            delta_x = ext_1 - ext_2
            delta_y = ext_3 - ext_4

            x_pos = item[5][0] + delta_x
            y_pos = item[5][1] + delta_y

            ax1.annotate('', (x_pos, y_pos),
                         xytext=(item[5][0], item[5][1]),
                         fontsize=8,
                         arrowprops={'arrowstyle': '->'},
                         color='black',
                         zorder=3.)

            x_pos_text = item[5][0] + delta_x / 2.
            y_pos_text = item[5][1] + delta_y / 2.

            vector_len = math.sqrt(delta_x**2 + delta_y**2)

            if item[3] == 'MgSiO3':
                dust_species = r'MgSiO$_{3}$'

            elif item[3] == 'Fe':
                dust_species = 'Fe'

            if item[4].is_integer():
                red_label = f'{dust_species} ({item[4]:.0f} µm)'

            else:
                red_label = f'{dust_species} ({item[4]:.1f} µm)'

            text = ax1.annotate(red_label, (x_pos_text, y_pos_text),
                                xytext=(-7. * delta_y / vector_len,
                                        7. * delta_x / vector_len),
                                textcoords='offset points',
                                fontsize=8.,
                                color='black',
                                ha='center',
                                va='center')

            ax1.plot([item[5][0], x_pos], [item[5][1], y_pos],
                     '-',
                     color='white')

            sp1 = ax1.transData.transform_point((item[5][0], item[5][1]))
            sp2 = ax1.transData.transform_point((x_pos, y_pos))

            angle = np.degrees(np.arctan2(sp2[1] - sp1[1], sp2[0] - sp1[0]))
            text.set_rotation(angle)

    if objects is not None:
        for i, item in enumerate(objects):
            objdata = read_object.ReadObject(item[0])

            objphot1 = objdata.get_photometry(item[1][0])
            objphot2 = objdata.get_photometry(item[1][1])
            objphot3 = objdata.get_photometry(item[2][0])
            objphot4 = objdata.get_photometry(item[2][1])

            if objphot1.ndim == 2:
                print(
                    f'Found {objphot1.shape[1]} values for filter {item[1][0]} of {item[0]}'
                )
                print(
                    f'so using the first value:  {objphot1[0, 0]} +/- {objphot1[1, 0]} mag'
                )
                objphot1 = objphot1[:, 0]

            if objphot2.ndim == 2:
                print(
                    f'Found {objphot2.shape[1]} values for filter {item[1][1]} of {item[0]}'
                )
                print(
                    f'so using the first value:  {objphot2[0, 0]} +/- {objphot2[1, 0]} mag'
                )
                objphot2 = objphot2[:, 0]

            if objphot3.ndim == 2:
                print(
                    f'Found {objphot3.shape[1]} values for filter {item[2][0]} of {item[0]}'
                )
                print(
                    f'so using the first value:  {objphot3[0, 0]} +/- {objphot3[1, 0]} mag'
                )
                objphot3 = objphot3[:, 0]

            if objphot4.ndim == 2:
                print(
                    f'Found {objphot4.shape[1]} values for filter {item[2][1]} of {item[0]}'
                )
                print(
                    f'so using the first value:  {objphot4[0, 0]} +/- {objphot4[1, 0]} mag'
                )
                objphot4 = objphot4[:, 0]

            color1 = objphot1[0] - objphot2[0]
            color2 = objphot3[0] - objphot4[0]

            error1 = math.sqrt(objphot1[1]**2 + objphot2[1]**2)
            error2 = math.sqrt(objphot3[1]**2 + objphot4[1]**2)

            if len(item) > 3 and item[3] is not None:
                kwargs = item[3]

            else:
                kwargs = {
                    'marker': '>',
                    'ms': 6.,
                    'color': 'black',
                    'mfc': 'white',
                    'mec': 'black',
                    'label': 'Direct imaging'
                }

            ax1.errorbar(color1,
                         color2,
                         xerr=error1,
                         yerr=error2,
                         zorder=3,
                         **kwargs)

            if companion_labels:
                if len(item) > 3:
                    kwargs = item[4]

                else:
                    kwargs = {
                        'ha': 'left',
                        'va': 'bottom',
                        'fontsize': 8.5,
                        'xytext': (5., 5.),
                        'color': 'black'
                    }

                ax1.annotate(objdata.object_name, (color1, color2),
                             zorder=3,
                             textcoords='offset points',
                             **kwargs)

    print(f'Plotting color-color diagram: {output}...', end='', flush=True)

    handles, labels = ax1.get_legend_handles_labels()

    if legend is not None:
        handles, labels = ax1.get_legend_handles_labels()

        # prevent duplicates
        by_label = dict(zip(labels, handles))

        if handles:
            ax1.legend(by_label.values(),
                       by_label.keys(),
                       loc=legend,
                       fontsize=8.5,
                       frameon=False,
                       numpoints=1)

    plt.savefig(os.getcwd() + '/' + output, bbox_inches='tight')
    plt.clf()
    plt.close()

    print(' [DONE]')
Exemple #3
0
def plot_color_magnitude(colorbox=None,
                         objects=None,
                         isochrones=None,
                         models=None,
                         label_x='color [mag]',
                         label_y='M [mag]',
                         xlim=None,
                         ylim=None,
                         offset=None,
                         legend='upper left',
                         output='color-magnitude.pdf'):
    """
    Parameters
    ----------
    colorbox : species.core.box.ColorMagBox, None
        Box with the colors and magnitudes.
    objects : tuple(tuple(str, str, str, str), ), None
        Tuple with individual objects. The objects require a tuple with their database tag, the two
        filter IDs for the color, and the filter ID for the absolute magnitude.
    isochrones : tuple(species.core.box.IsochroneBox, ), None
        Tuple with boxes of isochrone data. Not used if set to None.
    models : tuple(species.core.box.ColorMagBox, ), None

    label_x : str
        Label for the x-axis.
    label_y : str
        Label for the y-axis.
    xlim : tuple(float, float)
        Limits for the x-axis.
    ylim : tuple(float, float)
        Limits for the y-axis.
    legend : str
        Legend position.
    output : str
        Output filename.

    Returns
    -------
    None
    """

    marker = itertools.cycle(('o', 's', '<', '>', 'p', 'v', '^', '*',
                              'd', 'x', '+', '1', '2', '3', '4'))

    model_color = ('tomato', 'teal', 'dodgerblue')
    model_linestyle = ('-', '--', ':', '-.')

    sys.stdout.write('Plotting color-magnitude diagram: '+output+'... ')
    sys.stdout.flush()

    if (models is not None and colorbox is None) or \
            (models is not None and colorbox.object_type == 'temperature'):
        plt.figure(1, figsize=(4.4, 4.5))
        gridsp = mpl.gridspec.GridSpec(1, 3, width_ratios=[4, 0.15, 0.25])
        gridsp.update(wspace=0., hspace=0., left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[0, 0])
        ax2 = plt.subplot(gridsp[0, 2])

    elif colorbox.object_type != 'temperature':
        plt.figure(1, figsize=(4., 4.8))
        gridsp = mpl.gridspec.GridSpec(3, 1, height_ratios=[0.2, 0.1, 4.5])
        gridsp.update(wspace=0., hspace=0., left=0, right=1, bottom=0, top=1)

        ax1 = plt.subplot(gridsp[2, 0])
        ax2 = plt.subplot(gridsp[0, 0])

    # elif models is not None and colorbox.object_type != 'temperature':
    #     plt.figure(1, figsize=(4.2, 4.8))
    #     gridsp = mpl.gridspec.GridSpec(3, 3, width_ratios=[3.7, 0.15, 0.25], height_ratios=[0.25, 0.15, 4.4])
    #     gridsp.update(wspace=0., hspace=0., left=0, right=1, bottom=0, top=1)
    #
    #     ax1 = plt.subplot(gridsp[2, 0])
    #     ax2 = plt.subplot(gridsp[0, 0])
    #     ax3 = plt.subplot(gridsp[2, 2])

    if colorbox is not None:
        sptype = colorbox.sptype
        color = colorbox.color
        magnitude = colorbox.magnitude

    ax1.grid(True, linestyle=':', linewidth=0.7, color='silver', dashes=(1, 4), zorder=0)

    ax1.tick_params(axis='both', which='major', colors='black', labelcolor='black',
                    direction='in', width=0.8, length=5, labelsize=12, top=True,
                    bottom=True, left=True, right=True)

    ax1.tick_params(axis='both', which='minor', colors='black', labelcolor='black',
                    direction='in', width=0.8, length=3, labelsize=12, top=True,
                    bottom=True, left=True, right=True)

    ax1.set_xlabel(label_x, fontsize=14)
    ax1.set_ylabel(label_y, fontsize=14)

    ax1.invert_yaxis()

    if offset:
        ax1.get_xaxis().set_label_coords(0.5, offset[0])
        ax1.get_yaxis().set_label_coords(offset[1], 0.5)
    else:
        ax1.get_xaxis().set_label_coords(0.5, -0.08)
        ax1.get_yaxis().set_label_coords(-0.12, 0.5)

    if xlim:
        ax1.set_xlim(xlim[0], xlim[1])

    if ylim:
        ax1.set_ylim(ylim[0], ylim[1])

    if colorbox is not None:
        cmap_sptype = plt.cm.viridis

        if colorbox.object_type == 'star':
            bounds_sptype = np.arange(0, 11, 1)
        else:
            bounds_sptype = np.arange(0, 8, 1)

    if colorbox.object_type != 'temperature':
        norm_sptype = mpl.colors.BoundaryNorm(bounds_sptype, cmap_sptype.N)

        indices = np.where(sptype != b'None')[0]

        sptype = sptype[indices]
        color = color[indices]
        magnitude = magnitude[indices]

        if colorbox.object_type == 'star':
            spt_disc = plot_util.sptype_stellar(sptype, color.shape)
            unique = np.arange(0, color.size, 1)

        elif colorbox.object_type != 'temperature':
            spt_disc = plot_util.sptype_substellar(sptype, color.shape)
            _, unique = np.unique(color, return_index=True)

        if colorbox.object_type == 'temperature':
            scat_sptype = ax1.scatter(color, magnitude, c=sptype, cmap=cmap_sptype,
                                      zorder=6, s=40, alpha=0.6, edgecolor='none')

        else:
            sptype = sptype[unique]
            color = color[unique]
            magnitude = magnitude[unique]
            spt_disc = spt_disc[unique]

            scat_sptype = ax1.scatter(color, magnitude, c=spt_disc, cmap=cmap_sptype,
                                      norm=norm_sptype, zorder=6, s=40, alpha=0.6,
                                      edgecolor='none')

    if colorbox is not None:
        if colorbox.object_type == 'temperature':
            cb1 = Colorbar(ax=ax2, mappable=scat_sptype, orientation='vertical',
                           ticklocation='right', format='%i')

            cb1.ax.tick_params(width=0.8, length=5, labelsize=10, direction='in', color='black')
            cb1.ax.set_ylabel('Temperature [K]', rotation=270, fontsize=12, labelpad=22)
            cb1.solids.set_edgecolor("face")

        else:
            cb1 = Colorbar(ax=ax2, mappable=scat_sptype, orientation='horizontal',
                           ticklocation='top', format='%.2f')

            cb1.ax.tick_params(width=0.8, length=5, labelsize=10, direction='in', color='black')

            if colorbox.object_type == 'star':
                cb1.set_ticks(np.arange(0.5, 10., 1.))
                cb1.set_ticklabels(['O', 'B', 'A', 'F', 'G', 'K', 'M', 'L', 'T', 'Y'])

            else:
                cb1.set_ticks(np.arange(0.5, 7., 1.))
                cb1.set_ticklabels(['M0-M4', 'M5-M9', 'L0-L4', 'L5-L9', 'T0-T4', 'T6-T8', 'Y1-Y2'])

    if models is not None:
        cmap_teff = plt.cm.afmhot

        teff_min = np.inf
        teff_max = -np.inf

        for item in models:

            if np.amin(item.sptype) < teff_min:
                teff_min = np.amin(item.sptype)

            if np.amax(item.sptype) > teff_max:
                teff_max = np.amax(item.sptype)

        norm_teff = mpl.colors.Normalize(vmin=teff_min, vmax=teff_max)

        count = 0

        model_dict = {}

        for item in models:
            if item.library not in model_dict:
                model_dict[item.library] = [count, 0]
                count += 1

            else:
                model_dict[item.library] = [model_dict[item.library][0], model_dict[item.library][1]+1]

            model_count = model_dict[item.library]

            if model_count[1] == 0:
                label = plot_util.model_name(item.library)

                ax1.plot(item.color, item.magnitude, linestyle=model_linestyle[model_count[1]],
                         linewidth=0.6, zorder=3, color=model_color[model_count[0]], label=label)

            else:
                ax1.plot(item.color, item.magnitude, linestyle=model_linestyle[model_count[1]],
                         linewidth=0.6, zorder=3, color=model_color[model_count[0]])

            # scat_teff = ax1.scatter(item.color, item.magnitude, c=item.sptype, cmap=cmap_teff,
            #                         norm=norm_teff, zorder=4, s=15, alpha=1.0, edgecolor='none')

        # cb2 = ColorbarBase(ax=ax3, cmap=cmap_teff, norm=norm_teff, orientation='vertical', ticklocation='right')
        # cb2.ax.tick_params(width=0.8, length=5, labelsize=10, direction='in', color='black')
        # cb2.ax.set_ylabel('Temperature [K]', rotation=270, fontsize=12, labelpad=22)

    if isochrones is not None:
        for item in isochrones:
            ax1.plot(item.color, item.magnitude, linestyle='-', linewidth=1., color='black')

    if objects is not None:
        for item in objects:
            objdata = read_object.ReadObject(item[0])

            objcolor1 = objdata.get_photometry(item[1])
            objcolor2 = objdata.get_photometry(item[2])
            abs_mag = objdata.get_absmag(item[3])

            colorerr = math.sqrt(objcolor1[1]**2+objcolor2[1]**2)

            ax1.errorbar(objcolor1[0]-objcolor2[0], abs_mag[0], yerr=abs_mag[1], xerr=colorerr,
                         marker=next(marker), ms=6, color='black', label=objdata.object_name,
                         markerfacecolor='white', markeredgecolor='black', zorder=10)

    handles, labels = ax1.get_legend_handles_labels()

    if handles:
        # handles = [h[0] for h in handles]
        # ax1.legend(handles, labels, loc=legend, prop={'size': 9}, frameon=False, numpoints=1)

        ax1.legend(loc=legend, prop={'size': 9}, frameon=False, numpoints=1)

    plt.savefig(os.getcwd()+'/'+output, bbox_inches='tight')
    plt.close()

    sys.stdout.write('[DONE]\n')
    sys.stdout.flush()