示例#1
0
def _load_experiment(ex_dir: pyrado.PathLike):
    # Load the algorithm
    algo = Algorithm.load_snapshot(ex_dir)
    if not isinstance(algo, (NPDR, BayesSim)):
        raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim))

    # Load the prior and the data
    prior = pyrado.load("prior.pt", ex_dir)
    data_real = pyrado.load("data_real.pt", ex_dir)

    # Load the posteriors
    posteriors = [
        SBIBase.load_posterior(ex_dir, idx_round=i, verbose=True)
        for i in range(algo.num_sbi_rounds)
    ]
    posteriors = remove_none_from_list(
        posteriors)  # in case the algorithm terminated early

    if data_real.shape[0] > len(posteriors):
        print_cbt(
            f"Found {data_real.shape[0]} data sets but {len(posteriors)} posteriors. Truncated the superfluous data.",
            "y",
        )
        data_real = data_real[:len(posteriors), :]

    # Artificially repeat the data (which was the same for every round) to later be able to use the same code
    data_real = data_real.repeat(len(posteriors), 1)
    assert data_real.shape[0] == len(posteriors)

    return algo, prior, data_real, posteriors
示例#2
0
        raise pyrado.TypeErr(given=algo, expected_type=NPDR)
    env_sim = inner_env(pyrado.load("env_sim.pkl", ex_dir_npdr))
    prior_npdr = pyrado.load("prior.pt", ex_dir_npdr)
    posterior_npdr = algo.load_posterior(ex_dir_npdr,
                                         idx_iter=0,
                                         idx_round=6,
                                         obj=None,
                                         verbose=True)  # CHOICE
    data_real_npdr = pyrado.load(f"data_real.pt",
                                 ex_dir_npdr,
                                 prefix="iter_0",
                                 verbose=True)  # CHOICE
    domain_params_npdr, log_probs = SBIBase.eval_posterior(
        posterior_npdr,
        data_real_npdr,
        args.num_samples,
        normalize_posterior=False,  # not necessary here
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc),
    )
    domain_params_posterior_npdr = domain_params_npdr.reshape(
        1, -1, domain_params_npdr.shape[-1]).squeeze()

    # Bayessim
    ex_dir_bs = os.path.join(pyrado.TEMP_DIR, "mg-ik", "bayessim_time", "")
    algo = Algorithm.load_snapshot(ex_dir_bs)
    if not isinstance(algo, BayesSim):
        raise pyrado.TypeErr(given=algo, expected_type=BayesSim)
    posterior_bs = algo.load_posterior(ex_dir_bs,
                                       idx_iter=0,
                                       idx_round=0,
                                       obj=None,
示例#3
0
            nrows=1,
            ncols=3,
            figsize=pyrado.
            figsize_CoRL_6perrow_square  # , constrained_layout=True
        )
        for idx, (posterior, data) in enumerate(zip(posteriors, data_real)):
            # Select round or not
            if idx not in config["sel_rounds"]:
                continue

            if args.mode == "scatter":
                # Sample from the posterior
                domain_params, log_probs = SBIBase.eval_posterior(
                    posterior,
                    data.unsqueeze(0),
                    args.num_samples,
                    normalize_posterior=False,  # not necessary here
                    subrtn_sbi_sampling_hparam=dict(
                        sample_with_mcmc=args.use_mcmc),
                )
                domain_params = domain_params.squeeze(0)

                # Plot
                color_palette = sns.color_palette()[1:]
                _ = draw_posterior_scatter_2d(
                    ax=axs[ax_cnt],
                    dp_samples=[domain_params],
                    dp_mapping=algo.dp_mapping,
                    dims=(0, 1),
                    prior=prior,
                    env_sim=None,
                    env_real=algo._env_real,
    # Use the environments number of steps in case of the default argument (inf)
    max_steps = env.max_steps if args.max_steps == pyrado.inf else args.max_steps

    # Check which algorithm was used in the experiment
    algo = Algorithm.load_snapshot(load_dir=ex_dir, load_name="algo")
    if not isinstance(algo, (NPDR, BayesSim)):
        raise pyrado.TypeErr(given=algo, expected_type=(NPDR, BayesSim))

    # Sample domain parameters from the posterior. Use all samples, by hijacking the get_ml_posterior_samples to obtain
    # them sorted.
    domain_params, log_probs = SBIBase.get_ml_posterior_samples(
        dp_mapping=algo.dp_mapping,
        posterior=kwout["posterior"],
        data_real=data_real,
        num_eval_samples=args.num_samples,
        num_ml_samples=args.num_samples,
        calculate_log_probs=True,
        normalize_posterior=args.normalize,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc),
        return_as_tensor=False,
    )
    assert len(domain_params
               ) == 1  # the list has as many elements as evaluated iterations
    domain_params = domain_params[0]

    if args.normalize:
        # If the posterior is normalized, we do not rescale the probabilities since they already sum to 1
        probs = to.exp(log_probs)
    else:
        # If the posterior is not normalized, we rescale the probabilities to make them interpretable
        probs = to.exp(log_probs -
示例#5
0
    def step(self, snapshot_mode: str = "latest", meta_info: dict = None):
        # Save snapshot to save the correct iteration count
        self.save_snapshot()

        if self.curr_checkpoint == -1:
            if self._subrtn_policy is not None and self._train_initial_policy:
                # Add dummy values of variables that are logger later
                self.logger.add_value("avg log prob", -pyrado.inf)

                # Train the behavioral policy using the samples obtained from the prior.
                # Repeat the training if the resulting policy did not exceed the success threshold.
                domain_params = self._sbi_prior.sample(
                    sample_shape=(self.num_eval_samples, ))
                print_cbt(
                    "Training the initial policy using domain parameter sets sampled from prior.",
                    "c")
                wrapped_trn_fcn = until_thold_exceeded(
                    self.thold_succ_subrtn,
                    self.max_subrtn_rep)(self.train_policy_sim)
                wrapped_trn_fcn(
                    domain_params, prefix="init",
                    use_rec_init_states=False)  # overrides policy.pt

            self.reached_checkpoint()  # setting counter to 0

        if self.curr_checkpoint == 0:
            # Check if the rollout files already exist
            if (osp.isfile(
                    osp.join(self._save_dir,
                             f"iter_{self.curr_iter}_data_real.pt"))
                    and osp.isfile(osp.join(self._save_dir, "data_real.pt"))
                    and osp.isfile(
                        osp.join(self._save_dir, "rollouts_real.pkl"))):
                # Rollout files do exist (can be when continuing a previous experiment)
                self._curr_data_real = pyrado.load(
                    "data_real.pt",
                    self._save_dir,
                    prefix=f"iter_{self.curr_iter}")
                print_cbt(
                    f"Loaded existing rollout data for iteration {self.curr_iter}.",
                    "w")

            else:
                # If the policy depends on the domain-parameters, reset the policy with the
                # most likely dp-params from the previous round.
                pyrado.load(
                    "policy.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter - 1}"
                    if self.curr_iter != 0 else "init",
                    obj=self._policy,
                )
                if self.curr_iter != 0:
                    ml_domain_param = pyrado.load(
                        "ml_domain_param.pkl",
                        self.save_dir,
                        prefix=f"iter_{self._curr_iter - 1}")
                    self._policy.reset(**dict(domain_param=ml_domain_param))

                # Rollout files do not exist yet (usual case)
                self._curr_data_real, _ = SBIBase.collect_data_real(
                    self.save_dir,
                    self._env_real,
                    self._policy,
                    self._embedding,
                    prefix=f"iter_{self._curr_iter}",
                    num_rollouts=self.num_real_rollouts,
                    num_segments=self.num_segments,
                    len_segments=self.len_segments,
                )

                # Save the target domain data
                if self._curr_iter == 0:
                    # Append the first set of data
                    pyrado.save(self._curr_data_real, "data_real.pt",
                                self._save_dir)
                else:
                    # Append and save all data
                    prev_data = pyrado.load("data_real.pt", self._save_dir)
                    data_real_hist = to.cat([prev_data, self._curr_data_real],
                                            dim=0)
                    pyrado.save(data_real_hist, "data_real.pt", self._save_dir)

            # Initialize sbi simulator and prior
            self._setup_sbi(
                prior=self._sbi_prior,
                rollouts_real=pyrado.load("rollouts_real.pkl",
                                          self._save_dir,
                                          prefix=f"iter_{self._curr_iter}"),
            )

            self.reached_checkpoint()  # setting counter to 1

        if self.curr_checkpoint == 1:
            # Instantiate the sbi subroutine to retrain from scratch each iteration
            if self.reset_sbi_routine_each_iter:
                self._initialize_subrtn_sbi(
                    subrtn_sbi_class=SNPE_A,
                    num_components=self._num_components)

            # Initialize the proposal with the prior
            proposal = self._sbi_prior

            # Multi-round sbi
            for idx_r in range(self.num_sbi_rounds):
                # Sample parameters proposal, and simulate these parameters to obtain the data
                domain_param, data_sim = simulate_for_sbi(
                    simulator=self._sbi_simulator,
                    proposal=proposal,
                    num_simulations=self.num_sim_per_round,
                    simulation_batch_size=self.simulation_batch_size,
                    num_workers=self.num_workers,
                )
                self._cnt_samples += self.num_sim_per_round * self._env_sim_sbi.max_steps

                # Append simulations and proposals for sbi
                self._subrtn_sbi.append_simulations(
                    domain_param,
                    data_sim,
                    proposal=
                    proposal,  # do not pass proposal arg for SNLE or SNRE
                )

                # Train the posterior
                density_estimator = self._subrtn_sbi.train(
                    final_round=idx_r == self.num_sbi_rounds - 1,
                    component_perturbation=self._component_perturbation,
                    **self.subrtn_sbi_training_hparam,
                )
                posterior = self._subrtn_sbi.build_posterior(
                    density_estimator=density_estimator,
                    **self.subrtn_sbi_sampling_hparam)

                # Save the posterior of this iteration before tailoring it to the data (when it is still amortized)
                if idx_r == 0:
                    pyrado.save(
                        posterior,
                        "posterior.pt",
                        self._save_dir,
                        prefix=f"iter_{self._curr_iter}",
                    )

                # Set proposal of the next round to focus on the next data set.
                # set_default_x() expects dim [1, num_rollouts * data_samples]
                proposal = posterior.set_default_x(self._curr_data_real)

                # Save the posterior tailored to each round
                pyrado.save(
                    posterior,
                    "posterior.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter}_round_{idx_r}",
                )

                # Override the latest posterior
                pyrado.save(posterior, "posterior.pt", self._save_dir)

            self.reached_checkpoint()  # setting counter to 2

        if self.curr_checkpoint == 2:
            # Logging (the evaluation can be time-intensive)
            posterior = pyrado.load("posterior.pt", self._save_dir)
            self._curr_domain_param_eval, log_probs = SBIBase.eval_posterior(
                posterior,
                self._curr_data_real,
                self.num_eval_samples,
                calculate_log_probs=True,
                normalize_posterior=self.normalize_posterior,
                subrtn_sbi_sampling_hparam=self.subrtn_sbi_sampling_hparam,
            )
            self.logger.add_value("avg log prob", to.mean(log_probs), 4)
            self.logger.add_value("num total samples", self._cnt_samples)

            # Extract the most likely domain parameter set out of all target domain data sets
            current_domain_param = self._env_sim_sbi.domain_param
            idx_ml = to.argmax(log_probs).item()
            dp_vals = self._curr_domain_param_eval[idx_ml //
                                                   self.num_eval_samples,
                                                   idx_ml %
                                                   self.num_eval_samples, :]
            dp_vals = to.atleast_1d(dp_vals).numpy()
            ml_domain_param = dict(
                zip(self.dp_mapping.values(), dp_vals.tolist()))

            # Update the unchanged domain parameters with the most likely ones obtained from the posterior
            current_domain_param.update(ml_domain_param)
            pyrado.save(current_domain_param,
                        "ml_domain_param.pkl",
                        self.save_dir,
                        prefix=f"iter_{self._curr_iter}")

            self.reached_checkpoint()  # setting counter to 3

        if self.curr_checkpoint == 3:
            # Policy optimization
            if self._subrtn_policy is not None:
                pyrado.load(
                    "policy.pt",
                    self._save_dir,
                    prefix=f"iter_{self._curr_iter - 1}"
                    if self.curr_iter != 0 else "init",
                    obj=self._policy,
                )
                # Train the behavioral policy using the posterior samples obtained before.
                # Repeat the training if the resulting policy did not exceed the success threshold.
                print_cbt(
                    "Training the next policy using domain parameter sets sampled from the current posterior.",
                    "c")
                wrapped_trn_fcn = until_thold_exceeded(
                    self.thold_succ_subrtn,
                    self.max_subrtn_rep)(self.train_policy_sim)
                wrapped_trn_fcn(self._curr_domain_param_eval.squeeze(0),
                                prefix=f"iter_{self._curr_iter}",
                                use_rec_init_states=True)
            else:
                # save prefixed policy either way
                pyrado.save(self.policy,
                            "policy.pt",
                            self.save_dir,
                            prefix=f"iter_{self._curr_iter}",
                            use_state_dict=True)

            self.reached_checkpoint()  # setting counter to 0

        # Save snapshot data
        self.make_snapshot(snapshot_mode, None, meta_info)
    if args.iter != -1:
        # Only load the selected iteration's rollouts
        rollouts_real = rollouts_real[args.iter * algo.num_real_rollouts : (args.iter + 1) * algo.num_real_rollouts]
    num_rollouts_real = len(rollouts_real)
    [ro.numpy() for ro in rollouts_real]

    # Decide on the policy: either use the exact actions or use the same policy which is however observation-dependent
    if args.use_rec:
        policy = PlaybackPolicy(env_sim.spec, [ro.actions for ro in rollouts_real], no_reset=True)

    # Compute the most likely domain parameters for every target domain observation
    domain_params_ml_all, _ = SBIBase.get_ml_posterior_samples(
        algo.dp_mapping,
        posterior,
        data_real,
        num_eval_samples=args.num_samples,
        num_ml_samples=num_ml_samples,
        normalize_posterior=args.normalize,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=args.use_mcmc),
    )

    # Repeat the domain parameters to zip them later with the real rollouts, such that they all belong to the same iter
    num_iter = len(domain_params_ml_all)
    num_rep = num_rollouts_real // num_iter
    domain_params_ml_all = repeat_interleave(domain_params_ml_all, num_rep)
    assert len(domain_params_ml_all) == num_rollouts_real

    # Split rollouts into segments
    segments_real_all = []
    for ro in rollouts_real:
        # Split the target domain rollout, see SimRolloutSamplerForSBI.__call__()
