def _build_sts(self, observed_time_series=None): first_component = LocalLinearTrend( observed_time_series=observed_time_series, name='first_component') second_component = LocalLinearTrend( observed_time_series=observed_time_series, name='second_component') return Sum(components=[first_component, second_component], observed_time_series=observed_time_series)
def test_broadcast_batch_shapes(self): seed = test_util.test_seed(sampler_type='stateless') batch_shape = [3, 1, 4] partial_batch_shape = [2, 1] expected_broadcast_batch_shape = [3, 2, 4] # Build a model where parameters have different batch shapes. partial_batch_loc = self._build_placeholder( np.random.randn(*partial_batch_shape)) full_batch_loc = self._build_placeholder( np.random.randn(*batch_shape)) partial_scale_prior = tfd.LogNormal( loc=partial_batch_loc, scale=tf.ones_like(partial_batch_loc)) full_scale_prior = tfd.LogNormal( loc=full_batch_loc, scale=tf.ones_like(full_batch_loc)) loc_prior = tfd.Normal(loc=partial_batch_loc, scale=tf.ones_like(partial_batch_loc)) linear_trend = LocalLinearTrend(level_scale_prior=full_scale_prior, slope_scale_prior=full_scale_prior, initial_level_prior=loc_prior, initial_slope_prior=loc_prior) seasonal = Seasonal(num_seasons=3, drift_scale_prior=partial_scale_prior, initial_effect_prior=loc_prior) model = Sum([linear_trend, seasonal], observation_noise_scale_prior=partial_scale_prior) param_samples = [p.prior.sample(seed=seed) for p in model.parameters] ssm = model.make_state_space_model(num_timesteps=2, param_vals=param_samples) # Test that the model's batch shape matches the SSM's batch shape, # and that they both match the expected broadcast shape. self.assertAllEqual(model.batch_shape, ssm.batch_shape) (model_batch_shape_tensor_, ssm_batch_shape_tensor_) = self.evaluate((model.batch_shape_tensor(), ssm.batch_shape_tensor())) self.assertAllEqual(model_batch_shape_tensor_, ssm_batch_shape_tensor_) self.assertAllEqual(model_batch_shape_tensor_, expected_broadcast_batch_shape)
def _build_sts(self, observed_time_series=None): return LocalLinearTrend(observed_time_series=observed_time_series)