Exemple #1
0
class TestCopyBuilder(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dbname = "test_" + uuid4().hex
        s = MongoStore(cls.dbname, "test")
        s.connect()
        cls.client = s.collection.database.client

    @classmethod
    def tearDownClass(cls):
        cls.client.drop_database(cls.dbname)

    def setUp(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys]
        self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]]
        kwargs = dict(key="k", lu_field="lu")
        self.source = MongoStore(self.dbname, "source", **kwargs)
        self.target = MongoStore(self.dbname, "target", **kwargs)
        self.builder = CopyBuilder(self.source, self.target)
        self.source.connect()
        self.source.collection.create_index("lu")
        self.target.connect()
        self.target.collection.create_index("lu")
        self.target.collection.create_index("k")

    def tearDown(self):
        self.source.collection.drop()
        self.target.collection.drop()

    def test_get_items(self):
        self.source.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.old_docs))
        self.target.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())), 0)
        self.source.update(self.new_docs, update_lu=False)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.new_docs))

    def test_process_item(self):
        self.source.collection.insert_many(self.old_docs)
        items = list(self.builder.get_items())
        self.assertCountEqual(items, map(self.builder.process_item, items))

    def test_update_targets(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        items = list(map(self.builder.process_item, self.builder.get_items()))
        self.builder.update_targets(items)
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_confirm_lu_field_index(self):
        self.source.collection.drop_index("lu_1")
        with self.assertRaises(Exception) as cm:
            self.builder.get_items()
        self.assertTrue(cm.exception.args[0].startswith("Need index"))
        self.source.collection.create_index("lu")

    def test_runner(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        runner = Runner([self.builder])
        runner.run()
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_query(self):
        self.builder.query = {"k": {"$gt": 5}}
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        runner = Runner([self.builder])
        runner.run()
        all_docs = list(self.target.query(criteria={}))
        self.assertEqual(len(all_docs), 14)
        self.assertTrue(min([d['k'] for d in all_docs]), 6)
Exemple #2
0
class TestCopyBuilder(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dbname = "test_" + uuid4().hex
        s = MongoStore(cls.dbname, "test")
        s.connect()
        cls.client = s.collection.database.client

    @classmethod
    def tearDownClass(cls):
        cls.client.drop_database(cls.dbname)

    def setUp(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys]
        self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]]
        kwargs = dict(key="k", lu_field="lu")
        self.source = MongoStore(self.dbname, "source", **kwargs)
        self.target = MongoStore(self.dbname, "target", **kwargs)
        self.builder = CopyBuilder(self.source, self.target)

        self.source.connect()
        self.source.ensure_index(self.source.key)
        self.source.ensure_index(self.source.lu_field)

        self.target.connect()
        self.target.ensure_index(self.target.key)
        self.target.ensure_index(self.target.lu_field)

    def tearDown(self):
        self.source.collection.drop()
        self.target.collection.drop()

    def test_get_items(self):
        self.source.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.old_docs))
        self.target.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())), 0)
        self.source.update(self.new_docs, update_lu=False)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.new_docs))

    def test_process_item(self):
        self.source.collection.insert_many(self.old_docs)
        items = list(self.builder.get_items())
        self.assertCountEqual(items, map(self.builder.process_item, items))

    def test_update_targets(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        items = list(map(self.builder.process_item, self.builder.get_items()))
        self.builder.update_targets(items)
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    @unittest.skip(
        "Have to refactor how we force read-only so a warning will get thrown")
    def test_index_warning(self):
        """Should log warning when recommended store indexes are not present."""
        self.source.collection.drop_index([(self.source.key, 1)])
        with self.assertLogs(level=logging.WARNING) as cm:
            list(self.builder.get_items())
        self.assertIn("Ensure indices", "\n".join(cm.output))

    def test_run(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        self.builder.run()
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_query(self):
        self.builder.query = {"k": {"$gt": 5}}
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.builder.run()
        all_docs = list(self.target.query(criteria={}))
        self.assertEqual(len(all_docs), 14)
        self.assertTrue(min([d['k'] for d in all_docs]), 6)

    def test_delete_orphans(self):
        self.builder = CopyBuilder(self.source,
                                   self.target,
                                   delete_orphans=True)
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)

        deletion_criteria = {"k": {"$in": list(range(5))}}
        self.source.collection.delete_many(deletion_criteria)
        self.builder.run()

        self.assertEqual(
            self.target.collection.count_documents(deletion_criteria), 0)
        self.assertEqual(self.target.query_one(criteria={"k": 5})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_incremental_false(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        earlier = [{"lu": tic, "k": k, "v": "val"} for k in keys]
        later = [{"lu": toc, "k": k, "v": "val"} for k in keys]
        self.source.collection.insert_many(earlier)
        self.target.collection.insert_many(later)
        query = {"k": {"$gt": 5}}
        self.builder = CopyBuilder(self.source,
                                   self.target,
                                   incremental=False,
                                   query=query)
        self.builder.run()
        docs = sorted(self.target.query(), key=lambda d: d["k"])
        self.assertTrue(all(d["lu"] == tic) for d in docs[5:])
        self.assertTrue(all(d["lu"] == toc) for d in docs[:5])
Exemple #3
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if not mpid and (bandstructure_symm_line is None
                             or density_of_states is None):
                raise PreventUpdate

            elif bandstructure_symm_line is None or density_of_states is None:
                if label_select == "":
                    raise PreventUpdate

                # --
                # -- BS and DOS from API or DB
                # --

                bs_data = {"ticks": {}}

                bs_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="bandstructure_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                dos_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="dos_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                es_store = MongoStore(
                    database="fw_bs_prod",
                    collection_name="electronic_structure",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                    key="task_id",
                )

                # - BS traces from DB using task_id
                es_store.connect()
                bs_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=[
                        "bandstructure.{}.task_id".format(path_convention),
                        "bandstructure.{}.total.equiv_labels".format(
                            path_convention),
                    ],
                )

                es_store.close()

                bs_store.connect()
                bandstructure_symm_line = bs_store.query_one(criteria={
                    "metadata.task_id":
                    int(bs_query["bandstructure"][path_convention]["task_id"])
                }, )

                # If LM convention, get equivalent labels
                if path_convention != label_select:
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["equiv_labels"]

                    new_labels_dict = {}
                    for label in bandstructure_symm_line["labels_dict"].keys():

                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[label_select][f_label[0]] +
                                "|" +
                                bs_equiv_labels[label_select][f_label[1]] +
                                "$")
                        else:
                            new_labels_dict["$" + bs_equiv_labels[label_select]
                                            [label_formatted] +
                                            "$"] = bandstructure_symm_line[
                                                "labels_dict"][label]

                    bandstructure_symm_line["labels_dict"] = new_labels_dict

                # - DOS traces from DB using task_id
                es_store.connect()
                dos_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=["dos.task_id"],
                )
                es_store.close()

                dos_store.connect()
                density_of_states = dos_store.query_one(
                    criteria={"task_id": int(dos_query["dos"]["task_id"])}, )

            # - BS Data
            if (type(bandstructure_symm_line) != dict
                    and bandstructure_symm_line is not None):
                bandstructure_symm_line = bandstructure_symm_line.to_dict()

            if type(density_of_states
                    ) != dict and density_of_states is not None:
                density_of_states = density_of_states.to_dict()

            bsml = BSML.from_dict(bandstructure_symm_line)

            bs_reg_plot = BSPlotter(bsml)

            bs_data = bs_reg_plot.bs_plot_data()

            # Make plot continous for lm
            if path_convention == "lm":
                distance_map, kpath_euler = HSKP(
                    bsml.structure).get_continuous_path(bsml)

                kpath_labels = [pair[0] for pair in kpath_euler]
                kpath_labels.append(kpath_euler[-1][1])

            else:
                distance_map = [(i, False)
                                for i in range(len(bs_data["distances"]))]
                kpath_labels = []
                for label_ind in range(len(bs_data["ticks"]["label"]) - 1):
                    if (bs_data["ticks"]["label"][label_ind] !=
                            bs_data["ticks"]["label"][label_ind + 1]):
                        kpath_labels.append(
                            bs_data["ticks"]["label"][label_ind])
                kpath_labels.append(bs_data["ticks"]["label"][-1])

            bs_data["ticks"]["label"] = kpath_labels

            # Obtain bands to plot over and generate traces for bs data:
            energy_window = (-6.0, 10.0)
            bands = []
            for band_num in range(bs_reg_plot._nb_bands):
                if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                        energy_window[1]) and (bs_data["energy"][0][str(
                            Spin.up)][band_num][0] >= energy_window[0]):
                    bands.append(band_num)

            bstraces = []

            pmin = 0.0
            tick_vals = [0.0]

            cbm = bsml.get_cbm()
            vbm = bsml.get_vbm()

            cbm_new = bs_data["cbm"]
            vbm_new = bs_data["vbm"]

            for dnum, (d, rev) in enumerate(distance_map):

                x_dat = [
                    dval - bs_data["distances"][d][0] + pmin
                    for dval in bs_data["distances"][d]
                ]

                pmin = x_dat[-1]

                tick_vals.append(pmin)

                if not rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in range(len(bs_data["distances"][d]))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]
                elif rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in reversed(
                                range(len(bs_data["distances"][d])))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]

                if bs_reg_plot._bs.is_spin_polarized:

                    if not rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]
                    elif rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in reversed(
                                    range(len(bs_data["distances"][d])))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]

                bstraces += traces_for_segment

                # - Get proper cbm and vbm coords for lm
                if path_convention == "lm":
                    for (x_point, y_point) in bs_data["cbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (cbm["kpoint"].label is None
                                    or cbm["kpoint"].label in new_label):
                                cbm_new.append((x_point_new, y_point))

                    for (x_point, y_point) in bs_data["vbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (vbm["kpoint"].label is None
                                    or vbm["kpoint"].label in new_label):
                                vbm_new.append((x_point_new, y_point))

            bs_data["ticks"]["distance"] = tick_vals

            # - Strip latex math wrapping for labels
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "GAMMA": "Γ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
                "_{1}": "₁",
                "_{2}": "₂",
                "_{3}": "₃",
                "_{4}": "₄",
                "^{*}": "*",
            }

            bar_loc = []
            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])
                        if key == "\\mid":
                            bar_loc.append(
                                bs_data["ticks"]["distance"][entry_num])

            # Vertical lines for disjointed segments
            vert_traces = [{
                "x": [x_point, x_point],
                "y": energy_window,
                "mode": "lines",
                "marker": {
                    "color": "white"
                },
                "hoverinfo": "skip",
                "showlegend": False,
                "xaxis": "x",
                "yaxis": "y",
            } for x_point in bar_loc]

            bstraces += vert_traces

            # Dots for cbm and vbm

            dot_traces = [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "markers",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                            cbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(cbm_new)] + [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "marker",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords),
                                            vbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(vbm_new)]

            bstraces += dot_traces

            # - DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(density_of_states)

            dos_max = np.abs(
                (dos.energies - dos.efermi - energy_window[1])).argmin()
            dos_min = np.abs(
                (dos.energies - dos.efermi - energy_window[0])).argmin()

            if bs_reg_plot._bs.is_spin_polarized:
                # Add second spin data if available
                trace_tdos = {
                    "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": "Total DOS (spin ↓)",
                    "line": go.scatter.Line(color="#444444", dash="dot"),
                    "fill": "tozerox",
                    "fillcolor": "#C4C4C4",
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = {
                "x": dos.densities[Spin.up][dos_min:dos_max],
                "y": dos.energies[dos_min:dos_max] - dos.efermi,
                "mode": "lines",
                "name": tdos_label,
                "line": go.scatter.Line(color="#444444"),
                "fill": "tozerox",
                "fillcolor": "#C4C4C4",
                "legendgroup": "spinup",
                "xaxis": "x2",
                "yaxis": "y2",
            }

            dostraces.append(trace_tdos)

            ele_dos = dos.get_element_dos()
            elements = [str(entry) for entry in ele_dos.keys()]

            if dos_select == "ap":
                proj_data = ele_dos
            elif dos_select == "op":
                proj_data = dos.get_spd_dos()
            elif "orb" in dos_select:
                proj_data = dos.get_element_spd_dos(
                    Element(dos_select.replace("orb", "")))
            else:
                raise PreventUpdate

            # Projected DOS
            count = 0
            colors = [
                "#d62728",  # brick red
                "#2ca02c",  # cooked asparagus green
                "#17becf",  # blue-teal
                "#bcbd22",  # curry yellow-green
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
            ]

            for label in proj_data.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = {
                        "x":
                        -1.0 *
                        proj_data[label].densities[Spin.down][dos_min:dos_max],
                        "y":
                        dos.energies[dos_min:dos_max] - dos.efermi,
                        "mode":
                        "lines",
                        "name":
                        str(label) + " (spin ↓)",
                        "line":
                        dict(width=3, color=colors[count], dash="dot"),
                        "xaxis":
                        "x2",
                        "yaxis":
                        "y2",
                    }

                    dostraces.append(trace)
                    spin_up_label = str(label) + " (spin ↑)"

                else:
                    spin_up_label = str(label)

                trace = {
                    "x": proj_data[label].densities[Spin.up][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": spin_up_label,
                    "line": dict(width=2, color=colors[count]),
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace)

                count += 1
            traces = [bstraces, dostraces, bs_data]

            return (traces, elements)
Exemple #4
0
def get_ent_from_db(
    elec_store: MongoStore,
    material_store: MongoStore,
    tasks_store: MongoStore,
    batt_id: Union[str, int] = None,
    task_id: Union[str, int] = None,
    get_aeccar: bool = False,
    working_ion: str = "Li",
    add_fields: list = None,
    get_initial: bool = False,
):
    """
    Get the migration path information in the form of a ComputedEntryGraph
    object from the an atomate data stack

    The algorithm gets all tasks with structures that are valid (i.e. matches a
    base structure) and generates a migration pathway object using all possible
    relaxed working ion positions found in that set. Since each material entry
    might contain multiple calculations with different cell sizes, this will
    have to work at the task level. Need to group tasks together based on the
    cell size of the base material

    Note that SPGlib is some times inconsistent when it comes to the getting
    the number of symmetry operations for a given structure. Sometimes
    structures that are the same using StructureMatcher.fit will have
    different number of symmetry operation. As such we will check the number
    of operations for each base structure in a given family of structures
    and take the case with the highest number symmetry operations In cases
    where AECCAR is required, only the tasks with AECCARs will have this data.

    Args:

        elec_store: Electrode documents one per each similar group of
            insertion materials, can also use any db that contains a
        material_ids list with topotactic structures
        material_store: Material documenets one per each similar structure (
            multiple tasks)
        tasks_store: Task documents one per each VASP calculation
        batt_id: battery id to lookup in a database.
        task_id: if battery id is not provided then look up a materials id.
        get_aeccar: If True, only find base tasks with the charge density stored
        working_ion: Name of the working ion. Defaults to 'Li'.
        add_fields: Take these fields from the task_documents and store them
            in ComputedStructureEntry
        get_initial: Store the initial structure of a calculation

    """

    task_ids_type = type(material_store.query_one({})["task_ids"][0])
    material_ids_type = type(elec_store.query_one({})["material_ids"][0])

    logger.debug(material_ids_type)

    def get_task_ids_from_batt_id(b_id):
        mat_ids = list(
            map(task_ids_type,
                elec_store.query_one({"battid": b_id})["material_ids"]))
        logger.debug(f"mat_ids : {mat_ids}")
        l_task_ids = [
            imat["task_ids"]
            for imat in material_store.query({"task_ids": {
                "$in": mat_ids
            }})
        ]
        l_task_ids = list(chain.from_iterable(l_task_ids))
        logger.debug(f"l_task_ids : {l_task_ids}")
        return l_task_ids

    def get_batt_ids_from_task_id(t_id):
        l_task_ids = [
            c0["task_ids"]
            for c0 in material_store.query({"task_ids": {
                "$in": [int(t_id)]
            }})
        ]
        l_task_ids = list(chain.from_iterable(l_task_ids))
        l_task_ids = list(map(material_ids_type, l_task_ids))
        logger.debug(f"l_task_ids : {l_task_ids}")
        l_mat_ids = [
            c0["material_ids"]
            for c0 in elec_store.query({"material_ids": {
                "$in": l_task_ids
            }})
        ]
        l_mat_ids = list(chain.from_iterable(l_mat_ids))
        l_mat_ids = list(map(task_ids_type, l_mat_ids))
        logger.debug(f"l_mat_ids : {l_mat_ids}")
        l_task_ids = [
            c0["task_ids"]
            for c0 in material_store.query({"task_ids": {
                "$in": l_mat_ids
            }})
        ]
        l_task_ids = list(chain.from_iterable(l_task_ids))
        logger.debug(f"l_task_ids : {l_task_ids}")
        return l_task_ids

    def get_entry(task_doc,
                  base_with_aeccar=False,
                  add_fields=None,
                  get_initial=None):
        # we don't really need to think about compatibility for now if just
        # want to make a code that automate NEB calculations
        tmp_struct = Structure.from_dict(task_doc["output"]["structure"])
        settings_dict = dict(
            potcar_spec=task_doc["calcs_reversed"][0]["input"]["potcar_spec"],
            rung_type=task_doc["calcs_reversed"][0]["run_type"],
        )
        if "is_hubbard" in task_doc["calcs_reversed"][0].keys():
            settings_dict["hubbards"] = task_doc["calcs_reversed"][0][
                "hubbards"]
            settings_dict["is_hubbard"] = (
                task_doc["calcs_reversed"][0]["is_hubbard"], )

        entry = ComputedStructureEntry(
            structure=tmp_struct,
            energy=task_doc["output"]["energy"],
            parameters=settings_dict,
            entry_id=task_doc["task_id"],
        )
        if base_with_aeccar:
            logger.debug("test")
            aec_id = tasks_store.query_one({"task_id":
                                            entry.entry_id})["task_id"]
            aeccar = get_aeccar_from_store(tasks_store, aec_id)
            entry.data.update({"aeccar": aeccar})

        if add_fields:
            for field in add_fields:
                if field in task_doc:
                    entry.data.update({field: task_doc[field]})
        if get_initial:
            entry.data.update(
                {"initial_structure": task_doc["input"]["structure"]})

        return entry

    # Require a single base entry and multiple inserted entries to populate
    # the migration pathways

    # getting a full list of task ids
    # Batt_id -> material_id -> all task_ids
    # task_id -> mat_ids -> Batt_ids -> material_id -> all task_ids
    if batt_id:
        all_tasks = get_task_ids_from_batt_id(batt_id)
    else:
        all_tasks = get_batt_ids_from_task_id(task_id)
    # get_all the structures
    if get_aeccar:
        all_ents_base = [
            get_entry(
                c0,
                base_with_aeccar=True,
                add_fields=add_fields,
                get_initial=get_initial,
            ) for c0 in tasks_store.query({
                "task_id": {
                    "$in": all_tasks
                },
                "elements": {
                    "$nin": [working_ion]
                },
                "calcs_reversed.0.aeccar0_fs_id": {
                    "$exists": 1
                },
            })
        ]
    else:
        all_ents_base = [
            get_entry(c0) for c0 in tasks_store.query({
                "task_id": {
                    "$in": all_tasks
                },
                "elements": {
                    "$nin": [working_ion]
                }
            })
        ]
    logger.debug(f"Number of base entries: {len(all_ents_base)}")

    all_ents_insert = [
        get_entry(c0, add_fields=add_fields, get_initial=get_initial)
        for c0 in tasks_store.query({
            "task_id": {
                "$in": all_tasks
            },
            "elements": {
                "$in": [working_ion]
            }
        })
    ]
    logger.debug(f"Number of inserted entries: {len(all_ents_insert)}")
    tmp = [f"{itr.name}({itr.entry_id})" for itr in all_ents_insert]
    logger.debug(f"{tmp}")
    return all_ents_base, all_ents_insert
class JointStoreTest(unittest.TestCase):
    def setUp(self):
        self.jointstore = JointStore("maggma_test", ["test1", "test2"])
        self.jointstore.connect()
        self.jointstore.collection.drop()
        self.jointstore.collection.insert_many([{
            "task_id":
            k,
            "my_prop":
            k + 1,
            "last_updated":
            datetime.utcnow(),
            "category":
            k // 5
        } for k in range(10)])
        self.jointstore.collection.database["test2"].drop()
        self.jointstore.collection.database["test2"].insert_many([{
            "task_id":
            2 * k,
            "your_prop":
            k + 3,
            "last_updated":
            datetime.utcnow(),
            "category2":
            k // 3
        } for k in range(5)])
        self.test1 = MongoStore("maggma_test", "test1")
        self.test1.connect()
        self.test2 = MongoStore("maggma_test", "test2")
        self.test2.connect()

    def test_query(self):
        # Test query all
        docs = list(self.jointstore.query())
        self.assertEqual(len(docs), 10)
        docs_w_field = [d for d in docs if "test2" in d]
        self.assertEqual(len(docs_w_field), 5)
        docs_w_field = sorted(docs_w_field, key=lambda x: x['task_id'])
        self.assertEqual(docs_w_field[0]['test2']['your_prop'], 3)
        self.assertEqual(docs_w_field[0]['task_id'], 0)
        self.assertEqual(docs_w_field[0]['my_prop'], 1)

    def test_query_one(self):
        doc = self.jointstore.query_one()
        self.assertEqual(doc['my_prop'], doc['task_id'] + 1)
        # Test limit properties
        doc = self.jointstore.query_one(properties=['test2', 'task_id'])
        self.assertEqual(doc['test2']['your_prop'], doc['task_id'] + 3)
        self.assertIsNone(doc.get("my_prop"))
        # Test criteria
        doc = self.jointstore.query_one(criteria={"task_id": {"$gte": 10}})
        self.assertIsNone(doc)
        doc = self.jointstore.query_one(
            criteria={"test2.your_prop": {
                "$gt": 6
            }})
        self.assertEqual(doc['task_id'], 8)

        # Test merge_at_root
        self.jointstore.merge_at_root = True

        # Test merging is working properly
        doc = self.jointstore.query_one(criteria={"task_id": 2})
        self.assertEqual(doc['my_prop'], 3)
        self.assertEqual(doc['your_prop'], 4)

        # Test merging is allowing for subsequent match
        doc = self.jointstore.query_one(criteria={"your_prop": {"$gt": 6}})
        self.assertEqual(doc['task_id'], 8)

    def test_distinct(self):
        dyour_prop = self.jointstore.distinct("test2.your_prop")
        self.assertEqual(set(dyour_prop), {k + 3 for k in range(5)})
        dmy_prop = self.jointstore.distinct("my_prop")
        self.assertEqual(set(dmy_prop), {k + 1 for k in range(10)})
        dmy_prop_cond = self.jointstore.distinct(
            "my_prop", {"test2.your_prop": {
                "$gte": 5
            }})
        self.assertEqual(set(dmy_prop_cond), {5, 7, 9})

    def test_last_updated(self):
        doc = self.jointstore.query_one({"task_id": 0})
        test1doc = self.test1.query_one({"task_id": 0})
        test2doc = self.test2.query_one({"task_id": 0})
        self.assertEqual(test2doc['last_updated'], doc['last_updated'])
        self.assertNotEqual(test1doc['last_updated'], doc['last_updated'])
        # Swap the two
        test2date = test2doc['last_updated']
        test2doc['last_updated'] = test1doc['last_updated']
        test1doc['last_updated'] = test2date
        self.test1.update([test1doc], update_lu=False)
        self.test2.update([test2doc], update_lu=False)
        doc = self.jointstore.query_one({"task_id": 0})
        test1doc = self.test1.query_one({"task_id": 0})
        test2doc = self.test2.query_one({"task_id": 0})
        self.assertEqual(test1doc['last_updated'], doc['last_updated'])
        self.assertNotEqual(test2doc['last_updated'], doc['last_updated'])
        # Check also that still has a field if no task2 doc
        doc = self.jointstore.query_one({"task_id": 1})
        self.assertIsNotNone(doc['last_updated'])

    def test_groupby(self):
        docs = list(self.jointstore.groupby("category"))
        self.assertEqual(len(docs[0]['docs']), 5)
        self.assertEqual(len(docs[1]['docs']), 5)
        docs = list(self.jointstore.groupby("test2.category2"))
        docs_by_id = {get(d, '_id.test2.category2'): d['docs'] for d in docs}
        self.assertEqual(len(docs_by_id[None]), 5)
        self.assertEqual(len(docs_by_id[0]), 3)
        self.assertEqual(len(docs_by_id[1]), 2)