Ejemplo n.º 1
0
    def r_fn(args):
        beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, events_ = args
        t = events_.shape[-2] - 1
        state = compute_state(init_state, events_, model_spec.STOICHIOMETRY)
        state = tf.gather(state, t, axis=-2)  # State on final inference day

        model = model_spec.CovidUK(
            covariates=covar_data,
            initial_state=init_state,
            initial_step=0,
            num_steps=events_.shape[-2],
            priors=priors,
        )

        xi_pred = model_spec.conditional_gp(
            model.model["xi"](beta1_, sigma_),
            xi_,
            tf.constant([events.shape[-2] + model_spec.XI_FREQ],
                        dtype=model_spec.DTYPE)[:, tf.newaxis],
        )

        par = dict(
            beta1=beta1_,
            beta2=beta2_,
            beta3=beta3_,
            sigma=sigma_,
            gamma0=gamma0_,
            xi=xi_,
        )
        print("xi shape:", par["xi"].shape)
        ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
        ngm = ngm_fn(t, state)
        return ngm
Ejemplo n.º 2
0
    def sim_fn(args):
        theta_, xi_, init_ = args

        par = dict(beta1=theta_[0], beta2=theta_[1], gamma=theta_[2], xi=xi_)

        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
        )
        sim = model.sample(**par)
        return sim["seir"]
Ejemplo n.º 3
0
    def sim_fn(args):
        beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args

        par = dict(
            beta1=beta1_,
            beta2=beta2_,
            beta3=beta3_,
            gamma0=gamma0_,
            gamma1=gamma1_,
            xi=xi_,
        )

        model = model_spec.CovidUK(
            covar_data,
            initial_state=init_,
            initial_step=init_step,
            num_steps=num_steps,
            priors=priors,
        )
        sim = model.sample(**par)
        return sim["seir"]
Ejemplo n.º 4
0
        def sim_fn(args):
            beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args

            # FNC NOTE:
            # adding another 0.0 to beta3 as TF complains of dimension mismatch otherwise
            par = dict(
                beta1=beta1_,
                beta2=beta2_,
                beta3=tf.concat([beta3_, [0.0, 0.0]], axis=-1),
                gamma0=gamma0_,
                gamma1=gamma1_,
                xi=xi_,
            )

            model = model_spec.CovidUK(
                covar_data,
                initial_state=init_,
                initial_step=init_step,
                num_steps=num_steps,
                priors=priors,
            )
            sim = model.sample(**par)
            return sim["seir"]
Ejemplo n.º 5
0
        def r_fn(args):
            beta1_, beta2_, beta_3, sigma_, xi_, gamma0_, events_ = args
            t = events_.shape[-2] - 1
            state = compute_state(init_state, events_,
                                  model_spec.STOICHIOMETRY)
            state = tf.gather(state, t,
                              axis=-2)  # State on final inference day

            model = model_spec.CovidUK(
                covariates=covar_data,
                initial_state=init_state,
                initial_step=0,
                num_steps=events_.shape[-2],
                priors=priors,
            )

            xi_pred = model_spec.conditional_gp(
                model.model["xi"](beta1_, sigma_),
                xi_,
                tf.constant([events.shape[-2] + model_spec.XI_FREQ],
                            dtype=model_spec.DTYPE)[:, tf.newaxis],
            )

            # FNC NOTE:
            # adding another 0.0 to beta3 as TF complains of dimension mismatch otherwise
            par = dict(
                beta1=beta1_,
                beta2=beta2_,
                beta3=tf.concat([beta_3, [0.0, 0.0]], axis=-1),
                sigma=sigma_,
                gamma0=gamma0_,
                xi=xi_,  # tf.reshape(xi_pred.sample(), [1]),
            )
            print("xi shape:", par["xi"].shape)
            ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
            ngm = ngm_fn(t, state)
            return ngm
