Esempio n. 1
0
def test_mongostore_newer_in(mongostore):
    target = MongoStore("maggma_test", "test_target")
    target.connect()

    # make sure docs are newer in mongostore then target and check updated_keys

    target.update(
        [
            {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()}
            for i in range(10)
        ]
    )

    # Update docs in source
    mongostore.update(
        [
            {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()}
            for i in range(10)
        ]
    )

    assert len(target.newer_in(mongostore)) == 10
    assert len(target.newer_in(mongostore, exhaustive=True)) == 10
    assert len(mongostore.newer_in(target)) == 0

    target._collection.drop()
Esempio n. 2
0
def reporting_store():
    store = MongoStore("maggma_test", "reporting")
    store.connect()
    store.remove_docs({})
    yield store
    store.remove_docs({})
    store._collection.drop()
Esempio n. 3
0
def mongostore():
    store = MongoStore("maggma_test", "test")
    store.connect()
    store.remove_docs({})
    yield store
    store.remove_docs({})
    store._collection.drop()
Esempio n. 4
0
class TestDiffractionBuilder(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dbname = "test_" + uuid4().hex
        s = MongoStore(cls.dbname, "test")
        s.connect()
        cls.client = s.collection.database.client

    @classmethod
    def tearDownClass(cls):
        cls.client.drop_database(cls.dbname)

    def setUp(self):
        kwargs = dict(key="k", lu_field="lu")
        self.source = MongoStore(self.dbname, "source", **kwargs)
        self.target = MongoStore(self.dbname, "target", **kwargs)
        self.source.connect()
        self.source.collection.create_index("lu")
        self.source.collection.create_index("k", unique=True)
        self.target.connect()
        self.target.collection.create_index("lu")
        self.target.collection.create_index("k", unique=True)

    def tearDown(self):
        self.source.collection.drop()
        self.target.collection.drop()

    def test_get_xrd_from_struct(self):
        builder = DiffractionBuilder(self.source, self.target)
        structure = PymatgenTest.get_structure("Si")
        self.assertIn("Cu", builder.get_xrd_from_struct(structure))

    def test_serialization(self):
        builder = DiffractionBuilder(self.source, self.target)
        self.assertIsNone(builder.as_dict()["xrd_settings"])
Esempio n. 5
0
class TestMaterials(BuilderTest):
    def setUp(self):
        self.materials = MongoStore("emmet_test", "materials")
        self.materials.connect()

        self.materials.collection.drop()
        self.mbuilder = MaterialsBuilder(self.tasks,
                                         self.materials,
                                         mat_prefix="",
                                         chunk_size=1)

    def test_get_items(self):
        to_process = list(self.mbuilder.get_items())
        to_process_forms = {tasks[0]["formula_pretty"] for tasks in to_process}

        self.assertEqual(len(to_process), 12)
        self.assertEqual(len(to_process_forms), 12)
        self.assertEqual(len(list(chain.from_iterable(to_process))), 197)
        self.assertTrue("Sr" in to_process_forms)
        self.assertTrue("Hf" in to_process_forms)
        self.assertTrue("O2" in to_process_forms)
        self.assertFalse("H" in to_process_forms)

    def test_process_item(self):
        tasks = list(self.tasks.query(criteria={"chemsys": "Sr"}))
        mats = self.mbuilder.process_item(tasks)
        self.assertEqual(len(mats), 7)

        tasks = list(self.tasks.query(criteria={"chemsys": "Hf"}))
        mats = self.mbuilder.process_item(tasks)
        self.assertEqual(len(mats), 4)

        tasks = list(self.tasks.query(criteria={"chemsys": "O"}))
        mats = self.mbuilder.process_item(tasks)

        self.assertEqual(len(mats), 6)

        tasks = list(self.tasks.query(criteria={"chemsys": "O-Sr"}))
        mats = self.mbuilder.process_item(tasks)
        self.assertEqual(len(mats), 5)

        tasks = list(self.tasks.query(criteria={"chemsys": "Hf-O-Sr"}))
        mats = self.mbuilder.process_item(tasks)
        self.assertEqual(len(mats), 13)

    def test_update_targets(self):
        tasks = list(self.tasks.query(criteria={"chemsys": "Sr"}))
        mats = self.mbuilder.process_item(tasks)
        self.assertEqual(len(mats), 7)

        self.mbuilder.update_targets([mats])
        self.assertEqual(len(self.materials.distinct("task_id")), 7)
        self.assertEqual(len(list(self.materials.query())), 7)

    def tearDown(self):
        self.materials.collection.drop()
Esempio n. 6
0
def test_mongostore_connect_via_ssh():
    mongostore = MongoStore("maggma_test", "test")

    class fake_pipe:
        remote_bind_address = ("localhost", 27017)
        local_bind_address = ("localhost", 37017)

    server = fake_pipe()
    mongostore.connect(ssh_tunnel=server)
    assert isinstance(mongostore._collection, pymongo.collection.Collection)
Esempio n. 7
0
class TestMaterials(BuilderTest):
    def setUp(self):
        self.ml_strucs = MongoStore("emmet_test", "ml_strucs", key="entry_id")
        self.ml_strucs.connect()

        self.ml_strucs.collection.drop()
        self.mlbuilder = MLStructuresBuilder(
            self.tasks,
            self.ml_strucs,
            task_types=("Structure Optimization", "Static"))

    def test_get_items(self):
        to_process = list(self.mlbuilder.get_items())
        to_process_forms = {task["formula_pretty"] for task in to_process}

        self.assertEqual(len(to_process), 197)
        self.assertEqual(len(to_process_forms), 12)
        self.assertTrue("Sr" in to_process_forms)
        self.assertTrue("Hf" in to_process_forms)
        self.assertTrue("O2" in to_process_forms)
        self.assertFalse("H" in to_process_forms)

    def test_process_item(self):
        for task in self.tasks.query():
            ml_strucs = self.mlbuilder.process_item(task)
            t_type = task_type(get(task, 'input.incar'))
            if not any([t in t_type for t in self.mlbuilder.task_types]):
                self.assertEqual(len(ml_strucs), 0)
            else:
                self.assertEqual(
                    len(ml_strucs),
                    sum([
                        len(t["output"]["ionic_steps"])
                        for t in task["calcs_reversed"]
                    ]))

    def test_update_targets(self):
        for task in self.tasks.query():
            ml_strucs = self.mlbuilder.process_item(task)
            self.mlbuilder.update_targets([ml_strucs])
        self.assertEqual(len(self.ml_strucs.distinct("task_id")), 102)
        self.assertEqual(len(list(self.ml_strucs.query())), 1012)

    def tearDown(self):
        self.ml_strucs.collection.drop()
class JointStoreTest(unittest.TestCase):
    def setUp(self):
        self.jointstore = JointStore("maggma_test", ["test1", "test2"])
        self.jointstore.connect()
        self.jointstore.collection.drop()
        self.jointstore.collection.insert_many([{
            "task_id":
            k,
            "my_prop":
            k + 1,
            "last_updated":
            datetime.utcnow(),
            "category":
            k // 5
        } for k in range(10)])
        self.jointstore.collection.database["test2"].drop()
        self.jointstore.collection.database["test2"].insert_many([{
            "task_id":
            2 * k,
            "your_prop":
            k + 3,
            "last_updated":
            datetime.utcnow(),
            "category2":
            k // 3
        } for k in range(5)])
        self.test1 = MongoStore("maggma_test", "test1")
        self.test1.connect()
        self.test2 = MongoStore("maggma_test", "test2")
        self.test2.connect()

    def test_query(self):
        # Test query all
        docs = list(self.jointstore.query())
        self.assertEqual(len(docs), 10)
        docs_w_field = [d for d in docs if "test2" in d]
        self.assertEqual(len(docs_w_field), 5)
        docs_w_field = sorted(docs_w_field, key=lambda x: x['task_id'])
        self.assertEqual(docs_w_field[0]['test2']['your_prop'], 3)
        self.assertEqual(docs_w_field[0]['task_id'], 0)
        self.assertEqual(docs_w_field[0]['my_prop'], 1)

    def test_query_one(self):
        doc = self.jointstore.query_one()
        self.assertEqual(doc['my_prop'], doc['task_id'] + 1)
        # Test limit properties
        doc = self.jointstore.query_one(properties=['test2', 'task_id'])
        self.assertEqual(doc['test2']['your_prop'], doc['task_id'] + 3)
        self.assertIsNone(doc.get("my_prop"))
        # Test criteria
        doc = self.jointstore.query_one(criteria={"task_id": {"$gte": 10}})
        self.assertIsNone(doc)
        doc = self.jointstore.query_one(
            criteria={"test2.your_prop": {
                "$gt": 6
            }})
        self.assertEqual(doc['task_id'], 8)

        # Test merge_at_root
        self.jointstore.merge_at_root = True

        # Test merging is working properly
        doc = self.jointstore.query_one(criteria={"task_id": 2})
        self.assertEqual(doc['my_prop'], 3)
        self.assertEqual(doc['your_prop'], 4)

        # Test merging is allowing for subsequent match
        doc = self.jointstore.query_one(criteria={"your_prop": {"$gt": 6}})
        self.assertEqual(doc['task_id'], 8)

    def test_distinct(self):
        dyour_prop = self.jointstore.distinct("test2.your_prop")
        self.assertEqual(set(dyour_prop), {k + 3 for k in range(5)})
        dmy_prop = self.jointstore.distinct("my_prop")
        self.assertEqual(set(dmy_prop), {k + 1 for k in range(10)})
        dmy_prop_cond = self.jointstore.distinct(
            "my_prop", {"test2.your_prop": {
                "$gte": 5
            }})
        self.assertEqual(set(dmy_prop_cond), {5, 7, 9})

    def test_last_updated(self):
        doc = self.jointstore.query_one({"task_id": 0})
        test1doc = self.test1.query_one({"task_id": 0})
        test2doc = self.test2.query_one({"task_id": 0})
        self.assertEqual(test2doc['last_updated'], doc['last_updated'])
        self.assertNotEqual(test1doc['last_updated'], doc['last_updated'])
        # Swap the two
        test2date = test2doc['last_updated']
        test2doc['last_updated'] = test1doc['last_updated']
        test1doc['last_updated'] = test2date
        self.test1.update([test1doc], update_lu=False)
        self.test2.update([test2doc], update_lu=False)
        doc = self.jointstore.query_one({"task_id": 0})
        test1doc = self.test1.query_one({"task_id": 0})
        test2doc = self.test2.query_one({"task_id": 0})
        self.assertEqual(test1doc['last_updated'], doc['last_updated'])
        self.assertNotEqual(test2doc['last_updated'], doc['last_updated'])
        # Check also that still has a field if no task2 doc
        doc = self.jointstore.query_one({"task_id": 1})
        self.assertIsNotNone(doc['last_updated'])

    def test_groupby(self):
        docs = list(self.jointstore.groupby("category"))
        self.assertEqual(len(docs[0]['docs']), 5)
        self.assertEqual(len(docs[1]['docs']), 5)
        docs = list(self.jointstore.groupby("test2.category2"))
        docs_by_id = {get(d, '_id.test2.category2'): d['docs'] for d in docs}
        self.assertEqual(len(docs_by_id[None]), 5)
        self.assertEqual(len(docs_by_id[0]), 3)
        self.assertEqual(len(docs_by_id[1]), 2)
Esempio n. 9
0
 def setUpClass(cls):
     cls.dbname = "test_" + uuid4().hex
     s = MongoStore(cls.dbname, "test")
     s.connect()
     cls.client = s.collection.database.client
Esempio n. 10
0
class ElasticAnalysisBuilderTest(unittest.TestCase):
    @classmethod
    def setUp(self):
        # Set up test db, set up mpsft, etc.
        self.test_tasks = MongoStore("test_emmet", "tasks")
        self.test_tasks.connect()
        docs = loadfn(test_tasks, cls=None)
        self.test_tasks.update(docs)
        self.test_elasticity = MongoStore("test_emmet", "elasticity")
        self.test_elasticity.connect()
        if PROFILE_MODE:
            self.pr = cProfile.Profile()
            self.pr.enable()
            print("\n<<<---")

    @classmethod
    def tearDown(self):
        if not DEBUG_MODE:
            self.test_elasticity.collection.drop()
            self.test_tasks.collection.drop()
        if PROFILE_MODE:
            p = Stats(self.pr)
            p.strip_dirs()
            p.sort_stats('cumtime')
            p.print_stats()
            print("\n--->>>")

    def test_builder(self):
        ec_builder = ElasticAnalysisBuilder(self.test_tasks,
                                            self.test_elasticity,
                                            incremental=False)
        ec_builder.connect()
        for t in ec_builder.get_items():
            processed = ec_builder.process_item(t)
            self.assertTrue(bool(processed))
        runner = Runner([ec_builder])
        runner.run()
        # Test warnings
        doc = ec_builder.elasticity.query_one(
            criteria={"pretty_formula": "NaN3"})
        self.assertEqual(doc['warnings'], None)
        self.assertAlmostEqual(doc['compliance_tensor'][0][0], 41.576072, 6)

    def test_grouping_functions(self):
        docs1 = list(
            self.test_tasks.query(criteria={"formula_pretty": "NaN3"}))
        docs_grouped1 = group_by_parent_lattice(docs1)
        self.assertEqual(len(docs_grouped1), 1)
        grouped_by_opt = group_deformations_by_optimization_task(docs1)
        self.assertEqual(len(grouped_by_opt), 1)
        docs2 = self.test_tasks.query(
            criteria={"task_label": "elastic deformation"})
        sgroup2 = group_by_parent_lattice(docs2)

    def test_get_distinct_rotations(self):
        struct = PymatgenTest.get_structure("Si")
        conv = SpacegroupAnalyzer(struct).get_conventional_standard_structure()
        rots = get_distinct_rotations(conv)
        ops = SpacegroupAnalyzer(conv).get_symmetry_operations()
        for op in ops:
            self.assertTrue(
                any([np.allclose(op.rotation_matrix, r) for r in rots]))
        self.assertEqual(len(rots), 48)

    def test_process_elastic_calcs(self):
        test_struct = PymatgenTest.get_structure('Sn')  # use cubic test struct
        dss = DeformedStructureSet(test_struct)

        # Construct test task set
        opt_task = {
            "output": {
                "structure": test_struct.as_dict()
            },
            "input": {
                "structure": test_struct.as_dict()
            }
        }
        defo_tasks = []
        for n, (struct, defo) in enumerate(zip(dss, dss.deformations)):
            strain = defo.green_lagrange_strain
            defo_task = {
                "output": {
                    "structure": struct.as_dict(),
                    "stress": (strain * 5).tolist()
                },
                "input": None,
                "task_id": n,
                "completed_at": datetime.utcnow()
            }
            defo_task.update({
                "transmuter": {
                    "transformation_params": [{
                        "deformation": defo
                    }]
                }
            })
            defo_tasks.append(defo_task)

        defo_tasks.pop(0)
        explicit, derived = process_elastic_calcs(opt_task, defo_tasks)
        self.assertEqual(len(explicit), 23)
        self.assertEqual(len(derived), 1)

    def test_process_elastic_calcs_toec(self):
        # Test TOEC tasks
        test_struct = PymatgenTest.get_structure('Sn')  # use cubic test struct
        strain_states = get_default_strain_states(3)
        # Default stencil in atomate, this maybe shouldn't be hard-coded
        stencil = np.linspace(-0.075, 0.075, 7)
        strains = [
            Strain.from_voigt(s * np.array(strain_state))
            for s, strain_state in product(stencil, strain_states)
        ]
        strains = [s for s in strains if not np.allclose(s, 0)]
        sym_reduced = symmetry_reduce(strains, test_struct)
        opt_task = {
            "output": {
                "structure": test_struct.as_dict()
            },
            "input": {
                "structure": test_struct.as_dict()
            }
        }
        defo_tasks = []
        for n, strain in enumerate(sym_reduced):
            defo = strain.get_deformation_matrix()
            new_struct = defo.apply_to_structure(test_struct)
            defo_task = {
                "output": {
                    "structure": new_struct.as_dict(),
                    "stress": (strain * 5).tolist()
                },
                "input": None,
                "task_id": n,
                "completed_at": datetime.utcnow()
            }
            defo_task.update({
                "transmuter": {
                    "transformation_params": [{
                        "deformation": defo
                    }]
                }
            })
            defo_tasks.append(defo_task)
        explicit, derived = process_elastic_calcs(opt_task, defo_tasks)
        self.assertEqual(len(explicit), len(sym_reduced))
        self.assertEqual(len(derived), len(strains) - len(sym_reduced))
        for calc in derived:
            self.assertTrue(
                np.allclose(calc['strain'], calc['cauchy_stress'] / -0.5))
Esempio n. 11
0
class ElasticAggregateBuilderTest(unittest.TestCase):
    def setUp(self):
        # Empty aggregated collection
        self.test_elasticity_agg = MongoStore("test_emmet", "elasticity_agg")
        self.test_elasticity_agg.connect()

        # Generate test materials collection
        self.test_materials = MongoStore("test_emmet", "materials")
        self.test_materials.connect()
        mat_docs = []
        for n, formula in enumerate(['Si', 'BaNiO3', 'Li2O2', 'TiO2']):
            structure = PymatgenTest.get_structure(formula)
            structure.add_site_property("magmoms", [0.0] * len(structure))
            mat_docs.append({
                "task_id": "mp-{}".format(n),
                "structure": structure.as_dict(),
                "pretty_formula": formula
            })
        self.test_materials.update(mat_docs, update_lu=False)

        # Create elasticity collection and add docs
        self.test_elasticity = MongoStore("test_emmet",
                                          "elasticity",
                                          key="optimization_task_id")
        self.test_elasticity.connect()

        si = PymatgenTest.get_structure("Si")
        si.add_site_property("magmoms", [0.0] * len(si))
        et = ElasticTensor.from_voigt([[50, 25, 25, 0, 0, 0],
                                       [25, 50, 25, 0, 0, 0],
                                       [25, 25, 50, 0, 0, 0],
                                       [0, 0, 0, 75, 0,
                                        0], [0, 0, 0, 0, 75, 0],
                                       [0, 0, 0, 0, 0, 75]])
        doc = {
            "input_structure": si.copy().as_dict(),
            "order": 2,
            "magnetic_type": "non-magnetic",
            "optimization_task_id": "mp-1",
            "last_updated": datetime.utcnow(),
            "completed_at": datetime.utcnow(),
            "optimized_structure": si.copy().as_dict(),
            "pretty_formula": "Si",
            "state": "successful"
        }
        doc['elastic_tensor'] = et.voigt
        doc.update(et.property_dict)
        self.test_elasticity.update([doc])
        # Insert second doc with diff params
        si.perturb(0.005)
        doc.update({
            "optimized_structure": si.copy().as_dict(),
            "updated_at": datetime.utcnow(),
            "optimization_task_id": "mp-5"
        })
        self.test_elasticity.update([doc])
        self.builder = self.get_a_new_builder()

    def tearDown(self):
        if not DEBUG_MODE:
            self.test_elasticity.collection.drop()
            self.test_elasticity_agg.collection.drop()
            self.test_materials.collection.drop()

    def test_materials_aggregator(self):
        materials_dict = generate_formula_dict(self.test_materials)
        docs = []
        grouped_by_mpid = group_by_material_id(
            materials_dict['Si'],
            [{
                'structure': PymatgenTest.get_structure('Si').as_dict(),
                'magnetic_type': "non-magnetic"
            }])
        self.assertEqual(len(grouped_by_mpid), 1)
        materials_dict = generate_formula_dict(self.test_materials)

    def test_get_items(self):
        iterator = self.builder.get_items()
        for item in iterator:
            self.assertIsNotNone(item)

    def test_process_items(self):
        docs = list(
            self.test_elasticity.query(criteria={"pretty_formula": "Si"}))
        formula_dict = generate_formula_dict(self.test_materials)
        processed = self.builder.process_item((docs, formula_dict['Si']))
        self.assertEqual(len(processed), 1)
        self.assertEqual(len(processed[0]['all_elastic_fits']), 2)

    def test_update_targets(self):
        processed = [
            self.builder.process_item(item)
            for item in self.builder.get_items()
        ]
        self.builder.update_targets(processed)

    def test_aggregation(self):
        runner = Runner([self.builder])
        runner.run()
        all_agg_docs = list(self.test_elasticity_agg.query())
        self.assertTrue(bool(all_agg_docs))

    def get_a_new_builder(self):
        return ElasticAggregateBuilder(self.test_elasticity,
                                       self.test_materials,
                                       self.test_elasticity_agg)
Esempio n. 12
0
def test_mongostore_connect():
    mongostore = MongoStore("maggma_test", "test")
    assert mongostore._collection is None
    mongostore.connect()
    assert isinstance(mongostore._collection, pymongo.collection.Collection)
Esempio n. 13
0
class TestCopyBuilder(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dbname = "test_" + uuid4().hex
        s = MongoStore(cls.dbname, "test")
        s.connect()
        cls.client = s.collection.database.client

    @classmethod
    def tearDownClass(cls):
        cls.client.drop_database(cls.dbname)

    def setUp(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys]
        self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]]
        kwargs = dict(key="k", lu_field="lu")
        self.source = MongoStore(self.dbname, "source", **kwargs)
        self.target = MongoStore(self.dbname, "target", **kwargs)
        self.builder = CopyBuilder(self.source, self.target)
        self.source.connect()
        self.source.collection.create_index("lu")
        self.target.connect()
        self.target.collection.create_index("lu")
        self.target.collection.create_index("k")

    def tearDown(self):
        self.source.collection.drop()
        self.target.collection.drop()

    def test_get_items(self):
        self.source.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.old_docs))
        self.target.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())), 0)
        self.source.update(self.new_docs, update_lu=False)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.new_docs))

    def test_process_item(self):
        self.source.collection.insert_many(self.old_docs)
        items = list(self.builder.get_items())
        self.assertCountEqual(items, map(self.builder.process_item, items))

    def test_update_targets(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        items = list(map(self.builder.process_item, self.builder.get_items()))
        self.builder.update_targets(items)
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_confirm_lu_field_index(self):
        self.source.collection.drop_index("lu_1")
        with self.assertRaises(Exception) as cm:
            self.builder.get_items()
        self.assertTrue(cm.exception.args[0].startswith("Need index"))
        self.source.collection.create_index("lu")

    def test_runner(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        runner = Runner([self.builder])
        runner.run()
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_query(self):
        self.builder.query = {"k": {"$gt": 5}}
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        runner = Runner([self.builder])
        runner.run()
        all_docs = list(self.target.query(criteria={}))
        self.assertEqual(len(all_docs), 14)
        self.assertTrue(min([d['k'] for d in all_docs]), 6)
Esempio n. 14
0
class TestCopyBuilder(TestCase):
    @classmethod
    def setUpClass(cls):
        cls.dbname = "test_" + uuid4().hex
        s = MongoStore(cls.dbname, "test")
        s.connect()
        cls.client = s.collection.database.client

    @classmethod
    def tearDownClass(cls):
        cls.client.drop_database(cls.dbname)

    def setUp(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys]
        self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]]
        kwargs = dict(key="k", lu_field="lu")
        self.source = MongoStore(self.dbname, "source", **kwargs)
        self.target = MongoStore(self.dbname, "target", **kwargs)
        self.builder = CopyBuilder(self.source, self.target)

        self.source.connect()
        self.source.ensure_index(self.source.key)
        self.source.ensure_index(self.source.lu_field)

        self.target.connect()
        self.target.ensure_index(self.target.key)
        self.target.ensure_index(self.target.lu_field)

    def tearDown(self):
        self.source.collection.drop()
        self.target.collection.drop()

    def test_get_items(self):
        self.source.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.old_docs))
        self.target.collection.insert_many(self.old_docs)
        self.assertEqual(len(list(self.builder.get_items())), 0)
        self.source.update(self.new_docs, update_lu=False)
        self.assertEqual(len(list(self.builder.get_items())),
                         len(self.new_docs))

    def test_process_item(self):
        self.source.collection.insert_many(self.old_docs)
        items = list(self.builder.get_items())
        self.assertCountEqual(items, map(self.builder.process_item, items))

    def test_update_targets(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        items = list(map(self.builder.process_item, self.builder.get_items()))
        self.builder.update_targets(items)
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    @unittest.skip(
        "Have to refactor how we force read-only so a warning will get thrown")
    def test_index_warning(self):
        """Should log warning when recommended store indexes are not present."""
        self.source.collection.drop_index([(self.source.key, 1)])
        with self.assertLogs(level=logging.WARNING) as cm:
            list(self.builder.get_items())
        self.assertIn("Ensure indices", "\n".join(cm.output))

    def test_run(self):
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)
        self.builder.run()
        self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_query(self):
        self.builder.query = {"k": {"$gt": 5}}
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.builder.run()
        all_docs = list(self.target.query(criteria={}))
        self.assertEqual(len(all_docs), 14)
        self.assertTrue(min([d['k'] for d in all_docs]), 6)

    def test_delete_orphans(self):
        self.builder = CopyBuilder(self.source,
                                   self.target,
                                   delete_orphans=True)
        self.source.collection.insert_many(self.old_docs)
        self.source.update(self.new_docs, update_lu=False)
        self.target.collection.insert_many(self.old_docs)

        deletion_criteria = {"k": {"$in": list(range(5))}}
        self.source.collection.delete_many(deletion_criteria)
        self.builder.run()

        self.assertEqual(
            self.target.collection.count_documents(deletion_criteria), 0)
        self.assertEqual(self.target.query_one(criteria={"k": 5})["v"], "new")
        self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old")

    def test_incremental_false(self):
        tic = datetime.now()
        toc = tic + timedelta(seconds=1)
        keys = list(range(20))
        earlier = [{"lu": tic, "k": k, "v": "val"} for k in keys]
        later = [{"lu": toc, "k": k, "v": "val"} for k in keys]
        self.source.collection.insert_many(earlier)
        self.target.collection.insert_many(later)
        query = {"k": {"$gt": 5}}
        self.builder = CopyBuilder(self.source,
                                   self.target,
                                   incremental=False,
                                   query=query)
        self.builder.run()
        docs = sorted(self.target.query(), key=lambda d: d["k"])
        self.assertTrue(all(d["lu"] == tic) for d in docs[5:])
        self.assertTrue(all(d["lu"] == toc) for d in docs[:5])
