def __init__(self, format=None, crit="all"):
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.crit = crit
        self.db_os = Stats_Database(extension="_OS")
        self.stats_process_os = self.db_os._stats_process

        print self.stats_process_os.count()

        self.fontname = "Trebuchet MS"
        self.fontsize = 14

        self.get_data()

        plt.xlabel("Efficiency (one-photon water splitters, " + crit + ")",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Efficiency (transparent shields, " + crit + ")",
                   fontname=self.fontname,
                   fontsize=self.fontsize)

        if format:
            plt.savefig("lsos_plot_" + crit + "." + format)
        else:
            plt.show()
    def __init__(self, format=None, criteria="all"):
        plt.figure(1, figsize=(16, 8))
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.num_exps = self.stats_process.count()
        self.criteria = criteria
        self.fontname = "Trebuchet MS"

        params = [
            "crossover_fnc", "popsize", "selection_overall", "mutation_rate",
            "elitism_num", "fitness_fnc"
        ]

        plt.subplot(3, 2, 1)
        for idx, param in enumerate(params):
            plt.subplot(3, 2, idx + 1)
            self.get_data(param)

        if format:
            mpl.rcParams['figure.figsize'] = (16, 8)
            print mpl.rcParams['figure.figsize']

            plt.savefig("parameters_plot." + format)
        else:
            plt.show()
    def __init__(self, format=None):
        plt.figure(2)
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.num_exps = self.stats_process.count()

        self.lw = 4
        self.fontsize = 16
        self.fontname = "Trebuchet MS"

        plt.xlabel("GA parameter set rank",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Efficiency vs. random",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.setp(plt.gca().get_xticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.setp(plt.gca().get_yticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.xlim((0, self.num_exps))
        self.get_reference()
        self.get_data()
        if format:
            plt.savefig("comparison_plot." + format)
        else:
            plt.show()
    def __init__(self):
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process

        for it in self.stats_process.find():
            p = it['parameters']

            print("{}\t{}\t{}\t{}\t{}\t{}\t{}").format(
                p['popsize'], get_pretty_name(p['selection_fnc']),
                get_pretty_name(p['fitness_fnc']),
                get_pretty_name(p['crossover_fnc']), p['elitism_num'],
                it['ten'], it['all'])
    def __init__(self, format=None):
        plt.figure(1, figsize=(8, 6))
        self.db = Stats_Database()
        self.db2 = Stats_Database(extension="_exclusion")
        self.stats_process = self.db._stats_process
        self.stats_process2 = self.db2._stats_process

        self.lw = 4
        self.fontsize = 14
        self.fontname = "Trebuchet MS"

        self.get_goldschmidt_data()  # goldschmidt reference
        self.get_data(0, "best GA", "dodgerblue", pos="left",
                      crit="mixed")  # best
        self.get_data_exclusion(0,
                                "best GA +\nchemical ",
                                "#00B31B",
                                pos="right",
                                crit="all")  # best
        self.get_reference_data()  # reference

        plt.xlabel("Average number of calculations",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Potential solar water splitting materials",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.setp(plt.gca().get_xticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.setp(plt.gca().get_yticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.ylim((0, len(GOOD_CANDS_LS) + 0.5))
        plt.xlim((0, 4500))

        if format:
            plt.savefig("performance_plot_exclusion." + format)
        else:
            plt.show()
    def __init__(self, format=None):
        plt.figure(3)
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.fontname = "Trebuchet MS"
        self.fontsize = 14

        self.get_data()

        plt.xlabel("Efficiency (10 solutions)",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Efficiency (all 20 solutions)",
                   fontname=self.fontname,
                   fontsize=self.fontsize)

        if format:
            plt.savefig("tenall_plot." + format)
        else:
            plt.show()
    def __init__(self, format=None):
        plt.figure(3)
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.fontname = "Trebuchet MS"
        self.fontsize = 14

        self.get_data()

        plt.xlabel("Average number of global mutations",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Efficiency (all 20 materials)",
                   fontname=self.fontname,
                   fontsize=self.fontsize)

        if format:
            plt.savefig("breakout_plot." + format)
        else:
            plt.show()
    def __init__(self, format=None):
        plt.figure(1, figsize=(8, 6))
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        num_exps = self.stats_process.count()

        self.lw = 4
        self.fontsize = 14
        self.fontname = "Trebuchet MS"

        self.get_reference_data()  # reference
        # self.get_goldschmidt_data()  # goldschmidt reference
        self.get_data(0, "best GA", "dodgerblue", pos="right",
                      crit="mixed")  # best
        # self.get_data(0, "best GA (ten)", "green", pos="right", crit="ten")  # best
        #self.get_data(num_exps-1, "worst GA", "tomato", "right")  # ~worst

        plt.xlabel("Average number of calculations",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.ylabel("Potential solar water splitting materials",
                   fontname=self.fontname,
                   fontsize=self.fontsize)
        plt.setp(plt.gca().get_xticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.setp(plt.gca().get_yticklabels(),
                 fontname=self.fontname,
                 fontsize=self.fontsize)
        plt.xticks(np.arange(0, plt.xlim()[1], 2500))
        plt.ylim((0, len(GOOD_CANDS_LS) + 0.5))
        plt.xlim((0, NUM_CANDS))

        if format:
            plt.savefig("performance_plot." + format)
        else:
            plt.show()
    def __init__(self, format=None, criteria="all"):
        plt.figure(1, figsize=(12, 8))
        self.db = Stats_Database()
        self.stats_process = self.db._stats_process
        self.num_exps = self.stats_process.count()
        self.criteria = criteria
        self.fontname = "Trebuchet MS"
        self.fontsize = 14

        params = [
            "crossover_fnc", "popsize", "selection_overall", "elitism_num",
            "fitness_fnc", "mutation_rate"
        ]
        param_vals = []

        for p1 in params:
            p1_labels = self.stats_process.distinct("parameters." + p1)
            if "-" not in str(p1_labels[0]):
                p1_labels.sort()
            else:
                p1_labels.sort(
                    key=lambda label: ord(label.split("-")[0][1]) * 10 + float(
                        label.split("-")[1]) + float(label.split("-")[2]))

            for p1x in p1_labels:
                param_vals.append((p1, p1x))

        values = []

        for x1 in param_vals:
            for y1 in param_vals:
                if x1[0] == y1[0] and x1[1] != y1[1]:
                    values.append(0.5)
                elif x1[0] == y1[0]:
                    constraint = {"parameters." + x1[0]: x1[1]}
                    values.append(self.get_score(constraint))
                else:
                    constraint = {
                        "parameters." + x1[0]: x1[1],
                        "parameters." + y1[0]: y1[1]
                    }
                    values.append(self.get_score(constraint))

        values = np.array(values)
        values.shape = (len(param_vals), len(param_vals))

        plt.hot()
        plt.pcolormesh(values)
        x = [
            str(get_short_name(x[1])) + "=" + get_short_name(x[0])
            for x in param_vals
        ]
        ticklok = [z + 0.5 for z in range(0, len(param_vals))]
        plt.xticks(ticklok,
                   x,
                   ha='center',
                   rotation=90,
                   fontsize=14,
                   fontname=self.fontname)
        plt.yticks(ticklok, x, ha='right', fontsize=14, fontname=self.fontname)

        last_name = param_vals[0][0]

        for idx, label in enumerate(param_vals):
            if label[0] != last_name:
                plt.plot([idx, idx], [0, len(param_vals)],
                         color="darkslategray",
                         linewidth=3)
                plt.plot([0, len(param_vals)], [idx, idx],
                         color="darkslategray",
                         linewidth=3)
                last_name = label[0]

        plt.xlim(0, len(param_vals))
        plt.ylim(0, len(param_vals))
        plt.colorbar()

        if format:
            plt.savefig("heatmap_plot." + format)
        else:
            plt.show()
#!/usr/bin/env python
'''
Created on Oct 4, 2012
'''
from ga_optimization_ternary.database import Stats_Database
from ga_optimization_ternary.utils import get_reference_array

__author__ = "Anubhav Jain"
__copyright__ = "Copyright 2012, The Materials Project"
__version__ = "0.1"
__maintainer__ = "Anubhav Jain"
__email__ = "*****@*****.**"
__date__ = "Oct 4, 2012"

if __name__ == "__main__":
    sdb = Stats_Database()

    sdb_os = Stats_Database(extension="_OS")
    print get_reference_array()[20]
    print get_reference_array()[10]

    # print headers
    params = [
        "crossover_fnc", "popsize", "selection_overall", "mutation_rate",
        "elitism_num", "fitness_fnc"
    ]
    for param in params:
        print param,
    print 'ten', 'all', 'half (OS)', 'all (OS)'

    for item in sdb._stats_process.find():