Exemple #1
0
    def test_simple_regression_correctness(self):
        # Verify that optimizing a simple linear regression by gradient descent
        # recovers the known-correct weights.
        batch_shape = [4, 3]
        num_timesteps = 10
        num_features = 2
        design_matrix = self._build_placeholder(
            np.random.randn(*(batch_shape + [num_timesteps, num_features])))

        true_weights = self._build_placeholder([4., -3.])
        predicted_time_series = linear_operator_util.matmul_with_broadcast(
            design_matrix, true_weights[..., tf.newaxis])

        linear_regression = LinearRegression(
            design_matrix=design_matrix,
            weights_prior=tfd.Independent(tfd.Cauchy(
                loc=self._build_placeholder(np.zeros([num_features])),
                scale=self._build_placeholder(np.ones([num_features]))),
                                          reinterpreted_batch_ndims=1))
        observation_noise_scale_prior = tfd.LogNormal(
            loc=self._build_placeholder(-2),
            scale=self._build_placeholder(0.1))
        model = Sum(
            components=[linear_regression],
            observation_noise_scale_prior=observation_noise_scale_prior)

        learnable_weights = tf.Variable(
            tf.zeros([num_features], dtype=true_weights.dtype))
        learnable_ssm = model.make_state_space_model(
            num_timesteps=num_timesteps,
            param_vals={
                "LinearRegression/_weights": learnable_weights,
                "observation_noise_scale":
                observation_noise_scale_prior.mode()
            })

        loss = -learnable_ssm.log_prob(predicted_time_series)
        train_op = tf.train.AdamOptimizer(0.1).minimize(loss)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            for _ in range(80):
                _ = sess.run(train_op)
            self.assertAllClose(*sess.run((true_weights, learnable_weights)),
                                atol=0.2)
Exemple #2
0
  def test_broadcast_batch_shapes(self):
    seed = test_util.test_seed(sampler_type='stateless')

    batch_shape = [3, 1, 4]
    partial_batch_shape = [2, 1]
    expected_broadcast_batch_shape = [3, 2, 4]

    # Build a model where parameters have different batch shapes.
    partial_batch_loc = self._build_placeholder(
        np.random.randn(*partial_batch_shape))
    full_batch_loc = self._build_placeholder(
        np.random.randn(*batch_shape))

    partial_scale_prior = tfd.LogNormal(
        loc=partial_batch_loc, scale=tf.ones_like(partial_batch_loc))
    full_scale_prior = tfd.LogNormal(
        loc=full_batch_loc, scale=tf.ones_like(full_batch_loc))
    loc_prior = tfd.Normal(loc=partial_batch_loc,
                           scale=tf.ones_like(partial_batch_loc))

    linear_trend = LocalLinearTrend(level_scale_prior=full_scale_prior,
                                    slope_scale_prior=full_scale_prior,
                                    initial_level_prior=loc_prior,
                                    initial_slope_prior=loc_prior)
    seasonal = Seasonal(num_seasons=3,
                        drift_scale_prior=partial_scale_prior,
                        initial_effect_prior=loc_prior)
    model = Sum([linear_trend, seasonal],
                observation_noise_scale_prior=partial_scale_prior)
    param_samples = [p.prior.sample(seed=seed) for p in model.parameters]
    ssm = model.make_state_space_model(num_timesteps=2,
                                       param_vals=param_samples)

    # Test that the model's batch shape matches the SSM's batch shape,
    # and that they both match the expected broadcast shape.
    self.assertAllEqual(model.batch_shape, ssm.batch_shape)

    (model_batch_shape_tensor_,
     ssm_batch_shape_tensor_) = self.evaluate((model.batch_shape_tensor(),
                                               ssm.batch_shape_tensor()))
    self.assertAllEqual(model_batch_shape_tensor_, ssm_batch_shape_tensor_)
    self.assertAllEqual(model_batch_shape_tensor_,
                        expected_broadcast_batch_shape)
  def test_broadcast_batch_shapes(self):

    batch_shape = [3, 1, 4]
    partial_batch_shape = [2, 1]
    expected_broadcast_batch_shape = [3, 2, 4]

    # Build a model where parameters have different batch shapes.
    partial_batch_loc = self._build_placeholder(
        np.random.randn(*partial_batch_shape))
    full_batch_loc = self._build_placeholder(
        np.random.randn(*batch_shape))

    partial_scale_prior = tfd.LogNormal(
        loc=partial_batch_loc, scale=tf.ones_like(partial_batch_loc))
    full_scale_prior = tfd.LogNormal(
        loc=full_batch_loc, scale=tf.ones_like(full_batch_loc))
    loc_prior = tfd.Normal(loc=partial_batch_loc,
                           scale=tf.ones_like(partial_batch_loc))

    linear_trend = LocalLinearTrend(level_scale_prior=full_scale_prior,
                                    slope_scale_prior=full_scale_prior,
                                    initial_level_prior=loc_prior,
                                    initial_slope_prior=loc_prior)
    seasonal = Seasonal(num_seasons=3,
                        drift_scale_prior=partial_scale_prior,
                        initial_effect_prior=loc_prior)
    model = Sum([linear_trend, seasonal],
                observation_noise_scale_prior=partial_scale_prior)
    param_samples = [p.prior.sample() for p in model.parameters]
    ssm = model.make_state_space_model(num_timesteps=2,
                                       param_vals=param_samples)

    # Test that the model's batch shape matches the SSM's batch shape,
    # and that they both match the expected broadcast shape.
    self.assertAllEqual(model.batch_shape, ssm.batch_shape)

    (model_batch_shape_tensor_,
     ssm_batch_shape_tensor_) = self.evaluate((model.batch_shape_tensor(),
                                               ssm.batch_shape_tensor()))
    self.assertAllEqual(model_batch_shape_tensor_, ssm_batch_shape_tensor_)
    self.assertAllEqual(model_batch_shape_tensor_,
                        expected_broadcast_batch_shape)