示例#1
0
    def __init__(self, filepath: str,
                 new_shape: Optional[tuple] = None,
                 target_file_ext: Optional[str] = None,
                 model_type: str = 'classification', multilabel_classification: bool = False, label_sep: Optional[str] = None,
                 manual_labels: Optional[pd.DataFrame] = None, manual_labels_fileNameVar: Optional[str] = None, manual_labels_labelVar: Optional[str] = None,
                 x_scaling: str = "none", x_min_max_range: list = [0, 1], lower_triangular_padding: Optional[int] = None,
                 resmaple_method: str = "random",
                 training_percentage: float = 0.8,
                 verbose: bool = True, random_state: int = 1):
        """Initialization"""
        # - argument check -
        # for multilabel modelling label separation
        if model_type == 'classification':
            if multilabel_classification:
                if label_sep is None:
                    raise ValueError(
                        'set label_sep for multilabel classification.')
                else:
                    self.label_sep = label_sep
            else:
                if label_sep is not None:
                    warn('label_sep ignored when multilabel_class=False')
                    self.label_sep = None
                else:
                    self.label_sep = label_sep

        # - model information -
        self.model_type = model_type
        self.multilabel_class = multilabel_classification
        self.filepath = filepath
        self.target_ext = target_file_ext
        self.manual_labels = manual_labels
        self.manual_labels_fileNameVar = manual_labels_fileNameVar
        self.manual_labels_labelVar = manual_labels_labelVar
        self.new_shape = new_shape

        if model_type == 'semisupervised':
            self.semi_supervised = True
        else:
            self.semi_supervised = False

        # - processing -
        self.x_scaling = x_scaling
        self.x_min_max_range = x_min_max_range
        if lower_triangular_padding is not None and not isinstance(lower_triangular_padding, int):
            raise ValueError(
                'lower_triangular_padding needs to be an int if not None.')
        self.lower_triangular_padding = lower_triangular_padding
        # - resampling -
        self.resample_method = resmaple_method
        self.train_percentage = training_percentage
        self.test_percentage = 1 - training_percentage

        # - random state and other settings -
        self.rand = random_state
        self.verbose = verbose

        # - load paths -
        self.filepath_list, self.labels_list, self.lables_count, self.labels_map, self.labels_map_rev, self.encoded_labels = self._get_file_annot()
示例#2
0
    def __init__(self,
                 initial_x_shape,
                 y_len,
                 bottleneck_dim,
                 output_n,
                 output_activation='softmax',
                 multilabel=False):
        """
        # Details:\n
            - Use "softmax" for binary or mutually exclusive multiclass modelling,
                and use "sigmoid" for multilabel classification.\n
            - y_len: this is the length of y.
                y_len = 1 or 2: binary classification.
                y_len >= 2: multiclass or multilabel classification.
        """
        # super(CnnClassifierFuncAPI, self).__init__()

        # -- initialization and argument check--
        self.initial_x_shape = initial_x_shape
        self.y_len = y_len
        self.bottleneck_dim = bottleneck_dim
        self.multilabel = multilabel
        self.output_n = output_n
        if multilabel and output_activation == 'softmax':
            warn(
                'Activation automatically set to \'sigmoid\' for multilabel classification.'
            )
            self.output_activation = 'sigmoid'
        else:
            self.output_activation = output_activation

        # -- CNN model --
        model_input = tf.keras.Input(shape=self.initial_x_shape)
        # CNN encoding sub layers
        x = Conv2D(16, (3, 3),
                   activation='relu',
                   kernel_regularizer=tf.keras.regularizers.l2(l2=0.01),
                   padding='same')(model_input)  # output: 28, 28, 16
        x = BatchNormalization()(x)
        x = LeakyReLU()(x)
        x = MaxPooling2D((2, 2))(x)  # output: 14, 14, 16
        x = Conv2D(8, (3, 3),
                   activation='relu',
                   padding='same',
                   name='last_conv')(x)  # output: 14, 14, 8
        x = BatchNormalization()(x)
        x = LeakyReLU(name='grads_cam_dense')(x)
        # x = MaxPooling2D((2, 2))(x)  # output: 7, 7, 8
        x = MaxPooling2D((5, 5))(x)  # output: 9, 9, 8
        x = Flatten()(x)  # 7*7*8=392
        x = Dense(self.bottleneck_dim,
                  activation='relu',
                  activity_regularizer=tf.keras.regularizers.l2(l2=0.01))(x)
        x = LeakyReLU()(x)
        x = Dense(self.output_n, activation=self.output_activation)(x)

        self.m = Model(model_input, x)
