Exemple #1
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    samples = empirical_dist.sample(sample_shape=sample_shape)
    assert_equal(samples.size(), sample_shape + batch_shape)
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Exemple #2
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    samples = empirical_dist.sample(sample_shape=sample_shape)
    assert_equal(samples.size(), sample_shape + batch_shape)
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Exemple #3
0
def test_sample_examples(sample, weights, expected_mean, expected_var):
    emp_dist = Empirical(sample, weights)
    num_samples = 10000
    assert_equal(emp_dist.mean, expected_mean)
    assert_equal(emp_dist.variance, expected_var)
    emp_samples = emp_dist.sample((num_samples,))
    assert_close(emp_samples.mean(0), emp_dist.mean, rtol=1e-2)
    assert_close(emp_samples.var(0), emp_dist.variance, rtol=1e-2)
Exemple #4
0
def test_unweighted_mean_and_var(size, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(size, dtype=dtype) * i)
    true_mean = torch.ones(size) * 2
    true_var = torch.ones(size) * 2
    assert_equal(empirical_dist.mean, true_mean)
    assert_equal(empirical_dist.variance, true_var)
Exemple #5
0
def test_unweighted_mean_and_var(size, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(size, dtype=dtype) * i)
    true_mean = torch.ones(size) * 2
    true_var = torch.ones(size) * 2
    assert_equal(empirical_dist.mean, true_mean)
    assert_equal(empirical_dist.variance, true_var)
Exemple #6
0
def test_unweighted_samples(batch_shape, sample_shape, dtype):
    samples = []
    for i in range(5):
        samples.append(torch.ones(batch_shape, dtype=dtype) * i)
    samples = torch.stack(samples)
    empirical_dist = Empirical(samples, torch.ones(5))
    samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape))
    assert_equal(samples.size(), torch.Size(sample_shape + batch_shape))
    assert_equal(set(samples.view(-1).tolist()), set(range(5)))
Exemple #7
0
def test_unweighted_samples(batch_shape, event_shape, sample_shape, dtype):
    agg_dim_size = 5
    # empirical samples with desired shape
    dim_ordering = list(range(len(batch_shape + event_shape) + 1))  # +1 for agg dim
    dim_ordering.insert(len(batch_shape), dim_ordering.pop())
    emp_samples = torch.arange(agg_dim_size, dtype=dtype)\
        .expand(batch_shape + event_shape + [agg_dim_size])\
        .permute(dim_ordering)
    # initial weight assignment
    weights = torch.ones(batch_shape + [agg_dim_size])
    empirical_dist = Empirical(emp_samples, weights)
    samples = empirical_dist.sample(sample_shape=torch.Size(sample_shape))
    assert_equal(samples.size(), torch.Size(sample_shape + batch_shape + event_shape))
Exemple #8
0
def test_weighted_mean_var(event_shape, dtype):
    samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    empirical_dist = Empirical()
    for sample, weight in samples:
        empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight)
    if dtype in (torch.float32, torch.float64):
        true_mean = torch.ones(event_shape, dtype=dtype) * 0.25
        true_var = torch.ones(event_shape, dtype=dtype) * 0.1875
        assert_equal(empirical_dist.mean, true_mean)
        assert_equal(empirical_dist.variance, true_var)
    else:
        with pytest.raises(ValueError):
            empirical_dist.mean
            empirical_dist.variance
Exemple #9
0
def test_weighted_mean_var(event_shape, dtype):
    samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    empirical_dist = Empirical()
    for sample, weight in samples:
        empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype),
                           weight=weight)
    if dtype in (torch.float32, torch.float64):
        true_mean = torch.ones(event_shape, dtype=dtype) * 0.25
        true_var = torch.ones(event_shape, dtype=dtype) * 0.1875
        assert_equal(empirical_dist.mean, true_mean)
        assert_equal(empirical_dist.variance, true_var)
    else:
        with pytest.raises(ValueError):
            empirical_dist.mean
            empirical_dist.variance
