Ejemplo n.º 1
0
def construct_noise(name, modelParams, sigma=0.05, sigma_age=0.02):

    noise_R_sigma = yield HalfNormal(
        name=f"{name}_sigma",
        scale=sigma,
        conditionally_independent=True,
        event_stack=(modelParams.num_countries,),
        shape_label=("country"),
        transform=transformations.SoftPlus(),
    )

    noise_R_sigma_age = yield HalfNormal(
        name=f"{name}_sigma_age",
        scale=sigma_age,
        conditionally_independent=True,
        event_stack=(modelParams.num_countries, modelParams.num_age_groups),
        shape_label=("country", "age_group"),
        transform=transformations.SoftPlus(),
    )

    noise_R = (
        yield Normal(
            name=f"{name}",
            loc=0.0,
            scale=1.0,
            event_stack=(modelParams.length_sim, modelParams.num_countries,),
            shape_label=("time", "country"),
            conditionally_independent=True,
        )
    ) * noise_R_sigma[..., tf.newaxis, :]

    noise_R_age = (
        yield Normal(
            name=f"{name}_age",
            loc=0.0,
            scale=1.0,
            event_stack=(
                modelParams.length_sim,
                modelParams.num_countries,
                modelParams.num_age_groups,
            ),
            shape_label=("time", "country", "age_group"),
            conditionally_independent=True,
        )
    ) * noise_R_sigma_age[..., tf.newaxis, :, :]

    sum_noise_R = tf.math.cumsum(
        noise_R[..., tf.newaxis] + noise_R_age, exclusive=True, axis=-2
    )
    return sum_noise_R
def construct_R_0_old(name, modelParams, mean, beta):
    r"""
    Old constructor of :math:`R_0` using a gamma distribution:

    .. math::

        R_0 &\sim Gamma\left(\mu=2.5,\beta=2.0\right)

    Parameters
    ----------
    name: string
        Name of the distribution for trace and debugging.
    modelParams: :py:class:`covid19_npis.ModelParams`
        Instance of modelParams, mainly used for number of age groups and
        number of countries.
    mean:
        Mean :math:`\mu` of the gamma distribution.
    beta:
        Rate :math:`\beta` of the gamma distribution.

    Returns
    -------
    :
        R_0 tensor |shape| batch, country, age_group

    """
    event_shape = (modelParams.num_countries, modelParams.num_age_groups)
    R_0 = yield Gamma(
        name=name,
        concentration=mean * beta,
        rate=beta,
        conditionally_independent=True,
        event_stack=event_shape,
        transform=transformations.SoftPlus(),
        shape_label=("country", "age_group"),
    )
    return R_0
