コード例 #1
0
ファイル: file_oracle.py プロジェクト: Philipp15b/PrIC3
def _load_oracle_dict(state_graph: StateGraph,
                      filename="oracle.pr") -> Dict[StateId, z3.ExprRef]:
    """
    Load an oracle's dict from a file.

    Note that the state ids are different on each load!
    """
    fd = open(filename, "rb")
    data: _SerializedData = pickle.load(fd)
    fd.close()

    oracle: Dict[StateId, z3.ExprRef] = dict()

    to_consider: Set[StateId] = {state_graph.get_initial_state_id()}
    seen: Set[StateId] = set()

    while len(data) != 0:
        to_consider = to_consider - seen
        to_add: Set[StateId] = set()

        for state_id in to_consider:
            seen.add(state_id)

            state_val_str = str(state_graph.get_state_valuation(state_id))
            if state_val_str in oracle:
                oracle[state_id] = RealVal(data[state_val_str])
                del data[state_val_str]

            to_add = to_add.union({
                state
                for state, prob in state_graph.get_successor_distribution(
                    state_id)
            })

    return oracle
コード例 #2
0
    def __init__(self, state_graph: StateGraph, default_value: Fraction,
                 statistics: Statistics, settings: Settings,
                 model_type: PrismModelType):

        self.state_graph = state_graph
        self.statistics = statistics
        self.settings = settings
        self.model_type = model_type

        if default_value < 0:
            raise ValueError("Oracle values must be greater or equal to 0")

        self.default_value = RealVal(default_value)

        self.solver = Solver()
        self.solver_mdp = Optimize()

        # The way we refine the Oracle depends on the model type
        if model_type == PrismModelType.DTMC:
            self.refine_oracle = self.refine_oracle_mc

        elif model_type == PrismModelType.MDP:
            self.refine_oracle = self.refine_oracle_mdp

        else:
            raise Exception("Oracle: Unsupported model type")

        self.oracle_states: Set[StateId] = set()

        self.oracle: Dict[StateId, z3.ExprRef] = dict()
コード例 #3
0
def z3_real_floored_division(divisor, dividend):
    res = compare_solver.check(real_div_res == divisor / dividend)

    if res == unsat or res == unknown:
        raise

    else:
        return RealVal(
            math.floor(compare_solver.model()[real_div_res].as_fraction()))
コード例 #4
0
    def _ensure_value_in_oracle(self, state_id):
        """
        Used to override standard behaviour. Takes a state id, ensures that self.oracle contains this value.
        Invoked by get_oracle_value(state) in case state is no oracle state.
        :param state_id:
        :return:
        """

        #Design choice:
        self.oracle[state_id] = RealVal(self.settings.default_oracle_value)
コード例 #5
0
    def __init__(self, state_graph, statistics, settings, model_type):
        self.statistics = statistics
        self.state_graph = state_graph
        self.model_type = model_type

        logger.debug("Initialize oracle...")

        self._initialize_oracle(settings)

        logger.debug("Initialize obligation cache...")
        self._obligation_cache = ObligationCache()
        logger.debug("Initialize optimization solver...")
        # Initialize solver for optimization queries
        self.opt_solver = Optimize()

        self._realval_zero = RealVal(0)
        self._realval_one = RealVal(1)

        self.obligation_queue_class = settings.get_obligation_queue_class()
コード例 #6
0
    def _create_inexact_arithmetic_oracle(self, states):

        # Numpy needs a coefficient list. Build a map from indices to states.
        index_to_state = dict()
        state_to_index = dict()

        rhs_list = []
        coeff_list = []  # A list of lists, each element is a coefficient list

        i = 0
        for state in states:
            index_to_state[i] = state
            state_to_index[state] = i
            i = i + 1

            succ_dist = self.state_graph.get_filtered_successors(state)

            # The rhs of the equation for the state is -(sum of probs leading to a goal state)
            rhs = sum([
                float(prob) for state_id, prob in succ_dist if state_id == -1
            ])
            rhs_list.append(-rhs)

            coeff_list.append([0 for j in range(0, len(states))])

        # Now go through every state again and build its coefficient list
        for state in states:
            # First set the coefficient of state to -1
            coeff_list[state_to_index[state]][state_to_index[state]] = -1

            succ_dist = self.state_graph.get_filtered_successors(state)

            #Now add the probability for every non-target succ to its coefficient
            for succ_id, prob in succ_dist:
                if succ_id != -1 and succ_id in states:
                    coeff_list[state_to_index[state]][
                        state_to_index[succ_id]] += float(prob)

        coeff_matrix = np.array(coeff_list)
        rhs = np.array(rhs_list)

        solution = np.linalg.solve(coeff_matrix, rhs)
        #print(solution[state_to_index[self.state_graph.get_initial_state_id()]])

        solution = solution.astype(float)

        # Update oracle
        for state in states:
            self.oracle[state] = RealVal(solution[state_to_index[state]])
