def get_prob_distrib(self, condition): """ Fills the cache with the resulting table for the given condition :param condition: the condition for which to fill the cache """ builder = CategoricalTableBuilder(self._base_var + self._primes) full_effects = list() for inputVal in condition.get_values(): if isinstance(inputVal, Effect): full_effects.extend(inputVal.get_sub_effects()) full_effect = Effect(full_effects) values = full_effect.get_values(self._base_var) if full_effect.is_non_exclusive(self._base_var): add_val = ValueFactory.create(list(values.keys())) builder.add_row(add_val, 1.0) elif len(values) > 0: total = 0.0 for f in values.values(): total += float(f) for v in values.keys(): builder.add_row(v, values[v] / total) else: builder.add_row(ValueFactory.none(), 1.0) return builder.build()
def __init__(self, arg1=None, arg2=None, arg3=1, arg4=True, arg5=False): if isinstance(arg1, Template) and isinstance( arg2, Template) and isinstance(arg3, int) and isinstance( arg4, bool) and isinstance(arg5, bool): variable, value, priority, exclusive, negated = arg1, arg2, arg3, arg4, arg5 """ Constructs a new effect, with a variable label, value, and other arguments. The argument "add" specifies whether the effect is mutually exclusive with other effects. The argument "negated" specifies whether the effect includes a negation. :param variable: variable label :param value: variable value :param priority:the priority level (default is 1) :param exclusive: whether distinct values are mutually exclusive or not :param negated: whether to negate the effect or not. """ super(TemplateEffect, self).__init__( str(variable), ValueFactory.none() if value.is_under_specified() else ValueFactory.create(str(value)), priority, exclusive, negated) self._label_template = variable self._value_template = value else: raise NotImplementedError()
def _get_gaussian(node): """ Extracts the gaussian density function described by the XML specification :param node: the XML node :return: the corresponding Gaussian PDF properly encoded """ mean = None variance = None for child_node in node: str_value = child_node.text.strip() if str_value[:1] == '[': value = ValueFactory.create(str_value) else: value = ValueFactory.create("[%s]" % str_value) if child_node.tag == 'mean': mean = value.get_array() elif child_node.tag == 'variance': variance = value.get_array() else: raise ValueError() if mean is None or variance is None: raise ValueError() return GaussianDensityFunction(mean, variance)
def test_dep_empirical_distrib_continuous(self): bn = BNetwork() builder = CategoricalTableBuilder("var1") builder.add_row(ValueFactory.create("one"), 0.7) builder.add_row(ValueFactory.create("two"), 0.3) var1 = ChanceNode("var1", builder.build()) bn.add_node(var1) continuous = ContinuousDistribution("var2", UniformDensityFunction(-1.0, 3.0)) continuous2 = ContinuousDistribution( "var2", GaussianDensityFunction(3.0, 10.0)) table = ConditionalTable("var2") table.add_distrib(Assignment("var1", "one"), continuous) table.add_distrib(Assignment("var1", "two"), continuous2) var2 = ChanceNode("var2", table) var2.add_input_node(var1) bn.add_node(var2) inference = InferenceChecks() inference.check_cdf(bn, "var2", -1.5, 0.021) inference.check_cdf(bn, "var2", 0., 0.22) inference.check_cdf(bn, "var2", 2., 0.632) inference.check_cdf(bn, "var2", 8., 0.98)
def fill_slots(self, assignment): """ Fills the slots of the template, and returns the result of the function evaluation. If the function is not a simple arithmetic expression, """ filled = super(ArithmeticTemplate, self).fill_slots(assignment) if '{' in filled: return filled if ArithmeticTemplate.is_arithmetic_expression(filled): try: return StringUtils.get_short_form( MathExpression(filled).evaluate()) # TODO: need to check exception handling except Exception as e: self.log.warning("cannot evaluate " + filled) return filled # handling expressions that manipulate sets # (using + and - to respectively add/remove elements) merge = ValueFactory.none() for str_val in filled.split("+"): negations = str_val.split("-") merge = merge.concatenate(ValueFactory.create(negations[0])) for negation in negations[1:]: values = merge.get_sub_values() old_value = ValueFactory.create(negation) if old_value in values: values.remove(ValueFactory.create(negation)) merge = ValueFactory.create(values) return str(merge)
def test_build_basic_network(self): bn = NetworkExamples.construct_basic_network() assert len(bn.get_nodes()) == 8 assert len(bn.get_node("Burglary").get_output_nodes()) == 3 assert len(bn.get_node("Alarm").get_output_nodes()) == 2 assert len(bn.get_node("Alarm").get_input_nodes()) == 2 assert len(bn.get_node("Util1").get_input_nodes()) == 2 assert len(bn.get_node("Burglary").get_values()) == 2 assert len(bn.get_node("Alarm").get_values()) == 2 assert len(bn.get_node("MaryCalls").get_values()) == 2 assert ValueFactory.create(True) in bn.get_node( "Burglary").get_values() assert bn.get_chance_node("Burglary").get_prob( ValueFactory.create(True)) == pytest.approx(0.001, abs=0.0001) assert bn.get_chance_node("Alarm").get_prob( Assignment(["Burglary", "Earthquake"]), ValueFactory.create(True)) == pytest.approx(0.95, abs=0.0001) assert bn.get_chance_node("JohnCalls").get_prob( Assignment("Alarm"), ValueFactory.create(True)) == pytest.approx(0.9, abs=0.0001) assert len(bn.get_action_node("Action").get_values()) == 3 assert bn.get_utility_node("Util2").get_utility( Assignment(Assignment("Burglary"), "Action", ValueFactory.create("DoNothing"))) == pytest.approx( -10, abs=0.0001)
def sample(self): """ Samples from the distribution. :return: the sampled (variable, value) pair """ return ValueFactory.create(self._density_func.sample( )) if self._density_func.get_dimensions() > 1 else ValueFactory.create( self._density_func.sample()[0])
def test_switching(self): old_factor = SwitchingAlgorithm.max_branching_factor SwitchingAlgorithm.max_branching_factor = 4 network = NetworkExamples.construct_basic_network2() distrib = SwitchingAlgorithm().query_prob( network, ["Burglary"], Assignment(["JohnCalls", "MaryCalls"])) assert isinstance(distrib, MultivariateTable) builder = CategoricalTableBuilder("n1") builder.add_row(ValueFactory.create("aha"), 1.0) n1 = ChanceNode("n1", builder.build()) network.add_node(n1) builder = CategoricalTableBuilder("n2") builder.add_row(ValueFactory.create("oho"), 0.7) n2 = ChanceNode("n2", builder.build()) network.add_node(n2) builder = CategoricalTableBuilder("n3") builder.add_row(ValueFactory.create("ihi"), 0.7) n3 = ChanceNode("n3", builder.build()) network.add_node(n3) network.get_node("Alarm").add_input_node(n1) network.get_node("Alarm").add_input_node(n2) network.get_node("Alarm").add_input_node(n3) distrib = SwitchingAlgorithm().query_prob( network, ["Burglary"], Assignment(["JohnCalls", "MaryCalls"])) assert distrib.__class__ == EmpiricalDistribution network.remove_node(n1.get_id()) network.remove_node(n2.get_id()) distrib = SwitchingAlgorithm().query_prob( network, ["Burglary"], Assignment(["JohnCalls", "MaryCalls"])) assert isinstance(distrib, MultivariateTable) n1 = ChanceNode( "n1", ContinuousDistribution("n1", UniformDensityFunction(-2.0, 2.0))) n2 = ChanceNode( "n2", ContinuousDistribution("n2", GaussianDensityFunction(-1.0, 3.0))) network.add_node(n1) network.add_node(n2) network.get_node("Earthquake").add_input_node(n1) network.get_node("Earthquake").add_input_node(n2) distrib = SwitchingAlgorithm().query_prob( network, ["Burglary"], Assignment(["JohnCalls", "MaryCalls"])) assert isinstance(distrib, EmpiricalDistribution) SwitchingAlgorithm.max_branching_factor = old_factor
def sample(self, condition): """ Generates a sample from the distribution given the conditional assignment. """ prob = self.get_prob(condition) if self._sampler.next_double() < prob: return ValueFactory.create(True) else: return ValueFactory.create(False)
def get_values(self): """ Returns a set of two assignments: one with the value true, and one with the value false. :return: the set with the two possible assignments """ result = set() result.add(ValueFactory.create(True)) result.add(ValueFactory.create(False)) return result
def test_assign(self): a = Assignment.create_from_string( 'blabla=3 ^ !bloblo^TTT=32.4 ^v=[0.4,0.6] ^ final') assert len(a.get_variables()) == 5 assert a.get_variables() == {'blabla', 'bloblo', 'TTT', 'v', 'final'} assert a.get_value('blabla') == ValueFactory.create('3') assert a.get_value('bloblo') == ValueFactory.create(False) assert a.get_value('TTT') == ValueFactory.create('32.4') assert a.get_value('v') == ValueFactory.create([0.4, 0.6]) assert a.get_value('final') == ValueFactory.create(True)
def construct_basic_network2(): network = NetworkExamples.construct_basic_network() builder = CategoricalTableBuilder("Burglary") builder.add_row(ValueFactory.create(True), 0.1) builder.add_row(ValueFactory.create(False), 0.9) network.get_chance_node("Burglary").set_distrib(builder.build()) builder = CategoricalTableBuilder("Earthquake") builder.add_row(ValueFactory.create(True), 0.2) builder.add_row(ValueFactory.create(False), 0.8) network.get_chance_node("Earthquake").set_distrib(builder.build()) return network
def test_outputs(self): effects = [] assert Effect(effects) == Effect.parse_effect("Void") effects.append(BasicEffect("v1", "val1")) assert Effect(effects) == Effect.parse_effect("v1:=val1") effects.append(BasicEffect("v2", ValueFactory.create("val2"), 1, False, False)) assert Effect(effects) == Effect.parse_effect("v1:=val1 ^ v2+=val2") effects.append(BasicEffect("v2", ValueFactory.create("val3"), 1, True, True)) assert Effect(effects) == Effect.parse_effect("v1:=val1 ^ v2+=val2 ^ v2!=val3")
def generate_xml(self): distrib_element = Element('distrib') distrib_element.set('type', 'gaussian') mean_element = Element('mean') mean_element.text(str(ValueFactory.create(self._mean)) if len(self._mean) > 1 else str(StringUtils.get_short_form(self._mean[0]))) distrib_element.append(mean_element) variance_element = Element('variance') variance_element.text(str(ValueFactory.create(self._variance)) if len(self._variance) > 1 else str(StringUtils.get_short_form(self._variance[0]))) distrib_element.append(variance_element) return [distrib_element]
def add_pair(self, boolean_assignment): """ Adds a new (var,value) pair as determined by the form of the argument. If the argument starts with an exclamation mark, the value is set to False, else the value is set to True. :param boolean_assignment: the pair to add """ if not boolean_assignment.startswith("!"): self.add_pair(boolean_assignment, ValueFactory.create(True)) else: self.add_pair(boolean_assignment[1:], ValueFactory.create(False))
def generate_xml(self): distrib_element = Element('distrib') distrib_element.set('type', 'uniform') min_element = Element('min') min_element.text = str(ValueFactory.create(self._min_val)) distrib_element.append(min_element) max_element = Element('max') max_element.text = str(ValueFactory.create(self._max_val)) distrib_element.append(max_element) return [distrib_element]
def test_table_expansion(self): bn = NetworkExamples.construct_basic_network() builder = CategoricalTableBuilder("HouseSize") builder.add_row(ValueFactory.create("Small"), 0.7) builder.add_row(ValueFactory.create("Big"), 0.2) builder.add_row(ValueFactory.create("None"), 0.1) node = ChanceNode("HouseSize", builder.build()) bn.add_node(node) bn.get_node("Burglary").add_input_node(node) assert bn.get_chance_node("Burglary").get_prob( Assignment(["HouseSize", "Small"]), ValueFactory.create(True)) == pytest.approx(0.001, abs=0.0001) assert bn.get_chance_node("Burglary").get_prob( Assignment(["HouseSize", "Big"]), ValueFactory.create(True)) == pytest.approx(0.001, abs=0.0001) bn.get_node("Alarm").add_input_node(node) assert bn.get_chance_node("Alarm").get_prob( Assignment(["Burglary", "Earthquake"]), ValueFactory.create(True)) == pytest.approx(0.95, abs=0.0001) assert bn.get_chance_node("Alarm").get_prob( Assignment(Assignment(["Burglary", "Earthquake"]), "HouseSize", ValueFactory.create("None")), ValueFactory.create(True)) == pytest.approx(0.95, abs=0.0001)
def create_from_string(assignments_str): assignment = Assignment() assignments_str = assignments_str.split('^') for assignment_str in assignments_str: if '=' in assignment_str: variable = assignment_str.split('=')[0].strip() value = assignment_str.split('=')[1].strip() assignment.add_pair(variable, ValueFactory.create(value)) elif '!' in assignment_str: variable = assignment_str.replace('!', '').strip() assignment.add_pair(variable, ValueFactory.create(False)) else: variable = assignment_str.strip() assignment.add_pair(variable, ValueFactory.create(True)) return assignment
def test_pruning8(self): initial_state = copy(TestPruning.system.get_state()) created_nodes = SortedSet() for node_id in TestPruning.system.get_state().get_node_ids(): if node_id.find("a_u3^") != -1: created_nodes.add(node_id) assert len(created_nodes) == 2 values = TestPruning.system.get_state().get_node(created_nodes[0] + "").get_values() if ValueFactory.create("Greet") in values: greet_node = created_nodes[0] # created_nodes.first() howareyou_node = created_nodes[-1] # created_nodes.last() else: greet_node = created_nodes[-1] # created_nodes.last() howareyou_node = created_nodes[0] # created_nodes.first() TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "[" + howareyou_node + "," + greet_node + "]", 0.7) TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "none", 0.1) TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "[" + howareyou_node + "]", 0.2) TestPruning.inference.check_prob(TestPruning.system.get_state(), greet_node + "", "Greet", 0.7) TestPruning.inference.check_prob(TestPruning.system.get_state(), howareyou_node + "", "HowAreYou", 0.9) TestPruning.system.get_state().reset(initial_state)
def test_param_4(self): system = DialogueSystem(TestParameters.domain1) system.detach_module(ForwardPlanner) system.get_settings().show_gui = False system.start_system() rules = TestParameters.domain1.get_models()[1].get_rules() outputs = rules[0].get_output(Assignment("u_u", "my name is")) o = Effect(BasicEffect("u_u^p", "Pierre")) assert isinstance(outputs.get_parameter(o), SingleParameter) input = Assignment("theta_5", ValueFactory.create("[0.36, 0.24, 0.40]")) assert outputs.get_parameter(o).get_value(input) == pytest.approx(0.36, abs=0.01) system.get_state().remove_nodes(system.get_state().get_action_node_ids()) system.get_state().remove_nodes(system.get_state().get_utility_node_ids()) system.add_content("u_u", "my name is") system.get_state().remove_nodes(system.get_state().get_action_node_ids()) system.get_state().remove_nodes(system.get_state().get_utility_node_ids()) system.add_content("u_u", "Pierre") system.get_state().remove_nodes(system.get_state().get_action_node_ids()) system.get_state().remove_nodes(system.get_state().get_utility_node_ids()) system.add_content("u_u", "my name is") system.get_state().remove_nodes(system.get_state().get_action_node_ids()) system.get_state().remove_nodes(system.get_state().get_utility_node_ids()) system.add_content("u_u", "Pierre") assert system.get_state().query_prob("theta_5").to_continuous().get_function().get_mean()[0] == pytest.approx(0.3, abs=0.12)
def get_sub_values(self): """ Returns a list of words. :return: list of words """ from bn.values.value_factory import ValueFactory return [ValueFactory.create(w) for w in self._value.split(" ")]
def get_prob(self, value): """ Returns the probability P(value), if any is specified. Else, returns 0.0f. :param value: the value for the random variable (as a float array) :return: associated probability, if one exists. """ return self.get_prob(ValueFactory.create(value))
def test_priority(self): system = DialogueSystem(XMLDomainReader.extract_domain(TestRule3.domain_file)) system.get_settings().show_gui = False system.start_system() assert system.get_content("a_u").get_prob("Opening") == pytest.approx(0.8, abs=0.01) assert system.get_content("a_u").get_prob("Nothing") == pytest.approx(0.1, abs=0.01) assert system.get_content("a_u").get_prob("start") == pytest.approx(0.0, abs=0.01) assert not system.get_content("a_u").to_discrete().has_prob(ValueFactory.create("start"))
def remove_from_state(self, variable_id): """ Removes the variable from the dialogue state :param variable_id: the node to remove """ with self._locks['remove_from_state']: self.add_to_state(Assignment(variable_id, ValueFactory.none()))
def get_value(self, variable): """ Returns the value associated with the variable in the assignment, if one is specified. Else, returns the none value. :param variable: the variable :return: the associated value """ return self._map.get(variable, ValueFactory.none())
def add_pair(self, variable, value): """ Adds a new (var,value) pair to the assignment :param variable: the variable :param value: the value, as a float list """ self._map[variable] = ValueFactory.create(value) self._cached_hash = 0
def create_value(self, str_representation): """ Creates a value from a string representation within the graph. :param str_representation: string representation :return: the value """ from bn.values.value_factory import ValueFactory return ValueFactory.create(str_representation)
def test_assign_interchance(self): a1 = Assignment(Assignment("Burglary", True), "Earthquake", ValueFactory.create(False)) a1bis = Assignment(Assignment("Earthquake", False), "Burglary", ValueFactory.create(True)) a2 = Assignment(Assignment("Burglary", False), "Earthquake", ValueFactory.create(True)) a2bis = Assignment(Assignment("Earthquake", True), "Burglary", ValueFactory.create(False)) assert a1 != a2 assert hash(a1) != hash(a2) assert a1bis != a2bis assert hash(a1bis) != hash(a2bis) assert a1 != a2bis assert hash(a1) != hash(a2bis) assert a1bis != a2 assert hash(a1bis) != hash(a2)
def find(self, str_val, max_results): """ Searches for the occurrences of the relational template in the string (if the string is itself a relational structure). Else returns an empty list. """ val = ValueFactory.create(str_val) if isinstance(val, RelationalVal): return self.get_matches(val) return list()
def generate_xml(self): element_list = [] for point, prob in self._points.items(): value_node = Element('value') value_node.set('prob', StringUtils.get_short_form(prob)) value_node.text(str(ValueFactory.create(prob))) element_list.append(value_node) return element_list