Beispiel #1
0
    def train_init_policies(self):
        """
        Initialize the algorithm with a number of random distribution parameter sets a.k.a. candidates specified by
        the user. Train a policy for every candidate. Finally, store the policies and candidates.
        """
        cands = to.empty(self.num_init_cand, self.cand_dim)
        for i in range(self.num_init_cand):
            print_cbt(
                f'Generating initial domain instance and policy {i + 1} of {self.num_init_cand} ...',
                'g',
                bright=True)
            # Generate random samples within bounds
            cands[i, :] = (self.bounds[1, :] - self.bounds[0, :]) * to.rand(
                self.bounds.shape[1]) + self.bounds[0, :]
            # Train a policy for each candidate, repeat if the resulting policy did not exceed the success thold
            print_cbt(
                f'Randomly sampled the next candidate: {cands[i].numpy()}',
                'g')
            wrapped_trn_fcn = until_thold_exceeded(
                self.thold_succ_subroutine.item(),
                max_iter=self.max_subroutine_rep)(self.train_policy_sim)
            wrapped_trn_fcn(cands[i], prefix=f'init_{i}')

        # Save candidates into a single tensor (policy is saved during training or exists already)
        to.save(cands, osp.join(self._save_dir, 'candidates.pt'))
        self.cands = cands
Beispiel #2
0
    def train_init_policies(self):
        """
        Initialize the algorithm with a number of random distribution parameter sets a.k.a. candidates specified by
        the user. Train a policy for every candidate. Finally, store the policies and candidates.
        """
        cands = to.empty(self.num_init_cand, self.ddp_space.shape[0])
        for i in range(self.num_init_cand):
            print_cbt(
                f"Generating initial domain instance and policy {i + 1} of {self.num_init_cand} ...",
                "g",
                bright=True)
            # Sample random domain distribution parameters
            cands[i, :] = to.from_numpy(self.ddp_space.sample_uniform())

            # Train a policy for each candidate, repeat if the resulting policy did not exceed the success threshold
            print_cbt(
                f"Randomly sampled the next candidate: {cands[i].numpy()}",
                "g")
            wrapped_trn_fcn = until_thold_exceeded(
                self.thold_succ_subrtn.item(),
                self.max_subrtn_rep)(self.train_policy_sim)
            wrapped_trn_fcn(cands[i], prefix=f"init_{i}")

        # Save candidates into a single tensor (policy is saved during training or exists already)
        pyrado.save(cands, "candidates.pt", self.save_dir)
        self.cands = cands