コード例 #7
0
    def refine_oracle_mc(self, visited_states: Set[StateId]) -> Set[StateId]:

        self.statistics.inc_refine_oracle_counter()
        # First ensure progress
        if visited_states <= self.oracle_states:
            # Ensure progress by adding all non-target successors of states in oracle_states to the set
            self.oracle_states = self.oracle_states.union({
                succ_id
                for state_id in self.oracle_states for succ_id, prob in
                self.state_graph.get_filtered_successors(state_id)
                if succ_id != -1
            })

        else:
            self.oracle_states = self.oracle_states.union(visited_states)

        # TODO: A lot of optimization potential
        self.solver.push()

        # We need a variable for every oracle state
        variables = {
            state_id: Real("x_%s" % state_id)
            for state_id in self.oracle_states
        }

        # Set up EQ - System
        for state_id in self.oracle_states:
            self.solver.add(variables[state_id] == Sum([
                RealVal(1) *
                prob if succ_id == -1 else  # Case succ_id target state
                (
                    variables[succ_id] * prob if succ_id in
                    self.oracle_states else  # Case succ_id oracle state
                    self.get_oracle_value(succ_id) *
                    prob)  # Case sycc_id no target and no oracle state
                for succ_id, prob in self.state_graph.get_filtered_successors(
                    state_id)
            ]))

            self.solver.add(variables[state_id] >= RealVal(0))

        #print(self.solver.assertions())

        if self.solver.check() == sat:

            m = self.solver.model()

            # update oracle
            for state_id in self.oracle_states:
                self.oracle[state_id] = m[variables[state_id]]

            logger.info("Refined oracle.")
            #logger.info(self.oracle)

            self.solver.pop()

            return self.oracle_states

        else:

            # The oracle solver is unsat. In this case, we solve the LP.
            self.solver.pop()

            self.statistics.refine_oracle_counter = self.statistics.refine_oracle_counter - 1

            return self.refine_oracle_mdp(visited_states)
コード例 #8
0
    def refine_oracle_mdp(self, visited_states: Set[StateId]) -> Set[StateId]:

        self.statistics.inc_refine_oracle_counter()
        # First ensure progress
        if visited_states <= self.oracle_states:
            # Ensure progress by adding all non-target successors of states in oracle_states to the set (for every action)
            self.oracle_states = self.oracle_states.union({
                succ[0]
                for state_id in self.oracle_states for choice in
                self.state_graph.get_successors_filtered(state_id).choices
                for succ in choice.distribution if succ[0] != -1
            })

        else:
            self.oracle_states = self.oracle_states.union(visited_states)

        # TODO: A lot of optimization potential
        self.solver_mdp.push()

        # We need a variable for every oracle state
        variables = {
            state_id: Real("x_%s" % state_id)
            for state_id in self.oracle_states
        }

        # Set up EQ - System
        for state_id in self.oracle_states:
            for choice in self.state_graph.get_successors_filtered(
                    state_id).choices:
                self.solver_mdp.add(variables[state_id] >= Sum([
                    RealVal(1) *
                    prob if succ_id == -1 else  # Case succ_id target state
                    (
                        variables[succ_id] * prob if succ_id in
                        self.oracle_states else  # Case succ_id oracle state
                        self.get_oracle_value(succ_id) *
                        prob)  # Case sycc_id no target and no oracle state
                    for succ_id, prob in choice.distribution
                ]))

            self.solver_mdp.add(variables[state_id] >= RealVal(0))

        # Minimize value for initial state
        self.solver_mdp.minimize(
            variables[self.state_graph.get_initial_state_id()])

        if self.solver_mdp.check() == sat:

            m = self.solver_mdp.model()

            # update oracle
            for state_id in self.oracle_states:
                self.oracle[state_id] = m[variables[state_id]]

            logger.info("Refined oracle.")
            # logger.info(self.oracle)

            self.solver_mdp.pop()

            return self.oracle_states

        else:
            logger.error("Oracle solver unsat")
            raise RuntimeError("Oracle solver inconsistent.")
