示例#1
0
 def test_io_method(self, data, eight_schools_params):
     inference_data = self.get_inference_data(  # pylint: disable=W0612
         data, eight_schools_params)
     assert hasattr(inference_data, "posterior")
     here = os.path.dirname(os.path.abspath(__file__))
     data_directory = os.path.join(here, "saved_models")
     filepath = os.path.join(data_directory, "io_method_testfile.nc")
     assert not os.path.exists(filepath)
     # InferenceData method
     inference_data.to_netcdf(filepath)
     assert os.path.exists(filepath)
     assert os.path.getsize(filepath) > 0
     inference_data2 = InferenceData.from_netcdf(filepath)
     assert hasattr(inference_data2, "posterior")
     os.remove(filepath)
     assert not os.path.exists(filepath)
示例#2
0
    def test_io_method(self, data, eight_schools_params, groups_arg):
        # create InferenceData and check it has been properly created
        inference_data = self.get_inference_data(  # pylint: disable=W0612
            data, eight_schools_params
        )
        test_dict = {
            "posterior": ["eta", "theta", "mu", "tau"],
            "posterior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats": ["eta", "theta", "mu", "tau"],
            "prior": ["eta", "theta", "mu", "tau"],
            "prior_predictive": ["eta", "theta", "mu", "tau"],
            "sample_stats_prior": ["eta", "theta", "mu", "tau"],
            "observed_data": ["J", "y", "sigma"],
        }
        fails = check_multiple_attrs(test_dict, inference_data)
        assert not fails

        # check filename does not exist and use to_netcdf method
        here = os.path.dirname(os.path.abspath(__file__))
        data_directory = os.path.join(here, "saved_models")
        filepath = os.path.join(data_directory, "io_method_testfile.nc")
        assert not os.path.exists(filepath)
        # InferenceData method
        inference_data.to_netcdf(
            filepath, groups=("posterior", "observed_data") if groups_arg else None
        )

        # assert file has been saved correctly
        assert os.path.exists(filepath)
        assert os.path.getsize(filepath) > 0
        inference_data2 = InferenceData.from_netcdf(filepath)
        if groups_arg:  # if groups arg, update test dict to contain only saved groups
            test_dict = {
                "posterior": ["eta", "theta", "mu", "tau"],
                "observed_data": ["J", "y", "sigma"],
            }
            assert not hasattr(inference_data2, "sample_stats")
        fails = check_multiple_attrs(test_dict, inference_data2)
        assert not fails

        os.remove(filepath)
        assert not os.path.exists(filepath)