def _load_experiment(ex_dir: pyrado.PathLike): # Load the algorithm algo = Algorithm.load_snapshot(ex_dir) if not isinstance(algo, (NPDR, BayesSim)): raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim)) # Load the prior and the data prior = pyrado.load("prior.pt", ex_dir) data_real = pyrado.load("data_real.pt", ex_dir) # Load the posteriors posteriors = [ SBIBase.load_posterior(ex_dir, idx_round=i, verbose=True) for i in range(algo.num_sbi_rounds) ] posteriors = remove_none_from_list( posteriors) # in case the algorithm terminated early if data_real.shape[0] > len(posteriors): print_cbt( f"Found {data_real.shape[0]} data sets but {len(posteriors)} posteriors. Truncated the superfluous data.", "y", ) data_real = data_real[:len(posteriors), :] # Artificially repeat the data (which was the same for every round) to later be able to use the same code data_real = data_real.repeat(len(posteriors), 1) assert data_real.shape[0] == len(posteriors) return algo, prior, data_real, posteriors
raise pyrado.TypeErr(given=algo, expected_type=NPDR) env_sim = inner_env(pyrado.load("env_sim.pkl", ex_dir_npdr)) prior_npdr = pyrado.load("prior.pt", ex_dir_npdr) posterior_npdr = algo.load_posterior(ex_dir_npdr, idx_iter=0, idx_round=6, obj=None, verbose=True) # CHOICE data_real_npdr = pyrado.load(f"data_real.pt", ex_dir_npdr, prefix="iter_0", verbose=True) # CHOICE domain_params_npdr, log_probs = SBIBase.eval_posterior( posterior_npdr, data_real_npdr, args.num_samples, normalize_posterior=False, # not necessary here subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc), ) domain_params_posterior_npdr = domain_params_npdr.reshape( 1, -1, domain_params_npdr.shape[-1]).squeeze() # Bayessim ex_dir_bs = os.path.join(pyrado.TEMP_DIR, "mg-ik", "bayessim_time", "") algo = Algorithm.load_snapshot(ex_dir_bs) if not isinstance(algo, BayesSim): raise pyrado.TypeErr(given=algo, expected_type=BayesSim) posterior_bs = algo.load_posterior(ex_dir_bs, idx_iter=0, idx_round=0, obj=None,
nrows=1, ncols=3, figsize=pyrado. figsize_CoRL_6perrow_square # , constrained_layout=True ) for idx, (posterior, data) in enumerate(zip(posteriors, data_real)): # Select round or not if idx not in config["sel_rounds"]: continue if args.mode == "scatter": # Sample from the posterior domain_params, log_probs = SBIBase.eval_posterior( posterior, data.unsqueeze(0), args.num_samples, normalize_posterior=False, # not necessary here subrtn_sbi_sampling_hparam=dict( sample_with_mcmc=args.use_mcmc), ) domain_params = domain_params.squeeze(0) # Plot color_palette = sns.color_palette()[1:] _ = draw_posterior_scatter_2d( ax=axs[ax_cnt], dp_samples=[domain_params], dp_mapping=algo.dp_mapping, dims=(0, 1), prior=prior, env_sim=None, env_real=algo._env_real,
# Use the environments number of steps in case of the default argument (inf) max_steps = env.max_steps if args.max_steps == pyrado.inf else args.max_steps # Check which algorithm was used in the experiment algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name="algo") if not isinstance(algo, (NPDR, BayesSim)): raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim)) # Sample domain parameters from the posterior. Use all samples, by hijacking the get_ml_posterior_samples to obtain # them sorted. domain_params, log_probs = SBIBase.get_ml_posterior_samples( dp_mapping=algo.dp_mapping, posterior=kwout["posterior"], data_real=data_real, num_eval_samples=args.num_samples, num_ml_samples=args.num_samples, calculate_log_probs=True, normalize_posterior=args.normalize, subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc), return_as_tensor=False, ) assert len(domain_params ) == 1 # the list has as many elements as evaluated iterations domain_params = domain_params[0] if args.normalize: # If the posterior is normalized, we do not rescale the probabilities since they already sum to 1 probs = to.exp(log_probs) else: # If the posterior is not normalized, we rescale the probabilities to make them interpretable probs = to.exp(log_probs -
def step(self, snapshot_mode: str = "latest", meta_info: dict = None): # Save snapshot to save the correct iteration count self.save_snapshot() if self.curr_checkpoint == -1: if self._subrtn_policy is not None and self._train_initial_policy: # Add dummy values of variables that are logger later self.logger.add_value("avg log prob", -pyrado.inf) # Train the behavioral policy using the samples obtained from the prior. # Repeat the training if the resulting policy did not exceed the success threshold. domain_params = self._sbi_prior.sample( sample_shape=(self.num_eval_samples, )) print_cbt( "Training the initial policy using domain parameter sets sampled from prior.", "c") wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn, self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn( domain_params, prefix="init", use_rec_init_states=False) # overrides policy.pt self.reached_checkpoint() # setting counter to 0 if self.curr_checkpoint == 0: # Check if the rollout files already exist if (osp.isfile( osp.join(self._save_dir, f"iter_{self.curr_iter}_data_real.pt")) and osp.isfile(osp.join(self._save_dir, "data_real.pt")) and osp.isfile( osp.join(self._save_dir, "rollouts_real.pkl"))): # Rollout files do exist (can be when continuing a previous experiment) self._curr_data_real = pyrado.load( "data_real.pt", self._save_dir, prefix=f"iter_{self.curr_iter}") print_cbt( f"Loaded existing rollout data for iteration {self.curr_iter}.", "w") else: # If the policy depends on the domain-parameters, reset the policy with the # most likely dp-params from the previous round. pyrado.load( "policy.pt", self._save_dir, prefix=f"iter_{self._curr_iter - 1}" if self.curr_iter != 0 else "init", obj=self._policy, ) if self.curr_iter != 0: ml_domain_param = pyrado.load( "ml_domain_param.pkl", self.save_dir, prefix=f"iter_{self._curr_iter - 1}") self._policy.reset(**dict(domain_param=ml_domain_param)) # Rollout files do not exist yet (usual case) self._curr_data_real, _ = SBIBase.collect_data_real( self.save_dir, self._env_real, self._policy, self._embedding, prefix=f"iter_{self._curr_iter}", num_rollouts=self.num_real_rollouts, num_segments=self.num_segments, len_segments=self.len_segments, ) # Save the target domain data if self._curr_iter == 0: # Append the first set of data pyrado.save(self._curr_data_real, "data_real.pt", self._save_dir) else: # Append and save all data prev_data = pyrado.load("data_real.pt", self._save_dir) data_real_hist = to.cat([prev_data, self._curr_data_real], dim=0) pyrado.save(data_real_hist, "data_real.pt", self._save_dir) # Initialize sbi simulator and prior self._setup_sbi( prior=self._sbi_prior, rollouts_real=pyrado.load("rollouts_real.pkl", self._save_dir, prefix=f"iter_{self._curr_iter}"), ) self.reached_checkpoint() # setting counter to 1 if self.curr_checkpoint == 1: # Instantiate the sbi subroutine to retrain from scratch each iteration if self.reset_sbi_routine_each_iter: self._initialize_subrtn_sbi( subrtn_sbi_class=SNPE_A, num_components=self._num_components) # Initialize the proposal with the prior proposal = self._sbi_prior # Multi-round sbi for idx_r in range(self.num_sbi_rounds): # Sample parameters proposal, and simulate these parameters to obtain the data domain_param, data_sim = simulate_for_sbi( simulator=self._sbi_simulator, proposal=proposal, num_simulations=self.num_sim_per_round, simulation_batch_size=self.simulation_batch_size, num_workers=self.num_workers, ) self._cnt_samples += self.num_sim_per_round * self._env_sim_sbi.max_steps # Append simulations and proposals for sbi self._subrtn_sbi.append_simulations( domain_param, data_sim, proposal= proposal, # do not pass proposal arg for SNLE or SNRE ) # Train the posterior density_estimator = self._subrtn_sbi.train( final_round=idx_r == self.num_sbi_rounds - 1, component_perturbation=self._component_perturbation, **self.subrtn_sbi_training_hparam, ) posterior = self._subrtn_sbi.build_posterior( density_estimator=density_estimator, **self.subrtn_sbi_sampling_hparam) # Save the posterior of this iteration before tailoring it to the data (when it is still amortized) if idx_r == 0: pyrado.save( posterior, "posterior.pt", self._save_dir, prefix=f"iter_{self._curr_iter}", ) # Set proposal of the next round to focus on the next data set. # set_default_x() expects dim [1, num_rollouts * data_samples] proposal = posterior.set_default_x(self._curr_data_real) # Save the posterior tailored to each round pyrado.save( posterior, "posterior.pt", self._save_dir, prefix=f"iter_{self._curr_iter}_round_{idx_r}", ) # Override the latest posterior pyrado.save(posterior, "posterior.pt", self._save_dir) self.reached_checkpoint() # setting counter to 2 if self.curr_checkpoint == 2: # Logging (the evaluation can be time-intensive) posterior = pyrado.load("posterior.pt", self._save_dir) self._curr_domain_param_eval, log_probs = SBIBase.eval_posterior( posterior, self._curr_data_real, self.num_eval_samples, calculate_log_probs=True, normalize_posterior=self.normalize_posterior, subrtn_sbi_sampling_hparam=self.subrtn_sbi_sampling_hparam, ) self.logger.add_value("avg log prob", to.mean(log_probs), 4) self.logger.add_value("num total samples", self._cnt_samples) # Extract the most likely domain parameter set out of all target domain data sets current_domain_param = self._env_sim_sbi.domain_param idx_ml = to.argmax(log_probs).item() dp_vals = self._curr_domain_param_eval[idx_ml // self.num_eval_samples, idx_ml % self.num_eval_samples, :] dp_vals = to.atleast_1d(dp_vals).numpy() ml_domain_param = dict( zip(self.dp_mapping.values(), dp_vals.tolist())) # Update the unchanged domain parameters with the most likely ones obtained from the posterior current_domain_param.update(ml_domain_param) pyrado.save(current_domain_param, "ml_domain_param.pkl", self.save_dir, prefix=f"iter_{self._curr_iter}") self.reached_checkpoint() # setting counter to 3 if self.curr_checkpoint == 3: # Policy optimization if self._subrtn_policy is not None: pyrado.load( "policy.pt", self._save_dir, prefix=f"iter_{self._curr_iter - 1}" if self.curr_iter != 0 else "init", obj=self._policy, ) # Train the behavioral policy using the posterior samples obtained before. # Repeat the training if the resulting policy did not exceed the success threshold. print_cbt( "Training the next policy using domain parameter sets sampled from the current posterior.", "c") wrapped_trn_fcn = until_thold_exceeded( self.thold_succ_subrtn, self.max_subrtn_rep)(self.train_policy_sim) wrapped_trn_fcn(self._curr_domain_param_eval.squeeze(0), prefix=f"iter_{self._curr_iter}", use_rec_init_states=True) else: # save prefixed policy either way pyrado.save(self.policy, "policy.pt", self.save_dir, prefix=f"iter_{self._curr_iter}", use_state_dict=True) self.reached_checkpoint() # setting counter to 0 # Save snapshot data self.make_snapshot(snapshot_mode, None, meta_info)
if args.iter != -1: # Only load the selected iteration's rollouts rollouts_real = rollouts_real[args.iter * algo.num_real_rollouts : (args.iter + 1) * algo.num_real_rollouts] num_rollouts_real = len(rollouts_real) [ro.numpy() for ro in rollouts_real] # Decide on the policy: either use the exact actions or use the same policy which is however observation-dependent if args.use_rec: policy = PlaybackPolicy(env_sim.spec, [ro.actions for ro in rollouts_real], no_reset=True) # Compute the most likely domain parameters for every target domain observation domain_params_ml_all, _ = SBIBase.get_ml_posterior_samples( algo.dp_mapping, posterior, data_real, num_eval_samples=args.num_samples, num_ml_samples=num_ml_samples, normalize_posterior=args.normalize, subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc), ) # Repeat the domain parameters to zip them later with the real rollouts, such that they all belong to the same iter num_iter = len(domain_params_ml_all) num_rep = num_rollouts_real // num_iter domain_params_ml_all = repeat_interleave(domain_params_ml_all, num_rep) assert len(domain_params_ml_all) == num_rollouts_real # Split rollouts into segments segments_real_all = [] for ro in rollouts_real: # Split the target domain rollout, see SimRolloutSamplerForSBI.__call__()
def test_pair_plot_scatter( env: SimEnv, policy: Policy, layout: str, labels: Optional[str], legend_labels: Optional[str], axis_limits: Optional[str], use_kde: bool, use_trafo: bool, ): def _simulator(dp: to.Tensor) -> to.Tensor: """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope""" ro = rollout( env, policy, eval=True, reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2]))) observation_sim = to.from_numpy( ro.observations[-1]).to(dtype=to.float32) return to.atleast_2d(observation_sim) # Fix the init state env.init_space = SingularStateSpace(env.init_space.sample_uniform()) env_real = deepcopy(env) env_real.domain_param = {"mass": 0.8, "stiffness": 15, "d": 0.7} # Optionally transformed domain parameters for inference if use_trafo: env = LogDomainParamTransform(env, mask=["stiffness"]) # Domain parameter mapping and prior dp_mapping = {0: "mass", 1: "stiffness", 2: "d"} k_low = np.log(10) if use_trafo else 10 k_up = np.log(20) if use_trafo else 20 prior = sbiutils.BoxUniform(low=to.tensor([0.5, k_low, 0.2]), high=to.tensor([1.5, k_up, 0.8])) # Learn a likelihood from the simulator density_estimator = sbiutils.posterior_nn(model="maf", hidden_features=10, num_transforms=3) snpe = SNPE(prior, density_estimator) simulator, prior = prepare_for_sbi(_simulator, prior) domain_param, data_sim = simulate_for_sbi(simulator=simulator, proposal=prior, num_simulations=50, num_workers=1) snpe.append_simulations(domain_param, data_sim) density_estimator = snpe.train(max_num_epochs=5) posterior = snpe.build_posterior(density_estimator) # Create a fake (random) true domain parameter domain_param_gt = to.tensor([ env_real.domain_param[dp_mapping[key]] for key in sorted(dp_mapping.keys()) ]) domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 10 domain_param_gt = domain_param_gt.unsqueeze(0) data_real = simulator(domain_param_gt) domain_params, log_probs = SBIBase.eval_posterior( posterior, data_real, num_samples=6, normalize_posterior=False, subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False), ) dp_samples = [ domain_params.reshape(1, -1, domain_params.shape[-1]).squeeze() ] if layout == "inside": num_rows, num_cols = len(dp_mapping), len(dp_mapping) else: num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1 _, axs = plt.subplots(num_rows, num_cols, figsize=(8, 8), tight_layout=True) fig = draw_posterior_pairwise_scatter( axs=axs, dp_samples=dp_samples, dp_mapping=dp_mapping, prior=prior if axis_limits == "use_prior" else None, env_sim=env, env_real=env_real, axis_limits=axis_limits, marginal_layout=layout, labels=labels, legend_labels=legend_labels, use_kde=use_kde, ) assert fig is not None
env_sim, policy, kwout = load_experiment(ex_dir, args) env_real = pyrado.load("env_real.pkl", ex_dir) prior = kwout["prior"] posterior = kwout["posterior"] data_real = kwout["data_real"] if args.mode.lower() == "evolution-round" and args.iter == -1: args.iter = algo.curr_iter print_cbt( "Set the evaluation iteration to the latest iteration of the algorithm.", "y") # Load the sequence of posteriors if desired if args.mode.lower() == "evolution-iter": posterior = [ SBIBase.load_posterior(ex_dir, idx_iter=i, verbose=True) for i in range(algo.max_iter) ] posterior = remove_none_from_list( posterior) # in case the algorithm terminated early elif args.mode.lower() == "evolution-round": posterior = [ SBIBase.load_posterior(ex_dir, idx_round=i, verbose=True) for i in range(algo.num_sbi_rounds) ] posterior = remove_none_from_list( posterior) # in case the algorithm terminated early if "evolution" in args.mode.lower( ) and data_real.shape[0] > len(posterior): print_cbt(