def test_broadcasting_correctness(self): seed = test_util.test_seed(sampler_type='stateless') # This test verifies that broadcasting of component parameters works as # expected. We construct a SSM with no batch shape, and test that when we # add it to another SSM of batch shape [3], we get the same model # as if we had explicitly broadcast the parameters of the first SSM before # adding. num_timesteps = 5 transition_matrix = np.random.randn(2, 2) transition_noise_diag = np.exp(np.random.randn(2)) observation_matrix = np.random.randn(1, 2) observation_noise_diag = np.exp(np.random.randn(1)) initial_state_prior_diag = np.exp(np.random.randn(2)) # First build the model in which we let AdditiveSSM do the broadcasting. batchless_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(transition_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(transition_noise_diag)), observation_matrix=self._build_placeholder(observation_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(observation_noise_diag)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(initial_state_prior_diag))) another_ssm = self._dummy_model(num_timesteps=num_timesteps, latent_size=4, batch_shape=[3]) broadcast_additive_ssm = AdditiveStateSpaceModel( [batchless_ssm, another_ssm]) # Next try doing our own broadcasting explicitly. broadcast_vector = np.ones([3, 1]) broadcast_matrix = np.ones([3, 1, 1]) batch_ssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(transition_matrix * broadcast_matrix), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(transition_noise_diag * broadcast_vector)), observation_matrix=self._build_placeholder(observation_matrix * broadcast_matrix), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(observation_noise_diag * broadcast_vector)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(initial_state_prior_diag * broadcast_vector))) manual_additive_ssm = AdditiveStateSpaceModel([batch_ssm, another_ssm]) # Both additive SSMs define the same model, so they should give the same # log_probs. y = self.evaluate(broadcast_additive_ssm.sample(seed=seed)) self.assertAllClose(self.evaluate(broadcast_additive_ssm.log_prob(y)), self.evaluate(manual_additive_ssm.log_prob(y)))
def _make_state_space_model(self, num_timesteps, param_map, initial_state_prior=None, initial_step=0): weights = self.params_to_weights(**param_map) predicted_timeseries = self.design_matrix.matmul(weights[..., tf.newaxis]) dtype = self.design_matrix.dtype # Since this model has `latent_size=0`, the latent prior and # transition model are dummy objects (zero-dimensional MVNs). dummy_mvndiag = _zero_dimensional_mvndiag(dtype) if initial_state_prior is None: initial_state_prior = dummy_mvndiag return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=tf.zeros([0, 0], dtype=dtype), transition_noise=dummy_mvndiag, observation_matrix=tf.zeros([1, 0], dtype=dtype), observation_noise=_observe_timeseries_fn(predicted_timeseries), initial_state_prior=initial_state_prior, initial_step=initial_step)
def _make_state_space_model(self, num_timesteps, param_map, initial_state_prior=None, initial_step=0): weights = param_map['weights'] # shape: [B, num_features] predicted_timeseries = self.design_matrix.matmul(weights[..., tf.newaxis]) dtype = self.design_matrix.dtype # Since this model has `latent_size=0`, the latent prior and # transition model are dummy objects (zero-dimensional MVNs). dummy_mvndiag = tfd.MultivariateNormalDiag( scale_diag=tf.ones([0], dtype=dtype)) dummy_mvndiag.covariance = lambda: dummy_mvndiag.variance()[..., tf. newaxis] if initial_state_prior is None: initial_state_prior = dummy_mvndiag def observation_noise_fn(t): predicted_slice = predicted_timeseries[..., t, :] return tfd.MultivariateNormalDiag( loc=predicted_slice, scale_diag=tf.zeros_like(predicted_slice)) return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=tf.zeros([0, 0], dtype=dtype), transition_noise=dummy_mvndiag, observation_matrix=tf.zeros([1, 0], dtype=dtype), observation_noise=observation_noise_fn, initial_state_prior=initial_state_prior, initial_step=initial_step)
def _make_state_space_model(self, num_timesteps, param_map, initial_state_prior=None, **linear_gaussian_ssm_kwargs): weights = self.params_to_weights(**param_map) predicted_timeseries = self.design_matrix.matmul(weights[..., tf.newaxis]) # Move timestep to the first dim (before any batch dimensions). predicted_timeseries = distribution_util.move_dimension( predicted_timeseries, -2, 0) dtype = self.design_matrix.dtype # Since this model has `latent_size=0`, the latent prior and # transition model are dummy objects (zero-dimensional MVNs). dummy_mvndiag = _zero_dimensional_mvndiag(dtype) if initial_state_prior is None: initial_state_prior = dummy_mvndiag return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=tf.zeros([0, 0], dtype=dtype), transition_noise=dummy_mvndiag, observation_matrix=tf.zeros([1, 0], dtype=dtype), observation_noise=_observe_timeseries_fn(predicted_timeseries), initial_state_prior=initial_state_prior, **linear_gaussian_ssm_kwargs)
def _make_state_space_model(self, num_timesteps, param_map, initial_state_prior=None, **linear_gaussian_ssm_kwargs): # TODO(b/215267145): Automatically ensure that sample dimensions of # `weights` do not collide with batch dimensions of `design_matrix`. weights = param_map['weights'] # shape: [B, num_features] predicted_timeseries = self.design_matrix.matmul(weights[..., tf.newaxis]) # Move timestep to the first dim (before any batch dimensions). predicted_timeseries = distribution_util.move_dimension( predicted_timeseries, -2, 0) dtype = self.design_matrix.dtype # Since this model has `latent_size=0`, the latent prior and # transition model are dummy objects (zero-dimensional MVNs). dummy_mvndiag = _zero_dimensional_mvndiag(dtype) if initial_state_prior is None: initial_state_prior = dummy_mvndiag return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=tf.zeros([0, 0], dtype=dtype), transition_noise=dummy_mvndiag, observation_matrix=tf.zeros([1, 0], dtype=dtype), observation_noise=_observe_timeseries_fn(predicted_timeseries), initial_state_prior=initial_state_prior, **linear_gaussian_ssm_kwargs)
def _dummy_model(self, num_timesteps=5, batch_shape=None, initial_state_prior_batch_shape=None, latent_size=2, observation_size=1, dtype=None): batch_shape = batch_shape if batch_shape is not None else [] initial_state_prior_batch_shape = (initial_state_prior_batch_shape if initial_state_prior_batch_shape is not None else batch_shape) dtype = dtype if dtype is not None else self.dtype return tfd.LinearGaussianStateSpaceModel( num_timesteps=num_timesteps, transition_matrix=self._build_placeholder(np.eye(latent_size), dtype=dtype), transition_noise=tfd.MultivariateNormalDiag( scale_diag=np.ones(batch_shape + [latent_size]).astype(dtype)), observation_matrix=self._build_placeholder( np.random.standard_normal(batch_shape + [observation_size, latent_size]), dtype=dtype), observation_noise=tfd.MultivariateNormalDiag( loc=self._build_placeholder(np.ones(batch_shape + [observation_size]), dtype=dtype), scale_diag=self._build_placeholder(np.ones(batch_shape + [observation_size]), dtype=dtype)), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones( initial_state_prior_batch_shape + [latent_size]), dtype=dtype)))
def test_day_of_week_example(self): # Test that the Seasonal SSM is equivalent to individually modeling # a random walk on each season's slice of timesteps. seed = test_util.test_seed(sampler_type='stateless') drift_scale = 0.6 observation_noise_scale = 0.1 day_of_week = SeasonalStateSpaceModel( num_timesteps=28, num_seasons=7, drift_scale=self._build_placeholder(drift_scale), observation_noise_scale=observation_noise_scale, initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder(np.ones([7]))), num_steps_per_season=1) random_walk_model = tfd.LinearGaussianStateSpaceModel( num_timesteps=4, transition_matrix=self._build_placeholder([[1.]]), transition_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([drift_scale])), observation_matrix=self._build_placeholder([[1.]]), observation_noise=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([observation_noise_scale])), initial_state_prior=tfd.MultivariateNormalDiag( scale_diag=self._build_placeholder([1.]))) sampled_time_series = day_of_week.sample(seed=seed) (sampled_time_series_, total_lp_, prior_mean_, prior_variance_) = self.evaluate([ sampled_time_series, day_of_week.log_prob(sampled_time_series), day_of_week.mean(), day_of_week.variance() ]) expected_daily_means_, expected_daily_variances_ = self.evaluate( [random_walk_model.mean(), random_walk_model.variance()]) # For the (noncontiguous) indices corresponding to each season, assert # that the model's mean, variance, and log_prob match a random-walk model. daily_lps = [] for day_idx in range(7): self.assertAllClose(prior_mean_[day_idx::7], expected_daily_means_) self.assertAllClose(prior_variance_[day_idx::7], expected_daily_variances_) daily_lps.append( self.evaluate( random_walk_model.log_prob( sampled_time_series_[day_idx::7]))) self.assertAlmostEqual(total_lp_, sum(daily_lps), places=3)
def test_log_prob_matches_linear_gaussian_ssm(self): dim = 2 batch_shape = [3, 1] seed, *model_seeds = samplers.split_seed(test_util.test_seed(), n=6) # Sample a random linear Gaussian process. prior_loc = self.evaluate( tfd.Normal(0., 1.).sample(batch_shape + [dim], seed=model_seeds[0])) prior_scale = self.evaluate( tfd.InverseGamma(1., 1.).sample(batch_shape + [dim], seed=model_seeds[1])) transition_matrix = self.evaluate( tfd.Normal(0., 1.).sample([dim, dim], seed=model_seeds[2])) transition_bias = self.evaluate( tfd.Normal(0., 1.).sample(batch_shape + [dim], seed=model_seeds[3])) transition_scale_tril = self.evaluate( tf.linalg.cholesky( tfd.WishartTriL( df=dim, scale_tril=tf.eye(dim)).sample(seed=model_seeds[4]))) initial_state_prior = tfd.MultivariateNormalDiag( loc=prior_loc, scale_diag=prior_scale, name='initial_state_prior') lgssm = tfd.LinearGaussianStateSpaceModel( num_timesteps=7, transition_matrix=transition_matrix, transition_noise=tfd.MultivariateNormalTriL( loc=transition_bias, scale_tril=transition_scale_tril), # Trivial observation model to pass through the latent state. observation_matrix=tf.eye(dim), observation_noise=tfd.MultivariateNormalDiag( loc=tf.zeros(dim), scale_diag=tf.zeros(dim)), initial_state_prior=initial_state_prior) markov_chain = tfd.MarkovChain( initial_state_prior=initial_state_prior, transition_fn=lambda _, x: tfd.MultivariateNormalTriL( # pylint: disable=g-long-lambda loc=tf.linalg.matvec(transition_matrix, x) + transition_bias, scale_tril=transition_scale_tril), num_steps=7) x = markov_chain.sample(5, seed=seed) self.assertAllClose(lgssm.log_prob(x), markov_chain.log_prob(x), rtol=1e-5)