Exemplo n.º 1
0
def process_pytorch_prior(
        prior: Distribution) -> Tuple[Distribution, int, bool]:
    """Return PyTorch prior adapted to the requirements for sbi.

    Args:
        prior: PyTorch distribution prior provided by the user.

    Raises:
        ValueError: If prior is defined over an unwrapped scalar variable.

    Returns:
        prior: PyTorch distribution prior.
        theta_numel: Number of parameters - elements in a single sample from the prior.
        prior_returns_numpy: False.
    """

    # Turn off validation of input arguments to allow `log_prob()` on samples outside
    # of the support.
    prior.set_default_validate_args(False)

    # Reject unwrapped scalar priors.
    # This will reject Uniform priors with dimension larger than 1.
    if prior.sample().ndim == 0:
        raise ValueError(
            "Detected scalar prior. Please make sure to pass a PyTorch prior with "
            "`batch_shape=torch.Size([1])` or `event_shape=torch.Size([1])`.")
    # Cast 1D Uniform to BoxUniform to avoid shape error in mdn log prob.
    elif isinstance(prior, Uniform) and prior.batch_shape.numel() == 1:
        prior = BoxUniform(low=prior.low, high=prior.high)
        warnings.warn(
            "Casting 1D Uniform prior to BoxUniform to match sbi batch requirements."
        )

    check_prior_batch_behavior(prior)
    check_prior_batch_dims(prior)

    if not prior.sample().dtype == float32:
        prior = PytorchReturnTypeWrapper(prior,
                                         return_type=float32,
                                         validate_args=False)

    # This will fail for float64 priors.
    check_prior_return_type(prior)

    theta_numel = prior.sample().numel()

    return prior, theta_numel, False
Exemplo n.º 2
0
 def setUp(self):
     super().setUp()
     self.scalar_sample = 1
     self.tensor_sample_1 = torch.ones(3, 2)
     self.tensor_sample_2 = torch.ones(3, 2, 3)
     Distribution.set_default_validate_args(True)
Exemplo n.º 3
0
 def tearDown(self):
     super().tearDown()
     Distribution.set_default_validate_args(False)
Exemplo n.º 4
0
#!/usr/bin/env python

import math
import numpy as np
import os
import psutil
import random
import torch
import torch.nn as nn
import torch.utils.data as data

from functools import cached_property
from torch.distributions import Distribution
from typing import Any, Tuple, List

Distribution.set_default_validate_args(False)


class Simulator(nn.Module):
    r"""Abstract Simulator"""
    @cached_property
    def prior(self) -> Distribution:
        r""" p(theta) """

        return self.masked_prior(...)

    def masked_prior(self, mask: torch.BoolTensor) -> Distribution:
        r""" p(theta_a) """

        raise NotImplementedError()