Beispiel #3
0
    def step(self, snapshot_mode: str = 'latest', meta_info: dict = None):
        # Save snapshot to save the correct iteration count
        self.save_snapshot()

        if self.curr_checkpoint == -2:
            # Train the initial policies in the source domain
            self.train_init_policies()
            self.reached_checkpoint()  # setting counter to -1

        if self.curr_checkpoint == -1:
            # Evaluate the initial policies in the target domain
            self.eval_init_policies()
            self.reached_checkpoint()  # setting counter to 0

        if self.curr_checkpoint == 0:
            # Normalize the input data and standardize the output data
            cands_norm = self.ddp_projector.project_to(self.cands)
            cands_values_stdized = standardize(self.cands_values).unsqueeze(1)

            # Create and fit the GP model
            gp = SingleTaskGP(cands_norm, cands_values_stdized)
            gp.likelihood.noise_covar.register_constraint('raw_noise', GreaterThan(1e-5))
            mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
            fit_gpytorch_model(mll)
            print_cbt('Fitted the GP.', 'g')

            # Acquisition functions
            if self.acq_fcn_type == 'UCB':
                acq_fcn = UpperConfidenceBound(gp, beta=self.acq_param.get('beta', 0.1), maximize=True)
            elif self.acq_fcn_type == 'EI':
                acq_fcn = ExpectedImprovement(gp, best_f=cands_values_stdized.max().item(), maximize=True)
            elif self.acq_fcn_type == 'PI':
                acq_fcn = ProbabilityOfImprovement(gp, best_f=cands_values_stdized.max().item(), maximize=True)
            else:
                raise pyrado.ValueErr(given=self.acq_fcn_type, eq_constraint="'UCB', 'EI', 'PI'")

            # Optimize acquisition function and get new candidate point
            cand_norm, acq_value = optimize_acqf(
                acq_function=acq_fcn,
                bounds=to.stack([to.zeros(self.ddp_space.flat_dim), to.ones(self.ddp_space.flat_dim)]),
                q=1,
                num_restarts=self.acq_restarts,
                raw_samples=self.acq_samples
            )
            next_cand = self.ddp_projector.project_back(cand_norm)
            print_cbt(f'Found the next candidate: {next_cand.numpy()}', 'g')
            self.cands = to.cat([self.cands, next_cand], dim=0)
            pyrado.save(self.cands, 'candidates', 'pt', self.save_dir, meta_info)
            self.reached_checkpoint()  # setting counter to 1

        if self.curr_checkpoint == 1:
            # Train and evaluate a new policy, repeat if the resulting policy did not exceed the success threshold
            wrapped_trn_fcn = until_thold_exceeded(
                self.thold_succ_subrtn.item(), self.max_subrtn_rep
            )(self.train_policy_sim)
            wrapped_trn_fcn(self.cands[-1, :], prefix=f'iter_{self._curr_iter}')
            self.reached_checkpoint()  # setting counter to 2

        if self.curr_checkpoint == 2:
            # Evaluate the current policy in the target domain
            policy = pyrado.load(self.policy, 'policy', 'pt', self.save_dir,
                                        meta_info=dict(prefix=f'iter_{self._curr_iter}'))
            self.curr_cand_value = self.eval_policy(
                self.save_dir, self._env_real, policy, self.mc_estimator, f'iter_{self._curr_iter}',
                self.num_eval_rollouts_real
            )
            self.cands_values = to.cat([self.cands_values, self.curr_cand_value.view(1)], dim=0)
            pyrado.save(self.cands_values, 'candidates_values', 'pt', self.save_dir, meta_info)

            # Store the argmax after training and evaluating
            curr_argmax_cand = BayRn.argmax_posterior_mean(
                self.cands, self.cands_values.unsqueeze(1), self.ddp_space, self.acq_restarts, self.acq_samples
            )
            self.argmax_cand = to.cat([self.argmax_cand, curr_argmax_cand], dim=0)
            pyrado.save(self.argmax_cand, 'candidates_argmax', 'pt', self.save_dir, meta_info)
            self.reached_checkpoint()  # setting counter to 0
Beispiel #4
0
    def step(self, snapshot_mode: str, meta_info: dict = None):
        if not self.initialized:
            # Start initialization phase
            self.train_init_policies()
            self.eval_init_policies()
            self.initialized = True

        # Normalize the input data and standardize the output data
        cands_norm = self.uc_normalizer.project_to(self.cands)
        cands_values_stdized = standardize(self.cands_values).unsqueeze(1)

        # Create and fit the GP model
        gp = SingleTaskGP(cands_norm, cands_values_stdized)
        gp.likelihood.noise_covar.register_constraint('raw_noise',
                                                      GreaterThan(1e-5))
        mll = ExactMarginalLogLikelihood(gp.likelihood, gp)
        fit_gpytorch_model(mll)
        print_cbt('Fitted the GP.', 'g')

        # Acquisition functions
        if self.acq_fcn_type == 'UCB':
            acq_fcn = UpperConfidenceBound(gp,
                                           beta=self.acq_param.get(
                                               'beta', 0.1),
                                           maximize=True)
        elif self.acq_fcn_type == 'EI':
            acq_fcn = ExpectedImprovement(
                gp, best_f=cands_values_stdized.max().item(), maximize=True)
        elif self.acq_fcn_type == 'PI':
            acq_fcn = ProbabilityOfImprovement(
                gp, best_f=cands_values_stdized.max().item(), maximize=True)
        else:
            raise pyrado.ValueErr(given=self.acq_fcn_type,
                                  eq_constraint="'UCB', 'EI', 'PI'")

        # Optimize acquisition function and get new candidate point
        cand, acq_value = optimize_acqf(
            acq_function=acq_fcn,
            bounds=to.stack([to.zeros(self.cand_dim),
                             to.ones(self.cand_dim)]),
            q=1,
            num_restarts=self.acq_restarts,
            raw_samples=self.acq_samples)
        next_cand = self.uc_normalizer.project_back(cand)
        print_cbt(f'Found the next candidate: {next_cand.numpy()}', 'g')
        self.cands = to.cat([self.cands, next_cand], dim=0)
        to.save(self.cands, osp.join(self._save_dir, 'candidates.pt'))

        # Train and valuate the new candidate (saves to iter_{self._curr_iter}_policy.pt)
        prefix = f'iter_{self._curr_iter}'
        wrapped_trn_fcn = until_thold_exceeded(
            self.thold_succ_subroutine.item(),
            max_iter=self.max_subroutine_rep)(self.train_policy_sim)
        wrapped_trn_fcn(cand, prefix)

        # Evaluate the current policy on the target domain
        policy = to.load(osp.join(self._save_dir, f'{prefix}_policy.pt'))
        self.curr_cand_value = self.eval_policy(self._save_dir, self._env_real,
                                                policy,
                                                self.montecarlo_estimator,
                                                prefix,
                                                self.num_eval_rollouts_real)

        self.cands_values = to.cat(
            [self.cands_values,
             self.curr_cand_value.view(1)], dim=0)
        to.save(self.cands_values,
                osp.join(self._save_dir, 'candidates_values.pt'))

        # Store the argmax after training and evaluating
        curr_argmax_cand = BayRn.argmax_posterior_mean(
            self.cands, self.cands_values.unsqueeze(1), self.uc_normalizer,
            self.acq_restarts, self.acq_samples)
        self.argmax_cand = to.cat([self.argmax_cand, curr_argmax_cand], dim=0)
        to.save(self.argmax_cand,
                osp.join(self._save_dir, 'candidates_argmax.pt'))

        self.make_snapshot(snapshot_mode, float(to.mean(self.cands_values)),
                           meta_info)