示例#3
0
def csvRowSum(file: str, header: bool = True, col_idx: Union[int, str, None] = None, logger: MyLogger = None, verbose=True):
    """
    # Purpose:\n
        Calculate row sums of a CSV file.\n

    # Arguments:\n
        file: `str`. A csv file.\n
        header: `bool`. If the input file has a header.\n
        col_idx: `int`, `str` or `None`. Index for the columns to sum.\n
        logger: `MyLogger` or `None`. A `MyLogger` object for custom logging.\n
        verbose: `bool`. Verbose or not.\n

    # Details:\n
        - `col_idx` can take many forms.
            - single int, e.g. `2`, `-3`
            - list of int, e.g. `[2, 3, 6]`
            - single string, e.g. `'2'`, `'2:6'`, `"2:"`, `":2"`, `"-1"`, `"label1"`
            - list of string, e.g. `["2", "3", "4"]`, `["label1", "label2". "label3"]`
            - Note, the following are NOT SUPPORTED: `["1, 2, 3"]`, or  `"label1label2lable3"` (three labels, will be treated as one label: `label1label2lable3`).
            However, int strings like `"123"` are supported, e.g. `"123"` as three ints.\n
        - When `col_idx=None`, all columns are going to be used.\n
    """
    if verbose:
        if logger:
            logger.info(f'Initiating generator for file: {file}')
        else:
            print(f"Calculating row sum for file {file}...")

    with open(file, 'r') as f:
        i_reader, f_reader = itertools.tee(csv.reader(f))
        ncol = len(next(i_reader))
        if col_idx:
            ncol = len(string_flex(np.arange(ncol), col_idx))
        del i_reader
        o = np.zeros(ncol)
        if header:
            h = next(f_reader, None)

        for i, row in enumerate(tqdm(f_reader, unit=' line')):
            r = string_flex(row, col_idx)
            for j, v in enumerate(r):
                try:
                    o[j] += float(v)
                except:
                    if logger:
                        logger.warning(f'value skipped: {i} row, {j} column')
                    else:
                        warn(f'value skipped: {i} row, {j} column')
                    continue

    if verbose:
        if logger:
            logger.info(f'Row sum calculated for file: {file}')
        else:
            print("Done!")

    return o
示例#4
0
    def __init__(self,
                 initial_x_shape,
                 y_len,
                 bottleneck_dim,
                 output_n,
                 output_activation='softmax',
                 multilabel=False):
        """
        # Details:\n
            - Use "softmax" for binary or mutually exclusive multiclass modelling,
                and use "sigmoid" for multilabel classification.\n
            - y_len: this is the length of y.
                y_len = 1 or 2: binary classification.
                y_len >= 2: multiclass or multilabel classification.
        """
        super(CnnClassifier, self).__init__()

        # - initialization and argument check-
        self.initial_x_shape = initial_x_shape
        self.y_len = y_len
        self.bottleneck_dim = bottleneck_dim
        self.multilabel = multilabel
        if multilabel and output_activation == 'softmax':
            warn(
                'Activation automatically set to \'sigmoid\' for multilabel classification.'
            )
            self.output_activation = 'sigmoid'
        else:
            self.output_activation = output_activation

        # - CNN layers -
        # CNN encoding sub layers
        self.conv2d_1 = Conv2D(
            16, (3, 3),
            activation='relu',
            kernel_regularizer=tf.keras.regularizers.l2(l2=0.01),
            padding='same',
            input_shape=initial_x_shape)  # output: 28, 28, 16
        self.bn1 = BatchNormalization()
        self.leakyr1 = LeakyReLU()
        self.maxpooling_1 = MaxPooling2D((2, 2))  # output: 14, 14, 16
        self.conv2d_2 = Conv2D(8, (3, 3),
                               activation='relu',
                               padding='same',
                               name='last_conv')  # output: 14, 14, 8
        self.bn2 = BatchNormalization()
        self.leakyr2 = LeakyReLU(name='grads_cam_dense')
        # self.maxpooling_2 = MaxPooling2D((2, 2))  # output: 7, 7, 8
        self.maxpooling_2 = MaxPooling2D((5, 5))  # output: 9, 9, 8
        self.fl = Flatten()  # 7*7*8=392
        self.dense1 = Dense(
            bottleneck_dim,
            activation='relu',
            activity_regularizer=tf.keras.regularizers.l2(l2=0.01))
        self.encoded = LeakyReLU()
        self.dense2 = Dense(output_n, activation=output_activation)
