Ejemplo n.º 1
0
def run(
    task: Task,
    num_samples: int,
    num_observation: int,
    rerun: bool = True,
    **kwargs: Any,
) -> torch.Tensor:
    """Random samples from saved reference posterior as baseline

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_observation: Observation number to load, alternative to `observation`
        rerun: Whether to rerun reference or load from disk

    Returns:
        Random samples from reference posterior
    """
    log = sbibm.get_logger(__name__)

    if "num_simulations" in kwargs:
        log.warn(
            "`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
        )

    if rerun:
        return task._sample_reference_posterior(
            num_samples=num_samples, num_observation=num_observation)
    else:
        reference_posterior_samples = task.get_reference_posterior_samples(
            num_observation)

        return reference_posterior_samples[np.random.randint(
            reference_posterior_samples.shape[0], size=num_samples), :, ]
Ejemplo n.º 2
0
def get_proposal(
    task: Task,
    samples: torch.Tensor,
    prior_weight: float = 0.01,
    bounded: bool = True,
    density_estimator: str = "flow",
    flow_model: str = "nsf",
    **kwargs: Any,
) -> torch.Tensor:
    """Gets proposal distribution by performing density estimation on `samples`

    If `prior_weight` > 0., the proposal is defensive, i.e., the prior is mixed in

    Args:
        task: Task instance
        samples: Samples to fit
        prior_weight: Prior weight
        bounded: If True, will automatically transform proposal density to bounded space
        density_estimator: Density estimator
        flow_model: Flow to use if `density_estimator` is `flow`
        kwargs: Passed on to `get_flow` or `get_kde`

    Returns:
        Proposal distribution
    """
    tic = time.time()
    log = sbibm.get_logger(__name__)
    log.info("Get proposal distribution called")

    prior_dist = task.get_prior_dist()
    transform = task._get_transforms(
        automatic_transforms_enabled=bounded)["parameters"]

    if density_estimator == "flow":
        density_estimator_ = get_flow(model=flow_model,
                                      dim_distribution=task.dim_parameters,
                                      **kwargs)
        density_estimator_ = train_flow(density_estimator_,
                                        samples,
                                        transform=transform)

    elif density_estimator == "kde":
        density_estimator_ = get_kde(X=samples, transform=transform, **kwargs)

    else:
        raise NotImplementedError

    proposal_dist = DenfensiveProposal(
        dim=task.dim_parameters,
        proposal=density_estimator_,
        prior=prior_dist,
        prior_weight=prior_weight,
    )

    log.info(f"Proposal distribution is set up, took {time.time()-tic:.3f}sec")

    return proposal_dist
Ejemplo n.º 3
0
def run(
    task: Task,
    num_samples: int,
    **kwargs: Any,
) -> torch.Tensor:
    """Random samples from prior as baseline

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior

    Returns:
        Random samples from prior
    """
    log = sbibm.get_logger(__name__)

    if "num_simulations" in kwargs:
        log.warn(
            "`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
        )

    prior = task.get_prior()
    return prior(num_samples=num_samples)
