def source2(): store = MemoryStore("source2", key="k", last_updated_field="lu") store.connect() store.ensure_index("k") store.ensure_index("lu") store.update([{"k": k, "c": "c", "d": "d"} for k in range(15)]) return store
class TestRobocrysBuilder(unittest.TestCase): def setUp(self): """Set up materials and robocrys stores.""" struct_docs = loadfn(test_mats, cls=None) self.materials = MemoryStore("materials") self.materials.connect() self.materials.update(struct_docs) self.robocrys = MemoryStore("robocrys") def test_build(self): """Test building the robocrys database.""" builder = RobocrysBuilder(self.materials, self.robocrys) runner = Runner([builder]) runner.run() doc = list(self.robocrys.query(criteria={'task_id': 'mp-66'}))[0] self.assertEqual(doc['condensed_structure']['formula'], 'C') self.assertEqual(doc['condensed_structure']['spg_symbol'], 'Fd-3m') self.assertEqual(doc['condensed_structure']['mineral']['type'], 'diamond') self.assertEqual(doc['condensed_structure']['dimensionality'], '3') self.assertTrue("C is diamond structured" in doc['description']) self.assertTrue("bond lengths are 1.55" in doc['description'])
def source(docs): store = MemoryStore("source", key="k", last_updated_field="lu") store.connect() store.ensure_index("k") store.ensure_index("lu") store.update(docs) return store
def source1(): store = MemoryStore("source1", key="k", last_updated_field="lu") store.connect() store.ensure_index("k") store.ensure_index("lu") store.update([{"k": k, "a": "a", "b": "b"} for k in range(10)]) return store
class StructureSimilarityBuilderTest(unittest.TestCase): @classmethod def setUpClass(self): # Set up test db, etc. self.test_site_descriptors = MemoryStore("site_descr") self.test_site_descriptors.connect() site_fp_docs = loadfn(test_site_fp_stats, cls=None) self.test_site_descriptors.update(site_fp_docs) def test_get_items(self): test_structure_similarity = MemoryStore("struct_sim") test_structure_similarity.connect() sim_builder = StructureSimilarityBuilder(self.test_site_descriptors, test_structure_similarity, fp_type='opsf') items = list(sim_builder.get_items()) self.assertEqual(len(items), 3) for i in items: d1 = i[0] d2 = i[1] self.assertIn("statistics", d1) self.assertIn("statistics", d2) self.assertIn("task_id", d1) self.assertIn("task_id", d2) processed = sim_builder.process_item(i) if processed: pass else: import nose nose.tools.set_trace() def test_get_all_site_descriptors(self): test_structure_similarity = MemoryStore("struct_sim") test_structure_similarity.connect() sim_builder = StructureSimilarityBuilder(self.test_site_descriptors, test_structure_similarity, fp_type='opsf') for d in self.test_site_descriptors.query(): dsim = sim_builder.get_similarities(d, d) self.assertAlmostEqual(dsim['cos'], 1) self.assertAlmostEqual(dsim['dist'], 0) C = self.test_site_descriptors.query_one(criteria={"task_id": "mp-66"}) NaCl = self.test_site_descriptors.query_one(criteria={"task_id": "mp-22862"}) Fe = self.test_site_descriptors.query_one(criteria={"task_id": "mp-13"}) d = sim_builder.get_similarities(C, NaCl) self.assertAlmostEqual(d['cos'], 0.0013649) self.assertAlmostEqual(d['dist'], 2.6866749) d = sim_builder.get_similarities(C, Fe) self.assertAlmostEqual(d['cos'], 0.0013069) self.assertAlmostEqual(d['dist'], 2.6293889) d = sim_builder.get_similarities(NaCl, Fe) self.assertAlmostEqual(d['cos'], 0.0012729) self.assertAlmostEqual(d['dist'], 2.7235044)
class TaskTaggerTest(unittest.TestCase): def setUp(self): coords = list() coords.append([0, 0, 0]) coords.append([0.75, 0.5, 0.75]) lattice = [ [3.8401979337, 0.00, 0.00], [1.9200989668, 3.3257101909, 0.00], [0.00, -2.2171384943, 3.1355090603], ] structure = Structure(lattice, ["Si", "Si"], coords) input_sets = { "GGA Structure Optimization": MPRelaxSet(structure), "GGA Static": MPStaticSet(structure), "GGA NSCF Line": MPNonSCFSet(structure, mode="line"), "GGA NSCF Uniform": MPNonSCFSet(structure, mode="uniform"), } tasks = [] t_id = 1 for task_type, input_set in input_sets.items(): doc = { "true_task_type": task_type, "last_updated": datetime.now(), "task_id": t_id, "state": "successful", "orig_inputs": { "incar": input_set.incar.as_dict(), "kpoints": input_set.kpoints.as_dict(), }, "output": { "structure": structure.as_dict() }, } t_id += 1 tasks.append(doc) self.test_tasks = MemoryStore("tasks") self.task_types = MemoryStore("task_types") self.test_tasks.connect() self.task_types.connect() self.test_tasks.update(tasks) def test_mp_defs(self): task_tagger = TaskTagger(tasks=self.test_tasks, task_types=self.task_types) for t in task_tagger.get_items(): processed = task_tagger.calc(t) true_type = self.test_tasks.query_one( criteria={"task_id": t["task_id"]}, properties=["true_task_type"])["true_task_type"] self.assertEqual(processed["task_type"], true_type)
def setUpClass(cls) -> None: Registry.clear_all_registries() add_builtin_symbols_to_registry() cls.afa_web = AflowAdapter() store_data = loadfn(os.path.join(TEST_DATA_DIR, 'aflow_store.json')) store = MemoryStore() store.connect() store.update(store_data, key='auid') cls.afa_store = AflowAdapter(store)
class BuilderTest(unittest.TestCase): def setUp(self): self.materials = MemoryStore() self.materials.connect() materials = loadfn(os.path.join(TEST_DIR, "test_materials.json")) materials = jsanitize(materials, strict=True, allow_bson=True) self.materials.update(materials) self.propstore = MemoryStore() self.propstore.connect() def test_serial_runner(self): builder = PropnetBuilder(self.materials, self.propstore) runner = Runner([builder]) runner.run() def test_multiproc_runner(self): builder = PropnetBuilder(self.materials, self.propstore) runner = Runner([builder]) runner.run() def test_process_item(self): item = self.materials.query_one(criteria={"pretty_formula": "Cs"}) builder = PropnetBuilder(self.materials, self.propstore) processed = builder.process_item(item) self.assertIsNotNone(processed) # Ensure vickers hardness gets populated self.assertIn("vickers_hardness", processed) # @unittest.skipIf(not os.path.isfile("runner.json"), "No runner file") # def test_runner_pipeline(self): # from monty.serialization import loadfn # runner = loadfn("runner.json") # runner.builders[0].connect() # items = list(runner.builders[0].get_items()) # processed = runner.builders[0].process_item(items[0]) # runner.run() # Just here for reference, in case anyone wants to create a new set # of test materials -jhm @unittest.skipIf(True, "Skipping test materials creation") def create_test_docs(self): formulas = ["BaNiO3", "Si", "Fe2O3", "Cs"] from maggma.advanced_stores import MongograntStore from monty.serialization import dumpfn mgstore = MongograntStore("ro:matgen2.lbl.gov/mp_prod", "materials") builder = PropnetBuilder(mgstore, self.propstore, criteria={ "pretty_formula": { "$in": formulas }, "e_above_hull": 0 }) builder.connect() dumpfn(list(builder.get_items()), "test_materials.json")
def setUpClass(cls): materials = MemoryStore("materials") materials.connect() docs = [] for n, mat_string in enumerate(["Si", "Sn", "TiO2", "VO2"]): docs.append({"task_id": n, "structure": PymatgenTest.get_structure(mat_string).as_dict()}) materials.update(docs, key='task_id') elasticity = MemoryStore("elasticity") elasticity.connect() elasticity.update(docs[0:1], key="task_id") cls.materials = materials cls.elasticity = elasticity
class TestS3Store(unittest.TestCase): def setUp(self): self.index = MemoryStore("index'") with patch("boto3.resource") as mock_resource: mock_resource.return_value = MagicMock() mock_resource("s3").list_buckets.return_value = [ "bucket1", "bucket2" ] self.s3store = AmazonS3Store(self.index, "bucket1") self.s3store.connect() def test_qeuery_one(self): self.s3store.s3_bucket.Object.return_value = MagicMock() self.s3store.s3_bucket.Object( ).get.return_value = '{"task_id": "mp-1", "data": "asd"}' self.index.update([{"task_id": "mp-1"}]) self.assertEqual(self.s3store.query_one(criteria={"task_id": "mp-2"}), None) self.assertEqual( self.s3store.query_one(criteria={"task_id": "mp-1"})["data"], "asd") self.s3store.s3_bucket.Object().get.return_value = zlib.compress( '{"task_id": "mp-3", "data": "sdf"}'.encode()) self.index.update([{"task_id": "mp-3", "compression": "zlib"}]) self.assertEqual( self.s3store.query_one(criteria={"task_id": "mp-3"})["data"], "sdf") def test_update(self): self.s3store.update([{"task_id": "mp-1", "data": "asd"}]) self.assertEqual(self.s3store.s3_bucket.put_object.call_count, 1) called_kwargs = self.s3store.s3_bucket.put_object.call_args[1] self.assertEqual(self.s3store.s3_bucket.put_object.call_count, 1) self.assertEqual(called_kwargs["Key"], "mp-1") self.assertTrue(len(called_kwargs["Body"]) > 0) self.assertEqual(called_kwargs["Metadata"]["task_id"], "mp-1") def test_update_compression(self): self.s3store.update([{ "task_id": "mp-1", "data": "asd" }], compress=True) self.assertEqual(self.s3store.s3_bucket.put_object.call_count, 1) called_kwargs = self.s3store.s3_bucket.put_object.call_args[1] self.assertEqual(self.s3store.s3_bucket.put_object.call_count, 1) self.assertEqual(called_kwargs["Key"], "mp-1") self.assertTrue(len(called_kwargs["Body"]) > 0) self.assertEqual(called_kwargs["Metadata"]["task_id"], "mp-1") self.assertEqual(called_kwargs["Metadata"]["compression"], "zlib")
def _get_correlation_values(): full_propstore = MemoryStore() with open(os.path.join(CORR_TEST_DIR, "correlation_propnet_data.json"), 'r') as f: data = json.load(f) full_propstore.connect() full_propstore.update(jsanitize(data, strict=True, allow_bson=True)) correlation_store = MemoryStore() builder = CorrelationBuilder(full_propstore, correlation_store, props=PROPNET_PROPS, funcs='all', from_quantity_db=False) runner = Runner([builder]) runner.run() return builder
def search_helper(payload, base: str = "/?", debug=True) -> Response: """ Helper function to directly query search endpoints Args: store: store f base: base of the query, default to /query? client: TestClient generated from FastAPI payload: query in dictionary format debug: True = print out the url, false don't print anything Returns: request.Response object that contains the response of the correspoding payload """ store = MemoryStore("owners", key="name") store.connect() store.update([d.dict() for d in owners]) endpoint = ReadOnlyResource( store, Owner, query_operators=[ StringQueryOperator(model=Owner), NumericQuery(model=Owner), SparseFieldsQuery(model=Owner), ], disable_validation=True, ) app = FastAPI() app.include_router(endpoint.router) client = TestClient(app) print(inspect.signature(NumericQuery(model=Owner).query)) url = base + urlencode(payload) if debug: print(url) res = client.get(url) json = res.json() return res, json.get("data", []) # type: ignore
def create_correlation_quantity_indexed_docs(): """ Outputs JSON file containing the same data from create_correlation_test_docs() but as individual quantities. This mimics the quantity-indexed store. Must run create_correlation_test_docs() first and have the JSON file in the test directory. """ pn_store = MemoryStore() q_store = MemoryStore() m_store = MemoryStore() with open(os.path.join(CORR_TEST_DIR, "correlation_propnet_data.json"), 'r') as f: data = json.load(f) pn_store.connect() pn_store.update(jsanitize(data, strict=True, allow_bson=True)) sb = SeparationBuilder(pn_store, q_store, m_store) r = Runner([sb]) r.run() q_data = list(q_store.query(criteria={}, properties={'_id': False})) dumpfn( q_data, os.path.join(CORR_TEST_DIR, "correlation_propnet_quantity_data.json"))
def owner_store(): store = MemoryStore("owners", key="name") store.connect() store.update([d.dict() for d in owners]) return store
class SiteDescriptorsBuilderTest(unittest.TestCase): @classmethod def setUpClass(self): # Set up test db, etc. self.test_materials = MemoryStore("mat_site_fingerprint") self.test_materials.connect() struct_docs = loadfn(test_structs, cls=None) self.test_materials.update(struct_docs) def test_builder(self): test_site_descriptors = MemoryStore("test_site_descriptors") sd_builder = SiteDescriptorsBuilder(self.test_materials, test_site_descriptors) sd_builder.connect() for t in sd_builder.get_items(): processed = sd_builder.process_item(t) if processed: sd_builder.update_targets([processed]) else: import nose nose.tools.set_trace() self.assertEqual(len([t for t in sd_builder.get_items()]), 0) # Remove one data piece in diamond entry and test partial update. test_site_descriptors.collection.find_one_and_update( {'task_id': 'mp-66'}, {'$unset': { 'site_descriptors': 1 }}) items = [e for e in list(sd_builder.get_items())] self.assertEqual(len(items), 1) def test_get_all_site_descriptors(self): test_site_descriptors = MemoryStore("test_site_descriptors") sd_builder = SiteDescriptorsBuilder(self.test_materials, test_site_descriptors) C = self.test_materials.query_one(criteria={"task_id": "mp-66"}) NaCl = self.test_materials.query_one(criteria={"task_id": "mp-22862"}) Fe = self.test_materials.query_one(criteria={"task_id": "mp-13"}) # Diamond. d = sd_builder.get_site_descriptors_from_struct( Structure.from_dict(C["structure"])) for di in d.values(): self.assertEqual(len(di), 2) self.assertEqual(d['cn_VoronoiNN'][0]['CN_VoronoiNN'], 20) self.assertAlmostEqual(d['cn_wt_VoronoiNN'][0]['CN_VoronoiNN'], 4.5381162) self.assertEqual(d['cn_JMolNN'][0]['CN_JMolNN'], 4) self.assertAlmostEqual(d['cn_wt_JMolNN'][0]['CN_JMolNN'], 4.9617398) self.assertEqual(d['cn_MinimumDistanceNN'][0]['CN_MinimumDistanceNN'], 4) self.assertAlmostEqual( d['cn_wt_MinimumDistanceNN'][0]['CN_MinimumDistanceNN'], 4) self.assertEqual(d['cn_MinimumOKeeffeNN'][0]['CN_MinimumOKeeffeNN'], 4) self.assertAlmostEqual( d['cn_wt_MinimumOKeeffeNN'][0]['CN_MinimumOKeeffeNN'], 4) self.assertEqual(d['cn_MinimumVIRENN'][0]['CN_MinimumVIRENN'], 4) self.assertAlmostEqual(d['cn_wt_MinimumVIRENN'][0]['CN_MinimumVIRENN'], 4) self.assertEqual( d['cn_BrunnerNN_reciprocal'][0]['CN_BrunnerNN_reciprocal'], 4) self.assertAlmostEqual( d['cn_wt_BrunnerNN_reciprocal'][0]['CN_BrunnerNN_reciprocal'], 4) self.assertAlmostEqual(d['opsf'][0]['tetrahedral CN_4'], 0.9995) #self.assertAlmostEqual(d['csf'][0]['tetrahedral CN_4'], 0.9886777) ds = sd_builder.get_statistics(d) self.assertTrue('opsf' in list(ds.keys())) self.assertTrue('csf' in list(ds.keys())) for k, dsk in ds.items(): for di in dsk: self.assertEqual(len(list(di.keys())), 5) def get_index(li, optype): for i, di in enumerate(li): if di['name'] == optype: return i raise RuntimeError('did not find optype {}'.format(optype)) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['max'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['min'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['mean'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'tetrahedral CN_4')]['std'], 0) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['mean'], 0.0005) # NaCl. d = sd_builder.get_site_descriptors_from_struct( Structure.from_dict(NaCl["structure"])) self.assertAlmostEqual(d['opsf'][0]['octahedral CN_6'], 0.9995) #self.assertAlmostEqual(d['csf'][0]['octahedral CN_6'], 1) ds = sd_builder.get_statistics(d) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['max'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['min'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['mean'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'octahedral CN_6')]['std'], 0) # Iron. d = sd_builder.get_site_descriptors_from_struct( Structure.from_dict(Fe["structure"])) self.assertAlmostEqual(d['opsf'][0]['body-centered cubic CN_8'], 0.9995) #self.assertAlmostEqual(d['csf'][0]['body-centered cubic CN_8'], 0.755096) ds = sd_builder.get_statistics(d) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['max'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['min'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['mean'], 0.9995) self.assertAlmostEqual( ds['opsf'][get_index(ds['opsf'], 'body-centered cubic CN_8')]['std'], 0)
def pet_store(pets): store = MemoryStore("pets", key="name") store.connect() pets = [jsonable_encoder(d) for d in pets] store.update(pets) return store
def owner_store(owners): store = MemoryStore("owners", key="name") store.connect() owners = [jsonable_encoder(d) for d in owners] store.update(owners) return store
class BuilderTest(unittest.TestCase): def setUp(self): self.materials = MemoryStore() self.materials.connect() materials = loadfn(os.path.join(TEST_DIR, "test_materials.json")) materials = jsanitize(materials, strict=True, allow_bson=True) self.materials.update(materials) self.propstore = MemoryStore() self.propstore.connect() def test_serial_runner(self): builder = PropnetBuilder(self.materials, self.propstore) runner = Runner([builder]) runner.run() def test_multiproc_runner(self): builder = PropnetBuilder(self.materials, self.propstore) runner = Runner([builder]) runner.run() def test_process_item(self): item = self.materials.query_one(criteria={"pretty_formula": "Cs"}) builder = PropnetBuilder(self.materials, self.propstore) processed = builder.process_item(item) self.assertIsNotNone(processed) # Ensure vickers hardness gets populated self.assertIn("vickers_hardness", processed) if 'created_at' in item.keys(): date_value = item['created_at'] else: date_value = "" # Check that provenance values propagate correctly current_quantity = processed['vickers_hardness']['quantities'][0] at_deepest_level = False while not at_deepest_level: current_provenance = current_quantity['provenance'] if current_provenance['inputs'] is not None: self.assertEqual(current_provenance['source']['source'], "propnet") self.assertEqual(current_provenance['source']['source_key'], current_quantity['internal_id']) self.assertNotIn(current_provenance['source']['date_created'], ("", None)) current_quantity = current_provenance['inputs'][0] else: self.assertEqual(current_provenance['source']['source'], "Materials Project") self.assertEqual(current_provenance['source']['source_key'], item['task_id']) self.assertEqual(current_provenance['source']['date_created'], date_value) at_deepest_level = True # @unittest.skipIf(not os.path.isfile("runner.json"), "No runner file") # def test_runner_pipeline(self): # from monty.serialization import loadfn # runner = loadfn("runner.json") # runner.builders[0].connect() # items = list(runner.builders[0].get_items()) # processed = runner.builders[0].process_item(items[0]) # runner.run() # Just here for reference, in case anyone wants to create a new set # of test materials -jhm @unittest.skipIf(True, "Skipping test materials creation") def create_test_docs(self): formulas = ["BaNiO3", "Si", "Fe2O3", "Cs"] from maggma.advanced_stores import MongograntStore from monty.serialization import dumpfn mgstore = MongograntStore("ro:matgen2.lbl.gov/mp_prod", "materials") builder = PropnetBuilder( mgstore, self.propstore, criteria={"pretty_formula": {"$in": formulas}, "e_above_hull": 0}) builder.connect() dumpfn(list(builder.get_items()), "test_materials.json")
def search_helper(payload, base: str = "/?", debug=True) -> Tuple[Response, Any]: """ Helper function to directly query search endpoints Args: store: store f base: base of the query, default to /query? client: TestClient generated from FastAPI payload: query in dictionary format debug: True = print out the url, false don't print anything Returns: request.Response object that contains the response of the correspoding payload """ owner_store = MemoryStore("owners", key="name") owner_store.connect() owner_store.update([d.dict() for d in owners]) pets_store = MemoryStore("pets", key="name") pets_store.connect() pets_store.update([jsonable_encoder(d) for d in pets]) resources = { "owners": [ ReadOnlyResource( owner_store, Owner, query_operators=[ StringQueryOperator(model=Owner), # type: ignore NumericQuery(model=Owner), # type: ignore SparseFieldsQuery(model=Owner), PaginationQuery(), ], ) ], "pets": [ ReadOnlyResource( pets_store, Owner, query_operators=[ StringQueryOperator(model=Pet), NumericQuery(model=Pet), SparseFieldsQuery(model=Pet), PaginationQuery(), ], ) ], } api = API(resources=resources) client = TestClient(api.app) url = base + urlencode(payload) if debug: print(url) res = client.get(url) try: data = res.json().get("data", []) except Exception: data = res.reason return res, data
class CorrelationTest(unittest.TestCase): def setUp(self): self.propstore = MemoryStore() self.propstore.connect() materials = loadfn( os.path.join(TEST_DIR, "correlation_propnet_data.json")) materials = jsanitize(materials, strict=True, allow_bson=True) self.propstore.update(materials) self.materials = MemoryStore() self.materials.connect() materials = loadfn(os.path.join(TEST_DIR, "correlation_mp_data.json")) materials = jsanitize(materials, strict=True, allow_bson=True) self.materials.update(materials) self.correlation = MemoryStore() self.correlation.connect() self.propnet_props = [ "band_gap_pbe", "bulk_modulus", "vickers_hardness" ] self.mp_query_props = ["magnetism.total_magnetization_normalized_vol"] self.mp_props = ["total_magnetization_normalized_vol"] # vickers hardness (x-axis) vs. bulk modulus (y-axis) self.correlation_values_vickers_bulk = { 'linlsq': 0.4155837083845686, 'pearson': 0.6446578227126143, 'mic': 0.5616515521782413, 'theilsen': 0.4047519736540858, 'ransac': 0.3747245847179631 } self.correlation_values_bulk_vickers = { 'linlsq': 0.4155837083845686, 'pearson': 0.6446578227126143, 'mic': 0.5616515521782413, 'theilsen': 0.39860109570815505, 'ransac': 0.3119656700613579 } def test_serial_runner(self): builder = CorrelationBuilder(self.propstore, self.materials, self.correlation) runner = Runner([builder]) runner.run() def test_multiproc_runner(self): builder = CorrelationBuilder(self.propstore, self.materials, self.correlation) runner = Runner([builder], max_workers=2) runner.run() def test_process_item(self): test_props = [['band_gap_pbe', 'total_magnetization_normalized_vol'], ['bulk_modulus', 'vickers_hardness']] linlsq_correlation_values = [0.03620401274778131, 0.4155837083845686] path_lengths = [None, 2] for props, expected_correlation_val, expected_path_length in \ zip(test_props, linlsq_correlation_values, path_lengths): builder = CorrelationBuilder(self.propstore, self.materials, self.correlation, props=props) processed = None prop_x, prop_y = props for item in builder.get_items(): if item['x_name'] == prop_x and \ item['y_name'] == prop_y: processed = builder.process_item(item) break self.assertIsNotNone(processed) self.assertIsInstance(processed, tuple) px, py, correlation, func_name, n_points, path_length = processed self.assertEqual(px, prop_x) self.assertEqual(py, prop_y) self.assertAlmostEqual(correlation, expected_correlation_val) self.assertEqual(func_name, 'linlsq') self.assertEqual(n_points, 200) self.assertEqual(path_length, expected_path_length) def test_correlation_funcs(self): def custom_correlation_func(x, y): return 0.5 correlation_values = { k: v for k, v in self.correlation_values_bulk_vickers.items() } correlation_values['test_correlation.custom_correlation_func'] = 0.5 builder = CorrelationBuilder( self.propstore, self.materials, self.correlation, props=['vickers_hardness', 'bulk_modulus'], funcs=['all', custom_correlation_func]) self.assertEqual( set(builder._funcs.keys()), set(correlation_values.keys()), msg="Are there new built-in functions in the correlation builder?") for item in builder.get_items(): if item['x_name'] == 'bulk_modulus' and \ item['y_name'] == 'vickers_hardness': processed = builder.process_item(item) self.assertIsInstance(processed, tuple) prop_x, prop_y, correlation, func_name, n_points, path_length = processed self.assertEqual(prop_x, 'bulk_modulus') self.assertEqual(prop_y, 'vickers_hardness') self.assertIn(func_name, correlation_values.keys()) self.assertAlmostEqual(correlation, correlation_values[func_name]) self.assertEqual(n_points, 200) self.assertEqual(path_length, 2) def test_database_write(self): builder = CorrelationBuilder(self.propstore, self.materials, self.correlation, props=self.propnet_props + self.mp_props, funcs='all') runner = Runner([builder]) runner.run() data = list(self.correlation.query(criteria={})) # count = n_props**2 * n_funcs # n_props = 4, n_funcs = 5 self.assertEqual(len(data), 80) for d in data: self.assertIsInstance(d, dict) self.assertEqual( set(d.keys()), { 'property_x', 'property_y', 'correlation', 'correlation_func', 'n_points', 'shortest_path_length', 'id', '_id', 'last_updated' }) self.assertEqual(d['n_points'], 200) if d['property_x'] == 'vickers_hardness' and \ d['property_y'] == 'bulk_modulus': self.assertAlmostEqual( d['correlation'], self.correlation_values_vickers_bulk[ d['correlation_func']]) elif d['property_x'] == 'bulk_modulus' and \ d['property_y'] == 'vickers_hardness': self.assertAlmostEqual( d['correlation'], self.correlation_values_bulk_vickers[ d['correlation_func']]) # Just here for reference, in case anyone wants to create a new set # of test materials. Requires mongogrant read access to knowhere.lbl.gov. @unittest.skipIf(True, "Skipping test materials creation") def create_test_docs(self): from maggma.advanced_stores import MongograntStore from monty.serialization import dumpfn pnstore = MongograntStore("ro:knowhere.lbl.gov/mp_core", "propnet") pnstore.connect() mpstore = MongograntStore("ro:knowhere.lbl.gov/mp_core", "materials") mpstore.connect() cursor = pnstore.query(criteria={ '$and': [{ '$or': [{ p: { '$exists': True } }, { 'inputs.symbol_type': p }] } for p in self.propnet_props] }, properties=['task_id']) pn_mpids = [item['task_id'] for item in cursor] cursor = mpstore.query( criteria={p: { '$exists': True } for p in self.mp_query_props}, properties=['task_id']) mp_mpids = [item['task_id'] for item in cursor] mpids = list(set(pn_mpids).intersection(set(mp_mpids)))[:200] pn_data = pnstore.query(criteria={'task_id': { '$in': mpids }}, properties=['task_id', 'inputs'] + [p + '.mean' for p in self.propnet_props] + [p + '.units' for p in self.propnet_props]) dumpfn(list(pn_data), os.path.join(TEST_DIR, "correlation_propnet_data.json")) mp_data = mpstore.query(criteria={'task_id': { '$in': mpids }}, properties=['task_id'] + self.mp_query_props) dumpfn(list(mp_data), os.path.join(TEST_DIR, "correlation_mp_data.json"))