示例#5
0
    def predict_classes(self,
                        label_dict,
                        x,
                        proba_threshold=None,
                        batch_size=32,
                        verbose=1):
        """
        # Purpose:\n
            Generate class predictions for the input samples batch by batch.\n
        # Arguments:\n
            label_dict: dict. Dictionary with index (integers) as keys.\n
            x: input data, as a Numpy array or list of Numpy arrays
                (if the model has multiple inputs).\n
            proba_threshold: None or float. The probability threshold to allocate class labels to multilabel prediction.\n
            batch_size: integer.\n
            verbose: verbosity mode, 0 or 1.\n
        # Return:\n
            Two pandas dataframes for probability results and 0/1 classification results, in this order.\n
        # Details:\n
            - For label_dict, this is a dictionary with keys as index integers.
                Example:
                {0: 'all', 1: 'alpha', 2: 'beta', 3: 'fmri', 4: 'hig', 5: 'megs', 6: 'pc', 7: 'pt', 8: 'sc'}.
                This can be derived from the "label_map_rev" attribtue from BatchDataLoader class.\n
            - For binary classification, the length of the label_dict should be 1.
                Example: {0: 'case'}. \n
        """
        # - argument check -
        if not isinstance(label_dict, dict):
            raise ValueError('label_dict needs to be a dictionary.')
        else:
            label_keys = list(label_dict.keys())

        if not all(isinstance(key, int) for key in label_keys):
            raise ValueError('The keys in label_dict need to be integers.')

        if self.multilabel and proba_threshold is None:
            raise ValueError(
                'Set proba_threshold for multilabel class prediction.')

        # - set up output column names -
        if len(label_dict) == 1:
            label_dict[0] = label_dict.pop(label_keys[0])

        res_colnames = [None] * len(label_dict)
        for k, v in label_dict.items():
            res_colnames[k] = v

        # - prediction -
        proba = self.predict(x, batch_size=batch_size, verbose=verbose)
        if proba.min() < 0. or proba.max() > 1.:
            warn('Network returning invalid probability values.',
                 'The last layer might not normalize predictions',
                 'into probabilities (like softmax or sigmoid would).')

        proba_res = pd.DataFrame(proba, dtype=float)
        proba_res.columns = res_colnames

        # self.proba = proba
        if self.output_activation == 'softmax':
            if proba.shape[-1] > 1:
                multiclass_res = to_categorical(proba.argmax(axis=1),
                                                proba.shape[-1])
            else:
                multiclass_res = (proba > 0.5).astype('int32')

            multiclass_out = pd.DataFrame(multiclass_res, dtype=int)
            multiclass_out.columns = res_colnames

            return proba_res, multiclass_out

        elif self.output_activation == 'sigmoid':
            """this is to display percentages for each class"""
            # raise NotImplemented('TBC')
            multilabel_res = np.zeros(proba.shape)
            for i, j in enumerate(proba):
                print(f'{i}: {j}')
                sample_res = j >= proba_threshold
                for m, n in enumerate(sample_res):
                    print(f'{m}: {n}')
                    multilabel_res[i, m] = n
                # break

            if verbose:
                idxs = np.argsort(proba)
                for i, j in enumerate(idxs):
                    print(f'Sample: {i}')
                    idx_decrease = j[::-1]  # [::-1] to make decreasing order
                    sample_proba = proba[i]
                    for n in idx_decrease:
                        print(f'\t{label_dict[n]}: {sample_proba[n]*100:.2f}%')
                # break

            multilabel_out = pd.DataFrame(multilabel_res, dtype=int)
            multilabel_out.columns = res_colnames

            return proba_res, multilabel_out
        else:
            raise NotImplemented(
                f'predict_classes method not implemented for {self.output_activation}'
            )
