Ejemplo n.º 1
0
def calculate_manfred_direction(current_x, step_size, state, gradient_weight,
                                momentum):
    cache = state["cache"]
    pos_values = _get_values_for_pseudo_gradient(current_x, step_size, 1,
                                                 cache)
    neg_values = _get_values_for_pseudo_gradient(current_x, step_size, -1,
                                                 cache)
    f0 = aggregate_evaluations(cache[hash_array(current_x)]["evals"])

    two_sided_gradient = (pos_values - neg_values) / (2 * step_size)
    right_gradient = (pos_values - f0) / step_size
    left_gradient = (f0 - neg_values) / step_size

    gradient = two_sided_gradient
    gradient = np.where(np.isnan(gradient), right_gradient, gradient)
    gradient = np.where(np.isnan(gradient), left_gradient, gradient)
    gradient = np.where(np.isnan(gradient), 0, gradient)

    gradient_direction = _normalize_direction(-gradient)

    last_x = cache[state["x_history"][-2]]["x"]
    step_direction = _normalize_direction(current_x - last_x)

    direction = (gradient_weight * gradient_direction +
                 (1 - gradient_weight) * step_direction)

    dir_hist = state["direction_history"]
    if momentum > 0 and len(dir_hist) >= 1:
        direction = momentum * dir_hist[-1] + (1 - momentum) * direction

    return direction
Ejemplo n.º 2
0
def _determine_strategies_from_residuals(current_x, state):
    x_hash = hash_array(current_x)
    evals = state["cache"][x_hash]["evals"]
    residuals = np.array(
        [evaluation["root_contributions"] for evaluation in evals])
    residual_sum = residuals.sum()

    if residual_sum > 0:
        strategies = ["left"] * len(current_x)
    else:
        strategies = ["right"] * len(current_x)

    return strategies
def _build_and_evaluate_msm_func(
    params,
    seed,
    prefix,
    fall_start_date,
    fall_end_date,
    spring_start_date,
    spring_end_date,
    mode,
    debug,
):
    """ """
    params_hash = hash_array(params["value"].to_numpy())
    share_known_path = BLD / "exploration" / f"share_known_{params_hash}_{seed}.pkl"
    if mode in ["fall", "combined"]:
        res_fall = _build_and_evaluate_msm_func_one_season(
            params=params,
            seed=seed,
            prefix=prefix,
            start_date=fall_start_date,
            end_date=fall_end_date,
            debug=debug,
        )
        res_fall["share_known_cases"].to_pickle(share_known_path)

    if mode in ["spring", "combined"]:
        res_spring = _build_and_evaluate_msm_func_one_season(
            params=params,
            seed=seed + 84587,
            prefix=prefix,
            start_date=spring_start_date,
            end_date=spring_end_date,
            debug=debug,
            group_share_known_case_path=share_known_path,
        )
    if mode == "fall":
        res = res_fall
    elif mode == "spring":
        res = res_spring
    else:
        results = [res_fall, res_spring]
        raw_weights = np.array([
            (fall_end_date - fall_start_date).days,
            (spring_end_date - spring_start_date).days,
        ])
        weights = raw_weights / raw_weights.sum()
        res = _combine_results(results, weights)

    return res
Ejemplo n.º 4
0
def _get_values_for_pseudo_gradient(current_x, step_size, sign, cache):
    x_hashes = []
    for i, val in enumerate(current_x):
        x = current_x.copy()
        if sign > 0:
            x[i] = val + step_size
        else:
            x[i] = val - step_size
        x_hashes.append(hash_array(x))

    values = []
    for x_hash in x_hashes:
        if x_hash in cache:
            values.append(aggregate_evaluations(cache[x_hash]["evals"]))
        else:
            values.append(np.nan)
    return np.array(values)
    n_evaluations_per_x = _process_n_evaluations_per_x(
        n_evaluations_per_x, len(step_sizes)
    )
    linesearch_active = _process_scalar_or_list_arg(linesearch_active, len(step_sizes))
    linesearch_frequency = _process_scalar_or_list_arg(
        linesearch_frequency, len(step_sizes)
    )

    assert 0 <= gradient_weight <= 1

    state = {
        "func_counter": 0,
        "iter_counter": 0,
        "inner_iter_counter": 0,
        "cache": {},
        "x_history": [hash_array(x)],
        "direction_history": [],
        "seed": itertools.count(seed),
    }

    do_evaluations(
        func,
        [x],
        state,
        n_evaluations_per_x[0],
        return_type="aggregated",
        batch_evaluator=batch_evaluator,
        batch_evaluator_options=batch_evaluator_options,
    )

    current_x = x
def _build_and_evaluate_msm_func_one_season(
    params,
    seed,
    prefix,
    start_date,
    end_date,
    debug,
    group_share_known_case_path=None,
):
    """Build and evaluate a msm criterion function.

    Building the criterion function freshly for each run is necessary for it to be
    parallelizable.

    """
    simulate_kwargs = load_simulation_inputs(
        "baseline",
        start_date=start_date,
        end_date=end_date,
        group_share_known_case_path=group_share_known_case_path,
        debug=debug,
        return_last_states=False,
    )
    params_hash = hash_array(params["value"].to_numpy())
    path = BLD / "exploration" / f"{prefix}_{params_hash}_{os.getpid()}"

    sim_start = simulate_kwargs["duration"]["start"]
    sim_end = simulate_kwargs["duration"]["end"]
    period_outputs = _get_period_outputs_for_simulate()

    simulate = get_simulate_func(
        **simulate_kwargs,
        params=params,
        path=path,
        seed=seed,
        period_outputs=period_outputs,
        return_time_series=False,
    )

    calc_moments = _get_calc_moments()
    rki_data = pd.read_pickle(BLD / "data" / "processed_time_series" /
                              "rki.pkl")

    age_group_info = pd.read_pickle(BLD / "data" / "population_structure" /
                                    "age_groups_rki.pkl")

    state_info = pd.read_parquet(BLD / "data" / "population_structure" /
                                 "federal_states.parquet")
    state_sizes = state_info.set_index("name")["population"]

    empirical_moments = _get_empirical_moments(
        rki_data,
        age_group_sizes=age_group_info["n"],
        state_sizes=state_sizes,
        start_date=sim_start,
        end_date=sim_end,
    )

    weight_mat = _get_weighting_matrix(
        empirical_moments=empirical_moments,
        age_weights=age_group_info["weight"],
        state_weights=state_sizes / state_sizes.sum(),
    )

    additional_outputs = {
        "infection_channels": _aggregate_infection_channels,
        "share_known_cases": _calculate_share_known_cases,
    }

    msm_func = get_msm_func(
        simulate=simulate,
        calc_moments=calc_moments,
        empirical_moments=empirical_moments,
        replace_nans=lambda x: x * 1,
        weighting_matrix=weight_mat,
        additional_outputs=additional_outputs,
    )

    res = msm_func(params)
    shutil.rmtree(path)
    return res