Beispiel #5
0
    def step(self, snapshot_mode: str, meta_info: dict = None):
        """
        Perform a step of SPRL. This includes training the subroutine and updating the context distribution accordingly.
        For a description of the parameters see `pyrado.algorithms.base.Algorithm.step`.
        """
        self.save_snapshot()

        context_mean = to.cat([
            spl_param.context_mean for spl_param in self._spl_parameters
        ]).double()
        context_cov_chol = to.cat([
            spl_param.context_cov_chol_flat
            for spl_param in self._spl_parameters
        ]).double()

        target_mean = to.cat([
            spl_param.target_mean for spl_param in self._spl_parameters
        ]).double()
        target_cov_chol = to.cat([
            spl_param.target_cov_chol_flat
            for spl_param in self._spl_parameters
        ]).double()

        for param in self._spl_parameters:
            self.logger.add_value(f"cur context mean for {param.name}",
                                  param.context_mean.item())
            self.logger.add_value(f"cur context cov for {param.name}",
                                  param.context_cov.item())

        # If we are in the first iteration and have a bad performance,
        # we want to completely reset the policy if training is unsuccessful
        reset_policy = False
        if self.curr_iter == 0:
            reset_policy = True
        until_thold_exceeded(self._performance_lower_bound * 0.3,
                             self._max_subrtn_retries)(
                                 self._train_subroutine_and_evaluate_perf)(
                                     snapshot_mode, meta_info, reset_policy)

        # Update distribution
        previous_distribution = ParameterAgnosticMultivariateNormalWrapper(
            context_mean, context_cov_chol, self._optimize_mean,
            self._optimize_cov)
        target_distribution = ParameterAgnosticMultivariateNormalWrapper(
            target_mean, target_cov_chol, self._optimize_mean,
            self._optimize_cov)

        rollouts_all = self._subroutine.sampler.rollouts
        contexts = to.tensor(
            [[
                to.from_numpy(ro.rollout_info["domain_param"][param.name])
                for rollouts in rollouts_all for ro in rollouts
            ] for param in self._spl_parameters],
            requires_grad=True,
        ).T

        contexts_old_log_prob = previous_distribution.distribution.log_prob(
            contexts.double())
        kl_divergence = to.distributions.kl_divergence(
            previous_distribution.distribution,
            target_distribution.distribution)

        values = to.tensor([
            ro.undiscounted_return() for rollouts in rollouts_all
            for ro in rollouts
        ])

        def kl_constraint_fn(x):
            """Compute the constraint for the KL-divergence between current and proposed distribution."""
            distribution = previous_distribution.from_stacked(x)
            kl_divergence = to.distributions.kl_divergence(
                previous_distribution.distribution, distribution.distribution)
            return kl_divergence.detach().numpy()

        def kl_constraint_fn_prime(x):
            """Compute the derivative for the KL-constraint (used for scipy optimizer)."""
            distribution = previous_distribution.from_stacked(x)
            kl_divergence = to.distributions.kl_divergence(
                previous_distribution.distribution, distribution.distribution)
            grads = to.autograd.grad(kl_divergence, distribution.parameters())
            return np.concatenate([g.detach().numpy() for g in grads])

        kl_constraint = NonlinearConstraint(
            fun=kl_constraint_fn,
            lb=-np.inf,
            ub=self._kl_constraints_ub,
            jac=kl_constraint_fn_prime,
            # keep_feasible=True,
        )

        def performance_constraint_fn(x):
            """Compute the constraint for the expected performance under the proposed distribution."""
            distribution = previous_distribution.from_stacked(x)
            performance = self._compute_expected_performance(
                distribution, contexts, contexts_old_log_prob, values)
            return performance.detach().numpy()

        def performance_constraint_fn_prime(x):
            """Compute the derivative for the performance-constraint (used for scipy optimizer)."""
            distribution = previous_distribution.from_stacked(x)
            performance = self._compute_expected_performance(
                distribution, contexts, contexts_old_log_prob, values)
            grads = to.autograd.grad(performance, distribution.parameters())
            return np.concatenate([g.detach().numpy() for g in grads])

        performance_constraint = NonlinearConstraint(
            fun=performance_constraint_fn,
            lb=self._performance_lower_bound,
            ub=np.inf,
            jac=performance_constraint_fn_prime,
            # keep_feasible=True,
        )

        # Optionally clip the bounds of the new variance
        bounds = None
        x0, _, x0_cov_indices = previous_distribution.get_stacked(
            return_mean_cov_indices=True)
        if self._kl_threshold != -np.inf and (self._kl_threshold <
                                              kl_divergence):
            lower_bound = np.ones_like(x0) * -np.inf
            if x0_cov_indices is not None:
                lower_bound[x0_cov_indices] = self._std_lower_bound
            upper_bound = np.ones_like(x0) * np.inf
            # bounds = Bounds(lb=lower_bound, ub=upper_bound, keep_feasible=True)
            bounds = Bounds(lb=lower_bound, ub=upper_bound)
            x0 = np.clip(x0, lower_bound, upper_bound)

        objective_fn: Optional[Callable[..., Tuple[np.array, np.array]]] = None
        result = None
        constraints = None

        # Check whether we are already above our performance threshold
        if performance_constraint_fn(x0) >= self._performance_lower_bound:
            self._performance_lower_bound_reached = True
            constraints = [kl_constraint, performance_constraint]

            # We now optimize based on the kl-divergence between target and context distribution by minimizing it
            def objective(x):
                """Optimization objective before the minimum specified performance was reached.
                Tries to find the minimum kl divergence between the current and the update distribution, which
                still satisfies the minimum update constraint and the performance constraint."""
                distribution = previous_distribution.from_stacked(x)
                kl_divergence = to.distributions.kl_divergence(
                    distribution.distribution,
                    target_distribution.distribution)
                grads = to.autograd.grad(kl_divergence,
                                         distribution.parameters())

                return (
                    kl_divergence.detach().numpy(),
                    np.concatenate([g.detach().numpy() for g in grads]),
                )

            objective_fn = objective

        # If we have never reached the performance threshold we optimize just based on the kl constraint
        elif not self._performance_lower_bound_reached:
            constraints = [kl_constraint]

            # Now we optimize on the expected performance, meaning maximizing it
            def objective(x):
                """Optimization objective when the minimum specified performance was reached.
                Tries to maximizes performance while still satisfying the minimum kl update constraint."""
                distribution = previous_distribution.from_stacked(x)
                performance = self._compute_expected_performance(
                    distribution, contexts, contexts_old_log_prob, values)
                grads = to.autograd.grad(performance,
                                         distribution.parameters())

                return (
                    -performance.detach().numpy(),
                    -np.concatenate([g.detach().numpy() for g in grads]),
                )

            objective_fn = objective

        if objective_fn:
            result = minimize(
                objective_fn,
                x0,
                method="trust-constr",
                jac=True,
                constraints=constraints,
                options={
                    "gtol": 1e-4,
                    "xtol": 1e-6
                },
                bounds=bounds,
            )
        if result and result.success:
            self._adapt_parameters(result.x)

        # We have a result but the optimization process was not a success
        elif result:
            old_f = objective_fn(previous_distribution.get_stacked())[0]
            constraints_satisfied = all(
                (const.lb <= const.fun(result.x) <= const.ub
                 for const in constraints))

            std_ok = bounds is None or (np.all(
                bounds.lb <= result.x)) and np.all(result.x <= bounds.ub)

            if constraints_satisfied and std_ok and result.fun < old_f:
                self._adapt_parameters(result.x)
            else:
                print(
                    f"Update unsuccessful, keeping old values spl parameters")
