def doHistogramPlotting(self,
                            dataFrameIn,
                            dataFrameExtra,
                            freqVal='1min',
                            min_periodsVal=1,
                            Title="YourTitle",
                            YLabel="Y",
                            XLabel="Timestamp",
                            fileType="png",
                            typeWeight="EWMA"):
        """Specify the dateFrame that you want to plotted against its EMWA or rolling_mean using a histogram plot
        you require the parameters to shape the plot.

        The whole idea of this is to not aggregate but to do sums, this will possibly plot summations
        of readings, which is basically comparing the current to the expected sum of usage in that interval

        fileType: This is .png or .bmp ect which the self.histPlotDouble will produce for this method to return
        freqVal: this is the down sampling to make the intervals cover a large section of the data
        min_periodsVal: This is good to default to 1 if you using summations

        Return: This will attempt to return the file that stores the figured for the plot
        """
        dateFrameWeighted = None
        LegendToSend = ""
        dfIN = dataFrameIn
        dfWeighted = dataFrameExtra

        dfIN = rs.Resampling().downsample_data_frame(data_frame=dfIN,
                                                     freq=freqVal,
                                                     method="sum")
        #this will be fixed, was thinking 10 x freq to down sample further, i want to keep consistency
        dfIN = rs.Resampling().downsample_data_frame(data_frame=dfIN,
                                                     freq="10min",
                                                     method="mean")

        #dfWeighted = pd.DataFrame(dfWeighted,columns=("reading",))
        dfWeighted = rs.Resampling().downsample_data_frame(
            data_frame=dfWeighted, freq=freqVal, method="sum")

        ## check the weight in order to decide on the weighting system used
        if (typeWeight == "EWMA"):
            dfWeighted = pl.Plotter().ewma_resampling(
                data_frame=dfWeighted,
                freq="10min",
                min_periods=min_periodsVal)
            LegendToSend = self.EWMAHeading
        else:
            dfWeighted = pl.Plotter().equal_weight_moving_average(
                data_frame=dfWeighted,
                freq="10min",
                min_periods=min_periodsVal)
            LegendToSend = self.EQHeading

        return self.histPlotDouble(dataFrameOriginal=dfIN,
                                   dataFrameWeighted=dfWeighted,
                                   LegendLabelWeighted=LegendToSend,
                                   Title=Title,
                                   YLabel=YLabel,
                                   XLabel=XLabel,
                                   fileType=fileType)
Esempio n. 2
0
def confidence_intervals(x, y, z_flat, model, degree, alpha = 0, noise = 0):
    """
    Function for finding the estimated confidence intervals of a given models beta-parameters,
    and makes a plot of the parameters with confidence intervals corresponing to
    a 95% confidence interval.
    """
    X = create_design_matrix(x, y, degree)
    resample = Resampling(X, z_flat)
    betas, variance = resample.bootstrap(model, get_beta_var=True)

    CI = 1.96*np.sqrt(variance)


    #plotting
    plt.xticks(np.arange(0, len(betas), step=1))
    plt.errorbar(range(len(betas)), betas, CI, fmt="b.", capsize=3, label=r'$\beta_j \pm 1.96 \sigma$')
    plt.legend()
    plt.xlabel(r'index $j$')
    plt.ylabel(r'$\beta_j$')
    plt.grid()
    plt.show()
