Example #1
0
def _makeHTMLmap(fig, ax, htmlmap_path):
    """
    Update 'surface_current_tilemap.html' file that includes a HTML fragment which makes the domain map clickable.
    """
    i = 1
    header = '<map name="tileclickmap">\n'
    footer = "</map>\n"
    with open(htmlmap_path / "surface_current_tilemap.html",
              "w") as tilemapfile:
        tilemapfile.write(header)
        for tile, values in tile_coords_dic.items():
            x1, x2, y1, y2 = values[0], values[1], values[2], values[3]

            # Transform x,y data into pixel coordinates
            xy_pixels = ax.transData.transform(
                numpy.vstack([[x1, x2], [y1, y2]]).T)
            xpix, ypix = xy_pixels.T

            # 0,0 is the lower left coordinate in matplotlib but it's the upper left in HTML
            # so the y-coordinate needs to be flipped here
            width, height = FigureCanvasBase(fig).get_width_height()
            ypix = height - ypix

            curline = (
                '    <area shape="rect" coords="{:3d},{:3d},{:3d},{:3d}"  '
                'href="JavaScript: regionMap({:2d}); void(0);">\n'.format(
                    int(xpix[0]), int(ypix[0]), int(xpix[1]), int(ypix[1]), i))
            tilemapfile.write(curline)

            i += 1
        tilemapfile.write(footer)
Example #2
0
def test_canvas_change():
    fig = plt.figure()
    # Replaces fig.canvas
    canvas = FigureCanvasBase(fig)
    # Should still work.
    plt.close(fig)
    assert not plt.fignum_exists(fig.number)
Example #3
0
def test_get_default_filename():
    try:
        test_dir = tempfile.mkdtemp()
        plt.rcParams['savefig.directory'] = test_dir
        fig = plt.figure()
        canvas = FigureCanvasBase(fig)
        filename = canvas.get_default_filename()
        assert filename == 'image.png'
    finally:
        shutil.rmtree(test_dir)
Example #4
0
def _render_figures(fig_list, tile_names, storage_path, date_stamp, file_type):
    for fig, name in zip(fig_list, tile_names):
        ftile = "surface_currents_tile{:02d}_{}_UTC.{}".format(
            int(name[4:]), date_stamp, file_type
        )
        outfile = Path(storage_path, ftile)
        FigureCanvasBase(fig).print_figure(
            os.fspath(outfile), facecolor=fig.get_facecolor()
        )
        logger.debug(f"{outfile} saved")

        if file_type == "png":
            _apply_pngquant(outfile, 16)
Example #5
0
def test_get_default_filename_already_exists():
    # From #3068: Suggest non-existing default filename
    try:
        test_dir = tempfile.mkdtemp()
        plt.rcParams['savefig.directory'] = test_dir
        fig = plt.figure()
        canvas = FigureCanvasBase(fig)

        # create 'image.png' in figure's save dir
        open(os.path.join(test_dir, 'image.png'), 'w').close()

        filename = canvas.get_default_filename()
        assert filename == 'image-1.png'
    finally:
        shutil.rmtree(test_dir)
Example #6
0
def print_figure(fig, fmt="png", bbox_inches="tight", base64=False, **kwargs):
    """Print a figure to an image, and return the resulting file data

    Returned data will be bytes unless ``fmt='svg'``,
    in which case it will be unicode.

    Any keyword args are passed to fig.canvas.print_figure,
    such as ``quality`` or ``bbox_inches``.

    If `base64` is True, return base64-encoded str instead of raw bytes
    for binary-encoded image formats

    .. versionadded:: 7.29
        base64 argument
    """
    # When there's an empty figure, we shouldn't return anything, otherwise we
    # get big blank areas in the qt console.
    if not fig.axes and not fig.lines:
        return

    dpi = fig.dpi
    if fmt == 'retina':
        dpi = dpi * 2
        fmt = 'png'

    # build keyword args
    kw = {
        "format":fmt,
        "facecolor":fig.get_facecolor(),
        "edgecolor":fig.get_edgecolor(),
        "dpi":dpi,
        "bbox_inches":bbox_inches,
    }
    # **kwargs get higher priority
    kw.update(kwargs)

    bytes_io = BytesIO()
    if fig.canvas is None:
        from matplotlib.backend_bases import FigureCanvasBase
        FigureCanvasBase(fig)

    fig.canvas.print_figure(bytes_io, **kw)
    data = bytes_io.getvalue()
    if fmt == 'svg':
        data = data.decode('utf-8')
    elif base64:
        data = b2a_base64(data).decode("ascii")
    return data
