def test_against_ground_truth(supply_chain_name, inv_holding_rate, total_cost):

    stages, gsm = create_gsm_instance(
        GSM.Tree, os.path.join(dirname, "{}.txt".format(supply_chain_name)))

    sol = gsm.find_optimal_solution()
    sol = sol.serialize()
    policy = sol["policy"]

    inventory_costs = tree_gsm.compute_expected_inventory_costs(policy, stages)

    true_solution_filename = os.path.join(
        dirname, "{}_solution.json".format(supply_chain_name))
    with open(true_solution_filename, "r") as f:
        true_solution = json.load(f)

    assert len(policy) == len(true_solution)
    acc_total_cost = 0
    for stage_id, pol in true_solution:
        assert policy[stage_id]["s"] == pol["s"], pol["s"]
        if pol["cost"] == 0:
            assert inventory_costs[stage_id] == 0
        else:
            assert abs(inventory_costs[stage_id] * inv_holding_rate -
                       pol["cost"]) / pol["cost"] <= 0.001

        acc_total_cost += pol["cost"]

    assert abs(acc_total_cost - total_cost) / total_cost <= 0.001

    assert abs(sol["cost"] * inv_holding_rate -
               total_cost) / total_cost <= 0.001
def run_gsm_with_timeout(data_set_filename: str, gsm_type: GSM, timeout: int = 15) \
        -> Optional[Dict]:
    """
    Run the specified version of gsm
    """
    execution_start_time = datetime.utcnow()
    _, gsm = create_gsm_instance(gsm_type, data_set_filename)
    signal.signal(signal.SIGALRM, handler)
    signal.alarm(timeout)
    results = None
    try:
        solution = gsm.find_optimal_solution()  # type: GSM_Solution
        signal.alarm(0)  # Cancel timeout
        execution_time = (datetime.utcnow() -
                          execution_start_time).total_seconds()

        safety_stocks = compute_expected_inventories(solution.policy,
                                                     gsm.stages)
        base_stocks = compute_base_stocks(solution.policy, gsm.stages)
        results = dict(execution_time=execution_time,
                       solution_cost=solution.cost,
                       solution=solution.serialize(),
                       safety_stocks=safety_stocks,
                       base_stocks=base_stocks)
    except GSMException:
        print("The {} version of GSM timed out for {}".format(
            gsm_type, data_set_filename))

    return results
def test_vs_brute_force(supply_chain_name, stage_modify, new_rate):
    stages, gsm = create_gsm_instance(
        GSM.Tree, os.path.join(dirname, "{}.txt".format(supply_chain_name)))

    sol_1 = gsm.find_optimal_solution()
    sol_1 = sol_1.serialize()

    bf_solution_filename_1 = os.path.join(
        dirname, "brute_force_{}_sol_1.json".format(supply_chain_name))
    compare_with_brute_force_solution(sol_1,
                                      stages,
                                      bf_solution_filename_1,
                                      recompute=True)

    # change the balance of costs
    stages[stage_modify].cost_rate = new_rate
    sol_2 = gsm.find_optimal_solution()
    sol_2 = sol_2.serialize()

    bf_solution_filename_2 = os.path.join(
        dirname, "brute_force_{}_sol_2.json".format(supply_chain_name))
    compare_with_brute_force_solution(sol_2,
                                      stages,
                                      bf_solution_filename_2,
                                      recompute=True)

    with pytest.raises(AssertionError):
        assert_solution_policies_equal(sol_1["policy"], sol_2["policy"],
                                       stages)
def test_expected_inventories_and_basestocks_computations():
    stages, gsm = create_gsm_instance(GSM.Tree,
                                      os.path.join(dirname, "bulldozer.txt"))

    sol = gsm.find_optimal_solution()
    expected_inventories = tree_gsm.compute_expected_inventories(
        sol.policy, stages)
    base_stocks = tree_gsm.compute_base_stocks(sol.policy, stages)
    assert len(stages) == len(base_stocks)
    assert len(base_stocks) == len(expected_inventories)
def test_gsm_solution_class():
    stages, gsm = create_gsm_instance(GSM.Tree,
                                      os.path.join(dirname, "bulldozer.txt"))

    sol = gsm.find_optimal_solution()
    sol_dict = sol.serialize()

    sol_2 = tree_gsm.GSM_Solution(**sol_dict)
    sol_2_dict = sol_2.serialize()

    assert_solution_policies_equal(sol_dict["policy"], sol_2_dict["policy"],
                                   stages)
    assert sol_dict["cost"] == sol_2_dict["cost"]