Esempio n. 3
0
def generate_heatmaps(X,
                      y,
                      model_name,
                      min_deg=1,
                      max_deg=10,
                      num_alpha=13,
                      min_alpha=-10,
                      max_alpha=2):
    """
    In general, this function uses the same concept as complexity and alpha_tests, but
    now, we study the three variables: model complexity, hyperparameters, and MSE
    altogether by making a heatmap.
    
    N.B: Vizualization of the heatmap is performed by using the 
    Python-library Altair. 
    """

    # Calling in the model and the variables used
    model = Regress(model_name)
    alpha_vals = np.linspace(min_alpha, max_alpha, num_alpha)
    degrees = np.linspace(min_deg, max_deg, max_deg - min_deg + 1)
    variables = pd.DataFrame(columns=['degrees', 'mse', 'log lambda', 'r2'])

    min_mse = 1e100
    min_lambda = 0
    min_degree = 0
    min_r2 = 0

    i = 0

    # Running the loop for degrees and hyperparameter.
    # Due to runtime issues, tqdm was used to measure the progress in the
    # loops the verify that the program is functioning!
    for deg in degrees:
        j = 0
        X = X
        resample = Resampling(X, y)

        for alpha in tqdm(alpha_vals):
            model.set_alpha(10**alpha)

            mse, bias, variance, r2, mse_train = resample.sklearns_kfold(model)
            variables = variables.append(
                {
                    'degrees': deg,
                    'log lambda': alpha,
                    'mse': mse,
                    'r2': r2
                },
                ignore_index=True)

        if mse < min_mse:
            min_mse = mse
            min_r2 = r2
            min_deg = deg
            min_alpha = alpha

            j += 1
        i += 1

    # Save values in a csv-file, and pivot the varialvles
    variables.to_csv("test.csv")

    doc_raw = pd.read_csv('test.csv')
    doc_matrix = doc_raw.pivot("degrees", "log lambda", "r2")

    # Creating heat maps using Altair
    alt.renderers.enable('altair_viewer')

    chart = alt.Chart(variables).mark_rect().encode(x='degrees:O',
                                                    y='log lambda:O',
                                                    color='r2:Q')

    text = chart.mark_text(baseline='middle').encode(
        text=alt.Text('r2:Q', format=',.2r'))

    tot = chart + text

    # Directs you into Altair's local host
    # and gives you oportunity to save figure
    tot.show()

    print("R2:", min_r2)
    print("Optimal deg:", min_deg)
    print("Optimal alpha:", min_alpha)
    print(r2)
Esempio n. 4
0
def complexity(X, y, model_name, min_deg=1, max_deg=10, alpha=0):
    """
    Plots the training MSE against the test MSE in an interval of
    polynomial degrees (model complexity) from 1 to 10 as a default with desired
    regression method. It also plots the bias-variance tradeoff.
    
    Inputs:
        x and y: the coordinates used to define the design matrix.
        z = the target value
        model_name: Name of regression method: "OLS", "Ridge" or "Lasso"
        min_deg: minimum degree; max_deg: maximum degree
        alpha = Hyperparameter value
        
    
    N.B to change resampling type: Edit line 77 to desired method in Resampling.py
    """

    # Create a pandas dataframe to store error metrics.
    errors = pd.DataFrame(
        columns=['degrees', 'mse', 'bias', 'variance', 'r2', 'mse_train'])

    #initialize regression model and arrays for saving error values
    model = Regress(model_name, alpha=alpha)
    degrees = np.linspace(min_deg, max_deg, max_deg - min_deg + 1)
    num_deg = len(degrees)
    mse = np.zeros(num_deg)
    mse_train = np.zeros(num_deg)
    bias = np.zeros(num_deg)
    variance = np.zeros(num_deg)
    r2 = np.zeros(num_deg)

    # Initialize the optimal error metrics
    min_mse = 1e100
    min_r2 = 0
    min_deg = 0
    i = 0

    #loop through the specified degrees to be analyzed
    for deg in degrees:
        X = X
        resample = Resampling(X, y)

        #perform bootstrap resampling and save error values
        mse[i], bias[i], variance[i], r2[i], mse_train[
            i] = resample.sklearns_kfold(model)

        #save to pandas dataframe
        errors = errors.append(
            {
                'degrees': degrees[i],
                'mse': mse[i],
                'bias': bias[i],
                'variance': variance[i],
                'r2': r2[i],
                'mse_train': mse_train[i]
            },
            ignore_index=True)

        #Defines the optimal error value
        if mse[i] < min_mse:
            min_mse = mse[i]
            min_r2 = r2[i]
            min_deg = deg

        i += 1

    #plot error of test set and training set
    """
    plt.title("Bootstrap MSE, OLS (Franke function)")
    plt.plot(degrees, mse, label='test set')
    plt.plot(degrees, mse_train, label='training set')
    plt.legend()
    plt.xlabel('Model complexity [degree]')
    plt.ylabel('Mean Squared Error')
    plt.show()
    

    #plot bias^2 variance decomposition of the test error
    plt.title("Bias-variance tradeoff, Bootstrap")
    plt.plot(degrees, mse, label='mse')
    plt.plot(degrees, bias,'--', label='bias')
    plt.plot(degrees, variance, label='variance')
    plt.xlabel('Model complexity [degree]')
    plt.ylabel('Mean Squared Error')
    plt.legend()
    plt.show()
    """
    print('min mse:', min_mse)
    print('r2:', min_r2)
    print('deg:', min_deg)
    print(r2)