Exemple #10
0
 def posteriorLearning(self, round):
     if self.args.algorithm != 'SMC':
         self.posterior = self.inference(
             num_simulations=self.args.simulation_budget_per_round,
             proposal=self.proposal,
             validation_fraction=self.args.validationRatio,
             device=self.args.device)
         self.posteriors.append(self.posterior)
         self.proposal = self.posterior.set_default_x(self.observation)
     elif self.args.algorithm == 'SMC':
         if round == 0:
             self.posterior, self.summary = self.inference(
                 x_o=self.observation,
                 num_particles=self.args.simulation_budget_per_round,
                 num_initial_pop=self.args.simulation_budget_per_round,
                 num_simulations=self.args.simulation_budget_per_round *
                 self.args.numRound,
                 epsilon_decay=0.9,
                 return_summary=True)
             #self.posterior = Empirical(self.summary['particles'][self.get_idx(round * self.args.simulation_budget_per_round, self.summary['budgets'])],
             #                           log_weights=self.summary['weights'][self.get_idx(round * self.args.simulation_budget_per_round, self.summary['budgets'])])
             self.posterior = Empirical(
                 self.summary['particles'][-1],
                 log_weights=self.summary['weights'][-1])
             self.posteriors.append(self.posterior)
             print("Plotting Start")
             self.plot(round)
             print("Logging Start")
             self.log(round)
Exemple #11
0
def test_weighted_sample_coherence(event_shape, dtype):
    samples = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    empirical_dist = Empirical()
    for sample, weight in samples:
        empirical_dist.add(sample * torch.ones(event_shape, dtype=dtype), weight=weight)
    assert_equal(empirical_dist.event_shape, event_shape)
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
Exemple #12
0
def test_mean_var_non_nan():
    true_mean = torch.randn([1, 2, 3])
    samples, weights = [], []
    for i in range(10):
        samples.append(true_mean)
        weights.append(torch.tensor(-1000.))
    samples, weights = torch.stack(samples), torch.stack(weights)
    empirical_dist = Empirical(samples, weights)
    assert_equal(empirical_dist.mean, true_mean)
    assert_equal(empirical_dist.variance, torch.zeros_like(true_mean))
Exemple #13
0
def test_unweighted_mean_and_var(size, dtype):
    samples = []
    for i in range(5):
        samples.append(torch.ones(size, dtype=dtype) * i)
    samples = torch.stack(samples)
    empirical_dist = Empirical(samples, torch.ones(5, dtype=dtype))
    true_mean = torch.ones(size) * 2
    true_var = torch.ones(size) * 2
    assert_equal(empirical_dist.mean, true_mean)
    assert_equal(empirical_dist.variance, true_var)
Exemple #14
0
def test_weighted_sample_coherence(event_shape, dtype):
    data = [(1.0, 0.5), (0.0, 1.5), (1.0, 0.5), (0.0, 1.5)]
    samples, weights = [], []
    for sample, weight in data:
        samples.append(sample * torch.ones(event_shape, dtype=dtype))
        weights.append(torch.tensor(weight).log())
    samples, weights = torch.stack(samples), torch.stack(weights)
    empirical_dist = Empirical(samples, weights)
    assert_equal(empirical_dist.event_shape, torch.Size(event_shape))
    assert_equal(empirical_dist.sample_size, 4)
    sample_to_score = torch.ones(event_shape, dtype=dtype) * 1.0
    assert_equal(empirical_dist.log_prob(sample_to_score), torch.tensor(0.25).log())
    samples = empirical_dist.sample(sample_shape=torch.Size((1000,)))
    zeros = torch.zeros(event_shape, dtype=dtype)
    ones = torch.ones(event_shape, dtype=dtype)
    num_zeros = samples.eq(zeros).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    num_ones = samples.eq(ones).contiguous().view(1000, -1).min(dim=-1)[0].float().sum()
    assert_equal(num_zeros.item() / 1000, 0.75, prec=0.02)
    assert_equal(num_ones.item() / 1000, 0.25, prec=0.02)
