Beispiel #1
0
def main():
    task = "mg1"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()
    neural_likelihood = utils.get_neural_likelihood(
        "maf", parameter_dim, observation_dim
    )
    snl = SNL(
        simulator=simulator,
        true_observation=true_observation,
        prior=prior,
        neural_likelihood=neural_likelihood,
        mcmc_method="slice-np",
    )

    num_rounds, num_simulations_per_round = 10, 1000
    snl.run_inference(
        num_rounds=num_rounds, num_simulations_per_round=num_simulations_per_round
    )

    samples = snl.sample_posterior(1000)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()
        ).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig("./corner-posterior-snl.pdf")
Beispiel #2
0
def test_():
    task = "nonlinear-gaussian"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()
    neural_posterior = utils.get_neural_posterior("maf", parameter_dim,
                                                  observation_dim, simulator)
    apt = APT(
        simulator=simulator,
        true_observation=true_observation,
        prior=prior,
        neural_posterior=neural_posterior,
        num_atoms=-1,
        use_combined_loss=False,
        train_with_mcmc=False,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        discard_prior_samples=False,
    )

    num_rounds, num_simulations_per_round = 20, 1000
    apt.run_inference(num_rounds=num_rounds,
                      num_simulations_per_round=num_simulations_per_round)

    samples = apt.sample_posterior(2500)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig(
        os.path.join(utils.get_output_root(), "corner-posterior-apt.pdf"))

    samples = apt.sample_posterior_mcmc(num_samples=1000)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()).reshape(-1),
        lims=simulator.parameter_plotting_limits,
    )
    figure.savefig(
        os.path.join(utils.get_output_root(), "corner-posterior-apt-mcmc.pdf"))
Beispiel #3
0
def test_():
    # if torch.cuda.is_available():
    #     device = torch.device("cuda")
    #     torch.set_default_tensor_type("torch.cuda.FloatTensor")
    # else:
    #     device = torch.device("cpu")
    #     torch.set_default_tensor_type("torch.FloatTensor")

    loc = torch.Tensor([0, 0])
    covariance_matrix = torch.Tensor([[1, 0.99], [0.99, 1]])

    likelihood = distributions.MultivariateNormal(
        loc=loc, covariance_matrix=covariance_matrix)
    bound = 1.5
    low, high = -bound * torch.ones(2), bound * torch.ones(2)
    prior = distributions.Uniform(low=low, high=high)

    # def potential_function(inputs_dict):
    #     parameters = next(iter(inputs_dict.values()))
    #     return -(likelihood.log_prob(parameters) + prior.log_prob(parameters).sum())
    prior = distributions.Uniform(low=-5 * torch.ones(4),
                                  high=2 * torch.ones(4))
    from lfi.nsf import distributions as distributions_

    likelihood = distributions_.LotkaVolterraOscillating()
    potential_function = PotentialFunction(likelihood, prior)

    # kernel = Slice(potential_function=potential_function)
    from pyro.infer.mcmc import HMC, NUTS

    # kernel = HMC(potential_fn=potential_function)
    kernel = NUTS(potential_fn=potential_function)
    num_chains = 3
    sampler = MCMC(
        kernel=kernel,
        num_samples=10000 // num_chains,
        warmup_steps=200,
        initial_params={"": torch.zeros(num_chains, 4)},
        num_chains=num_chains,
    )
    sampler.run()
    samples = next(iter(sampler.get_samples().values()))

    utils.plot_hist_marginals(utils.tensor2numpy(samples),
                              ground_truth=utils.tensor2numpy(loc),
                              lims=[-6, 3])
    # plt.show()
    plt.savefig("/home/conor/Dropbox/phd/projects/lfi/out/mcmc.pdf")
    plt.close()
Beispiel #4
0
    def simulate(self, parameters):

        parameters = utils.tensor2numpy(parameters)

        assert parameters.shape[1] == 3, "parameter must be 3-dimensional"
        p1, p2, p3 = parameters[:, 0:1], parameters[:, 1:2], parameters[:, 2:3]
        N = parameters.shape[0]

        # service times (uniformly distributed)
        sts = (p2 - p1) * np.random.rand(N, self.n_sim_steparameters) + p1

        # inter-arrival times (exponentially distributed)
        iats = -np.log(1.0 - np.random.rand(N, self.n_sim_steparameters)) / p3

        # arrival times
        ats = np.cumsum(iats, axis=1)

        # inter-departure times
        idts = np.empty([N, self.n_sim_steparameters], dtype=float)
        idts[:, 0] = sts[:, 0] + ats[:, 0]

        # departure times
        dts = np.empty([N, self.n_sim_steparameters], dtype=float)
        dts[:, 0] = idts[:, 0]

        for i in range(1, self.n_sim_steparameters):
            idts[:, i] = sts[:, i] + np.maximum(0.0, ats[:, i] - dts[:, i - 1])
            dts[:, i] = dts[:, i - 1] + idts[:, i]

        self.num_total_simulations += N

        if self._summarize_observations:
            idts = self._summarizer(idts)

        return torch.Tensor(idts)
