コード例 #1
0
def test_save_and_laod_yaml():
    ex_dir = setup_experiment("testenv",
                              "testalgo",
                              "testinfo",
                              base_dir=TEMP_DIR)

    # Save test data to YAML-file
    save_dicts_to_yaml(
        dict(a=1),
        dict(b=2.0),
        dict(c=np.array([1.0, 2.0])),
        dict(d=to.tensor([3.0, 4.0])),
        dict(e="string"),
        dict(f=[5, "f"]),
        dict(g=(6, "g")),
        save_dir=ex_dir,
        file_name="testfile",
    )

    data = load_dict_from_yaml(osp.join(ex_dir, "testfile.yaml"))
    assert isinstance(data, dict)
    assert data["a"] == 1
    assert data["b"] == 2
    assert data["c"] == [1.0, 2.0]  # now a list
    assert data["d"] == [3.0, 4.0]  # now a list
    assert data["e"] == "string"
    assert data["f"] == [5, "f"]
    assert data["g"] == (6, "g")

    # Delete the created folder recursively
    shutil.rmtree(osp.join(TEMP_DIR, "testenv"),
                  ignore_errors=True)  # also deletes read-only files
コード例 #2
0
def test_save_and_laod_yaml():
    ex_dir = setup_experiment('testenv',
                              'testalgo',
                              'testinfo',
                              base_dir=TEMP_DIR)

    # Save test data to YAML-file (ndarrays should be converted to lists)
    save_list_of_dicts_to_yaml(
        [dict(a=1),
         dict(b=2.0),
         dict(c=np.array([1., 2.]).tolist())], ex_dir, 'testfile')

    data = load_dict_from_yaml(osp.join(ex_dir, 'testfile.yaml'))
    assert isinstance(data, dict)

    # Delete the created folder recursively
    shutil.rmtree(osp.join(TEMP_DIR, 'testenv'),
                  ignore_errors=True)  # also deletes read-only files
コード例 #3
0
from pyrado.utils.argparser import get_argparser
from pyrado.utils.input_output import print_cbt
from pyrado.utils.data_types import RenderMode

if __name__ == '__main__':
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment()
    if not osp.isdir(ex_dir):
        raise pyrado.PathErr(given=ex_dir)

    # Load the environment randomizer
    env_sim = joblib.load(osp.join(ex_dir, 'env_sim.pkl'))
    hparam = load_dict_from_yaml(osp.join(ex_dir, 'hyperparams.yaml'))

    # Override the time step size if specified
    if args.dt is not None:
        env_sim.dt = args.dt

    # Crawl through the given directory and check how many init policies and candidates there are
    for root, dirs, files in os.walk(ex_dir):
        if args.load_all:
            found_policies = [p for p in files if p.endswith('_policy.pt')]
            found_cands = [c for c in files if c.endswith('_candidate.pt')]
        else:
            found_policies = [p for p in files if not p.startswith('init_') and p.endswith('_policy.pt')]
            found_cands = [c for c in files if not c.startswith('init_') and c.endswith('_candidate.pt')]
    if not len(found_policies) == len(found_cands):  # don't count the final policy
        raise pyrado.ValueErr(msg='Found a different number of initial policies than candidates!')
