コード例 #1
0
    def append_simulations(self, theta: Tensor,
                           x: Tensor) -> "RestrictionEstimator":
        r"""
        Store parameters and simulation outputs to use them for training later.
        Data ar stored as entries in lists for each type of variable (parameter/data).

        Args:
            theta: Parameter sets.
            x: Simulation outputs.

        Returns:
            `RestrictionEstimator` object (returned so that this function is chainable).
        """

        validate_theta_and_x(theta, x)

        if self._valid_or_invalid_criterion == "nan":
            label, _, _ = handle_invalid_x(x)
        else:
            label = self._valid_or_invalid_criterion(x)

        label = label.long()

        if self._data_round_index:
            self._data_round_index.append(self._data_round_index[-1] + 1)
        else:
            self._data_round_index.append(0)

        self._theta_roundwise.append(theta)
        self._x_roundwise.append(x)
        self._label_roundwise.append(label)

        return self
コード例 #2
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_validate_theta_and_x_shapes(
    shape_x: Tuple[int], shape_theta: Tuple[int]
) -> None:
    x = torch.empty(shape_x)
    theta = torch.empty(shape_theta)

    validate_theta_and_x(theta, x, training_device="cpu")
コード例 #3
0
def test_validate_theta_and_x_device(device):

    # Skip GPU test if not available.
    if device == "cuda:0" and not torch.cuda.is_available():
        pass
    else:
        theta = torch.ones((2, 2), dtype=torch.float32).to(device)
        x = torch.zeros((2, 10), dtype=torch.float32).to(device)

        assert isinstance(
            theta, torch.Tensor
        ), f"{device} based torch.tensor is not an instance of torch.Tensor"
        assert theta.dtype == torch.float32, (
            f"{device} based torch.tensor(dtype=torch.float32) yields unexpected dtype"
            f"{theta.dtype}.")
        if device == "cuda:0":
            assert not isinstance(
                theta, torch.FloatTensor
            ), f"""{device} based torch.tensor(dtype=torch.float32) must not be 
            FloatTensor."""
        else:
            assert isinstance(
                theta, torch.FloatTensor
            ), f"{device} based torch.tensor(dtype=torch.float32) must be FloatTensor."
        validate_theta_and_x(theta, x)

        with pytest.raises(AssertionError) as _:
            validate_theta_and_x(theta, x.to(torch.float64))

        plain_ft = torch.FloatTensor((32, 8))
        assert (plain_ft.dtype == torch.float32
                ), "FloatTensor does not expose float32 dtype."
コード例 #4
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None:
    theta = torch.empty((1, 1)).to(data_device)
    x = torch.empty((1, 1)).to(data_device)

    if training_device != data_device:
        with pytest.warns(UserWarning):
            theta, x = validate_theta_and_x(theta, x, training_device=training_device)
    else:
        theta, x = validate_theta_and_x(theta, x, training_device=training_device)

    assert str(theta.device) == training_device, (
        f"Data should have its device converted from '{data_device}' "
        f"to training_device '{training_device}'."
    )
コード例 #5
0
def test_validate_theta_and_x_gpu():

    gpu_if_present = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    theta = torch.ones((32, 8), dtype=torch.float32).to(gpu_if_present)
    x = torch.zeros((32, 100), dtype=torch.float32).to(gpu_if_present)

    assert isinstance(
        theta, torch.Tensor
    ), "gpu based torch.tensor is not an instance of torch.Tensor"
    assert (
        theta.dtype == torch.float32
    ), f"gpu based torch.tensor(dtype=torch.float32) yields unexpected dtype {theta.dtype}."
    assert not isinstance(
        theta, torch.FloatTensor
    ), "gpu based torch.tensor(dtype=torch.float32) is FloatTensor, even though it shouldn't be."

    validate_theta_and_x(theta, x)

    with pytest.raises(AssertionError) as exc:
        validate_theta_and_x(theta, x.to(torch.float64))
コード例 #6
0
def test_validate_theta_and_x_cpu():

    cpu_device = torch.device("cpu")

    theta = torch.ones((32, 8), dtype=torch.float32).to(cpu_device)
    x = torch.zeros((32, 100), dtype=torch.float32).to(cpu_device)
    plain_ft = torch.FloatTensor((32, 8))  # using an explicit type

    assert isinstance(
        theta, torch.Tensor
    ), "cpu based torch.tensor is not an instance of torch.Tensor"
    assert (
        theta.dtype == torch.float32
    ), f"cpu based torch.tensor(dtype=torch.float32) yields unexpected dtype {theta.dtype}."
    assert isinstance(
        theta, torch.FloatTensor
    ), "cpu based torch.tensor(dtype=torch.float32) is no FloatTensor."
    assert plain_ft.dtype == torch.float32, "FloatTensor does not expose float32 dtype."

    # test on cpu
    validate_theta_and_x(theta, x)

    with pytest.raises(AssertionError) as exc:
        validate_theta_and_x(theta, x.to(torch.float64))
コード例 #7
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_validate_theta_and_x_type() -> None:
    x = torch.empty((1, 1))
    theta = torch.empty((1, 1), dtype=int)

    with pytest.raises(Exception):
        validate_theta_and_x(theta, x, training_device="cpu")
コード例 #8
0
ファイル: inference_on_device_test.py プロジェクト: bkmi/sbi
def test_validate_theta_and_x_tensor() -> None:
    x = torch.empty((1, 1))
    theta = torch.ones((1, 1)).tolist()

    with pytest.raises(Exception):
        validate_theta_and_x(theta, x, training_device="cpu")