Esempio n. 15
0
class TestThermo(BuilderTest):
    def setUp(self):

        self.materials = MongoStore("emmet_test", "materials")
        self.thermo = MongoStore("emmet_test", "thermo")

        self.materials.connect()
        self.thermo.connect()

        self.mbuilder = MaterialsBuilder(self.tasks,
                                         self.materials,
                                         mat_prefix="",
                                         chunk_size=1)
        self.tbuilder = ThermoBuilder(self.materials,
                                      self.thermo,
                                      chunk_size=1)
        runner = Runner([self.mbuilder])
        runner.run()

    def test_get_entries(self):
        self.assertEqual(len(self.tbuilder.get_entries("Sr")), 7)
        self.assertEqual(len(self.tbuilder.get_entries("Hf")), 4)
        self.assertEqual(len(self.tbuilder.get_entries("O")), 6)
        self.assertEqual(len(self.tbuilder.get_entries("Hf-O-Sr")), 44)
        self.assertEqual(len(self.tbuilder.get_entries("Sr-Hf")), 11)

    def test_get_items(self):
        self.thermo.collection.drop()
        comp_systems = list(self.tbuilder.get_items())
        self.assertEqual(len(comp_systems), 1)
        self.assertEqual(len(comp_systems[0]), 44)

    def test_process_item(self):

        tbuilder = ThermoBuilder(self.materials,
                                 self.thermo,
                                 query={"elements": ["Sr"]},
                                 chunk_size=1)
        entries = list(tbuilder.get_items())[0]
        self.assertEqual(len(entries), 7)

        t_docs = self.tbuilder.process_item(entries)
        e_above_hulls = [t['thermo']['e_above_hull'] for t in t_docs]
        sorted_t_docs = list(
            sorted(t_docs, key=lambda x: x['thermo']['e_above_hull']))
        self.assertEqual(sorted_t_docs[0]["task_id"], "mp-76")

    def test_update_targets(self):
        self.thermo.collection.drop()

        tbuilder = ThermoBuilder(self.materials,
                                 self.thermo,
                                 query={"elements": ["Sr"]},
                                 chunk_size=1)
        entries = list(tbuilder.get_items())[0]
        self.assertEqual(len(entries), 7)

        t_docs = self.tbuilder.process_item(entries)
        self.tbuilder.update_targets([t_docs])
        self.assertEqual(len(list(self.thermo.query())), len(t_docs))

    def tearDown(self):
        self.materials.collection.drop()
        self.thermo.collection.drop()
