def load_metanash_pure_strat(policy: Policy, pure_strat_spec: StrategySpec):
    pure_strat_checkpoint_path = pure_strat_spec.metadata["checkpoint_path"]
    checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path)
    weights = checkpoint_data["weights"]
    weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
    policy.set_weights(weights=weights)
    policy.p2sro_policy_spec = pure_strat_spec
def load_pure_strat(policy: Policy, pure_strat_spec, checkpoint_path: str = None):
    assert pure_strat_spec is None or checkpoint_path is None, "can only pass one or the other"
    if checkpoint_path is None:
        if hasattr(policy, "policy_spec") and pure_strat_spec == policy.policy_spec:
            return
        pure_strat_checkpoint_path = pure_strat_spec.metadata["checkpoint_path"]
    else:
        pure_strat_checkpoint_path = checkpoint_path

    checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path)
    weights = checkpoint_data["weights"]
    weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
    policy.set_weights(weights=weights)
    policy.policy_spec = pure_strat_spec
Esempio n. 3
0
    def load_pure_strat_cached(policy: Policy, pure_strat_spec):

        pure_strat_checkpoint_path = pure_strat_spec.metadata[
            "checkpoint_path"]

        if pure_strat_checkpoint_path in cache:
            weights = cache[pure_strat_checkpoint_path]
        else:
            checkpoint_data = deepdish.io.load(path=pure_strat_checkpoint_path)
            weights = checkpoint_data["weights"]
            weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
            cache[pure_strat_checkpoint_path] = weights

        policy.set_weights(weights=weights)
        policy.policy_spec = pure_strat_spec
Esempio n. 4
0
def load_pure_strat(policy: Policy,
                    pure_strat_spec: StrategySpec = None,
                    checkpoint_path: str = None,
                    weights_key: str = "weights"):
    if pure_strat_spec is not None and checkpoint_path is not None:
        raise ValueError(
            "Can only pass pure_strat_spec or checkpoint_path but not both")
    if checkpoint_path is None:
        if hasattr(policy,
                   "policy_spec") and pure_strat_spec == policy.policy_spec:
            return
        pure_strat_checkpoint_path = pure_strat_spec.metadata[
            "checkpoint_path"]
    else:
        pure_strat_checkpoint_path = checkpoint_path

    weights = None

    try:
        num_load_attempts = 5
        for attempt in range(num_load_attempts):
            try:
                checkpoint_data = deepdish.io.load(
                    path=pure_strat_checkpoint_path)
                weights = checkpoint_data[weights_key]
                break
            except (HDF5ExtError, KeyError):
                if attempt + 1 == num_load_attempts:
                    raise
                time.sleep(1.0)

    #TODO use correct exception
    except Exception:
        with open(pure_strat_checkpoint_path, "rb") as pickle_file:
            checkpoint_data = cloudpickle.load(pickle_file)
            weights = checkpoint_data[weights_key]

    weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
    policy.set_weights(weights=weights)
    policy.policy_spec = pure_strat_spec
Esempio n. 5
0
 def set_policy_weights(policy: Policy, checkpoint_path: str):
     checkpoint_data = deepdish.io.load(path=checkpoint_path)
     weights = checkpoint_data["weights"]
     weights = {k.replace("_dot_", "."): v for k, v in weights.items()}
     policy.set_weights(weights)