Esempio n. 5
0
def alpha_tests(X, y, model_name, min_alpha=-10, max_alpha=2, num_alpha=13):
    """
    Plots the MSE for a fixed polynomial degree against the hyperparameters.
    
    Input values remain mostly the same as for the function above

    """

    #Store the errors in pandas dataframe
    errors = pd.DataFrame(
        columns=['log lambda', 'mse', 'bias', 'variance', 'r2', 'mse_train'])

    model = Regress(model_name)
    alpha_vals = np.linspace(min_alpha, max_alpha, num_alpha)

    # Again, initialize values
    mse = np.zeros(num_alpha)
    mse_train = np.zeros(num_alpha)
    bias = np.zeros(num_alpha)
    variance = np.zeros(num_alpha)
    r2 = np.zeros(num_alpha)

    # Making a loop to perform regression analysis with different
    # hyperparameter

    i = 0
    for alpha in alpha_vals:
        X = X
        resample = Resampling(X, y)
        model.set_alpha(10**alpha)

        mse[i], bias[i], variance[i], r2[i], mse_train[i] = resample.bootstrap(
            model)

        errors = errors.append(
            {
                'log lambda': alpha_vals[i],
                'mse': mse[i],
                'bias': bias[i],
                'variance': variance[i],
                'r2': r2[i],
                'mse_train': mse_train[i]
            },
            ignore_index=True)

        i += 1

    # Plotting with MSE and R2-score together.
    fig, ax1 = plt.subplots()
    color = "tab:green"

    #plt.hlines(0.955,-10,0,linestyles="dashed",label="OLS R2-score")
    #plt.hlines(0.004,-10,0,linestyles="dashed",label="OLS MSE",colors="b")
    #plt.legend()

    ax1.set_xlabel(r'$log_{10}\lambda$')
    ax1.set_ylabel("MSE", color=color)
    ax1.plot(alpha_vals, mse, color=color, label="MSE")
    ax1.tick_params(axis="y", labelcolor=color)

    ax2 = ax1.twinx()

    color = "tab:red"
    ax2.set_ylabel("R2-score", color=color)
    ax2.plot(alpha_vals, r2, color=color, label="R2")
    ax2.tick_params(axis="y", labelcolor=color)

    fig.tight_layout()
    plt.title("Error metrics, Ridge")

    plt.show()
Esempio n. 6
0
def model_degree_analysis(x, y, z, model_name, min_deg=1, max_deg=10, n_bootstraps = 100, alpha = 0, ID = '000'):
    """
    Function for analyzing the performance of a model for different model complexities.
    Performs bootstrap resampling for each configuration.
    Plots the MSE of the training error and test error for each configuration,
    and the bias^2 - variance decomposition of the testing error.
    The error scores for each configuration is also saved to a .csv file.


    Inputs:
    -x, y values, dimensions (n, n)
    -z values, dimension (n^2, 1)
    -model_name, name of the model: 'ols', 'ridge', 'lasso'
    -min_deg, max_deg - degrees to analyze
    -n_bootstraps - number of resamples in bootstrap
    -alpha - hyperparameter in 'ridge' and 'lasso'
    -ID - figure IDs
    """


    #setup directories
    dat_filename = 'results/' + 'error_scores_deg_analysis_' + model_name
    fig_filename = 'figures/' + 'deg_analysis_' + model_name
    error_scores = pd.DataFrame(columns=['degree', 'mse', 'bias', 'variance', 'r2', 'mse_train'])

    #initialize regression model and arrays for saving error values
    model = RegressionMethods(model_name, alpha=alpha)
    degrees = np.linspace(min_deg, max_deg, max_deg - min_deg + 1)
    nDegs = len(degrees)
    mse = np.zeros(nDegs)
    bias = np.zeros(nDegs)
    variance = np.zeros(nDegs)
    r2 = np.zeros(nDegs)
    mse_train = np.zeros(nDegs)


    min_mse = 1e100
    min_r2 = 0
    min_deg = 0
    i = 0

    #loop through the specified degrees to be analyzed
    for deg in degrees:
        X = create_design_matrix(x, y, int(deg))
        resample = Resampling(X, z)

        #perform bootstrap resampling and save error values
        mse[i], bias[i], variance[i], mse_train[i], r2[i] = resample.bootstrap(model, n_bootstraps)

        #save to pandas dataframe
        error_scores = error_scores.append({'degree': degrees[i],
                                            'mse': mse[i],
                                            'bias': bias[i],
                                            'variance': variance[i],
                                            'r2': r2[i],
                                            'mse_train': mse_train[i]}, ignore_index=True)

        #check if this configuration gives smallest error
        if mse[i] < min_mse:
            min_mse = mse[i]
            min_r2 = r2[i]
            min_deg = deg


        i += 1
    #end for


    #plot error of test set and training set
    plt.plot(degrees, mse, label='test set')
    plt.plot(degrees, mse_train, label='training set')
    plt.legend()
    plt.xlabel('Model complexity [degree]')
    plt.ylabel('Mean Squared Error')
    plt.savefig(fig_filename + '_test_train_' + ID + '.pdf')
    plt.show()

    #plot bias^2 variance decomposition of the test error
    plt.plot(degrees, mse, label='mse')
    plt.plot(degrees, bias,'--', label='bias')
    plt.plot(degrees, variance, label='variance')
    plt.xlabel('Model complexity [degree]')
    plt.ylabel('Mean Squared Error')
    plt.legend()
    plt.savefig(fig_filename + '_bias_variance_' + ID + '.pdf')
    plt.show()


    #save error scores to file
    error_scores.to_csv(dat_filename + '.csv')
    print('min mse:', min_mse)
    print('r2:', min_r2)
    print('deg:', min_deg)