示例#6
0
def rocaucPlot(classifier,
               x,
               y=None,
               label_dict=None,
               legend_pos='inside',
               **kwargs):
    """
    # Purpose\n
        To calculate and plot ROC-AUC for binary or mutual multiclass classification.
    # Arguments\n
        classifier: tf.keras.model subclass. 
            These classes were created with a custom "predict_classes" method, along with other smaller custom attributes.\n
        x: tf.Dataset or np.ndarray. Input x data for prediction.\n
        y: None or np.ndarray. Only needed when x is a np.ndarray. Label information.\n
        label_dict: dict. Dictionary with index (integers) as keys.\n
        legend_pos: str. Legend position setting. Can be set to 'none' to hide legends.\n
        **kwargs: additional arguments for the classifier.predict_classes.\n
    # Return\n
        - AUC scores for all the classes.\n
        - Plot objects "fg" and "ax" from matplotlib.\n
        - Order: auc_res, fg, ax.\n
    # Details\n
        - The function will throw an warning if multilabel classifier is used.\n        
        - The output auc_res is a pd.DataFrame. 
            Column names:  'label', 'auc', 'thresholds', 'fpr', 'tpr'.
            Since the threshold contains multiple values, so as the corresponding 'fpr' and 'tpr',
            the value of these columns is a list.\n
        - For label_dict, this is a dictionary with keys as index integers.
            Example:
            {0: 'all', 1: 'alpha', 2: 'beta', 3: 'fmri', 4: 'hig', 5: 'megs', 6: 'pc', 7: 'pt', 8: 'sc'}.
            This can be derived from the "label_map_rev" attribtue from BatchDataLoader class.\n
    # Note\n
        - need to test the non-tf.Dataset inputs.\n
        - In the case of using tf.Dataset as x, y is not needed.\n
    """
    # - arguments check -
    # more model classes are going to be added.
    if not isinstance(classifier, (CnnClassifier, CnnClassifierFuncAPI)):
        raise ValueError(
            'The classifier should be one of \'CnnClassifier\', \'CnnClassifierFuncAPI\'.'
        )

    if not isinstance(x, (np.ndarray, tf.data.Dataset)):
        raise TypeError(
            'x needs to be either a np.ndarray or tf.data.Dataset class.')

    if isinstance(x, np.ndarray):
        if y is None:
            raise ValueError('Set y (np.ndarray) when x is np.ndarray')
        elif not isinstance(y, np.ndarray):
            raise ValueError('Set y (np.ndarray) when x is np.ndarray')
        elif y.shape[-1] != classifier.y_len:
            raise ValueError(
                'Make sure y is the same length as classifier.y_len.')

    if classifier.multilabel:
        warn('ROC-AUC for multilabel models should not be used.')

    if legend_pos not in ['none', 'inside', 'outside']:
        raise ValueError(
            'Options for legend_pos are \'none\', \'inside\' and \'outside\'.')

    if label_dict is None:
        raise ValueError('Set label_dict.')

    # - make prediction -
    proba, _ = classifier.predict_classes(x=x, label_dict=label_dict, **kwargs)
    proba_np = proba.to_numpy()

    # - set up plotting data -
    if isinstance(x, tf.data.Dataset):
        t = np.ndarray((0, proba.shape[-1]))
        for _, b in x:
            # print(b.numpy())
            bn = b.numpy()
            # print(type(bn))
            t = np.concatenate((t, bn), axis=0)
    else:
        t = y

    # - calculate AUC and plotting -
    auc_res = pd.DataFrame(
        columns=['label', 'auc', 'thresholds', 'fpr', 'tpr'])

    fg, ax = plt.subplots()
    ax.plot([0, 1], [0, 1], 'k--')
    for class_idx, auc_class in enumerate(proba.columns):
        fpr, tpr, thresholds = roc_curve(t[:, class_idx], proba_np[:,
                                                                   class_idx])
        auc_score = roc_auc_score(t[:, class_idx], proba_np[:, class_idx])
        auc_res.loc[class_idx] = [auc_class, auc_score, thresholds, fpr,
                                  tpr]  # store results

        ax.plot(fpr, tpr, label=f'{auc_class} vs rest: {auc_score:.3f}')
    ax.set_title('ROC-AUC')
    ax.set_xlabel('False Positive Rate')
    ax.set_ylabel('True Positive Rate')
    if legend_pos == 'inside':
        ax.legend(loc='best')
    elif legend_pos == 'outside':
        ax.legend(loc='best', bbox_to_anchor=(1.01, 1.0))
    plt.show()

    return auc_res, fg, ax