Example #7
0
    def compute_output(self, output_module, configuration):
        figure = output_module.get_input('value')
        w = configuration["width"]
        h = configuration["height"]
        img_format = self.get_format(configuration)
        filename = self.get_filename(configuration, suffix='.%s' % img_format)

        w_inches = w / 72.0
        h_inches = h / 72.0

        previous_size = tuple(figure.get_size_inches())
        figure.set_size_inches(w_inches, h_inches)
        canvas = FigureCanvasBase(figure)
        canvas.print_figure(filename, dpi=72, format=img_format)
        figure.set_size_inches(previous_size[0], previous_size[1])
        canvas.draw()
Example #8
0
def new_figure_manager_given_figure(num, figure):
    """Create a new figure manager instance for the given figure."""
    
    #print(f'timer: {time.perf_counter()}')

    if not gui.valid() or gui._qapp is None:                    
        #In case of comming from other Process
        #Don't do a guicall, FigureCanvasGh2 or FigureManagerQT is not pickable!
        #Is called if figure, line, ... is depickled from the interprocess queue
        canvas = FigureCanvasBase(figure)
        manager = FigureManagerGh2Child(canvas, num)
        
    else:
        canvas = gui.gui_call(FigureCanvasGh2, figure)    
        manager = FigureManagerGh2(canvas, num)
        
    return manager        
Example #9
0
    def save_with_matplotlib(plot, width, height, dpi, filename):
        from matplotlib.backend_bases import FigureCanvasBase
        figure = ChacoFigure(plot, width, height, dpi)
        canvas = FigureCanvasBase(figure)
        ext = os.path.splitext(filename)[1][1:]

        try:
            # Call the relevant print_ method on the canvas.
            # This invokes the correct backend and prints the "figure".
            func = getattr(canvas, 'print_' + ext)
        except AttributeError, e:
            errmsg = ("The filename must have an extension that matches "
                      "a graphics format, such as '.png' or '.tiff'.")
            if str(e.message) != '':
                errmsg = ("Unknown filename extension: '%s'\n" %
                          str(e.message)) + errmsg
            error(None, errmsg, title="Invalid Filename Extension")
def test_location_event_position():
    # LocationEvent should cast its x and y arguments
    # to int unless it is None
    fig = plt.figure()
    canvas = FigureCanvasBase(fig)
    test_positions = [(42, 24), (None, 42), (None, None),
                      (200, 100.01), (205.75, 2.0)]
    for x, y in test_positions:
        event = LocationEvent("test_event", canvas, x, y)
        if x is None:
            assert event.x is None
        else:
            assert event.x == int(x)
            assert isinstance(event.x, int)
        if y is None:
            assert event.y is None
        else:
            assert event.y == int(y)
            assert isinstance(event.y, int)
Example #11
0
def test_location_event_position(x, y):
    # LocationEvent should cast its x and y arguments to int unless it is None.
    fig, ax = plt.subplots()
    canvas = FigureCanvasBase(fig)
    event = LocationEvent("test_event", canvas, x, y)
    if x is None:
        assert event.x is None
    else:
        assert event.x == int(x)
        assert isinstance(event.x, int)
    if y is None:
        assert event.y is None
    else:
        assert event.y == int(y)
        assert isinstance(event.y, int)
    if x is not None and y is not None:
        assert re.match(
            "x={} +y={}".format(ax.format_xdata(x), ax.format_ydata(y)),
            ax.format_coord(x, y))
        ax.fmt_xdata = ax.fmt_ydata = lambda x: "foo"
        assert re.match("x=foo +y=foo", ax.format_coord(x, y))
