def test_iwae(self): assert_allclose_ = functools.partial(assert_allclose, rtol=1e-5, atol=1e-6) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True) cost = iwae_estimator(log_f, axis=[0]) assert_allclose_(-cost, iwae_estimator(log_f, axis=[0], negative=True)) assert_allclose_(T.reduce_mean(cost), iwae_estimator(log_f, axis=[0], reduction='mean')) assert_allclose_(T.reduce_sum(cost), iwae_estimator(log_f, axis=[0], reduction='sum')) cost_shape = T.shape(cost) assert_allclose_( T.grad([T.reduce_sum(cost)], [y])[0], T.reduce_sum(wk_hat * (2 * x * y), axis=[0])) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) wk_hat = f / T.reduce_sum(f, axis=[0], keepdims=True) cost_k = iwae_estimator(log_f, axis=[0], keepdims=True) assert_allclose_( T.reduce_mean(cost_k), iwae_estimator(log_f, axis=[0], keepdims=True, reduction='mean')) assert_allclose_( T.reduce_sum(cost_k), iwae_estimator(log_f, axis=[0], keepdims=True, reduction='sum')) assert_allclose_( -cost_k, T.to_numpy( iwae_estimator(log_f, axis=[0], keepdims=True, negative=True))) self.assertListEqual([1] + cost_shape, T.shape(cost_k)) assert_allclose_( T.grad([T.reduce_sum(cost_k)], [y])[0], T.reduce_sum(wk_hat * (2 * x * y), axis=[0]))
def test_invertible_matrices(self): for cls in (LooseInvertibleMatrix, StrictInvertibleMatrix): for n in [1, 3, 5]: m = cls(np.random.randn(n, n)) self.assertEqual(repr(m), f'{cls.__qualname__}(size={n})') self.assertEqual(m.size, n) m = tk.layers.jit_compile(m) # check the initial value is an orthogonal matrix matrix, _ = m(inverse=False, compute_log_det=False) inv_matrix, _ = m(inverse=True, compute_log_det=False) assert_allclose(np.eye(n), T.matmul(matrix, inv_matrix), rtol=1e-4, atol=1e-6) assert_allclose(np.eye(n), T.matmul(inv_matrix, matrix), rtol=1e-4, atol=1e-6) # check the invertibility check_invertible_matrix(self, m, n) # check the gradient matrix, log_det = m(inverse=False, compute_log_det=True) params = list(tk.layers.iter_parameters(m)) grads = T.grad( [T.reduce_sum(matrix), T.reduce_sum(log_det)], params) # update with gradient, then check the invertibility if cls is StrictInvertibleMatrix: for param, grad in zip(params, grads): with T.no_grad(): T.assign(param, param + 0.001 * grad) check_invertible_matrix(self, m, n)
def test_grad(self): def f(): indices = T.as_tensor([[0, 1], [1, 0]]) values = T.requires_grad(T.as_tensor([0.1, 0.2])) shape = [2, 2] x = T.sparse.make_sparse(indices, values, shape=shape) return values, x values, x = f() y = T.sparse.reduce_sum(x * x) [grad] = T.grad([y], [values]) assert_allclose(grad, 2 * values, atol=1e-4, rtol=1e-6) values, x = f() y = T.sparse.reduce_sum(T.sparse.stop_grad(x) * x) [grad] = T.grad([y], [values]) assert_allclose(grad, values, atol=1e-4, rtol=1e-6)
def test_sgvb(self): assert_allclose_ = functools.partial(assert_allclose, rtol=1e-5, atol=1e-6) # default x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost = sgvb_estimator(f) assert_allclose_(-cost, sgvb_estimator(f, negative=True)) assert_allclose_(T.reduce_mean(cost), sgvb_estimator(f, reduction='mean')) assert_allclose_(T.reduce_sum(cost), sgvb_estimator(f, reduction='sum')) cost_shape = T.shape(cost) assert_allclose_( T.grad([T.reduce_sum(cost)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0])) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost_r = sgvb_estimator(f, axis=[0]) assert_allclose_(-cost_r, sgvb_estimator(f, axis=[0], negative=True)) self.assertListEqual(cost_shape[1:], T.shape(cost_r)) assert_allclose_( T.grad([T.reduce_sum(cost_r)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0]) / 7) x, y, z, f, log_f, log_q = prepare_test_payload(reparameterized=True) cost_rk = sgvb_estimator(f, axis=[0], keepdims=True) assert_allclose_(T.reduce_mean(cost_rk), sgvb_estimator(f, axis=[0], reduction='mean')) assert_allclose_(T.reduce_sum(cost_rk), sgvb_estimator(f, axis=[0], reduction='sum')) assert_allclose_( -cost_rk, sgvb_estimator(f, axis=[0], keepdims=True, negative=True)) self.assertListEqual([1] + cost_shape[1:], T.shape(cost_rk)) assert_allclose_( T.grad([T.reduce_sum(cost_rk)], [y])[0], T.reduce_sum(2 * x * y * f, axis=[0]) / 7)
def test_sample_and_log_prob(self): array_low = np.random.randn(2, 1) array_high = np.exp(np.random.randn(1, 3)) + 1. log_zero = -1e6 def log_prob(x, low, high, group_ndims=0): if low is None and high is None: low, high = 0., 1. log_pdf = -np.log(np.ones_like(x) * (high - low)) log_pdf = np.where(np.logical_and(low <= x, x <= high), log_pdf, log_zero) log_pdf = np.sum(log_pdf, axis=tuple(range(-group_ndims, 0))) return log_pdf for shape, dtype, (low, high), event_ndims in product( [None, [], [5, 4]], float_dtypes, [(None, None), (-1., 2.), (array_low, array_high)], range(5)): low_rank = len(np.shape(low)) if low is not None else 0 if event_ndims > len(shape or []) + low_rank: continue if isinstance(low, np.ndarray): low_t = T.as_tensor(low, dtype=dtype) high_t = T.as_tensor(high, dtype=dtype) uniform = Uniform(shape=shape, low=low_t, high=high_t, event_ndims=event_ndims, log_zero=log_zero) value_shape = (shape or []) + [2, 3] self.assertIs(uniform.low, low_t) self.assertIs(uniform.high, high_t) else: uniform = Uniform(shape=shape, low=low, high=high, dtype=dtype, event_ndims=event_ndims, log_zero=log_zero) value_shape = shape or [] self.assertEqual(uniform.low, low) self.assertEqual(uniform.high, high) self.assertEqual(uniform.log_zero, log_zero) self.assertEqual(uniform.value_shape, value_shape) # sample(n_samples=None) t = uniform.sample() x = T.to_numpy(t.tensor) sample_shape = value_shape self.assertIsInstance(t, StochasticTensor) self.assertIs(t.distribution, uniform) self.assertEqual(T.get_dtype(t.tensor), dtype) self.assertEqual(t.n_samples, None) self.assertEqual(t.group_ndims, 0) self.assertEqual(t.reparameterized, True) self.assertIsInstance(t.tensor, T.Tensor) self.assertEqual(T.shape(t.tensor), sample_shape) for log_pdf in [t.log_prob(), uniform.log_prob(t)]: self.assertEqual(T.get_dtype(log_pdf), dtype) assert_allclose(log_pdf, log_prob(x, low, high, event_ndims), rtol=1e-4) # test log-prob on out-of-range values assert_allclose( uniform.log_prob(t.tensor * 10.), log_prob(x * 10., low, high, event_ndims), rtol=1e-4, ) # sample(n_samples=7) if event_ndims >= 1: t = uniform.sample(n_samples=7, group_ndims=-1, reparameterized=False) x = T.to_numpy(t.tensor) sample_shape = [7] + value_shape self.assertIsInstance(t, StochasticTensor) self.assertIs(t.distribution, uniform) self.assertEqual(T.get_dtype(t.tensor), dtype) self.assertEqual(t.n_samples, 7) self.assertEqual(t.group_ndims, -1) self.assertEqual(t.reparameterized, False) self.assertIsInstance(t.tensor, T.Tensor) self.assertEqual(T.shape(t.tensor), sample_shape) reduce_ndims = event_ndims - 1 for log_pdf in [ t.log_prob(), uniform.log_prob(t, group_ndims=-1) ]: self.assertEqual(T.get_dtype(log_pdf), dtype) assert_allclose(log_pdf, log_prob(x, low, high, reduce_ndims), rtol=1e-4) # test reparameterized low_t = T.requires_grad(T.as_tensor(array_low)) high_t = T.requires_grad(T.as_tensor(array_high)) uniform = Uniform(low=low_t, high=high_t) t = uniform.sample() self.assertTrue(t.reparameterized) u = (T.to_numpy(t.tensor) - array_low) / (array_high - array_low) [low_grad, high_grad] = T.grad([T.reduce_sum(t.tensor)], [low_t, high_t]) assert_allclose(low_grad, np.sum(1. - u, axis=-1, keepdims=True), rtol=1e-4) assert_allclose(high_grad, np.sum(u, axis=0, keepdims=True), rtol=1e-4) t = uniform.sample(reparameterized=False) w_t = T.requires_grad(T.as_tensor(np.random.randn(2, 3))) self.assertFalse(t.reparameterized) [low_grad, high_grad] = T.grad([T.reduce_sum(w_t * t.tensor)], [low_t, high_t], allow_unused=True) self.assertTrue(T.is_null_grad(low_t, low_grad)) self.assertTrue(T.is_null_grad(high_t, high_grad))
def do_test(low, high, dtype): # test(n_samples=n_samples) mean_t = T.as_tensor(mean, dtype) std_t = T.as_tensor(std, dtype) logstd_t = T.as_tensor(logstd, dtype) t = T.random.truncated_normal(mean_t, std_t, n_samples=n_samples, low=low, high=high) self.assertEqual(T.get_dtype(t), dtype) self.assertEqual(T.get_device(t), T.current_device()) self.assertEqual(T.shape(t), [n_samples, 2, 3, 4]) # test sample value range x = T.to_numpy(t) if low is not None: np.testing.assert_array_less( (low * std + mean - 1e-7) * np.ones_like(x), x) if high is not None: np.testing.assert_array_less( x, np.ones_like(x) * high * std + mean + 1e-7) # test log_prob do_check_log_prob(given=t, batch_ndims=len(x.shape), Z_log_prob_fn=partial( T.random.truncated_normal_log_pdf, mean=mean_t, std=std_t, logstd=logstd_t, low=low, high=high, log_zero=log_zero, ), np_log_prob=log_prob(x, low, high)) do_check_log_prob( given=t * 10., # where the majority is out of [low, high] range batch_ndims=len(x.shape), Z_log_prob_fn=partial( T.random.truncated_normal_log_pdf, mean=mean_t, std=std_t, logstd=logstd_t, low=low, high=high, log_zero=log_zero, ), np_log_prob=log_prob(x * 10., low, high)) # test(n_samples=None) mean_t = T.as_tensor(mean, dtype) std_t = T.as_tensor(std, dtype) logstd_t = T.as_tensor(logstd, dtype) t = T.random.truncated_normal(mean_t, std_t, low=low, high=high) self.assertEqual(T.get_dtype(t), dtype) self.assertEqual(T.get_device(t), T.current_device()) # test sample value range x = T.to_numpy(t) if low is not None: np.testing.assert_array_less( (low * std + mean - 1e-7) * np.ones_like(x), x) if high is not None: np.testing.assert_array_less( x, np.ones_like(x) * high * std + mean + 1e-7) # test log_prob do_check_log_prob(given=t, batch_ndims=len(x.shape), Z_log_prob_fn=partial( T.random.truncated_normal_log_pdf, mean=mean_t, std=std_t, logstd=logstd_t, low=low, high=high, log_zero=log_zero, ), np_log_prob=log_prob(x, low, high)) do_check_log_prob( given=t * 10., # where the majority is out of [low, high] range batch_ndims=len(x.shape), Z_log_prob_fn=partial( T.random.truncated_normal_log_pdf, mean=mean_t, std=std_t, logstd=logstd_t, low=low, high=high, log_zero=log_zero, ), np_log_prob=log_prob(x * 10., low, high)) # test reparameterized w = np.random.randn(2, 3, 4) w_t = T.requires_grad(T.as_tensor(w)) mean_t = T.requires_grad(T.as_tensor(mean, dtype)) std_t = T.requires_grad(T.as_tensor(std, dtype)) t = w_t * T.random.truncated_normal(mean_t, std_t) [mean_grad, std_grad] = T.grad([t], [mean_t, std_t], [T.ones_like(t)]) assert_allclose(mean_grad, w, rtol=1e-4) assert_allclose(std_grad, np.sum(T.to_numpy((t - w_t * mean_t) / std_t), axis=0), rtol=1e-4) # test not reparameterized w_t = T.requires_grad(T.as_tensor(w)) mean_t = T.requires_grad(T.as_tensor(mean, dtype)) std_t = T.requires_grad(T.as_tensor(std, dtype)) t = w_t * T.random.truncated_normal( mean_t, std_t, reparameterized=False) [mean_grad, std_grad] = T.grad([t], [mean_t, std_t], [T.ones_like(t)], allow_unused=True) self.assertTrue(T.is_null_grad(mean_t, mean_grad)) self.assertTrue(T.is_null_grad(std_t, std_grad)) # given has lower rank than params, broadcasted to match param mean_t = T.as_tensor(mean, dtype) std_t = T.as_tensor(std, dtype) logstd_t = T.as_tensor(logstd, dtype) assert_allclose(T.random.truncated_normal_log_pdf( T.float_scalar(0.), mean_t, std_t, logstd_t, low=low, high=high, log_zero=log_zero), log_prob(0., low=low, high=high), rtol=1e-4) # dtype mismatch with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'): _ = T.random.truncated_normal(T.as_tensor(mean, T.float32), T.as_tensor(std, T.float64), low=low, high=high) # check numerics mean_t = T.as_tensor(mean) std_t = T.zeros_like(mean_t) logstd_t = T.as_tensor(T.log(std_t)) t = T.random.normal(mean_t, std_t) with pytest.raises(Exception, match='Infinity or NaN value encountered'): _ = T.random.truncated_normal_log_pdf(t, mean_t, std_t, logstd_t, validate_tensors=True)
def test_normal(self): mean = np.random.randn(2, 3, 4) logstd = np.random.randn(3, 4) std = np.exp(logstd) def log_prob(given): # np.log(np.exp(-(given - mean) ** 2 / (2. * std ** 2)) / # (np.sqrt(2 * np.pi) * std)) return (-(given - mean)**2 * (0.5 * np.exp(-2. * logstd)) - np.log(np.sqrt(2 * np.pi)) - logstd) # test n_samples by manual expanding the param shape for dtype in float_dtypes: # test sample dtype and shape mean_t = T.cast(T.expand(T.as_tensor(mean), [n_samples, 2, 3, 4]), dtype) std_t = T.cast(T.expand(T.as_tensor(std), [n_samples, 1, 3, 4]), dtype) logstd_t = T.cast( T.expand(T.as_tensor(logstd), [n_samples, 1, 3, 4]), dtype) t = T.random.normal(mean_t, std_t) self.assertEqual(T.get_dtype(t), dtype) self.assertEqual(T.get_device(t), T.current_device()) self.assertEqual(T.shape(t), [n_samples, 2, 3, 4]) # test sample mean x = T.to_numpy(t) x_mean = np.mean(x, axis=0) np.testing.assert_array_less( np.abs(x_mean - mean), np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0), [2, 1, 1])) # test log_prob do_check_log_prob(given=t, batch_ndims=len(x.shape), Z_log_prob_fn=partial(T.random.normal_log_pdf, mean=mean_t, logstd=logstd_t), np_log_prob=log_prob(x)) # test with n_samples for dtype in float_dtypes: # test sample dtype and shape mean_t = T.as_tensor(mean, dtype) std_t = T.as_tensor(std, dtype) logstd_t = T.as_tensor(logstd, dtype) t = T.random.normal(mean_t, std_t, n_samples=n_samples) self.assertEqual(T.get_dtype(t), dtype) self.assertEqual(T.get_device(t), T.current_device()) self.assertEqual(T.shape(t), [n_samples, 2, 3, 4]) # test sample mean x = T.to_numpy(t) x_mean = np.mean(x, axis=0) np.testing.assert_array_less( np.abs(x_mean - mean), np.tile(np.expand_dims(5 * std / np.sqrt(n_samples), axis=0), [2, 1, 1])) # test log_prob do_check_log_prob(given=t, batch_ndims=len(x.shape), Z_log_prob_fn=partial(T.random.normal_log_pdf, mean=mean_t, logstd=logstd_t), np_log_prob=log_prob(x)) # test no n_samples for dtype in float_dtypes: mean_t = T.as_tensor(mean, dtype) std_t = T.as_tensor(std, dtype) logstd_t = T.as_tensor(logstd, dtype) t = T.random.normal(mean_t, std_t) self.assertEqual(T.get_dtype(t), dtype) self.assertEqual(T.get_device(t), T.current_device()) # test log_prob x = T.to_numpy(t) do_check_log_prob(given=t, batch_ndims=len(x.shape), Z_log_prob_fn=partial(T.random.normal_log_pdf, mean=mean_t, logstd=logstd_t), np_log_prob=log_prob(x)) # test reparameterized w = np.random.randn(2, 3, 4) for dtype in float_dtypes: w_t = T.requires_grad(T.as_tensor(w)) mean_t = T.requires_grad(T.as_tensor(mean, dtype)) std_t = T.requires_grad(T.as_tensor(std, dtype)) t = w_t * T.random.normal(mean_t, std_t) [mean_grad, std_grad] = T.grad([t], [mean_t, std_t], [T.ones_like(t)]) assert_allclose(mean_grad, w, rtol=1e-4) assert_allclose(std_grad, np.sum(T.to_numpy((t - w_t * mean_t) / std_t), axis=0), rtol=1e-4) # test not reparameterized for dtype in float_dtypes: w_t = T.requires_grad(T.as_tensor(w)) mean_t = T.requires_grad(T.as_tensor(mean, dtype)) std_t = T.requires_grad(T.as_tensor(std, dtype)) t = w_t * T.random.normal(mean_t, std_t, reparameterized=False) [mean_grad, std_grad] = T.grad([t], [mean_t, std_t], [T.ones_like(t)], allow_unused=True) self.assertTrue(T.is_null_grad(mean_t, mean_grad)) self.assertTrue(T.is_null_grad(std_t, std_grad)) # given has lower rank than params, broadcasted to match param for dtype in float_dtypes: mean_t = T.as_tensor(mean, dtype) logstd_t = T.as_tensor(logstd, dtype) for val in (0., 1., -1.): assert_allclose(T.random.normal_log_pdf( T.float_scalar(val), mean_t, logstd_t), log_prob(val), rtol=1e-4) # dtype mismatch with pytest.raises(Exception, match='`mean.dtype` != `std.dtype`'): _ = T.random.normal(T.as_tensor(mean, T.float32), T.as_tensor(std, T.float64)) # check numerics mean_t = T.as_tensor(mean) std_t = T.zeros_like(mean_t) logstd_t = T.as_tensor(T.log(std_t)) t = T.random.normal(mean_t, std_t) with pytest.raises(Exception, match='Infinity or NaN value encountered'): _ = T.random.normal_log_pdf(t, mean_t, logstd_t, validate_tensors=True)
def test_embedding(self): n_channels = 3 n_embeddings = n_samples for spatial_ndims in (0, 1, 2, 3): w_shape = make_conv_shape([], n_channels, [4, 5, 6][:spatial_ndims]) w_size = int(np.prod(w_shape)) layer = getattr(tk.layers, (f'Embedding{spatial_ndims}d' if spatial_ndims > 0 else 'Embedding'))(n_embeddings, w_shape) weight = layer.weight # check the weight self.assertEqual(T.shape(weight), [n_embeddings] + w_shape) reduce_axis = list(range(len(w_shape) + 1)) reduce_axis.pop(-1 if T.IS_CHANNEL_LAST else 1) w_mean = np.average(T.to_numpy(weight), axis=tuple(reduce_axis)) np.testing.assert_array_less( w_mean, 3. / np.sqrt(n_samples * w_size / n_channels)) # check the output layer = jit_compile(layer) weight_array = T.to_numpy(weight) for in_shape in ([7], [7, 8]): indices = T.random.randint(0, n_samples, in_shape) indices = T.concat([indices, indices[:3]], axis=0) # check the output output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) # check the grad if spatial_ndims in (0, 1): out_sum = T.reduce_sum(output**2) [grad] = T.grad([out_sum], [weight]) expected_grad = np.zeros(T.shape(weight)) for idx in T.to_numpy(indices).reshape([-1]): expected_grad[idx] += 2. * weight_array[idx] assert_allclose(grad, expected_grad) # test the constructor error if spatial_ndims > 0: with pytest.raises( ValueError, match=f'`embedding_size` must be a int list ' f'with {spatial_ndims + 1} elements'): _ = getattr(tk.layers, f'Embedding{spatial_ndims}d')(n_embeddings, w_shape[:-1]) # test no grad layer = Embedding(n_embeddings, n_channels, freeze=True) weight = layer.weight self.assertEqual(T.shape(weight), [n_embeddings, n_channels]) layer = jit_compile(layer) indices = T.random.randint(0, n_samples, [7, 8]) output = layer(indices) assert_allclose(output, T.embedding(weight, indices)) out_sum = T.reduce_sum(output**2) try: [grad] = T.grad([out_sum], [weight]) except Exception: pass else: self.assertTrue(T.is_null_grad(weight, grad)) # test errors with pytest.raises(ValueError, match='`embedding_size` must not be empty'): _ = Embedding(n_embeddings, [])