def test_io_method(self, data, eight_schools_params): inference_data = self.get_inference_data( # pylint: disable=W0612 data, eight_schools_params) assert hasattr(inference_data, "posterior") here = os.path.dirname(os.path.abspath(__file__)) data_directory = os.path.join(here, "saved_models") filepath = os.path.join(data_directory, "io_method_testfile.nc") assert not os.path.exists(filepath) # InferenceData method inference_data.to_netcdf(filepath) assert os.path.exists(filepath) assert os.path.getsize(filepath) > 0 inference_data2 = InferenceData.from_netcdf(filepath) assert hasattr(inference_data2, "posterior") os.remove(filepath) assert not os.path.exists(filepath)
def test_io_method(self, data, eight_schools_params, groups_arg): # create InferenceData and check it has been properly created inference_data = self.get_inference_data( # pylint: disable=W0612 data, eight_schools_params ) test_dict = { "posterior": ["eta", "theta", "mu", "tau"], "posterior_predictive": ["eta", "theta", "mu", "tau"], "sample_stats": ["eta", "theta", "mu", "tau"], "prior": ["eta", "theta", "mu", "tau"], "prior_predictive": ["eta", "theta", "mu", "tau"], "sample_stats_prior": ["eta", "theta", "mu", "tau"], "observed_data": ["J", "y", "sigma"], } fails = check_multiple_attrs(test_dict, inference_data) assert not fails # check filename does not exist and use to_netcdf method here = os.path.dirname(os.path.abspath(__file__)) data_directory = os.path.join(here, "saved_models") filepath = os.path.join(data_directory, "io_method_testfile.nc") assert not os.path.exists(filepath) # InferenceData method inference_data.to_netcdf( filepath, groups=("posterior", "observed_data") if groups_arg else None ) # assert file has been saved correctly assert os.path.exists(filepath) assert os.path.getsize(filepath) > 0 inference_data2 = InferenceData.from_netcdf(filepath) if groups_arg: # if groups arg, update test dict to contain only saved groups test_dict = { "posterior": ["eta", "theta", "mu", "tau"], "observed_data": ["J", "y", "sigma"], } assert not hasattr(inference_data2, "sample_stats") fails = check_multiple_attrs(test_dict, inference_data2) assert not fails os.remove(filepath) assert not os.path.exists(filepath)