Example #12
0
def print_figure(fig, fmt='png', bbox_inches='tight', **kwargs):
    """Print a figure to an image, and return the resulting file data
    
    Returned data will be bytes unless ``fmt='svg'``,
    in which case it will be unicode.
    
    Any keyword args are passed to fig.canvas.print_figure,
    such as ``quality`` or ``bbox_inches``.
    """
    # When there's an empty figure, we shouldn't return anything, otherwise we
    # get big blank areas in the qt console.
    if not fig.axes and not fig.lines:
        return

    dpi = fig.dpi
    if fmt == 'retina':
        dpi = dpi * 2
        fmt = 'png'

    # build keyword args
    kw = {
        "format": fmt,
        "facecolor": fig.get_facecolor(),
        "edgecolor": fig.get_edgecolor(),
        "dpi": dpi,
        "bbox_inches": bbox_inches,
    }
    # **kwargs get higher priority
    kw.update(kwargs)

    bytes_io = BytesIO()
    if fig.canvas is None:
        from matplotlib.backend_bases import FigureCanvasBase
        FigureCanvasBase(fig)

    fig.canvas.print_figure(bytes_io, **kw)
    data = bytes_io.getvalue()
    if fmt == 'svg':
        data = data.decode('utf-8')
    return data
Example #13
0
    def _destroy(event):

        if event.key in mpl.rcParams["keymap.quit"]:
            # grab the manager off the event
            mgr = event.canvas.manager
            if mgr is None:
                raise RuntimeError("Should never be here, please report a bug")
            fig = event.canvas.figure
            # remove this callback.  Callbacks lives on the Figure so survive
            # the canvas being replaced.
            old_cid = getattr(mgr, "_destroy_cid", None)
            if old_cid is not None:
                fig.canvas.mpl_disconnect(old_cid)
                mgr._destroy_cid = None
            # close the window
            mgr.destroy()
            # disconnect the manager from the canvas
            fig.canvas.manager = None
            # reset the dpi
            fig.dpi = getattr(fig, "_original_dpi", fig.dpi)
            # Go back to "base" canvas
            # (this sets state on fig in the canvas init)
            FigureCanvasBase(fig)
def test_get_default_filename(tmpdir):
    plt.rcParams['savefig.directory'] = str(tmpdir)
    fig = plt.figure()
    canvas = FigureCanvasBase(fig)
    filename = canvas.get_default_filename()
    assert filename == 'image.png'
Example #15
0
def _render_figure(fig, storage_path, file_type):
    domain_name = "surface_currents_tilemap.{}".format(file_type)
    outfile = Path(storage_path, domain_name)
    FigureCanvasBase(fig).print_figure(outfile.as_posix(),
                                       facecolor=fig.get_facecolor())
    print(domain_name)
Example #16
0
 def function___init__(self):
     l_canvas = FigureCanvasBase()
