def test_monto_carlo_objective(self): log_p, log_q = prepare_test_payload() obj = monte_carlo_objective(log_p, log_q, axis=[0]) assert_allclose( T.reduce_mean(obj), monte_carlo_objective(log_p, log_q, axis=[0], reduction='mean'), rtol=1e-4, atol=1e-6 ) assert_allclose( T.reduce_sum(obj), monte_carlo_objective(log_p, log_q, axis=[0], reduction='sum'), rtol=1e-4, atol=1e-6 ) obj_shape = T.shape(obj) assert_allclose(obj, T.log_mean_exp(log_p - log_q, axis=[0])) obj_k = monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True) assert_allclose( T.reduce_mean(obj_k), monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='mean') ) assert_allclose( T.reduce_sum(obj_k), monte_carlo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='sum') ) self.assertListEqual([1] + obj_shape, T.shape(obj_k)) assert_allclose( obj_k, T.log_mean_exp(log_p - log_q, axis=[0], keepdims=True) )
def test_invertible_matrices(self): for cls in (LooseInvertibleMatrix, StrictInvertibleMatrix): for n in [1, 3, 5]: m = cls(np.random.randn(n, n)) self.assertEqual(repr(m), f'{cls.__qualname__}(size={n})') self.assertEqual(m.size, n) m = tk.layers.jit_compile(m) # check the initial value is an orthogonal matrix matrix, _ = m(inverse=False, compute_log_det=False) inv_matrix, _ = m(inverse=True, compute_log_det=False) assert_allclose(np.eye(n), T.matmul(matrix, inv_matrix), rtol=1e-4, atol=1e-6) assert_allclose(np.eye(n), T.matmul(inv_matrix, matrix), rtol=1e-4, atol=1e-6) # check the invertibility check_invertible_matrix(self, m, n) # check the gradient matrix, log_det = m(inverse=False, compute_log_det=True) params = list(tk.layers.iter_parameters(m)) grads = T.grad( [T.reduce_sum(matrix), T.reduce_sum(log_det)], params) # update with gradient, then check the invertibility if cls is StrictInvertibleMatrix: for param, grad in zip(params, grads): with T.no_grad(): T.assign(param, param + 0.001 * grad) check_invertible_matrix(self, m, n)
def test_elbo(self): log_p, log_q = prepare_test_payload() obj = elbo_objective(log_p, log_q) assert_allclose( T.reduce_mean(obj), elbo_objective(log_p, log_q, reduction='mean') ) assert_allclose( T.reduce_sum(obj), elbo_objective(log_p, log_q, reduction='sum') ) obj_shape = T.shape(obj) assert_allclose(obj, log_p - log_q) obj_r = elbo_objective(log_p, log_q, axis=[0]) self.assertListEqual(obj_shape[1:], T.shape(obj_r)) assert_allclose(obj_r, T.reduce_mean(log_p - log_q, axis=[0])) obj_rk = elbo_objective(log_p, log_q, axis=[0], keepdims=True) assert_allclose( T.reduce_mean(obj_rk), elbo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='mean') ) assert_allclose( T.reduce_sum(obj_rk), elbo_objective(log_p, log_q, axis=[0], keepdims=True, reduction='sum') ) self.assertListEqual([1] + obj_shape[1:], T.shape(obj_rk)) assert_allclose( obj_rk, T.reduce_mean(log_p - log_q, axis=[0], keepdims=True) )
def test_iwae(self): assert_allclose_ = functools.partial(assert_allclose, rtol=1e-5, atol=1e-6) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True) cost = iwae_estimator(log_f, axis=[0]) assert_allclose_(-cost, iwae_estimator(log_f, axis=[0], negative=True)) assert_allclose_(T.reduce_mean(cost), iwae_estimator(log_f, axis=[0], reduction='mean')) assert_allclose_(T.reduce_sum(cost), iwae_estimator(log_f, axis=[0], reduction='sum')) cost_shape = T.shape(cost) assert_allclose_( T.grad([T.reduce_sum(cost)], [y])[0], T.reduce_sum(wk_hat * (2 * x * y), axis=[0])) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True) cost_k = iwae_estimator(log_f, axis=[0], keepdims=True) assert_allclose_( T.reduce_mean(cost_k), iwae_estimator(log_f, axis=[0], keepdims=True, reduction='mean')) assert_allclose_( T.reduce_sum(cost_k), iwae_estimator(log_f, axis=[0], keepdims=True, reduction='sum')) assert_allclose_( -cost_k, T.to_numpy( iwae_estimator(log_f, axis=[0], keepdims=True, negative=True))) self.assertListEqual([1] + cost_shape, T.shape(cost_k)) assert_allclose_( T.grad([T.reduce_sum(cost_k)], [y])[0], T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))
def test_monto_carlo_objective(self): log_p, log_q = prepare_test_payload() ll = importance_sampling_log_likelihood(log_p, log_q, axis=[0]) ll_shape = T.shape(ll) assert_allclose_(ll, T.log_mean_exp(log_p - log_q, axis=[0])) assert_allclose_( T.reduce_mean(ll), importance_sampling_log_likelihood(log_p, log_q, axis=[0], reduction='mean')) assert_allclose_( T.reduce_sum(ll), importance_sampling_log_likelihood(log_p, log_q, axis=[0], reduction='sum')) ll_k = importance_sampling_log_likelihood(log_p, log_q, axis=[0], keepdims=True) self.assertListEqual([1] + ll_shape, T.shape(ll_k)) assert_allclose_( ll_k, T.log_mean_exp(log_p - log_q, axis=[0], keepdims=True))
def log_prob_fn(t): log_px = distribution.log_prob(t.transform_origin.tensor, group_ndims=0) y, log_det = flow(t.transform_origin.tensor) # y and log |dy/dx| assert_allclose(y, t.tensor, atol=1e-4, rtol=1e-6) ctx.assertEqual( T.rank(log_det), T.rank(log_px) - (flow.get_x_event_ndims() - distribution.event_ndims)) return -log_det + T.reduce_sum( log_px, T.int_range( -(flow.get_x_event_ndims() - distribution.event_ndims), 0))
def test_sgvb(self): assert_allclose_ = functools.partial(assert_allclose, rtol=1e-5, atol=1e-6) # default x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost = sgvb_estimator(f) assert_allclose_(-cost, sgvb_estimator(f, negative=True)) assert_allclose_(T.reduce_mean(cost), sgvb_estimator(f, reduction='mean')) assert_allclose_(T.reduce_sum(cost), sgvb_estimator(f, reduction='sum')) cost_shape = T.shape(cost) assert_allclose_( T.grad([T.reduce_sum(cost)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0])) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost_r = sgvb_estimator(f, axis=[0]) assert_allclose_(-cost_r, sgvb_estimator(f, axis=[0], negative=True)) self.assertListEqual(cost_shape[1:], T.shape(cost_r)) assert_allclose_( T.grad([T.reduce_sum(cost_r)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0]) / 7) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost_rk = sgvb_estimator(f, axis=[0], keepdims=True) assert_allclose_(T.reduce_mean(cost_rk), sgvb_estimator(f, axis=[0], reduction='mean')) assert_allclose_(T.reduce_sum(cost_rk), sgvb_estimator(f, axis=[0], reduction='sum')) assert_allclose_( -cost_rk, sgvb_estimator(f, axis=[0], keepdims=True, negative=True)) self.assertListEqual([1] + cost_shape[1:], T.shape(cost_rk)) assert_allclose_( T.grad([T.reduce_sum(cost_rk)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0]) / 7)
def check_invertible_linear(ctx, spatial_ndims: int, invertible_linear_factory, linear_factory, strict: bool,): batch_shape = [2] num_features = 4 spatial_shape = [5, 6, 7][:spatial_ndims] x = T.random.randn(make_conv_shape( batch_shape, num_features, spatial_shape)) # construct the layer flow = invertible_linear_factory(num_features, strict=strict) ctx.assertIn(f'num_features={num_features}', repr(flow)) flow = tk.layers.jit_compile(flow) # derive the expected answer weight, log_det = flow.invertible_matrix( inverse=False, compute_log_det=True) linear_kwargs = {} if spatial_ndims > 0: linear_kwargs['kernel_size'] = 1 linear = linear_factory( num_features, num_features, weight_init=T.reshape(weight, T.shape(weight) + [1] * spatial_ndims), use_bias=False, **linear_kwargs ) x_flatten, front_shape = T.flatten_to_ndims(x, spatial_ndims + 2) expected_y = T.unflatten_from_ndims(linear(x_flatten), front_shape) expected_log_det = T.expand( T.reduce_sum(T.expand(log_det, spatial_shape)), batch_shape) # check the invertible layer flow_standard_check(ctx, flow, x, expected_y, expected_log_det, T.random.randn(batch_shape))
def check_scale(ctx, scale: Scale, x, pre_scale, expected_y, expected_log_det): assert(T.shape(x) == T.shape(expected_log_det)) # dimension error with pytest.raises(Exception, match=r'`rank\(input\) >= event_ndims` does not hold'): _ = scale(T.random.randn([1]), T.random.randn([1]), event_ndims=2) with pytest.raises(Exception, match=r'`rank\(input\) >= rank\(pre_scale\)` does not hold'): _ = scale(T.random.randn([1]), T.random.randn([1, 2]), event_ndims=1) with pytest.raises(Exception, match=r'The shape of `input_log_det` is not expected'): _ = scale(T.random.randn([2, 3]), T.random.randn([2, 3]), event_ndims=1, input_log_det=T.random.randn([3])) with pytest.raises(Exception, match=r'The shape of `input_log_det` is not expected'): _ = scale(T.random.randn([2, 3]), T.random.randn([2, 3]), event_ndims=2, input_log_det=T.random.randn([2])) # check call for event_ndims in range(T.rank(pre_scale), T.rank(x)): this_expected_log_det = T.reduce_sum( expected_log_det, axis=T.int_range(-event_ndims, 0)) input_log_det = T.random.randn(T.shape(this_expected_log_det)) # check no compute log_det y, log_det = scale(x, pre_scale, event_ndims=event_ndims, compute_log_det=False) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) ctx.assertIsNone(log_det) # check compute log_det y, log_det = scale(x, pre_scale, event_ndims=event_ndims) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) assert_allclose(log_det, this_expected_log_det, rtol=1e-4, atol=1e-6) # check compute log_det with input_log_det y, log_det = scale( x, pre_scale, event_ndims=event_ndims, input_log_det=input_log_det) assert_allclose(y, expected_y, rtol=1e-4, atol=1e-6) assert_allclose(log_det, input_log_det + this_expected_log_det, rtol=1e-4, atol=1e-6) # check inverse, no compute log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True, compute_log_det=False) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) ctx.assertIsNone(log_det) # check inverse, compute log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) assert_allclose(log_det, -this_expected_log_det, rtol=1e-4, atol=1e-6) # check inverse, compute log_det with input_log_det inv_x, log_det = scale(expected_y, pre_scale, event_ndims=event_ndims, inverse=True, input_log_det=input_log_det) assert_allclose(inv_x, x, rtol=1e-4, atol=1e-6) assert_allclose(log_det, input_log_det - this_expected_log_det, rtol=1e-4, atol=1e-6)
def test_sample_and_log_prob(self): array_low = np.random.randn(2, 1) array_high = np.exp(np.random.randn(1, 3)) + 1. log_zero = -1e6 def log_prob(x, low, high, group_ndims=0): if low is None and high is None: low, high = 0., 1. log_pdf = -np.log(np.ones_like(x) * (high - low)) log_pdf = np.where(np.logical_and(low <= x, x <= high), log_pdf, log_zero) log_pdf = np.sum(log_pdf, axis=tuple(range(-group_ndims, 0))) return log_pdf for shape, dtype, (low, high), event_ndims in product( [None, [], [5, 4]], float_dtypes, [(None, None), (-1., 2.), (array_low, array_high)], range(5)): low_rank = len(np.shape(low)) if low is not None else 0 if event_ndims > len(shape or []) + low_rank: continue if isinstance(low, np.ndarray): low_t = T.as_tensor(low, dtype=dtype) high_t = T.as_tensor(high, dtype=dtype) uniform = Uniform(shape=shape, low=low_t, high=high_t, event_ndims=event_ndims, log_zero=log_zero) value_shape = (shape or []) + [2, 3] self.assertIs(uniform.low, low_t) self.assertIs(uniform.high, high_t) else: uniform = Uniform(shape=shape, low=low, high=high, dtype=dtype, event_ndims=event_ndims, log_zero=log_zero) value_shape = shape or [] self.assertEqual(uniform.low, low) self.assertEqual(uniform.high, high) self.assertEqual(uniform.log_zero, log_zero) self.assertEqual(uniform.value_shape, value_shape) # sample(n_samples=None) t = uniform.sample() x = T.to_numpy(t.tensor) sample_shape = value_shape self.assertIsInstance(t, StochasticTensor) self.assertIs(t.distribution, uniform) self.assertEqual(T.get_dtype(t.tensor), dtype) self.assertEqual(t.n_samples, None) self.assertEqual(t.group_ndims, 0) self.assertEqual(t.reparameterized, True) self.assertIsInstance(t.tensor, T.Tensor) self.assertEqual(T.shape(t.tensor), sample_shape) for log_pdf in [t.log_prob(), uniform.log_prob(t)]: self.assertEqual(T.get_dtype(log_pdf), dtype) assert_allclose(log_pdf, log_prob(x, low, high, event_ndims), rtol=1e-4) # test log-prob on out-of-range values assert_allclose( uniform.log_prob(t.tensor * 10.), log_prob(x * 10., low, high, event_ndims), rtol=1e-4, ) # sample(n_samples=7) if event_ndims >= 1: t = uniform.sample(n_samples=7, group_ndims=-1, reparameterized=False) x = T.to_numpy(t.tensor) sample_shape = [7] + value_shape self.assertIsInstance(t, StochasticTensor) self.assertIs(t.distribution, uniform) self.assertEqual(T.get_dtype(t.tensor), dtype) self.assertEqual(t.n_samples, 7) self.assertEqual(t.group_ndims, -1) self.assertEqual(t.reparameterized, False) self.assertIsInstance(t.tensor, T.Tensor) self.assertEqual(T.shape(t.tensor), sample_shape) reduce_ndims = event_ndims - 1 for log_pdf in [ t.log_prob(), uniform.log_prob(t, group_ndims=-1) ]: self.assertEqual(T.get_dtype(log_pdf), dtype) assert_allclose(log_pdf, log_prob(x, low, high, reduce_ndims), rtol=1e-4) # test reparameterized low_t = T.requires_grad(T.as_tensor(array_low)) high_t = T.requires_grad(T.as_tensor(array_high)) uniform = Uniform(low=low_t, high=high_t) t = uniform.sample() self.assertTrue(t.reparameterized) u = (T.to_numpy(t.tensor) - array_low) / (array_high - array_low) [low_grad, high_grad] = T.grad([T.reduce_sum(t.tensor)], [low_t, high_t]) assert_allclose(low_grad, np.sum(1. - u, axis=-1, keepdims=True), rtol=1e-4) assert_allclose(high_grad, np.sum(u, axis=0, keepdims=True), rtol=1e-4) t = uniform.sample(reparameterized=False) w_t = T.requires_grad(T.as_tensor(np.random.randn(2, 3))) self.assertFalse(t.reparameterized) [low_grad, high_grad] = T.grad([T.reduce_sum(w_t * t.tensor)], [low_t, high_t], allow_unused=True) self.assertTrue(T.is_null_grad(low_t, low_grad)) self.assertTrue(T.is_null_grad(high_t, high_grad))
def check_distribution_instance(ctx, d, event_ndims, batch_shape, min_event_ndims, max_event_ndims, log_prob_fn, transform_origin_distribution=None, transform_origin_group_ndims=None, **expected_attrs): ctx.assertLessEqual(max_event_ndims - event_ndims, d.batch_ndims) event_shape = expected_attrs.get('event_shape', None) ctx.assertEqual(d.min_event_ndims, min_event_ndims) ctx.assertEqual(d.value_ndims, len(batch_shape) + event_ndims) if event_shape is not None: ctx.assertEqual(d.value_shape, batch_shape + event_shape) ctx.assertEqual(d.batch_shape, batch_shape) ctx.assertEqual(d.batch_ndims, len(batch_shape)) ctx.assertEqual(d.event_ndims, event_ndims) ctx.assertEqual(d.event_shape, event_shape) for attr, val in expected_attrs.items(): ctx.assertEqual(getattr(d, attr), val) ctx.assertEqual( d.validate_tensors, expected_attrs.get('validate_tensors', settings.validate_tensors)) # check sample for n_samples in (None, 5): for group_ndims in (None, 0, -(event_ndims - min_event_ndims), max_event_ndims - event_ndims): for reparameterized2 in (None, True, False): if reparameterized2 and not d.reparameterized: continue # sample() sample_kwargs = {} if n_samples is not None: sample_kwargs['n_samples'] = n_samples sample_shape = [n_samples] else: sample_shape = [] if group_ndims is not None: sample_kwargs['group_ndims'] = group_ndims else: group_ndims = 0 if reparameterized2 is not None: sample_kwargs['reparameterized'] = reparameterized2 else: reparameterized2 = d.reparameterized t = d.sample(**sample_kwargs) ctx.assertEqual(t.group_ndims, group_ndims) ctx.assertEqual(t.reparameterized, reparameterized2) ctx.assertEqual(T.rank(t.tensor), d.value_ndims + len(sample_shape)) ctx.assertEqual( T.shape(t.tensor)[:(d.batch_ndims + len(sample_shape))], sample_shape + d.batch_shape) if transform_origin_distribution is not None: ctx.assertIsInstance(t.transform_origin, StochasticTensor) ctx.assertIs(t.transform_origin.distribution, transform_origin_distribution) ctx.assertIs(t.transform_origin.group_ndims, transform_origin_group_ndims) # log_prob() expected_log_prob = log_prob_fn(t) for group_ndims2 in (None, 0, -(event_ndims - min_event_ndims), max_event_ndims - event_ndims): if group_ndims2 is not None: log_prob_kwargs = {'group_ndims': group_ndims2} else: log_prob_kwargs = {} group_ndims2 = group_ndims log_prob = t.log_prob(**log_prob_kwargs) ctx.assertEqual( T.shape(log_prob), T.shape(t.tensor)[:T.rank(t.tensor) - (group_ndims2 + event_ndims)]) assert_allclose( log_prob, T.reduce_sum( expected_log_prob, T.int_range( -(group_ndims2 + (event_ndims - min_event_ndims)), 0)), rtol=1e-4, atol=1e-6, ) prob = t.prob(**log_prob_kwargs) assert_allclose(prob, T.exp(log_prob), rtol=1e-4, atol=1e-6) if transform_origin_distribution is not None: for p in (log_prob, prob): ctx.assertIsInstance(p.transform_origin, StochasticTensor) ctx.assertIs(p.transform_origin.distribution, transform_origin_distribution) ctx.assertIs(p.transform_origin.group_ndims, transform_origin_group_ndims)
def test_embedding(self): n_channels = 3 n_embeddings = n_samples for spatial_ndims in (0, 1, 2, 3): w_shape = make_conv_shape([], n_channels, [4, 5, 6][:spatial_ndims]) w_size = int(np.prod(w_shape)) layer = getattr(tk.layers, (f'Embedding{spatial_ndims}d' if spatial_ndims > 0 else 'Embedding'))(n_embeddings, w_shape) weight = layer.weight # check the weight self.assertEqual(T.shape(weight), [n_embeddings] + w_shape) reduce_axis = list(range(len(w_shape) + 1)) reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1) w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis)) np.testing.assert_array_less( w_mean, 3. / np.sqrt(n_samples * w_size / n_channels)) # check the output layer = jit_compile(layer) weight_array = T.to_numpy(weight) for in_shape in ([7], [7, 8]): indices = T.random.randint(0, n_samples, in_shape) indices = T.concat([indices, indices[:3]], axis=0) # check the output output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) # check the grad if spatial_ndims in (0, 1): out_sum = T.reduce_sum(output**2) [grad] = T.grad([out_sum], [weight]) expected_grad = np.zeros(T.shape(weight)) for idx in T.to_numpy(indices).reshape([-1]): expected_grad[idx] += 2. * weight_array[idx] assert_allclose(grad, expected_grad) # test the constructor error if spatial_ndims > 0: with pytest.raises( ValueError, match=f'`embedding_size` must be a int list ' f'with {spatial_ndims + 1} elements'): _ = getattr(tk.layers, f'Embedding{spatial_ndims}d')(n_embeddings, w_shape[:-1]) # test no grad layer = Embedding(n_embeddings, n_channels, freeze=True) weight = layer.weight self.assertEqual(T.shape(weight), [n_embeddings, n_channels]) layer = jit_compile(layer) indices = T.random.randint(0, n_samples, [7, 8]) output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) out_sum = T.reduce_sum(output**2) try: [grad] = T.grad([out_sum], [weight]) except Exception: pass else: self.assertTrue(T.is_null_grad(weight, grad)) # test errors with pytest.raises(ValueError, match='`embedding_size` must not be empty'): _ = Embedding(n_embeddings, [])
def test_basic_interface(self): normal = UnitNormal(shape=[2, 3]) samples = normal.sample(n_samples=5) samples_0 = samples.tensor[0] samples_no_grad = T.stop_grad(samples.tensor) log_prob = normal.log_prob(samples.tensor, group_ndims=0) log_prob_reduce_1 = T.reduce_sum(log_prob, axis=[-1]) ## t = StochasticTensor( tensor=samples_no_grad, distribution=normal, n_samples=5, group_ndims=0, reparameterized=False, ) self.assertIs(t.tensor, samples_no_grad) self.assertIs(t.distribution, normal) self.assertEqual(t.n_samples, 5) self.assertEqual(t.group_ndims, 0) self.assertEqual(t.reparameterized, False) self.assertIsNone(t.transform_origin) self.assertIsNone(t._cached_log_prob) self.assertIsNone(t._cached_prob) self.assertEqual(repr(t), f'StochasticTensor({t.tensor!r})') self.assertEqual(hash(t), hash(t)) self.assertEqual(t, t) self.assertNotEqual( t, StochasticTensor( tensor=samples_0, distribution=normal, n_samples=5, group_ndims=0, reparameterized=False, )) self.assertEqual(t.continuous, True) # log_prob() this_log_prob = t.log_prob() self.assertIs(t._cached_log_prob, this_log_prob) self.assertIs(t.log_prob(), t._cached_log_prob) self.assertIs(t.log_prob(group_ndims=0), t._cached_log_prob) assert_allclose(this_log_prob, log_prob, rtol=1e-4) this_log_prob = t.log_prob(group_ndims=1) self.assertIsNot(this_log_prob, t._cached_log_prob) assert_allclose(this_log_prob, log_prob_reduce_1, rtol=1e-4) # prob() this_prob = t.prob() self.assertIs(t._cached_prob, this_prob) self.assertIs(t.prob(), t._cached_prob) self.assertIs(t.prob(group_ndims=0), t._cached_prob) assert_allclose(this_prob, np.exp(T.to_numpy(log_prob)), rtol=1e-4) this_prob = t.prob(group_ndims=1) self.assertIsNot(this_prob, t._cached_prob) assert_allclose(this_prob, np.exp(T.to_numpy(log_prob_reduce_1)), rtol=1e-4) ## normal.continuous = False t = StochasticTensor( tensor=samples_0, distribution=normal, n_samples=None, group_ndims=1, reparameterized=True, log_prob=log_prob_reduce_1, transform_origin=samples, ) self.assertEqual(t.continuous, False) self.assertIs(t.tensor, samples_0) self.assertIs(t.distribution, normal) self.assertEqual(t.n_samples, None) self.assertEqual(t.group_ndims, 1) self.assertEqual(t.reparameterized, True) self.assertIs(t.transform_origin, samples) self.assertIs(t._cached_log_prob, log_prob_reduce_1) self.assertIsNone(t._cached_prob)