Beispiel #5
0
    def simulate(self, parameters):
        """
        Generates observations for the given batch of parameters.

        :param parameters: torch.Tensor
            Batch of parameters.
        :return: torch.Tensor
            Batch of observations.
        """

        # Run simulator in NumPy.
        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        # If we have a single parameter then view it as a batch of one.
        if parameters.ndim == 1:
            return self.simulate(parameters[None, ...])

        num_simulations = parameters.shape[0]

        # Keep track of total simulations.
        self.num_total_simulations += num_simulations

        # Run simulator.
        a = np.pi * np.random.rand(num_simulations) - np.pi / 2
        r = 0.01 * np.random.randn(num_simulations) + 0.1
        p = np.column_stack([r * np.cos(a) + 0.25, r * np.sin(a)])
        s = (1 / np.sqrt(2)) * np.column_stack(
            [
                -np.abs(parameters[:, 0] + parameters[:, 1]),
                (-parameters[:, 0] + parameters[:, 1]),
            ]
        )
        return torch.Tensor(p + s)
Beispiel #6
0
def test_():
    task = "lotka-volterra"
    simulator, prior = simulators.get_simulator_and_prior(task)
    parameter_dim, observation_dim = (
        simulator.parameter_dim,
        simulator.observation_dim,
    )
    true_observation = simulator.get_ground_truth_observation()

    classifier = utils.get_classifier("mlp", parameter_dim, observation_dim)
    ratio_estimator = SRE(
        simulator=simulator,
        true_observation=true_observation,
        classifier=classifier,
        prior=prior,
        num_atoms=-1,
        mcmc_method="slice-np",
        retrain_from_scratch_each_round=False,
    )

    num_rounds, num_simulations_per_round = 10, 1000
    ratio_estimator.run_inference(
        num_rounds=num_rounds,
        num_simulations_per_round=num_simulations_per_round)

    samples = ratio_estimator.sample_posterior(num_samples=2500)
    samples = utils.tensor2numpy(samples)
    figure = utils.plot_hist_marginals(
        data=samples,
        ground_truth=utils.tensor2numpy(
            simulator.get_ground_truth_parameters()).reshape(-1),
        lims=[-4, 4],
    )
    figure.savefig(
        os.path.join(utils.get_output_root(), "corner-posterior-ratio.pdf"))

    mmds = ratio_estimator.summary["mmds"]
    if mmds:
        figure, axes = plt.subplots(1, 1)
        axes.plot(
            np.arange(0, num_rounds * num_simulations_per_round,
                      num_simulations_per_round),
            np.array(mmds),
            "-o",
            linewidth=2,
        )
        figure.savefig(os.path.join(utils.get_output_root(), "mmd-ratio.pdf"))
Beispiel #7
0
    def log_prob(self, observations, parameters):
        """
        Likelihood is proportional to a product of self._num_observations_per_parameter 2D
        Gaussians and so log likelihood can be computed analytically.

        :param observations: torch.Tensor [batch_size, observation_dim]
            Batch of observations.
        :param parameters: torch.Tensor [batch_size, parameter_dim]
            Batch of parameters.
        :return: torch.Tensor [batch_size]
            Log likelihood log p(x | theta) for each item in the batch.
        """

        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        if isinstance(observations, torch.Tensor):
            observations = utils.tensor2numpy(observations)

        if observations.ndim == 1 and parameters.ndim == 1:
            observations, parameters = (
                observations.reshape(1, -1),
                parameters.reshape(1, -1),
            )

        m0, m1, s0, s1, r = self._unpack_params(parameters)
        logdet = np.log(s0) + np.log(s1) + 0.5 * np.log(1.0 - r**2)

        observations = observations.reshape(
            [observations.shape[0], self._num_observations_per_parameter, 2])
        us = np.empty_like(observations)

        us[:, :, 0] = (observations[:, :, 0] - m0) / s0
        us[:, :, 1] = (observations[:, :, 1] - m1 -
                       s1 * r * us[:, :, 0]) / (s1 * np.sqrt(1.0 - r**2))
        us = us.reshape(
            [us.shape[0], 2 * self._num_observations_per_parameter])

        L = (np.sum(scipy.stats.norm.logpdf(us), axis=1) -
             self._num_observations_per_parameter * logdet[:, 0])

        return L
