Exemplo n.º 1
0
 def test_load_numpyro_model_with_syntax_error(self):
     try:
         load_custom_numpyro_model('./tests/models/syntax_error.py',
                                   Namespace(), [], pd.DataFrame())
     except ModelException as e:
         if isinstance(e.base, SyntaxError):
             return  # = success here; otherwise, fall through to next line
     self.fail(
         "load_custom_numpyro_model did not raise SyntaxError on model with syntax error"
     )
Exemplo n.º 2
0
 def test_load_numpyro_model_model_not_a_function(self):
     try:
         load_custom_numpyro_model('./tests/models/model_not_a_function.py',
                                   Namespace(), [], pd.DataFrame())
     except ModelException as e:
         if e.title.find('model'.upper()) != -1 and e.msg.find(
                 'must be a function') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for model not being a function, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for model not being a function"
     )
Exemplo n.º 3
0
    def test_load_numpyro_model_model_factory_with_autoguide(self):
        orig_data = pd.DataFrame({
            'first': np.zeros(10),
            'second': np.ones(10)
        })
        model, guide, preprocess, postprocess = load_custom_numpyro_model(
            './tests/models/model_factory_with_autoguide.py',
            Namespace(epsilon=1.), ['--prior_mu', '10'], orig_data)
        self.assertIsNotNone(model)
        self.assertIsNotNone(guide)
        self.assertIsNotNone(preprocess)
        self.assertIsNotNone(postprocess)
        z = orig_data.to_numpy()
        guide_samples_with_obs = trace(seed(
            guide, jax.random.PRNGKey(0))).get_trace(z, num_obs_total=10)
        self.assertEqual(
            guide_samples_with_obs['guide_loc']['value'].shape,
            (4, ))  # 2 parameters (mu, sigma) with 2 dimensions each
        self.assertEqual(guide_samples_with_obs['guide_scale']['value'].shape,
                         (4, ))
        guide_samples_no_obs = trace(seed(
            guide, jax.random.PRNGKey(0))).get_trace(num_obs_total=10)
        self.assertEqual(guide_samples_no_obs['guide_loc']['value'].shape,
                         (4, ))
        self.assertEqual(guide_samples_no_obs['guide_scale']['value'].shape,
                         (4, ))

        samples_with_obs = trace(seed(model, jax.random.PRNGKey(0))).get_trace(
            z, num_obs_total=10)
        self.assertTrue(np.allclose(samples_with_obs['x']['value'], z))
        samples_no_obs = trace(seed(
            model, jax.random.PRNGKey(0))).get_trace(num_obs_total=10)
        self.assertEqual(samples_no_obs['x']['value'].shape, (1, 2))
        self.assertFalse(np.allclose(samples_no_obs['x']['value'], z))
Exemplo n.º 4
0
 def test_load_numpyro_model_model_factory_wrong_returns_bad_tuple(self):
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     try:
         load_custom_numpyro_model(
             './tests/models/model_factory_wrong_returns_bad_tuple.py',
             Namespace(epsilon=1.), ['--prior_mu', '10'], orig_data)
     except ModelException as e:
         if e.title.find('model factory'.upper()) != -1 and e.msg.find(
                 'either a model function or a tuple') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for wrong returns in model_factory, but did not give expected explanation; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for wrong returns in model_factory"
     )
Exemplo n.º 5
0
 def test_load_numpyro_model_model_factory_broken(self):
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     try:
         load_custom_numpyro_model('./tests/models/model_factory_broken.py',
                                   Namespace(epsilon=1.),
                                   ['--prior_mu', '10'], orig_data)
     except ModelException as e:
         print(e.title)
         if e.title.find('model factory'.upper()) != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model_factory, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model_factory"
     )
Exemplo n.º 6
0
 def test_load_numpyro_model_preprocess_single_return_series(self):
     orig_data = pd.DataFrame({
         'first': np.ones(10),
         'second': np.zeros(10)
     })
     _, _, preprocess, _ = load_custom_numpyro_model(
         './tests/models/preprocess_single_return_series.py', Namespace(),
         [], orig_data)
     train_data, num_data, feature_names = preprocess(orig_data)
     self.assertEqual(10, num_data)
     self.assertIsInstance(train_data, tuple)
     self.assertEqual(1, len(train_data))
     self.assertIsInstance(train_data[0], pd.Series)
     self.assertTrue(np.allclose(orig_data['first'] + 2, train_data[0]))
     self.assertEqual(['new_first'], feature_names)
