Esempio n. 1
0
def eval_domain_params(
        pool: SamplerPool,
        env: SimEnv,
        policy: Policy,
        params: List[Dict],
        init_state: Optional[np.ndarray] = None) -> List[StepSequence]:
    """
    Evaluate a policy on a multidimensional grid of domain parameters.

    :param pool: parallel sampler
    :param env: environment to evaluate in
    :param policy: policy to evaluate
    :param params: multidimensional grid of domain parameters
    :param init_state: initial state of the environment which will be fixed if not set to `None`
    :return: list of rollouts
    """
    # Strip all domain randomization wrappers from the environment
    env = remove_all_dr_wrappers(env, verbose=True)
    if init_state is not None:
        env.init_space = SingularStateSpace(fixed_state=init_state)

    pool.invoke_all(_ps_init, pickle.dumps(env), pickle.dumps(policy))

    # Run with progress bar
    with tqdm(leave=False, file=sys.stdout, unit="rollouts",
              desc="Sampling") as pb:
        return pool.run_map(
            functools.partial(_ps_run_one_domain_param, eval=True), params, pb)
Esempio n. 2
0
def test_pair_plot_scatter(
    env: SimEnv,
    policy: Policy,
    layout: str,
    labels: Optional[str],
    legend_labels: Optional[str],
    axis_limits: Optional[str],
    use_kde: bool,
    use_trafo: bool,
):
    def _simulator(dp: to.Tensor) -> to.Tensor:
        """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope"""
        ro = rollout(
            env,
            policy,
            eval=True,
            reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2])))
        observation_sim = to.from_numpy(
            ro.observations[-1]).to(dtype=to.float32)
        return to.atleast_2d(observation_sim)

    # Fix the init state
    env.init_space = SingularStateSpace(env.init_space.sample_uniform())
    env_real = deepcopy(env)
    env_real.domain_param = {"mass": 0.8, "stiffness": 15, "d": 0.7}

    # Optionally transformed domain parameters for inference
    if use_trafo:
        env = LogDomainParamTransform(env, mask=["stiffness"])

    # Domain parameter mapping and prior
    dp_mapping = {0: "mass", 1: "stiffness", 2: "d"}
    k_low = np.log(10) if use_trafo else 10
    k_up = np.log(20) if use_trafo else 20
    prior = sbiutils.BoxUniform(low=to.tensor([0.5, k_low, 0.2]),
                                high=to.tensor([1.5, k_up, 0.8]))

    # Learn a likelihood from the simulator
    density_estimator = sbiutils.posterior_nn(model="maf",
                                              hidden_features=10,
                                              num_transforms=3)
    snpe = SNPE(prior, density_estimator)
    simulator, prior = prepare_for_sbi(_simulator, prior)
    domain_param, data_sim = simulate_for_sbi(simulator=simulator,
                                              proposal=prior,
                                              num_simulations=50,
                                              num_workers=1)
    snpe.append_simulations(domain_param, data_sim)
    density_estimator = snpe.train(max_num_epochs=5)
    posterior = snpe.build_posterior(density_estimator)

    # Create a fake (random) true domain parameter
    domain_param_gt = to.tensor([
        env_real.domain_param[dp_mapping[key]]
        for key in sorted(dp_mapping.keys())
    ])
    domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 10
    domain_param_gt = domain_param_gt.unsqueeze(0)
    data_real = simulator(domain_param_gt)

    domain_params, log_probs = SBIBase.eval_posterior(
        posterior,
        data_real,
        num_samples=6,
        normalize_posterior=False,
        subrtn_sbi_sampling_hparam=dict(sample_with_mcmc=False),
    )
    dp_samples = [
        domain_params.reshape(1, -1, domain_params.shape[-1]).squeeze()
    ]

    if layout == "inside":
        num_rows, num_cols = len(dp_mapping), len(dp_mapping)
    else:
        num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1

    _, axs = plt.subplots(num_rows,
                          num_cols,
                          figsize=(8, 8),
                          tight_layout=True)
    fig = draw_posterior_pairwise_scatter(
        axs=axs,
        dp_samples=dp_samples,
        dp_mapping=dp_mapping,
        prior=prior if axis_limits == "use_prior" else None,
        env_sim=env,
        env_real=env_real,
        axis_limits=axis_limits,
        marginal_layout=layout,
        labels=labels,
        legend_labels=legend_labels,
        use_kde=use_kde,
    )
    assert fig is not None
