예제 #1
0
 def dump_dicts(self):
     """
     Dump the model evaluation dictionaries
     """
     p_train = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "train_dict.pkl")
     pkl.dump(self.eval_train, open(p_train, "wb"))
     p_test = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "test_dict.pkl")
     pkl.dump(self.eval_test, open(p_test, "wb"))
     p_val = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "validation_dict.pkl")
     pkl.dump(self.eval_validation, open(p_val, "wb"))
 def dump_dicts(self):
     """
     Dump the model evaluation dictionaries
     """
     p_train = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "train_dict.pkl")
     pkl.dump(self.eval_train, open(p_train, "wb"))
     p_test = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "test_dict.pkl")
     pkl.dump(self.eval_test, open(p_test, "wb"))
     p_val = paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), "validation_dict.pkl")
     pkl.dump(self.eval_validation, open(p_val, "wb"))
예제 #3
0
파일: base.py 프로젝트: ylfzr/ADGM
    def plot_eval(self, eval_dict, labels, path_extension=""):
        """
        Plot the loss function in a overall plot and a zoomed plot.
        :param path_extension: If the plot should be saved in an incremental way.
        """
        def plot(x, y, fit, label):
            sns.regplot(np.array(x),
                        np.array(y),
                        fit_reg=fit,
                        label=label,
                        scatter_kws={"s": 5})

        plt.clf()
        plt.subplot(211)
        idx = np.array(eval_dict.values()[0]).shape[0]
        x = np.array(eval_dict.values())
        for i in range(idx):
            plot(eval_dict.keys(), x[:, i], False, labels[i])
        plt.legend()
        plt.subplot(212)
        for i in range(idx):
            plot(eval_dict.keys()[-int(len(x) * 0.25):],
                 x[-int(len(x) * 0.25):][:, i], True, labels[i])
        plt.xlabel('Epochs')
        plt.savefig(
            paths.get_plot_evaluation_path_for_model(
                self.model.get_root_path(), path_extension + ".png"))
예제 #4
0
    def plot_eval(self, eval_dict, labels, path_extension=""):
        """
        Plot the loss function in a overall plot and a zoomed plot.
        :param path_extension: If the plot should be saved in an incremental way.
        """

        def plot(x, y, fit, label):
            sns.regplot(np.array(x), np.array(y), fit_reg=fit, label=label, scatter_kws={"s": 5})

        plt.clf()
        plt.subplot(211)
        idx = np.array(eval_dict.values()[0]).shape[0]
        x = np.array(eval_dict.values())
        for i in range(idx):
            plot(eval_dict.keys(), x[:, i], False, labels[i])
        plt.legend()
        plt.subplot(212)
        for i in range(idx):
            plot(eval_dict.keys()[-int(len(x) * 0.25):], x[-int(len(x) * 0.25):][:, i], True, labels[i])
        plt.xlabel('Epochs')
        plt.savefig(paths.get_plot_evaluation_path_for_model(self.model.get_root_path(), path_extension+".png"))