def register(self, component): # in case it gets overwritten! self.location = Path(self.location) cache_config = component.context.get("cache", "no") # default to off if cache_config: key = f"{component.kind}/{component.get_hash()}" cache_kind, cache_inner = parse_config( cache_config, shortcut_ok=True ) if cache_kind == "no": return NoCache() elif cache_kind == "disk": if "location" in cache_inner: cache_inner["location"] = Path(cache_inner["location"]) / key else: cache_inner["location"] = self.location / key cache = DiskCache(**cache_inner) self.caches.append((str(component), key, cache)) else: raise NotImplementedError( f"Currently, only 'disk' type caches are supported. Not {cache_config}." ) return cache else: return None
def _to_weightf(weightf, acc): if weightf == "unity": return {"function": "unity", "cutoff": acc} else: kind, inner = parse_config(weightf) assert kind == "exp" return {"function": "exp", "cutoff": acc, "scale": inner["ls"]}
def test_catches_exceptions(self): pool = EvaluationPool( max_workers=4, evaluator_config={"mock_eval": {}}, caught_exceptions=(ValueError, ), ) future = pool.schedule("raise") result = pool.finish(future) state, inner = parse_config(result) self.assertEqual(state, "error") self.assertEqual(inner["error"], "ValueError") self.assertEqual(inner["error_text"], "Hello!") self.assertTrue("traceback" in inner) # is the same found in the cache? state2, inner2 = pool.evals[future.eid] self.assertEqual(state2, "error") self.assertEqual(inner2, inner) # what if we do it again? future = pool.schedule("raise") result = pool.finish(future) state3, inner3 = pool.evals[future.eid] self.assertEqual(state3, "error") self.assertEqual(inner3, inner) pool.shutdown()
def test_catches_timeout_exceptions(self): # this is a separate case because this exception # is raised at a slighly different location! pool = EvaluationPool( max_workers=4, evaluator_config={"mock_eval": {}}, caught_exceptions=(TimeoutError, ), trial_timeout=0.01, ) future = pool.schedule("wait") result = pool.finish(future) state, inner = parse_config(result) self.assertEqual(state, "error") self.assertEqual(inner["error"], "TimeoutError") self.assertTrue("traceback" in inner) # is the same found in the cache? state2, inner2 = pool.evals[future.eid] self.assertEqual(state2, "error") self.assertEqual(inner2, inner) # what if we do it again? future = pool.schedule("wait") result = pool.finish(future) state3, inner3 = pool.evals[future.eid] self.assertEqual(state3, "error") self.assertEqual(inner3, inner) pool.shutdown()
def make_params(sfs): g2_params = [] g4_params = [] for sf in sfs: kind, inner = parse_config(sf) if kind == "rad": g2_params.append([inner["eta"], inner["mu"]]) elif kind == "ang": g4_params.append([inner["eta"], inner["zeta"], inner["lambd"]]) else: raise ValueError( f"ACSF kind {kind} is not yet implemented. (Allowed: rad and ang.)" ) if len(g2_params) == 0: g2_params = None else: g2_params = np.array(g2_params) if len(g4_params) == 0: g4_params = None else: g4_params = np.array(g4_params) return g2_params, g4_params
def normalize_universal_sfs(sfs): result = [] count_rad = 0 count_ang = 0 for sf in sfs: if sf is not None: kind, inner = parse_config(sf) if kind == "rad": result.append(sf) count_rad += 1 elif kind == "ang": result.append(sf) count_ang += 1 elif kind in schemes: generated, tmp_rad, tmp_ang = schemes[kind](**inner) result.extend(generated) count_rad += tmp_rad count_ang += tmp_ang else: raise ValueError( f"Don't know how to deal with universal symmetry function config {sf}." ) return result, count_rad, count_ang
def replay(self, tape): for record in tape: action, payload = parse_config(record) if action == "suggest": self.replay_suggest(payload) if action == "submit": self.replay_submit(payload)
def replay_submit(self, payload): tid = payload["tid"] result = payload["result"] self.submit(tid, result) # refilling the evals...! (since the pool can't do it) state, outcome = parse_config(result) eid = compute_hash(outcome["suggestion"]) self.evals.submit_result(eid, result)
def _to_weightf(config): """Translate to qmmlpack weightf format.""" if isinstance(config, str): return config else: kind, inner = parse_config(config) return (kind, (inner["ls"], ))
def _from_config(config, classes={}, **kwargs): if isinstance(config, Configurable): # did we accidentally pass an already instantiated object? return config else: kind, inner = parse_config(config) if kind in classes: return classes[kind].from_config(inner, **kwargs) else: raise ValueError( f"Cannot find class with name {kind} in registry.")
def normalize_elemental_sfs(sfs): result = [] for sf in sfs: if sf is not None: kind, inner = parse_config(sf) if kind in ["rad", "ang"]: result.append(sf) else: raise ValueError( f"Don't know how to deal with elemental symmetry function config {sf}." ) return result
def check_result(result): """Enforce the run-internal result format.""" state, outcome = parse_config(result) if state not in ["error", "ok"]: raise ValueError(f"Received a result with invalid state={state}.") if state == ["ok"]: if "loss" not in outcome: raise ValueError(f"Results with status 'ok' must contain a loss.") if "suggestion" not in outcome: raise ValueError(f"Results with status 'ok' must contain a suggestion dict.") return state, outcome
def __init__(self, elems, cutoff, sfs=[], stratify=True, context={}): super().__init__(context=context) sfs_with_cutoff = [] for sf in sfs: kind, inner = parse_config(sf) inner["cutoff"] = cutoff sfs_with_cutoff.append({kind: inner}) self.runner_config = prepare_config(elems=elems, elemental=[], universal=sfs_with_cutoff) self.config = { "elems": elems, "sfs": sfs, "cutoff": cutoff, "stratify": stratify }
def test_parse_valid_config(self): valid = {"a": {"b": 3}} kind, config = parse_config(valid) self.assertEqual(kind, "a") self.assertEqual(config, {"b": 3})
def run_wait(): future = pool.schedule("wait") result = pool.finish(future) state, inner = parse_config(result)
def test_parse_valid_config_with_shortcut_ok(self): valid = "test" kind, config = parse_config(valid, shortcut_ok=True) self.assertEqual(kind, "test") self.assertEqual(config, {})
def test_parse_invalid_config(self): with self.assertRaises(ValueError): invalid = {3: {"b": 3}, "b": 3} parse_config(invalid)
def test_parse_invalid_config_with_shortcut_ok(self): with self.assertRaises(ValueError): invalid = {3: {"b": 3}, "b": 3} parse_config(invalid, shortcut_ok=True)
def submit_result(self, key, result): state, outcome = parse_config(result) self.submit(key, state, outcome)