コード例 #1
0
ファイル: util.py プロジェクト: basnijholt/pyABC
def generate_valid_proposal(
        t: int, m: np.ndarray, p: np.ndarray,
        model_prior: RV, parameter_priors: List[Distribution],
        model_perturbation_kernel: ModelPerturbationKernel,
        transitions: List[Transition]):
    """Sample a parameter for a model.

    Parameters
    ----------
    t: Population index to generate for.
    m: Indices of alive models.
    p: Probabilities of alive models.
    model_prior: The model prior.
    parameter_priors: The parameter priors.
    model_perturbation_kernel: The model perturbation kernel.
    transitions: The transitions, one per model.

    Returns
    -------
    (m_ss, theta_ss): Model, parameter.
    """
    # first generation
    if t == 0:
        # sample from prior
        m_ss = int(model_prior.rvs())
        theta_ss = parameter_priors[m_ss].rvs()
        return m_ss, theta_ss

    # later generation
    # counter
    n_sample, n_sample_soft_limit = 0, 1000
    # sample until the prior density is positive
    while True:
        if len(m) > 1:
            index = fast_random_choice(p)
            m_s = m[index]
            m_ss = model_perturbation_kernel.rvs(m_s)
            # theta_s is None if the population m_ss has died out.
            # This can happen since the model_perturbation_kernel
            # can return a model nr which has died out.
            if m_ss not in m:
                continue
        else:
            # only one model
            m_ss = m[0]
        theta_ss = Parameter(**transitions[m_ss].rvs().to_dict())

        # check if positive under prior
        if (model_prior.pmf(m_ss)
                * parameter_priors[m_ss].pdf(theta_ss) > 0):
            return m_ss, theta_ss

        # unhealthy sampling detection
        n_sample += 1
        if n_sample == n_sample_soft_limit:
            logger.warning(
                "Unusually many (model, parameter) samples have prior "
                "density zero. The transition might be inappropriate.")
コード例 #2
0
    def setUp(self):
        self.d = Distribution(
            **{
                "a": RV("randint", low=0, high=3 + 1),
                "b": Distribution(
                    **{
                        "b1": RV("randint", low=0, high=3 + 1),
                        "b2": RV("randint", low=0, high=3 + 1),
                    }
                ),
            }
        )
        self.d_plus_one = Distribution(
            **{
                "a": RV("randint", low=1, high=1 + 1),
                "b": Distribution(
                    **{
                        "b1": RV("randint", low=1, high=1 + 1),
                        "b2": RV("randint", low=1, high=1 + 1),
                    }
                ),
            }
        )
        self.x_one = Parameter({"a": 1, "b": Parameter({"b1": 1, "b2": 1})})

        self.x_zero = Parameter({"a": 0, "b": Parameter({"b1": 0, "b2": 0})})

        self.x_two = Parameter({"a": 2, "b": Parameter({"b1": 2, "b2": 2})})
コード例 #3
0
ファイル: test_distance.py プロジェクト: ICB-DCM/pyABC
def test_info_weighted_pnorm_distance():
    """Just test the info weighted distance pipeline."""
    db_file = create_sqlite_db_id()[len("sqlite:///"):]
    scale_log_file = tempfile.mkstemp()[1]
    info_log_file = tempfile.mkstemp()[1]
    info_sample_log_file = tempfile.mkstemp()[1]

    try:

        def model(p):
            return {
                "s0": p["p0"] + np.random.normal(),
                "s1": p["p1"] + np.random.normal(size=2),
            }

        prior = Distribution(p0=RV("uniform", 0, 1), p1=RV("uniform", 0, 10))
        data = {"s0": 0.5, "s1": np.array([5, 5])}

        for feature_normalization in ["mad", "std", "weights", "none"]:
            distance = InfoWeightedPNormDistance(
                predictor=LinearPredictor(),
                fit_info_ixs={1, 3},
                feature_normalization=feature_normalization,
                scale_log_file=scale_log_file,
                info_log_file=info_log_file,
                info_sample_log_file=info_sample_log_file,
            )
            abc = ABCSMC(model, prior, distance, population_size=100)
            abc.new("sqlite:///" + db_file, data)
            abc.run(max_nr_populations=3)
    finally:
        if os.path.exists(db_file):
            os.remove(db_file)
        if os.path.exists(scale_log_file):
            os.remove(scale_log_file)
        if os.path.exists(info_log_file):
            os.remove(info_log_file)
