def train_init_policies(self): """ Initialize the algorithm with a number of random distribution parameter sets a.k.a. candidates specified by the user. Train a policy for every candidate. Finally, store the policies and candidates. """ cands = to.empty(self.num_init_cand, self.cand_dim) for i in range(self.num_init_cand): print_cbt( f'Generating initial domain instance and policy {i + 1} of {self.num_init_cand} ...', 'g', bright=True) # Generate random samples within bounds cands[i, :] = (self.bounds[1, :] - self.bounds[0, :]) * to.rand( self.bounds.shape[1]) + self.bounds[0, :] # Train a policy for each candidate, repeat if the resulting policy did not exceed the success thold print_cbt( f'Randomly sampled the next candidate: {cands[i].numpy()}', 'g') wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subroutine.item(), max_iter=self.max_subroutine_rep)(self.train_policy_sim) wrapped_trn_fcn(cands[i], prefix=f'init_{i}') # Save candidates into a single tensor (policy is saved during training or exists already) to.save(cands, osp.join(self._save_dir, 'candidates.pt')) self.cands = cands
def train_init_policies(self): """ Initialize the algorithm with a number of random distribution parameter sets a.k.a. candidates specified by the user. Train a policy for every candidate. Finally, store the policies and candidates. """ cands = to.empty(self.num_init_cand, self.ddp_space.shape[0]) for i in range(self.num_init_cand): print_cbt( f"Generating initial domain instance and policy {i + 1} of {self.num_init_cand} ...", "g", bright=True) # Sample random domain distribution parameters cands[i, :] = to.from_numpy(self.ddp_space.sample_uniform()) # Train a policy for each candidate, repeat if the resulting policy did not exceed the success threshold print_cbt( f"Randomly sampled the next candidate: {cands[i].numpy()}", "g") wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn.item(), self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn(cands[i], prefix=f"init_{i}") # Save candidates into a single tensor (policy is saved during training or exists already) pyrado.save(cands, "candidates.pt", self.save_dir) self.cands = cands
def step(self, snapshot_mode: str = 'latest', meta_info: dict = None): # Save snapshot to save the correct iteration count self.save_snapshot() if self.curr_checkpoint == -2: # Train the initial policies in the source domain self.train_init_policies() self.reached_checkpoint() # setting counter to -1 if self.curr_checkpoint == -1: # Evaluate the initial policies in the target domain self.eval_init_policies() self.reached_checkpoint() # setting counter to 0 if self.curr_checkpoint == 0: # Normalize the input data and standardize the output data cands_norm = self.ddp_projector.project_to(self.cands) cands_values_stdized = standardize(self.cands_values).unsqueeze(1) # Create and fit the GP model gp = SingleTaskGP(cands_norm, cands_values_stdized) gp.likelihood.noise_covar.register_constraint('raw_noise', GreaterThan(1e-5)) mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_model(mll) print_cbt('Fitted the GP.', 'g') # Acquisition functions if self.acq_fcn_type == 'UCB': acq_fcn = UpperConfidenceBound(gp, beta=self.acq_param.get('beta', 0.1), maximize=True) elif self.acq_fcn_type == 'EI': acq_fcn = ExpectedImprovement(gp, best_f=cands_values_stdized.max().item(), maximize=True) elif self.acq_fcn_type == 'PI': acq_fcn = ProbabilityOfImprovement(gp, best_f=cands_values_stdized.max().item(), maximize=True) else: raise pyrado.ValueErr(given=self.acq_fcn_type, eq_constraint="'UCB', 'EI', 'PI'") # Optimize acquisition function and get new candidate point cand_norm, acq_value = optimize_acqf( acq_function=acq_fcn, bounds=to.stack([to.zeros(self.ddp_space.flat_dim), to.ones(self.ddp_space.flat_dim)]), q=1, num_restarts=self.acq_restarts, raw_samples=self.acq_samples ) next_cand = self.ddp_projector.project_back(cand_norm) print_cbt(f'Found the next candidate: {next_cand.numpy()}', 'g') self.cands = to.cat([self.cands, next_cand], dim=0) pyrado.save(self.cands, 'candidates', 'pt', self.save_dir, meta_info) self.reached_checkpoint() # setting counter to 1 if self.curr_checkpoint == 1: # Train and evaluate a new policy, repeat if the resulting policy did not exceed the success threshold wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn.item(), self.max_subrtn_rep )(self.train_policy_sim) wrapped_trn_fcn(self.cands[-1, :], prefix=f'iter_{self._curr_iter}') self.reached_checkpoint() # setting counter to 2 if self.curr_checkpoint == 2: # Evaluate the current policy in the target domain policy = pyrado.load(self.policy, 'policy', 'pt', self.save_dir, meta_info=dict(prefix=f'iter_{self._curr_iter}')) self.curr_cand_value = self.eval_policy( self.save_dir, self._env_real, policy, self.mc_estimator, f'iter_{self._curr_iter}', self.num_eval_rollouts_real ) self.cands_values = to.cat([self.cands_values, self.curr_cand_value.view(1)], dim=0) pyrado.save(self.cands_values, 'candidates_values', 'pt', self.save_dir, meta_info) # Store the argmax after training and evaluating curr_argmax_cand = BayRn.argmax_posterior_mean( self.cands, self.cands_values.unsqueeze(1), self.ddp_space, self.acq_restarts, self.acq_samples ) self.argmax_cand = to.cat([self.argmax_cand, curr_argmax_cand], dim=0) pyrado.save(self.argmax_cand, 'candidates_argmax', 'pt', self.save_dir, meta_info) self.reached_checkpoint() # setting counter to 0
def step(self, snapshot_mode: str, meta_info: dict = None): if not self.initialized: # Start initialization phase self.train_init_policies() self.eval_init_policies() self.initialized = True # Normalize the input data and standardize the output data cands_norm = self.uc_normalizer.project_to(self.cands) cands_values_stdized = standardize(self.cands_values).unsqueeze(1) # Create and fit the GP model gp = SingleTaskGP(cands_norm, cands_values_stdized) gp.likelihood.noise_covar.register_constraint('raw_noise', GreaterThan(1e-5)) mll = ExactMarginalLogLikelihood(gp.likelihood, gp) fit_gpytorch_model(mll) print_cbt('Fitted the GP.', 'g') # Acquisition functions if self.acq_fcn_type == 'UCB': acq_fcn = UpperConfidenceBound(gp, beta=self.acq_param.get( 'beta', 0.1), maximize=True) elif self.acq_fcn_type == 'EI': acq_fcn = ExpectedImprovement( gp, best_f=cands_values_stdized.max().item(), maximize=True) elif self.acq_fcn_type == 'PI': acq_fcn = ProbabilityOfImprovement( gp, best_f=cands_values_stdized.max().item(), maximize=True) else: raise pyrado.ValueErr(given=self.acq_fcn_type, eq_constraint="'UCB', 'EI', 'PI'") # Optimize acquisition function and get new candidate point cand, acq_value = optimize_acqf( acq_function=acq_fcn, bounds=to.stack([to.zeros(self.cand_dim), to.ones(self.cand_dim)]), q=1, num_restarts=self.acq_restarts, raw_samples=self.acq_samples) next_cand = self.uc_normalizer.project_back(cand) print_cbt(f'Found the next candidate: {next_cand.numpy()}', 'g') self.cands = to.cat([self.cands, next_cand], dim=0) to.save(self.cands, osp.join(self._save_dir, 'candidates.pt')) # Train and valuate the new candidate (saves to iter_{self._curr_iter}_policy.pt) prefix = f'iter_{self._curr_iter}' wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subroutine.item(), max_iter=self.max_subroutine_rep)(self.train_policy_sim) wrapped_trn_fcn(cand, prefix) # Evaluate the current policy on the target domain policy = to.load(osp.join(self._save_dir, f'{prefix}_policy.pt')) self.curr_cand_value = self.eval_policy(self._save_dir, self._env_real, policy, self.montecarlo_estimator, prefix, self.num_eval_rollouts_real) self.cands_values = to.cat( [self.cands_values, self.curr_cand_value.view(1)], dim=0) to.save(self.cands_values, osp.join(self._save_dir, 'candidates_values.pt')) # Store the argmax after training and evaluating curr_argmax_cand = BayRn.argmax_posterior_mean( self.cands, self.cands_values.unsqueeze(1), self.uc_normalizer, self.acq_restarts, self.acq_samples) self.argmax_cand = to.cat([self.argmax_cand, curr_argmax_cand], dim=0) to.save(self.argmax_cand, osp.join(self._save_dir, 'candidates_argmax.pt')) self.make_snapshot(snapshot_mode, float(to.mean(self.cands_values)), meta_info)
def step(self, snapshot_mode: str, meta_info: dict = None): """ Perform a step of SPRL. This includes training the subroutine and updating the context distribution accordingly. For a description of the parameters see `pyrado.algorithms.base.Algorithm.step`. """ self.save_snapshot() context_mean = to.cat([ spl_param.context_mean for spl_param in self._spl_parameters ]).double() context_cov_chol = to.cat([ spl_param.context_cov_chol_flat for spl_param in self._spl_parameters ]).double() target_mean = to.cat([ spl_param.target_mean for spl_param in self._spl_parameters ]).double() target_cov_chol = to.cat([ spl_param.target_cov_chol_flat for spl_param in self._spl_parameters ]).double() for param in self._spl_parameters: self.logger.add_value(f"cur context mean for {param.name}", param.context_mean.item()) self.logger.add_value(f"cur context cov for {param.name}", param.context_cov.item()) # If we are in the first iteration and have a bad performance, # we want to completely reset the policy if training is unsuccessful reset_policy = False if self.curr_iter == 0: reset_policy = True until_thold_exceeded(self._performance_lower_bound * 0.3, self._max_subrtn_retries)( self._train_subroutine_and_evaluate_perf)( snapshot_mode, meta_info, reset_policy) # Update distribution previous_distribution = ParameterAgnosticMultivariateNormalWrapper( context_mean, context_cov_chol, self._optimize_mean, self._optimize_cov) target_distribution = ParameterAgnosticMultivariateNormalWrapper( target_mean, target_cov_chol, self._optimize_mean, self._optimize_cov) rollouts_all = self._subroutine.sampler.rollouts contexts = to.tensor( [[ to.from_numpy(ro.rollout_info["domain_param"][param.name]) for rollouts in rollouts_all for ro in rollouts ] for param in self._spl_parameters], requires_grad=True, ).T contexts_old_log_prob = previous_distribution.distribution.log_prob( contexts.double()) kl_divergence = to.distributions.kl_divergence( previous_distribution.distribution, target_distribution.distribution) values = to.tensor([ ro.undiscounted_return() for rollouts in rollouts_all for ro in rollouts ]) def kl_constraint_fn(x): """Compute the constraint for the KL-divergence between current and proposed distribution.""" distribution = previous_distribution.from_stacked(x) kl_divergence = to.distributions.kl_divergence( previous_distribution.distribution, distribution.distribution) return kl_divergence.detach().numpy() def kl_constraint_fn_prime(x): """Compute the derivative for the KL-constraint (used for scipy optimizer).""" distribution = previous_distribution.from_stacked(x) kl_divergence = to.distributions.kl_divergence( previous_distribution.distribution, distribution.distribution) grads = to.autograd.grad(kl_divergence, distribution.parameters()) return np.concatenate([g.detach().numpy() for g in grads]) kl_constraint = NonlinearConstraint( fun=kl_constraint_fn, lb=-np.inf, ub=self._kl_constraints_ub, jac=kl_constraint_fn_prime, # keep_feasible=True, ) def performance_constraint_fn(x): """Compute the constraint for the expected performance under the proposed distribution.""" distribution = previous_distribution.from_stacked(x) performance = self._compute_expected_performance( distribution, contexts, contexts_old_log_prob, values) return performance.detach().numpy() def performance_constraint_fn_prime(x): """Compute the derivative for the performance-constraint (used for scipy optimizer).""" distribution = previous_distribution.from_stacked(x) performance = self._compute_expected_performance( distribution, contexts, contexts_old_log_prob, values) grads = to.autograd.grad(performance, distribution.parameters()) return np.concatenate([g.detach().numpy() for g in grads]) performance_constraint = NonlinearConstraint( fun=performance_constraint_fn, lb=self._performance_lower_bound, ub=np.inf, jac=performance_constraint_fn_prime, # keep_feasible=True, ) # Optionally clip the bounds of the new variance bounds = None x0, _, x0_cov_indices = previous_distribution.get_stacked( return_mean_cov_indices=True) if self._kl_threshold != -np.inf and (self._kl_threshold < kl_divergence): lower_bound = np.ones_like(x0) * -np.inf if x0_cov_indices is not None: lower_bound[x0_cov_indices] = self._std_lower_bound upper_bound = np.ones_like(x0) * np.inf # bounds = Bounds(lb=lower_bound, ub=upper_bound, keep_feasible=True) bounds = Bounds(lb=lower_bound, ub=upper_bound) x0 = np.clip(x0, lower_bound, upper_bound) objective_fn: Optional[Callable[..., Tuple[np.array, np.array]]] = None result = None constraints = None # Check whether we are already above our performance threshold if performance_constraint_fn(x0) >= self._performance_lower_bound: self._performance_lower_bound_reached = True constraints = [kl_constraint, performance_constraint] # We now optimize based on the kl-divergence between target and context distribution by minimizing it def objective(x): """Optimization objective before the minimum specified performance was reached. Tries to find the minimum kl divergence between the current and the update distribution, which still satisfies the minimum update constraint and the performance constraint.""" distribution = previous_distribution.from_stacked(x) kl_divergence = to.distributions.kl_divergence( distribution.distribution, target_distribution.distribution) grads = to.autograd.grad(kl_divergence, distribution.parameters()) return ( kl_divergence.detach().numpy(), np.concatenate([g.detach().numpy() for g in grads]), ) objective_fn = objective # If we have never reached the performance threshold we optimize just based on the kl constraint elif not self._performance_lower_bound_reached: constraints = [kl_constraint] # Now we optimize on the expected performance, meaning maximizing it def objective(x): """Optimization objective when the minimum specified performance was reached. Tries to maximizes performance while still satisfying the minimum kl update constraint.""" distribution = previous_distribution.from_stacked(x) performance = self._compute_expected_performance( distribution, contexts, contexts_old_log_prob, values) grads = to.autograd.grad(performance, distribution.parameters()) return ( -performance.detach().numpy(), -np.concatenate([g.detach().numpy() for g in grads]), ) objective_fn = objective if objective_fn: result = minimize( objective_fn, x0, method="trust-constr", jac=True, constraints=constraints, options={ "gtol": 1e-4, "xtol": 1e-6 }, bounds=bounds, ) if result and result.success: self._adapt_parameters(result.x) # We have a result but the optimization process was not a success elif result: old_f = objective_fn(previous_distribution.get_stacked())[0] constraints_satisfied = all( (const.lb <= const.fun(result.x) <= const.ub for const in constraints)) std_ok = bounds is None or (np.all( bounds.lb <= result.x)) and np.all(result.x <= bounds.ub) if constraints_satisfied and std_ok and result.fun < old_f: self._adapt_parameters(result.x) else: print( f"Update unsuccessful, keeping old values spl parameters")
def step(self, snapshot_mode: str = "latest", meta_info: dict = None): # Save snapshot to save the correct iteration count self.save_snapshot() if self.curr_checkpoint == -1: if self._subrtn_policy is not None and self._train_initial_policy: # Add dummy values of variables that are logger later self.logger.add_value("avg log prob", -pyrado.inf) # Train the behavioral policy using the samples obtained from the prior. # Repeat the training if the resulting policy did not exceed the success threshold. domain_params = self._sbi_prior.sample( sample_shape=(self.num_eval_samples, )) print_cbt( "Training the initial policy using domain parameter sets sampled from prior.", "c") wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn, self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn( domain_params, prefix="init", use_rec_init_states=False) # overrides policy.pt self.reached_checkpoint() # setting counter to 0 if self.curr_checkpoint == 0: # Check if the rollout files already exist if (osp.isfile( osp.join(self._save_dir, f"iter_{self.curr_iter}_data_real.pt")) and osp.isfile(osp.join(self._save_dir, "data_real.pt")) and osp.isfile( osp.join(self._save_dir, "rollouts_real.pkl"))): # Rollout files do exist (can be when continuing a previous experiment) self._curr_data_real = pyrado.load( "data_real.pt", self._save_dir, prefix=f"iter_{self.curr_iter}") print_cbt( f"Loaded existing rollout data for iteration {self.curr_iter}.", "w") else: # If the policy depends on the domain-parameters, reset the policy with the # most likely dp-params from the previous round. pyrado.load( "policy.pt", self._save_dir, prefix=f"iter_{self._curr_iter - 1}" if self.curr_iter != 0 else "init", obj=self._policy, ) if self.curr_iter != 0: ml_domain_param = pyrado.load( "ml_domain_param.pkl", self.save_dir, prefix=f"iter_{self._curr_iter - 1}") self._policy.reset(**dict(domain_param=ml_domain_param)) # Rollout files do not exist yet (usual case) self._curr_data_real, _ = SBIBase.collect_data_real( self.save_dir, self._env_real, self._policy, self._embedding, prefix=f"iter_{self._curr_iter}", num_rollouts=self.num_real_rollouts, num_segments=self.num_segments, len_segments=self.len_segments, ) # Save the target domain data if self._curr_iter == 0: # Append the first set of data pyrado.save(self._curr_data_real, "data_real.pt", self._save_dir) else: # Append and save all data prev_data = pyrado.load("data_real.pt", self._save_dir) data_real_hist = to.cat([prev_data, self._curr_data_real], dim=0) pyrado.save(data_real_hist, "data_real.pt", self._save_dir) # Initialize sbi simulator and prior self._setup_sbi( prior=self._sbi_prior, rollouts_real=pyrado.load("rollouts_real.pkl", self._save_dir, prefix=f"iter_{self._curr_iter}"), ) self.reached_checkpoint() # setting counter to 1 if self.curr_checkpoint == 1: # Instantiate the sbi subroutine to retrain from scratch each iteration if self.reset_sbi_routine_each_iter: self._initialize_subrtn_sbi( subrtn_sbi_class=SNPE_A, num_components=self._num_components) # Initialize the proposal with the prior proposal = self._sbi_prior # Multi-round sbi for idx_r in range(self.num_sbi_rounds): # Sample parameters proposal, and simulate these parameters to obtain the data domain_param, data_sim = simulate_for_sbi( simulator=self._sbi_simulator, proposal=proposal, num_simulations=self.num_sim_per_round, simulation_batch_size=self.simulation_batch_size, num_workers=self.num_workers, ) self._cnt_samples += self.num_sim_per_round * self._env_sim_sbi.max_steps # Append simulations and proposals for sbi self._subrtn_sbi.append_simulations( domain_param, data_sim, proposal= proposal, # do not pass proposal arg for SNLE or SNRE ) # Train the posterior density_estimator = self._subrtn_sbi.train( final_round=idx_r == self.num_sbi_rounds - 1, component_perturbation=self._component_perturbation, **self.subrtn_sbi_training_hparam, ) posterior = self._subrtn_sbi.build_posterior( density_estimator=density_estimator, **self.subrtn_sbi_sampling_hparam) # Save the posterior of this iteration before tailoring it to the data (when it is still amortized) if idx_r == 0: pyrado.save( posterior, "posterior.pt", self._save_dir, prefix=f"iter_{self._curr_iter}", ) # Set proposal of the next round to focus on the next data set. # set_default_x() expects dim [1, num_rollouts * data_samples] proposal = posterior.set_default_x(self._curr_data_real) # Save the posterior tailored to each round pyrado.save( posterior, "posterior.pt", self._save_dir, prefix=f"iter_{self._curr_iter}_round_{idx_r}", ) # Override the latest posterior pyrado.save(posterior, "posterior.pt", self._save_dir) self.reached_checkpoint() # setting counter to 2 if self.curr_checkpoint == 2: # Logging (the evaluation can be time-intensive) posterior = pyrado.load("posterior.pt", self._save_dir) self._curr_domain_param_eval, log_probs = SBIBase.eval_posterior( posterior, self._curr_data_real, self.num_eval_samples, calculate_log_probs=True, normalize_posterior=self.normalize_posterior, subrtn_sbi_sampling_hparam=self.subrtn_sbi_sampling_hparam, ) self.logger.add_value("avg log prob", to.mean(log_probs), 4) self.logger.add_value("num total samples", self._cnt_samples) # Extract the most likely domain parameter set out of all target domain data sets current_domain_param = self._env_sim_sbi.domain_param idx_ml = to.argmax(log_probs).item() dp_vals = self._curr_domain_param_eval[idx_ml // self.num_eval_samples, idx_ml % self.num_eval_samples, :] dp_vals = to.atleast_1d(dp_vals).numpy() ml_domain_param = dict( zip(self.dp_mapping.values(), dp_vals.tolist())) # Update the unchanged domain parameters with the most likely ones obtained from the posterior current_domain_param.update(ml_domain_param) pyrado.save(current_domain_param, "ml_domain_param.pkl", self.save_dir, prefix=f"iter_{self._curr_iter}") self.reached_checkpoint() # setting counter to 3 if self.curr_checkpoint == 3: # Policy optimization if self._subrtn_policy is not None: pyrado.load( "policy.pt", self._save_dir, prefix=f"iter_{self._curr_iter - 1}" if self.curr_iter != 0 else "init", obj=self._policy, ) # Train the behavioral policy using the posterior samples obtained before. # Repeat the training if the resulting policy did not exceed the success threshold. print_cbt( "Training the next policy using domain parameter sets sampled from the current posterior.", "c") wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn, self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn(self._curr_domain_param_eval.squeeze(0), prefix=f"iter_{self._curr_iter}", use_rec_init_states=True) else: # save prefixed policy either way pyrado.save(self.policy, "policy.pt", self.save_dir, prefix=f"iter_{self._curr_iter}", use_state_dict=True) self.reached_checkpoint() # setting counter to 0 # Save snapshot data self.make_snapshot(snapshot_mode, None, meta_info)
def step(self, snapshot_mode: str = 'latest', meta_info: dict = None): # Save snapshot to save the correct iteration count self.save_snapshot() if self.curr_checkpoint == 0: if self._curr_iter == 0: # First iteration, use the policy parameters (initialized from a prior) cand = self._subrtn_distr.policy.transform_to_ddp_space( self._subrtn_distr.policy.param_values) self.cands = cand.unsqueeze(0) else: # Select the latest domain distribution parameter set assert isinstance(self.cands, to.Tensor) cand = self.cands[-1, :].clone() print_cbt( f'Current domain distribution parameters: {cand.detach().cpu().numpy()}', 'g') # Train and evaluate the behavioral policy, repeat if the policy did not exceed the success threshold wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn.item(), self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn(cand, prefix=f'iter_{self._curr_iter}') # Save the latest behavioral policy self._subrtn_policy.save_snapshot() self.reached_checkpoint() # setting counter to 1 if self.curr_checkpoint == 1: # Evaluate the current policy in the target domain policy = pyrado.load( self.policy, 'policy', 'pt', self.save_dir, meta_info=dict(prefix=f'iter_{self._curr_iter}')) self.eval_behav_policy(self.save_dir, self._env_real, policy, f'iter_{self._curr_iter}', self.num_eval_rollouts, None) # if self._curr_iter == 0: # # First iteration, also evaluate the random initialization # self.cands_values = SimOpt.eval_ddp_policy( # rollouts_real, self._env_sim, self.num_eval_rollouts, self._subrtn_distr, self._subrtn_policy # ) # self.cands_values = to.tensor(self.cands_values).unsqueeze(0) self.reached_checkpoint() # setting counter to 2 if self.curr_checkpoint == 2: # Train and evaluate the policy that represents domain parameter distribution rollouts_real = pyrado.load( None, 'rollouts_real', 'pkl', self.save_dir, meta_info=dict(prefix=f'iter_{self._curr_iter}')) curr_cand_value = self.train_ddp_policy( rollouts_real, prefix=f'iter_{self._curr_iter}') if self._curr_iter == 0: self.cands_values = to.tensor(curr_cand_value).unsqueeze(0) else: self.cands_values = to.cat([ self.cands_values, to.tensor(curr_cand_value).unsqueeze(0) ], dim=0) pyrado.save(self.cands_values, 'candidates_values', 'pt', self.save_dir, meta_info) # The next candidate is the current search distribution and not the best policy parameter set (is saved) next_cand = self._subrtn_distr.policy.transform_to_ddp_space( self._subrtn_distr.policy.param_values) self.cands = to.cat([self.cands, next_cand.unsqueeze(0)], dim=0) pyrado.save(self.cands, 'candidates', 'pt', self.save_dir, meta_info) # Save the latest domain distribution parameter policy self._subrtn_distr.save_snapshot( meta_info=dict(prefix='ddp', rollouts_real=rollouts_real)) self.reached_checkpoint() # setting counter to 0