Beispiel #6
0
    def step(self, snapshot_mode: str = "latest", meta_info: dict = None):
        # Save snapshot to save the correct iteration count
        self.save_snapshot()

        if self.curr_checkpoint == -1:
            if self._subrtn_policy is not None and self._train_initial_policy:
                # Add dummy values of variables that are logger later
                self.logger.add_value("avg log prob", -pyrado.inf)

                # Train the behavioral policy using the samples obtained from the prior.
                # Repeat the training if the resulting policy did not exceed the success threshold.
                domain_params = self._sbi_prior.sample(
                    sample_shape=(self.num_eval_samples, ))
                print_cbt(
                    "Training the initial policy using domain parameter sets sampled from prior.",
                    "c")
                wrapped_trn_fcn = until_thold_exceeded(
                    self.thold_succ_subrtn,
                    self.max_subrtn_rep)(self.train_policy_sim)
                wrapped_trn_fcn(
                    domain_params, prefix="init",
                    use_rec_init_states=False)  # overrides policy.pt

            self.reached_checkpoint()  # setting counter to 0

        if self.curr_checkpoint == 0:
            # Check if the rollout files already exist
            if (osp.isfile(
                    osp.join(self._save_dir,
                             f"iter_{self.curr_iter}_data_real.pt"))
                    and osp.isfile(osp.join(self._save_dir, "data_real.pt"))
                    and osp.isfile(
                        osp.join(self._save_dir, "rollouts_real.pkl"))):
                # Rollout files do exist (can be when continuing a previous experiment)
                self._curr_data_real = pyrado.load(
                    "data_real.pt",
                    self._save_dir,
                    prefix=f"iter_{self.curr_iter}")
                print_cbt(
                    f"Loaded existing rollout data for iteration {self.curr_iter}.",
                    "w")

            else:
                # If the policy depends on the domain-parameters, reset the policy with the
                # most likely dp-params from the previous round.
                pyrado.load(
                    "policy.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter - 1}"
                    if self.curr_iter != 0 else "init",
                    obj=self._policy,
                )
                if self.curr_iter != 0:
                    ml_domain_param = pyrado.load(
                        "ml_domain_param.pkl",
                        self.save_dir,
                        prefix=f"iter_{self._curr_iter - 1}")
                    self._policy.reset(**dict(domain_param=ml_domain_param))

                # Rollout files do not exist yet (usual case)
                self._curr_data_real, _ = SBIBase.collect_data_real(
                    self.save_dir,
                    self._env_real,
                    self._policy,
                    self._embedding,
                    prefix=f"iter_{self._curr_iter}",
                    num_rollouts=self.num_real_rollouts,
                    num_segments=self.num_segments,
                    len_segments=self.len_segments,
                )

                # Save the target domain data
                if self._curr_iter == 0:
                    # Append the first set of data
                    pyrado.save(self._curr_data_real, "data_real.pt",
                                self._save_dir)
                else:
                    # Append and save all data
                    prev_data = pyrado.load("data_real.pt", self._save_dir)
                    data_real_hist = to.cat([prev_data, self._curr_data_real],
                                            dim=0)
                    pyrado.save(data_real_hist, "data_real.pt", self._save_dir)

            # Initialize sbi simulator and prior
            self._setup_sbi(
                prior=self._sbi_prior,
                rollouts_real=pyrado.load("rollouts_real.pkl",
                                          self._save_dir,
                                          prefix=f"iter_{self._curr_iter}"),
            )

            self.reached_checkpoint()  # setting counter to 1

        if self.curr_checkpoint == 1:
            # Instantiate the sbi subroutine to retrain from scratch each iteration
            if self.reset_sbi_routine_each_iter:
                self._initialize_subrtn_sbi(
                    subrtn_sbi_class=SNPE_A,
                    num_components=self._num_components)

            # Initialize the proposal with the prior
            proposal = self._sbi_prior

            # Multi-round sbi
            for idx_r in range(self.num_sbi_rounds):
                # Sample parameters proposal, and simulate these parameters to obtain the data
                domain_param, data_sim = simulate_for_sbi(
                    simulator=self._sbi_simulator,
                    proposal=proposal,
                    num_simulations=self.num_sim_per_round,
                    simulation_batch_size=self.simulation_batch_size,
                    num_workers=self.num_workers,
                )
                self._cnt_samples += self.num_sim_per_round * self._env_sim_sbi.max_steps

                # Append simulations and proposals for sbi
                self._subrtn_sbi.append_simulations(
                    domain_param,
                    data_sim,
                    proposal=
                    proposal,  # do not pass proposal arg for SNLE or SNRE
                )

                # Train the posterior
                density_estimator = self._subrtn_sbi.train(
                    final_round=idx_r == self.num_sbi_rounds - 1,
                    component_perturbation=self._component_perturbation,
                    **self.subrtn_sbi_training_hparam,
                )
                posterior = self._subrtn_sbi.build_posterior(
                    density_estimator=density_estimator,
                    **self.subrtn_sbi_sampling_hparam)

                # Save the posterior of this iteration before tailoring it to the data (when it is still amortized)
                if idx_r == 0:
                    pyrado.save(
                        posterior,
                        "posterior.pt",
                        self._save_dir,
                        prefix=f"iter_{self._curr_iter}",
                    )

                # Set proposal of the next round to focus on the next data set.
                # set_default_x() expects dim [1, num_rollouts * data_samples]
                proposal = posterior.set_default_x(self._curr_data_real)

                # Save the posterior tailored to each round
                pyrado.save(
                    posterior,
                    "posterior.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter}_round_{idx_r}",
                )

                # Override the latest posterior
                pyrado.save(posterior, "posterior.pt", self._save_dir)

            self.reached_checkpoint()  # setting counter to 2

        if self.curr_checkpoint == 2:
            # Logging (the evaluation can be time-intensive)
            posterior = pyrado.load("posterior.pt", self._save_dir)
            self._curr_domain_param_eval, log_probs = SBIBase.eval_posterior(
                posterior,
                self._curr_data_real,
                self.num_eval_samples,
                calculate_log_probs=True,
                normalize_posterior=self.normalize_posterior,
                subrtn_sbi_sampling_hparam=self.subrtn_sbi_sampling_hparam,
            )
            self.logger.add_value("avg log prob", to.mean(log_probs), 4)
            self.logger.add_value("num total samples", self._cnt_samples)

            # Extract the most likely domain parameter set out of all target domain data sets
            current_domain_param = self._env_sim_sbi.domain_param
            idx_ml = to.argmax(log_probs).item()
            dp_vals = self._curr_domain_param_eval[idx_ml //
                                                   self.num_eval_samples,
                                                   idx_ml %
                                                   self.num_eval_samples, :]
            dp_vals = to.atleast_1d(dp_vals).numpy()
            ml_domain_param = dict(
                zip(self.dp_mapping.values(), dp_vals.tolist()))

            # Update the unchanged domain parameters with the most likely ones obtained from the posterior
            current_domain_param.update(ml_domain_param)
            pyrado.save(current_domain_param,
                        "ml_domain_param.pkl",
                        self.save_dir,
                        prefix=f"iter_{self._curr_iter}")

            self.reached_checkpoint()  # setting counter to 3

        if self.curr_checkpoint == 3:
            # Policy optimization
            if self._subrtn_policy is not None:
                pyrado.load(
                    "policy.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter - 1}"
                    if self.curr_iter != 0 else "init",
                    obj=self._policy,
                )
                # Train the behavioral policy using the posterior samples obtained before.
                # Repeat the training if the resulting policy did not exceed the success threshold.
                print_cbt(
                    "Training the next policy using domain parameter sets sampled from the current posterior.",
                    "c")
                wrapped_trn_fcn = until_thold_exceeded(
                    self.thold_succ_subrtn,
                    self.max_subrtn_rep)(self.train_policy_sim)
                wrapped_trn_fcn(self._curr_domain_param_eval.squeeze(0),
                                prefix=f"iter_{self._curr_iter}",
                                use_rec_init_states=True)
            else:
                # save prefixed policy either way
                pyrado.save(self.policy,
                            "policy.pt",
                            self.save_dir,
                            prefix=f"iter_{self._curr_iter}",
                            use_state_dict=True)

            self.reached_checkpoint()  # setting counter to 0

        # Save snapshot data
        self.make_snapshot(snapshot_mode, None, meta_info)