Esempio n. 7
0
def ridge_lasso_complexity_analysis(x, y, z, model_name, min_deg=1, max_deg=10, n_lambdas=13, min_lamb=-10, max_lamb=2, ID = '000'):
    """
    Function for analyzing ridge or lasso model performance for
    different values lambda and model complexity.
    Performs bootstrap resampling for each configuration and plots
    a heat map of the error for each configuration. Error scores are
    also saved to a .csv file.

    Inputs:
    -x, y values, dimensions (n, n)
    -z values, dimension (n^2, 1)
    -model_name, name of the model: 'ridge', 'lasso'
    -min_deg, max_deg - degrees to analyze
    -min_lamb, max_lamb, n_lambdas - values of log_10(lambda) to analyze, and how many
    -ID - figure IDs
    """


    #initialize model and arrays for parameters lambda and complexity
    model = RegressionMethods(model_name)
    lambdas = np.linspace(min_lamb, max_lamb, n_lambdas)
    degrees = np.linspace(min_deg, max_deg, max_deg - min_deg + 1)

    #setup directories
    dat_filename = 'results/' + 'error_scores_' + model_name
    fig_filename = 'figures/' + 'min_mse_meatmap_' + model_name
    error_scores = pd.DataFrame(columns=['degree', 'log lambda', 'mse', 'bias', 'variance', 'r2', 'mse_train'])


    min_mse = 1e100
    min_lambda = 0
    min_degree = 0
    min_r2 = 0


    i = 0
    #loop through specified degrees
    for deg in degrees:
        j = 0
        X = create_design_matrix(x, y, int(deg))
        resample = Resampling(X, z)
        #loop through specified lambdas
        for lamb in tqdm(lambdas):
            model.set_alpha(10**lamb)

            #perform resampling
            mse, bias, variance, mse_train, r2 = resample.bootstrap(model, n_bootstraps=10)

            #save error scores in pandas dataframe
            error_scores = error_scores.append({'degree': deg,
                                                'log lambda': lamb,
                                                'mse': mse,
                                                'bias': bias,
                                                'variance': variance,
                                                'r2': r2,
                                                'mse_train': mse_train}, ignore_index=True)

            #check if current configuration gives minimal error
            if mse < min_mse:
                min_mse = mse
                min_lambda = lamb
                min_degree = deg
                min_r2 = r2

            j+=1
        #end for lambdas
        i+=1
    #end for degrees

    print('min mse:', min_mse)
    print('min r2:', min_r2)
    print('degree:', min_degree)
    print('lambda:', min_lambda)


    #save scores to file
    error_scores.to_csv(dat_filename + '.csv')



    #plot heat map of error scores of each configuration
    mse_table = pd.pivot_table(error_scores, values='mse', index=['degree'], columns='log lambda')
    idx_i = np.where(mse_table == min_mse)[0]
    idx_j = np.where(mse_table == min_mse)[1]

    fig = plt.figure()
    ax = sns.heatmap(mse_table, annot=True, fmt='.2g', cbar=True, linewidths=1, linecolor='white',
                            cbar_kws={'label': 'Mean Squared Error'})
    ax.add_patch(Rectangle((idx_j, idx_i), 1, 1, fill=False, edgecolor='red', lw=3))

    ax.set_xlabel(r"$\log_{10}(\lambda)$")
    ax.set_ylabel("Complexity")
    ax.set_ylim(len(degrees), 0)
    plt.show()
