def load_metanash_pure_strat(policy: Policy, pure_strat_spec: StrategySpec): pure_strat_checkpoint_path = pure_strat_spec.metadata["checkpoint_path"] checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path) weights = checkpoint_data["weights"] weights = {k.replace("_dot_", "."): v for k, v in weights.items()} policy.set_weights(weights=weights) policy.p2sro_policy_spec = pure_strat_spec
def load_pure_strat(policy: Policy, pure_strat_spec, checkpoint_path: str = None): assert pure_strat_spec is None or checkpoint_path is None, "can only pass one or the other" if checkpoint_path is None: if hasattr(policy, "policy_spec") and pure_strat_spec == policy.policy_spec: return pure_strat_checkpoint_path = pure_strat_spec.metadata["checkpoint_path"] else: pure_strat_checkpoint_path = checkpoint_path checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path) weights = checkpoint_data["weights"] weights = {k.replace("_dot_", "."): v for k, v in weights.items()} policy.set_weights(weights=weights) policy.policy_spec = pure_strat_spec
def load_pure_strat_cached(policy: Policy, pure_strat_spec): pure_strat_checkpoint_path = pure_strat_spec.metadata[ "checkpoint_path"] if pure_strat_checkpoint_path in cache: weights = cache[pure_strat_checkpoint_path] else: checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path) weights = checkpoint_data["weights"] weights = {k.replace("_dot_", "."): v for k, v in weights.items()} cache[pure_strat_checkpoint_path] = weights policy.set_weights(weights=weights) policy.policy_spec = pure_strat_spec
def load_pure_strat(policy: Policy, pure_strat_spec: StrategySpec = None, checkpoint_path: str = None, weights_key: str = "weights"): if pure_strat_spec is not None and checkpoint_path is not None: raise ValueError( "Can only pass pure_strat_spec or checkpoint_path but not both") if checkpoint_path is None: if hasattr(policy, "policy_spec") and pure_strat_spec == policy.policy_spec: return pure_strat_checkpoint_path = pure_strat_spec.metadata[ "checkpoint_path"] else: pure_strat_checkpoint_path = checkpoint_path weights = None try: num_load_attempts = 5 for attempt in range(num_load_attempts): try: checkpoint_data = deepdish.io.load( path=pure_strat_checkpoint_path) weights = checkpoint_data[weights_key] break except (HDF5ExtError, KeyError): if attempt + 1 == num_load_attempts: raise time.sleep(1.0) #TODO use correct exception except Exception: with open(pure_strat_checkpoint_path, "rb") as pickle_file: checkpoint_data = cloudpickle.load(pickle_file) weights = checkpoint_data[weights_key] weights = {k.replace("_dot_", "."): v for k, v in weights.items()} policy.set_weights(weights=weights) policy.policy_spec = pure_strat_spec
def set_policy_weights(policy: Policy, checkpoint_path: str): checkpoint_data = deepdish.io.load(path=checkpoint_path) weights = checkpoint_data["weights"] weights = {k.replace("_dot_", "."): v for k, v in weights.items()} policy.set_weights(weights)