Exemplo n.º 7
0
 def test_load_numpyro_model_model_not_allowing_None_arguments(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_no_none.py', Namespace(), [],
         pd.DataFrame())
     try:
         seed(model, jax.random.PRNGKey(0))(num_obs_total=100)
     except ModelException as e:
         if e.title.find('model'.upper()) != -1 and e.msg.find(
                 'None for synthesising data') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
Exemplo n.º 8
0
 def test_load_numpyro_model_model_without_num_obs_total(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_no_num_obs_total.py',
         Namespace(), [], pd.DataFrame())
     z = np.ones((10, 2))
     try:
         seed(model, jax.random.PRNGKey(0))(z, num_obs_total=100)
     except ModelException as e:
         if e.title.find('model'.upper()) != -1 and e.msg.find(
                 'num_obs_total') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
Exemplo n.º 9
0
 def test_load_numpyro_model_broken_model(self):
     model, _, _, _ = load_custom_numpyro_model(
         './tests/models/simple_gauss_model_broken.py', Namespace(), [],
         pd.DataFrame())
     z = np.ones((10, 2))
     try:
         seed(model, jax.random.PRNGKey(0))(z)
     except ModelException as e:
         if isinstance(e.base,
                       NameError) and e.title.find('model'.upper()) != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in model, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         f"load_custom_numpyro_model did not raise for error in model")
Exemplo n.º 10
0
 def test_load_numpyro_model_simple_working_model(self):
     """ only verifies that no errors occur and all returned functions are not None """
     model, guide, preprocess, postprocess = load_custom_numpyro_model(
         './tests/models/simple_gauss_model.py', Namespace(), [],
         pd.DataFrame())
     self.assertIsNotNone(model)
     self.assertIsNotNone(guide)
     self.assertIsNotNone(preprocess)
     self.assertIsNotNone(postprocess)
     z = np.ones((10, 2))
     samples_with_obs = trace(seed(model, jax.random.PRNGKey(0))).get_trace(
         z, num_obs_total=10)
     self.assertTrue(np.allclose(samples_with_obs['x']['value'], z))
     samples_no_obs = trace(seed(
         model, jax.random.PRNGKey(0))).get_trace(num_obs_total=10)
     self.assertEqual(samples_no_obs['x']['value'].shape, (1, 2))
     self.assertFalse(np.allclose(samples_no_obs['x']['value'], z))
Exemplo n.º 11
0
 def test_load_numpyro_model_with_broken_preprocess(self):
     orig_data = pd.DataFrame({
         'first': np.ones(10),
         'second': np.zeros(10)
     })
     _, _, preprocess, _ = load_custom_numpyro_model(
         './tests/models/preprocess_broken.py', Namespace(), [], orig_data)
     try:
         preprocess(orig_data)
     except ModelException as e:
         if isinstance(e.base, KeyError) and e.title.find(
                 'preprocessing data'.upper()) != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in preprocess, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did not raise for error in preprocess")
Exemplo n.º 12
0
 def test_load_numpyro_model_with_postprocess(self):
     samples = {'x': np.zeros((10, 2))}
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     feature_names = ['first', 'second']
     _, _, _, postprocess = load_custom_numpyro_model(
         './tests/models/postprocess.py', Namespace(), [], orig_data)
     syn_data, encoded_syn_data = postprocess(samples, orig_data,
                                              feature_names)
     self.assertIsInstance(syn_data, pd.DataFrame)
     self.assertTrue(np.allclose(samples['x'][:, 0], syn_data['first']))
     self.assertTrue(np.allclose(samples['x'][:, 1], syn_data['second']))
     self.assertIsInstance(encoded_syn_data, pd.DataFrame)
     self.assertTrue(
         np.allclose(samples['x'][:, 0] + 2, encoded_syn_data['first']))
     self.assertTrue(
         np.allclose(samples['x'][:, 1] + 2, encoded_syn_data['second']))