Esempio n. 16
0
def jointstore_test2():
    store = MongoStore("maggma_test", "test2")
    store.connect()
    yield store
    store._collection.drop()
Esempio n. 17
0
        def bs_dos_data(
            mpid,
            path_convention,
            dos_select,
            label_select,
            bandstructure_symm_line,
            density_of_states,
        ):
            if not mpid and (bandstructure_symm_line is None
                             or density_of_states is None):
                raise PreventUpdate

            elif bandstructure_symm_line is None or density_of_states is None:
                if label_select == "":
                    raise PreventUpdate

                # --
                # -- BS and DOS from API or DB
                # --

                bs_data = {"ticks": {}}

                bs_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="bandstructure_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                dos_store = GridFSStore(
                    database="fw_bs_prod",
                    collection_name="dos_fs",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                )

                es_store = MongoStore(
                    database="fw_bs_prod",
                    collection_name="electronic_structure",
                    host="mongodb03.nersc.gov",
                    port=27017,
                    username="******",
                    password="",
                    key="task_id",
                )

                # - BS traces from DB using task_id
                es_store.connect()
                bs_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=[
                        "bandstructure.{}.task_id".format(path_convention),
                        "bandstructure.{}.total.equiv_labels".format(
                            path_convention),
                    ],
                )

                es_store.close()

                bs_store.connect()
                bandstructure_symm_line = bs_store.query_one(criteria={
                    "metadata.task_id":
                    int(bs_query["bandstructure"][path_convention]["task_id"])
                }, )

                # If LM convention, get equivalent labels
                if path_convention != label_select:
                    bs_equiv_labels = bs_query["bandstructure"][
                        path_convention]["total"]["equiv_labels"]

                    new_labels_dict = {}
                    for label in bandstructure_symm_line["labels_dict"].keys():

                        label_formatted = label.replace("$", "")

                        if "|" in label_formatted:
                            f_label = label_formatted.split("|")
                            new_labels.append(
                                "$" +
                                bs_equiv_labels[label_select][f_label[0]] +
                                "|" +
                                bs_equiv_labels[label_select][f_label[1]] +
                                "$")
                        else:
                            new_labels_dict["$" + bs_equiv_labels[label_select]
                                            [label_formatted] +
                                            "$"] = bandstructure_symm_line[
                                                "labels_dict"][label]

                    bandstructure_symm_line["labels_dict"] = new_labels_dict

                # - DOS traces from DB using task_id
                es_store.connect()
                dos_query = es_store.query_one(
                    criteria={"task_id": int(mpid)},
                    properties=["dos.task_id"],
                )
                es_store.close()

                dos_store.connect()
                density_of_states = dos_store.query_one(
                    criteria={"task_id": int(dos_query["dos"]["task_id"])}, )

            # - BS Data
            if (type(bandstructure_symm_line) != dict
                    and bandstructure_symm_line is not None):
                bandstructure_symm_line = bandstructure_symm_line.to_dict()

            if type(density_of_states
                    ) != dict and density_of_states is not None:
                density_of_states = density_of_states.to_dict()

            bsml = BSML.from_dict(bandstructure_symm_line)

            bs_reg_plot = BSPlotter(bsml)

            bs_data = bs_reg_plot.bs_plot_data()

            # Make plot continous for lm
            if path_convention == "lm":
                distance_map, kpath_euler = HSKP(
                    bsml.structure).get_continuous_path(bsml)

                kpath_labels = [pair[0] for pair in kpath_euler]
                kpath_labels.append(kpath_euler[-1][1])

            else:
                distance_map = [(i, False)
                                for i in range(len(bs_data["distances"]))]
                kpath_labels = []
                for label_ind in range(len(bs_data["ticks"]["label"]) - 1):
                    if (bs_data["ticks"]["label"][label_ind] !=
                            bs_data["ticks"]["label"][label_ind + 1]):
                        kpath_labels.append(
                            bs_data["ticks"]["label"][label_ind])
                kpath_labels.append(bs_data["ticks"]["label"][-1])

            bs_data["ticks"]["label"] = kpath_labels

            # Obtain bands to plot over and generate traces for bs data:
            energy_window = (-6.0, 10.0)
            bands = []
            for band_num in range(bs_reg_plot._nb_bands):
                if (bs_data["energy"][0][str(Spin.up)][band_num][0] <=
                        energy_window[1]) and (bs_data["energy"][0][str(
                            Spin.up)][band_num][0] >= energy_window[0]):
                    bands.append(band_num)

            bstraces = []

            pmin = 0.0
            tick_vals = [0.0]

            cbm = bsml.get_cbm()
            vbm = bsml.get_vbm()

            cbm_new = bs_data["cbm"]
            vbm_new = bs_data["vbm"]

            for dnum, (d, rev) in enumerate(distance_map):

                x_dat = [
                    dval - bs_data["distances"][d][0] + pmin
                    for dval in bs_data["distances"][d]
                ]

                pmin = x_dat[-1]

                tick_vals.append(pmin)

                if not rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in range(len(bs_data["distances"][d]))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]
                elif rev:
                    traces_for_segment = [{
                        "x":
                        x_dat,
                        "y": [
                            bs_data["energy"][d][str(Spin.up)][i][j]
                            for j in reversed(
                                range(len(bs_data["distances"][d])))
                        ],
                        "mode":
                        "lines",
                        "line": {
                            "color": "#1f77b4"
                        },
                        "hoverinfo":
                        "skip",
                        "name":
                        "spin ↑"
                        if bs_reg_plot._bs.is_spin_polarized else "Total",
                        "hovertemplate":
                        "%{y:.2f} eV",
                        "showlegend":
                        False,
                        "xaxis":
                        "x",
                        "yaxis":
                        "y",
                    } for i in bands]

                if bs_reg_plot._bs.is_spin_polarized:

                    if not rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in range(len(bs_data["distances"][d]))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]
                    elif rev:
                        traces_for_segment += [{
                            "x":
                            x_dat,
                            "y": [
                                bs_data["energy"][d][str(Spin.down)][i][j]
                                for j in reversed(
                                    range(len(bs_data["distances"][d])))
                            ],
                            "mode":
                            "lines",
                            "line": {
                                "color": "#ff7f0e",
                                "dash": "dot"
                            },
                            "hoverinfo":
                            "skip",
                            "showlegend":
                            False,
                            "name":
                            "spin ↓",
                            "hovertemplate":
                            "%{y:.2f} eV",
                            "xaxis":
                            "x",
                            "yaxis":
                            "y",
                        } for i in bands]

                bstraces += traces_for_segment

                # - Get proper cbm and vbm coords for lm
                if path_convention == "lm":
                    for (x_point, y_point) in bs_data["cbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (cbm["kpoint"].label is None
                                    or cbm["kpoint"].label in new_label):
                                cbm_new.append((x_point_new, y_point))

                    for (x_point, y_point) in bs_data["vbm"]:
                        if x_point in bs_data["distances"][d]:
                            xind = bs_data["distances"][d].index(x_point)
                            if not rev:
                                x_point_new = x_dat[xind]
                            else:
                                x_point_new = x_dat[len(x_dat) - xind - 1]

                            new_label = bs_data["ticks"]["label"][
                                tick_vals.index(x_point_new)]

                            if (vbm["kpoint"].label is None
                                    or vbm["kpoint"].label in new_label):
                                vbm_new.append((x_point_new, y_point))

            bs_data["ticks"]["distance"] = tick_vals

            # - Strip latex math wrapping for labels
            str_replace = {
                "$": "",
                "\\mid": "|",
                "\\Gamma": "Γ",
                "\\Sigma": "Σ",
                "GAMMA": "Γ",
                "_1": "₁",
                "_2": "₂",
                "_3": "₃",
                "_4": "₄",
                "_{1}": "₁",
                "_{2}": "₂",
                "_{3}": "₃",
                "_{4}": "₄",
                "^{*}": "*",
            }

            bar_loc = []
            for entry_num in range(len(bs_data["ticks"]["label"])):
                for key in str_replace.keys():
                    if key in bs_data["ticks"]["label"][entry_num]:
                        bs_data["ticks"]["label"][entry_num] = bs_data[
                            "ticks"]["label"][entry_num].replace(
                                key, str_replace[key])
                        if key == "\\mid":
                            bar_loc.append(
                                bs_data["ticks"]["distance"][entry_num])

            # Vertical lines for disjointed segments
            vert_traces = [{
                "x": [x_point, x_point],
                "y": energy_window,
                "mode": "lines",
                "marker": {
                    "color": "white"
                },
                "hoverinfo": "skip",
                "showlegend": False,
                "xaxis": "x",
                "yaxis": "y",
            } for x_point in bar_loc]

            bstraces += vert_traces

            # Dots for cbm and vbm

            dot_traces = [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "markers",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords),
                                            cbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(cbm_new)] + [{
                "x": [x_point],
                "y": [y_point],
                "mode":
                "marker",
                "marker": {
                    "color": "#7E259B",
                    "size": 16,
                    "line": {
                        "color": "white",
                        "width": 2
                    },
                },
                "showlegend":
                False,
                "hoverinfo":
                "text",
                "name":
                "",
                "hovertemplate":
                "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords),
                                            vbm["energy"]),
                "xaxis":
                "x",
                "yaxis":
                "y",
            } for (x_point, y_point) in set(vbm_new)]

            bstraces += dot_traces

            # - DOS Data
            dostraces = []

            dos = CompleteDos.from_dict(density_of_states)

            dos_max = np.abs(
                (dos.energies - dos.efermi - energy_window[1])).argmin()
            dos_min = np.abs(
                (dos.energies - dos.efermi - energy_window[0])).argmin()

            if bs_reg_plot._bs.is_spin_polarized:
                # Add second spin data if available
                trace_tdos = {
                    "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": "Total DOS (spin ↓)",
                    "line": go.scatter.Line(color="#444444", dash="dot"),
                    "fill": "tozerox",
                    "fillcolor": "#C4C4C4",
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace_tdos)

                tdos_label = "Total DOS (spin ↑)"
            else:
                tdos_label = "Total DOS"

            # Total DOS
            trace_tdos = {
                "x": dos.densities[Spin.up][dos_min:dos_max],
                "y": dos.energies[dos_min:dos_max] - dos.efermi,
                "mode": "lines",
                "name": tdos_label,
                "line": go.scatter.Line(color="#444444"),
                "fill": "tozerox",
                "fillcolor": "#C4C4C4",
                "legendgroup": "spinup",
                "xaxis": "x2",
                "yaxis": "y2",
            }

            dostraces.append(trace_tdos)

            ele_dos = dos.get_element_dos()
            elements = [str(entry) for entry in ele_dos.keys()]

            if dos_select == "ap":
                proj_data = ele_dos
            elif dos_select == "op":
                proj_data = dos.get_spd_dos()
            elif "orb" in dos_select:
                proj_data = dos.get_element_spd_dos(
                    Element(dos_select.replace("orb", "")))
            else:
                raise PreventUpdate

            # Projected DOS
            count = 0
            colors = [
                "#d62728",  # brick red
                "#2ca02c",  # cooked asparagus green
                "#17becf",  # blue-teal
                "#bcbd22",  # curry yellow-green
                "#9467bd",  # muted purple
                "#8c564b",  # chestnut brown
                "#e377c2",  # raspberry yogurt pink
            ]

            for label in proj_data.keys():

                if bs_reg_plot._bs.is_spin_polarized:
                    trace = {
                        "x":
                        -1.0 *
                        proj_data[label].densities[Spin.down][dos_min:dos_max],
                        "y":
                        dos.energies[dos_min:dos_max] - dos.efermi,
                        "mode":
                        "lines",
                        "name":
                        str(label) + " (spin ↓)",
                        "line":
                        dict(width=3, color=colors[count], dash="dot"),
                        "xaxis":
                        "x2",
                        "yaxis":
                        "y2",
                    }

                    dostraces.append(trace)
                    spin_up_label = str(label) + " (spin ↑)"

                else:
                    spin_up_label = str(label)

                trace = {
                    "x": proj_data[label].densities[Spin.up][dos_min:dos_max],
                    "y": dos.energies[dos_min:dos_max] - dos.efermi,
                    "mode": "lines",
                    "name": spin_up_label,
                    "line": dict(width=2, color=colors[count]),
                    "xaxis": "x2",
                    "yaxis": "y2",
                }

                dostraces.append(trace)

                count += 1
            traces = [bstraces, dostraces, bs_data]

            return (traces, elements)