Beispiel #8
0
def _test():
    # prior = MG1Uniform(low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3]))
    # uniform = distributions.Uniform(
    #     low=torch.zeros(3), high=torch.Tensor([10, 10, 1 / 3])
    # )
    # x = torch.Tensor([10, 20, 1 / 3]).reshape(1, -1)
    # print(uniform.log_prob(x))
    # print(prior.log_prob(x))
    d = LotkaVolterraOscillating()
    samples = d.sample((1000, ))
    utils.plot_hist_marginals(utils.tensor2numpy(samples), lims=[-6, 3])
    plt.show()
Beispiel #9
0
    def simulate(self, parameters):
        """
        Generates observations for the given batch of parameters.

        :param parameters: torch.Tensor
            Batch of parameters.
        :return: torch.Tensor
            Batch of observations.
        """

        # Run simulator in NumPy.
        if isinstance(parameters, torch.Tensor):
            parameters = utils.tensor2numpy(parameters)

        # If we have a single parameter then view it as a batch of one.
        if parameters.ndim == 1:
            return self.simulate(parameters[np.newaxis, :])[0]

        num_simulations = parameters.shape[0]

        # Keep track of total simulations.
        self.num_total_simulations += num_simulations

        # Run simulator to generate self._num_observations_per_parameter
        # observations from a 2D Gaussian parameterized by the 5 given parameters.
        m0, m1, s0, s1, r = self._unpack_params(parameters)

        us = np.random.randn(num_simulations,
                             self._num_observations_per_parameter, 2)
        observations = np.empty_like(us)

        observations[:, :, 0] = s0 * us[:, :, 0] + m0
        observations[:, :, 1] = (
            s1 * (r * us[:, :, 0] + np.sqrt(1.0 - r**2) * us[:, :, 1]) + m1)

        mean, std = self._get_observation_normalization_parameters()
        return (torch.Tensor(
            observations.reshape(
                [num_simulations, 2 * self._num_observations_per_parameter])) -
                mean.reshape(1, -1)) / std.reshape(1, -1)
Beispiel #10
0
def test_():
    features = 3
    context_features = 5
    num_mixture_components = 5
    model = MixtureOfGaussiansMADE(
        features=features,
        hidden_features=32,
        context_features=context_features,
        num_mixture_components=num_mixture_components,
        num_blocks=2,
        use_residual_blocks=True,
        random_mask=False,
        activation=F.relu,
        dropout_probability=0,
        use_batch_norm=False,
        epsilon=1e-2,
        custom_initialization=True,
    )
    context = torch.randn(2, context_features)
    samples = model.sample(1000, context=context)
    utils.plot_hist_marginals(utils.tensor2numpy(samples.squeeze(0)),
                              lims=[-10, 10])

    plt.show()