Example #17
0
    def plot_hops_corner(self, fitting_directory):
        def correlation(x, y):
            n = len(x)
            mx = np.mean(x)
            sx = np.std(x)
            my = np.mean(y)
            sy = np.std(y)
            return np.round(
                np.sum((x - mx) * (y - my)) / ((n - 1) * sx * sy), 2)

        def td_distribution(datax, datay, axx):

            datax = np.array(datax)
            median = np.median(datax)
            med = np.sqrt(np.median((datax - median)**2))
            xstep = med / 5.0
            xmin = min(datax)
            xmax = max(datax)
            x_size = int(round((xmax - xmin) / xstep)) + 1
            datax = np.int_((datax - xmin) / xstep)
            datay = np.array(datay)
            median = np.median(datay)
            med = np.sqrt(np.median((datay - median)**2))
            ystep = med / 5.0
            ymin = min(datay)
            ymax = max(datay)
            y_size = int(round((ymax - ymin) / ystep)) + 1
            datay = np.int_((datay - ymin) / ystep)

            yx_size = x_size * y_size
            yx = datay * x_size + datax

            yx = np.bincount(yx)
            yx = np.insert(yx, len(yx), np.zeros(yx_size - len(yx)))

            xx, yy = np.meshgrid(xmin + np.arange(x_size) * xstep,
                                 ymin + np.arange(y_size) * ystep)

            final = np.reshape(yx, (y_size, x_size))
            axx.imshow(np.where(final > 0,
                                np.log(np.where(final > 0, final, 1)), 0),
                       extent=(np.min(xx), np.max(xx), np.min(yy), np.max(yy)),
                       cmap=cm.Greys,
                       origin='lower',
                       aspect='auto')

        if not self.mcmc_run_complete:
            raise RuntimeError('MCMC not completed')

        names = []
        results = []
        print_results = []
        errors1 = []
        print_errors1 = []
        errors2 = []
        print_errors2 = []
        errors = []
        traces = []
        traces_bins = []
        traces_counts = []

        for i in self.names:
            if self.results['parameters'][i]['initial']:
                names.append(self.results['parameters'][i]['print_name'])
                results.append(self.results['parameters'][i]['value'])
                print_results.append(
                    self.results['parameters'][i]['print_value'])
                errors1.append(self.results['parameters'][i]['m_error'])
                print_errors1.append(
                    self.results['parameters'][i]['print_m_error'])
                errors2.append(self.results['parameters'][i]['p_error'])
                print_errors2.append(
                    self.results['parameters'][i]['print_p_error'])
                errors.append(0.5 * (self.results['parameters'][i]['m_error'] +
                                     self.results['parameters'][i]['p_error']))
                traces.append(self.results['parameters'][i]['trace'])
                traces_bins.append(self.results['parameters'][i]['trace_bins'])
                traces_counts.append(
                    self.results['parameters'][i]['trace_counts'])

        all_var = len(traces)
        fig = Figure(figsize=(2.5 * all_var, 2.5 * all_var),
                     tight_layout=False)
        canvas = FigureCanvasBase(fig)
        cmap = cm.get_cmap('brg')

        for var in range(len(names)):

            try:
                ax = fig.add_subplot(all_var,
                                     all_var,
                                     all_var * var + var + 1,
                                     facecolor='w')
            except AttributeError:
                ax = fig.add_subplot(all_var,
                                     all_var,
                                     all_var * var + var + 1,
                                     axisbg='w')

            ax.step(traces_bins[var],
                    traces_counts[var],
                    color='k',
                    where='mid')

            ax.axvline(results[var], c='k')
            ax.axvline(results[var] - errors1[var], c='k', ls='--', lw=0.5)
            ax.axvline(results[var] + errors2[var], c='k', ls='--', lw=0.5)

            ax.set_xticks([results[var]])
            ax.set_yticks([0])
            ax.tick_params(left=False,
                           right=False,
                           top=False,
                           bottom=False,
                           labelbottom=False,
                           labelleft=False)

            ax.set_xlabel('{0}\n{1}\n{2}\n{3}'.format(
                r'${0}$'.format(names[var]),
                r'${0}$'.format(print_results[var]),
                r'$-{0}$'.format(print_errors1[var]),
                r'$+{0}$'.format(print_errors2[var])),
                          fontsize=20)

            ax.set_xlim(results[var] - 6 * errors[var],
                        results[var] + 6 * errors[var])
            ax.set_ylim(0, ax.get_ylim()[1])

            for j in range(var + 1, all_var):

                try:
                    ax2 = fig.add_subplot(all_var,
                                          all_var,
                                          all_var * var + 1 + j,
                                          facecolor='w')
                except AttributeError:
                    ax2 = fig.add_subplot(all_var,
                                          all_var,
                                          all_var * var + 1 + j,
                                          axisbg='w')

                td_distribution(traces[j], traces[var], ax2)

                ax2.set_yticks([0])
                ax2.set_xticks([results[j]])
                ax2.tick_params(bottom=False,
                                left=False,
                                right=False,
                                top=False,
                                labelbottom=False,
                                labelleft=False,
                                labelright=False,
                                labeltop=False)

                ax2.set_xlim(results[j] - 6 * errors[j],
                             results[j] + 6 * errors[j])
                ax2.set_ylim(results[var] - 6 * errors[var],
                             results[var] + 6 * errors[var])
                text_x = ax2.get_xlim()[1] - 0.05 * (ax2.get_xlim()[1] -
                                                     ax2.get_xlim()[0])
                text_y = ax2.get_ylim()[1] - 0.05 * (ax2.get_ylim()[1] -
                                                     ax2.get_ylim()[0])
                ax2.text(text_x,
                         text_y,
                         '{0}{1}{2}'.format(
                             r'$', str(correlation(traces[j], traces[var])),
                             '$'),
                         color=cmap(
                             abs(correlation(traces[j], traces[var])) / 2.),
                         fontsize=20,
                         ha='right',
                         va='top')

        fig.subplots_adjust(hspace=0, wspace=0)
        fig.savefig(os.path.join(fitting_directory, 'corner.pdf'),
                    transparent=False)