Ejemplo n.º 4
0
def run(
    task: Task,
    num_samples: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    num_chains: int = 10,
    num_warmup: int = 10000,
    kernel: str = "slice",
    kernel_parameters: Optional[Dict[str, Any]] = None,
    thinning: int = 1,
    diagnostics: bool = True,
    available_cpu: int = 1,
    mp_context: str = "fork",
    jit_compile: bool = False,
    automatic_transforms_enabled: bool = True,
    initial_params: Optional[torch.Tensor] = None,
    **kwargs: Any,
) -> torch.Tensor:
    """Runs MCMC using Pyro on potential function

    Produces `num_samples` while accounting for warmup (burn-in) and thinning.

    Note that the actual number of simulations is not controlled for with MCMC since
    algorithms are only used as a reference method in the benchmark. 

    MCMC is run on the potential function, which returns the unnormalized
    negative log posterior probability. Note that this requires a tractable likelihood.
    Pyro is used to automatically construct the potential function.

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        num_chains: Number of chains
        num_warmup: Warmup steps, during which parameters of the sampler are adapted.
            Warmup samples are not returned by the algorithm.
        kernel: HMC, NUTS, or Slice
        kernel_parameters: Parameters passed to kernel
        thinning: Amount of thinning to apply, in order to avoid drawing
            correlated samples from the chain
        diagnostics: Flag for diagnostics
        available_cpu: Number of CPUs used to parallelize chains
        mp_context: multiprocessing context, only fork might work
        jit_compile: Just-in-time (JIT) compilation, can yield significant speed ups
        automatic_transforms_enabled: Whether or not to use automatic transforms
        initial_params: Parameters to initialize at

    Returns:
        Samples from posterior
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    tic = time.time()
    log = sbibm.get_logger(__name__)

    hook_fn = None
    if diagnostics:
        log.info(f"MCMC sampling for observation {num_observation}")
        tb_writer, tb_close = tb_make_writer(
            logger=log,
            basepath=
            f"tensorboard/pyro_{kernel.lower()}/observation_{num_observation}",
        )
        hook_fn = tb_make_hook_fn(tb_writer)

    if "num_simulations" in kwargs:
        warnings.warn(
            "`num_simulations` was passed as a keyword but will be ignored, see docstring for more info."
        )

    # Prepare model and transforms
    conditioned_model = task._get_pyro_model(num_observation=num_observation,
                                             observation=observation)
    transforms = task._get_transforms(
        num_observation=num_observation,
        observation=observation,
        automatic_transforms_enabled=automatic_transforms_enabled,
    )

    kernel_parameters = kernel_parameters if kernel_parameters is not None else {}
    kernel_parameters["jit_compile"] = jit_compile
    kernel_parameters["transforms"] = transforms
    log.info("Using kernel: {name}({parameters})".format(
        name=kernel,
        parameters=",".join([f"{k}={v}"
                             for k, v in kernel_parameters.items()]),
    ))
    if kernel.lower() == "nuts":
        mcmc_kernel = NUTS(model=conditioned_model, **kernel_parameters)

    elif kernel.lower() == "hmc":
        mcmc_kernel = HMC(model=conditioned_model, **kernel_parameters)

    elif kernel.lower() == "slice":
        mcmc_kernel = Slice(model=conditioned_model, **kernel_parameters)

    else:
        raise NotImplementedError

    if initial_params is not None:
        site_name = "parameters"
        initial_params = {site_name: transforms[site_name](initial_params)}
    else:
        initial_params = None

    mcmc_parameters = {
        "num_chains": num_chains,
        "num_samples": thinning * num_samples,
        "warmup_steps": num_warmup,
        "available_cpu": available_cpu,
        "initial_params": initial_params,
    }
    log.info("Calling MCMC with: MCMC({name}_kernel, {parameters})".format(
        name=kernel,
        parameters=",".join([f"{k}={v}" for k, v in mcmc_parameters.items()]),
    ))

    mcmc = MCMC(mcmc_kernel, hook_fn=hook_fn, **mcmc_parameters)
    mcmc.run()

    toc = time.time()
    log.info(f"Finished MCMC after {toc-tic:.3f} seconds")
    log.info(f"Automatic transforms {mcmc.transforms}")

    log.info(f"Apply thinning of {thinning}")
    mcmc._samples = {
        "parameters": mcmc._samples["parameters"][:, ::thinning, :]
    }

    num_samples_available = (mcmc._samples["parameters"].shape[0] *
                             mcmc._samples["parameters"].shape[1])
    if num_samples_available < num_samples:
        warnings.warn("Some samples will be included multiple times")
        samples = mcmc.get_samples(
            num_samples=num_samples,
            group_by_chain=False)["parameters"].squeeze()
    else:
        samples = mcmc.get_samples(
            group_by_chain=False)["parameters"].squeeze()
        idx = torch.randperm(samples.shape[0])[:num_samples]
        samples = samples[idx, :]

    assert samples.shape[0] == num_samples

    if diagnostics:
        mcmc.summary()
        tb_ess(tb_writer, mcmc)
        tb_r_hat(tb_writer, mcmc)
        tb_marginals(tb_writer, mcmc)
        tb_acf(tb_writer, mcmc)
        tb_posteriors(tb_writer, mcmc)
        tb_plot_posterior(tb_writer, samples, tag="posterior/final")
        tb_close()

    return samples
Ejemplo n.º 5
0
        batch_size: Batch size used when finding M
        num_batches_without_new_max: Number of batches that need to be evaluated without
            finding new M before search is stopped
        multiplier_M: Multiplier used when determining M
        proposal_dist: If specified, will be used as a proposal distribution instead
            of prior
        kwargs: Not used

    Returns:
        Random samples from reference posterior
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    tic = time.time()
    log = sbibm.get_logger(__name__)
    log.info("Rejection sampling")

    if "num_simulations" in kwargs:
        log.warn(
            "`num_simulations` was passed as a keyword but will be ignored, since this is a baseline method."
        )

    prior = task.get_prior()
    prior_dist = task.get_prior_dist()

    if proposal_dist is None:
        proposal_dist = prior_dist

    log_prob_fn = task._get_log_prob_fn(
        num_observation=num_observation,
Ejemplo n.º 6
0
def run(
    task: Task,
    num_samples: int,
    num_simulations: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    num_top_samples: Optional[int] = 100,
    quantile: Optional[float] = None,
    eps: Optional[float] = None,
    distance: str = "l2",
    batch_size: int = 1000,
    save_distances: bool = False,
    kde_bandwidth: Optional[str] = "cv",
    sass: bool = False,
    sass_fraction: float = 0.5,
    sass_feature_expansion_degree: int = 3,
    lra: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
    """Runs REJ-ABC from `sbi`

    Choose one of `num_top_samples`, `quantile`, `eps`.

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_simulations: Simulation budget
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        num_top_samples: If given, will use `top=True` with num_top_samples
        quantile: Quantile to use
        eps: Epsilon threshold to use
        distance: Distance to use
        batch_size: Batch size for simulator
        save_distances: If True, stores distances of samples to disk
        kde_bandwidth: If not None, will resample using KDE when necessary, set
            e.g. to "cv" for cross-validated bandwidth selection
        sass: If True, summary statistics are learned as in
            Fearnhead & Prangle 2012.
        sass_fraction: Fraction of simulation budget to use for sass.
        sass_feature_expansion_degree: Degree of polynomial expansion of the summary
            statistics.
        lra: If True, posterior samples are adjusted with
            linear regression as in Beaumont et al. 2002.
    Returns:
        Samples from posterior, number of simulator calls, log probability of true params if computable
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    assert not (num_top_samples is None and quantile is None and eps is None)

    log = sbibm.get_logger(__name__)
    log.info(f"Running REJ-ABC")

    prior = task.get_prior_dist()
    simulator = task.get_simulator(max_calls=num_simulations)
    kde = kde_bandwidth is not None
    if observation is None:
        observation = task.get_observation(num_observation)

    if num_top_samples is not None and quantile is None:
        if sass:
            quantile = num_top_samples / (num_simulations -
                                          int(sass_fraction * num_simulations))
        else:
            quantile = num_top_samples / num_simulations

    inference_method = MCABC(
        simulator=simulator,
        prior=prior,
        simulation_batch_size=batch_size,
        distance=distance,
        show_progress_bars=True,
    )
    # Returns samples or kde posterior in output.
    output, summary = inference_method(
        x_o=observation,
        num_simulations=num_simulations,
        eps=eps,
        quantile=quantile,
        return_summary=True,
        kde=kde,
        kde_kwargs={} if run_kde else {"kde_bandwidth": kde_bandwidth},
        lra=lra,
        sass=sass,
        sass_expansion_degree=sass_feature_expansion_degree,
        sass_fraction=sass_fraction,
    )

    assert simulator.num_simulations == num_simulations

    if save_distances:
        save_tensor_to_csv("distances.csv", summary["distances"])

    if kde:
        kde_posterior = output
        samples = kde_posterior.sample(num_simulations)

        # LPTP can only be returned with KDE posterior.
        if num_observation is not None:
            true_parameters = task.get_true_parameters(
                num_observation=num_observation)
            log_prob_true_parameters = kde_posterior.log_prob(
                true_parameters.squeeze())
            return samples, simulator.num_simulations, log_prob_true_parameters
    else:
        samples = output
        return samples, simulator.num_simulations, None