示例#7
0
def epochsPlotV2(model_history,
                 file=None,
                 figure_width: float = 5,
                 figure_height: float = 5,
                 **kwargs):
    """
    # Purpose\n
        The plot function for epoch history from LSTM modelling\n
    # Arguments\n
        model_history: keras.callbacks.History. Keras modelling history object, generated by model.fit process.\n
        file: string or None. (optional) Directory and file name to save the figure. Use os.path.join() to generate.\n
        figure_size: float in two-tuple/list. Figure size.\n
        kwargs: generic keyword arguments for metrics to visualize in the history object. \n
    # Return\n
        The function returns a pdf figure file to the set directory and with the set file name\n
    # Details\n
        - The loss_var and accuracy_var are keys in the history.history object.\n
    """
    # -- argument check --
    if not isinstance(model_history, tf.keras.callbacks.History):
        raise TypeError('model_history needs to be a keras History object."')
    metrics_dict = model_history.history

    # - set up metrics names -
    hist_metrics = []
    if len(kwargs) > 0:
        for _, key_val in kwargs.items():
            if key_val in model_history.history:
                hist_metrics.append(key_val)
            else:
                warn(f'Input metric {key_val} not found in the model_history.')
                pass
    else:
        for key_val in model_history.history.keys():
            hist_metrics.append(key_val)

        hist_metrics = [x for x in hist_metrics if 'val_' not in x]

    if len(hist_metrics) == 0:
        raise ValueError('No valid metrics found to plot.')

    # -- set up data and plotting-
    if len(hist_metrics) == 1:
        fig, axes = generate_subplots(len(hist_metrics))
    else:
        fig, axes, idxes_to_turn_off = generate_subplots(len(hist_metrics))

    for hist_metric, ax in zip(hist_metrics, axes):
        plot_metric = np.array(metrics_dict[hist_metric])
        plot_x = np.arange(1, len(plot_metric) + 1)

        try:
            plot_val_metric = np.array(metrics_dict['val_' + hist_metric])
            ax.plot(plot_x,
                    plot_val_metric,
                    linestyle='-',
                    color='red',
                    label='validation')
        except:
            warn(f'{hist_metric} on validation data not found.')
        finally:
            ax.plot(plot_x,
                    plot_metric,
                    linestyle='-',
                    color='blue',
                    label='train')
            ax.set_facecolor('white')
            ax.set_title(hist_metric, color='black')
            # ax.set_xlabel('Epoch', fontsize=10, color='black')
            ax.set_ylabel(hist_metric, fontsize=10, color='black')
            ax.legend()
            ax.tick_params(labelsize=5, color='black', labelcolor='black')

            plt.setp(ax.spines.values(), color='black')

    if len(hist_metrics) > 1:
        for i in idxes_to_turn_off:
            plt.setp(axes[i].get_xticklabels(), visible=False)

    plt.xlabel('Epoch')
    plt.tight_layout()

    fig.set_facecolor('white')
    fig.set_size_inches(figure_width, figure_height, forward=True)
    # fig.subplots_adjust(left=0.0, right=1.0, bottom=0.0, top=1.0)
    fig

    # - save output -
    if file is not None:
        full_path = os.path.normpath(os.path.abspath(os.path.expanduser(file)))
        if not os.path.isfile(full_path):
            raise ValueError('Invalid input file or input file not found.')
        else:
            print(f'Saveing plot as {file}...', end='')
            plt.savefig(full_path,
                        dpi=600,
                        bbox_inches='tight',
                        facecolor='white')
            print('Done!')

    return fig, ax
