コード例 #1
0
    def test_builds_without_errors(self):
        batch_shape = [4, 3]
        num_timesteps = 10
        num_features = 2
        design_matrix = self._build_placeholder(
            np.random.randn(*(batch_shape + [num_timesteps, num_features])))

        weights_batch_shape = []
        if not self.use_static_shape:
            weights_batch_shape = tf1.placeholder_with_default(np.array(
                weights_batch_shape, dtype=np.int32),
                                                               shape=None)
        sparse_regression = SparseLinearRegression(
            design_matrix=design_matrix,
            weights_batch_shape=weights_batch_shape)
        prior_params = [
            param.prior.sample() for param in sparse_regression.parameters
        ]

        ssm = sparse_regression.make_state_space_model(
            num_timesteps=num_timesteps, param_vals=prior_params)
        if self.use_static_shape:
            output_shape = ssm.sample().shape.as_list()
        else:
            output_shape = self.evaluate(tf.shape(ssm.sample()))
        self.assertAllEqual(output_shape, batch_shape + [num_timesteps, 1])
コード例 #2
0
    def _build_sts(self, observed_time_series=None):
        max_timesteps = 100
        num_features = 3

        # LinearRegression components don't currently take an `observed_time_series`
        # argument, so they can't infer a prior batch shape. This means we have to
        # manually set the batch shape expected by the tests.
        batch_shape = None
        if observed_time_series is not None:
            observed_time_series_tensor, _ = (
                sts_util.canonicalize_observed_time_series_with_mask(
                    observed_time_series))
            batch_shape = tf.shape(observed_time_series_tensor)[:-2]

        regression = SparseLinearRegression(design_matrix=np.random.randn(
            max_timesteps, num_features).astype(np.float32),
                                            weights_batch_shape=batch_shape)
        return Sum(components=[regression],
                   observed_time_series=observed_time_series)