Ejemplo n.º 6
0
# time epoch which we are analysing.
cases = model_spec.read_cases(config["data"]["reported_cases"])

# Single imputation of censored data
events = model_spec.impute_censored_events(cases)

# Initial conditions S(0), E(0), I(0), R(0) are calculated
# by calculating the state at the beginning of the inference period
state = compute_state(
    initial_state=tf.concat(
        [covar_data["N"], tf.zeros_like(events[:, 0, :])], axis=-1),
    events=events,
    stoichiometry=model_spec.STOICHIOMETRY,
)
start_time = state.shape[1] - cases.shape[1]
initial_state = state[:, start_time, :]
events = events[:, start_time:, :]

# Build model and sample
full_probability_model = model_spec.CovidUK(
    covariates=covar_data,
    initial_state=initial_state,
    initial_step=0,
    num_steps=80,
)
seir = full_probability_model.model["seir"](beta1=0.35,
                                            beta2=0.65,
                                            xi=[0.0] * 5,
                                            gamma=0.49)
sim = seir.sample()
Ejemplo n.º 7
0
        beta3=posterior["samples/beta3"][idx, ],
        sigma=posterior["samples/sigma"][idx, ],
        xi=posterior["samples/xi"][idx],
        gamma0=posterior["samples/gamma0"][idx],
        gamma1=posterior["samples/gamma1"][idx],
    )
    events = posterior["samples/events"][idx]
    init_state = posterior["initial_state"][:]
    state_timeseries = compute_state(init_state, events,
                                     model_spec.STOICHIOMETRY)

    # Build model
    model = model_spec.CovidUK(
        covar_data,
        initial_state=init_state,
        initial_step=0,
        num_steps=events.shape[1],
        priors=config["mcmc"]["prior"],
    )

    ngms = calc_R_it(param, events, init_state, covar_data,
                     config["mcmc"]["prior"])
    b, _ = power_iteration(ngms)
    rt = rayleigh_quotient(ngms, b)
    q = np.arange(0.05, 1.0, 0.05)
    rt_quantiles = pd.DataFrame({
        "Rt": np.quantile(rt, q, axis=-1)
    }, index=q).T.to_excel(
        os.path.join(config["output"]["results_dir"],
                     config["output"]["national_rt"]), )
