コード例 #1
0
def test_update(gridfsstore):
    data1 = np.random.rand(256)
    data2 = np.random.rand(256)
    tic = datetime(2018, 4, 12, 16)
    # Test metadata storage
    gridfsstore.update([{
        "task_id": "mp-1",
        "data": data1,
        gridfsstore.last_updated_field: tic
    }])
    assert (gridfsstore._files_collection.find_one(
        {"metadata.task_id": "mp-1"}) is not None)

    # Test storing data
    gridfsstore.update([{
        "task_id": "mp-1",
        "data": data2,
        gridfsstore.last_updated_field: tic
    }])
    assert len(list(gridfsstore.query({"task_id": "mp-1"}))) == 1
    assert "task_id" in gridfsstore.query_one({"task_id": "mp-1"})
    nptu.assert_almost_equal(
        gridfsstore.query_one({"task_id": "mp-1"})["data"], data2, 7)

    # Test storing compressed data
    gridfsstore = GridFSStore("maggma_test",
                              "test",
                              key="task_id",
                              compression=True)
    gridfsstore.connect()
    gridfsstore.update([{"task_id": "mp-1", "data": data1}])
    assert (gridfsstore._files_collection.find_one(
        {"metadata.compression": "zlib"}) is not None)

    nptu.assert_almost_equal(
        gridfsstore.query_one({"task_id": "mp-1"})["data"], data1, 7)
コード例 #2
0
from maggma.stores import MongoStore, GridFSStore

db = MongoStore(database="mp_rk_calculations",
                collection_name="tasks",
                host="mongodb03.nersc.gov",
                port=27017,
                username="******",
                password="******",
                last_updated_field="last_updated",
                key="metadata.task_id")

elfcar_store = GridFSStore(database="mp_rk_calculations",
                           collection_name="elfcar_fs",
                           host="mongodb03.nersc.gov",
                           port=27017,
                           username="******",
                           password="******",
                           last_updated_field="last_updated",
                           key="metadata.task_id")

chgcar_store = GridFSStore(database="mp_rk_calculations",
                           collection_name="chgcar_fs",
                           host="mongodb03.nersc.gov",
                           port=27017,
                           username="******",
                           password="******",
                           last_updated_field="last_updated",
                           key="metadata.task_id")