コード例 #4
0
ファイル: test_distance.py プロジェクト: ICB-DCM/pyABC
def test_wasserstein_distance():
    """Test Wasserstein and Sliced Wasserstein distances."""
    n_sample = 11

    def model_1d(p):
        return {"y": np.random.normal(p["p0"], 1.0, size=n_sample)}

    p_true = {"p0": -0.5}
    y0 = model_1d(p_true)

    p1 = {"p0": -0.55}
    y1 = model_1d(p1)

    p2 = {"p0": 3.55}
    y2 = model_1d(p2)

    class IdSumstat(Sumstat):
        """Identity summary statistic."""
        def __call__(self, data: dict) -> np.ndarray:
            # shape (n, dim)
            return data["y"].reshape((-1, 1))

    for p in [1, 2]:
        for distance in [
                WassersteinDistance(
                    sumstat=IdSumstat(),
                    p=p,
                ),
                SlicedWassersteinDistance(
                    sumstat=IdSumstat(),
                    p=p,
                ),
        ]:
            distance.initialize(x_0=y0)

            # evaluate distance
            dist = distance(y1, y0)

            assert dist > 0

            # sample from somewhere else
            assert dist < distance(y2, y0)

            # compare to ground truth
            if isinstance(distance, SlicedWassersteinDistance):
                continue

            # weights
            w = np.ones(shape=n_sample) / n_sample

            dist_exp = sp_dist.minkowski(
                np.sort(y1["y"].flatten()),
                np.sort(y0["y"]).flatten(),
                w=w,
                p=p,
            )

            assert np.isclose(dist, dist_exp)

    with pytest.raises(ValueError):
        WassersteinDistance(sumstat=IdSumstat(), p=3)

    # test integrated
    prior = Distribution(p0=RV("norm", 0, 2))
    db_file = tempfile.mkstemp(suffix=".db")[1]
    try:
        for distance in [
                WassersteinDistance(sumstat=IdSumstat(), ),
                SlicedWassersteinDistance(sumstat=IdSumstat(), ),
        ]:
            abc = ABCSMC(model_1d, prior, distance, population_size=10)
            abc.new("sqlite:///" + db_file, y0)
            abc.run(max_nr_populations=3)
    finally:
        os.remove(db_file)
コード例 #5
0
ファイル: smc.py プロジェクト: basnijholt/pyABC
    def __init__(
            self,
            models: Union[List[Model], Model, Callable],
            parameter_priors: Union[List[Distribution],
                                    Distribution, Callable],
            distance_function: Union[Distance, Callable] = None,
            population_size: Union[PopulationStrategy, int] = 100,
            summary_statistics: Callable[[model_output], dict] = identity,
            model_prior: RV = None,
            model_perturbation_kernel: ModelPerturbationKernel = None,
            transitions: Union[List[Transition], Transition] = None,
            eps: Epsilon = None,
            sampler: Sampler = None,
            acceptor: Acceptor = None,
            stop_if_only_single_model_alive: bool = False,
            max_nr_recorded_particles: int = np.inf):
        if not isinstance(models, list):
            models = [models]
        models = list(map(SimpleModel.assert_model, models))
        self.models = models

        if not isinstance(parameter_priors, list):
            parameter_priors = [parameter_priors]
        self.parameter_priors = parameter_priors

        # sanity checks
        if len(self.models) != len(self.parameter_priors):
            raise AssertionError(
                "Number models and number parameter priors have to agree.")

        if distance_function is None:
            distance_function = PNormDistance()
        self.distance_function = to_distance(distance_function)

        self.summary_statistics = summary_statistics

        if model_prior is None:
            model_prior = RV("randint", 0, len(self.models))
        self.model_prior = model_prior

        if model_perturbation_kernel is None:
            model_perturbation_kernel = ModelPerturbationKernel(
                len(self.models), probability_to_stay=.7)
        self.model_perturbation_kernel = model_perturbation_kernel

        if transitions is None:
            transitions = [MultivariateNormalTransition()
                           for _ in self.models]
        if not isinstance(transitions, list):
            transitions = [transitions]
        self.transitions = transitions  # type: List[Transition]

        if eps is None:
            eps = MedianEpsilon(median_multiplier=1)
        self.eps = eps

        if isinstance(population_size, int):
            population_size = ConstantPopulationSize(
                population_size)
        self.population_size = population_size

        if sampler is None:
            sampler = DefaultSampler()
        self.sampler = sampler

        if acceptor is None:
            acceptor = UniformAcceptor()
        self.acceptor = SimpleFunctionAcceptor.assert_acceptor(acceptor)

        self.stop_if_only_single_model_alive = stop_if_only_single_model_alive
        self.max_nr_recorded_particles = max_nr_recorded_particles

        # will be set later
        self.x_0 = None
        self.history = None
        self._initial_population = None
        self.minimum_epsilon = None
        self.max_nr_populations = None
        self.min_acceptance_rate = None

        self._sanity_check()
コード例 #6
0
 def test_no_kwargs(self):
     a = RV.from_dictionary({"type": "uniform", "args": [0, 0]})
     self.assertEqual(0, a.rvs())