Ejemplo n.º 7
0
def run(
    task: Task,
    num_samples: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    num_simulations: Optional[int] = None,
    low: Optional[torch.Tensor] = None,
    high: Optional[torch.Tensor] = None,
    eps: float = 0.001,
    resolution: Optional[int] = None,
    batch_size: int = 10000,
    save: bool = False,
    **kwargs: Any,
) -> torch.Tensor:
    """Random samples from gridded posterior as baseline

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        num_simulations: Number of simulations to determine resolution
        low: Lower limit per dimension, tries to infer it if not passed
        high: Upper limit per dimension, tries to infer it if not passed
        eps: Eps added to bounds to avoid NaN evaluations
        resolution: Resolution for all dimensions, alternatively use `num_simulations`
        batch_size: Batch size
        save: If True, saves grid and log probs

    Returns:
        Random samples from reference posterior
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)
    assert not (num_simulations is None and resolution is None)
    assert not (num_simulations is not None and resolution is not None)

    tic = time.time()
    log = sbibm.get_logger(__name__)

    if num_simulations is not None:
        resolution = int(
            math.floor(math.exp(math.log(num_simulations) / task.dim_parameters))
        )
    log.info(f"Resolution: {resolution}")

    # Infer bounds if not passed
    prior_params = task.get_prior_params()
    if low is None:
        if "low" in prior_params:
            low = prior_params["low"]
        else:
            raise ValueError("`low` could not be inferred from prior")
    if high is None:
        if "high" in prior_params:
            high = prior_params["high"]
        else:
            raise ValueError("`high` could not be inferred from prior")

    dim_parameters = task.dim_parameters
    assert len(low) == dim_parameters
    assert len(high) == dim_parameters

    # Apply eps to bounds to avoid NaN evaluations
    low += eps
    high -= eps

    # Construct grid
    grid = torch.stack(
        torch.meshgrid(
            [torch.linspace(low[d], high[d], resolution) for d in range(dim_parameters)]
        )
    )  # dim_parameters x resolution x ... x resolution
    grid_flat = grid.view(
        dim_parameters, -1
    ).T  # resolution^dim_parameters x dim_parameters

    # Get log probability function (unnormalized log posterior)
    log_prob_fn = task._get_log_prob_fn(
        num_observation=num_observation,
        observation=observation,
        implementation="experimental",
        posterior=True,
        **kwargs,
    )

    total_evaluations = grid_flat.shape[0]
    log.info(f"Total evaluations: {total_evaluations}")

    batch_size = min(batch_size, total_evaluations)
    num_batches = int(total_evaluations / batch_size)

    log_probs = torch.empty([resolution for _ in range(dim_parameters)])
    for i in tqdm(range(num_batches)):
        ix_from = i * batch_size
        ix_to = ix_from + batch_size
        if ix_to > total_evaluations:
            ix_to = total_evaluations
        log_probs.view(-1)[ix_from:ix_to] = log_prob_fn(grid_flat[ix_from:ix_to, :])

    if save:
        log.info("Saving grid and log probs")
        torch.save(grid, "grid.pt")
        torch.save(log_probs, "log_probs.pt")

    probs = torch.exp(log_probs.view(-1))
    indices = torch.arange(0, len(probs))
    idxs = choice(indices, num_samples, True, probs)
    samples = grid_flat[idxs, :]
    num_unique_samples = len(torch.unique(samples, dim=0))
    log.info(f"Unique samples: {num_unique_samples}")

    toc = time.time()
    log.info(f"Finished after {toc-tic:.3f} seconds")

    return samples
Ejemplo n.º 8
0
    def _sample_reference_posterior(
        self,
        num_samples: int,
        num_observation: int,
        observation: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """Sample reference posterior for given observation

        Uses closed form solution

        Args:
            num_samples: Number of samples to generate
            num_observation: Observation number
            observation: Observed data, if None, will be loaded using `num_observation`

        Returns:
            Samples from reference posterior
        """
        log = sbibm.get_logger(__name__)

        if observation is None:
            observation = self.get_observation(num_observation)

        reference_posterior_samples = []

        ang = torch.tensor([-math.pi / 4.0])
        c = torch.cos(-ang)
        s = torch.sin(-ang)

        simulator = self.get_simulator()

        reference_posterior_samples = []
        counter = 0
        while len(reference_posterior_samples) < num_samples:
            counter += 1

            p = simulator(torch.zeros(1, 2))
            q = torch.zeros(2)
            q[0] = p[0, 0] - observation[0, 0]
            q[1] = observation[0, 1] - p[0, 1]

            if np.random.rand() < 0.5:
                q[0] = -q[0]

            sample = torch.tensor([[c * q[0] - s * q[1], s * q[0] + c * q[1]]])

            is_outside_prior = torch.isinf(self.prior_dist.log_prob(sample).sum())

            if len(reference_posterior_samples) > 0:
                is_duplicate = sample in torch.cat(reference_posterior_samples)
            else:
                is_duplicate = False

            if not is_outside_prior and not is_duplicate:
                reference_posterior_samples.append(sample)

        reference_posterior_samples = torch.cat(reference_posterior_samples)
        acceptance_rate = float(num_samples / counter)

        log.info(
            f"Acceptance rate for observation {num_observation}: {acceptance_rate}"
        )

        return reference_posterior_samples
Ejemplo n.º 9
0
def run(
    task: Task,
    num_samples: int,
    num_simulations: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    batch_size: int = 100000,
    proposal_dist: Optional[DenfensiveProposal] = None,
    **kwargs: Any,
) -> torch.Tensor:
    """Random samples from Sequential Importance Resampling (SIR) as a baseline

    SIR is also referred to as weighted bootstrap [1]. The prior is used as a proposal,
    so that the weights become the likelihood, this has also been referred to as
    likelihood weighting in the literature.

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        batch_size: Batch size for simulations
        proposal_dist: If specified, will be used as a proposal distribution instead
            of prior
        kwargs: Not used

    Returns:
        Random samples from reference posterior

    [1] A. F. M. Smith and A. E. Gelfand. Bayesian statistics without tears: a
    sampling-resampling perspective. The American Statistician, 46(2):84-88, 1992.
    doi:10.1080/00031305.1992.10475856.
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    tic = time.time()
    log = sbibm.get_logger(__name__)
    log.info("Sequential Importance Resampling (SIR)")

    prior_dist = task.get_prior_dist()

    if proposal_dist is None:
        proposal_dist = prior_dist

    log_prob_fn = task._get_log_prob_fn(
        num_observation=num_observation,
        observation=observation,
        implementation="experimental",
        posterior=True,
    )

    batch_size = min(batch_size, num_simulations)
    num_batches = int(num_simulations / batch_size)

    particles = []
    log_weights = []
    for i in tqdm(range(num_batches)):
        batch_draws = proposal_dist.sample((batch_size, ))
        log_weights.append(
            log_prob_fn(batch_draws) - proposal_dist.log_prob(batch_draws))
        particles.append(batch_draws)
    log.info("Finished sampling")

    particles = torch.cat(particles)
    log_weights = torch.cat(log_weights)
    probs = torch.exp(log_weights.view(-1))
    probs /= probs.sum()

    indices = torch.arange(0, len(probs))
    idxs = choice(indices, num_samples, True, probs)
    samples = particles[idxs, :]
    log.info("Finished resampling")

    num_unique = torch.unique(samples, dim=0).shape[0]
    log.info(f"Unique particles: {num_unique} out of {len(samples)}")

    toc = time.time()
    log.info(f"Finished after {toc-tic:.3f} seconds")

    return samples