Beispiel #11
0
    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(
                    num_samples=1000),
            )
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        # Median |x - x0| for most recent round.
        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._summary_net(self._observation_bank[-1]) -
                     self._summary_net(self._true_observation).reshape(1,
                                                                       -1))**2,
                    dim=-1,
                )))
        self._summary["median-observation-distances"].append(
            median_observation_distance.item())

        # KDE estimate of negative log prob true parameters using
        # parameters from most recent round.
        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item())

        # Rejection sampling acceptance rate
        rejection_sampling_acceptance_rate = self._estimate_acceptance_rate()
        self._summary["rejection-sampling-acceptance-rates"].append(
            rejection_sampling_acceptance_rate)

        # Plot most recently sampled parameters.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )

        # Write quantities using SummaryWriter.
        self._summary_writer.add_figure(tag="posterior-samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"]
            [-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="rejection-sampling-acceptance-rate",
            scalar_value=self._summary["rejection-sampling-acceptance-rates"]
            [-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
Beispiel #12
0
    def __init__(self,
                 simulator,
                 prior,
                 true_observation,
                 neural_posterior,
                 num_atoms=-1,
                 use_combined_loss=False,
                 train_with_mcmc=False,
                 mcmc_method="slice-np",
                 summary_net=None,
                 retrain_from_scratch_each_round=False,
                 discard_prior_samples=False,
                 summary_writer=None,
                 label=None,
                 device='cpu'):
        """
        :param simulator:
            Python object with 'simulate' method which takes a torch.Tensor
            of parameter values, and returns a simulation result for each parameter
            as a torch.Tensor.
        :param prior: Distribution
            Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor [observation_dim] or [1, observation_dim]
            True observation x0 for which to perform inference on the posterior p(theta | x0).
        :param neural_posterior: nets.Module
            Conditional density estimator q(theta | x) with 'log_prob' and 'sample' methods.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param use_combined_loss: bool
            Whether to jointly train prior samples using maximum likelihood.
            Useful to prevent density leaking when using box uniform priors.
        :param train_with_mcmc: bool
            Whether to sample using MCMC instead of i.i.d. sampling at the end of each round
        :param mcmc_method: str
            MCMC method to use if 'train_with_mcmc' is True.
            One of ['slice-numpy', 'hmc', 'nuts'].
        :param summary_net: nets.Module
            Optional network which may be used to produce feature vectors
            f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: bool
            Whether to retrain the conditional density estimator for the posterior
            from scratch each round.
        :param discard_prior_samples: bool
            Whether to discard prior samples from round two onwards.
        :param summary_writer: SummaryWriter
            Optionally pass summary writer.
            If None, will create one internally.
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_posterior = neural_posterior
        self._device = device
        self._label = label

        assert isinstance(num_atoms,
                          int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._use_combined_loss = use_combined_loss

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net
            self._summary_bank = []

        self._mcmc_method = mcmc_method
        self._train_with_mcmc = train_with_mcmc

        # HMC and NUTS from Pyro.
        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_posterior, prior, self._summary_net(self._true_observation))

        # Axis-aligned slice sampling implementation in NumPy
        target_log_prob = (lambda parameters: self._neural_posterior.log_prob(
            inputs=torch.Tensor(parameters).reshape(1, -1),
            context=self._summary_net(self._true_observation).reshape(1, -1),
        ).item() if not np.isinf(
            self._prior.log_prob(torch.Tensor(parameters)).sum().item()) else
                           -np.inf)
        self._neural_posterior.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_posterior.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        self._untrained_neural_posterior = deepcopy(neural_posterior)

        self._discard_prior_samples = discard_prior_samples

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank, self._prior_masks = [], [], []

        self._model_bank = []

        self._total_num_generated_examples = 0

        # Each APT run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "apt", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
            "rejection-sampling-acceptance-rates": [],
        }
Beispiel #13
0
 def simulate(self, parameters):
     parameters = utils.tensor2numpy(parameters)
     observations = self._summarizer.calc(self._simulator.sim(parameters))
     return torch.Tensor(observations)
Beispiel #14
0
    def _summarize(self, round_):

        # Update summaries.
        try:
            mmd = utils.unbiased_mmd_squared(
                self._parameter_bank[-1],
                self._simulator.get_ground_truth_posterior_samples(
                    num_samples=1000),
            )
            self._summary["mmds"].append(mmd.item())
        except:
            pass

        median_observation_distance = torch.median(
            torch.sqrt(
                torch.sum(
                    (self._observation_bank[-1] -
                     self._true_observation.reshape(1, -1))**2,
                    dim=-1,
                )))
        self._summary["median-observation-distances"].append(
            median_observation_distance.item())

        negative_log_prob_true_parameters = -utils.gaussian_kde_log_eval(
            samples=self._parameter_bank[-1],
            query=self._simulator.get_ground_truth_parameters().reshape(1, -1),
        )
        self._summary["negative-log-probs-true-parameters"].append(
            negative_log_prob_true_parameters.item())

        # Plot most recently sampled parameters in TensorBoard.
        parameters = utils.tensor2numpy(self._parameter_bank[-1])
        figure = utils.plot_hist_marginals(
            data=parameters,
            ground_truth=utils.tensor2numpy(
                self._simulator.get_ground_truth_parameters()).reshape(-1),
            lims=self._simulator.parameter_plotting_limits,
        )
        self._summary_writer.add_figure(tag="posterior-samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        self._summary_writer.add_scalar(
            tag="epochs-trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best-validation-log-prob",
            scalar_value=self._summary["best-validation-log-probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="median-observation-distance",
            scalar_value=self._summary["median-observation-distances"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="negative-log-prob-true-parameters",
            scalar_value=self._summary["negative-log-probs-true-parameters"]
            [-1],
            global_step=round_ + 1,
        )

        if self._summary["mmds"]:
            self._summary_writer.add_scalar(
                tag="mmd",
                scalar_value=self._summary["mmds"][-1],
                global_step=round_ + 1,
            )

        self._summary_writer.flush()
Beispiel #15
0
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        classifier,
        num_atoms=-1,
        mcmc_method="slice-np",
        summary_net=None,
        retrain_from_scratch_each_round=False,
        summary_writer=None,
    ):
        """
        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param classifier: Binary classifier in the form of an nets.Module.
        Takes as input (x, theta) pairs and outputs pre-sigmoid activations.
        :param num_atoms: int
            Number of atoms to use for classification.
            If -1, use all other parameters in minibatch.
        :param summary_net: Optional network which may be used to produce feature vectors
        f(x) for high-dimensional observations.
        :param retrain_from_scratch_each_round: Whether to retrain the conditional density
        estimator for the posterior from scratch each round.
        """

        self._simulator = simulator
        self._true_observation = true_observation
        self._classifier = classifier
        self._prior = prior

        assert isinstance(num_atoms,
                          int), "Number of atoms must be an integer."
        self._num_atoms = num_atoms

        self._mcmc_method = mcmc_method

        # We may want to summarize high-dimensional observations.
        # This may be either a fixed or learned transformation.
        if summary_net is None:
            self._summary_net = nn.Identity()
        else:
            self._summary_net = summary_net

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            classifier, prior, true_observation)

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (lambda parameters: self._classifier(
            torch.cat(
                (torch.Tensor(parameters), self._true_observation)).reshape(
                    1, -1)).item() + self._prior.log_prob(
                        torch.Tensor(parameters)).sum().item())
        self._classifier.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1, ))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._classifier.train()

        self._retrain_from_scratch_each_round = retrain_from_scratch_each_round
        # If we're retraining from scratch each round,
        # keep a copy of the original untrained model for reinitialization.
        if retrain_from_scratch_each_round:
            self._untrained_classifier = deepcopy(classifier)
        else:
            self._untrained_classifier = None

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SRE run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(utils.get_log_root(), "sre", simulator.name,
                                   utils.get_timestamp())
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }
Beispiel #16
0
    def __init__(
        self,
        simulator,
        prior,
        true_observation,
        neural_likelihood,
        mcmc_method="slice-np",
        summary_writer=None,
    ):
        """

        :param simulator: Python object with 'simulate' method which takes a torch.Tensor
        of parameter values, and returns a simulation result for each parameter as a torch.Tensor.
        :param prior: Distribution object with 'log_prob' and 'sample' methods.
        :param true_observation: torch.Tensor containing the observation x0 for which to
        perform inference on the posterior p(theta | x0).
        :param neural_likelihood: Conditional density estimator q(x | theta) in the form of an
        nets.Module. Must have 'log_prob' and 'sample' methods.
        :param mcmc_method: MCMC method to use for posterior sampling. Must be one of
        ['slice', 'hmc', 'nuts'].
        """

        self._simulator = simulator
        self._prior = prior
        self._true_observation = true_observation
        self._neural_likelihood = neural_likelihood
        self._mcmc_method = mcmc_method

        # Defining the potential function as an object means Pyro's MCMC scheme
        # can pickle it to be used across multiple chains in parallel, even if
        # the potential function requires evaluating a neural likelihood as is the
        # case here.
        self._potential_function = NeuralPotentialFunction(
            neural_likelihood=self._neural_likelihood,
            prior=self._prior,
            true_observation=self._true_observation,
        )

        # TODO: decide on Slice Sampling implementation
        target_log_prob = (
            lambda parameters: self._neural_likelihood.log_prob(
                inputs=self._true_observation.reshape(1, -1),
                context=torch.Tensor(parameters).reshape(1, -1),
            ).item()
            + self._prior.log_prob(torch.Tensor(parameters)).sum().item()
        )
        self._neural_likelihood.eval()
        self.posterior_sampler = SliceSampler(
            utils.tensor2numpy(self._prior.sample((1,))).reshape(-1),
            lp_f=target_log_prob,
            thin=10,
        )
        self._neural_likelihood.train()

        # Need somewhere to store (parameter, observation) pairs from each round.
        self._parameter_bank, self._observation_bank = [], []

        # Each SNL run has an associated log directory for TensorBoard output.
        if summary_writer is None:
            log_dir = os.path.join(
                utils.get_log_root(), "snl", simulator.name, utils.get_timestamp()
            )
            self._summary_writer = SummaryWriter(log_dir)
        else:
            self._summary_writer = summary_writer

        # Each run also has a dictionary of summary statistics which are populated
        # over the course of training.
        self._summary = {
            "mmds": [],
            "median-observation-distances": [],
            "negative-log-probs-true-parameters": [],
            "neural-net-fit-times": [],
            "mcmc-times": [],
            "epochs": [],
            "best-validation-log-probs": [],
        }