def _create_distributions(modelParams):
    r"""
        Returns a dict of distributions for further processing/sampling with the following priors:

        .. math::

            \alpha^\dagger_i &\sim \mathcal{N}\left(-1, 2\right)\quad \forall i,\\
            \Delta \alpha^\dagger_c &\sim \mathcal{N}\left(0, \sigma_{\alpha, \text{country}}\right) \quad \forall c, \\
            \Delta \alpha^\dagger_a &\sim \mathcal{N}\left(0, \sigma_{\alpha, \text{age}}\right)\quad \forall a, \\
            \sigma_{\alpha, \text{country}}  &\sim HalfNormal\left(0.1\right),\\
            \sigma_{\alpha, \text{age}} &\sim HalfNormal\left(0.1\right)

        .. math::

            l^\dagger_{\text{positive}} &\sim \mathcal{N}\left(3, 1\right),\\
            l^\dagger_{\text{negative}} &\sim \mathcal{N}\left(5, 2\right),\\
            \Delta l^\dagger_i &\sim \mathcal{N}\left(0,\sigma_{l, \text{interv.}} \right)\quad \forall i,\\
            \sigma_{l, \text{interv.}}&\sim HalfNormal\left(1\right)

        .. math::

            \Delta d_i  &\sim \mathcal{N}\left(0, \sigma_{d, \text{interv.}}\right)\quad \forall i,\\
            \Delta d_c &\sim \mathcal{N}\left(0, \sigma_{d, \text{country}}\right)\quad \forall c,\\
            \sigma_{d, \text{interv.}}  &\sim HalfNormal\left(0.3\right),\\
            \sigma_{d, \text{country}} &\sim HalfNormal\left(0.3\right)

        Parameters
        ----------
        modelParams: :py:class:`covid19_npis.ModelParams`
            Instance of modelParams, mainly used for number of age groups and
            number of countries.

        Return
        ------
        :
            interventions, distributions
    """
    log.debug("_create_distributions")
    """
        Δ Alpha cross for each country and age group with hyperdistributions
    """
    alpha_sigma_c = HalfNormal(
        name="alpha_sigma_country",
        scale=0.1,
        transform=transformations.SoftPlus(scale=0.1),
        conditionally_independent=True,
    )
    alpha_sigma_a = HalfNormal(
        name="alpha_sigma_age_group",
        scale=0.1,
        transform=transformations.SoftPlus(scale=0.1),
        conditionally_independent=True,
    )
    # We need to multiply alpha_sigma_c and alpha_sigma_a later. (See construct R_t)
    delta_alpha_cross_c = Normal(
        name="delta_alpha_cross_c",
        loc=0.0,
        scale=1.0,
        event_stack=(1, modelParams.num_countries,
                     1),  # intervention country age_group
        shape_label=(None, "country", None),
        conditionally_independent=True,
    )
    delta_alpha_cross_a = Normal(
        name="delta_alpha_cross_a",
        loc=0.0,
        scale=1.0,
        event_stack=(
            1,
            1,
            modelParams.num_age_groups,
        ),  # intervention country age_group
        shape_label=(None, None, "age_group"),
        conditionally_independent=True,
    )
    alpha_cross_i = Normal(
        name="alpha_cross_i",
        loc=-1.0,  # See publication for reasoning behind -1 and 2
        scale=2.0,
        conditionally_independent=True,
        event_stack=(
            modelParams.num_interventions,
            1,
            1,
        ),  # intervention country age_group
        shape_label=("intervention", None, None),
    )
    """
        l distributions
    """
    l_sigma_interv = HalfNormal(
        name="l_sigma_interv",
        scale=1.0,
        transform=transformations.SoftPlus(),
        conditionally_independent=True,
    )

    delta_l_cross_i = Normal(
        name="delta_l_cross_i",
        loc=0.0,
        scale=1.0,
        conditionally_independent=True,
        event_stack=(modelParams.num_interventions, ),
        shape_label=("intervention"),
    )
    log.debug(f"l_sigma_interv\n{l_sigma_interv}")
    # Δl_i^cross was created in intervention class see above
    l_positive_cross = Normal(
        name="l_positive_cross",
        loc=3.0,
        scale=1.0,
        conditionally_independent=True,
        event_stack=(1, ),
    )
    l_negative_cross = Normal(
        name="l_negative_cross",
        loc=5.0,
        scale=2.0,
        conditionally_independent=True,
        event_stack=(1, ),
    )
    """
        date d distributions
    """
    d_sigma_interv = HalfNormal(
        name="d_sigma_interv",
        scale=0.3,
        transform=transformations.SoftPlus(scale=0.3),
        conditionally_independent=True,
    )
    d_sigma_country = HalfNormal(
        name="d_sigma_country",
        scale=0.3,
        transform=transformations.SoftPlus(scale=0.3),
        conditionally_independent=True,
    )
    delta_d_i = Normal(
        name="delta_d_i",
        loc=0.0,
        scale=1.0,
        event_stack=(modelParams.num_interventions, 1, 1),
        shape_label=("intervention", None, None),
        conditionally_independent=True,
    )
    delta_d_c = Normal(
        name="delta_d_c",
        loc=0.0,
        scale=1.0,
        event_stack=(1, modelParams.num_countries, 1),
        shape_label=(None, "country", None),
        conditionally_independent=True,
    )

    # We create a dict here to pass all distributions to another function
    distributions = {}
    distributions["alpha_sigma_c"] = alpha_sigma_c
    distributions["alpha_sigma_a"] = alpha_sigma_a
    distributions["delta_alpha_cross_c"] = delta_alpha_cross_c
    distributions["delta_alpha_cross_a"] = delta_alpha_cross_a
    distributions["alpha_cross_i"] = alpha_cross_i
    distributions["l_sigma_interv"] = l_sigma_interv
    distributions["l_positive_cross"] = l_positive_cross
    distributions["l_negative_cross"] = l_negative_cross
    distributions["delta_l_cross_i"] = delta_l_cross_i
    distributions["d_sigma_interv"] = d_sigma_interv
    distributions["d_sigma_country"] = d_sigma_country
    distributions["delta_d_i"] = delta_d_i
    distributions["delta_d_c"] = delta_d_c

    return distributions
def construct_R_0(name, modelParams, loc, scale, hn_scale):
    r"""
        Constructs R_0 in the following hierarchical manner:

        .. math::

            R^*_{0,c} &= R^*_0 + \Delta R^*_{0,c}, \\
            R^*_0 &\sim \mathcal{N}\left(2,0.5\right)\\
            \Delta R^*_{0,c} &\sim \mathcal{N}\left(0, \sigma_{R^*, \text{country}}\right)\quad \forall c,\\
            \sigma_{R^*, \text{country}} &\sim HalfNormal\left(0.3\right)

        Parameters
        ----------
        name: str
            Name of the distribution (gets added to trace).
        modelParams: :py:class:`covid19_npis.ModelParams`
            Instance of modelParams, mainly used for number of age groups and
            number of countries.
        loc: number
            Location parameter of the R^*_0 Normal distribution.
        scale: number
            Scale paramter of the R^*_0 Normal distribution.
        hn_scale: number
            Scale parameter of the \sigma_{R^*, \text{country}} HaflNormal distribution.

        Returns
        -------
        :
            R_0 tensor |shape| batch, country, age_group
    """

    R_0 = (yield Normal(
        name="R_0",
        loc=0.0,
        scale=scale,
        conditionally_independent=True,
    )) + loc
    log.debug(f"R_0:\n{R_0}")

    R_0_sigma_c = (yield HalfNormal(
        name="R_0_sigma_c",
        scale=1.0,
        conditionally_independent=True,
        transform=transformations.SoftPlus(),
    )) * hn_scale

    delta_R_0_c = (yield Normal(
        name="delta_R_0_c",
        loc=0.0,
        scale=1.0,
        event_stack=(modelParams.num_countries),
        shape_label=("country"),
        conditionally_independent=True,
    )) * R_0_sigma_c[..., tf.newaxis]
    log.debug(f"delta_R_0_c:\n{delta_R_0_c}")

    # Add to trace via deterministic
    R_0_c = R_0[..., tf.newaxis] + delta_R_0_c
    log.debug(f"R_0_c before softplus:\n{R_0_c}")

    # Softplus because we want to make sure that R_0 > 0.
    R_0_c = tf.math.softplus(R_0_c)
    R_0_c = yield Deterministic(
        name=name,
        value=R_0_c,
        shape_label=("country"),
    )
    log.debug(f"R_0_c:\n{R_0_c}")

    # for robustness
    tf.clip_by_value(R_0_c, 1, 5)

    return tf.repeat(R_0_c[..., tf.newaxis],
                     repeats=modelParams.num_age_groups,
                     axis=-1)