def test_against_gound_truth_camera(scenario, service_time_constraints,
                                    correct_cost):

    inventory_holding_rate = 0.24

    stages, gsm = create_gsm_instance(
        GSM.Tree, os.path.join(dirname, "digital_camera.txt"))

    for stage_id, max_s_time in service_time_constraints:
        gsm.stages[stage_id].max_s_time = max_s_time

    sol = gsm.find_optimal_solution()
    sol = sol.serialize()

    solution_cost = sol["cost"] * inventory_holding_rate
    assert abs(solution_cost - correct_cost) / correct_cost < 0.01

    policy = sol["policy"]
    inventories = tree_gsm.compute_expected_inventories(policy, gsm.stages)

    if scenario == 2:
        for stage_id in stages:

            if stage_id in [
                    "Camera", "Imager", "Circuit_Board", "Other_Parts_L_60",
                    "Other_Parts_M_60", "Build_Test_Pack"
            ]:

                assert policy[stage_id]["s"] == 0
                assert inventories[stage_id] > 0
            else:
                assert inventories[stage_id] == 0
                if stage_id == "Transfer":
                    assert policy[stage_id]["s"] == 2

                elif stage_id == "Ship":
                    assert policy[stage_id]["s"] == 5

    elif scenario == 4:
        for stage_id in stages:
            if stage_id in ["Build_Test_Pack", "Ship"]:
                assert inventories[stage_id] == 0
            else:
                assert inventories[stage_id] > 0
def run_gsm_type_on_all_willems_datasets(
        dir_name: str,
        gsm_type: GSM,
        results: Dict,
        inventory_holding_rates: Dict[int, float],
        data_set_ids: List[int] = None,
        timeout: int = 15,
        json_filename: str = 'results.json') -> None:
    """
    Run the specified version of gsm of all the data sets (by default, although a subset can be
    specified).  Accumulate and print the results.

    :param dir_name: The directory within which the Willems data set config files are stored.
    :param results: Where to store the results of the experiments (cost and execution time).
    :param gsm_type: The type of the GSM model (e.g. spanning tree or clusters of commonality).
    :param inventory_holding_rates:  To modify cost.  If not specified for a data set id (key),
                                     assumed to be 1.
    :param data_set_ids:  A potentially restricted set of data set identifiers [1..38]
    :param timeout:  A number of seconds for ending a run of an algorithm on a dataset if taking
                     too long.
    :param json_filename: filename to save results dictionary. This is updated incrementally for
        each dataset
    """

    assert all(list(map(lambda x: 0 < x < 39,
                        data_set_ids)))  # Check data set ids are valid
    assert results["metadata"][
        "timeout"] == timeout, 'cannot mix results with diff timeouts'

    for data_set_id in data_set_ids:
        data_set_filename = get_willems_filename(data_set_id)
        try:
            execution_start_time = datetime.utcnow()

            _, gsm = create_gsm_instance(
                gsm_type, os.path.join(dir_name, data_set_filename))

            signal.signal(signal.SIGALRM, handler)
            signal.alarm(timeout)
            try:
                # TODO: could use run_gsm_with_timeout
                sol = gsm.find_optimal_solution()  # type: GSM_Solution

                signal.alarm(0)  # Cancel timeout

                sol_s = sol.serialize()

                solution_cost = sol_s["cost"] * inventory_holding_rates.get(
                    data_set_id, 1)

                execution_time = (datetime.utcnow() -
                                  execution_start_time).total_seconds()

                print("In {:6} seconds, the {} version of GSM computes a "
                      "total cost of {} for {}".format(
                          round(execution_time, 2), gsm_type, solution_cost,
                          data_set_filename))

                results["data"].setdefault(data_set_id, {})["{}".format(gsm_type)] = \
                    {"execution_time": execution_time,
                     "solution_cost": solution_cost}

                if json_filename is not None:
                    with open(json_filename, 'w') as fp:
                        json.dump(results, fp)

            except GSMTimeOutException:
                print("The {} version of GSM timed out for {}".format(
                    gsm_type, data_set_filename))

        except UnSupportedGSMException:
            print("Skipping all files as model type not supported".format())
            break
        except IncompatibleGraphTopology:
            print("Skipping {} as not a compatible topology".format(
                data_set_filename))
            continue
        except InconsistentGSMConfiguration:
            print(
                "Skipping {} as network topology labels are not as expected.".
                format(data_set_filename))
            continue
def test_optimal_solution_invariance_to_labeling():
    stages, gsm = create_gsm_instance(GSM.Tree,
                                      os.path.join(dirname, "bulldozer.txt"))

    check_labeling_invariance(gsm_obj=gsm)