Exemplo n.º 1
0
def tb_plot_posterior(writer: SummaryWriter,
                      samples: torch.Tensor,
                      tag: str = "posterior"):
    if type(samples) == torch.Tensor:
        samples = samples.numpy()
    fig, _ = pairplot(samples.squeeze(), points=[])
    writer.add_figure(f"{tag}", fig, close=True)
Exemplo n.º 2
0
def test_rejection_with_proposal(
    plt,
    task_name="gaussian_linear_uniform",
    num_observation=1,
    num_samples=10000,
    prior_weight=0.1,
    multiplier_M=1.2,
    batch_size=10000,
    num_batches_without_new_max=1000,
):
    task = sbibm.get_task(task_name)

    reference_samples = task.get_reference_posterior_samples(
        num_observation=num_observation)

    proposal_dist = get_proposal(
        task=task,
        samples=reference_samples,
        prior_weight=prior_weight,
        bounded=False,
        density_estimator="flow",
        flow_model="nsf",
    )

    samples = run(
        task=task,
        num_observation=num_observation,
        num_samples=num_samples,
        batch_size=batch_size,
        num_batches_without_new_max=num_batches_without_new_max,
        multiplier_M=multiplier_M,
        proposal_dist=proposal_dist,
    )

    num_samples_plotting = 1000
    pairplot([
        samples.numpy()[:num_samples_plotting, :],
        reference_samples.numpy()[:num_samples_plotting, :],
    ])

    acc = c2st(samples, reference_samples[:num_samples, :])

    assert torch.abs(acc - 0.5) < 0.01
Exemplo n.º 3
0
Arquivo: base.py Projeto: ulamaca/sbi
    def _summarize(
        self,
        round_: int,
        x_o: Union[Tensor, None],
        theta_bank: Tensor,
        x_bank: Tensor,
        posterior_samples_acceptance_rate: Optional[Tensor] = None,
    ) -> None:
        """Update the summary_writer with statistics for a given round.

        Statistics are extracted from the arguments and from entries in self._summary
        created during training.
        """

        # NB. This is a subset of the logging as done in `GH:conormdurkan/lfi`. A big
        # part of the logging was removed because of API changes, e.g., logging
        # comparisons to ground-truth parameters and samples.

        # Median |x - x0| for most recent round.
        if x_o is not None:
            median_observation_distance = torch.median(
                torch.sqrt(
                    torch.sum(
                        (x_bank - x_o.reshape(1, -1))**2,
                        dim=-1,
                    )))
            self._summary["median_observation_distances"].append(
                median_observation_distance.item())

            self._summary_writer.add_scalar(
                tag="median_observation_distance",
                scalar_value=self._summary["median_observation_distances"][-1],
                global_step=round_ + 1,
            )

        # Rejection sampling acceptance rate, only for SNPE.
        if posterior_samples_acceptance_rate is not None:
            self._summary["rejection_sampling_acceptance_rates"].append(
                posterior_samples_acceptance_rate.item())

            self._summary_writer.add_scalar(
                tag="rejection_sampling_acceptance_rate",
                scalar_value=self.
                _summary["rejection_sampling_acceptance_rates"][-1],
                global_step=round_ + 1,
            )

        # Plot most recently sampled parameters.
        # XXX: need more plotting kwargs, e.g., prior limits.
        parameters = theta_bank

        figure, axes = pairplot(parameters.to("cpu").numpy())

        self._summary_writer.add_figure(tag="posterior_samples",
                                        figure=figure,
                                        global_step=round_ + 1)

        # Add most recent training stats to summary writer.
        self._summary_writer.add_scalar(
            tag="epochs_trained",
            scalar_value=self._summary["epochs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.add_scalar(
            tag="best_validation_log_prob",
            scalar_value=self._summary["best_validation_log_probs"][-1],
            global_step=round_ + 1,
        )

        self._summary_writer.flush()
Exemplo n.º 4
0
    reference_samples = task.get_reference_posterior_samples(
        num_observation=num_observation)

    proposal_dist = get_proposal(
        task=task,
        samples=reference_samples,
        prior_weight=prior_weight,
        bounded=False,
        density_estimator="flow",
        flow_model="nsf",
    )

    samples = run(
        task=task,
        num_observation=num_observation,
        num_simulations=num_simulations,
        num_samples=num_samples,
        batch_size=batch_size,
        proposal_dist=proposal_dist,
    )

    num_samples_plotting = 1000
    pairplot([
        samples.numpy()[:num_samples_plotting, :],
        reference_samples.numpy()[:num_samples_plotting, :],
    ])

    acc = c2st(samples, reference_samples[:num_samples, :])

    assert torch.abs(acc - 0.5) < 0.01
Exemplo n.º 5
0
def plot_blobs_different():
    X, Y = sample_blobs_different(500)
    pairplot([X.numpy(), Y.numpy()])
    plt.show()
Exemplo n.º 6
0
def plot_blobs_same():
    X, Y = sample_blobs_same(500, sep=10)
    pairplot([X.numpy(), Y.numpy()])
    plt.show()