Beispiel #1
0
    def train_argmax_policy(load_dir: str,
                            env_sim: MetaDomainRandWrapper,
                            subroutine: Algorithm,
                            num_restarts: int,
                            num_samples: int,
                            policy_param_init: to.Tensor = None,
                            valuefcn_param_init: to.Tensor = None) -> Policy:
        """
        Train a policy based on the maximizer of the posterior mean.

        :param load_dir: directory to load from
        :param env_sim: simulation environment
        :param subroutine: algorithm which performs the policy / value-function optimization
        :param num_restarts: number of restarts for the optimization of the acquisition function
        :param num_samples: number of samples for the optimization of the acquisition function
        :param policy_param_init: initial policy parameter values for the subroutine, set `None` to be random
        :param valuefcn_param_init: initial value function parameter values for the subroutine, set `None` to be random
        :return: the final BayRn policy
        """
        # Load the required data
        cands = to.load(osp.join(load_dir, 'candidates.pt'))
        cands_values = to.load(osp.join(load_dir,
                                        'candidates_values.pt')).unsqueeze(1)
        bounds = to.load(osp.join(load_dir, 'bounds.pt'))
        uc_normalizer = UnitCubeProjector(bounds[0, :], bounds[1, :])

        # Find the maximizer
        argmax_cand = BayRn.argmax_posterior_mean(cands, cands_values,
                                                  uc_normalizer, num_restarts,
                                                  num_samples)

        # Set the domain randomizer given the hyper-parameters
        env_sim.adapt_randomizer(argmax_cand.numpy())

        # Reset the subroutine's algorithm which includes resetting the exploration
        subroutine.reset()

        # Reset the subroutine's policy (and value function)
        subroutine.policy.init_param(policy_param_init)
        if isinstance(subroutine, ActorCritic):
            subroutine.critic.value_fcn.init_param(valuefcn_param_init)
        if policy_param_init is None:
            print_cbt('Learning the argmax solution from scratch', 'y')
        else:
            print_cbt('Learning the argmax solution given an initialization',
                      'y')

        subroutine.train(
            snapshot_mode='best')  # meta_info=dict(prefix='final')
        return subroutine.policy
Beispiel #2
0
def _load_experiment(ex_dir: pyrado.PathLike):
    # Load the algorithm
    algo = Algorithm.load_snapshot(ex_dir)
    if not isinstance(algo, (NPDR, BayesSim)):
        raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim))

    # Load the prior and the data
    prior = pyrado.load("prior.pt", ex_dir)
    data_real = pyrado.load("data_real.pt", ex_dir)

    # Load the posteriors
    posteriors = [
        SBIBase.load_posterior(ex_dir, idx_round=i, verbose=True)
        for i in range(algo.num_sbi_rounds)
    ]
    posteriors = remove_none_from_list(
        posteriors)  # in case the algorithm terminated early

    if data_real.shape[0] > len(posteriors):
        print_cbt(
            f"Found {data_real.shape[0]} data sets but {len(posteriors)} posteriors. Truncated the superfluous data.",
            "y",
        )
        data_real = data_real[:len(posteriors), :]

    # Artificially repeat the data (which was the same for every round) to later be able to use the same code
    data_real = data_real.repeat(len(posteriors), 1)
    assert data_real.shape[0] == len(posteriors)

    return algo, prior, data_real, posteriors
Beispiel #3
0
    def _adapt_batch_size(self, subroutine: Algorithm, n: int):
        """
        Adapt the number of dynamics transitions (steps or rollouts) of the subroutines according to the number of
        domains that is used in the current iteration of SPOTA.
        """
        if isinstance(subroutine, ParameterExploring):
            # Subclasses of ParameterExploring sample num_rollouts_per_param complete rollouts per iteration
            subroutine.sampler.num_rollouts_per_param = self.ntau * n

        elif isinstance(subroutine, ActorCritic):
            # The PPO sampler can either sample a minimum number of rollouts or steps
            if subroutine.sampler.min_steps is not None:
                subroutine.min_steps = self.ntau * n * self._env_dr.max_steps
                subroutine.sampler.set_min_count(min_steps=self.ntau * n *
                                                 self._env_dr.max_steps)
            if subroutine.sampler.min_rollouts is not None:
                subroutine.min_rollouts = self.ntau * n
                subroutine.sampler.set_min_count(min_rollouts=self.ntau * n)
        else:
            raise NotImplementedError(
                f'No _adapt_batch_size method found for class {type(subroutine)}!'
            )