Exemple #15
0
def test_log_prob(batch_shape, event_shape, dtype):
    samples = []
    for i in range(5):
        samples.append(torch.ones(event_shape, dtype=dtype) * i)
    samples = torch.stack(samples).expand(batch_shape + [5] + event_shape)
    weights = torch.tensor(1.0).expand(batch_shape + [5])
    empirical_dist = Empirical(samples, weights)
    sample_to_score = torch.tensor(1, dtype=dtype).expand(batch_shape +
                                                          event_shape)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob,
                 (weights.new_ones(batch_shape + [1]) * 0.2).sum(-1).log())

    # Value outside support returns -Inf
    sample_to_score = torch.tensor(
        1, dtype=dtype).expand(batch_shape + event_shape) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob.shape == torch.Size(batch_shape)
    assert torch.isinf(log_prob).all()

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones([3] + batch_shape + event_shape,
                                     dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
Exemple #16
0
def test_log_prob(batch_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    sample_to_score = torch.ones(batch_shape, dtype=dtype)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob, torch.tensor(0.2).log())

    # Value outside support returns -Inf
    sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob == -float("inf")

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones((3, ) + batch_shape, dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
Exemple #17
0
def test_log_prob(batch_shape, dtype):
    empirical_dist = Empirical()
    for i in range(5):
        empirical_dist.add(torch.ones(batch_shape, dtype=dtype) * i)
    sample_to_score = torch.ones(batch_shape, dtype=dtype)
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert_equal(log_prob, torch.tensor(0.2).log())

    # Value outside support returns -Inf
    sample_to_score = torch.ones(batch_shape, dtype=dtype) * 6
    log_prob = empirical_dist.log_prob(sample_to_score)
    assert log_prob == -float("inf")

    # Vectorized ``log_prob`` raises ValueError
    with pytest.raises(ValueError):
        sample_to_score = torch.ones((3,) + batch_shape, dtype=dtype)
        empirical_dist.log_prob(sample_to_score)
Exemple #18
0
def test_weighted_mean_var(event_shape, dtype, batch_shape):
    data = [(1, 0.5), (0, 1.5), (1, 0.5), (0, 1.5)]
    samples, weights = [], []
    for sample, weight in data:
        samples.append(sample * torch.ones(event_shape, dtype=dtype))
        weight_dtype = dtype if dtype is not torch.long else None
        weights.append(torch.tensor(weight, dtype=weight_dtype).log())
    samples = torch.stack(samples).expand(batch_shape + [4] + event_shape)
    weights = torch.stack(weights).expand(batch_shape + [4])
    empirical_dist = Empirical(samples, weights)
    if dtype in (torch.float32, torch.float64):
        true_mean = torch.ones(batch_shape + event_shape, dtype=dtype) * 0.25
        true_var = torch.ones(batch_shape + event_shape, dtype=dtype) * 0.1875
        assert_equal(empirical_dist.mean, true_mean)
        assert_equal(empirical_dist.variance, true_var)
    else:
        with pytest.raises(ValueError):
            empirical_dist.mean
            empirical_dist.variance
Exemple #19
0
    def __call__(
        self,
        x_o: Union[Tensor, ndarray],
        num_simulations: int,
        eps: Optional[float] = None,
        quantile: Optional[float] = None,
        return_distances: bool = False,
    ) -> Union[Distribution, Tuple[Distribution, Tensor]]:
        r"""Run MCABC.

        Args:
            x_o: Observed data.
            num_simulations: Number of simulations to run.
            eps: Acceptance threshold $\epsilon$ for distance between observed and
                simulated data.
            quantile: Upper quantile of smallest distances for which the corresponding
                parameters are returned, e.g, q=0.01 will return the top 1%. Exactly
                one of quantile or `eps` have to be passed.
            return_distances: Whether to return the distances corresponding to the
                selected parameters.
        Returns:
            posterior: Empirical distribution based on selected parameters.
            distances: Tensor of distances of the selected parameters.
        """
        # Exactly one of eps or quantile need to be passed.
        assert (eps is not None) ^ (
            quantile
            is not None), "Eps or quantile must be passed, but not both."

        # Simulate and calculate distances.
        theta = self.prior.sample((num_simulations, ))
        x = self._batched_simulator(theta)

        # Infer shape of x to test and set x_o.
        self.x_shape = x[0].unsqueeze(0).shape
        self.x_o = process_x(x_o, self.x_shape)

        distances = self.distance(self.x_o, x)

        # Select based on acceptance threshold epsilon.
        if eps is not None:
            is_accepted = distances < eps
            num_accepted = is_accepted.sum().item()
            assert num_accepted > 0, f"No parameters accepted, eps={eps} too small"

            theta_accepted = theta[is_accepted]
            distances_accepted = distances[is_accepted]

        # Select based on quantile on sorted distances.
        elif quantile is not None:
            num_top_samples = int(num_simulations * quantile)
            sort_idx = torch.argsort(distances)
            theta_accepted = theta[sort_idx][:num_top_samples]
            distances_accepted = distances[sort_idx][:num_top_samples]

        else:
            raise ValueError("One of epsilon or quantile has to be passed.")

        posterior = Empirical(theta_accepted,
                              log_weights=ones(theta_accepted.shape[0]))

        if return_distances:
            return posterior, distances_accepted
        else:
            return posterior
    def __call__(
        self,
        x_o: Union[Tensor, ndarray],
        num_particles: int,
        num_initial_pop: int,
        num_simulations: int,
        epsilon_decay: float,
        distance_based_decay: bool = False,
        ess_min: float = 0.5,
        kernel_variance_scale: float = 1.0,
        use_last_pop_samples: bool = True,
        return_summary: bool = False,
    ) -> Union[Distribution, Tuple[Distribution, dict]]:
        r"""Run SMCABC.

        Args:
            x_o: Observed data.
            num_particles: Number of particles in each population.
            num_initial_pop: Number of simulations used for initial population.
            num_simulations: Total number of possible simulations.
            epsilon_decay: Factor with which the acceptance threshold $\epsilon$ decays.
            distance_based_decay: Whether the $\epsilon$ decay is constant over
                populations or calculated from the previous populations distribution of
                distances.
            ess_min: Threshold of effective sampling size for resampling weights.
            kernel_variance_scale: Factor for scaling the perturbation kernel variance.
            use_last_pop_samples: Whether to fill up the current population with
                samples from the previous population when the budget is used up. If
                False, the current population is discarded and the previous population
                is returned.
            return_summary: Whether to return a dictionary with all accepted particles, 
                weights, etc. at the end.

        Returns:
            posterior: Empirical posterior distribution defined by the accepted
                particles and their weights.
            summary (optional): A dictionary containing particles, weights, epsilons
                and distances of each population.
        """

        pop_idx = 0
        self.num_simulations = num_simulations

        # run initial population
        particles, epsilon, distances = self._set_xo_and_sample_initial_population(
            x_o, num_particles, num_initial_pop
        )
        log_weights = torch.log(1 / num_particles * ones(num_particles))

        self.logger.info(
            (
                f"population={pop_idx}, eps={epsilon}, ess={1.0}, "
                f"num_sims={num_initial_pop}"
            )
        )

        all_particles = [particles]
        all_log_weights = [log_weights]
        all_distances = [distances]
        all_epsilons = [epsilon]
        all_budgets = [0]

        while self.simulation_counter < num_simulations:

            pop_idx += 1
            # Decay based on quantile of distances from previous pop.
            if distance_based_decay:
                epsilon = self._get_next_epsilon(
                    all_distances[pop_idx - 1], epsilon_decay
                )
            # Constant decay.
            else:
                epsilon = epsilon * epsilon_decay

            # Get kernel variance from previous pop.
            self.kernel_variance = self.get_kernel_variance(
                all_particles[pop_idx - 1],
                torch.exp(all_log_weights[pop_idx - 1]),
                num_samples=1000,
                kernel_variance_scale=kernel_variance_scale,
            )
            particles, log_weights, distances = self._sample_next_population(
                particles=all_particles[pop_idx - 1],
                log_weights=all_log_weights[pop_idx - 1],
                distances=all_distances[pop_idx - 1],
                epsilon=epsilon,
                use_last_pop_samples=use_last_pop_samples,
            )

            # Resample population if effective sampling size is too small.
            if self.algorithm_variant == "B":
                particles, log_weights = self.resample_if_ess_too_small(
                    particles, log_weights, num_particles, ess_min, pop_idx
                )

            self.logger.info(
                (
                    f"population={pop_idx} done: eps={epsilon:.6f},"
                    f" num_sims={self.simulation_counter}."
                )
            )

            # collect results
            all_particles.append(particles)
            all_log_weights.append(log_weights)
            all_distances.append(distances)
            all_epsilons.append(epsilon)
            all_budgets.append(self.simulation_counter)

        posterior = Empirical(all_particles[-1], log_weights=all_log_weights[-1])

        if return_summary:
            return (
                posterior,
                dict(
                    particles=all_particles,
                    weights=all_log_weights,
                    epsilons=all_epsilons,
                    distances=all_distances,
                    budgets=all_budgets
                ),
            )
        else:
            return posterior
Exemple #21
0
    def __call__(
        self,
        x_o: Union[Tensor, ndarray],
        num_simulations: int,
        eps: Optional[float] = None,
        quantile: Optional[float] = None,
        return_distances: bool = False,
        return_x_accepted: bool = False,
        lra: bool = False,
        sass: bool = False,
        sass_fraction: float = 0.25,
        sass_expansion_degree: int = 1,
    ) -> Union[Distribution, Tuple[Distribution, Tensor]]:
        r"""Run MCABC.

        Args:
            x_o: Observed data.
            num_simulations: Number of simulations to run.
            eps: Acceptance threshold $\epsilon$ for distance between observed and
                simulated data.
            quantile: Upper quantile of smallest distances for which the corresponding
                parameters are returned, e.g, q=0.01 will return the top 1%. Exactly
                one of quantile or `eps` have to be passed.
            return_distances: Whether to return the distances corresponding to
                the accepted parameters.
            return_distances: Whether to return the simulated data corresponding to
                the accepted parameters.
            lra: Whether to run linear regression adjustment as in Beaumont et al. 2002
            sass: Whether to determine semi-automatic summary statistics as in
                Fearnhead & Prangle 2012.
            sass_fraction: Fraction of simulation budget used for the initial sass run.
            sass_expansion_degree: Degree of the polynomial feature expansion for the
                sass regression, default 1 - no expansion.

        Returns:
            posterior: Empirical distribution based on selected parameters.
            distances: Tensor of distances of the selected parameters.
        """
        # Exactly one of eps or quantile need to be passed.
        assert (eps is not None) ^ (
            quantile
            is not None), "Eps or quantile must be passed, but not both."

        # Run SASS and change the simulator and x_o accordingly.
        if sass:
            num_pilot_simulations = int(sass_fraction * num_simulations)
            self.logger.info(
                f"Running SASS with {num_pilot_simulations} pilot samples.")
            num_simulations -= num_pilot_simulations

            pilot_theta = self.prior.sample((num_pilot_simulations, ))
            pilot_x = self._batched_simulator(pilot_theta)

            sass_transform = self.get_sass_transform(pilot_theta, pilot_x,
                                                     sass_expansion_degree)

            simulator = lambda theta: sass_transform(
                self._batched_simulator(theta))
            x_o = sass_transform(x_o)
        else:
            simulator = self._batched_simulator

        # Simulate and calculate distances.
        theta = self.prior.sample((num_simulations, ))
        x = simulator(theta)

        # Infer shape of x to test and set x_o.
        self.x_shape = x[0].unsqueeze(0).shape
        self.x_o = process_x(x_o, self.x_shape)

        distances = self.distance(self.x_o, x)

        # Select based on acceptance threshold epsilon.
        if eps is not None:
            is_accepted = distances < eps
            num_accepted = is_accepted.sum().item()
            assert num_accepted > 0, f"No parameters accepted, eps={eps} too small"

            theta_accepted = theta[is_accepted]
            distances_accepted = distances[is_accepted]
            x_accepted = x[is_accepted]

        # Select based on quantile on sorted distances.
        elif quantile is not None:
            num_top_samples = int(num_simulations * quantile)
            sort_idx = torch.argsort(distances)
            theta_accepted = theta[sort_idx][:num_top_samples]
            distances_accepted = distances[sort_idx][:num_top_samples]
            x_accepted = x[sort_idx][:num_top_samples]

        else:
            raise ValueError("One of epsilon or quantile has to be passed.")

        # Maybe adjust theta with LRA.
        if lra:
            self.logger.info("Running Linear regression adjustment.")
            theta_adjusted = self.run_lra(theta_accepted,
                                          x_accepted,
                                          observation=self.x_o)
        else:
            theta_adjusted = theta_accepted

        posterior = Empirical(theta_adjusted,
                              log_weights=ones(theta_accepted.shape[0]))

        if return_distances and return_x_accepted:
            return posterior, distances_accepted, x_accepted
        if return_distances:
            return posterior, distances_accepted
        if return_x_accepted:
            return posterior, x_accepted
        else:
            return posterior
Exemple #22
0
    def __call__(
        self,
        x_o: Union[Tensor, ndarray],
        num_particles: int,
        num_initial_pop: int,
        num_simulations: int,
        epsilon_decay: float,
        distance_based_decay: bool = False,
        ess_min: Optional[float] = None,
        kernel_variance_scale: float = 1.0,
        use_last_pop_samples: bool = True,
        return_summary: bool = False,
        lra: bool = False,
        lra_with_weights: bool = False,
        sass: bool = False,
        sass_fraction: float = 0.25,
        sass_expansion_degree: int = 1,
    ) -> Union[Distribution, Tuple[Distribution, dict]]:
        r"""Run SMCABC.

        Args:
            x_o: Observed data.
            num_particles: Number of particles in each population.
            num_initial_pop: Number of simulations used for initial population.
            num_simulations: Total number of possible simulations.
            epsilon_decay: Factor with which the acceptance threshold $\epsilon$ decays.
            distance_based_decay: Whether the $\epsilon$ decay is constant over
                populations or calculated from the previous populations distribution of
                distances.
            ess_min: Threshold of effective sampling size for resampling weights. Not
                used when None (default).
            kernel_variance_scale: Factor for scaling the perturbation kernel variance.
            use_last_pop_samples: Whether to fill up the current population with
                samples from the previous population when the budget is used up. If
                False, the current population is discarded and the previous population
                is returned.
            return_summary: Whether to return a dictionary with all accepted particles,
                weights, etc. at the end.
            lra: Whether to run linear regression adjustment as in Beaumont et al. 2002
            lra_with_weights: Whether to run lra as weighted linear regression with SMC
                weights
            sass: Whether to determine semi-automatic summary statistics as in
                Fearnhead & Prangle 2012.
            sass_fraction: Fraction of simulation budget used for the initial sass run.
            sass_expansion_degree: Degree of the polynomial feature expansion for the
                sass regression, default 1 - no expansion.

        Returns:
            posterior: Empirical posterior distribution defined by the accepted
                particles and their weights.
            summary (optional): A dictionary containing particles, weights, epsilons
                and distances of each population.
        """

        pop_idx = 0
        self.num_simulations = num_simulations

        # Pilot run for SASS.
        if sass:
            num_pilot_simulations = int(sass_fraction * num_simulations)
            self.logger.info(
                f"Running SASS with {num_pilot_simulations} pilot samples."
            )
            sass_transform = self.run_sass_set_xo(
                num_particles, num_pilot_simulations, x_o, lra, sass_expansion_degree
            )
            # Udpate simulator and xo
            x_o = sass_transform(self.x_o)

            def sass_simulator(theta):
                self.simulation_counter += theta.shape[0]
                return sass_transform(self._batched_simulator(theta))

            self._simulate_with_budget = sass_simulator

        # run initial population
        particles, epsilon, distances, x = self._set_xo_and_sample_initial_population(
            x_o, num_particles, num_initial_pop
        )
        log_weights = torch.log(1 / num_particles * ones(num_particles))

        self.logger.info(
            (
                f"population={pop_idx}, eps={epsilon}, ess={1.0}, "
                f"num_sims={num_initial_pop}"
            )
        )

        all_particles = [particles]
        all_log_weights = [log_weights]
        all_distances = [distances]
        all_epsilons = [epsilon]
        all_x = [x]

        while self.simulation_counter < self.num_simulations:

            pop_idx += 1
            # Decay based on quantile of distances from previous pop.
            if distance_based_decay:
                epsilon = self._get_next_epsilon(
                    all_distances[pop_idx - 1], epsilon_decay
                )
            # Constant decay.
            else:
                epsilon *= epsilon_decay

            # Get kernel variance from previous pop.
            self.kernel_variance = self.get_kernel_variance(
                all_particles[pop_idx - 1],
                torch.exp(all_log_weights[pop_idx - 1]),
                samples_per_dim=500,
                kernel_variance_scale=kernel_variance_scale,
            )
            particles, log_weights, distances, x = self._sample_next_population(
                particles=all_particles[pop_idx - 1],
                log_weights=all_log_weights[pop_idx - 1],
                distances=all_distances[pop_idx - 1],
                epsilon=epsilon,
                x=all_x[pop_idx - 1],
                use_last_pop_samples=use_last_pop_samples,
            )

            # Resample population if effective sampling size is too small.
            if ess_min is not None:
                particles, log_weights = self.resample_if_ess_too_small(
                    particles, log_weights, ess_min, pop_idx
                )

            self.logger.info(
                (
                    f"population={pop_idx} done: eps={epsilon:.6f},"
                    f" num_sims={self.simulation_counter}."
                )
            )

            # collect results
            all_particles.append(particles)
            all_log_weights.append(log_weights)
            all_distances.append(distances)
            all_epsilons.append(epsilon)
            all_x.append(x)

        # Maybe run LRA and adjust weights.
        if lra:
            self.logger.info("Running Linear regression adjustment.")
            adjusted_particels, adjusted_weights = self.run_lra_update_weights(
                particles=all_particles[-1],
                xs=all_x[-1],
                observation=x_o,
                log_weights=all_log_weights[-1],
                lra_with_weights=lra_with_weights,
            )
            posterior = Empirical(adjusted_particels, log_weights=adjusted_weights)
        else:
            posterior = Empirical(all_particles[-1], log_weights=all_log_weights[-1])

        if return_summary:
            return (
                posterior,
                dict(
                    particles=all_particles,
                    weights=all_log_weights,
                    epsilons=all_epsilons,
                    distances=all_distances,
                    xs=all_x,
                ),
            )
        else:
            return posterior