Esempio n. 1
0
    def get_inference_data3(self, data, eight_schools_params):
        """Read with observed Tensor var_names and dims."""
        import tensorflow as tf

        if int(tf.__version__[0]) > 1:
            import tensorflow.compat.v1 as tf  # pylint: disable=import-error

            tf.disable_v2_behavior()

        inference_data = from_tfp(
            data.obj,
            var_names=["mu", "tau", "eta"],
            model_fn=lambda: data.model(
                eight_schools_params["J"], eight_schools_params["sigma"].astype(np.float32)
            ),
            posterior_predictive_samples=100,
            posterior_predictive_size=3,
            observed=tf.convert_to_tensor(
                np.vstack(
                    (
                        eight_schools_params["y"],
                        eight_schools_params["y"],
                        eight_schools_params["y"],
                    )
                ).astype(np.float32),
                np.float32,
            ),
            coords={"school": np.arange(eight_schools_params["J"])},
            dims={"eta": ["school"], "obs": ["size_dim", "school"]},
        )
        return inference_data
Esempio n. 2
0
 def get_inference_data4(self, data, eight_schools_params):
     """Test setter."""
     inference_data = from_tfp(
         data.obj + [np.ones_like(data.obj[0]).astype(np.float32)],
         var_names=["mu", "tau", "eta", "avg_effect"],
         model_fn=lambda: data.model(
             eight_schools_params["J"], eight_schools_params["sigma"].astype(np.float32)
         ),
         observed=eight_schools_params["y"].astype(np.float32),
     )
     return inference_data
Esempio n. 3
0
 def get_inference_data(self, data, eight_schools_params):
     """Normal read with observed and var_names."""
     inference_data = from_tfp(
         data.obj,
         var_names=["mu", "tau", "eta"],
         model_fn=lambda: data.model(
             eight_schools_params["J"], eight_schools_params["sigma"].astype(np.float32)
         ),
         observed=eight_schools_params["y"].astype(np.float32),
     )
     return inference_data
Esempio n. 4
0
 def get_inference_data2(self, data):
     """Fit only."""
     inference_data = from_tfp(data.obj)
     return inference_data