Esempio n. 8
0
    def plot_single_frame_unusual(self,data_frame,freq="1min", title=None, legend=None, y_label=None, x_label=None, file_name=None,
                          prediction=False,file_type="png"):
        """Specify the dataFrame that you want to plot, single dataFrame
        title: The string title for the plot
        y_label: The label for the y axis
        x_label: The label for the x axis
        file_name: file name you want to save as or None for a byteIO
        file_type: the file type you want your fig saved as using byteIO
        prediction: Make this true if plotting a prediction plot
        Return: return the figure or a byteIO of the figure
        """
        # Review comments: How does this end up in the plt object? Don't have docs to check

        if data_frame is None:
            raise ValueError('Invalid DateFrame, Please pass DateFrame with actually data')

        try:
            data_frame = rs.Resampling().downsample_data_frame(data_frame, freq)
            data_frameWeighted = self.ewma_resampling(data_frame=data_frame, freq=freq,min_periods=1)
        except:
            raise ValueError("One of the parameters is incorrect")


        #data_frame.reading.plot(label=legend, color=self.GraphColorGreen,sharex=True,marker="o")
            # This is to handle plotting forecast data
        if title:
            plt.title(title)
        if y_label:
            plt.ylabel(y_label)
        if x_label:
            plt.xlabel(x_label)
        if legend:
            plt.legend()
        data_frameWeighted = pd.DataFrame(data_frameWeighted,columns=("reading",))
        #data_frameWeighted.reading.plot()

        std = rs.Resampling().get_frame_std_dev(data_frame=data_frameWeighted)
        toRed = std*3
        toOrange = std*1.75

        means = data_frameWeighted.ix[:, "reading"]
        true_values = data_frame.ix[:, "reading"]



        #single.reading.plot(color=self.AlertOrange,marker='o',markerfacecolor=self.AlertOrange)
        single = []
        for x in range(len(true_values)):
            difference = true_values[x] - means[x]
            difference = abs(difference)
            if (difference < toOrange):
                plottedColor = self.AlertGreen
            elif (difference>=toRed):
                plottedColor = self.AlertRed
            else:
                plottedColor = self.AlertOrange
            single.append(plottedColor)

        data_frame.reading.plot(kind='bar',label=legend,sharex=True,color=single,grid=False)
        #data_frame.reading.bar(label=legend,color= self.GraphColorGreen)
        #data_frame.reading.plot(label=legend,color= self.GraphColorGreen)
        #data_frame.reading.plot(label=legend)
        plt.ylabel("Usage(kwh)")
        plt.xlabel("Time")
        plt.title("unusual plot")
        ##print(data_frame["reading"])
        #plt.scatter(data_frame.index,data_frame["reading"],marker="o")
        #data_frame["reading",0].plot(style='.')
        if not file_name:   # if file_name is None
            if (not(file_type.__eq__("bmp") or file_type.__eq__("png") or file_type.__eq__("jpg"))):
                file_type = "png"
            buf = io.BytesIO()
            plt.savefig(buf, format=file_type, bbox_inches="tight", dpi=300, facecolor="w", edgecolor="g")
            buf.seek(0)
            im = Image.open(buf)
            # possibly should use save fig
            return im
        else:
            file_name_split = file_name.split(".")
            if (file_name_split[-1] == "svg") or (file_name_split[-1] == "png") or (file_name_split[-1] == "jpg"):
                return plt.savefig(file_name)
            else:
                return None
Esempio n. 9
0
S2_zip = tkFileDialog.askopenfilenames(
    parent=root,
    title='Choose at least two Sentinel 2 products',
    filetypes=(("S2 Zip File", "*.zip"), ))
zip_path_list = list(S2_zip)

work_dir = raw_input("Choose name for working directory: ")
cwd = os.getcwd() + '/' + work_dir
os.mkdir(cwd)
os.mkdir(cwd + jp2_images_dir)
os.mkdir(cwd + resampled_images_dir)
os.mkdir(cwd + mosaiced_image_dir)
os.mkdir(cwd + masked_image_dir)
os.mkdir(cwd + stacked_bands_dir)