コード例 #9
0
ファイル: input_program.py プロジェクト: Philipp15b/PrIC3
    def translate_expression(self, variables,
                             prism_expr: stormpy.Expression) -> z3.ExprRef:
        # Translate primitive types.
        # TODO: missing types!
        if isinstance(prism_expr, bool):
            # Note: the order of comparing bool and int types is significant,
            # as bool is a subtype of int in Python.
            return BoolVal(prism_expr)
        elif isinstance(prism_expr, int):
            cache_hit = self._int_cache.get(prism_expr)
            if cache_hit is not None:
                return cache_hit
            z3intval = INTVAL_CTOR(prism_expr)
            self._int_cache[prism_expr] = z3intval
            return z3intval
        # If we have a variable from the environment, look up its z3 variable.
        elif prism_expr.is_variable() and variables[prism_expr.identifier()]:
            return variables[prism_expr.identifier()].variable
        # Otherwise, evaluate the value if it is a variable or a literal.
        elif prism_expr.is_variable() or prism_expr.is_literal():
            if prism_expr.has_boolean_type():
                return BoolVal(prism_expr.evaluate_as_bool())
            elif prism_expr.has_integer_type():
                intval = prism_expr.evaluate_as_int()
                cache_hit = self._int_cache.get(intval)
                if cache_hit is not None:
                    return cache_hit
                z3intval = INTVAL_CTOR(intval)
                self._int_cache[intval] = z3intval
                return z3intval
            elif prism_expr.has_rational_type():
                rational_value = prism_expr.evaluate_as_rational()
                return RealVal(
                    Fraction(Fraction(str(rational_value.numerator)),
                             Fraction(str(rational_value.denominator))))
        # special case for sot.Divide: we intentionally only support division for constant values
        elif prism_expr.is_function_application and prism_expr.operator == stormpy.OperatorType.Divide:
            return Fraction(
                Fraction(prism_expr.get_operand(0).evaluate_as_int()),
                Fraction(prism_expr.get_operand(1).evaluate_as_int()))
        # Lastly, handle function applications.
        elif prism_expr.is_function_application:
            sot = stormpy.OperatorType
            operators = {
                sot.And:
                And,
                sot.Or:
                Or,
                sot.Xor:
                operator.xor,
                sot.Implies:
                z3.Implies,
                sot.Iff:
                operator.eq,
                sot.Plus:
                operator.add,
                sot.Minus:
                operator.sub,
                sot.Times:
                operator.mul,
                # sot.Divide is handled above
                # sot.Min: z3.Min,
                # sot.Max: z3.Max,
                # TODO: Power, Modulo missing
                sot.Equal:
                operator.eq,
                sot.NotEqual:
                operator.ne,
                sot.Less:
                operator.lt,
                sot.LessOrEqual:
                operator.le,
                sot.Greater:
                operator.gt,
                sot.GreaterOrEqual:
                operator.ge,
                sot.Not:
                operator.neg,
                # TODO: Floor and Ceil missing
                sot.Ite:
                z3.If
            }
            op_fn = operators[prism_expr.operator]

            operands = (_translate_expression(variables,
                                              prism_expr.get_operand(i))
                        for i in range(prism_expr.arity))

            return op_fn(*operands)

        raise NotImplementedError()
