def append_simulations(self, theta: Tensor, x: Tensor) -> "RestrictionEstimator": r""" Store parameters and simulation outputs to use them for training later. Data ar stored as entries in lists for each type of variable (parameter/data). Args: theta: Parameter sets. x: Simulation outputs. Returns: `RestrictionEstimator` object (returned so that this function is chainable). """ validate_theta_and_x(theta, x) if self._valid_or_invalid_criterion == "nan": label, _, _ = handle_invalid_x(x) else: label = self._valid_or_invalid_criterion(x) label = label.long() if self._data_round_index: self._data_round_index.append(self._data_round_index[-1] + 1) else: self._data_round_index.append(0) self._theta_roundwise.append(theta) self._x_roundwise.append(x) self._label_roundwise.append(label) return self
def test_validate_theta_and_x_shapes( shape_x: Tuple[int], shape_theta: Tuple[int] ) -> None: x = torch.empty(shape_x) theta = torch.empty(shape_theta) validate_theta_and_x(theta, x, training_device="cpu")
def test_validate_theta_and_x_device(device): # Skip GPU test if not available. if device == "cuda:0" and not torch.cuda.is_available(): pass else: theta = torch.ones((2, 2), dtype=torch.float32).to(device) x = torch.zeros((2, 10), dtype=torch.float32).to(device) assert isinstance( theta, torch.Tensor ), f"{device} based torch.tensor is not an instance of torch.Tensor" assert theta.dtype == torch.float32, ( f"{device} based torch.tensor(dtype=torch.float32) yields unexpected dtype" f"{theta.dtype}.") if device == "cuda:0": assert not isinstance( theta, torch.FloatTensor ), f"""{device} based torch.tensor(dtype=torch.float32) must not be FloatTensor.""" else: assert isinstance( theta, torch.FloatTensor ), f"{device} based torch.tensor(dtype=torch.float32) must be FloatTensor." validate_theta_and_x(theta, x) with pytest.raises(AssertionError) as _: validate_theta_and_x(theta, x.to(torch.float64)) plain_ft = torch.FloatTensor((32, 8)) assert (plain_ft.dtype == torch.float32 ), "FloatTensor does not expose float32 dtype."
def test_validate_theta_and_x_device(training_device: str, data_device: str) -> None: theta = torch.empty((1, 1)).to(data_device) x = torch.empty((1, 1)).to(data_device) if training_device != data_device: with pytest.warns(UserWarning): theta, x = validate_theta_and_x(theta, x, training_device=training_device) else: theta, x = validate_theta_and_x(theta, x, training_device=training_device) assert str(theta.device) == training_device, ( f"Data should have its device converted from '{data_device}' " f"to training_device '{training_device}'." )
def test_validate_theta_and_x_gpu(): gpu_if_present = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") theta = torch.ones((32, 8), dtype=torch.float32).to(gpu_if_present) x = torch.zeros((32, 100), dtype=torch.float32).to(gpu_if_present) assert isinstance( theta, torch.Tensor ), "gpu based torch.tensor is not an instance of torch.Tensor" assert ( theta.dtype == torch.float32 ), f"gpu based torch.tensor(dtype=torch.float32) yields unexpected dtype {theta.dtype}." assert not isinstance( theta, torch.FloatTensor ), "gpu based torch.tensor(dtype=torch.float32) is FloatTensor, even though it shouldn't be." validate_theta_and_x(theta, x) with pytest.raises(AssertionError) as exc: validate_theta_and_x(theta, x.to(torch.float64))
def test_validate_theta_and_x_cpu(): cpu_device = torch.device("cpu") theta = torch.ones((32, 8), dtype=torch.float32).to(cpu_device) x = torch.zeros((32, 100), dtype=torch.float32).to(cpu_device) plain_ft = torch.FloatTensor((32, 8)) # using an explicit type assert isinstance( theta, torch.Tensor ), "cpu based torch.tensor is not an instance of torch.Tensor" assert ( theta.dtype == torch.float32 ), f"cpu based torch.tensor(dtype=torch.float32) yields unexpected dtype {theta.dtype}." assert isinstance( theta, torch.FloatTensor ), "cpu based torch.tensor(dtype=torch.float32) is no FloatTensor." assert plain_ft.dtype == torch.float32, "FloatTensor does not expose float32 dtype." # test on cpu validate_theta_and_x(theta, x) with pytest.raises(AssertionError) as exc: validate_theta_and_x(theta, x.to(torch.float64))
def test_validate_theta_and_x_type() -> None: x = torch.empty((1, 1)) theta = torch.empty((1, 1), dtype=int) with pytest.raises(Exception): validate_theta_and_x(theta, x, training_device="cpu")
def test_validate_theta_and_x_tensor() -> None: x = torch.empty((1, 1)) theta = torch.ones((1, 1)).tolist() with pytest.raises(Exception): validate_theta_and_x(theta, x, training_device="cpu")