Beispiel #4
0
def test_pddr(ex_dir, env: SimEnv, policy, algo_hparam):
    pyrado.set_seed(0)

    # Create algorithm and train
    teacher_policy = deepcopy(policy)
    critic = GAE(
        vfcn=FNNPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace),
                       hidden_sizes=[16, 16],
                       hidden_nonlin=to.tanh))
    teacher_algo_hparam = dict(critic=critic, min_steps=1500, max_iter=2)
    teacher_algo = PPO

    # Wrapper
    randomizer = create_default_randomizer(env)
    env = DomainRandWrapperLive(env, randomizer)

    # Subroutine
    algo_hparam = dict(
        max_iter=2,
        min_steps=env.max_steps,
        std_init=0.15,
        num_epochs=10,
        num_teachers=2,
        teacher_policy=teacher_policy,
        teacher_algo=teacher_algo,
        teacher_algo_hparam=teacher_algo_hparam,
        num_workers=1,
    )

    algo = PDDR(ex_dir, env, policy, **algo_hparam)

    algo.train()

    assert algo.curr_iter == algo.max_iter

    # Save and load
    algo.save_snapshot(meta_info=None)
    algo_loaded = Algorithm.load_snapshot(load_dir=ex_dir)
    assert isinstance(algo_loaded, Algorithm)
    policy_loaded = algo_loaded.policy

    # Check
    assert all(algo.policy.param_values == policy_loaded.param_values)

    # Load the experiment. Since we did not save any hyper-parameters, we ignore the errors when loading.
    env, policy, extra = load_experiment(ex_dir)
    assert isinstance(env, Env)
    assert isinstance(policy, Policy)
    assert isinstance(extra, dict)
Beispiel #5
0
    def train_argmax_policy(
        load_dir: pyrado.PathLike,
        env_sim: MetaDomainRandWrapper,
        subrtn: Algorithm,
        num_restarts: int,
        num_samples: int,
        policy_param_init: to.Tensor = None,
        valuefcn_param_init: to.Tensor = None,
        subrtn_snapshot_mode: str = "best",
    ) -> Policy:
        """
        Train a policy based on the maximizer of the posterior mean.

        :param load_dir: directory to load from
        :param env_sim: simulation environment
        :param subrtn: algorithm which performs the policy / value-function optimization
        :param num_restarts: number of restarts for the optimization of the acquisition function
        :param num_samples: number of samples for the optimization of the acquisition function
        :param policy_param_init: initial policy parameter values for the subroutine, set `None` to be random
        :param valuefcn_param_init: initial value function parameter values for the subroutine, set `None` to be random
        :param subrtn_snapshot_mode: snapshot mode for saving during training of the subroutine
        :return: the final BayRn policy
        """
        # Load the required data
        cands = pyrado.load("candidates.pt", load_dir)
        cands_values = pyrado.load("candidates_values.pt",
                                   load_dir).unsqueeze(1)
        ddp_space = pyrado.load("ddp_space.pkl", load_dir)

        if cands.shape[0] > cands_values.shape[0]:
            print_cbt(
                f"There are {cands.shape[0]} candidates but only {cands_values.shape[0]} evaluations. Ignoring the"
                f"candidates without evaluation for computing the argmax.",
                "y",
            )
            cands = cands[:cands_values.shape[0], :]

        # Find the maximizer
        argmax_cand = BayRn.argmax_posterior_mean(cands, cands_values,
                                                  ddp_space, num_restarts,
                                                  num_samples)

        # Set the domain randomizer
        env_sim.adapt_randomizer(argmax_cand.numpy())

        # Reset the subroutine algorithm which includes resetting the exploration
        subrtn.reset()

        # Do a warm start
        subrtn.init_modules(warmstart=True,
                            policy_param_init=policy_param_init,
                            valuefcn_param_init=valuefcn_param_init)

        subrtn.train(snapshot_mode=subrtn_snapshot_mode,
                     meta_info=dict(suffix="argmax"))
        return subrtn.policy