Beispiel #7
0
    def step(self, snapshot_mode: str = 'latest', meta_info: dict = None):
        # Save snapshot to save the correct iteration count
        self.save_snapshot()

        if self.curr_checkpoint == 0:
            if self._curr_iter == 0:
                # First iteration, use the policy parameters (initialized from a prior)
                cand = self._subrtn_distr.policy.transform_to_ddp_space(
                    self._subrtn_distr.policy.param_values)
                self.cands = cand.unsqueeze(0)
            else:
                # Select the latest domain distribution parameter set
                assert isinstance(self.cands, to.Tensor)
                cand = self.cands[-1, :].clone()
            print_cbt(
                f'Current domain distribution parameters: {cand.detach().cpu().numpy()}',
                'g')

            # Train and evaluate the behavioral policy, repeat if the policy did not exceed the success threshold
            wrapped_trn_fcn = until_thold_exceeded(
                self.thold_succ_subrtn.item(),
                self.max_subrtn_rep)(self.train_policy_sim)
            wrapped_trn_fcn(cand, prefix=f'iter_{self._curr_iter}')

            # Save the latest behavioral policy
            self._subrtn_policy.save_snapshot()
            self.reached_checkpoint()  # setting counter to 1

        if self.curr_checkpoint == 1:
            # Evaluate the current policy in the target domain
            policy = pyrado.load(
                self.policy,
                'policy',
                'pt',
                self.save_dir,
                meta_info=dict(prefix=f'iter_{self._curr_iter}'))
            self.eval_behav_policy(self.save_dir, self._env_real, policy,
                                   f'iter_{self._curr_iter}',
                                   self.num_eval_rollouts, None)
            # if self._curr_iter == 0:
            #     # First iteration, also evaluate the random initialization
            #     self.cands_values = SimOpt.eval_ddp_policy(
            #         rollouts_real, self._env_sim, self.num_eval_rollouts, self._subrtn_distr, self._subrtn_policy
            #     )
            #     self.cands_values = to.tensor(self.cands_values).unsqueeze(0)
            self.reached_checkpoint()  # setting counter to 2

        if self.curr_checkpoint == 2:
            # Train and evaluate the policy that represents domain parameter distribution
            rollouts_real = pyrado.load(
                None,
                'rollouts_real',
                'pkl',
                self.save_dir,
                meta_info=dict(prefix=f'iter_{self._curr_iter}'))
            curr_cand_value = self.train_ddp_policy(
                rollouts_real, prefix=f'iter_{self._curr_iter}')
            if self._curr_iter == 0:
                self.cands_values = to.tensor(curr_cand_value).unsqueeze(0)
            else:
                self.cands_values = to.cat([
                    self.cands_values,
                    to.tensor(curr_cand_value).unsqueeze(0)
                ],
                                           dim=0)
            pyrado.save(self.cands_values, 'candidates_values', 'pt',
                        self.save_dir, meta_info)

            # The next candidate is the current search distribution and not the best policy parameter set (is saved)
            next_cand = self._subrtn_distr.policy.transform_to_ddp_space(
                self._subrtn_distr.policy.param_values)
            self.cands = to.cat([self.cands, next_cand.unsqueeze(0)], dim=0)
            pyrado.save(self.cands, 'candidates', 'pt', self.save_dir,
                        meta_info)

            # Save the latest domain distribution parameter policy
            self._subrtn_distr.save_snapshot(
                meta_info=dict(prefix='ddp', rollouts_real=rollouts_real))
            self.reached_checkpoint()  # setting counter to 0