コード例 #10
0
    def run(self, state_id, chosen_command, delta, states_with_fixed_probabilities = set()):
        """

        :param state_id: 
        :param delta: 
        :return: (1) True iff it is possible to find probabilities for the successors of the given state_id and delta.
                 (2) If True, then it returns a dict form succ_ids to probabilities. This dict does not contain goal states.
        """
        # TODO consider changing to None if not possible, and dict otherwise.
        self.statistics.inc_get_probability_counter()
        self.statistics.start_get_probability_timer()

        # First check whether we have cached the corresponding obligation
        res = self._obligation_cache.get_cached(state_id, chosen_command, delta)

        if res != False:
            self.statistics.stop_get_probability_timer()
            return (True, res)

        # If not, we have to ask the SMT-Solver
        succ_dist = self.state_graph.get_successors_filtered(state_id).by_command_index(chosen_command)

        succ_dist_without_target_states = [(state_id, prob)
                                           for (state_id, prob) in succ_dist
                                           if state_id != -1]

        # Check if there is at least one non-target state. Otherwise, repairing is not possible (smt solver would return unsat if we continued, so checking this is an optimization).
        if len(succ_dist_without_target_states) == 0:
            self.statistics.stop_get_probability_timer()
            return (False, None)

        self.opt_solver.push()
        vars = {}

        # We need a variable for each successor
        for (succ_id, prob) in succ_dist:
            if succ_id != -1:
                vars[succ_id] = Real("x_%s" % succ_id)

                # all results must of be probabilities
                self.opt_solver.add(vars[succ_id] >= self._realval_zero)
                self.opt_solver.add(vars[succ_id] <= self._realval_one)

        # \Phi(F)[s] = delta constraint
        # TODO: Type of porb is pycarl.gmp.gmp.Rational. Z3 magically deals with this
        self.opt_solver.add(
            Sum([
                (vars[succ_id] if succ_id != -1 else RealVal(1)) * prob
                # Note: Keep in mind that you need to check whether succ is a target state
                for (succ_id, prob) in succ_dist
            ]) == delta)

        for (succ_id, prob) in succ_dist:
            if succ_id in states_with_fixed_probabilities:
                self.opt_solver.add(vars[succ_id] == self.obligation_queue_class.smallest_probability_for_state[succ_id])


        # If we have more than one non-target successor, we have to optimize
        if len(succ_dist_without_target_states) > 1:

            # first check whether all oracle values are 0 (note that we do not have to do this if there is only one succ without target)

            if len(succ_dist_without_target_states) > 1 and sum([self.oracle.get_oracle_value(state_id).as_fraction() for state_id, prob in
                    succ_dist_without_target_states]) == 0:

                # In this case, we require that the probability mass is distributed equally
                for i in range(0, len(succ_dist_without_target_states) - 1):
                    self.opt_solver.add(
                        vars[succ_dist_without_target_states[i][0]] == vars[succ_dist_without_target_states[i + 1][0]])

            else:

                # First Try to solve the eq system
                # TODO: Do not use opt_solver for this
                if self._get_probabilities_by_solving_eq_system(succ_dist_without_target_states, vars):

                    self.statistics.inc_solved_eq_system_instead_of_optimization_counter()
                    m = self.opt_solver.model()

                    result = {
                        succ_id: m[vars[succ_id]]
                        for (succ_id, prob) in succ_dist_without_target_states
                    }

                    # Because get_probabilities_by_solving_eq_system pushes
                    # TODO: This is ugly
                    # TODO: Compare solve-eq-system-time with optimization-problem-time
                    self.opt_solver.pop()
                    self.opt_solver.pop()

                    self._obligation_cache.cache(state_id, chosen_command, delta, result)
                    self.statistics.stop_get_probability_timer()
                    return (True, result)

                else:

                    self.statistics.inc_had_to_solve_optimization_problem_counter()
                    # for each non-target-succ, we need n opt-var
                    opt_vars = {}

                    # For every non-target successor, we need an optimization variable
                    for (succ_id, prob) in succ_dist_without_target_states:
                        opt_vars[succ_id] = Real("opt_var_%s" % succ_id)

                    # Now assert that opt_var_i = |var_i \ (var_1 + ... + var_n)   -   oracle(s_i) \ ( oracle(s_1) + ... + oracle(s_n ) |
                    # for every opt_var_i
                    for (succ_id, prob) in succ_dist_without_target_states:
                        # opt_var is the absolute value of the ratio
                        self.opt_solver.add(
                            If(((vars[succ_id] * Sum([
                                self.oracle.get_oracle_value(succ_id_2) for
                                (succ_id_2, prob) in succ_dist_without_target_states
                            ])) - ((self.oracle.get_oracle_value(succ_id) * Sum([
                                vars[succ_id_2] for
                                (succ_id_2, prob) in succ_dist_without_target_states
                            ])))) < 0, opt_vars[succ_id] ==
                               (((self.oracle.get_oracle_value(succ_id) * Sum([
                                   vars[succ_id_2] for
                                   (succ_id_2, prob) in succ_dist_without_target_states
                               ]))) - (vars[succ_id] * Sum([
                                   self.oracle.get_oracle_value(succ_id_2) for
                                   (succ_id_2, prob) in succ_dist_without_target_states
                               ]))), opt_vars[succ_id] == ((vars[succ_id] * Sum([
                                self.oracle.get_oracle_value(succ_id_2) for
                                (succ_id_2, prob) in succ_dist_without_target_states
                            ])) - ((self.oracle.get_oracle_value(succ_id) * Sum([
                                vars[succ_id_2] for
                                (succ_id_2, prob) in succ_dist_without_target_states
                            ]))))))

                        # minimize sum of opt-vars
                        opt = self.opt_solver.minimize(
                            Sum([
                                opt_vars[succ_id]
                                for (succ_id, prob) in succ_dist_without_target_states
                            ]))

        if self.opt_solver.check() == sat:
            # We found probabilities or the successors
            m = self.opt_solver.model()

            result = {
                succ_id: m[vars[succ_id]]
                for (succ_id, prob) in succ_dist_without_target_states
            }
            self.opt_solver.pop()

            self._obligation_cache.cache(state_id, chosen_command, delta, result)

            self.statistics.stop_get_probability_timer()

            return (True, result)

        else:
            # There are no such probabilities
            self.opt_solver.pop()
            self.statistics.stop_get_probability_timer()

            return (False, None)
