def test_default_shape_from_params(): with raises(ValueError, match="^ndim_supp*"): default_shape_from_params(0, (np.array([1, 2]), 0)) res = default_shape_from_params(1, (np.array([1, 2]), np.eye(2)), rep_param_idx=0) assert res == (2, ) res = default_shape_from_params(1, (np.array([1, 2]), 0), param_shapes=((2, ), ())) assert res == (2, ) with raises(ValueError, match="^Reference parameter*"): default_shape_from_params(1, (np.array(1), ), rep_param_idx=0) res = default_shape_from_params(2, (np.array([1, 2]), np.ones((2, 3, 4))), rep_param_idx=1) assert res == (3, 4)
def _shape_from_params(self, dist_params, rep_param_idx=1, param_shapes=None): return default_shape_from_params(self.ndim_supp, dist_params, rep_param_idx, param_shapes)