Esempio n. 3
0
def test_pair_plot(
    env: SimEnv,
    policy: Policy,
    layout: str,
    labels: Optional[str],
    prob_labels: Optional[str],
    use_prior: bool,
    use_trafo: bool,
):
    def _simulator(dp: to.Tensor) -> to.Tensor:
        """The most simple interface of a simulation to sbi, using `env` and `policy` from outer scope"""
        ro = rollout(
            env,
            policy,
            eval=True,
            reset_kwargs=dict(domain_param=dict(m=dp[0], k=dp[1], d=dp[2])))
        observation_sim = to.from_numpy(
            ro.observations[-1]).to(dtype=to.float32)
        return to.atleast_2d(observation_sim)

    # Fix the init state
    env.init_space = SingularStateSpace(env.init_space.sample_uniform())
    env_real = deepcopy(env)
    env_real.domain_param = {"mass": 0.8, "stiffness": 35, "d": 0.7}

    # Optionally transformed domain parameters for inference
    if use_trafo:
        env = SqrtDomainParamTransform(env, mask=["stiffness"])

    # Domain parameter mapping and prior
    dp_mapping = {0: "mass", 1: "stiffness", 2: "d"}
    prior = sbiutils.BoxUniform(low=to.tensor([0.5, 20, 0.2]),
                                high=to.tensor([1.5, 40, 0.8]))

    # Learn a likelihood from the simulator
    density_estimator = sbiutils.posterior_nn(model="maf",
                                              hidden_features=10,
                                              num_transforms=3)
    snpe = SNPE(prior, density_estimator)
    simulator, prior = prepare_for_sbi(_simulator, prior)
    domain_param, data_sim = simulate_for_sbi(simulator=simulator,
                                              proposal=prior,
                                              num_simulations=50,
                                              num_workers=1)
    snpe.append_simulations(domain_param, data_sim)
    density_estimator = snpe.train(max_num_epochs=5)
    posterior = snpe.build_posterior(density_estimator)

    # Create a fake (random) true domain parameter
    domain_param_gt = to.tensor(
        [env_real.domain_param[key] for _, key in dp_mapping.items()])
    domain_param_gt += domain_param_gt * to.randn(len(dp_mapping)) / 5
    domain_param_gt = domain_param_gt.unsqueeze(0)
    data_real = simulator(domain_param_gt)

    # Get a (random) condition
    condition = Embedding.pack(domain_param_gt.clone())

    if layout == "inside":
        num_rows, num_cols = len(dp_mapping), len(dp_mapping)
    else:
        num_rows, num_cols = len(dp_mapping) + 1, len(dp_mapping) + 1

    if use_prior:
        grid_bounds = None
    else:
        prior = None
        grid_bounds = to.cat(
            [to.zeros((len(dp_mapping), 1)),
             to.ones((len(dp_mapping), 1))],
            dim=1)

    _, axs = plt.subplots(num_rows,
                          num_cols,
                          figsize=(14, 14),
                          tight_layout=True)
    fig = draw_posterior_pairwise_heatmap(
        axs,
        posterior,
        data_real,
        dp_mapping,
        condition,
        prior=prior,
        env_real=env_real,
        marginal_layout=layout,
        grid_bounds=grid_bounds,
        grid_res=100,
        normalize_posterior=False,
        rescale_posterior=True,
        labels=None if labels is None else [""] * len(dp_mapping),
        prob_labels=prob_labels,
    )

    assert fig is not None