コード例 #11
0
ファイル: conv_safety_solve.py プロジェクト: gwgundersen/DLV
def getDecimalValue(v0): 
    v = RealVal(str(v0))
    return float(v.numerator_as_long())/v.denominator_as_long()
コード例 #12
0
    def get_generalization_linear_functions(self, frame_id, state_id, low_val,
                                            low_delta, high_val, high_delta,
                                            input_variable):

        data_points = [(low_val, low_delta), (high_val, high_delta)]

        if z3_values_check_eq(low_val, high_val):
            return []

        state_valuation = self.state_graph.get_state_valuation(state_id)

        linear_function = self._interpolator.get_interpolating_polynomial(
            data_points, input_variable.variable)

        state_args = [
            eq_no_coerce(var.variable, val)
            if var != input_variable else ge_no_coerce(var.variable, low_val)
            for var, val in state_valuation.items()
        ] + [input_variable.variable <= high_val]

        rel_ind_result = self.p_solver.is_relative_inductive(
            frame_id, state_args, linear_function)

        if rel_ind_result == True:
            # print("We were able to generalize! (inputvar:  %s)" % (input_variable.name))
            # print('Linear function for %s in [%s, %s]:  %s  (for frame %s)' % (
            #     input_variable.name, low_val, high_val, linear_function, frame_id + 1))

            if self.is_poly_probability(linear_function, low_val, high_val,
                                        input_variable.variable):
                return [(low_val, low_delta, high_val, high_delta,
                         linear_function)]

            else:
                return []

        else:
            i = 1
            while i <= self.settings.max_num_ctgs:
                i = i + 1
                # interpolate from start_val to intermediate val and from intermediate val to ctg
                # but do so only if start_val != intermediate val

                value_of_input_variable_from_ctg = rel_ind_result[
                    input_variable.variable]

                #print('CTG: Value of input var %s   (CTG %s):    %s' % (
                # input_variable.name, i, value_of_input_variable_from_ctg))

                ctg_delta = self.approximate_phi_value_for_state(
                    frame_id, [
                        eq_no_coerce(var.variable, val)
                        if var != input_variable else eq_no_coerce(
                            var.variable, value_of_input_variable_from_ctg)
                        for var, val in state_valuation.items()
                    ])

                if z3_values_check_eq(ctg_delta, RealVal(1)):
                    #print('delta is 1')
                    return []

                if z3_values_check_eq(value_of_input_variable_from_ctg,
                                      low_val):
                    return []

                data_points.pop()
                data_points.append(
                    (value_of_input_variable_from_ctg, ctg_delta))

                linear_function = self._interpolator.get_interpolating_polynomial(
                    data_points, input_variable.variable)

                state_args = [
                    eq_no_coerce(var.variable, val) if var != input_variable
                    else ge_no_coerce(var.variable, low_val)
                    for var, val in state_valuation.items()
                ] + [
                    input_variable.variable <= value_of_input_variable_from_ctg
                ]

                rel_ind_result = self.p_solver.is_relative_inductive(
                    frame_id, state_args, linear_function)
                if rel_ind_result == True:
                    #print("We were able to generalize! (inputvar:  %s)" % (input_variable.name))
                    #print('Linear function for %s in [%s, %s]:  %s  (for frame %s)' % (
                    #   input_variable.name, low_val, value_of_input_variable_from_ctg, linear_function, frame_id + 1))

                    if self.is_poly_probability(linear_function, low_val,
                                                high_val,
                                                input_variable.variable):
                        return [(low_val, low_delta,
                                 value_of_input_variable_from_ctg, ctg_delta,
                                 linear_function)]

                    else:
                        return []

            return []