aeccar0_store = GridFSStore(database="mp_rk_calculations",
                            collection_name="aeccar0_fs",
コード例 #3
0
def gridfsstore():
    store = GridFSStore("maggma_test", "test", key="task_id")
    store.connect()
    yield store
    store._files_collection.drop()
    store._chunks_collection.drop()
コード例 #4
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if not mpid and (bandstructure_symm_line is None
                             or density_of_states is None):
                raise PreventUpdate

            elif bandstructure_symm_line is None or density_of_states is None:
                if label_select == "":
                    raise PreventUpdate

                # --
                # -- BS and DOS from API or DB
                # --

                bs_data = {"ticks": {}}

                bs_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="bandstructure_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                dos_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="dos_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                es_store = MongoStore(
                    database="fw_bs_prod",
                    collection_name="electronic_structure",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                    key="task_id",
                )

                # - BS traces from DB using task_id
                es_store.connect()
                bs_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=[
                        "bandstructure.{}.task_id".format(path_convention),
                        "bandstructure.{}.total.equiv_labels".format(
                            path_convention),
                    ],
                )

                es_store.close()

                bs_store.connect()
                bandstructure_symm_line = bs_store.query_one(criteria={
                    "metadata.task_id":
                    int(bs_query["bandstructure"][path_convention]["task_id"])
                }, )

                # If LM convention, get equivalent labels
                if path_convention != label_select:
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["equiv_labels"]

                    new_labels_dict = {}
                    for label in bandstructure_symm_line["labels_dict"].keys():

                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[label_select][f_label[0]] +
                                "|" +
                                bs_equiv_labels[label_select][f_label[1]] +
                                "$")
                        else:
                            new_labels_dict["$" + bs_equiv_labels[label_select]
                                            [label_formatted] +
                                            "$"] = bandstructure_symm_line[
                                                "labels_dict"][label]

                    bandstructure_symm_line["labels_dict"] = new_labels_dict

                # - DOS traces from DB using task_id
                es_store.connect()
                dos_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=["dos.task_id"],
                )
                es_store.close()

                dos_store.connect()
                density_of_states = dos_store.query_one(
                    criteria={"task_id": int(dos_query["dos"]["task_id"])}, )

            # - BS Data
            if (type(bandstructure_symm_line) != dict
                    and bandstructure_symm_line is not None):
                bandstructure_symm_line = bandstructure_symm_line.to_dict()

            if type(density_of_states
                    ) != dict and density_of_states is not None:
                density_of_states = density_of_states.to_dict()

            bsml = BSML.from_dict(bandstructure_symm_line)

            bs_reg_plot = BSPlotter(bsml)

            bs_data = bs_reg_plot.bs_plot_data()

            # Make plot continous for lm
            if path_convention == "lm":
                distance_map, kpath_euler = HSKP(
                    bsml.structure).get_continuous_path(bsml)

                kpath_labels = [pair[0] for pair in kpath_euler]
                kpath_labels.append(kpath_euler[-1][1])

            else:
                distance_map = [(i, False)
                                for i in range(len(bs_data["distances"]))]
                kpath_labels = []
                for label_ind in range(len(bs_data["ticks"]["label"]) - 1):
                    if (bs_data["ticks"]["label"][label_ind] !=
                            bs_data["ticks"]["label"][label_ind + 1]):
                        kpath_labels.append(
                            bs_data["ticks"]["label"][label_ind])
                kpath_labels.append(bs_data["ticks"]["label"][-1])

            bs_data["ticks"]["label"] = kpath_labels

            # Obtain bands to plot over and generate traces for bs data:
            energy_window = (-6.0, 10.0)
            bands = []
            for band_num in range(bs_reg_plot._nb_bands):
                if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                        energy_window[1]) and (bs_data["energy"][0][str(
                            Spin.up)][band_num][0] >= energy_window[0]):
                    bands.append(band_num)

            bstraces = []

            pmin = 0.0
            tick_vals = [0.0]

            cbm = bsml.get_cbm()
            vbm = bsml.get_vbm()

            cbm_new = bs_data["cbm"]
            vbm_new = bs_data["vbm"]

            for dnum, (d, rev) in enumerate(distance_map):

                x_dat = [
                    dval - bs_data["distances"][d][0] + pmin
                    for dval in bs_data["distances"][d]
                ]

                pmin = x_dat[-1]

                tick_vals.append(pmin)

                if not rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in range(len(bs_data["distances"][d]))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]
                elif rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in reversed(
                                range(len(bs_data["distances"][d])))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]

                if bs_reg_plot._bs.is_spin_polarized:

                    if not rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]
                    elif rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in reversed(
                                    range(len(bs_data["distances"][d])))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]

                bstraces += traces_for_segment

                # - Get proper cbm and vbm coords for lm
                if path_convention == "lm":
                    for (x_point, y_point) in bs_data["cbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (cbm["kpoint"].label is None
                                    or cbm["kpoint"].label in new_label):
                                cbm_new.append((x_point_new, y_point))

                    for (x_point, y_point) in bs_data["vbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (vbm["kpoint"].label is None
                                    or vbm["kpoint"].label in new_label):
                                vbm_new.append((x_point_new, y_point))

            bs_data["ticks"]["distance"] = tick_vals

            # - Strip latex math wrapping for labels
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "GAMMA": "Γ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
                "_{1}": "₁",
                "_{2}": "₂",
                "_{3}": "₃",
                "_{4}": "₄",
                "^{*}": "*",
            }

            bar_loc = []
            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])
                        if key == "\\mid":
                            bar_loc.append(
                                bs_data["ticks"]["distance"][entry_num])

            # Vertical lines for disjointed segments
            vert_traces = [{
                "x": [x_point, x_point],
                "y": energy_window,
                "mode": "lines",
                "marker": {
                    "color": "white"
                },
                "hoverinfo": "skip",
                "showlegend": False,
                "xaxis": "x",
                "yaxis": "y",
            } for x_point in bar_loc]

            bstraces += vert_traces

            # Dots for cbm and vbm

            dot_traces = [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "markers",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                            cbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(cbm_new)] + [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "marker",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords),
                                            vbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(vbm_new)]

            bstraces += dot_traces

            # - DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(density_of_states)

            dos_max = np.abs(
                (dos.energies - dos.efermi - energy_window[1])).argmin()
            dos_min = np.abs(
                (dos.energies - dos.efermi - energy_window[0])).argmin()

            if bs_reg_plot._bs.is_spin_polarized:
                # Add second spin data if available
                trace_tdos = {
                    "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": "Total DOS (spin ↓)",
                    "line": go.scatter.Line(color="#444444", dash="dot"),
                    "fill": "tozerox",
                    "fillcolor": "#C4C4C4",
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = {
                "x": dos.densities[Spin.up][dos_min:dos_max],
                "y": dos.energies[dos_min:dos_max] - dos.efermi,
                "mode": "lines",
                "name": tdos_label,
                "line": go.scatter.Line(color="#444444"),
                "fill": "tozerox",
                "fillcolor": "#C4C4C4",
                "legendgroup": "spinup",
                "xaxis": "x2",
                "yaxis": "y2",
            }

            dostraces.append(trace_tdos)

            ele_dos = dos.get_element_dos()
            elements = [str(entry) for entry in ele_dos.keys()]

            if dos_select == "ap":
                proj_data = ele_dos
            elif dos_select == "op":
                proj_data = dos.get_spd_dos()
            elif "orb" in dos_select:
                proj_data = dos.get_element_spd_dos(
                    Element(dos_select.replace("orb", "")))
            else:
                raise PreventUpdate

            # Projected DOS
            count = 0
            colors = [
                "#d62728",  # brick red
                "#2ca02c",  # cooked asparagus green
                "#17becf",  # blue-teal
                "#bcbd22",  # curry yellow-green
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
            ]

            for label in proj_data.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = {
                        "x":
                        -1.0 *
                        proj_data[label].densities[Spin.down][dos_min:dos_max],
                        "y":
                        dos.energies[dos_min:dos_max] - dos.efermi,
                        "mode":
                        "lines",
                        "name":
                        str(label) + " (spin ↓)",
                        "line":
                        dict(width=3, color=colors[count], dash="dot"),
                        "xaxis":
                        "x2",
                        "yaxis":
                        "y2",
                    }

                    dostraces.append(trace)
                    spin_up_label = str(label) + " (spin ↑)"

                else:
                    spin_up_label = str(label)

                trace = {
                    "x": proj_data[label].densities[Spin.up][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": spin_up_label,
                    "line": dict(width=2, color=colors[count]),
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace)

                count += 1
            traces = [bstraces, dostraces, bs_data]

            return (traces, elements)