Ejemplo n.º 8
0
        stoichiometry=STOICHIOMETRY,
    )
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    xi_freq = 14
    num_xi = events.shape[1] // xi_freq
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
    model = model_spec.CovidUK(
        covariates=covar_data,
        xi_freq=14,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
    )

    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
    def logp(theta, xi, events):
        return model.log_prob(
            dict(
                beta1=theta[0],
                beta2=theta[1],
                gamma=theta[2],
                xi=xi,
                nu=0.5,  # Fixed!
                seir=events,
Ejemplo n.º 9
0
        rdcc_nslots=1e6,
    )

    # Pre-determined thinning of posterior (better done in MCMC?)
    idx = slice(posterior["samples/theta"].shape[0])  # range(6000, 10000, 10)
    theta = posterior["samples/theta"][idx]
    xi = posterior["samples/xi"][idx]
    events = posterior["samples/events"][idx]
    init_state = posterior["initial_state"][:]
    state_timeseries = compute_state(init_state, events,
                                     model_spec.STOICHIOMETRY)

    # Build model
    model = model_spec.CovidUK(
        covar_data,
        initial_state=init_state,
        initial_step=0,
        num_steps=events.shape[1],
    )

    ngms = calc_R_it(theta, xi, events, init_state, covar_data)
    b, _ = power_iteration(ngms)
    rt = rayleigh_quotient(ngms, b)
    q = np.arange(0.05, 1.0, 0.05)
    rt_quantiles = np.stack([q, np.quantile(rt, q)], axis=-1)

    # Prediction requires simulation from the last available timepoint for 28 + 4 + 1 days
    # Note a 4 day recording lag in the case timeseries data requires that
    # now = state_timeseries.shape[-2] + 4
    prediction = predicted_incidence(
        theta,
        xi,
Ejemplo n.º 10
0
def runSummary(pipelineData):
    # Pipeline data should contain config at a minimium
    config = pipelineData['config']
    settings = config['SummaryData']

    if settings['input'] == 'processed':
        summaryData = GetData.SummaryData.process(config)
        pipelineData['summary'] = summaryData
        return pipelineData

    # as we're running in a function, we need to assign covar_data before defining
    # functions that call it in order for it to be in scope
    # previously, covar_dict was defined in the __name__ == 'main' portion of this script
    # moving to a pipeline necessitates this change.
    # grab all data from dicts
    # inference_period = config['dates']['inference_period']
    # date_low = config['dates']['low']
    # date_high = config['dates']['high']
    # weekday = config['dates']['weekday']

    if 'covar_data' in pipelineData:
        covar_data = pipelineData['covar_data']
    else:
        covar_data, tmp = GetData.CovarData(config)

    # Reproduction number calculation
    def calc_R_it(param, events, init_state, covar_data, priors):
        """Calculates effective reproduction number for batches of metapopulations
        :param theta: a tensor of batched theta parameters [B] + theta.shape
        :param xi: a tensor of batched xi parameters [B] + xi.shape
        :param events: a [B, M, T, X] batched events tensor
        :param init_state: the initial state of the epidemic at earliest inference date
        :param covar_data: the covariate data
        :return a batched vector of R_it estimates
        """
        def r_fn(args):
            beta1_, beta2_, beta_3, sigma_, xi_, gamma0_, events_ = args
            t = events_.shape[-2] - 1
            state = compute_state(init_state, events_,
                                  model_spec.STOICHIOMETRY)
            state = tf.gather(state, t,
                              axis=-2)  # State on final inference day

            model = model_spec.CovidUK(
                covariates=covar_data,
                initial_state=init_state,
                initial_step=0,
                num_steps=events_.shape[-2],
                priors=priors,
            )

            xi_pred = model_spec.conditional_gp(
                model.model["xi"](beta1_, sigma_),
                xi_,
                tf.constant([events.shape[-2] + model_spec.XI_FREQ],
                            dtype=model_spec.DTYPE)[:, tf.newaxis],
            )

            # FNC NOTE:
            # adding another 0.0 to beta3 as TF complains of dimension mismatch otherwise
            par = dict(
                beta1=beta1_,
                beta2=beta2_,
                beta3=tf.concat([beta_3, [0.0, 0.0]], axis=-1),
                sigma=sigma_,
                gamma0=gamma0_,
                xi=xi_,  # tf.reshape(xi_pred.sample(), [1]),
            )
            print("xi shape:", par["xi"].shape)
            ngm_fn = model_spec.next_generation_matrix_fn(covar_data, par)
            ngm = ngm_fn(t, state)
            return ngm

        return tf.vectorized_map(
            r_fn,
            elems=(
                param["beta1"],
                param["beta2"],
                param["beta3"],
                param["sigma"],
                param["xi"],
                param["gamma0"],
                events,
            ),
        )

    @tf.function
    def predicted_incidence(param, init_state, init_step, num_steps, priors):
        """Runs the simulation forward in time from `init_state` at time `init_time`
        for `num_steps`.
        :param theta: a tensor of batched theta parameters [B] + theta.shape
        :param xi: a tensor of batched xi parameters [B] + xi.shape
        :param events: a [B, M, S] batched state tensor
        :param init_step: the initial time step
        :param num_steps: the number of steps to simulate
        :param priors: the priors for gamma
        :returns: a tensor of srt_quhape [B, M, num_steps, X] where X is the number of state
                transitions
        """
        def sim_fn(args):
            beta1_, beta2_, beta3_, sigma_, xi_, gamma0_, gamma1_, init_ = args

            # FNC NOTE:
            # adding another 0.0 to beta3 as TF complains of dimension mismatch otherwise
            par = dict(
                beta1=beta1_,
                beta2=beta2_,
                beta3=tf.concat([beta3_, [0.0, 0.0]], axis=-1),
                gamma0=gamma0_,
                gamma1=gamma1_,
                xi=xi_,
            )

            model = model_spec.CovidUK(
                covar_data,
                initial_state=init_,
                initial_step=init_step,
                num_steps=num_steps,
                priors=priors,
            )
            sim = model.sample(**par)
            return sim["seir"]

        events = tf.map_fn(
            sim_fn,
            elems=(
                param["beta1"],
                param["beta2"],
                param["beta3"],
                param["sigma"],
                param["xi"],
                param["gamma0"],
                param["gamma1"],
                init_state,
            ),
            fn_output_signature=(tf.float64),
        )
        return events

    # Today's prevalence
    def prevalence(predicted_state, population_size, name=None):
        """Computes prevalence of E and I individuals

        :param state: the state at a particular timepoint [batch, M, S]
        :param population_size: the size of the population
        :returns: a dict of mean and 95% credibility intervals for prevalence
                in units of infections per person
        """
        prev = tf.reduce_sum(predicted_state[:, :, 1:3],
                             axis=-1) / tf.squeeze(population_size)
        return mean_and_ci(prev, name=name)

    def predicted_events(events, name=None):
        num_events = tf.reduce_sum(events, axis=-1)
        return mean_and_ci(num_events, name=name)

    # Load posterior file
    posterior_path = config['PosteriorData']['address']
    print("Using posterior:", posterior_path)
    posterior = h5py.File(
        os.path.expandvars(posterior_path, ),
        "r",
        rdcc_nbytes=1024**3,
        rdcc_nslots=1e6,
    )

    # Pre-determined thinning of posterior (better done in MCMC?)
    if posterior["samples/beta1"].size >= 10000:
        idx = range(6000, 10000, 10)
    else:
        print('Using smaller MCMC sample range')
        print('Size of posterior["samples/beta1"] is',
              posterior["samples/beta1"].size)
        idx = range(600, 1000, 10)
    param = dict(
        beta1=posterior["samples/beta1"][idx],
        beta2=posterior["samples/beta2"][idx],
        beta3=posterior["samples/beta3"][idx, ],
        sigma=posterior["samples/sigma"][idx, ],
        xi=posterior["samples/xi"][idx],
        gamma0=posterior["samples/gamma0"][idx],
        gamma1=posterior["samples/gamma1"][idx],
    )
    events = posterior["samples/events"][idx]
    init_state = posterior["initial_state"][:]
    state_timeseries = compute_state(init_state, events,
                                     model_spec.STOICHIOMETRY)

    # Build model
    model = model_spec.CovidUK(
        covar_data,
        initial_state=init_state,
        initial_step=0,
        num_steps=events.shape[1],
        priors=config["mcmc"]["prior"],
    )

    ngms = calc_R_it(param, events, init_state, covar_data,
                     config["mcmc"]["prior"])
    b, _ = power_iteration(ngms)
    rt = rayleigh_quotient(ngms, b)
    q = np.arange(0.05, 1.0, 0.05)

    # FNC Note: removed dict from this and
    # instead added Rt as a sheet name in the excel writer
    rt_quantiles = pd.DataFrame(np.quantile(rt, q, axis=-1), index=q).T
    rt_quantiles.to_excel(config['RtQuantileData']['address'], sheet_name='Rt')

    # Prediction requires simulation from the last available timepoint for 28 + 4 + 1 days
    # Note a 4 day recording lag in the case timeseries data requires that
    # now = state_timeseries.shape[-2] + 4
    prediction = predicted_incidence(
        param,
        init_state=state_timeseries[..., -1, :],
        init_step=state_timeseries.shape[-2] - 1,
        num_steps=70,
        priors=config["mcmc"]["prior"],
    )
    predicted_state = compute_state(state_timeseries[..., -1, :], prediction,
                                    model_spec.STOICHIOMETRY)

    # Prevalence now
    prev_now = prevalence(predicted_state[..., 4, :],
                          covar_data["N"],
                          name="prev")

    # Incidence of detections now
    cases_now = predicted_events(prediction[..., 4:5, 2], name="cases")

    # Incidence from now to now+7
    cases_7 = predicted_events(prediction[..., 4:11, 2], name="cases7")
    cases_14 = predicted_events(prediction[..., 4:18, 2], name="cases14")
    cases_21 = predicted_events(prediction[..., 4:25, 2], name="cases21")
    cases_28 = predicted_events(prediction[..., 4:32, 2], name="cases28")
    cases_56 = predicted_events(prediction[..., 4:60, 2], name="cases56")

    # Prevalence at day 7
    prev_7 = prevalence(predicted_state[..., 11, :],
                        covar_data["N"],
                        name="prev7")
    prev_14 = prevalence(predicted_state[..., 18, :],
                         covar_data["N"],
                         name="prev14")
    prev_21 = prevalence(predicted_state[..., 25, :],
                         covar_data["N"],
                         name="prev21")
    prev_28 = prevalence(predicted_state[..., 32, :],
                         covar_data["N"],
                         name="prev28")
    prev_56 = prevalence(predicted_state[..., 60, :],
                         covar_data["N"],
                         name="prev56")

    # Package up summary data
    # this will be saved into a pickle
    # Add LADs in for later reference
    summaryData = {
        'cases': {
            'now': cases_now,
            '7': cases_7,
            '14': cases_14,
            '21': cases_21,
            '28': cases_28,
            '56': cases_56
        },
        'prev': {
            'now': prev_now,
            '7': prev_7,
            '14': prev_14,
            '21': prev_21,
            '28': prev_28,
            '56': prev_56
        },
        'metrics': {
            'ngms': ngms,
            'b': b,
            'rt': rt,
            'q': q
        },
        'LADs': config['lad19cds']
    }

    # Save and pass on the output data
    if config['GenerateOutput']['summary']:
        settings = config['SummaryData']
        if settings['format'] == 'pickle':
            fn = settings['address']
            with open(fn, 'wb') as file:
                pickle.dump(summaryData, file)
    pipelineData['summary'] = summaryData
    return pipelineData
Ejemplo n.º 11
0
def runInference(pipelineData):

    # Read in settings
    config = pipelineData['config']

    # Extract data
    if 'covar_data' in pipelineData:
        covar_data = pipelineData['covar_data']
    else:
        covar_data, data = GetData.CovarData(config)
        pipelineData['covar_data'] = covar_data
        pipelineData['data'] = data
    # inference_period = config['dates']['inference_period']
    # date_low = config['dates']['low']
    # date_high = config['dates']['high']
    # weekday = config['dates']['weekday']

    # We load in cases and impute missing infections first, since this sets the
    # time epoch which we are analysing.
    cases = pipelineData['data']['cases_wide']

    # Impute censored events, return cases
    events = model_spec.impute_censored_events(cases)

    # Initial conditions are calculated by calculating the state
    # at the beginning of the inference period
    #
    # Imputed censored events that pre-date the first I-R events
    # in the cases dataset are discarded.  They are only used to
    # to set up a sensible initial state.
    _initial_state = tf.concat(
        [covar_data["N"], tf.zeros_like(events[:, 0, :])], axis=-1)
    state = compute_state(
        initial_state=_initial_state,
        events=events,
        stoichiometry=model_spec.STOICHIOMETRY,
    )
    start_time = state.shape[1] - cases.shape[1]
    initial_state = state[:, start_time, :]
    events = events[:, start_time:, :]
    num_metapop = covar_data["N"].shape[0]

    ########################################################
    # Build the model, and then construct the MCMC kernels #
    ########################################################
    def convert_priors(node):
        if isinstance(node, dict):
            for k, v in node.items():
                node[k] = convert_priors(v)
            return node
        return float(node)

    model = model_spec.CovidUK(
        covariates=covar_data,
        initial_state=initial_state,
        initial_step=0,
        num_steps=events.shape[1],
        priors=convert_priors(config["mcmc"]["prior"]),
    )

    # Full joint log posterior distribution
    # $\pi(\theta, \xi, y^{se}, y^{ei} | y^{ir})$
    # FNC NOTE:
    # adding another 0.0 to beta3 as TF complains of dimension mismatch otherwise
    def logp(block0, block1, events):
        return model.log_prob(
            dict(
                beta2=block0[0],
                gamma0=block0[1],
                gamma1=block0[2],
                sigma=block0[3],
                beta3=tf.concat([block0[4:6], [0.0, 0.0]], axis=-1),
                beta1=block1[0],
                xi=block1[1:],
                seir=events,
            ))

    # Build Metropolis within Gibbs sampler
    #
    # Kernels are:
    #     Q(\theta, \theta^\prime)
    #     Q(\xi, \xi^\prime)
    #     Q(Z^{se}, Z^{se\prime}) (partially-censored)
    #     Q(Z^{ei}, Z^{ei\prime}) (partially-censored)
    #     Q(Z^{se}, Z^{se\prime}) (occult)
    #     Q(Z^{ei}, Z^{ei\prime}) (occult)
    def make_blk0_kernel(shape, name):
        def fn(target_log_prob_fn, _):
            return tfp.mcmc.TransformedTransitionKernel(
                inner_kernel=AdaptiveRandomWalkMetropolis(
                    target_log_prob_fn=target_log_prob_fn,
                    initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE)
                    * 1e-1,
                    covariance_burnin=200,
                ),
                bijector=tfp.bijectors.Blockwise(
                    bijectors=[
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                        tfp.bijectors.Exp(),
                        tfp.bijectors.Identity(),
                    ],
                    block_sizes=[1, 2, 1, 2],
                ),
                name=name,
            )

        return fn

    def make_blk1_kernel(shape, name):
        def fn(target_log_prob_fn, _):
            return AdaptiveRandomWalkMetropolis(
                target_log_prob_fn=target_log_prob_fn,
                initial_covariance=np.eye(shape[0], dtype=model_spec.DTYPE) *
                1e-1,
                covariance_burnin=200,
                name=name,
            )

        return fn

    def make_partially_observed_step(target_event_id,
                                     prev_event_id=None,
                                     next_event_id=None,
                                     name=None):
        def fn(target_log_prob_fn, _):
            return tfp.mcmc.MetropolisHastings(
                inner_kernel=UncalibratedEventTimesUpdate(
                    target_log_prob_fn=target_log_prob_fn,
                    target_event_id=target_event_id,
                    prev_event_id=prev_event_id,
                    next_event_id=next_event_id,
                    initial_state=initial_state,
                    dmax=config["mcmc"]["dmax"],
                    mmax=config["mcmc"]["m"],
                    nmax=config["mcmc"]["nmax"],
                ),
                name=name,
            )

        return fn

    def make_occults_step(prev_event_id, target_event_id, next_event_id, name):
        def fn(target_log_prob_fn, _):
            return tfp.mcmc.MetropolisHastings(
                inner_kernel=UncalibratedOccultUpdate(
                    target_log_prob_fn=target_log_prob_fn,
                    topology=TransitionTopology(prev_event_id, target_event_id,
                                                next_event_id),
                    cumulative_event_offset=initial_state,
                    nmax=config["mcmc"]["occult_nmax"],
                    t_range=(events.shape[1] - 21, events.shape[1]),
                    name=name,
                ),
                name=name,
            )

        return fn

    def make_event_multiscan_kernel(target_log_prob_fn, _):
        return MultiScanKernel(
            config["mcmc"]["num_event_time_updates"],
            GibbsKernel(
                target_log_prob_fn=target_log_prob_fn,
                kernel_list=[
                    (0, make_partially_observed_step(0, None, 1, "se_events")),
                    (0, make_partially_observed_step(1, 0, 2, "ei_events")),
                    (0, make_occults_step(None, 0, 1, "se_occults")),
                    (0, make_occults_step(0, 1, 2, "ei_occults")),
                ],
                name="gibbs1",
            ),
        )

    # MCMC tracing functions
    def trace_results_fn(_, results):
        """Packs results into a dictionary"""
        results_dict = {}
        res0 = results.inner_results

        results_dict["block0"] = {
            "is_accepted":
            res0[0].inner_results.is_accepted,
            "target_log_prob":
            res0[0].inner_results.accepted_results.target_log_prob,
        }
        results_dict["block1"] = {
            "is_accepted": res0[1].is_accepted,
            "target_log_prob": res0[1].accepted_results.target_log_prob,
        }

        def get_move_results(results):
            return {
                "is_accepted":
                results.is_accepted,
                "target_log_prob":
                results.accepted_results.target_log_prob,
                "proposed_delta":
                tf.stack([
                    results.accepted_results.m,
                    results.accepted_results.t,
                    results.accepted_results.delta_t,
                    results.accepted_results.x_star,
                ]),
            }

        res1 = res0[2].inner_results
        results_dict["move/S->E"] = get_move_results(res1[0])
        results_dict["move/E->I"] = get_move_results(res1[1])
        results_dict["occult/S->E"] = get_move_results(res1[2])
        results_dict["occult/E->I"] = get_move_results(res1[3])

        return results_dict

    # Build MCMC algorithm here.  This will be run in bursts for memory economy
    @tf.function(autograph=False, experimental_compile=True)
    def sample(n_samples, init_state, thin=0, previous_results=None):
        with tf.name_scope("main_mcmc_sample_loop"):

            init_state = init_state.copy()

            gibbs_schema = GibbsKernel(
                target_log_prob_fn=logp,
                kernel_list=[
                    (0, make_blk0_kernel(init_state[0].shape, "block0")),
                    (1, make_blk1_kernel(init_state[1].shape, "block1")),
                    (2, make_event_multiscan_kernel),
                ],
                name="gibbs0",
            )

            samples, results, final_results = tfp.mcmc.sample_chain(
                n_samples,
                init_state,
                kernel=gibbs_schema,
                num_steps_between_results=thin,
                previous_kernel_results=previous_results,
                return_final_kernel_results=True,
                trace_fn=trace_results_fn,
            )

            return samples, results, final_results

    ####################################
    # Construct bursted MCMC loop here #
    ####################################

    # MCMC Control
    NUM_BURSTS = int(config["mcmc"]["num_bursts"])
    NUM_BURST_SAMPLES = int(config["mcmc"]["num_burst_samples"])
    NUM_EVENT_TIME_UPDATES = int(config["mcmc"]["num_event_time_updates"])
    NUM_SAVED_SAMPLES = NUM_BURST_SAMPLES * NUM_BURSTS

    # RNG stuff
    tf.random.set_seed(2)

    current_state = [
        np.array([0.2, 0.0, 0.0, 0.1, 0.0, 0.0], dtype=DTYPE),
        np.zeros(
            model.model["xi"](0.0, 0.1).event_shape[-1]
            # + model.model["beta3"]().event_shape[-1]
            + 1,
            dtype=DTYPE,
        ),
        events,
    ]
    print("Initial logpi:", logp(*current_state))

    # Output file
    samples, results, _ = sample(1, current_state)
    print('Storing posterior data at', config["PosteriorData"]["address"])
    posterior = Posterior(
        config["PosteriorData"]["address"],
        sample_dict={
            "beta2": (samples[0][:, 0], (NUM_BURST_SAMPLES, )),
            "gamma0": (samples[0][:, 1], (NUM_BURST_SAMPLES, )),
            "gamma1": (samples[0][:, 2], (NUM_BURST_SAMPLES, )),
            "sigma": (samples[0][:, 3], (NUM_BURST_SAMPLES, )),
            "beta3": (samples[0][:, 4:], (NUM_BURST_SAMPLES, 2)),
            "beta1": (samples[1][:, 0], (NUM_BURST_SAMPLES, )),
            "xi": (
                samples[1][:, 1:],
                (NUM_BURST_SAMPLES, samples[1].shape[1] - 1),
            ),
            "events": (samples[2], (NUM_BURST_SAMPLES, 43, 43,
                                    1)),  # Change so it adapts size correctly
        },
        results_dict=results,
        num_samples=NUM_SAVED_SAMPLES,
    )
    posterior._file.create_dataset("initial_state", data=initial_state)
    posterior._file.create_dataset("config", data=yaml.dump(config))

    # We loop over successive calls to sample because we have to dump results
    #   to disc, or else end OOM (even on a 32GB system).
    # with tf.profiler.experimental.Profile("/tmp/tf_logdir"):
    final_results = None
    for i in tqdm.tqdm(range(NUM_BURSTS),
                       unit_scale=NUM_BURST_SAMPLES * config["mcmc"]["thin"]):
        samples, results, final_results = sample(
            NUM_BURST_SAMPLES,
            init_state=current_state,
            thin=config["mcmc"]["thin"] - 1,
            previous_results=final_results,
        )
        current_state = [s[-1] for s in samples]
        print(current_state[0].numpy(), flush=True)

        start = perf_counter()
        posterior.write_samples(
            {
                "beta2": samples[0][:, 0],
                "gamma0": samples[0][:, 1],
                "gamma1": samples[0][:, 2],
                "sigma": samples[0][:, 3],
                "beta3": samples[0][:, 4:],
                "beta1": samples[1][:, 0],
                "xi": samples[1][:, 1:],
                "events": samples[2],
            },
            first_dim_offset=i * NUM_BURST_SAMPLES,
        )
        posterior.write_results(results,
                                first_dim_offset=i * NUM_BURST_SAMPLES)
        end = perf_counter()

        print("Storage time:", end - start, "seconds")
        print(
            "Acceptance theta:",
            tf.reduce_mean(
                tf.cast(results["block0"]["is_accepted"], tf.float32)),
        )
        print(
            "Acceptance xi:",
            tf.reduce_mean(
                tf.cast(results["block1"]["is_accepted"], tf.float32), ),
        )
        print(
            "Acceptance move S->E:",
            tf.reduce_mean(
                tf.cast(results["move/S->E"]["is_accepted"], tf.float32)),
        )
        print(
            "Acceptance move E->I:",
            tf.reduce_mean(
                tf.cast(results["move/E->I"]["is_accepted"], tf.float32)),
        )
        print(
            "Acceptance occult S->E:",
            tf.reduce_mean(
                tf.cast(results["occult/S->E"]["is_accepted"], tf.float32)),
        )
        print(
            "Acceptance occult E->I:",
            tf.reduce_mean(
                tf.cast(results["occult/E->I"]["is_accepted"], tf.float32)),
        )

    print(
        f"Acceptance theta: {posterior['results/block0/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance xi: {posterior['results/block1/is_accepted'][:].mean()}")
    print(
        f"Acceptance move S->E: {posterior['results/move/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance move E->I: {posterior['results/move/E->I/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult S->E: {posterior['results/occult/S->E/is_accepted'][:].mean()}"
    )
    print(
        f"Acceptance occult E->I: {posterior['results/occult/E->I/is_accepted'][:].mean()}"
    )

    del posterior