Exemplo n.º 13
0
 def test_load_numpyro_model_with_preprocess_returns_array(self):
     orig_data = pd.DataFrame({
         'first': np.ones(10),
         'second': np.zeros(10)
     })
     _, _, preprocess, _ = load_custom_numpyro_model(
         './tests/models/preprocess_returns_array.py', Namespace(), [],
         orig_data)
     try:
         preprocess(orig_data)
     except ModelException as e:
         if e.title.find('preprocessing data'.upper()) != -1 and e.msg.find(
                 'must return') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for non-dataframe returns, but did not give expected explanation; got:\n{e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did not raise for non-dataframe returns in preprocess"
     )
Exemplo n.º 14
0
 def test_load_numpyro_model_with_broken_postprocess(self):
     samples = {'x': np.zeros((10, 2))}
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     feature_names = ['first', 'second']
     _, _, _, postprocess = load_custom_numpyro_model(
         './tests/models/postprocess_broken.py', Namespace(), [], orig_data)
     try:
         postprocess(samples, orig_data, feature_names)
     except ModelException as e:  # check exception is raised
         # and original exception is passed on correctly
         if isinstance(e.base, KeyError) and e.title.find(
                 'postprocessing data'.upper()) != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for error in postprocess, but did not correctly pass causal exception; got: {e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did not raise for error in postprocess")
Exemplo n.º 15
0
 def test_load_numpyro_model_with_postprocess_old_style_wrong_returns(self):
     samples = {'x': np.zeros((10, 2))}
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     feature_names = ['first', 'second']
     _, _, _, postprocess = load_custom_numpyro_model(
         './tests/models/postprocess_old_style_wrong_returns.py',
         Namespace(), [], orig_data)
     try:
         postprocess(samples, orig_data, feature_names)
     except ModelException as e:
         if e.title.find('postprocessing data'.upper()
                         ) != -1 and e.msg.find('must return') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for wrong return value in postprocess, but did not give expected explanation; got:\n{e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did not raise for wrong return value in postprocess"
     )
Exemplo n.º 16
0
 def test_load_numpyro_model_with_postprocess_wrong_signature(self):
     samples = {'x': np.zeros((10, 2))}
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     feature_names = ['first', 'second']
     _, _, _, postprocess = load_custom_numpyro_model(
         './tests/models/postprocess_wrong_signature.py', Namespace(), [],
         orig_data)
     try:
         postprocess(samples, orig_data, feature_names)
     except ModelException as e:  # check exception is raised
         # and original exception is passed on correctly
         if e.title.find('postprocessing data'.upper()
                         ) != -1 and e.msg.find('as argument') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for wrong signature in postprocess, but did not give expected explanation; got:\n{e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did not raise for wrong signature in postprocess"
     )
Exemplo n.º 17
0
 def test_load_numpyro_model_with_old_postprocess_but_assumed_new_model(
         self):
     samples = {'first': np.zeros((10, )), 'second': np.zeros((10, ))}
     orig_data = pd.DataFrame({
         'first': np.zeros(10),
         'second': np.ones(10)
     })
     feature_names = ['first', 'second']
     _, _, _, postprocess = load_custom_numpyro_model(
         './tests/models/postprocess_old_style.py', Namespace(), [],
         orig_data)
     try:
         postprocess(samples, orig_data, feature_names)
     except ModelException as e:
         if e.title.find('postprocessing data'.upper(
         )) != -1 and e.msg.find(
                 'postprocessing function with a single argument') != -1:
             return
         self.fail(
             f"load_custom_numpyro_model did raise for wrong sample sites for old-style postprocess, but did not give expected explanation; got:\n{e.format_message('')}"
         )
     self.fail(
         "load_custom_numpyro_model did no raise for wrong sample sites for old-style postprocess"
     )
Exemplo n.º 18
0
 def test_load_numpyro_model_file_not_found(self):
     with self.assertRaises(FileNotFoundError):
         load_custom_numpyro_model('./tests/models/does_not_exist',
                                   Namespace(), [], pd.DataFrame())
Exemplo n.º 19
0
 def test_load_numpyro_model_no_model_fn(self):
     with self.assertRaisesRegex(ModelException,
                                 "does neither specify a 'model'"):
         load_custom_numpyro_model('./tests/models/empty_model.py',
                                   Namespace(), [], pd.DataFrame())
Exemplo n.º 20
0
 def test_load_numpyro_model_not_a_module(self):
     with self.assertRaisesRegex(ModelException, "as a Python module"):
         load_custom_numpyro_model('./tests/models/gauss_data.csv',
                                   Namespace(), [], pd.DataFrame())
Exemplo n.º 21
0
def main(args: argparse.Namespace, unknown_args: Iterable[str]) -> int:
    # read data
    try:
        df = pd.read_csv(args.data_path)
    except Exception as e:
        print("#### UNABLE TO READ DATA FILE ####")
        print(e)
        exit(1)

    args = argparse.Namespace(**vars(args), output_path='')

    train_df = df.copy()
    if args.drop_na:
        train_df = train_df.dropna()
    num_data = 100

    try:
        # loading the model
        if args.model_path[-3:] == '.py':
            try:
                model, guide, preprocess_fn, postprocess_fn = load_custom_numpyro_model(
                    args.model_path, args, unknown_args, train_df)
            except (ModuleNotFoundError, FileNotFoundError) as e:
                print("#### COULD NOT FIND THE MODEL FILE ####")
                print(e)
                exit(1)
        else:
            print("#### loading txt file model currently not supported ####")
            exit(2)

        print("Extracting relevant features from data (using preprocess)")
        zeroed_train_data, _, feature_names = preprocess_fn(train_df.iloc[:2])
        zeroed_train_data = tuple(
            np.zeros_like(df) for df in zeroed_train_data)

        print("Sampling from prior distribution (using model, guide)")
        # We use Preditive with model to sample from the prior predictive distribution. Since this does not inolve guide,
        # Predictive has no clue about which of the samples are for observations and which are for parameter values.
        # Since we expect postprocess_fn to deal only with observations, we trace through guide to identify
        # parameter sample sites and filter those out. (To invoke guide we need a small batch of data, for which we
        # use whatever preprocess_fn returned to get the right shapes, but zero it out to prevent information leakage).
        try:
            prior_samples = Predictive(model, num_samples=num_data)(
                jax.random.PRNGKey(0))
        except Exception as e:
            raise ModelException(
                "Error while obtaining prior samples from model",
                base_exception=e)
        try:
            parameter_sites = trace(seed(
                guide, jax.random.PRNGKey(0))).get_trace(*zeroed_train_data)
        except Exception as e:
            raise ModelException(
                "Error while determining the sampling sites of parameter priors"
            )
        parameter_sites = parameter_sites.keys()
        prior_samples = {
            site: samples.squeeze(1)
            for site, samples in prior_samples.items()
            if site not in parameter_sites
        }

        print(
            "Transforming prior samples to output domain to obtain dummy data (using postprocess)"
        )
        _, syn_prior_encoded = postprocess_fn(prior_samples, df, feature_names)

        print("Preprocessing dummy data (using preprocess)")
        train_data, num_train_data, feature_names = preprocess_fn(
            syn_prior_encoded)

        assert isinstance(train_data, tuple)
        assert num_train_data == num_data  # TODO: maybe not?

        print("Inferring model parameters (using model, guide)")
        try:
            posterior_params, _ = train_model_no_dp(
                d3p.random.PRNGKey(0),
                model,
                guide,
                train_data,
                batch_size=num_train_data // 2,
                num_data=num_train_data,
                num_epochs=3,
                silent=True)
        except Exception as e:
            raise ModelException("Error while performing inference",
                                 base_exception=e)

        print("Sampling from posterior distribution (using model, guide)")
        try:
            # posterior_samples = Predictive(
            #     model, guide = guide, params = posterior_params,
            #     num_samples = num_train_data
            # )(jax.random.PRNGKey(0))
            posterior_samples = sample_synthetic_data(model, guide,
                                                      posterior_params,
                                                      jax.random.PRNGKey(0),
                                                      num_train_data,
                                                      num_train_data)
        except Exception as e:
            raise ModelException(
                "Error while obtaining posterior samples from model",
                base_exception=e)
        print("Postprocessing (using postprocess)")
        conditioned_postprocess_fn = lambda samples: postprocess_fn(
            samples, df, feature_names)
        reshape_and_postprocess_synthetic_data(
            posterior_samples,
            conditioned_postprocess_fn,
            separate_output=True,
            num_parameter_samples=num_train_data  #
        )

        print("Everything okay!")
        return 0

    except ModelException as e:
        if args.full_traceback:
            print(e)
        else:
            print(e.format_message(args.model_path))
    except AssertionError as e:
        raise e
    except Exception as e:
        print("#### AN UNCATEGORISED ERROR OCCURRED ####")
        raise e
    return 1
Exemplo n.º 22
0
def main():
    args, unknown_args = parser.parse_known_args()
    print(args)
    if unknown_args:
        print(f"Additional received arguments: {unknown_args}")

    # read data
    try:
        df = pd.read_csv(args.data_path)
    except Exception as e:
        print("#### UNABLE TO READ DATA FILE ####")
        print(e)
        return 1
    print("Loaded data set has {} rows (entries) and {} columns (features).".format(*df.shape))
    num_data = len(df)

    try:
    # check whether we parse model from txt or whether we have a numpyro module
        if args.model_path[-3:] == '.py':

            train_df = df.copy()
            if args.drop_na:
                train_df = train_df.dropna()

            try:
                model, guide, preprocess_fn, postprocess_fn = load_custom_numpyro_model(args.model_path, args, unknown_args, train_df)
            except (ModuleNotFoundError, FileNotFoundError) as e:
                print("#### COULD NOT FIND THE MODEL FILE ####")
                print(e)
                return 1

            train_data, num_data, feature_names = preprocess_fn(train_df)
        else:
            print("Parsing model from txt file (was unable to read it as python module containing numpyro code)")
            k = args.k
            # read model file
            with open(args.model_path, 'r') as model_handle:
                model_str = "".join(model_handle.readlines())
            features = automodel.parse_model(model_str)
            feature_names = [feature.name for feature in features]

            # pick features from data according to model file
            missing_features = set(feature_names).difference(df.columns)
            if missing_features:
                raise automodel.ParsingError(
                    "The model specifies features that are not present in the data:\n{}".format(
                        ", ".join(missing_features)
                    )
                )

            df = df.loc[:, feature_names]

            train_df = df.copy() # TODO: this duplicates code with the other branch but cannot currently pull it out because we are manipulating df above
            if args.drop_na:
                train_df = train_df.dropna()

            # TODO normalize?

            # data preprocessing: determines number of categories for Categorical
            #   distribution and maps categorical values in the data to ints
            for feature in features:
                train_df = feature.preprocess_data(train_df)

            # build model
            model = automodel.make_model(features, k)

            # build variational guide for optimization
            guide = AutoDiagonalNormal(model)

            # postprocessing for automodel
            postprocess_fn = automodel.postprocess_function_factory(features)
            num_data = train_df.shape[0]
            train_data = (train_df,)

        assert isinstance(train_data, tuple)
        if len(train_data) == 1:
            print("After preprocessing, the data has {} entries with {} features each.".format(*train_data[0].shape))
        else:
            print("After preprocessing, the data was split into {} splits:".format(len(train_data)))
            for i, x in enumerate(train_data):
                print("\tSplit {} has {} entries with {} features each.".format(i, x.shape[0], 1 if x.ndim == 1 else x.shape[1]))

        # compute DP values
        # TODO need to make this fail safely
        batch_size = q_to_batch_size(args.sampling_ratio, num_data)

        if not args.no_privacy:
            target_delta = args.delta
            if target_delta is None:
                target_delta = 1. / num_data
            if target_delta * num_data > 1.:
                print("!!!!! WARNING !!!!! The given value for privacy parameter delta ({:1.3e}) exceeds 1/(number of data) ({:1.3e}),\n" \
                    "which the maximum value that is usually considered safe!".format(
                        target_delta, 1. / num_data
                    ))
                x = input("Continue? (type YES ): ")
                if x != "YES":
                    print("Aborting...")
                    return 4
                print("Continuing... (YOU HAVE BEEN WARNED!)")

            num_compositions = int(args.num_epochs / args.sampling_ratio)
            dp_sigma, epsilon, _ = approximate_sigma_remove_relation(
                args.epsilon, target_delta, args.sampling_ratio, num_compositions
            )
            sigma_per_sample = dp_sigma / q_to_batch_size(args.sampling_ratio, num_data)
            print("Will apply noise with std deviation {:.2f} (~ {:.2f} per element in batch) to achieve privacy epsilon "\
                "of {:.3f} (for delta {:.2e}) ".format(dp_sigma, sigma_per_sample, epsilon, target_delta))
            # TODO: warn for high noise? but when is it too high? what is a good heuristic?

            do_training = lambda inference_rng: train_model(
                inference_rng,
                d3p.random,
                model, guide,
                train_data,
                batch_size=batch_size,
                num_data=num_data,
                num_epochs=args.num_epochs,
                dp_scale=dp_sigma,
                clipping_threshold=args.clipping_threshold
            )
        else:
            print("!!!!! WARNING !!!!! PRIVACY FEATURES HAVE BEEN DISABLED!")
            do_training = lambda inference_rng: train_model_no_dp(
                inference_rng,
                model, guide,
                train_data,
                batch_size=batch_size,
                num_data=num_data,
                num_epochs=args.num_epochs
            )

        inference_rng, sampling_rng = initialize_rngs(args.seed)

        # learn posterior distributions
        try:
            posterior_params, elbo = do_training(inference_rng)
        except (InferenceException, FloatingPointError):
            print("################################## ERROR ##################################")
            print("!!!!! The inference procedure encountered a NaN value (not a number). !!!!!")
            print("This means the model has major difficulties in capturing the data and is")
            print("likely to happen when the dataset is very small and/or sparse.")
            print("Try adapting (simplifying) the model.")
            print("Aborting...")
            return 2

        # Store learned model parameters
        # TODO: we should have a mode for twinify that allows to rerun the sampling without training, using stored parameters
        store_twinify_run_result(f"{args.output_path}.p", posterior_params, elbo, args, unknown_args, __version__)

        # sample synthetic data
        print("Model learning complete; now sampling data!")
        num_synthetic = args.num_synthetic
        if num_synthetic is None:
            num_synthetic = num_data

        num_parameter_samples = int(np.ceil(num_synthetic / args.num_synthetic_records_per_parameter_sample))
        num_synthetic = num_parameter_samples * args.num_synthetic_records_per_parameter_sample
        print(f"Will sample {args.num_synthetic_records_per_parameter_sample} synthetic data records for each of "
              f"{num_parameter_samples} samples from the parameter posterior for a total of {num_synthetic} records.")
        if args.separate_output:
            print("They will be stored in separate data sets for each parameter posterior sample.")
        else:
            print("They will be stored in a single large data set.")
        posterior_samples = sample_synthetic_data(
            model, guide, posterior_params, sampling_rng, num_parameter_samples, args.num_synthetic_records_per_parameter_sample
        )

        # postprocess: so that the synthetic twin looks like the original data
        #   - extract samples from the posterior_samples dictionary and construct pd.DataFrame
        #   - if preprocessing involved data mapping, it is mapped back here
        conditioned_postprocess_fn = lambda posterior_samples: postprocess_fn(posterior_samples, df, feature_names)
        for i, (syn_df, encoded_syn_df) in enumerate(reshape_and_postprocess_synthetic_data(
            posterior_samples, conditioned_postprocess_fn, args.separate_output, num_parameter_samples
        )):
            if args.separate_output:
                filename = f"{args.output_path}.{i}.csv"
            else:
                filename = f"{args.output_path}.csv"
            encoded_syn_df.to_csv(filename, index=False)

        ### illustrate results TODO need to adopt new way of handing train_df
        #if args.visualize != 'none':
        #    show_popups = args.visualize in ('popup', 'both')
        #    save_plots = args.visualize in ('store', 'both')
        #    # Missing value rate
        #    if not args.drop_na:
        #        missing_value_fig = plot_missing_values(syn_df, train_df, show=show_popups)
        #        if save_plots:
        #            missing_value_fig.savefig(args.output_path + "_missing_value_plots.svg")
        #    # Marginal violins
        #    margin_fig = plot_margins(syn_df, train_df, show=show_popups)
        #    # Covariance matrices
        #    cov_fig = plot_covariance_heatmap(syn_df, train_df, show=show_popups)
        #    if save_plots:
        #        margin_fig.savefig(args.output_path + "_marginal_plots.svg")
        #        cov_fig.savefig(args.output_path + "_correlation_plots.svg")
        #    if show_popups:
        #        plt.show()
        return 0
    except ModelException as e:
        print(e.format_message(args.model_path))
    except AssertionError as e:
        raise e
    except Exception as e:
        print("#### AN UNCATEGORISED ERROR OCCURRED ####")
        raise e
    return 1