Beispiel #6
0
from pyrado.algorithms.meta.npdr import NPDR
from pyrado.algorithms.meta.sbi_base import SBIBase
from pyrado.environment_wrappers.utils import inner_env
from pyrado.plotting.distribution import draw_posterior_pairwise_scatter
from pyrado.utils.argparser import get_argparser

if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()
    plt.rc("text", usetex=args.use_tex)
    if not isinstance(args.num_samples, int) or args.num_samples < 1:
        raise pyrado.ValueErr(given=args.num_samples, ge_constraint="1")

    # NPDR
    ex_dir_npdr = os.path.join(pyrado.TEMP_DIR, "mg-ik", "npdr_time", "")
    algo = Algorithm.load_snapshot(ex_dir_npdr)
    if not isinstance(algo, NPDR):
        raise pyrado.TypeErr(given=algo, expected_type=NPDR)
    env_sim = inner_env(pyrado.load("env_sim.pkl", ex_dir_npdr))
    prior_npdr = pyrado.load("prior.pt", ex_dir_npdr)
    posterior_npdr = algo.load_posterior(ex_dir_npdr,
                                         idx_iter=0,
                                         idx_round=6,
                                         obj=None,
                                         verbose=True)  # CHOICE
    data_real_npdr = pyrado.load(f"data_real.pt",
                                 ex_dir_npdr,
                                 prefix="iter_0",
                                 verbose=True)  # CHOICE
    domain_params_npdr, log_probs = SBIBase.eval_posterior(
        posterior_npdr,
Beispiel #7
0
    def update(self, rollouts: Sequence[StepSequence], use_empirical_returns: bool = False):
        """
        Adapt the parameters of the advantage function estimator, minimizing the MSE loss for the given samples.

        :param rollouts: batch of rollouts
        :param use_empirical_returns: use the return from the rollout (True) or the ones from the V-fcn (False)
        :return adv: tensor of advantages after V-function updates
        """
        # Turn the batch of rollouts into a list of steps
        concat_ros = StepSequence.concat(rollouts)
        concat_ros.torch(data_type=to.get_default_dtype())

        if use_empirical_returns:
            # Compute the value targets (empirical discounted returns) for all samples
            v_targ = discounted_values(rollouts, self.gamma).view(-1, 1)
        else:
            # Use the value function to compute the value targets (also called bootstrapping)
            v_targ = self.tdlamda_returns(concat_ros=concat_ros)
        concat_ros.add_data('v_targ', v_targ)

        # Logging
        with to.no_grad():
            v_pred_old = self.values(concat_ros)
            loss_old = self.loss_fcn(v_pred_old, v_targ)
        vfcn_grad_norm = []

        # Iterate over all gathered samples num_epoch times
        for e in range(self.num_epoch):

            for batch in tqdm(concat_ros.split_shuffled_batches(
                self.batch_size, complete_rollouts=isinstance(self.vfcn, RecurrentPolicy)),
                total=num_iter_from_rollouts(None, concat_ros, self.batch_size),
                desc=f'Epoch {e}', unit='batches', file=sys.stdout, leave=False):
                # Reset the gradients
                self.optim.zero_grad()

                # Make predictions for this mini-batch using values function
                v_pred = self.values(batch)

                # Compute estimator loss for this mini-batch and backpropagate
                vfcn_loss = self.loss_fcn(v_pred, batch.v_targ)
                vfcn_loss.backward()

                # Clip the gradients if desired
                vfcn_grad_norm.append(Algorithm.clip_grad(self.vfcn, self.max_grad_norm))

                # Call optimizer
                self.optim.step()

            # Update the learning rate if a scheduler has been specified
            if self._lr_scheduler is not None:
                self._lr_scheduler.step()

        # Estimate the advantage after fitting the parameters of the V-fcn
        adv = self.gae(concat_ros)  # is done with to.no_grad()

        with to.no_grad():
            v_pred_new = self.values(concat_ros)
            loss_new = self.loss_fcn(v_pred_new, v_targ)
            vfcn_loss_impr = loss_old - loss_new  # positive values are desired
            explvar = explained_var(v_pred_new, v_targ)  # values close to 1 are desired

        # Log metrics computed from the old value function (before the update)
        self.logger.add_value('explained var critic', explvar, 4)
        self.logger.add_value('loss improv critic', vfcn_loss_impr, 4)
        self.logger.add_value('avg grad norm critic', np.mean(vfcn_grad_norm), 4)
        if self._lr_scheduler is not None:
            self.logger.add_value('lr critic', self._lr_scheduler.get_last_lr(), 6)

        return adv
from pyrado.logger.experiment import ask_for_experiment
from pyrado.utils.argparser import get_argparser
from pyrado.utils.experiments import load_experiment

if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment(
        hparam_list=args.show_hparams) if args.dir is None else args.dir

    # Load the environment and the policy
    env_sim, policy, kwout = load_experiment(ex_dir, args)

    subrtn = Algorithm.load_snapshot(load_dir=ex_dir, load_name="subrtn")

    # Start from previous results policy if desired
    ppi = policy.param_values.data if args.warmstart is not None else None
    if isinstance(subrtn, ActorCritic):
        vpi = kwout[
            "vfcn"].param_values.data if args.warmstart is not None else None
    else:
        vpi = None

    # Set seed if desired
    pyrado.set_seed(args.seed, verbose=True)

    # Train the policy on the most lucrative domain
    BayRn.train_argmax_policy(ex_dir,
                              env_sim,
    env, policy, kwout = load_experiment(ex_dir, args)
    env_real = pyrado.load("env_real.pkl", ex_dir)
    data_real = kwout["data_real"]
    if args.iter == -1:
        # This script is not made to evaluate multiple iterations at once, thus we always select the data one iteration
        data_real = to.atleast_2d(data_real[args.iter])

    # Override the time step size if specified
    if args.dt is not None:
        env.dt = args.dt

    # Use the environments number of steps in case of the default argument (inf)
    max_steps = env.max_steps if args.max_steps == pyrado.inf else args.max_steps

    # Check which algorithm was used in the experiment
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name="algo")
    if not isinstance(algo, (NPDR, BayesSim)):
        raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim))

    # Sample domain parameters from the posterior. Use all samples, by hijacking the get_ml_posterior_samples to obtain
    # them sorted.
    domain_params, log_probs = SBIBase.get_ml_posterior_samples(
        dp_mapping=algo.dp_mapping,
        posterior=kwout["posterior"],
        data_real=data_real,
        num_eval_samples=args.num_samples,
        num_ml_samples=args.num_samples,
        calculate_log_probs=True,
        normalize_posterior=args.normalize,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc),
        return_as_tensor=False,
