def test_mongostore_newer_in(mongostore): target = MongoStore("maggma_test", "test_target") target.connect() # make sure docs are newer in mongostore then target and check updated_keys target.update( [ {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10) ] ) # Update docs in source mongostore.update( [ {mongostore.key: i, mongostore.last_updated_field: datetime.utcnow()} for i in range(10) ] ) assert len(target.newer_in(mongostore)) == 10 assert len(target.newer_in(mongostore, exhaustive=True)) == 10 assert len(mongostore.newer_in(target)) == 0 target._collection.drop()
def reporting_store(): store = MongoStore("maggma_test", "reporting") store.connect() store.remove_docs({}) yield store store.remove_docs({}) store._collection.drop()
def mongostore(): store = MongoStore("maggma_test", "test") store.connect() store.remove_docs({}) yield store store.remove_docs({}) store._collection.drop()
class TestDiffractionBuilder(TestCase): @classmethod def setUpClass(cls): cls.dbname = "test_" + uuid4().hex s = MongoStore(cls.dbname, "test") s.connect() cls.client = s.collection.database.client @classmethod def tearDownClass(cls): cls.client.drop_database(cls.dbname) def setUp(self): kwargs = dict(key="k", lu_field="lu") self.source = MongoStore(self.dbname, "source", **kwargs) self.target = MongoStore(self.dbname, "target", **kwargs) self.source.connect() self.source.collection.create_index("lu") self.source.collection.create_index("k", unique=True) self.target.connect() self.target.collection.create_index("lu") self.target.collection.create_index("k", unique=True) def tearDown(self): self.source.collection.drop() self.target.collection.drop() def test_get_xrd_from_struct(self): builder = DiffractionBuilder(self.source, self.target) structure = PymatgenTest.get_structure("Si") self.assertIn("Cu", builder.get_xrd_from_struct(structure)) def test_serialization(self): builder = DiffractionBuilder(self.source, self.target) self.assertIsNone(builder.as_dict()["xrd_settings"])
class TestMaterials(BuilderTest): def setUp(self): self.materials = MongoStore("emmet_test", "materials") self.materials.connect() self.materials.collection.drop() self.mbuilder = MaterialsBuilder(self.tasks, self.materials, mat_prefix="", chunk_size=1) def test_get_items(self): to_process = list(self.mbuilder.get_items()) to_process_forms = {tasks[0]["formula_pretty"] for tasks in to_process} self.assertEqual(len(to_process), 12) self.assertEqual(len(to_process_forms), 12) self.assertEqual(len(list(chain.from_iterable(to_process))), 197) self.assertTrue("Sr" in to_process_forms) self.assertTrue("Hf" in to_process_forms) self.assertTrue("O2" in to_process_forms) self.assertFalse("H" in to_process_forms) def test_process_item(self): tasks = list(self.tasks.query(criteria={"chemsys": "Sr"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 7) tasks = list(self.tasks.query(criteria={"chemsys": "Hf"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 4) tasks = list(self.tasks.query(criteria={"chemsys": "O"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 6) tasks = list(self.tasks.query(criteria={"chemsys": "O-Sr"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 5) tasks = list(self.tasks.query(criteria={"chemsys": "Hf-O-Sr"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 13) def test_update_targets(self): tasks = list(self.tasks.query(criteria={"chemsys": "Sr"})) mats = self.mbuilder.process_item(tasks) self.assertEqual(len(mats), 7) self.mbuilder.update_targets([mats]) self.assertEqual(len(self.materials.distinct("task_id")), 7) self.assertEqual(len(list(self.materials.query())), 7) def tearDown(self): self.materials.collection.drop()
def test_mongostore_connect_via_ssh(): mongostore = MongoStore("maggma_test", "test") class fake_pipe: remote_bind_address = ("localhost", 27017) local_bind_address = ("localhost", 37017) server = fake_pipe() mongostore.connect(ssh_tunnel=server) assert isinstance(mongostore._collection, pymongo.collection.Collection)
class TestMaterials(BuilderTest): def setUp(self): self.ml_strucs = MongoStore("emmet_test", "ml_strucs", key="entry_id") self.ml_strucs.connect() self.ml_strucs.collection.drop() self.mlbuilder = MLStructuresBuilder( self.tasks, self.ml_strucs, task_types=("Structure Optimization", "Static")) def test_get_items(self): to_process = list(self.mlbuilder.get_items()) to_process_forms = {task["formula_pretty"] for task in to_process} self.assertEqual(len(to_process), 197) self.assertEqual(len(to_process_forms), 12) self.assertTrue("Sr" in to_process_forms) self.assertTrue("Hf" in to_process_forms) self.assertTrue("O2" in to_process_forms) self.assertFalse("H" in to_process_forms) def test_process_item(self): for task in self.tasks.query(): ml_strucs = self.mlbuilder.process_item(task) t_type = task_type(get(task, 'input.incar')) if not any([t in t_type for t in self.mlbuilder.task_types]): self.assertEqual(len(ml_strucs), 0) else: self.assertEqual( len(ml_strucs), sum([ len(t["output"]["ionic_steps"]) for t in task["calcs_reversed"] ])) def test_update_targets(self): for task in self.tasks.query(): ml_strucs = self.mlbuilder.process_item(task) self.mlbuilder.update_targets([ml_strucs]) self.assertEqual(len(self.ml_strucs.distinct("task_id")), 102) self.assertEqual(len(list(self.ml_strucs.query())), 1012) def tearDown(self): self.ml_strucs.collection.drop()
class JointStoreTest(unittest.TestCase): def setUp(self): self.jointstore = JointStore("maggma_test", ["test1", "test2"]) self.jointstore.connect() self.jointstore.collection.drop() self.jointstore.collection.insert_many([{ "task_id": k, "my_prop": k + 1, "last_updated": datetime.utcnow(), "category": k // 5 } for k in range(10)]) self.jointstore.collection.database["test2"].drop() self.jointstore.collection.database["test2"].insert_many([{ "task_id": 2 * k, "your_prop": k + 3, "last_updated": datetime.utcnow(), "category2": k // 3 } for k in range(5)]) self.test1 = MongoStore("maggma_test", "test1") self.test1.connect() self.test2 = MongoStore("maggma_test", "test2") self.test2.connect() def test_query(self): # Test query all docs = list(self.jointstore.query()) self.assertEqual(len(docs), 10) docs_w_field = [d for d in docs if "test2" in d] self.assertEqual(len(docs_w_field), 5) docs_w_field = sorted(docs_w_field, key=lambda x: x['task_id']) self.assertEqual(docs_w_field[0]['test2']['your_prop'], 3) self.assertEqual(docs_w_field[0]['task_id'], 0) self.assertEqual(docs_w_field[0]['my_prop'], 1) def test_query_one(self): doc = self.jointstore.query_one() self.assertEqual(doc['my_prop'], doc['task_id'] + 1) # Test limit properties doc = self.jointstore.query_one(properties=['test2', 'task_id']) self.assertEqual(doc['test2']['your_prop'], doc['task_id'] + 3) self.assertIsNone(doc.get("my_prop")) # Test criteria doc = self.jointstore.query_one(criteria={"task_id": {"$gte": 10}}) self.assertIsNone(doc) doc = self.jointstore.query_one( criteria={"test2.your_prop": { "$gt": 6 }}) self.assertEqual(doc['task_id'], 8) # Test merge_at_root self.jointstore.merge_at_root = True # Test merging is working properly doc = self.jointstore.query_one(criteria={"task_id": 2}) self.assertEqual(doc['my_prop'], 3) self.assertEqual(doc['your_prop'], 4) # Test merging is allowing for subsequent match doc = self.jointstore.query_one(criteria={"your_prop": {"$gt": 6}}) self.assertEqual(doc['task_id'], 8) def test_distinct(self): dyour_prop = self.jointstore.distinct("test2.your_prop") self.assertEqual(set(dyour_prop), {k + 3 for k in range(5)}) dmy_prop = self.jointstore.distinct("my_prop") self.assertEqual(set(dmy_prop), {k + 1 for k in range(10)}) dmy_prop_cond = self.jointstore.distinct( "my_prop", {"test2.your_prop": { "$gte": 5 }}) self.assertEqual(set(dmy_prop_cond), {5, 7, 9}) def test_last_updated(self): doc = self.jointstore.query_one({"task_id": 0}) test1doc = self.test1.query_one({"task_id": 0}) test2doc = self.test2.query_one({"task_id": 0}) self.assertEqual(test2doc['last_updated'], doc['last_updated']) self.assertNotEqual(test1doc['last_updated'], doc['last_updated']) # Swap the two test2date = test2doc['last_updated'] test2doc['last_updated'] = test1doc['last_updated'] test1doc['last_updated'] = test2date self.test1.update([test1doc], update_lu=False) self.test2.update([test2doc], update_lu=False) doc = self.jointstore.query_one({"task_id": 0}) test1doc = self.test1.query_one({"task_id": 0}) test2doc = self.test2.query_one({"task_id": 0}) self.assertEqual(test1doc['last_updated'], doc['last_updated']) self.assertNotEqual(test2doc['last_updated'], doc['last_updated']) # Check also that still has a field if no task2 doc doc = self.jointstore.query_one({"task_id": 1}) self.assertIsNotNone(doc['last_updated']) def test_groupby(self): docs = list(self.jointstore.groupby("category")) self.assertEqual(len(docs[0]['docs']), 5) self.assertEqual(len(docs[1]['docs']), 5) docs = list(self.jointstore.groupby("test2.category2")) docs_by_id = {get(d, '_id.test2.category2'): d['docs'] for d in docs} self.assertEqual(len(docs_by_id[None]), 5) self.assertEqual(len(docs_by_id[0]), 3) self.assertEqual(len(docs_by_id[1]), 2)
def setUpClass(cls): cls.dbname = "test_" + uuid4().hex s = MongoStore(cls.dbname, "test") s.connect() cls.client = s.collection.database.client
class ElasticAnalysisBuilderTest(unittest.TestCase): @classmethod def setUp(self): # Set up test db, set up mpsft, etc. self.test_tasks = MongoStore("test_emmet", "tasks") self.test_tasks.connect() docs = loadfn(test_tasks, cls=None) self.test_tasks.update(docs) self.test_elasticity = MongoStore("test_emmet", "elasticity") self.test_elasticity.connect() if PROFILE_MODE: self.pr = cProfile.Profile() self.pr.enable() print("\n<<<---") @classmethod def tearDown(self): if not DEBUG_MODE: self.test_elasticity.collection.drop() self.test_tasks.collection.drop() if PROFILE_MODE: p = Stats(self.pr) p.strip_dirs() p.sort_stats('cumtime') p.print_stats() print("\n--->>>") def test_builder(self): ec_builder = ElasticAnalysisBuilder(self.test_tasks, self.test_elasticity, incremental=False) ec_builder.connect() for t in ec_builder.get_items(): processed = ec_builder.process_item(t) self.assertTrue(bool(processed)) runner = Runner([ec_builder]) runner.run() # Test warnings doc = ec_builder.elasticity.query_one( criteria={"pretty_formula": "NaN3"}) self.assertEqual(doc['warnings'], None) self.assertAlmostEqual(doc['compliance_tensor'][0][0], 41.576072, 6) def test_grouping_functions(self): docs1 = list( self.test_tasks.query(criteria={"formula_pretty": "NaN3"})) docs_grouped1 = group_by_parent_lattice(docs1) self.assertEqual(len(docs_grouped1), 1) grouped_by_opt = group_deformations_by_optimization_task(docs1) self.assertEqual(len(grouped_by_opt), 1) docs2 = self.test_tasks.query( criteria={"task_label": "elastic deformation"}) sgroup2 = group_by_parent_lattice(docs2) def test_get_distinct_rotations(self): struct = PymatgenTest.get_structure("Si") conv = SpacegroupAnalyzer(struct).get_conventional_standard_structure() rots = get_distinct_rotations(conv) ops = SpacegroupAnalyzer(conv).get_symmetry_operations() for op in ops: self.assertTrue( any([np.allclose(op.rotation_matrix, r) for r in rots])) self.assertEqual(len(rots), 48) def test_process_elastic_calcs(self): test_struct = PymatgenTest.get_structure('Sn') # use cubic test struct dss = DeformedStructureSet(test_struct) # Construct test task set opt_task = { "output": { "structure": test_struct.as_dict() }, "input": { "structure": test_struct.as_dict() } } defo_tasks = [] for n, (struct, defo) in enumerate(zip(dss, dss.deformations)): strain = defo.green_lagrange_strain defo_task = { "output": { "structure": struct.as_dict(), "stress": (strain * 5).tolist() }, "input": None, "task_id": n, "completed_at": datetime.utcnow() } defo_task.update({ "transmuter": { "transformation_params": [{ "deformation": defo }] } }) defo_tasks.append(defo_task) defo_tasks.pop(0) explicit, derived = process_elastic_calcs(opt_task, defo_tasks) self.assertEqual(len(explicit), 23) self.assertEqual(len(derived), 1) def test_process_elastic_calcs_toec(self): # Test TOEC tasks test_struct = PymatgenTest.get_structure('Sn') # use cubic test struct strain_states = get_default_strain_states(3) # Default stencil in atomate, this maybe shouldn't be hard-coded stencil = np.linspace(-0.075, 0.075, 7) strains = [ Strain.from_voigt(s * np.array(strain_state)) for s, strain_state in product(stencil, strain_states) ] strains = [s for s in strains if not np.allclose(s, 0)] sym_reduced = symmetry_reduce(strains, test_struct) opt_task = { "output": { "structure": test_struct.as_dict() }, "input": { "structure": test_struct.as_dict() } } defo_tasks = [] for n, strain in enumerate(sym_reduced): defo = strain.get_deformation_matrix() new_struct = defo.apply_to_structure(test_struct) defo_task = { "output": { "structure": new_struct.as_dict(), "stress": (strain * 5).tolist() }, "input": None, "task_id": n, "completed_at": datetime.utcnow() } defo_task.update({ "transmuter": { "transformation_params": [{ "deformation": defo }] } }) defo_tasks.append(defo_task) explicit, derived = process_elastic_calcs(opt_task, defo_tasks) self.assertEqual(len(explicit), len(sym_reduced)) self.assertEqual(len(derived), len(strains) - len(sym_reduced)) for calc in derived: self.assertTrue( np.allclose(calc['strain'], calc['cauchy_stress'] / -0.5))
class ElasticAggregateBuilderTest(unittest.TestCase): def setUp(self): # Empty aggregated collection self.test_elasticity_agg = MongoStore("test_emmet", "elasticity_agg") self.test_elasticity_agg.connect() # Generate test materials collection self.test_materials = MongoStore("test_emmet", "materials") self.test_materials.connect() mat_docs = [] for n, formula in enumerate(['Si', 'BaNiO3', 'Li2O2', 'TiO2']): structure = PymatgenTest.get_structure(formula) structure.add_site_property("magmoms", [0.0] * len(structure)) mat_docs.append({ "task_id": "mp-{}".format(n), "structure": structure.as_dict(), "pretty_formula": formula }) self.test_materials.update(mat_docs, update_lu=False) # Create elasticity collection and add docs self.test_elasticity = MongoStore("test_emmet", "elasticity", key="optimization_task_id") self.test_elasticity.connect() si = PymatgenTest.get_structure("Si") si.add_site_property("magmoms", [0.0] * len(si)) et = ElasticTensor.from_voigt([[50, 25, 25, 0, 0, 0], [25, 50, 25, 0, 0, 0], [25, 25, 50, 0, 0, 0], [0, 0, 0, 75, 0, 0], [0, 0, 0, 0, 75, 0], [0, 0, 0, 0, 0, 75]]) doc = { "input_structure": si.copy().as_dict(), "order": 2, "magnetic_type": "non-magnetic", "optimization_task_id": "mp-1", "last_updated": datetime.utcnow(), "completed_at": datetime.utcnow(), "optimized_structure": si.copy().as_dict(), "pretty_formula": "Si", "state": "successful" } doc['elastic_tensor'] = et.voigt doc.update(et.property_dict) self.test_elasticity.update([doc]) # Insert second doc with diff params si.perturb(0.005) doc.update({ "optimized_structure": si.copy().as_dict(), "updated_at": datetime.utcnow(), "optimization_task_id": "mp-5" }) self.test_elasticity.update([doc]) self.builder = self.get_a_new_builder() def tearDown(self): if not DEBUG_MODE: self.test_elasticity.collection.drop() self.test_elasticity_agg.collection.drop() self.test_materials.collection.drop() def test_materials_aggregator(self): materials_dict = generate_formula_dict(self.test_materials) docs = [] grouped_by_mpid = group_by_material_id( materials_dict['Si'], [{ 'structure': PymatgenTest.get_structure('Si').as_dict(), 'magnetic_type': "non-magnetic" }]) self.assertEqual(len(grouped_by_mpid), 1) materials_dict = generate_formula_dict(self.test_materials) def test_get_items(self): iterator = self.builder.get_items() for item in iterator: self.assertIsNotNone(item) def test_process_items(self): docs = list( self.test_elasticity.query(criteria={"pretty_formula": "Si"})) formula_dict = generate_formula_dict(self.test_materials) processed = self.builder.process_item((docs, formula_dict['Si'])) self.assertEqual(len(processed), 1) self.assertEqual(len(processed[0]['all_elastic_fits']), 2) def test_update_targets(self): processed = [ self.builder.process_item(item) for item in self.builder.get_items() ] self.builder.update_targets(processed) def test_aggregation(self): runner = Runner([self.builder]) runner.run() all_agg_docs = list(self.test_elasticity_agg.query()) self.assertTrue(bool(all_agg_docs)) def get_a_new_builder(self): return ElasticAggregateBuilder(self.test_elasticity, self.test_materials, self.test_elasticity_agg)
def test_mongostore_connect(): mongostore = MongoStore("maggma_test", "test") assert mongostore._collection is None mongostore.connect() assert isinstance(mongostore._collection, pymongo.collection.Collection)
class TestCopyBuilder(TestCase): @classmethod def setUpClass(cls): cls.dbname = "test_" + uuid4().hex s = MongoStore(cls.dbname, "test") s.connect() cls.client = s.collection.database.client @classmethod def tearDownClass(cls): cls.client.drop_database(cls.dbname) def setUp(self): tic = datetime.now() toc = tic + timedelta(seconds=1) keys = list(range(20)) self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys] self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]] kwargs = dict(key="k", lu_field="lu") self.source = MongoStore(self.dbname, "source", **kwargs) self.target = MongoStore(self.dbname, "target", **kwargs) self.builder = CopyBuilder(self.source, self.target) self.source.connect() self.source.collection.create_index("lu") self.target.connect() self.target.collection.create_index("lu") self.target.collection.create_index("k") def tearDown(self): self.source.collection.drop() self.target.collection.drop() def test_get_items(self): self.source.collection.insert_many(self.old_docs) self.assertEqual(len(list(self.builder.get_items())), len(self.old_docs)) self.target.collection.insert_many(self.old_docs) self.assertEqual(len(list(self.builder.get_items())), 0) self.source.update(self.new_docs, update_lu=False) self.assertEqual(len(list(self.builder.get_items())), len(self.new_docs)) def test_process_item(self): self.source.collection.insert_many(self.old_docs) items = list(self.builder.get_items()) self.assertCountEqual(items, map(self.builder.process_item, items)) def test_update_targets(self): self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.target.collection.insert_many(self.old_docs) items = list(map(self.builder.process_item, self.builder.get_items())) self.builder.update_targets(items) self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new") self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old") def test_confirm_lu_field_index(self): self.source.collection.drop_index("lu_1") with self.assertRaises(Exception) as cm: self.builder.get_items() self.assertTrue(cm.exception.args[0].startswith("Need index")) self.source.collection.create_index("lu") def test_runner(self): self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.target.collection.insert_many(self.old_docs) runner = Runner([self.builder]) runner.run() self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new") self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old") def test_query(self): self.builder.query = {"k": {"$gt": 5}} self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) runner = Runner([self.builder]) runner.run() all_docs = list(self.target.query(criteria={})) self.assertEqual(len(all_docs), 14) self.assertTrue(min([d['k'] for d in all_docs]), 6)
class TestCopyBuilder(TestCase): @classmethod def setUpClass(cls): cls.dbname = "test_" + uuid4().hex s = MongoStore(cls.dbname, "test") s.connect() cls.client = s.collection.database.client @classmethod def tearDownClass(cls): cls.client.drop_database(cls.dbname) def setUp(self): tic = datetime.now() toc = tic + timedelta(seconds=1) keys = list(range(20)) self.old_docs = [{"lu": tic, "k": k, "v": "old"} for k in keys] self.new_docs = [{"lu": toc, "k": k, "v": "new"} for k in keys[:10]] kwargs = dict(key="k", lu_field="lu") self.source = MongoStore(self.dbname, "source", **kwargs) self.target = MongoStore(self.dbname, "target", **kwargs) self.builder = CopyBuilder(self.source, self.target) self.source.connect() self.source.ensure_index(self.source.key) self.source.ensure_index(self.source.lu_field) self.target.connect() self.target.ensure_index(self.target.key) self.target.ensure_index(self.target.lu_field) def tearDown(self): self.source.collection.drop() self.target.collection.drop() def test_get_items(self): self.source.collection.insert_many(self.old_docs) self.assertEqual(len(list(self.builder.get_items())), len(self.old_docs)) self.target.collection.insert_many(self.old_docs) self.assertEqual(len(list(self.builder.get_items())), 0) self.source.update(self.new_docs, update_lu=False) self.assertEqual(len(list(self.builder.get_items())), len(self.new_docs)) def test_process_item(self): self.source.collection.insert_many(self.old_docs) items = list(self.builder.get_items()) self.assertCountEqual(items, map(self.builder.process_item, items)) def test_update_targets(self): self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.target.collection.insert_many(self.old_docs) items = list(map(self.builder.process_item, self.builder.get_items())) self.builder.update_targets(items) self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new") self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old") @unittest.skip( "Have to refactor how we force read-only so a warning will get thrown") def test_index_warning(self): """Should log warning when recommended store indexes are not present.""" self.source.collection.drop_index([(self.source.key, 1)]) with self.assertLogs(level=logging.WARNING) as cm: list(self.builder.get_items()) self.assertIn("Ensure indices", "\n".join(cm.output)) def test_run(self): self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.target.collection.insert_many(self.old_docs) self.builder.run() self.assertEqual(self.target.query_one(criteria={"k": 0})["v"], "new") self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old") def test_query(self): self.builder.query = {"k": {"$gt": 5}} self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.builder.run() all_docs = list(self.target.query(criteria={})) self.assertEqual(len(all_docs), 14) self.assertTrue(min([d['k'] for d in all_docs]), 6) def test_delete_orphans(self): self.builder = CopyBuilder(self.source, self.target, delete_orphans=True) self.source.collection.insert_many(self.old_docs) self.source.update(self.new_docs, update_lu=False) self.target.collection.insert_many(self.old_docs) deletion_criteria = {"k": {"$in": list(range(5))}} self.source.collection.delete_many(deletion_criteria) self.builder.run() self.assertEqual( self.target.collection.count_documents(deletion_criteria), 0) self.assertEqual(self.target.query_one(criteria={"k": 5})["v"], "new") self.assertEqual(self.target.query_one(criteria={"k": 10})["v"], "old") def test_incremental_false(self): tic = datetime.now() toc = tic + timedelta(seconds=1) keys = list(range(20)) earlier = [{"lu": tic, "k": k, "v": "val"} for k in keys] later = [{"lu": toc, "k": k, "v": "val"} for k in keys] self.source.collection.insert_many(earlier) self.target.collection.insert_many(later) query = {"k": {"$gt": 5}} self.builder = CopyBuilder(self.source, self.target, incremental=False, query=query) self.builder.run() docs = sorted(self.target.query(), key=lambda d: d["k"]) self.assertTrue(all(d["lu"] == tic) for d in docs[5:]) self.assertTrue(all(d["lu"] == toc) for d in docs[:5])
class TestThermo(BuilderTest): def setUp(self): self.materials = MongoStore("emmet_test", "materials") self.thermo = MongoStore("emmet_test", "thermo") self.materials.connect() self.thermo.connect() self.mbuilder = MaterialsBuilder(self.tasks, self.materials, mat_prefix="", chunk_size=1) self.tbuilder = ThermoBuilder(self.materials, self.thermo, chunk_size=1) runner = Runner([self.mbuilder]) runner.run() def test_get_entries(self): self.assertEqual(len(self.tbuilder.get_entries("Sr")), 7) self.assertEqual(len(self.tbuilder.get_entries("Hf")), 4) self.assertEqual(len(self.tbuilder.get_entries("O")), 6) self.assertEqual(len(self.tbuilder.get_entries("Hf-O-Sr")), 44) self.assertEqual(len(self.tbuilder.get_entries("Sr-Hf")), 11) def test_get_items(self): self.thermo.collection.drop() comp_systems = list(self.tbuilder.get_items()) self.assertEqual(len(comp_systems), 1) self.assertEqual(len(comp_systems[0]), 44) def test_process_item(self): tbuilder = ThermoBuilder(self.materials, self.thermo, query={"elements": ["Sr"]}, chunk_size=1) entries = list(tbuilder.get_items())[0] self.assertEqual(len(entries), 7) t_docs = self.tbuilder.process_item(entries) e_above_hulls = [t['thermo']['e_above_hull'] for t in t_docs] sorted_t_docs = list( sorted(t_docs, key=lambda x: x['thermo']['e_above_hull'])) self.assertEqual(sorted_t_docs[0]["task_id"], "mp-76") def test_update_targets(self): self.thermo.collection.drop() tbuilder = ThermoBuilder(self.materials, self.thermo, query={"elements": ["Sr"]}, chunk_size=1) entries = list(tbuilder.get_items())[0] self.assertEqual(len(entries), 7) t_docs = self.tbuilder.process_item(entries) self.tbuilder.update_targets([t_docs]) self.assertEqual(len(list(self.thermo.query())), len(t_docs)) def tearDown(self): self.materials.collection.drop() self.thermo.collection.drop()
def jointstore_test2(): store = MongoStore("maggma_test", "test2") store.connect() yield store store._collection.drop()
def bs_dos_data( mpid, path_convention, dos_select, label_select, bandstructure_symm_line, density_of_states, ): if not mpid and (bandstructure_symm_line is None or density_of_states is None): raise PreventUpdate elif bandstructure_symm_line is None or density_of_states is None: if label_select == "": raise PreventUpdate # -- # -- BS and DOS from API or DB # -- bs_data = {"ticks": {}} bs_store = GridFSStore( database="fw_bs_prod", collection_name="bandstructure_fs", host="mongodb03.nersc.gov", port=27017, username="******", password="", ) dos_store = GridFSStore( database="fw_bs_prod", collection_name="dos_fs", host="mongodb03.nersc.gov", port=27017, username="******", password="", ) es_store = MongoStore( database="fw_bs_prod", collection_name="electronic_structure", host="mongodb03.nersc.gov", port=27017, username="******", password="", key="task_id", ) # - BS traces from DB using task_id es_store.connect() bs_query = es_store.query_one( criteria={"task_id": int(mpid)}, properties=[ "bandstructure.{}.task_id".format(path_convention), "bandstructure.{}.total.equiv_labels".format( path_convention), ], ) es_store.close() bs_store.connect() bandstructure_symm_line = bs_store.query_one(criteria={ "metadata.task_id": int(bs_query["bandstructure"][path_convention]["task_id"]) }, ) # If LM convention, get equivalent labels if path_convention != label_select: bs_equiv_labels = bs_query["bandstructure"][ path_convention]["total"]["equiv_labels"] new_labels_dict = {} for label in bandstructure_symm_line["labels_dict"].keys(): label_formatted = label.replace("$", "") if "|" in label_formatted: f_label = label_formatted.split("|") new_labels.append( "$" + bs_equiv_labels[label_select][f_label[0]] + "|" + bs_equiv_labels[label_select][f_label[1]] + "$") else: new_labels_dict["$" + bs_equiv_labels[label_select] [label_formatted] + "$"] = bandstructure_symm_line[ "labels_dict"][label] bandstructure_symm_line["labels_dict"] = new_labels_dict # - DOS traces from DB using task_id es_store.connect() dos_query = es_store.query_one( criteria={"task_id": int(mpid)}, properties=["dos.task_id"], ) es_store.close() dos_store.connect() density_of_states = dos_store.query_one( criteria={"task_id": int(dos_query["dos"]["task_id"])}, ) # - BS Data if (type(bandstructure_symm_line) != dict and bandstructure_symm_line is not None): bandstructure_symm_line = bandstructure_symm_line.to_dict() if type(density_of_states ) != dict and density_of_states is not None: density_of_states = density_of_states.to_dict() bsml = BSML.from_dict(bandstructure_symm_line) bs_reg_plot = BSPlotter(bsml) bs_data = bs_reg_plot.bs_plot_data() # Make plot continous for lm if path_convention == "lm": distance_map, kpath_euler = HSKP( bsml.structure).get_continuous_path(bsml) kpath_labels = [pair[0] for pair in kpath_euler] kpath_labels.append(kpath_euler[-1][1]) else: distance_map = [(i, False) for i in range(len(bs_data["distances"]))] kpath_labels = [] for label_ind in range(len(bs_data["ticks"]["label"]) - 1): if (bs_data["ticks"]["label"][label_ind] != bs_data["ticks"]["label"][label_ind + 1]): kpath_labels.append( bs_data["ticks"]["label"][label_ind]) kpath_labels.append(bs_data["ticks"]["label"][-1]) bs_data["ticks"]["label"] = kpath_labels # Obtain bands to plot over and generate traces for bs data: energy_window = (-6.0, 10.0) bands = [] for band_num in range(bs_reg_plot._nb_bands): if (bs_data["energy"][0][str(Spin.up)][band_num][0] <= energy_window[1]) and (bs_data["energy"][0][str( Spin.up)][band_num][0] >= energy_window[0]): bands.append(band_num) bstraces = [] pmin = 0.0 tick_vals = [0.0] cbm = bsml.get_cbm() vbm = bsml.get_vbm() cbm_new = bs_data["cbm"] vbm_new = bs_data["vbm"] for dnum, (d, rev) in enumerate(distance_map): x_dat = [ dval - bs_data["distances"][d][0] + pmin for dval in bs_data["distances"][d] ] pmin = x_dat[-1] tick_vals.append(pmin) if not rev: traces_for_segment = [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.up)][i][j] for j in range(len(bs_data["distances"][d])) ], "mode": "lines", "line": { "color": "#1f77b4" }, "hoverinfo": "skip", "name": "spin ↑" if bs_reg_plot._bs.is_spin_polarized else "Total", "hovertemplate": "%{y:.2f} eV", "showlegend": False, "xaxis": "x", "yaxis": "y", } for i in bands] elif rev: traces_for_segment = [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.up)][i][j] for j in reversed( range(len(bs_data["distances"][d]))) ], "mode": "lines", "line": { "color": "#1f77b4" }, "hoverinfo": "skip", "name": "spin ↑" if bs_reg_plot._bs.is_spin_polarized else "Total", "hovertemplate": "%{y:.2f} eV", "showlegend": False, "xaxis": "x", "yaxis": "y", } for i in bands] if bs_reg_plot._bs.is_spin_polarized: if not rev: traces_for_segment += [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.down)][i][j] for j in range(len(bs_data["distances"][d])) ], "mode": "lines", "line": { "color": "#ff7f0e", "dash": "dot" }, "hoverinfo": "skip", "showlegend": False, "name": "spin ↓", "hovertemplate": "%{y:.2f} eV", "xaxis": "x", "yaxis": "y", } for i in bands] elif rev: traces_for_segment += [{ "x": x_dat, "y": [ bs_data["energy"][d][str(Spin.down)][i][j] for j in reversed( range(len(bs_data["distances"][d]))) ], "mode": "lines", "line": { "color": "#ff7f0e", "dash": "dot" }, "hoverinfo": "skip", "showlegend": False, "name": "spin ↓", "hovertemplate": "%{y:.2f} eV", "xaxis": "x", "yaxis": "y", } for i in bands] bstraces += traces_for_segment # - Get proper cbm and vbm coords for lm if path_convention == "lm": for (x_point, y_point) in bs_data["cbm"]: if x_point in bs_data["distances"][d]: xind = bs_data["distances"][d].index(x_point) if not rev: x_point_new = x_dat[xind] else: x_point_new = x_dat[len(x_dat) - xind - 1] new_label = bs_data["ticks"]["label"][ tick_vals.index(x_point_new)] if (cbm["kpoint"].label is None or cbm["kpoint"].label in new_label): cbm_new.append((x_point_new, y_point)) for (x_point, y_point) in bs_data["vbm"]: if x_point in bs_data["distances"][d]: xind = bs_data["distances"][d].index(x_point) if not rev: x_point_new = x_dat[xind] else: x_point_new = x_dat[len(x_dat) - xind - 1] new_label = bs_data["ticks"]["label"][ tick_vals.index(x_point_new)] if (vbm["kpoint"].label is None or vbm["kpoint"].label in new_label): vbm_new.append((x_point_new, y_point)) bs_data["ticks"]["distance"] = tick_vals # - Strip latex math wrapping for labels str_replace = { "$": "", "\\mid": "|", "\\Gamma": "Γ", "\\Sigma": "Σ", "GAMMA": "Γ", "_1": "₁", "_2": "₂", "_3": "₃", "_4": "₄", "_{1}": "₁", "_{2}": "₂", "_{3}": "₃", "_{4}": "₄", "^{*}": "*", } bar_loc = [] for entry_num in range(len(bs_data["ticks"]["label"])): for key in str_replace.keys(): if key in bs_data["ticks"]["label"][entry_num]: bs_data["ticks"]["label"][entry_num] = bs_data[ "ticks"]["label"][entry_num].replace( key, str_replace[key]) if key == "\\mid": bar_loc.append( bs_data["ticks"]["distance"][entry_num]) # Vertical lines for disjointed segments vert_traces = [{ "x": [x_point, x_point], "y": energy_window, "mode": "lines", "marker": { "color": "white" }, "hoverinfo": "skip", "showlegend": False, "xaxis": "x", "yaxis": "y", } for x_point in bar_loc] bstraces += vert_traces # Dots for cbm and vbm dot_traces = [{ "x": [x_point], "y": [y_point], "mode": "markers", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "CBM: k = {}, {} eV".format(list(cbm["kpoint"].frac_coords), cbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(cbm_new)] + [{ "x": [x_point], "y": [y_point], "mode": "marker", "marker": { "color": "#7E259B", "size": 16, "line": { "color": "white", "width": 2 }, }, "showlegend": False, "hoverinfo": "text", "name": "", "hovertemplate": "VBM: k = {}, {} eV".format(list(vbm["kpoint"].frac_coords), vbm["energy"]), "xaxis": "x", "yaxis": "y", } for (x_point, y_point) in set(vbm_new)] bstraces += dot_traces # - DOS Data dostraces = [] dos = CompleteDos.from_dict(density_of_states) dos_max = np.abs( (dos.energies - dos.efermi - energy_window[1])).argmin() dos_min = np.abs( (dos.energies - dos.efermi - energy_window[0])).argmin() if bs_reg_plot._bs.is_spin_polarized: # Add second spin data if available trace_tdos = { "x": -1.0 * dos.densities[Spin.down][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": "Total DOS (spin ↓)", "line": go.scatter.Line(color="#444444", dash="dot"), "fill": "tozerox", "fillcolor": "#C4C4C4", "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace_tdos) tdos_label = "Total DOS (spin ↑)" else: tdos_label = "Total DOS" # Total DOS trace_tdos = { "x": dos.densities[Spin.up][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": tdos_label, "line": go.scatter.Line(color="#444444"), "fill": "tozerox", "fillcolor": "#C4C4C4", "legendgroup": "spinup", "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace_tdos) ele_dos = dos.get_element_dos() elements = [str(entry) for entry in ele_dos.keys()] if dos_select == "ap": proj_data = ele_dos elif dos_select == "op": proj_data = dos.get_spd_dos() elif "orb" in dos_select: proj_data = dos.get_element_spd_dos( Element(dos_select.replace("orb", ""))) else: raise PreventUpdate # Projected DOS count = 0 colors = [ "#d62728", # brick red "#2ca02c", # cooked asparagus green "#17becf", # blue-teal "#bcbd22", # curry yellow-green "#9467bd", # muted purple "#8c564b", # chestnut brown "#e377c2", # raspberry yogurt pink ] for label in proj_data.keys(): if bs_reg_plot._bs.is_spin_polarized: trace = { "x": -1.0 * proj_data[label].densities[Spin.down][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": str(label) + " (spin ↓)", "line": dict(width=3, color=colors[count], dash="dot"), "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace) spin_up_label = str(label) + " (spin ↑)" else: spin_up_label = str(label) trace = { "x": proj_data[label].densities[Spin.up][dos_min:dos_max], "y": dos.energies[dos_min:dos_max] - dos.efermi, "mode": "lines", "name": spin_up_label, "line": dict(width=2, color=colors[count]), "xaxis": "x2", "yaxis": "y2", } dostraces.append(trace) count += 1 traces = [bstraces, dostraces, bs_data] return (traces, elements)