示例#8
0
def epochsPlot(model_history,
               file=None,
               loss_var='loss',
               val_loss_var='val_loss',
               accuracy_var=None,
               val_accuracy_var=None,
               plot_title_loss='Loss',
               plot_title_acc='Accuracy',
               figure_size=(5, 5)):
    """
    # Purpose\n
        The plot function for epoch history from LSTM modelling\n
    # Arguments\n
        model_history: keras.callbacks.History. Keras modelling history object, generated by model.fit process.\n
        file: string or None. (optional) Directory and file name to save the figure. Use os.path.join() to generate.\n
        loss_var: string. Variable name for loss in the model history.\n
        plot_title: string. Plot title.\n
        xlabel: string. X-axis title.\n
        ylabel: string. Y-axis title.\n
        figure_size: float in two-tuple/list. Figure size.\n
    # Return\n
        The function returns a pdf figure file to the set directory and with the set file name\n
    # Details\n
        - The loss_var and accuracy_var are keys in the history.history object.\n
    """
    # -- argument check --
    if not isinstance(model_history, tf.keras.callbacks.History):
        raise TypeError('model_history needs to be a keras History object."')

    if not all(hist_key in model_history.history
               for hist_key in [loss_var, val_loss_var]):
        raise ValueError(
            'Make sure both loss_var and val_loss_var exist in model_history.')

    if all(acc is not None for acc in [accuracy_var, val_accuracy_var]):
        if not all(hist_key in model_history.history
                   for hist_key in [accuracy_var, val_accuracy_var]):
            raise ValueError(
                'Make sure both accuracy_var and val_accuracy_var exist in model_history.'
            )
        else:
            acc_plot = True
    elif any(acc is not None for acc in [accuracy_var, val_accuracy_var]):
        try:
            raise ValueError
        except ValueError as e:
            warn('Only one of accuracy_var, val_accuracy_var are set.',
                 'Proceed with only loss plot.')
        finally:
            acc_plot = False
    else:
        acc_plot = False

    # -- prepare data --
    plot_loss = np.array(model_history.history[loss_var])  # RMSE
    plot_val_loss = np.array(model_history.history[val_loss_var])  # RMSE
    plot_x = np.arange(1, len(plot_loss) + 1)

    if acc_plot:
        plot_acc = np.array(model_history.history[accuracy_var])
        plot_val_acc = np.array(model_history.history[val_accuracy_var])

    # -- plotting --
    if acc_plot:  # two plots
        fig, ax = plt.subplots(1, 2, figsize=(15, 5))
        ax[0].plot(plot_x,
                   plot_loss,
                   linestyle='-',
                   color='blue',
                   label='train')
        ax[0].plot(plot_x,
                   plot_val_loss,
                   linestyle='-',
                   color='red',
                   label='validation')
        ax[0].set_facecolor('white')
        ax[0].set_title(plot_title_loss, color='black')
        ax[0].set_xlabel('Epoch', fontsize=10, color='black')
        ax[0].set_ylabel('Loss', fontsize=10, color='black')
        ax[0].legend()
        ax[0].tick_params(labelsize=5, color='black', labelcolor='black')

        ax[1].plot(plot_x,
                   plot_acc,
                   linestyle='-',
                   color='blue',
                   label='train')
        ax[1].plot(plot_x,
                   plot_val_acc,
                   linestyle='-',
                   color='red',
                   label='validation')
        ax[1].set_facecolor('white')
        ax[1].set_title(plot_title_acc, color='black')
        ax[1].set_xlabel('Epoch', fontsize=10, color='black')
        ax[1].set_ylabel('Accuracy', fontsize=10, color='black')
        ax[1].legend()
        ax[1].tick_params(labelsize=5, color='black', labelcolor='black')

        plt.setp(ax[0].spines.values(), color='black')
        plt.setp(ax[1].spines.values(), color='black')
    else:
        fig, ax = plt.subplots(figsize=figure_size)
        ax.plot(plot_x, plot_loss, linestyle='-', color='blue', label='train')
        ax.plot(plot_x,
                plot_val_loss,
                linestyle='-',
                color='red',
                label='validation')
        ax.set_facecolor('white')
        ax.set_title(plot_title_loss, color='black')
        ax.set_xlabel('Epoch', fontsize=10, color='black')
        ax.set_ylabel('Accuracy', fontsize=10, color='black')
        ax.legend()
        ax.tick_params(labelsize=5, color='black', labelcolor='black')

        plt.setp(ax.spines.values(), color='black')

    fig.set_facecolor('white')
    fig

    # - save output -
    if file is not None:
        full_path = os.path.normpath(os.path.abspath(os.path.expanduser(file)))
        if not os.path.isfile(full_path):
            raise ValueError('Invalid input file or input file not found.')
        else:
            plt.savefig(full_path,
                        dpi=600,
                        bbox_inches='tight',
                        facecolor='white')

    return fig, ax