Beispiel #10
0
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
# IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""
Continue a training run in the same folder
"""
import os.path as osp

from pyrado.algorithms.base import Algorithm
from pyrado.logger.experiment import ask_for_experiment, load_dict_from_yaml
from pyrado.utils.argparser import get_argparser

if __name__ == "__main__":
    # Parse command line arguments
    args = get_argparser().parse_args()

    # Get the experiment's directory to load from
    ex_dir = ask_for_experiment(
        hparam_list=args.show_hparams) if args.dir is None else args.dir

    # Load the hyper-parameters
    hparams = load_dict_from_yaml(osp.join(ex_dir, "hyperparams.yaml"))

    # Load the complete algorithm
    algo = Algorithm.load_snapshot(ex_dir)

    # Jeeeha
    algo.train(seed=hparams.get("seed", None))
Beispiel #11
0
def test_snapshots_notmeta(ex_dir, env: SimEnv, policy, algo_class,
                           algo_hparam):
    # Collect hyper-parameters, create algorithm, and train
    common_hparam = dict(max_iter=1, num_workers=1)
    common_hparam.update(algo_hparam)

    if issubclass(algo_class, ActorCritic):
        common_hparam.update(
            min_rollouts=3,
            critic=GAE(
                vfcn=FNNPolicy(spec=EnvSpec(env.obs_space, ValueFunctionSpace),
                               hidden_sizes=[16, 16],
                               hidden_nonlin=to.tanh)),
        )
    elif issubclass(algo_class, ParameterExploring):
        common_hparam.update(num_init_states_per_domain=1)
    elif issubclass(algo_class, (DQL, SAC)):
        common_hparam.update(memory_size=1000,
                             num_updates_per_step=2,
                             gamma=0.99,
                             min_rollouts=1)
        fnn_hparam = dict(hidden_sizes=[8, 8], hidden_nonlin=to.tanh)
        if issubclass(algo_class, DQL):
            # Override the setting
            env = BallOnBeamDiscSim(env.dt, env.max_steps)
            net = FNN(
                input_size=DiscreteActQValPolicy.get_qfcn_input_size(env.spec),
                output_size=DiscreteActQValPolicy.get_qfcn_output_size(),
                **fnn_hparam,
            )
            policy = DiscreteActQValPolicy(spec=env.spec, net=net)
        else:
            # Override the setting
            env = ActNormWrapper(env)
            policy = TwoHeadedGRUPolicy(env.spec,
                                        shared_hidden_size=8,
                                        shared_num_recurrent_layers=1)
            obsact_space = BoxSpace.cat([env.obs_space, env.act_space])
            common_hparam.update(qfcn_1=FNNPolicy(
                spec=EnvSpec(obsact_space, ValueFunctionSpace), **fnn_hparam))
            common_hparam.update(qfcn_2=FNNPolicy(
                spec=EnvSpec(obsact_space, ValueFunctionSpace), **fnn_hparam))
    else:
        raise NotImplementedError

    # Simulate training
    algo = algo_class(ex_dir, env, policy, **common_hparam)
    algo.policy.param_values += to.tensor([42.0])
    if isinstance(algo, ActorCritic):
        algo.critic.vfcn.param_values += to.tensor([42.0])

    # Save and load
    algo.save_snapshot(meta_info=None)
    algo_loaded = Algorithm.load_snapshot(load_dir=ex_dir)
    assert isinstance(algo_loaded, Algorithm)
    policy_loaded = algo_loaded.policy
    if isinstance(algo, ActorCritic):
        critic_loaded = algo_loaded.critic

    # Check
    assert all(algo.policy.param_values == policy_loaded.param_values)
    if isinstance(algo, ActorCritic):
        assert all(
            algo.critic.vfcn.param_values == critic_loaded.vfcn.param_values)

    # Load the experiment. Since we did not save any hyper-parameters, we ignore the errors when loading.
    env, policy, extra = load_experiment(ex_dir)
    assert isinstance(env, Env)
    assert isinstance(policy, Policy)
    assert isinstance(extra, dict)
Beispiel #12
0
def load_experiment(
    ex_dir: str,
    args: Any = None
) -> Tuple[Optional[Union[SimEnv, EnvWrapper]], Optional[Policy],
           Optional[dict]]:
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser, pass `None` to fall back to the values from the default argparser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    env, policy, extra = None, None, dict()

    if args is None:
        # Fall back to default arguments. By passing [], we ignore the command line arguments
        args = get_argparser().parse_args([])

    # Hyper-parameters
    extra["hparams"] = load_hyperparameters(ex_dir)

    # Algorithm specific
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name="algo")

    if algo.name == "spota":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(10)
            print_cbt(
                f"Loaded the domain randomizer\n{env.randomizer}\nand filled it with 10 random instances.",
                "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(algo.subroutine_cand.policy,
                             f"{args.policy_name}.pt",
                             ex_dir,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine_cand, ActorCritic):
            extra["vfcn"] = pyrado.load(algo.subroutine_cand.critic.vfcn,
                                        f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        verbose=True)

    elif algo.name == "bayrn":
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if hasattr(env, "randomizer"):
            last_cand = to.load(osp.join(ex_dir, "candidates.pt"))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo.subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name == "simopt":
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            last_cand = to.load(osp.join(ex_dir, "candidates.pt"))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "r")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.subroutine_policy.policy,
                             verbose=True)
        # Extra (domain parameter distribution policy)
        extra["ddp_policy"] = pyrado.load("ddp_policy.pt",
                                          ex_dir,
                                          obj=algo.subroutine_distr.policy,
                                          verbose=True)

    elif algo.name in ["epopt", "udr"]:
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperLive):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperLive)
            print_cbt(f"Loaded the domain randomizer\n{env.randomizer}", "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "y")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo.subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name in ["bayessim", "npdr"]:
        # Environment
        env = pyrado.load("env_sim.pkl", ex_dir)
        if getattr(env, "randomizer", None) is not None:
            if not isinstance(env, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(10)
            print_cbt(
                f"Loaded the domain randomizer\n{env.randomizer}\nand filled it with 10 random instances.",
                "w")
        else:
            print_cbt("Loaded environment has no randomizer, or it is None.",
                      "y")
            env = remove_all_dr_wrappers(env, verbose=True)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (prior, posterior, data)
        extra["prior"] = pyrado.load("prior.pt", ex_dir, verbose=True)
        # By default load the latest posterior (latest iteration and the last round)
        try:
            extra["posterior"] = algo.load_posterior(ex_dir,
                                                     args.iter,
                                                     args.round,
                                                     obj=None,
                                                     verbose=True)
            # Load the complete data or the data of the given iteration
            prefix = "" if args.iter == -1 else f"iter_{args.iter}"
            extra["data_real"] = pyrado.load(f"data_real.pt",
                                             ex_dir,
                                             prefix=prefix,
                                             verbose=True)
        except FileNotFoundError:
            pass

    elif algo.name in ["a2c", "ppo", "ppo2"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (value function)
        extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                    ex_dir,
                                    obj=algo.critic.vfcn,
                                    verbose=True)

    elif algo.name in ["hc", "pepg", "power", "cem", "reps", "nes"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)

    elif algo.name in ["dql", "sac"]:
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Target value functions
        if algo.name == "dql":
            extra["qfcn_target"] = pyrado.load("qfcn_target.pt",
                                               ex_dir,
                                               obj=algo.qfcn_targ,
                                               verbose=True)
        elif algo.name == "sac":
            extra["qfcn_target1"] = pyrado.load("qfcn_target1.pt",
                                                ex_dir,
                                                obj=algo.qfcn_targ_1,
                                                verbose=True)
            extra["qfcn_target2"] = pyrado.load("qfcn_target2.pt",
                                                ex_dir,
                                                obj=algo.qfcn_targ_2,
                                                verbose=True)
        else:
            raise NotImplementedError

    elif algo.name == "svpg":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Extra (particles)
        for idx, p in enumerate(algo.particles):
            extra[f"particle{idx}"] = pyrado.load(f"particle_{idx}.pt",
                                                  ex_dir,
                                                  obj=algo.particles[idx],
                                                  verbose=True)

    elif algo.name == "tspred":
        # Dataset
        extra["dataset"] = to.load(osp.join(ex_dir, "dataset.pt"))
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)

    elif algo.name == "sprl":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env.pkl')}.", "g")
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt", ex_dir, obj=algo.policy)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", "g")
        # Extra (value function)
        if isinstance(algo._subroutine, ActorCritic):
            extra["vfcn"] = pyrado.load(f"{args.vfcn_name}.pt",
                                        ex_dir,
                                        obj=algo._subroutine.critic.vfcn,
                                        verbose=True)

    elif algo.name == "pddr":
        # Environment
        env = pyrado.load("env.pkl", ex_dir)
        # Policy
        policy = pyrado.load(f"{args.policy_name}.pt",
                             ex_dir,
                             obj=algo.policy,
                             verbose=True)
        # Teachers
        extra["teacher_policies"] = algo.teacher_policies
        extra["teacher_envs"] = algo.teacher_envs
        extra["teacher_expl_strats"] = algo.teacher_expl_strats
        extra["teacher_critics"] = algo.teacher_critics
        extra["teacher_ex_dirs"] = algo.teacher_ex_dirs

    else:
        raise pyrado.TypeErr(
            msg=
            "No matching algorithm name found during loading the experiment!")

    # Check if the return types are correct. They can be None, too.
    if env is not None and not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if policy is not None and not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if extra is not None and not isinstance(extra, dict):
        raise pyrado.TypeErr(given=extra, expected_type=dict)

    return env, policy, extra
Beispiel #13
0
def load_experiment(
        ex_dir: str,
        args: Any = None) -> (Union[SimEnv, EnvWrapper], Policy, dict):
    """
    Load the (training) environment and the policy.
    This helper function first tries to read the hyper-parameters yaml-file in the experiment's directory to infer
    why entities should be loaded. If no file was found, we fall back to some heuristic and hope for the best.

    :param ex_dir: experiment's parent directory
    :param args: arguments from the argument parser, pass `None` to fall back to the values from the default argparser
    :return: environment, policy, and optional output (e.g. valuefcn)
    """
    env, policy, extra = None, None, dict()

    if args is None:
        # Fall back to default arguments. By passing [], we ignore the command line arguments
        args = get_argparser().parse_args([])

    # Hyper-parameters
    hparams_file_name = 'hyperparams.yaml'
    try:
        hparams = load_dict_from_yaml(osp.join(ex_dir, hparams_file_name))
        extra['hparams'] = hparams
    except (pyrado.PathErr, FileNotFoundError, KeyError):
        print_cbt(
            f'Did not find {hparams_file_name} in {ex_dir} or could not crawl the loaded hyper-parameters.',
            'y',
            bright=True)

    # Algorithm specific
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name='algo')
    if isinstance(algo, BayRn):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SPOTA):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperBuffer):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperBuffer)
            typed_env(env, DomainRandWrapperBuffer).fill_buffer(100)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} and filled it with 100 random instances.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_cand.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine_cand, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine_cand.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, SimOpt):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, 'env_sim.pkl')}.", 'g')
        if hasattr(env, 'randomizer'):
            last_cand = to.load(osp.join(ex_dir, 'candidates.pt'))[-1, :]
            env.adapt_randomizer(last_cand.numpy())
            print_cbt(f'Loaded the domain randomizer\n{env.randomizer}', 'w')
        else:
            print_cbt('Loaded environment has no randomizer.', 'r')
        # Policy
        policy = pyrado.load(algo.subroutine_policy.policy,
                             f'{args.policy_name}', 'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (domain parameter distribution policy)
        extra['ddp_policy'] = pyrado.load(algo.subroutine_distr.policy,
                                          'ddp_policy', 'pt', ex_dir, None)

    elif isinstance(algo, (EPOpt, UDR)):
        # Environment
        env = pyrado.load(None, 'env_sim', 'pkl', ex_dir, None)
        if hasattr(env, 'randomizer'):
            if not isinstance(env.randomizer, DomainRandWrapperLive):
                raise pyrado.TypeErr(given=env.randomizer,
                                     expected_type=DomainRandWrapperLive)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'env.pkl')} with DomainRandWrapperLive randomizer.",
                'g')
        else:
            print_cbt('Loaded environment has no randomizer.', 'y')
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        if isinstance(algo.subroutine, ActorCritic):
            extra['vfcn'] = pyrado.load(algo.subroutine.critic.vfcn,
                                        f'{args.vfcn_name}', 'pt', ex_dir,
                                        None)
            print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}",
                      'g')

    elif isinstance(algo, ActorCritic):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (value function)
        extra['vfcn'] = pyrado.load(algo.critic.vfcn, f'{args.vfcn_name}',
                                    'pt', ex_dir, None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.vfcn_name}.pt')}", 'g')

    elif isinstance(algo, ParameterExploring):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')

    elif isinstance(algo, ValueBased):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Target value functions
        if isinstance(algo, DQL):
            extra['qfcn_target'] = pyrado.load(algo.qfcn_targ, 'qfcn_target',
                                               'pt', ex_dir, None)
            print_cbt(f"Loaded {osp.join(ex_dir, 'qfcn_target.pt')}", 'g')
        elif isinstance(algo, SAC):
            extra['qfcn_target1'] = pyrado.load(algo.qfcn_targ_1,
                                                'qfcn_target1', 'pt', ex_dir,
                                                None)
            extra['qfcn_target2'] = pyrado.load(algo.qfcn_targ_2,
                                                'qfcn_target2', 'pt', ex_dir,
                                                None)
            print_cbt(
                f"Loaded {osp.join(ex_dir, 'qfcn_target1.pt')} and {osp.join(ex_dir, 'qfcn_target2.pt')}",
                'g')
        else:
            raise NotImplementedError

    elif isinstance(algo, SVPG):
        # Environment
        env = pyrado.load(None, 'env', 'pkl', ex_dir, None)
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)
        print_cbt(f"Loaded {osp.join(ex_dir, f'{args.policy_name}.pt')}", 'g')
        # Extra (particles)
        for idx, p in enumerate(algo.particles):
            extra[f'particle{idx}'] = pyrado.load(algo.particles[idx],
                                                  f'particle_{idx}', 'pt',
                                                  ex_dir, None)

    elif isinstance(algo, TSPred):
        # Dataset
        extra['dataset'] = to.load(osp.join(ex_dir, 'dataset.pt'))
        # Policy
        policy = pyrado.load(algo.policy, f'{args.policy_name}', 'pt', ex_dir,
                             None)

    else:
        raise pyrado.TypeErr(
            msg=
            'No matching algorithm name found during loading the experiment!')

    # Check if the return types are correct. They can be None, too.
    if env is not None and not isinstance(env, (SimEnv, EnvWrapper)):
        raise pyrado.TypeErr(given=env, expected_type=[SimEnv, EnvWrapper])
    if policy is not None and not isinstance(policy, Policy):
        raise pyrado.TypeErr(given=policy, expected_type=Policy)
    if extra is not None and not isinstance(extra, dict):
        raise pyrado.TypeErr(given=extra, expected_type=dict)

    return env, policy, extra
    parser.add_argument('--new_ex_dir', type=str, nargs='?',
                        help="path to the directory where the experiment should be saved/moved to")
    args = parser.parse_args()

    if not osp.isdir(args.ex_dir):
        raise pyrado.PathErr(given=args.ex_dir)
    if args.new_ex_dir is None:
        raise pyrado.ValueErr(msg='Provide the path to the new experiment directory using --new_ex_dir')

    # Create the new directory and test it
    os.makedirs(args.new_ex_dir, exist_ok=True)
    if not osp.isdir(args.new_ex_dir):
        raise pyrado.PathErr(given=args.new_ex_dir)

    # Load the old algorithm including the loggers
    algo = Algorithm.load_snapshot(args.ex_dir)

    # Update all entries that contain information about where the experiment is stored
    algo.save_dir = args.new_ex_dir
    for printer in algo.logger.printers:
        if isinstance(printer, CSVPrinter):
            printer.file = osp.join(args.new_ex_dir, printer.file[printer.file.rfind('/') + 1:])
        elif isinstance(printer, TensorBoardPrinter):
            printer.dir = args.new_ex_dir

    # Copy the complete content
    copy_tree(args.ex_dir, args.new_ex_dir)

    # Save the new algorithm with the updated entries
    algo.save_snapshot()