def generate_valid_proposal( t: int, m: np.ndarray, p: np.ndarray, model_prior: RV, parameter_priors: List[Distribution], model_perturbation_kernel: ModelPerturbationKernel, transitions: List[Transition]): """Sample a parameter for a model. Parameters ---------- t: Population index to generate for. m: Indices of alive models. p: Probabilities of alive models. model_prior: The model prior. parameter_priors: The parameter priors. model_perturbation_kernel: The model perturbation kernel. transitions: The transitions, one per model. Returns ------- (m_ss, theta_ss): Model, parameter. """ # first generation if t == 0: # sample from prior m_ss = int(model_prior.rvs()) theta_ss = parameter_priors[m_ss].rvs() return m_ss, theta_ss # later generation # counter n_sample, n_sample_soft_limit = 0, 1000 # sample until the prior density is positive while True: if len(m) > 1: index = fast_random_choice(p) m_s = m[index] m_ss = model_perturbation_kernel.rvs(m_s) # theta_s is None if the population m_ss has died out. # This can happen since the model_perturbation_kernel # can return a model nr which has died out. if m_ss not in m: continue else: # only one model m_ss = m[0] theta_ss = Parameter(**transitions[m_ss].rvs().to_dict()) # check if positive under prior if (model_prior.pmf(m_ss) * parameter_priors[m_ss].pdf(theta_ss) > 0): return m_ss, theta_ss # unhealthy sampling detection n_sample += 1 if n_sample == n_sample_soft_limit: logger.warning( "Unusually many (model, parameter) samples have prior " "density zero. The transition might be inappropriate.")
def setUp(self): self.d = Distribution( **{ "a": RV("randint", low=0, high=3 + 1), "b": Distribution( **{ "b1": RV("randint", low=0, high=3 + 1), "b2": RV("randint", low=0, high=3 + 1), } ), } ) self.d_plus_one = Distribution( **{ "a": RV("randint", low=1, high=1 + 1), "b": Distribution( **{ "b1": RV("randint", low=1, high=1 + 1), "b2": RV("randint", low=1, high=1 + 1), } ), } ) self.x_one = Parameter({"a": 1, "b": Parameter({"b1": 1, "b2": 1})}) self.x_zero = Parameter({"a": 0, "b": Parameter({"b1": 0, "b2": 0})}) self.x_two = Parameter({"a": 2, "b": Parameter({"b1": 2, "b2": 2})})
def test_info_weighted_pnorm_distance(): """Just test the info weighted distance pipeline.""" db_file = create_sqlite_db_id()[len("sqlite:///"):] scale_log_file = tempfile.mkstemp()[1] info_log_file = tempfile.mkstemp()[1] info_sample_log_file = tempfile.mkstemp()[1] try: def model(p): return { "s0": p["p0"] + np.random.normal(), "s1": p["p1"] + np.random.normal(size=2), } prior = Distribution(p0=RV("uniform", 0, 1), p1=RV("uniform", 0, 10)) data = {"s0": 0.5, "s1": np.array([5, 5])} for feature_normalization in ["mad", "std", "weights", "none"]: distance = InfoWeightedPNormDistance( predictor=LinearPredictor(), fit_info_ixs={1, 3}, feature_normalization=feature_normalization, scale_log_file=scale_log_file, info_log_file=info_log_file, info_sample_log_file=info_sample_log_file, ) abc = ABCSMC(model, prior, distance, population_size=100) abc.new("sqlite:///" + db_file, data) abc.run(max_nr_populations=3) finally: if os.path.exists(db_file): os.remove(db_file) if os.path.exists(scale_log_file): os.remove(scale_log_file) if os.path.exists(info_log_file): os.remove(info_log_file)
def test_wasserstein_distance(): """Test Wasserstein and Sliced Wasserstein distances.""" n_sample = 11 def model_1d(p): return {"y": np.random.normal(p["p0"], 1.0, size=n_sample)} p_true = {"p0": -0.5} y0 = model_1d(p_true) p1 = {"p0": -0.55} y1 = model_1d(p1) p2 = {"p0": 3.55} y2 = model_1d(p2) class IdSumstat(Sumstat): """Identity summary statistic.""" def __call__(self, data: dict) -> np.ndarray: # shape (n, dim) return data["y"].reshape((-1, 1)) for p in [1, 2]: for distance in [ WassersteinDistance( sumstat=IdSumstat(), p=p, ), SlicedWassersteinDistance( sumstat=IdSumstat(), p=p, ), ]: distance.initialize(x_0=y0) # evaluate distance dist = distance(y1, y0) assert dist > 0 # sample from somewhere else assert dist < distance(y2, y0) # compare to ground truth if isinstance(distance, SlicedWassersteinDistance): continue # weights w = np.ones(shape=n_sample) / n_sample dist_exp = sp_dist.minkowski( np.sort(y1["y"].flatten()), np.sort(y0["y"]).flatten(), w=w, p=p, ) assert np.isclose(dist, dist_exp) with pytest.raises(ValueError): WassersteinDistance(sumstat=IdSumstat(), p=3) # test integrated prior = Distribution(p0=RV("norm", 0, 2)) db_file = tempfile.mkstemp(suffix=".db")[1] try: for distance in [ WassersteinDistance(sumstat=IdSumstat(), ), SlicedWassersteinDistance(sumstat=IdSumstat(), ), ]: abc = ABCSMC(model_1d, prior, distance, population_size=10) abc.new("sqlite:///" + db_file, y0) abc.run(max_nr_populations=3) finally: os.remove(db_file)
def __init__( self, models: Union[List[Model], Model, Callable], parameter_priors: Union[List[Distribution], Distribution, Callable], distance_function: Union[Distance, Callable] = None, population_size: Union[PopulationStrategy, int] = 100, summary_statistics: Callable[[model_output], dict] = identity, model_prior: RV = None, model_perturbation_kernel: ModelPerturbationKernel = None, transitions: Union[List[Transition], Transition] = None, eps: Epsilon = None, sampler: Sampler = None, acceptor: Acceptor = None, stop_if_only_single_model_alive: bool = False, max_nr_recorded_particles: int = np.inf): if not isinstance(models, list): models = [models] models = list(map(SimpleModel.assert_model, models)) self.models = models if not isinstance(parameter_priors, list): parameter_priors = [parameter_priors] self.parameter_priors = parameter_priors # sanity checks if len(self.models) != len(self.parameter_priors): raise AssertionError( "Number models and number parameter priors have to agree.") if distance_function is None: distance_function = PNormDistance() self.distance_function = to_distance(distance_function) self.summary_statistics = summary_statistics if model_prior is None: model_prior = RV("randint", 0, len(self.models)) self.model_prior = model_prior if model_perturbation_kernel is None: model_perturbation_kernel = ModelPerturbationKernel( len(self.models), probability_to_stay=.7) self.model_perturbation_kernel = model_perturbation_kernel if transitions is None: transitions = [MultivariateNormalTransition() for _ in self.models] if not isinstance(transitions, list): transitions = [transitions] self.transitions = transitions # type: List[Transition] if eps is None: eps = MedianEpsilon(median_multiplier=1) self.eps = eps if isinstance(population_size, int): population_size = ConstantPopulationSize( population_size) self.population_size = population_size if sampler is None: sampler = DefaultSampler() self.sampler = sampler if acceptor is None: acceptor = UniformAcceptor() self.acceptor = SimpleFunctionAcceptor.assert_acceptor(acceptor) self.stop_if_only_single_model_alive = stop_if_only_single_model_alive self.max_nr_recorded_particles = max_nr_recorded_particles # will be set later self.x_0 = None self.history = None self._initial_population = None self.minimum_epsilon = None self.max_nr_populations = None self.min_acceptance_rate = None self._sanity_check()
def test_no_kwargs(self): a = RV.from_dictionary({"type": "uniform", "args": [0, 0]}) self.assertEqual(0, a.rvs())