示例#7
0
def test_pair_plot_scatter(
    env: SimEnv,
    policy: Policy,
    layout: str,
    labels: Optional[str],
    legend_labels: Optional[str],
    axis_limits: Optional[str],
    use_kde: bool,
    use_trafo: bool,
):
    def _simulator(dp: to.Tensor) -> to.Tensor:
        """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope"""
        ro = rollout(
            env,
            policy,
            eval=True,
            reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2])))
        observation_sim = to.from_numpy(
            ro.observations[-1]).to(dtype=to.float32)
        return to.atleast_2d(observation_sim)

    # Fix the init state
    env.init_space = SingularStateSpace(env.init_space.sample_uniform())
    env_real = deepcopy(env)
    env_real.domain_param = {"mass": 0.8, "stiffness": 15, "d": 0.7}

    # Optionally transformed domain parameters for inference
    if use_trafo:
        env = LogDomainParamTransform(env, mask=["stiffness"])

    # Domain parameter mapping and prior
    dp_mapping = {0: "mass", 1: "stiffness", 2: "d"}
    k_low = np.log(10) if use_trafo else 10
    k_up = np.log(20) if use_trafo else 20
    prior = sbiutils.BoxUniform(low=to.tensor([0.5, k_low, 0.2]),
                                high=to.tensor([1.5, k_up, 0.8]))

    # Learn a likelihood from the simulator
    density_estimator = sbiutils.posterior_nn(model="maf",
                                              hidden_features=10,
                                              num_transforms=3)
    snpe = SNPE(prior, density_estimator)
    simulator, prior = prepare_for_sbi(_simulator, prior)
    domain_param, data_sim = simulate_for_sbi(simulator=simulator,
                                              proposal=prior,
                                              num_simulations=50,
                                              num_workers=1)
    snpe.append_simulations(domain_param, data_sim)
    density_estimator = snpe.train(max_num_epochs=5)
    posterior = snpe.build_posterior(density_estimator)

    # Create a fake (random) true domain parameter
    domain_param_gt = to.tensor([
        env_real.domain_param[dp_mapping[key]]
        for key in sorted(dp_mapping.keys())
    ])
    domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 10
    domain_param_gt = domain_param_gt.unsqueeze(0)
    data_real = simulator(domain_param_gt)

    domain_params, log_probs = SBIBase.eval_posterior(
        posterior,
        data_real,
        num_samples=6,
        normalize_posterior=False,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False),
    )
    dp_samples = [
        domain_params.reshape(1, -1, domain_params.shape[-1]).squeeze()
    ]

    if layout == "inside":
        num_rows, num_cols = len(dp_mapping), len(dp_mapping)
    else:
        num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1

    _, axs = plt.subplots(num_rows,
                          num_cols,
                          figsize=(8, 8),
                          tight_layout=True)
    fig = draw_posterior_pairwise_scatter(
        axs=axs,
        dp_samples=dp_samples,
        dp_mapping=dp_mapping,
        prior=prior if axis_limits == "use_prior" else None,
        env_sim=env,
        env_real=env_real,
        axis_limits=axis_limits,
        marginal_layout=layout,
        labels=labels,
        legend_labels=legend_labels,
        use_kde=use_kde,
    )
    assert fig is not None
示例#8
0
    env_sim, policy, kwout = load_experiment(ex_dir, args)
    env_real = pyrado.load("env_real.pkl", ex_dir)
    prior = kwout["prior"]
    posterior = kwout["posterior"]
    data_real = kwout["data_real"]

    if args.mode.lower() == "evolution-round" and args.iter == -1:
        args.iter = algo.curr_iter
        print_cbt(
            "Set the evaluation iteration to the latest iteration of the algorithm.",
            "y")

    # Load the sequence of posteriors if desired
    if args.mode.lower() == "evolution-iter":
        posterior = [
            SBIBase.load_posterior(ex_dir, idx_iter=i, verbose=True)
            for i in range(algo.max_iter)
        ]
        posterior = remove_none_from_list(
            posterior)  # in case the algorithm terminated early
    elif args.mode.lower() == "evolution-round":
        posterior = [
            SBIBase.load_posterior(ex_dir, idx_round=i, verbose=True)
            for i in range(algo.num_sbi_rounds)
        ]
        posterior = remove_none_from_list(
            posterior)  # in case the algorithm terminated early

    if "evolution" in args.mode.lower(
    ) and data_real.shape[0] > len(posterior):
        print_cbt(