rs.resampling(cwd + jp2_images_dir, cwd + resampled_images_dir, zip_path_list)
# work_dir = 'Prova'
# cwd = os.getcwd() + '/' + work_dir
mi.mosaic_images(cwd + resampled_images_dir, cwd + mosaiced_image_dir)
# work_dir = 'Prova'
# cwd = os.getcwd() + '/' + work_dir
geojson_path = tkFileDialog.askopenfilenames(
    parent=root,
    title='Choose a valid geoJSON file',
    filetypes=(("GeoJSON File", "*.json"), ))
mki.mask_image(cwd + mosaiced_image_dir, cwd + masked_image_dir,
               geojson_path[0])
# work_dir = 'Prova'
# cwd = os.getcwd() + '/' + work_dir
ms.stack_bands(cwd + masked_image_dir, cwd + stacked_bands_dir)
# shutil.rmtree(cwd + jp2_images_dir)
Esempio n. 10
0
    def histPlotDouble(self,
                       dataFrameOriginal,
                       dataFrameWeighted,
                       LegendLabelOriginal="Current reading",
                       LegendLabelWeighted="weighted plot",
                       Title="YourTitle",
                       YLabel="Y",
                       XLabel="Timestamp",
                       fileType="png",
                       file_name=None):
        """Specify the dataFrame that you want to plot, a weighted vs actual readings currently

        Title: The title for the plot : "THE BEST PLOT EVER"
        YLabel: The label for the y axis : reading(kWh)
        XLabel: The label for the x axis : Time Stamp
        fileType: The type of image you looking for
        Labels: These are for your plots so you can have a lengend which is useful

        Return: a image generated by the figure and combination of the ByteIO
        """

        FileToReturn = None
        canSave = False
        if (len(LegendLabelOriginal) <= 0):
            LegendLabelOriginal = "current readings"
        if (len(LegendLabelWeighted) <= 0):
            LegendLabelWeighted = "weighted readings"
        if (len(XLabel) == 0):
            XLabel = "X"
        if (len(YLabel) == 0):
            YLabel = "Y"
        if (len(Title) == 0):
            Title = "The plot"

        #print(dataFrameOriginal)
        dataFrameWeighted = pd.DataFrame(dataFrameWeighted,
                                         columns=("reading", ))
        mino, maxo = rs.Resampling().get_max_value_in_frame(dataFrameOriginal)
        minw, maxw = rs.Resampling().get_max_value_in_frame(
            dataFrameWeighted)  # the issue is indexing in a weighted df

        maxv = max(maxw, maxo) + 5
        minv = min(minw, mino) - 1

        if (maxw > maxo):
            ori = dataFrameOriginal.reading.plot(label=LegendLabelOriginal,
                                                 legend=LegendLabelOriginal,
                                                 kind="bar",
                                                 color=self.coloredFirst)
            wei = dataFrameWeighted.reading.plot(label=LegendLabelWeighted,
                                                 legend=LegendLabelWeighted,
                                                 kind="bar",
                                                 color=self.coloredSecond,
                                                 stacked=True)
        else:
            wei = dataFrameWeighted.reading.plot(label=LegendLabelWeighted,
                                                 legend=LegendLabelWeighted,
                                                 kind="bar",
                                                 color=self.coloredFirst)
            ori = dataFrameOriginal.reading.plot(label=LegendLabelOriginal,
                                                 legend=LegendLabelOriginal,
                                                 kind="bar",
                                                 color=self.coloredSecond,
                                                 stacked=True)

        plt.legend()
        plt.ylim([minv, maxv])
        plt.ylabel(YLabel)
        plt.xlabel(XLabel)
        plt.title(Title)

        if not file_name:  # if file_name is None
            if (not (fileType.__eq__("bmp") or fileType.__eq__("png")
                     or fileType.__eq__("jpg"))):
                fileType = "png"
            buf = io.BytesIO()
            plt.savefig(buf,
                        format=fileType,
                        bbox_inches="tight",
                        dpi=300,
                        facecolor="w",
                        edgecolor="g")
            buf.seek(0)
            im = Image.open(buf)
            # possibly should use save fig
            return im
        else:
            file_name_split = file_name.split(".")
            if (file_name_split[-1]
                    == "svg") or (file_name_split[-1]
                                  == "png") or (file_name_split[-1] == "jpg"):
                return plt.savefig(file_name)
            else:
                return None