Ejemplo n.º 10
0
def run(
    task: Task,
    num_samples: int,
    num_simulations: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    population_size: Optional[int] = None,
    distance: str = "l2",
    epsilon_decay: float = 0.2,
    distance_based_decay: bool = True,
    ess_min: Optional[float] = None,
    initial_round_factor: int = 5,
    batch_size: int = 1000,
    kernel: str = "gaussian",
    kernel_variance_scale: float = 0.5,
    use_last_pop_samples: bool = True,
    algorithm_variant: str = "C",
    save_summary: bool = False,
    sass: bool = False,
    sass_fraction: float = 0.5,
    sass_feature_expansion_degree: int = 3,
    lra: bool = False,
    lra_sample_weights: bool = True,
    kde_bandwidth: Optional[str] = "cv",
    kde_sample_weights: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
    """Runs SMC-ABC from `sbi`

    SMC-ABC supports two different ways of scheduling epsilon:
    1) Exponential decay: eps_t+1 = epsilon_decay * eps_t
    2) Distance based decay: the new eps is determined from the "epsilon_decay" 
        quantile of the distances of the accepted simulations in the previous population. This is used if `distance_based_decay` is set to True.

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_simulations: Simulation budget
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        population_size: If None, uses heuristic: 1000 if `num_simulations` is greater
            than 10k, else 100
        distance: Distance function, options = {l1, l2, mse}
        epsilon_decay: Decay for epsilon; treated as quantile in case of distance based decay.
        distance_based_decay: Whether to determine new epsilon from quantile of
            distances of the previous population.
        ess_min: Threshold for resampling a population if effective sampling size is 
            too small.
        initial_round_factor: Used to determine initial round size
        batch_size: Batch size for the simulator
        kernel: Kernel distribution used to perturb the particles.
        kernel_variance_scale: Scaling factor for kernel variance.
        use_last_pop_samples: If True, samples of a population that was quit due to
            budget are used by filling up missing particles from the previous
            population.
        algorithm_variant: There are three SMCABC variants implemented: A, B, and C.
            See doctstrings in SBI package for more details.
        save_summary: Whether to save a summary containing all populations, distances,
            etc. to file.
        sass: If True, summary statistics are learned as in
            Fearnhead & Prangle 2012.
        sass_fraction: Fraction of simulation budget to use for sass.
        sass_feature_expansion_degree: Degree of polynomial expansion of the summary
            statistics.
        lra: If True, posterior samples are adjusted with
            linear regression as in Beaumont et al. 2002.
        lra_sample_weights: Whether to weigh LRA samples
        kde_bandwidth: If not None, will resample using KDE when necessary, set
            e.g. to "cv" for cross-validated bandwidth selection
        kde_sample_weights: Whether to weigh KDE samples


    Returns:
        Samples from posterior, number of simulator calls, log probability of true params if computable
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)

    log = sbibm.get_logger(__name__)
    smc_papers = dict(A="Toni 2010", B="Sisson et al. 2007", C="Beaumont et al. 2009")
    log.info(f"Running SMC-ABC as in {smc_papers[algorithm_variant]}.")

    prior = task.get_prior_dist()
    simulator = task.get_simulator(max_calls=num_simulations)
    if observation is None:
        observation = task.get_observation(num_observation)

    if population_size is None:
        population_size = 100
        if num_simulations > 10_000:
            population_size = 1000

    population_size = min(population_size, num_simulations)

    initial_round_size = clip_int(
        value=initial_round_factor * population_size,
        minimum=population_size,
        maximum=max(0.5 * num_simulations, population_size),
    )

    inference_method = SMCABC(
        simulator=simulator,
        prior=prior,
        simulation_batch_size=batch_size,
        distance=distance,
        show_progress_bars=True,
        kernel=kernel,
        algorithm_variant=algorithm_variant,
    )
    posterior, summary = inference_method(
        x_o=observation,
        num_particles=population_size,
        num_initial_pop=initial_round_size,
        num_simulations=num_simulations,
        epsilon_decay=epsilon_decay,
        distance_based_decay=distance_based_decay,
        ess_min=ess_min,
        kernel_variance_scale=kernel_variance_scale,
        use_last_pop_samples=use_last_pop_samples,
        return_summary=True,
        lra=lra,
        lra_with_weights=lra_sample_weights,
        sass=sass,
        sass_fraction=sass_fraction,
        sass_expansion_degree=sass_feature_expansion_degree,
    )

    if save_summary:
        log.info("Saving smcabc summary to csv.")
        pd.DataFrame.from_dict(summary,).to_csv("summary.csv", index=False)

    assert simulator.num_simulations == num_simulations

    if kde_bandwidth is not None:
        samples = posterior._samples

        log.info(
            f"KDE on {samples.shape[0]} samples with bandwidth option {kde_bandwidth}"
        )

        kde = get_kde(
            samples,
            bandwidth=kde_bandwidth,
            sample_weight=posterior._log_weights.exp() if kde_sample_weights else None,
        )
        samples = kde.sample(num_samples)
    else:
        samples = posterior.sample((num_samples,)).detach()

    if num_observation is not None:
        true_parameters = task.get_true_parameters(num_observation=num_observation)
        log_prob_true_parameters = posterior.log_prob(true_parameters)
        return samples, simulator.num_simulations, log_prob_true_parameters
    else:
        return samples, simulator.num_simulations, None
Ejemplo n.º 11
0
def run(
    task: Task,
    num_samples: int,
    num_simulations: int,
    num_observation: Optional[int] = None,
    observation: Optional[torch.Tensor] = None,
    population_size: Optional[int] = None,
    distance: Optional[str] = "l2",
    initial_round_factor: int = 5,
    batch_size: int = 1000,
    epsilon_decay: Optional[float] = 0.5,
    kernel: Optional[str] = "gaussian",
    kernel_variance_scale: Optional[float] = 0.5,
    population_strategy: Optional[str] = "constant",
    use_last_pop_samples: bool = False,
    num_workers: int = 1,
    sass: bool = False,
    sass_sample_weights: bool = False,
    sass_feature_expansion_degree: int = 1,
    sass_fraction: float = 0.5,
    lra: bool = False,
    lra_sample_weights: bool = True,
    kde_bandwidth: Optional[str] = None,
    kde_sample_weights: bool = False,
) -> Tuple[torch.Tensor, int, Optional[torch.Tensor]]:
    """ABC-SMC using pyabc toolbox

    Args:
        task: Task instance
        num_samples: Number of samples to generate from posterior
        num_simulations: Simulation budget
        num_observation: Observation number to load, alternative to `observation`
        observation: Observation, alternative to `num_observation`
        population_size: If None, uses heuristic: 1000 if `num_simulations` is greater
            than 10k, else 100
        distance: Distance function, options = {l1, l2, mse}
        epsilon_decay: Decay for epsilon, quantile based.
        kernel: Kernel distribution used to perturb the particles.
        kernel_variance_scale: Scaling factor for kernel variance.
        sass: If True, summary statistics are learned as in
            Fearnhead & Prangle 2012.
        sass_sample_weights: Whether to weigh SASS samples
        sass_feature_expansion_degree: Degree of polynomial expansion of the summary
            statistics.
        sass_fraction: Fraction of simulation budget to use for sass.
        lra: If True, posterior samples are adjusted with
            linear regression as in Beaumont et al. 2002.
        lra_sample_weights: Whether to weigh LRA samples
        kde_bandwidth: If not None, will resample using KDE when necessary, set
            e.g. to "cv" for cross-validated bandwidth selection
        kde_sample_weights: Whether to weigh KDE samples
    Returns:
        Samples from posterior, number of simulator calls, log probability of true params if computable
    """
    assert not (num_observation is None and observation is None)
    assert not (num_observation is not None and observation is not None)
    log = sbibm.get_logger(__name__)
    db = "sqlite:///" + os.path.join(
        tempfile.gettempdir(),
        f"pyabc_{time.time()}_{random.randint(0, 1e9)}.db")

    # Wrap sbibm prior and simulator for pyABC
    prior = wrap_prior(task)
    simulator = PyAbcSimulator(task)
    distance_str = distance

    if observation is None:
        observation = task.get_observation(num_observation)

    # Population size strategy
    if population_size is None:
        population_size = 100
        if num_simulations > 10_000:
            population_size = 1000

    # Find initial epsilon with rej abc run.
    initial_round_size = clip_int(
        value=initial_round_factor * population_size,
        minimum=population_size,
        maximum=max(0.5 * num_simulations, population_size),
    )
    log.info(
        f"Running REJ-ABC with {initial_round_size} samples to find initial epsilon."
    )
    _, distances = run_rejection_abc(task, initial_round_size, population_size,
                                     observation, distance_str, batch_size)
    initial_epsilon = distances[-1].item()

    # Wrap observation and distance for pyabc.
    distance = get_distance(distance_str)
    observation = np.atleast_1d(np.array(observation, dtype=float).squeeze())
    # Define quantile based epsilon decay.
    epsilon = pyabc.epsilon.QuantileEpsilon(initial_epsilon=initial_epsilon,
                                            alpha=epsilon_decay)

    # Perturbation kernel
    transition = pyabc.transition.MultivariateNormalTransition(
        scaling=kernel_variance_scale)

    population_size = min(population_size, num_simulations)

    if population_strategy == "constant":
        population_size_strategy = population_size
    elif population_strategy == "adaptive":
        raise NotImplementedError("Not implemented atm.")
        population_size_strategy = pyabc.populationstrategy.AdaptivePopulationSize(
            start_nr_particles=population_size,
            max_population_size=int(10 * population_size),
            min_population_size=int(0.1 * population_size),
        )

    # Multiprocessing
    if num_workers > 1:
        sampler = pyabc.sampler.MulticoreParticleParallelSampler(
            n_procs=num_workers)
    else:
        sampler = pyabc.sampler.SingleCoreSampler(check_max_eval=False)

    # Collect kwargs
    kwargs = dict(
        parameter_priors=[prior],
        distance_function=distance,
        population_size=population_size_strategy,
        transitions=[transition],
        eps=epsilon,
        sampler=sampler,
    )

    # Semi-automatic summary statistics.
    if sass:
        num_pilot_simulations = int(sass_fraction * num_simulations)
        log.info(f"SASS pilot run with {num_pilot_simulations} simulations.")
        kwargs["models"] = [simulator]

        # Run pyabc with fixed budget.
        pilot_theta, pilot_weights = run_pyabc(
            task,
            db,
            num_pilot_simulations,
            observation,
            pyabc_kwargs=kwargs,
            use_last_pop_samples=use_last_pop_samples,
            distance_str=distance_str,
            batch_size=batch_size,
        )

        # Regression
        # TODO: Posterior does not return xs, which we would need for
        # regression adjustment. So we will resimulate, which is
        # unneccessary. Should ideally change `inference_method` to return xs
        # if requested instead. This step thus does not count towards budget
        pilot_x = task.get_simulator(max_calls=None)(pilot_theta)

        # Run SASS.
        sumstats_transform = get_sass_transform(
            theta=pilot_theta,
            x=pilot_x,
            expansion_degree=sass_feature_expansion_degree,
            sample_weight=pilot_weights if sass_sample_weights else None,
        )

        # Update simulator to use sass summary stats.
        def sumstats_simulator(theta):
            # Pyabc simulator returns dict.
            x = simulator(theta)["data"].reshape(1, -1)
            # Transform return Tensor.
            sx = sumstats_transform(x)
            return dict(data=sx.numpy().squeeze())

        observation = sumstats_transform(observation.reshape(1, -1))
        observation = np.atleast_1d(
            np.array(observation, dtype=float).squeeze())
        log.info(f"Finished learning summary statistics.")
    else:
        sumstats_simulator = simulator
        num_pilot_simulations = 0
        population_size = min(population_size, num_simulations)

    log.info("""Running ABC-SMC-pyabc with {} simulations""".format(
        num_simulations - num_pilot_simulations))
    kwargs["models"] = [sumstats_simulator]

    # Run pyabc with fixed budget.
    particles, weights = run_pyabc(
        task,
        db,
        num_simulations=num_simulations - num_pilot_simulations,
        observation=observation,
        pyabc_kwargs=kwargs,
        use_last_pop_samples=use_last_pop_samples,
        distance_str=distance_str,
        batch_size=batch_size,
    )

    if lra:
        log.info(f"Running linear regression adjustment.")
        # TODO: Posterior does not return xs, which we would need for
        # regression adjustment. So we will resimulate, which is
        # unneccessary. Should ideally change `inference_method` to return xs
        # if requested instead.
        xs = task.get_simulator(max_calls=None)(particles)

        # NOTE: If posterior is bounded we should do the regression in
        # unbounded space, as described in https://arxiv.org/abs/1707.01254
        transform_to_unbounded = True
        transforms = task._get_transforms(transform_to_unbounded)["parameters"]

        # Update the particles with LRA.
        particles = run_lra(
            theta=particles,
            x=xs,
            observation=torch.tensor(observation,
                                     dtype=torch.float32).unsqueeze(0),
            sample_weight=weights if lra_sample_weights else None,
            transforms=transforms,
        )

        # TODO: Maybe set weights uniform because they can't be updated?
        # weights = torch.ones(particles.shape[0]) / particles.shape[0]

    if kde_bandwidth is not None:
        samples = particles

        log.info(
            f"KDE on {samples.shape[0]} samples with bandwidth option {kde_bandwidth}"
        )
        kde = get_kde(
            samples,
            bandwidth=kde_bandwidth,
            sample_weight=weights if kde_sample_weights else None,
        )
        samples = kde.sample(num_samples)
    else:
        log.info(f"Sampling {num_samples} samples from trace")
        samples = sample_with_weights(particles,
                                      weights,
                                      num_samples=num_samples)

    log.info(f"Unique samples: {torch.unique(samples, dim=0).shape[0]}")

    return samples, simulator.simulator.num_simulations, None
Ejemplo n.º 12
0
def run_pyabc(
    task: Task,
    db,
    num_simulations: int,
    observation: np.ndarray,
    pyabc_kwargs: dict,
    distance_str: str = "l2",
    batch_size: int = 1000,
    use_last_pop_samples: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Run pyabc SMC with fixed budget and return particles and weights.
    
    Return previous population or prior samples if budget is exceeded.
    """
    log = sbibm.get_logger(__name__)

    abc = pyabc.ABCSMC(**pyabc_kwargs)
    abc.new(db, {"data": observation})
    history = abc.run(max_total_nr_simulations=num_simulations)
    num_calls = history.total_nr_simulations

    if num_calls < 1.0 * num_simulations:
        (particles_df, weights) = history.get_distribution(t=history.max_t)
        particles = torch.as_tensor(particles_df.values, dtype=torch.float32)
        weights = torch.as_tensor(weights, dtype=torch.float32)
    else:
        if history.max_t > 0:
            log.info(
                f"Last population exceeded budget by {num_calls - num_simulations}."
            )
            (particles_df,
             weights) = history.get_distribution(t=history.max_t - 1)
            old_particles = torch.as_tensor(particles_df.values,
                                            dtype=torch.float32)
            old_weights = torch.as_tensor(weights, dtype=torch.float32)
            if use_last_pop_samples:
                df = history.get_all_populations()
                num_calls_last_pop = df.samples.values[-1]
                over_budget = num_calls - num_simulations
                proportion_over_budget = over_budget / num_calls_last_pop
                # The proportion over budget needs to be replaced with old particles.
                num_old_particles = int(
                    np.ceil(proportion_over_budget *
                            pyabc_kwargs["population_size"]))
                log.info(
                    f"Filling up with {num_old_particles+1} samples from previous population."
                )
                # Combining populations.
                (particles_df,
                 weights) = history.get_distribution(t=history.max_t)
                new_particles = torch.as_tensor(particles_df.values,
                                                dtype=torch.float32)
                new_weights = torch.as_tensor(weights, dtype=torch.float32)

                particles = torch.zeros_like(old_particles)
                weights = torch.zeros_like(old_weights)
                particles[:
                          num_old_particles] = old_particles[:
                                                             num_old_particles]
                particles[num_old_particles:] = new_particles[
                    num_old_particles:]
                weights[:num_old_particles] = old_weights[:num_old_particles]
                weights[num_old_particles:] = new_weights[num_old_particles:]
                # Normalize combined weights.
                weights /= weights.sum()
            else:
                log.info("Returning previous population.")
                particles = old_particles
                weights = old_weights
        else:
            log.info(
                "Running REJABC because first population exceeded budget.")
            posterior, _ = run_rejection_abc(
                task,
                num_simulations,
                pyabc_kwargs["population_size"],
                observation=torch.tensor(observation, dtype=torch.float32),
                distance=distance_str,
                batch_size=batch_size,
            )
            particles = posterior._samples
            weights = posterior._log_weights.exp()

    return particles, weights