コード例 #4
0
ファイル: continue.py プロジェクト: fdamken/SimuRLacra
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Continue a training run in the same folder
"""
import os.path as osp

from pyrado.algorithms.base import Algorithm
from pyrado.logger.experiment import ask_for_experiment, load_dict_from_yaml
from pyrado.utils.argparser import get_argparser

if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment(
        hparam_list=args.show_hparams) if args.dir is None else args.dir

    # Load the hyper-parameters
    hparams = load_dict_from_yaml(osp.join(ex_dir, "hyperparams.yaml"))

    # Load the complete algorithm
    algo = Algorithm.load_snapshot(ex_dir)

    # Jeeeha
    algo.train(seed=hparams.get("seed", None))
コード例 #5
0
def load_experiment(
        ex_dir: str,
        args: Any = None) -> (Union[SimEnv, EnvWrapper], Policy, dict):
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser, pass `None` to fall back to the values from the default argparser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    env, policy, extra = None, None, dict()

    if args is None:
        # Fall back to default arguments. By passing [], we ignore the command line arguments
        args = get_argparser().parse_args([])

    # Hyper-parameters
    hparams_file_name = 'hyperparams.yaml'
    try:
        hparams = load_dict_from_yaml(osp.join(ex_dir, hparams_file_name))
        extra['hparams'] = hparams
    except (pyrado.PathErr, FileNotFoundError, KeyError):
        print_cbt(
            f'Did not find {hparams_file_name} in {ex_dir} or could not crawl the loaded hyper-parameters.',
            'y',
            bright=True)

    # Algorithm specific
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name='algo')
    if isinstance(algo, BayRn):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SPOTA):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} and filled it with 100 random instances.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_cand.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine_cand, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine_cand.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SimOpt):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_policy.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (domain parameter distribution policy)
        extra['ddp_policy'] = pyrado.load(algo.subroutine_distr.policy,
                                          'ddp_policy', 'pt', ex_dir, None)

    elif isinstance(algo, (EPOpt, UDR)):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperLive):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperLive)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} with DomainRandWrapperLive randomizer.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'y')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, ActorCritic):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        extra['vfcn'] = pyrado.load(algo.critic.vfcn, f'{args.vfcn_name}',
                                    'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}", 'g')

    elif isinstance(algo, ParameterExploring):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')

    elif isinstance(algo, ValueBased):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Target value functions
        if isinstance(algo, DQL):
            extra['qfcn_target'] = pyrado.load(algo.qfcn_targ, 'qfcn_target',
                                               'pt', ex_dir, None)
            print_cbt(f"Loaded {osp.join(ex_dir, 'qfcn_target.pt')}", 'g')
        elif isinstance(algo, SAC):
            extra['qfcn_target1'] = pyrado.load(algo.qfcn_targ_1,
                                                'qfcn_target1', 'pt', ex_dir,
                                                None)
            extra['qfcn_target2'] = pyrado.load(algo.qfcn_targ_2,
                                                'qfcn_target2', 'pt', ex_dir,
                                                None)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'qfcn_target1.pt')} and {osp.join(ex_dir, 'qfcn_target2.pt')}",
                'g')
        else:
            raise NotImplementedError

    elif isinstance(algo, SVPG):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (particles)
        for idx, p in enumerate(algo.particles):
            extra[f'particle{idx}'] = pyrado.load(algo.particles[idx],
                                                  f'particle_{idx}', 'pt',
                                                  ex_dir, None)

    elif isinstance(algo, TSPred):
        # Dataset
        extra['dataset'] = to.load(osp.join(ex_dir, 'dataset.pt'))
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)

    else:
        raise pyrado.TypeErr(
            msg=
            'No matching algorithm name found during loading the experiment!')

    # Check if the return types are correct. They can be None, too.
    if env is not None and not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if policy is not None and not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if extra is not None and not isinstance(extra, dict):
        raise pyrado.TypeErr(given=extra, expected_type=dict)

    return env, policy, extra
コード例 #6
0
from pyrado.algorithms.base import Algorithm
from pyrado.logger.experiment import ask_for_experiment, load_dict_from_yaml
from pyrado.utils.argparser import get_argparser
from pyrado.utils.data_types import update_matching_keys_recursively


if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment(hparam_list=args.show_hparams) if args.dir is None else args.dir

    # Load the hyper-parameters
    hparam_args, setting_args = None, None
    for file_name in os.listdir(ex_dir):
        if file_name.startswith("hparam") and file_name.endswith(".yaml"):
            hparam_args = load_dict_from_yaml(osp.join(ex_dir, file_name))
        elif file_name == "settings.yaml":
            setting_args = load_dict_from_yaml(osp.join(ex_dir, file_name))

    # Update matching
    update_matching_keys_recursively(setting_args, hparam_args)

    # Load the complete algorithm
    algo = Algorithm.load_snapshot(ex_dir)

    # Jeeeha
    algo.train(seed=setting_args.get("seed", None))
コード例 #7
0
def load_experiment(
        ex_dir: str,
        args: Any = None) -> ([SimEnv, EnvWrapper], Policy, Optional[dict]):
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    hparams_file_name = 'hyperparams.yaml'
    env, policy, kwout = None, None, dict()

    try:
        hparams = load_dict_from_yaml(osp.join(ex_dir, hparams_file_name))
        kwout['hparams'] = hparams

        # Check which algorithm has been used for training, i.e. what can be loaded, by crawing the hyper-parameters
        # First check meta algorithms so they don't get masked by their subroutines
        if SPOTA.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'init_env.pkl'))
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'init_env.pkl')} and filled it with 100 random instances.",
                'g')
            # Policy
            if args.iter == -1:
                policy = to.load(osp.join(ex_dir, 'final_policy_cand.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, 'final_policy_cand.pt')}",
                          'g')
            else:
                policy = to.load(
                    osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt'))
                print_cbt(
                    f"Loaded {osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt')}",
                    'g')
            # Value function (optional)
            if any([
                    a.name in hparams.get('subroutine_name', '')
                    for a in [PPO, PPO2, A2C]
            ]):
                try:
                    kwout['valuefcn'] = to.load(
                        osp.join(ex_dir, 'final_valuefcn.pt'))
                    print_cbt(
                        f"Loaded {osp.join(ex_dir, 'final_valuefcn.pt')}", 'g')
                except FileNotFoundError:
                    kwout['valuefcn'] = to.load(osp.join(
                        ex_dir, 'valuefcn.pt'))
                    print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif BayRn.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env_sim.pkl'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
            if hasattr(env, 'randomizer'):
                last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
                env.adapt_randomizer(last_cand.numpy())
                print_cbt(f'Loaded the domain randomizer\n{env.randomizer}',
                          'w')
            # Policy
            if args.iter == -1:
                policy = to.load(osp.join(ex_dir, 'policy.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            else:
                policy = to.load(osp.join(ex_dir, f'iter_{args.iter}.pt'))
                print_cbt(f"Loaded {osp.join(ex_dir, f'iter_{args.iter}.pt')}",
                          'g')
            # Value function (optional)
            if any([
                    a.name in hparams.get('subroutine_name', '')
                    for a in [PPO, PPO2, A2C]
            ]):
                try:
                    kwout['valuefcn'] = to.load(
                        osp.join(ex_dir, 'final_valuefcn.pt'))
                    print_cbt(
                        f"Loaded {osp.join(ex_dir, 'final_valuefcn.pt')}", 'g')
                except FileNotFoundError:
                    kwout['valuefcn'] = to.load(osp.join(
                        ex_dir, 'valuefcn.pt'))
                    print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif EPOpt.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')

        elif any(
            [a.name in hparams.get('algo_name', '')
             for a in [PPO, PPO2, A2C]]):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            # Value function
            kwout['valuefcn'] = to.load(osp.join(ex_dir, 'valuefcn.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'valuefcn.pt')}", 'g')

        elif SAC.name in hparams.get('algo_name', ''):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
            # Target value functions
            kwout['target1'] = to.load(osp.join(ex_dir, 'target1.pt'))
            kwout['target2'] = to.load(osp.join(ex_dir, 'target2.pt'))
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'target1.pt')} and {osp.join(ex_dir, 'target2.pt')}",
                'g')

        elif any([
                a.name in hparams.get('algo_name', '')
                for a in [HC, PEPG, NES, REPS, PoWER, CEM]
        ]):
            # Environment
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            # Policy
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')

        else:
            raise KeyError(
                'No matching algorithm name found during loading the experiment.'
                'Check for the algo_name field in the yaml-file.')

    except (FileNotFoundError, KeyError):
        print_cbt(
            f'Did not find {hparams_file_name} in {ex_dir} or could not crawl the loaded hyper-parameters.',
            'y',
            bright=True)

        try:
            # Results of a standard algorithm
            env = joblib.load(osp.join(ex_dir, 'env.pkl'))
            policy = to.load(osp.join(ex_dir, 'policy.pt'))
            print_cbt(f"Loaded {osp.join(ex_dir, 'policy.pt')}", 'g')
        except FileNotFoundError:
            try:
                # Results of SPOTA
                env = joblib.load(osp.join(ex_dir, 'init_env.pkl'))
                typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
                print_cbt(
                    f"Loaded {osp.join(ex_dir, 'init_env.pkl')} and filled it with 100 random instances.",
                    'g')
            except FileNotFoundError:
                # Results of BayRn
                env = joblib.load(osp.join(ex_dir, 'env_sim.pkl'))

            try:
                # Results of SPOTA
                if args.iter == -1:
                    policy = to.load(osp.join(ex_dir, 'final_policy_cand.pt'))
                    print_cbt(f'Loaded final_policy_cand.pt', 'g')
                else:
                    policy = to.load(
                        osp.join(ex_dir, f'iter_{args.iter}_policy_cand.pt'))
                    print_cbt(f'Loaded iter_{args.iter}_policy_cand.pt', 'g')
            except FileNotFoundError:
                # Results of BayRn
                if args.iter == -1:
                    policy = to.load(osp.join(ex_dir, 'final_policy.pt'))
                    print_cbt(f'Loaded final_policy.pt', 'g')
                else:
                    policy = to.load(
                        osp.join(ex_dir, f'iter_{args.iter}_policy.pt'))
                    print_cbt(f'Loaded iter_{args.iter}_policy.pt', 'g')

    # Check if the return types are correct
    if not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)

    return env, policy, kwout