def bound_variables(
        self,
        seir_timeseries_variables,
    ):
        """See parent class."""

        (vaccinated_ratio_list, vaccine_effectiveness_list,
         average_contact_id_list, average_contact_iud_list,
         reinfectable_rate_list, alpha_list, diagnosis_rate_list,
         recovery_rate_id_list, recovery_rate_iud_list, recovery_rate_h_list,
         hospitalization_rate_list, death_rate_id_list,
         death_rate_h_list) = seir_timeseries_variables

        vaccinated_ratio = model_utils.apply_relu_bounds(
            vaccinated_ratio_list[-1], 0.0, 1.0)
        vaccine_effectiveness = model_utils.apply_relu_bounds(
            vaccine_effectiveness_list[-1], 0.0, 1.0)
        average_contact_id = 1.0 * tf.nn.sigmoid(average_contact_id_list[-1])
        average_contact_iud = 1.0 * tf.nn.sigmoid(average_contact_iud_list[-1])
        reinfectable_rate = 0.001 * tf.nn.sigmoid(reinfectable_rate_list[-1])
        alpha = 0.2 * tf.nn.sigmoid(alpha_list[-1])
        diagnosis_rate = 0.01 + 0.09 * tf.nn.sigmoid(diagnosis_rate_list[-1])
        recovery_rate_id = 0.1 * tf.nn.sigmoid(recovery_rate_id_list[-1])
        recovery_rate_iud = 0.1 * tf.nn.sigmoid(recovery_rate_iud_list[-1])
        recovery_rate_h = 0.005 + 0.095 * tf.nn.sigmoid(
            recovery_rate_h_list[-1])
        hospitalization_rate = 0.005 + 0.095 * tf.nn.sigmoid(
            hospitalization_rate_list[-1])
        death_rate_id = 0.01 * tf.nn.sigmoid(death_rate_id_list[-1])
        death_rate_h = 0.1 * tf.nn.sigmoid(death_rate_h_list[-1])

        return (vaccinated_ratio, vaccine_effectiveness, average_contact_id,
                average_contact_iud, reinfectable_rate, alpha, diagnosis_rate,
                recovery_rate_id, recovery_rate_iud, recovery_rate_h,
                hospitalization_rate, death_rate_id, death_rate_h)
    def bound_variables(
        self,
        seir_timeseries_variables,
    ):
        """See parent class."""

        (first_dose_vaccine_ratio_per_day_list,
         second_dose_vaccine_ratio_per_day_list, average_contact_id_list,
         average_contact_iud_list, reinfectable_rate_list, alpha_list,
         diagnosis_rate_list, recovery_rate_id_list, recovery_rate_iud_list,
         recovery_rate_h_list, recovery_rate_i_list, recovery_rate_v_list,
         hospitalization_rate_list, icu_rate_list, ventilator_rate_list,
         death_rate_id_list, death_rate_h_list, death_rate_i_list,
         death_rate_v_list) = seir_timeseries_variables

        first_dose_vaccine_ratio_per_day = model_utils.apply_relu_bounds(
            first_dose_vaccine_ratio_per_day_list[-1], 0.0, 1.0)
        second_dose_vaccine_ratio_per_day = model_utils.apply_relu_bounds(
            second_dose_vaccine_ratio_per_day_list[-1], 0.0, 1.0)
        average_contact_id = 1.0 * tf.nn.sigmoid(average_contact_id_list[-1])
        average_contact_iud = 1.0 * tf.nn.sigmoid(average_contact_iud_list[-1])
        reinfectable_rate = 0.001 * tf.nn.sigmoid(reinfectable_rate_list[-1])
        alpha = 0.2 * tf.nn.sigmoid(alpha_list[-1])
        diagnosis_rate = 0.01 + 0.09 * tf.nn.sigmoid(diagnosis_rate_list[-1])
        recovery_rate_id = 0.1 * tf.nn.sigmoid(recovery_rate_id_list[-1])
        recovery_rate_iud = 0.1 * tf.nn.sigmoid(recovery_rate_iud_list[-1])
        recovery_rate_h = 0.1 * tf.nn.sigmoid(recovery_rate_h_list[-1])
        recovery_rate_i = 0.1 * tf.nn.sigmoid(recovery_rate_i_list[-1])
        recovery_rate_v = 0.1 * tf.nn.sigmoid(recovery_rate_v_list[-1])
        hospitalization_rate = 0.1 * tf.nn.sigmoid(
            hospitalization_rate_list[-1])
        icu_rate = 0.1 * tf.nn.sigmoid(icu_rate_list[-1])
        ventilator_rate = 0.01 + 0.19 * tf.nn.sigmoid(ventilator_rate_list[-1])
        death_rate_id = 0.01 * tf.nn.sigmoid(death_rate_id_list[-1])
        death_rate_h = 0.1 * tf.nn.sigmoid(death_rate_h_list[-1])
        death_rate_i = 0.1 * tf.nn.sigmoid(death_rate_i_list[-1])
        death_rate_v = 0.01 + 0.09 * tf.nn.sigmoid(death_rate_v_list[-1])

        return (first_dose_vaccine_ratio_per_day,
                second_dose_vaccine_ratio_per_day, average_contact_id,
                average_contact_iud, reinfectable_rate, alpha, diagnosis_rate,
                recovery_rate_id, recovery_rate_iud, recovery_rate_h,
                recovery_rate_i, recovery_rate_v, hospitalization_rate,
                icu_rate, ventilator_rate, death_rate_id, death_rate_h,
                death_rate_i, death_rate_v)
    def apply_quantile_transform(self,
                                 hparams,
                                 propagated_states,
                                 quantile_kernel,
                                 quantile_biases,
                                 ground_truth_timeseries,
                                 num_train_steps,
                                 num_forecast_steps,
                                 num_quantiles=23,
                                 epsilon=1e-8,
                                 is_training=True,
                                 initial_quantile_step=0):
        """Transform predictions into vector representing different quantiles.

    Args:
      hparams: Hyperparameters.
      propagated_states: single value predictions, its dimensions represent
        timestep * states * location.
      quantile_kernel: Quantile mapping kernel.
      quantile_biases: Biases for quantiles.
      ground_truth_timeseries: Ground truth time series.
      num_train_steps: number of train steps
      num_forecast_steps: number of forecasting steps
      num_quantiles: Number of quantiles
      epsilon: A small number to avoid 0 division issues.
      is_training: Whether the phase is training or inference.
      initial_quantile_step: start index for quantile training

    Returns:
      Vector value predictions of size
        timestep * states * location * num_quantiles
    """
        (_, gt_list, gt_indicator, _, _) = ground_truth_timeseries

        unstacked_propagated_states = tf.unstack(propagated_states, axis=1)
        pred_infected = unstacked_propagated_states[1]
        pred_recovered = unstacked_propagated_states[3]
        pred_hospitalized = unstacked_propagated_states[5]
        pred_icu = unstacked_propagated_states[8]
        pred_ventilator = unstacked_propagated_states[9]
        pred_death = unstacked_propagated_states[10]
        pred_reinfected = unstacked_propagated_states[12]

        pred_confirmed = (pred_infected + pred_recovered + pred_death +
                          pred_hospitalized + pred_icu + pred_ventilator +
                          pred_reinfected)

        quantile_encoding_window = hparams["quantile_encoding_window"]
        smooth_coef = hparams["quantile_smooth_coef"]
        partial_mean_interval = hparams["partial_mean_interval"]

        quantile_mapping_kernel = tf.math.softplus(
            tf.expand_dims(quantile_kernel, 2))
        quantile_biases = tf.math.softplus(tf.expand_dims(quantile_biases, 1))

        propagated_states_quantiles = []
        state_quantiles_multiplier_prev = tf.ones_like(
            tf.expand_dims(propagated_states[0, :, :], 2))

        def gt_ratio_feature(gt_values, predicted):
            """Creates the GT ratio feature."""

            # This uses the imputed values when the values are not valid.
            ratio_pred = (1 - (predicted[:num_train_steps, :] /
                               (epsilon + gt_values[:num_train_steps])))
            # Add 0 at the beginning
            ratio_pred = tf.concat([
                0 * ratio_pred[:(quantile_encoding_window +
                                 num_forecast_steps), :], ratio_pred
            ],
                                   axis=0)
            ratio_pred = tf.expand_dims(ratio_pred, 1)
            ratio_pred = tf.tile(ratio_pred, [1, self.num_states, 1])
            return ratio_pred

        def indicator_feature(gt_indicator):
            """Creates the indicator feature."""

            indicator = 1. - gt_indicator
            # Add 0 at the beginning
            indicator = tf.concat([
                0 *
                indicator[:(quantile_encoding_window + num_forecast_steps), :],
                indicator
            ],
                                  axis=0)
            indicator = tf.expand_dims(indicator, 1)
            indicator = tf.tile(indicator, [1, self.num_states, 1])
            return indicator

        # Propagated states features
        temp_propagated_states = tf.concat([
            0 * propagated_states[:quantile_encoding_window, :, :],
            propagated_states
        ],
                                           axis=0)

        # GT ratio features
        death_gt_ratio_feature = gt_ratio_feature(gt_list["death"], pred_death)
        confirmed_gt_ratio_feature = gt_ratio_feature(gt_list["confirmed"],
                                                      pred_confirmed)
        hospitalized_gt_ratio_feature = gt_ratio_feature(
            gt_list["hospitalized"], pred_hospitalized)

        # Indicator features
        death_indicator_feature = indicator_feature(gt_indicator["death"])
        confirmed_indicator_feature = indicator_feature(
            gt_indicator["confirmed"])
        hospitalized_indicator_feature = indicator_feature(
            gt_indicator["hospitalized"])

        for ti in range(initial_quantile_step,
                        num_train_steps + num_forecast_steps):

            if ti < num_train_steps:
                state_quantiles_multiplier = tf.ones_like(
                    tf.expand_dims(propagated_states[0, :, :], 2))
                state_quantiles_multiplier = tf.tile(
                    state_quantiles_multiplier, [1, 1, num_quantiles])
            else:
                # Construct the input features to be used for quantile estimation.
                encoding_input = []

                # Features coming from the trend of the estimated.
                encoding_input.append(
                    1 - (temp_propagated_states[ti:(
                        ti + quantile_encoding_window), :, :] /
                         (epsilon + temp_propagated_states[
                             ti + quantile_encoding_window, :, :])))

                # Features coming from the ground truth ratio of death.
                encoding_input.append(death_gt_ratio_feature[ti:(
                    ti + quantile_encoding_window), :, :])
                # Features coming from the ground truth ratio of confirmed.
                encoding_input.append(confirmed_gt_ratio_feature[ti:(
                    ti + quantile_encoding_window), :, :])
                # Features coming from the ground truth ratio of hospitalized.
                encoding_input.append(hospitalized_gt_ratio_feature[ti:(
                    ti + quantile_encoding_window), :, :])

                # Features coming from death indicator.
                encoding_input.append(death_indicator_feature[ti:(
                    ti + quantile_encoding_window), :, :])
                # Features coming from confirmed indicator.
                encoding_input.append(confirmed_indicator_feature[ti:(
                    ti + quantile_encoding_window), :, :])
                # Features coming from hospitalized indicator.
                encoding_input.append(hospitalized_indicator_feature[ti:(
                    ti + quantile_encoding_window), :, :])

                encoding_input_t = tf.expand_dims(
                    tf.concat(encoding_input, axis=0), 3)

                # Limit the range of features.
                encoding_input_t = model_utils.apply_relu_bounds(
                    encoding_input_t,
                    lower_bound=0.0,
                    upper_bound=2.0,
                    replace_nan=True)

                # Estimate the multipliers of quantiles
                state_quantiles_multiplier = quantile_biases + tf.math.reduce_mean(
                    tf.multiply(encoding_input_t, quantile_mapping_kernel), 0)

                # Consider accumulation to guarantee monotonicity
                state_quantiles_multiplier = tf.math.cumsum(
                    state_quantiles_multiplier, axis=-1)
                if partial_mean_interval == 0:
                    # Normalize to match the median to point forecasts
                    state_quantiles_multiplier /= (epsilon + tf.expand_dims(
                        state_quantiles_multiplier[:, :,
                                                   (num_quantiles - 1) // 2],
                        -1))
                else:
                    # Normalize with major densities to approximate point forecast (mean)
                    median_idx = (num_quantiles - 1) // 2
                    normalize_start = median_idx - partial_mean_interval
                    normalize_end = median_idx + partial_mean_interval
                    normalizer = tf.reduce_mean(
                        0.5 *
                        (state_quantiles_multiplier[:, :, normalize_start:
                                                    normalize_end] +
                         state_quantiles_multiplier[:, :, normalize_start +
                                                    1:normalize_end + 1]),
                        axis=2,
                        keepdims=True)
                    state_quantiles_multiplier /= (epsilon + normalizer)

                state_quantiles_multiplier = (
                    smooth_coef * state_quantiles_multiplier_prev +
                    (1 - smooth_coef) * state_quantiles_multiplier)

            state_quantiles_multiplier_prev = state_quantiles_multiplier

            # Return the estimated quantiles
            propagated_states_quantiles_timestep = tf.multiply(
                tf.expand_dims(propagated_states[ti, :, :], 2),
                state_quantiles_multiplier)

            propagated_states_quantiles.append(
                propagated_states_quantiles_timestep)

        return tf.stack(propagated_states_quantiles)