Example #18
0
def test_canvas_ctor():
    assert isinstance(FigureCanvasBase().figure, Figure)
Example #19
0
    def plot_hops_output(self, target, data_dates, observer, observatory,
                         fitting_directory):

        if target is None:
            target = ' '

        if data_dates is None:
            data_dates = map(str, [
                'set_{0}'.format(str(ff))
                for ff in range(1, self.total_sets + 1)
            ])

        for set_number in range(self.total_sets):

            funit = 1.0
            fcol = 7
            frow = 5
            fbottom = 0.11
            fright = 0.05
            fsmain = 10
            fsbig = 15
            fig = Figure(figsize=(funit * fcol / (1 - fright),
                                  funit * frow / (1 - fbottom)))
            canvas = FigureCanvasBase(fig)
            try:
                gs = gridspec.GridSpec(frow, fcol, fig, 0, fbottom,
                                       1.0 - fright, 1.0, 0.0, 0.0)
            except TypeError:
                gs = gridspec.GridSpec(frow, fcol, 0, fbottom, 1.0 - fright,
                                       1.0, 0.0, 0.0)

            fig.text(0.5,
                     0.94,
                     '{0}{1}{2}'.format('$\mathbf{', target, '}$'),
                     fontsize=24,
                     va='center',
                     ha='center')
            fig.text(0.97,
                     0.97,
                     data_dates[set_number],
                     fontsize=fsmain,
                     va='top',
                     ha='right')

            logo_ax = fig.add_subplot(gs[0, 0])
            logo_ax.imshow(log.holomon_logo_jpg)
            logo_ax.spines['top'].set_visible(False)
            logo_ax.spines['bottom'].set_visible(False)
            logo_ax.spines['left'].set_visible(False)
            logo_ax.spines['right'].set_visible(False)
            logo_ax.tick_params(left=False,
                                bottom=False,
                                labelleft=False,
                                labelbottom=False)

            self.results = {ff: self.results[ff] for ff in self.results}

            period = self.results['parameters']['P']['value']
            mt = self.results['parameters']['mt']['value']
            mt += round(
                (np.mean(self.data[set_number][0]) - mt) / period) * period

            prediction = (self.mid_time + round(
                (np.mean(self.data[set_number][0]) - self.mid_time) /
                self.period) * self.period)

            duration = plc.transit_duration(self.rp_over_rs, self.period,
                                            self.sma_over_rs, self.inclination,
                                            self.eccentricity, self.periastron)

            ingress = prediction - duration / 2
            egress = prediction + duration / 2

            set_indices = np.where(self.data_set_number == set_number)

            ax1 = fig.add_subplot(gs[1:4, 1:])

            ax1.plot(
                self.results['detrended_output_series']['phase'][set_indices],
                self.results['detrended_input_series']['value'][set_indices],
                'ko',
                ms=2)
            ax1.plot(
                self.results['detrended_output_series']['phase'][set_indices],
                self.results['detrended_output_series']['model'][set_indices],
                'r-')

            fig.text(0.04,
                     fbottom + 2.5 * (1 - fbottom) / frow,
                     'relative flux (de-trended)',
                     fontsize=fsbig,
                     va='center',
                     ha='center',
                     rotation='vertical')

            data_ymin = (min(
                self.results['detrended_input_series']['value'][set_indices]) -
                         3 * np.std(self.results['detrended_output_series']
                                    ['residuals'][set_indices]))

            data_ymax = (max(
                self.results['detrended_input_series']['value'][set_indices]) +
                         2 * np.std(self.results['detrended_output_series']
                                    ['residuals'][set_indices]))

            ax1.set_yticks(
                ax1.get_yticks()[np.where(ax1.get_yticks() > data_ymin)])

            ymin, ymax = data_ymax - 1.05 * (data_ymax - data_ymin), data_ymax

            ax1.set_ylim(ymin, ymax)

            x_max = max(
                np.abs(self.results['detrended_output_series']['phase']
                       [set_indices]) + 0.05 *
                (max(self.results['detrended_output_series']['phase']
                     [set_indices]) -
                 min(self.results['detrended_output_series']['phase']
                     [set_indices])))

            ax1.set_xlim(-x_max, x_max)
            ax1.tick_params(labelbottom=False, labelsize=fsmain)

            rpstr = '{0}{1}{2}{3}{4}{5}{6}{7}'.format(
                r'$R_\mathrm{p}/R_* = ',
                self.results['parameters']['rp']['print_value'], '_{-',
                self.results['parameters']['rp']['print_m_error'], '}', '^{+',
                self.results['parameters']['rp']['print_p_error'], '}$')
            mtstr = '${0}{1}{2}{3}{4}{5}{6}{7}{8}{9}{10}{11}{12}{13}{14}{15}$'.format(
                'T_\mathrm{BJD_{TDB}} = ',
                self.results['parameters']['mt']['print_value'], '_{-',
                self.results['parameters']['mt']['print_m_error'], '}', '^{+',
                self.results['parameters']['mt']['print_p_error'], '}',
                ' \quad \mathrm{O-C_{minutes}} = ',
                round(
                    (self.results['parameters']['mt']['value'] - prediction) *
                    24 * 60, 1), '_{-',
                round(self.results['parameters']['mt']['m_error'] * 24 * 60,
                      1), '}', '^{+',
                round(self.results['parameters']['mt']['p_error'] * 24 * 60,
                      1), '}')

            ax1.text(0,
                     ymin + 0.1 * (ymax - ymin),
                     '{0}{1}{2}'.format(rpstr, '\n', mtstr),
                     ha='center',
                     va='center',
                     fontsize=fsmain)

            ax1.axvline((ingress - mt) / period,
                        0.3,
                        1.0,
                        ls='--',
                        c='k',
                        lw=0.75)
            ax1.text((ingress - mt) / period,
                     ax1.get_ylim()[0] + 0.3 * (ymax - ymin),
                     'predicted\ningress\nstart',
                     ha='right',
                     va='top',
                     fontsize=fsmain)
            ax1.axvline((egress - mt) / period,
                        0.3,
                        1.0,
                        ls='--',
                        c='k',
                        lw=0.75)
            ax1.text((egress - mt) / period,
                     ax1.get_ylim()[0] + 0.3 * (ymax - ymin),
                     'predicted\negress\nend',
                     ha='left',
                     va='top',
                     fontsize=fsmain)

            fig.text((1 - fright) / fcol,
                     1 - (1 - fbottom) / frow,
                     '\n\n{0}\n{1}'.format(observer, observatory),
                     fontsize=fsmain,
                     ha='left',
                     va='bottom')

            ax2 = fig.add_subplot(gs[4, 1:])
            ax2.plot(
                self.results['detrended_output_series']['phase'][set_indices],
                self.results['detrended_output_series']['residuals']
                [set_indices],
                'ko',
                ms=2)
            ax2.plot(
                self.results['detrended_output_series']['phase'][set_indices],
                np.zeros_like(self.results['detrended_output_series']['phase']
                              [set_indices]), 'r-')

            ax2.set_ylim(
                -5 * np.std(self.results['detrended_output_series']
                            ['residuals'][set_indices]),
                5 * np.std(self.results['detrended_output_series']['residuals']
                           [set_indices]))

            ax2.set_xlabel('phase', fontsize=fsbig)
            fig.text(0.04,
                     fbottom + 0.5 * (1 - fbottom) / frow,
                     'residuals',
                     fontsize=fsbig,
                     va='center',
                     ha='center',
                     rotation='vertical')

            ax2.set_xlim(-x_max, x_max)
            ax2.tick_params(labelsize=fsmain)

            ax2.text(ax2.get_xlim()[0] + 0.02 *
                     (ax2.get_xlim()[-1] - ax2.get_xlim()[0]),
                     ax2.get_ylim()[0] + 0.07 *
                     (ax2.get_ylim()[-1] - ax2.get_ylim()[0]),
                     r'$\mathrm{rms}_\mathrm{res} = %.1e$' %
                     np.std(self.results['detrended_output_series']
                            ['residuals'][set_indices]),
                     fontsize=fsmain)

            fig.savefig(os.path.join(fitting_directory,
                                     'detrended_model_300dpi.jpg'),
                        dpi=300,
                        transparent=False)
            fig.savefig(os.path.join(fitting_directory,
                                     'detrended_model_900dpi.jpg'),
                        dpi=900,
                        transparent=False)
            fig.savefig(os.path.join(fitting_directory,
                                     'detrended_model_1200dpi.jpg'),
                        dpi=1200,
                        transparent=False)

            return fig