コード例 #13
0
    def get_generalization_for_variable(self, frame_index, state_id,
                                        start_value, start_delta,
                                        input_variable):
        """
        This procedure tries to find an univariate polynomial in (the z3 variable of) input variable, which is a valid generalization
        (i.e. rel. ind. to frame_id) of the constraint
               state_vars = stat_val_of(state_id)      =>      Frame <= delta.


        :param state_id: The ID of the state the constraint that is to be generalized is talking about.
        :param delta: The delta (probability) of the constraint that is to be generalized.
        :param input_variable: The correpsonding variable that is to be "dropped" by the genralization.
        :return: A pair (state_args_describing_the_states_the_generalization_talks_about, corresponding_polyomial), if generalization possible,
                (state_args, original_delta) otherwhise.
        """
        #print("")
        state_valuation = self.state_graph.get_state_valuation(state_id)

        # we do not generalize if delta = 1
        if start_delta == RealVal(1):
            return [([
                eq_no_coerce(var.variable, val) if var != input_variable else
                eq_no_coerce(var.variable, start_value)
                for var, val in state_valuation.items()
            ], start_delta)]

        #Generalize only if we have a state of the same kind different form the current one
        same_kind_id = self._state_of_the_same_kind_cache.get_first_state_of_this_kind(
            state_id, input_variable)
        if same_kind_id != -1:
            same_kind_valuation = self.state_graph.get_state_valuation(
                same_kind_id)[input_variable]
            #
            # If the same kind valuation sits between low-val and high-val
            if not z3_values_check_neq(same_kind_valuation,
                                       state_valuation[input_variable]):
                return [([
                    eq_no_coerce(var.variable, val) if var != input_variable
                    else eq_no_coerce(var.variable, start_value)
                    for var, val in state_valuation.items()
                ], start_delta)]

        else:
            return [([
                eq_no_coerce(var.variable, val) if var != input_variable else
                eq_no_coerce(var.variable, start_value)
                for var, val in state_valuation.items()
            ], start_delta)]

        # --------tes
        # if input_variable.name == "cur_package":
        #    return (state_valuation_to_z3_check_args(state_valuation), delta)
        end_value = input_variable.upper_bound
        end_delta = self.approximate_phi_value_for_state(
            frame_index, [
                eq_no_coerce(var.variable, val) if var != input_variable else
                eq_no_coerce(var.variable, input_variable.upper_bound)
                for var, val in state_valuation.items()
            ])

        #print("Trying to generalize %s by dropping var %s   (starte_value = %s, start_delta = %s, end_value = %s, end_delta = %s, for frame %s)" % (
        #state_valuation, input_variable.name, start_value, float(start_delta.as_fraction()), end_value, float(end_delta.as_fraction()), frame_index + 1))

        # Continue only if value_of_input_variable_of_state_id is not the variables max value!
        if z3_values_check_geq(start_value, end_value):
            #print("not possible (start >= end)")
            return [([
                eq_no_coerce(var.variable, val) if var != input_variable else
                eq_no_coerce(var.variable, start_value)
                for var, val in state_valuation.items()
            ], start_delta)]

        state_args = [
            eq_no_coerce(var.variable, val) if var != input_variable else
            ge_no_coerce(var.variable, start_value)
            for var, val in state_valuation.items()
        ] + [input_variable.variable <= end_value]

        if self.is_generalization_possible(frame_index, state_args,
                                           input_variable, start_value,
                                           start_delta):

            (generalization_possible, generalization_result
             ) = self.settings.get_generalization_method()(
                 self, frame_index, state_id, state_valuation, start_value,
                 start_delta, end_value, end_delta, input_variable)

            # Try at least to generalize by a value iteration step. might make further generalizations possible

            if generalization_possible:
                return generalization_result

            else:
                return generalization_result  #+ self.generalize_by_value_iteration_step(frame_index, [eq_no_coerce(var.variable, val) if var != input_variable
                #   else eq_no_coerce(var.variable, start_value)
                #  for var, val in state_valuation.items()], state_args)

        else:
            #print("not possbible (dropping doesnt work)")
            return [([
                eq_no_coerce(var.variable, val) if var != input_variable else
                eq_no_coerce(var.variable, start_value)
                for var, val in state_valuation.items()
            ], start_delta)]
コード例 #14
0
 def to_hit_probability_dict(self) -> Dict[StateId, ExprRef]:
     """Return the hit probability for each state."""
     return {
         state: RealVal(hits / visits)
         for state, (hits, visits) in self._state_stats.items()
     }