def test_truncated_at_zero(self): """Test scaling shape parameters implying a truncation at zero.""" expected = [np.array([1.0, 0, -0.5]), np.array([np.inf, np.inf, np.inf])] shape_parameters = [0, np.inf] plugin = Plugin(distribution="truncnorm", shape_parameters=shape_parameters) plugin._rescale_shape_parameters(self.location_parameter, self.scale_parameter) self.assertArrayAlmostEqual(plugin.shape_parameters, expected)
def test_discrete_shape_parameters(self): """Test scaling discrete shape parameters.""" expected = [np.array([-3, -2.666667, -2.5]), np.array([7, 4, 2.5])] shape_parameters = [-4, 6] plugin = Plugin(distribution="truncnorm", shape_parameters=shape_parameters) plugin._rescale_shape_parameters(self.location_parameter, self.scale_parameter) self.assertArrayAlmostEqual(plugin.shape_parameters, expected)
def test_alternative_distribution(self): """Test specifying a distribution other than truncated normal. In this instance, no rescaling is applied.""" shape_parameters = [0, np.inf] plugin = Plugin(distribution="norm", shape_parameters=shape_parameters) plugin._rescale_shape_parameters(self.location_parameter, self.scale_parameter) self.assertArrayEqual(plugin.shape_parameters, shape_parameters)
def test_no_shape_parameters_exception(self): """Test raising an exception when shape parameters are not specified for the truncated normal distribution.""" plugin = Plugin(distribution="truncnorm") msg = "For the truncated normal distribution" with self.assertRaisesRegex(ValueError, msg): plugin._rescale_shape_parameters(self.location_parameter, self.scale_parameter)