示例#1
0
    def test_rule_and_params(self):
        domain = XMLDomainReader.extract_domain(TestRuleAndParams.domain_file)
        system = DialogueSystem(domain)

        system.get_settings().show_gui = False

        system.start_system()
        assert system.get_content("theta_moves").to_continuous().get_function().get_mean()[0] == pytest.approx(0.2, abs=0.02)
        assert system.get_content("a_u^p").get_prob("I want left") == pytest.approx(0.12, abs=0.03)
        assert len(system.get_state().get_chance_node("theta_moves").get_output_node_ids()) == 1
        assert system.get_state().has_chance_node("movements")
        assert isinstance(system.get_state().get_chance_node("movements").get_distrib(), AnchoredRule)

        t = CategoricalTableBuilder("a_u")
        t.add_row("I want left", 0.8)
        t.add_row("I want forward", 0.1)
        system.add_content(t.build())

        assert len(system.get_state().get_chance_node("theta_moves").get_output_node_ids()) == 0
        assert not system.get_state().has_chance_node("movements")
        assert system.get_content("theta_moves").to_continuous().get_function().get_mean()[0] == pytest.approx(2.0 / 6.0, abs=0.07)

        system.add_content("a_m", "turning left")

        assert system.get_content("a_u^p").get_prob("I want left") == pytest.approx(0.23, abs=0.04)
        assert len(system.get_state().get_chance_node("theta_moves").get_output_node_ids()) == 1
示例#2
0
    def test_IS2013(self):
        domain = XMLDomainReader.extract_domain(TestLearning.domain_file)
        params = XMLStateReader.extract_bayesian_network(TestLearning.parameters_file, "parameters")
        domain.set_parameters(params)
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        Settings.nr_samples = Settings.nr_samples * 3
        Settings.max_sampling_time = Settings.max_sampling_time * 10
        system.start_system()

        init_mean = system.get_content("theta_1").to_continuous().get_function().get_mean()

        builder = CategoricalTableBuilder("a_u")
        builder.add_row("Move(Left)", 1.0)
        builder.add_row("Move(Right)", 0.0)
        builder.add_row("None", 0.0)
        system.add_content(builder.build())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.get_state().remove_nodes(system.get_state().get_action_node_ids())

        after_mean = system.get_content("theta_1").to_continuous().get_function().get_mean()

        assert after_mean[0] - init_mean[0] > 0.04
        assert after_mean[1] - init_mean[1] < 0.04
        assert after_mean[2] - init_mean[2] < 0.04
        assert after_mean[3] - init_mean[3] < 0.04
        assert after_mean[4] - init_mean[4] < 0.04
        assert after_mean[5] - init_mean[5] < 0.04
        assert after_mean[6] - init_mean[6] < 0.04
        assert after_mean[7] - init_mean[7] < 0.04

        Settings.nr_samples = int(Settings.nr_samples / 3)
        Settings.max_sampling_time = Settings.max_sampling_time / 10
示例#3
0
    def test_5(self):
        domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file3)
        system2 = DialogueSystem(domain2)
        system2.detach_module(ForwardPlanner)
        system2.get_settings().show_gui = False
        system2.start_system()

        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "Nod"),
                Assignment("a_md'", "None")
            ]), 2.4)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "Nod"),
                Assignment("a_md'", "DanceAround")
            ]), -0.6)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "None"),
                Assignment("a_md'", "None")
            ]), 1.6)
示例#4
0
    def __init__(self, arg1=None, arg2=None):
        from dialogue_system import DialogueSystem
        if not isinstance(arg1, DialogueSystem):
            raise NotImplementedError("UNDEFINED PARAMETERS")

        system = arg1
        if isinstance(arg2, str):
            domain = arg2
            """
            Creates a new user/environment simulator.

            :param system: the main dialogue system to which the simulator should connect
            :param domain: the dialogue domain for the simulator simulator could
                           not be created
            """
            domain = XMLDomainReader.extract_domain(domain)
        elif isinstance(arg2, Domain):
            domain = arg2
        else:
            raise NotImplementedError("UNDEFINED PARAMETERS")

        """
        Creates a new user/environment simulator.

        :param system: the main dialogue system to which the simulator should connect
        :param domain: the dialogue domain for the simulator not be created
        """
        self.system = system
        self.domain = domain
        self.simulator_state = copy(domain.get_initial_state())
        self.simulator_state.set_parameters(domain.get_parameters())
        self.system.change_settings(domain.get_settings())

        self._lock = threading.RLock()
示例#5
0
 def test_priority(self):
     system = DialogueSystem(XMLDomainReader.extract_domain(TestRule3.domain_file))
     system.get_settings().show_gui = False
     system.start_system()
     assert system.get_content("a_u").get_prob("Opening") == pytest.approx(0.8, abs=0.01)
     assert system.get_content("a_u").get_prob("Nothing") == pytest.approx(0.1, abs=0.01)
     assert system.get_content("a_u").get_prob("start") == pytest.approx(0.0, abs=0.01)
     assert not system.get_content("a_u").to_discrete().has_prob(ValueFactory.create("start"))
示例#6
0
    def test_template_quick(self):
        domain = XMLDomainReader.extract_domain("test/data/quicktest.xml")
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False

        system.start_system()
        assert system.get_content("caught").get_prob(False) == pytest.approx(1.0, abs=0.01)
        assert system.get_content("caught2").get_prob(True) == pytest.approx(1.0, abs=0.01)
示例#7
0
 def test_param_6(self):
     system = DialogueSystem(XMLDomainReader.extract_domain("test/data/testparams3.xml"))
     system.get_settings().show_gui = False
     system.start_system()
     table = system.get_content("b").to_discrete()
     assert len(table) == 6
     assert table.get_prob("something else") == pytest.approx(0.45, abs=0.05)
     assert table.get_prob("value: first with type 1") == pytest.approx(0.175, abs=0.05)
     assert table.get_prob("value: second with type 2") == pytest.approx(0.05, abs=0.05)
示例#8
0
 def test_6(self):
     domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file4)
     system2 = DialogueSystem(domain2)
     system2.detach_module(ForwardPlanner)
     system2.get_settings().show_gui = False
     system2.start_system()
     TestRule2.inference.check_prob(system2.get_state(), "A",
                                    ValueFactory.create("[a1,a2]"), 1.0)
     TestRule2.inference.check_prob(system2.get_state(), "a_u",
                                    "Request(ball)", 0.5)
示例#9
0
    def test_incondition(self):
        domain = XMLDomainReader.extract_domain(TestRule3.incondition_file)

        system = DialogueSystem(domain)
        system.get_settings().show_gui = False

        system.start_system()

        assert system.get_content("out").get_prob("val1 is in [val1, val2]") + system.get_content("out").get_prob("val1 is in [val2, val1]") == pytest.approx(0.56, abs=0.01)
        assert system.get_content("out2").get_prob("this is a string is matched") == pytest.approx(0.5, abs=0.01)
示例#10
0
 def test_underspec(self):
     domain = XMLDomainReader.extract_domain("test/data/underspectest.xml")
     system = DialogueSystem(domain)
     system.get_settings().show_gui = False
     StatePruner.enable_reduction = False
     system.start_system()
     assert system.get_content("match").get_prob("obj_1") == pytest.approx(0.66, abs=0.05)
     assert system.get_content("match").get_prob("obj_3") == pytest.approx(0.307, abs=0.05)
     assert len(system.get_state().get_chance_node_ids()) == 14
     StatePruner.enable_reduction = True
class TestDialogueState:
    domain_file = "test/data/domain1.xml"

    domain = XMLDomainReader.extract_domain(domain_file)
    inference = InferenceChecks()

    def test_state_copy(self):
        system = DialogueSystem(TestDialogueState.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False

        system.get_settings().show_gui = False
        system.start_system()

        initial_state = copy(system.get_state())

        rule_id = ""
        for id in system.get_state().get_node("u_u2").get_output_node_ids():
            if str(system.get_content(id)).find("+=HowAreYou") != -1:
                rule_id = id

        TestDialogueState.inference.check_prob(
            initial_state, rule_id, Effect.parse_effect("a_u2+=HowAreYou"),
            0.9)
        TestDialogueState.inference.check_prob(initial_state, rule_id,
                                               Effect.parse_effect("Void"),
                                               0.1)

        TestDialogueState.inference.check_prob(initial_state, "a_u2",
                                               "[HowAreYou]", 0.2)
        TestDialogueState.inference.check_prob(initial_state, "a_u2",
                                               "[Greet, HowAreYou]", 0.7)
        TestDialogueState.inference.check_prob(initial_state, "a_u2", "[]",
                                               0.1)

        StatePruner.enable_reduction = True

    def test_state_copy2(self):
        InferenceChecks.exact_threshold = 0.08

        system = DialogueSystem(TestDialogueState.domain)
        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        system.start_system()

        initial_state = copy(system.get_state())

        TestDialogueState.inference.check_prob(initial_state, "a_u2",
                                               "[HowAreYou]", 0.2)
        TestDialogueState.inference.check_prob(initial_state, "a_u2",
                                               "[Greet, HowAreYou]", 0.7)
        TestDialogueState.inference.check_prob(initial_state, "a_u2", "[]",
                                               0.1)
示例#12
0
    def test_1(self):
        domain = XMLDomainReader.extract_domain(TestRule3.test1_domain_file)
        inference = InferenceChecks()
        inference.exact_threshold = 0.06
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.start_system()

        inference.check_prob(system.get_state(), "found", "A", 0.7)

        inference.check_prob(system.get_state(), "found2", "D", 0.3)
        inference.check_prob(system.get_state(), "found2", "C", 0.5)

        StatePruner.enable_reduction = True
示例#13
0
    def test_2(self):
        inference = InferenceChecks()

        domain = XMLDomainReader.extract_domain(TestRule3.test2_domain_file)
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False

        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.start_system()

        inference.check_prob(system.get_state(), "graspable(obj1)", "True", 0.81)

        inference.check_prob(system.get_state(), "graspable(obj2)", "True", 0.16)
        inference.check_util(system.get_state(), "a_m'", "grasp(obj1)", 0.592)

        StatePruner.enable_reduction = True
示例#14
0
    def __init__(self, arg1=None):

        if arg1 is None:
            """
            Creates a new dialogue system with an empty dialogue system
            """
            self._settings = Settings()  # the system setting
            self._cur_state = DialogueState()  # the dialogue state

            self._domain = Domain()  # the dialogue domain
            self._paused = True  # whether the system is paused or active
            self._modules = []  # the set of modules attached to the system

            # Inserting standard modules
            system = self
            self._modules.append(GUIFrame(system))
            self._modules.append(DialogueRecorder(self))
            if self._settings.planner == 'forward':
                self.log.info("Forward planner will be used.")
                self._modules.append(ForwardPlanner(self))
            elif self._settings.planner == 'mcts':
                self.log.info("MCTS planner will be used.")
                self._modules.append(MCTSPlanner(self))
            else:
                raise ValueError("Not supported planner: %s" %
                                 self._settings.planner)
            self._init_lock()

        elif isinstance(arg1, Domain):
            domain = arg1
            """
            Creates a new dialogue system with the provided dialogue domain

            :param domain: the dialogue domain to employ
            """
            self.__init__()
            self.change_domain(domain)
        elif isinstance(arg1, str):
            domain_file = arg1
            """
            Creates a new dialogue system with the provided dialogue domain

            :param domain_file: the dialogue domain to employ
            """
            self.__init__()
            self.change_domain(XMLDomainReader.extract_domain(domain_file))
示例#15
0
    def refresh_domain(self):
        """
        Refreshes the dialogue domain by rereading its source file (in case it has been changed by the user).
        """
        if self._domain.is_empty():
            return

        src_file = self._domain.get_source_file().get_path()

        try:
            self._domain = XMLDomainReader.extract_domain(src_file)
            self.change_settings(self._domain.get_settings())
            self.display_comment("Dialogue domain successfully updated")
        except Exception as e:
            self.log.critical("Cannot refresh domain %s" % e)
            self.display_comment("Syntax error: %s" % e)
            self._domain = Domain()
            self._domain.set_source_file(src_file)
示例#16
0
    def open_domain(self, domain=False):
        if domain is False:
            fname = QFileDialog.getOpenFileName(self, 'Open File')
            domain = XMLDomainReader.extract_domain(fname[0])

            domain_name = str(domain).split('.')[0]

            self.chatlog.clear()
            self.chatlog.append("[%s domain successfully attached]\n" %
                                domain_name)
            self.chatlog.setAlignment(Qt.AlignLeft)
            self._system.change_domain(domain)

            self.save_as.setEnabled(True)
            self.reset.setEnabled(True)
            self.pause.setEnabled(True)
            self.export_act1.setEnabled(True)

        else:
            self._system.change_domain(domain)
示例#17
0
    def test_param1(self):
        domain = XMLDomainReader.extract_domain(TestDemo.domain_file)
        params = XMLStateReader.extract_bayesian_network(
            TestDemo.param_file, "parameters")
        domain.set_parameters(params)
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False

        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False

        system.start_system()
        system.add_content("a_m", "AskRepeat")

        t = CategoricalTableBuilder("a_u")
        t.add_row("DoA", 0.7)
        t.add_row("a_u", 0.2)
        t.add_row("a_u", 0.1)
        system.add_content(t.build())
        for i in range(3000):
            print((system.get_state().get_chance_node("theta").sample()
                   ).get_array()[0])
示例#18
0
    def test_4(self):
        domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file2)
        system2 = DialogueSystem(domain2)
        system2.get_settings().show_gui = False
        system2.detach_module(ForwardPlanner)
        system2.start_system()

        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment(
                [Assignment("a_m3'", "Do"),
                 Assignment("obj(a_m3)'", "A")]), 0.3)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment(
                [Assignment("a_m3'", "Do"),
                 Assignment("obj(a_m3)'", "B")]), -1.7)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment([
                Assignment("a_m3'", "SayHi"),
                Assignment("obj(a_m3)'", "None")
            ]), -0.9)
示例#19
0
    def test1(self):
        domain = XMLDomainReader.extract_domain(TestIncremental.domain_file)
        system = DialogueSystem(domain)

        # NEED GUI & Recording
        system.get_settings().show_gui = False
        # system.get_settings().recording = Settings.Recording.ALL

        system.start_system()
        system.add_content(system.get_settings().user_speech, "busy")
        system.add_incremental_content(SingleValueDistribution("u_u", "go"),
                                       False)

        sleep(0.1)

        assert ValueFactory.create("go") in system.get_content(
            "u_u").get_values()

        t = CategoricalTableBuilder("u_u")
        t.add_row("forward", 0.7)
        t.add_row("backward", 0.2)

        system.add_incremental_content(t.build(), True)

        sleep(0.1)

        assert ValueFactory.create("go forward") in system.get_content(
            "u_u").get_values()
        assert system.get_content("u_u").get_prob(
            "go backward") == pytest.approx(0.2, abs=0.001)
        assert system.get_state().has_chance_node("nlu")

        system.add_content(system.get_settings().user_speech, "None")
        assert len(system.get_state().get_chance_nodes()) == 7
        system.add_incremental_content(
            SingleValueDistribution("u_u", "please"), True)
        assert system.get_state().get_evidence().contains_pair(
            "=_a_u", ValueFactory.create(True))

        assert system.get_content("u_u").get_prob(
            "go please") == pytest.approx(0.1, abs=0.001)
        assert system.get_state().has_chance_node("nlu")

        system.get_state().set_as_committed("u_u")
        assert not system.get_state().has_chance_node("nlu")

        t2 = CategoricalTableBuilder("u_u")
        t2.add_row("I said go backward", 0.3)

        system.add_incremental_content(t2.build(), True)
        assert system.get_content("a_u").get_prob(
            "Request(Backward)") == pytest.approx(0.82, abs=0.05)
        assert ValueFactory.create("I said go backward") in system.get_content(
            "u_u").get_values()
        assert system.get_state().has_chance_node("nlu")

        system.get_state().set_as_committed("u_u")
        assert not system.get_state().has_chance_node("nlu")
        system.add_incremental_content(
            SingleValueDistribution("u_u", "yes that is right"), False)
        assert ValueFactory.create("yes that is right") in system.get_content(
            "u_u").get_values()
示例#20
0
class TestPlanning:
    domain_file = "test/data/domain3.xml"
    domain_file2 = "test/data/basicplanning.xml"
    domain_file3 = "test/data/planning2.xml"
    settings_file = "test/data/settings_test2.xml"

    inference = InferenceChecks()
    domain = XMLDomainReader.extract_domain(domain_file)
    domain2 = XMLDomainReader.extract_domain(domain_file2)
    domain3 = XMLDomainReader.extract_domain(domain_file3)

    def test_planning(self):
        system = DialogueSystem(TestPlanning.domain)

        system.get_settings().show_gui = False

        system.start_system()
        assert len(system.get_state().get_nodes()) == 3
        assert len(system.get_state().get_chance_nodes()) == 3
        assert len(system.get_state().get_evidence().get_variables()) == 0
        TestPlanning.inference.check_prob(system.get_state(), "a_m3", "Do",
                                          1.0)
        TestPlanning.inference.check_prob(system.get_state(), "obj(a_m3)", "A",
                                          1.0)

    def test_planning2(self):
        system = DialogueSystem(TestPlanning.domain2)

        system.get_settings().show_gui = False

        system.start_system()
        assert len(system.get_state().get_node_ids()) == 2
        assert not system.get_state().has_chance_node("a_m")

    def test_planning3(self):
        system = DialogueSystem(TestPlanning.domain2)

        system.get_settings().show_gui = False

        system.get_settings().horizon = 2
        system.start_system()
        TestPlanning.inference.check_prob(system.get_state(), "a_m",
                                          "AskRepeat", 1.0)

    def test_planning4(self):
        system = DialogueSystem(TestPlanning.domain3)

        system.get_settings().show_gui = False

        system.get_settings().horizon = 3
        system.start_system()

        t1 = CategoricalTableBuilder("a_u")
        t1.add_row("Ask(Coffee)", 0.95)
        t1.add_row("Ask(Tea)", 0.02)
        system.add_content(t1.build())
        TestPlanning.inference.check_prob(system.get_state(), "a_m",
                                          "Do(Coffee)", 1.0)

    def test_planning5(self):
        system = DialogueSystem(TestPlanning.domain3)

        system.get_settings().show_gui = False

        system.get_settings().horizon = 3
        system.start_system()

        t1 = CategoricalTableBuilder("a_u")
        t1.add_row("Ask(Coffee)", 0.3)
        t1.add_row("Ask(Tea)", 0.3)
        system.add_content(t1.build())

        TestPlanning.inference.check_prob(system.get_state(), "a_m",
                                          "AskRepeat", 1.0)
示例#21
0
class TestFlightBooking:
    domain = XMLDomainReader.extract_domain("test/data/example-flightbooking.xml")

    def test_dialogue(self):
        system = DialogueSystem(TestFlightBooking.domain)
        system.get_settings().show_gui = False
        system.start_system()
        assert str(system.get_content("u_m").get_best()).find("your destination?") != -1

        u_u = dict()
        u_u["to Bergen"] = 0.4
        u_u["to Bethleem"] = 0.2
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Airport,Bergen)]") == pytest.approx(0.833, abs=0.01)
        assert len(system.get_content("a_u").get_values()) == pytest.approx(3, abs=0.01)
        assert system.get_state().query_prob("a_u", False).get_prob("[Other]") == pytest.approx(0.055, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(Destination,Bergen)"
        u_u.clear()
        u_u["yes exactly"] = 0.8
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.98, abs=0.01)
        assert system.get_content("Destination").get_prob("Bergen") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(Destination,Bergen)"
        assert str(system.get_content("u_m").get_best()).find("your departure?") != -1
        u_u.clear()
        u_u["to Stockholm"] = 0.8
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Other]") == pytest.approx(0.8, abs=0.01)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Ground(Destination,Bergen)"
        assert len(system.get_content("Destination").to_discrete().get_values()) == 1
        assert system.get_content("Destination").get_prob("Bergen") == pytest.approx(1.0, abs=0.01)
        assert not system.get_state().has_chance_node("Departure")
        assert str(system.get_content("a_m").get_best()) == "AskRepeat"
        assert str(system.get_content("u_m").get_best()).find("you repeat?") != -1
        u_u.clear()
        u_u["to Sandefjord then"] = 0.6
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("None") == pytest.approx(0.149, abs=0.05)
        assert system.get_content("Departure").get_prob("Sandefjord") == pytest.approx(0.88, abs=0.05)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(Departure,Sandefjord)"
        assert str(system.get_content("u_m").get_best()).find("that correct?") != -1
        u_u.clear()
        u_u["no to Trondheim sorry"] = 0.08
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Airport,Trondheim),Disconfirm]") == pytest.approx(0.51, abs=0.01)
        assert system.get_content("Departure").get_prob("Trondheim") == pytest.approx(0.51, abs=0.05)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "AskRepeat"
        assert str(system.get_content("u_m").get_best()).find("repeat?") != -1
        u_u.clear()
        u_u["to Trondheim"] = 0.3
        u_u["Sandefjord"] = 0.1
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Airport,Trondheim)]") == pytest.approx(0.667, abs=0.01)
        assert system.get_content("Destination").get_prob("Bergen") == pytest.approx(1.0, abs=0.01)
        assert system.get_content("Departure").get_prob("Trondheim") == pytest.approx(0.89, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(Departure,Trondheim)"
        u_u.clear()
        u_u["yes exactly that's it"] = 0.8
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(Departure,Trondheim)"
        assert str(system.get_content("u_m").get_best()).find("which date") != -1
        u_u.clear()
        u_u["that will be on May 26"] = 0.4
        u_u["this will be on May 24"] = 0.2
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Date,May,24)]") == pytest.approx(0.2, abs=0.01)
        assert system.get_content("a_u").get_prob("[Inform(Date,May,26)]") == pytest.approx(0.4, abs=0.01)
        assert system.get_content("Destination").get_prob("Bergen") == pytest.approx(1.0, abs=0.01)
        assert system.get_content("Departure").get_prob("Trondheim") == pytest.approx(1.0, abs=0.01)
        assert system.get_content("Date").get_prob("May 26") == pytest.approx(0.4, abs=0.01)
        assert system.get_content("Date").get_prob("May 24") == pytest.approx(0.2, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "AskRepeat"
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Ground(Departure,Trondheim)"
        u_u.clear()
        u_u["May 24"] = 0.5
        u_u["Mayday four"] = 0.5
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Date,May,24)]") == pytest.approx(0.82, abs=0.05)
        assert system.get_content("a_u").get_prob("[Inform(Number,4)]") == pytest.approx(0.176, abs=0.01)
        assert system.get_content("Date").get_prob("May 26") == pytest.approx(0.02, abs=0.01)
        assert system.get_content("Date").get_prob("May 24") == pytest.approx(0.94, abs=0.01)
        assert system.get_state().has_chance_node("a_m")
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "AskRepeat"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(Date,May 24)"
        assert str(system.get_content("u_m").get_best()).find("return trip") != -1
        u_u.clear()
        u_u["no thanks"] = 0.9
        system.add_user_input(u_u)
        assert str(system.get_content("u_m").get_best()).find("to order tickets?") != -1
        assert system.get_content("ReturnDate").get_prob("NoReturn") == pytest.approx(1.0, abs=0.01)
        assert system.get_content("current_step").get_prob("MakeOffer") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "MakeOffer(179)"
        u_u.clear()
        u_u["yes"] = 0.02
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.177, abs=0.01)
        assert system.get_content("current_step").get_prob("MakeOffer") == pytest.approx(1.0, abs=0.01)
        assert not system.get_state().has_chance_node("a_m")
        u_u.clear()
        u_u["yes"] = 0.8
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.978, abs=0.01)
        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("many tickets") != -1
        u_u.clear()
        u_u["uh I don't know me"] = 0.6
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Other]") == pytest.approx(0.6, abs=0.01)
        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Ground(MakeOffer)"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "AskRepeat"
        u_u.clear()
        u_u["three tickets please"] = 0.9
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Number,3)]") == pytest.approx(0.9, abs=0.01)
        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "AskRepeat"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(NbTickets,3)"
        u_u.clear()
        u_u["no sorry two tickets"] = 0.4
        u_u["sorry to tickets"] = 0.3
        system.add_user_input(u_u)
        assert len(system.get_content("a_u").get_values()) == pytest.approx(3, abs=0.01)
        assert system.get_content("a_u").get_prob("[Disconfirm,Inform(Number,2)]") == pytest.approx(0.86, abs=0.05)
        assert system.get_content("NbTickets").get_prob(2) == pytest.approx(0.86, abs=0.05)
        assert system.get_content("NbTickets").get_prob(3) == pytest.approx(0.125, abs=0.05)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Confirm(NbTickets,3)"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(NbTickets,2)"
        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        u_u.clear()
        u_u["yes thank you"] = 0.75
        u_u["yes mind you"] = 0.15
        system.add_user_input(u_u)
        assert len(system.get_content("a_u").get_values()) == pytest.approx(2, abs=0.01)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(1.0, abs=0.05)
        assert system.get_content("NbTickets").get_prob(2) == pytest.approx(1.0, abs=0.05)
        assert system.get_content("NbTickets").get_prob(3) == pytest.approx(0.0, abs=0.05)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Confirm(NbTickets,2)"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(NbTickets,2)"
        assert system.get_content("current_step").get_prob("LastConfirm") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("Shall I confirm") != -1
        assert str(system.get_content("u_m").get_best()).find("358 EUR") != -1
        u_u.clear()
        u_u["err yes"] = 0.2
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.726, abs=0.01)
        assert system.get_content("current_step").get_prob("LastConfirm") == pytest.approx(1.0, abs=0.01)
        u_u.clear()
        u_u["yes please confirm"] = 0.5
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.934, abs=0.01)
        assert system.get_content("current_step").get_prob("Final") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Book"
        assert str(system.get_content("u_m").get_best()).find("additional tickets?") != -1
        u_u.clear()
        u_u["thanks but no thanks"] = 0.7
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Disconfirm]") == pytest.approx(0.97, abs=0.01)
        assert system.get_content("current_step").get_prob("Close") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("welcome back!") != -1

        assert sorted(system.get_state().get_chance_node_ids()) == ["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost", "a_m", "a_m-prev", "a_u", "current_step", "u_m", "u_u"]
        assert len(system.get_content(["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost"]).to_discrete().get_values()) == 1
        assert system.get_content(["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost"]).to_discrete().get_prob(
            Assignment([
                Assignment("Date", "May 24"),
                Assignment("Departure", "Trondheim"),
                Assignment("Destination", "Bergen"),
                Assignment("NbTickets", 2.),
                Assignment("ReturnDate", "NoReturn"),
                Assignment("TotalCost", 358.)
            ])
        ) == pytest.approx(1.0, abs=0.01)

    def test_dialogue2(self):
        system = DialogueSystem(TestFlightBooking.domain)
        system.get_settings().show_gui = False
        system.start_system()

        u_u = dict()
        u_u["err I don't know, where can I go?"] = 0.8
        system.add_user_input(u_u)
        assert system.get_content("a_u").to_discrete().get_prob("[Other]") == pytest.approx(0.8, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "AskRepeat"
        u_u.clear()
        u_u["ah ok well I want to go to Tromsø please"] = 0.8
        system.add_user_input(u_u)
        assert system.get_content("a_u").to_discrete().get_prob("[Inform(Airport,Tromsø)]") == pytest.approx(0.91, abs=0.01)
        assert system.get_content("Destination").to_discrete().get_prob("Tromsø") == pytest.approx(0.91, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "Confirm(Destination,Tromsø)"
        u_u.clear()
        u_u["that's right"] = 0.6
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Ground(Destination,Tromsø)"
        assert str(system.get_content("u_m").get_best()).find("departure?") != -1
        u_u.clear()
        u_u["I'll be leaving from Moss"] = 0.1
        system.add_user_input(u_u)
        assert system.get_content("a_u").to_discrete().get_prob("[Inform(Airport,Moss)]") == pytest.approx(0.357, abs=0.01)
        assert system.get_content("Destination").to_discrete().get_prob("Tromsø") == pytest.approx(1.0, abs=0.01)
        assert system.get_content("Departure").to_discrete().get_prob("Moss") == pytest.approx(0.357, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "AskRepeat"
        u_u.clear()
        u_u["I am leaving from Moss, did you get that right?"] = 0.2
        u_u["Bodø, did you get that right?"] = 0.4
        system.add_user_input(u_u)
        assert system.get_content("a_u").to_discrete().get_prob("[Confirm,Inform(Airport,Moss)]") == pytest.approx(0.72, abs=0.01)
        assert system.get_content("Departure").to_discrete().get_prob("Moss") == pytest.approx(0.88, abs=0.01)
        assert system.get_content("Departure").to_discrete().get_prob("Bodø") == pytest.approx(0.10, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "Confirm(Departure,Moss)"
        u_u.clear()
        u_u["yes"] = 0.6
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Ground(Departure,Moss)"
        assert str(system.get_content("u_m").get_best()).find("which date") != -1
        u_u.clear()
        u_u["March 16"] = 0.7
        u_u["March 60"] = 0.2
        system.add_user_input(u_u)
        assert system.get_content("a_u").to_discrete().get_prob("[Inform(Date,March,16)]") == pytest.approx(0.7, abs=0.01)
        assert system.get_content("a_u").to_discrete().get_prob("[Other]") == pytest.approx(0.2, abs=0.01)
        assert str(system.get_content("a_m").get_best()) == "AskRepeat"
        u_u.clear()
        u_u["March 16"] = 0.05
        u_u["March 60"] = 0.3
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Confirm(Date,March 16)"
        u_u.clear()
        u_u["yes"] = 0.6
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Ground(Date,March 16)"
        assert str(system.get_content("u_m").get_best()).find("return trip?") != -1
        u_u.clear()
        u_u["err"] = 0.1
        system.add_user_input(u_u)
        assert not system.get_state().has_chance_node("a_m")
        u_u.clear()
        u_u["yes"] = 0.3
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "AskRepeat"
        u_u.clear()
        u_u["yes"] = 0.5
        system.add_user_input(u_u)
        assert system.get_content("current_step").get_prob("ReturnDate") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("travel back") != -1
        u_u.clear()
        u_u["on the 20th of March"] = 0.7
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Confirm(ReturnDate,March 20)"
        assert system.get_content("ReturnDate").to_discrete().get_prob("March 20") == pytest.approx(0.7, abs=0.01)
        u_u.clear()
        u_u["yes"] = 0.6
        system.add_user_input(u_u)
        assert str(system.get_content("u_m").get_best()).find("299 EUR") != -1
        assert str(system.get_content("u_m").get_best()).find("to order tickets?") != -1
        copy_state1 = copy(system.get_state())
        u_u.clear()
        u_u["no"] = 0.7
        system.add_user_input(u_u)
        assert str(system.get_content("a_m").get_best()) == "Ground(Cancel)"
        assert str(system.get_content("current_step").get_best()) == "Final"
        assert str(system.get_content("u_m").get_best()).find("additional tickets?") != -1
        copy_state2 = copy(system.get_state())
        assert str(copy_state2.query_prob("current_step").get_best()) == "Final"
        u_u.clear()
        u_u["no"] = 0.7
        system.add_user_input(u_u)
        assert str(copy_state2.query_prob("current_step").get_best()) == "Final"
        assert str(system.get_content("a_m").get_best()) == "Ground(Close)"
        assert str(system.get_content("current_step").get_best()) == "Close"
        assert str(system.get_content("u_m").get_best()).find("welcome back") != -1
        system.get_state().remove_nodes(system.get_state().get_chance_node_ids())
        system.get_state().add_network(copy_state2)
        assert str(copy_state2.query_prob("current_step").get_best()) == "Final"
        u_u.clear()
        u_u["yes"] = 0.7
        system.add_user_input(u_u)
        assert not system.get_state().has_chance_node("Destination")
        assert str(system.get_content("u_m").get_best()).find("destination?") != -1
        assert sorted(system.get_state().get_chance_node_ids()) == ["Destination^p", "a_m-prev", "current_step", "u_m", "u_u"]

        system.add_user_input("Oslo")
        assert str(system.get_content("a_m").get_best()) == "Ground(Destination,Oslo)"
        system.get_state().remove_nodes(system.get_state().get_chance_node_ids())
        system.get_state().add_network(copy_state1)
        u_u.clear()
        u_u["yes"] = 0.8
        system.add_user_input(u_u)

        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(MakeOffer)"
        u_u.clear()
        u_u["one single ticket"] = 0.9
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Inform(Number,1)]") == pytest.approx(0.9, abs=0.01)
        assert system.get_content("current_step").get_prob("NbTickets") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Ground(MakeOffer)"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Confirm(NbTickets,1)"
        u_u.clear()
        u_u["yes thank you"] = 1.0
        system.add_user_input(u_u)
        assert len(system.get_content("a_u").get_values()) == pytest.approx(1, abs=0.01)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(1.0, abs=0.05)
        assert system.get_content("NbTickets").get_prob(1.) == pytest.approx(1.0, abs=0.05)
        assert str(system.get_content("a_m-prev").to_discrete().get_best()) == "Confirm(NbTickets,1)"
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Ground(NbTickets,1)"
        assert system.get_content("current_step").get_prob("LastConfirm") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("Shall I confirm") != -1
        assert str(system.get_content("u_m").get_best()).find("299 EUR") != -1
        u_u.clear()
        u_u["yes please"] = 0.5
        u_u["yellow"] = 0.4
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Confirm]") == pytest.approx(0.9397, abs=0.01)
        assert system.get_content("current_step").get_prob("Final") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("a_m").to_discrete().get_best()) == "Book"
        assert str(system.get_content("u_m").get_best()).find("additional tickets?") != -1
        copystate3 = copy(system.get_state())
        u_u.clear()
        u_u["thanks but no thanks"] = 0.7
        system.add_user_input(u_u)
        assert system.get_content("a_u").get_prob("[Disconfirm]") == pytest.approx(0.97, abs=0.01)
        assert system.get_content("current_step").get_prob("Close") == pytest.approx(1.0, abs=0.01)
        assert str(system.get_content("u_m").get_best()).find("welcome back!") != -1
        assert sorted(system.get_state().get_chance_node_ids()) == ["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost", "a_m", "a_m-prev", "a_u", "current_step", "u_m", "u_u"]
        assert len(system.get_content(["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost"]).to_discrete().get_values()) == 1
        assert system.get_content(["Date", "Departure", "Destination", "NbTickets", "ReturnDate", "TotalCost"]).to_discrete().get_prob(
            Assignment([
                Assignment("Date", "March 16"),
                Assignment("Departure", "Moss"),
                Assignment("Destination", "Tromsø"),
                Assignment("NbTickets", 1.),
                Assignment("ReturnDate", "March 20"),
                Assignment("TotalCost", 299.)
            ])
        ) == pytest.approx(1.0, abs=0.01)

        system.get_state().remove_nodes(system.get_state().get_chance_node_ids())  ##
        system.get_state().add_network(copystate3)
        u_u.clear()
        u_u["yes"] = 0.7
        system.add_user_input(u_u)
        assert not system.get_state().has_chance_node("Destination")
        assert str(system.get_content("u_m").get_best()).find("destination?") != -1
        assert sorted(system.get_state().get_chance_node_ids()) == ["Destination^p", "a_m-prev", "current_step", "u_m", "u_u"]
        system.add_user_input("Oslo")
        assert str(system.get_content("a_m").get_best()) == "Ground(Destination,Oslo)"
示例#22
0
class TestPruning:
    domainFile = "test/data/domain1.xml"

    domain = XMLDomainReader.extract_domain(domainFile)
    inference = InferenceChecks()
    InferenceChecks.exact_threshold = 0.1
    InferenceChecks.sampling_threshold = 0.1
    system = DialogueSystem(domain)
    system.get_settings().show_gui = False

    system.start_system()

    def test_pruning0(self):
        assert len(TestPruning.system.get_state().get_node_ids()) == 15
        assert len(TestPruning.system.get_state().get_evidence().get_variables()) == 0

    def test_pruning1(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u", "Greeting", 0.8)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u", "None", 0.2)

    def test_pruning2(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "i_u", "Inform", 0.7 * 0.8)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "i_u", "None", 1 - 0.7 * 0.8)

    def test_pruning3(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "direction", "straight", 0.79)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "direction", "left", 0.20)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "direction", "right", 0.01)

    def test_pruning4(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have var1=value2", 0.3)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have localvar=value1", 0.2)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have localvar=value3", 0.31)

    def test_pruning5(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o2", "here is value1", 0.35)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o2", "and value2 is over there", 0.07)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o2", "value3, finally", 0.28)

    def test_pruning6(self):
        initial_state = copy(TestPruning.system.get_state())

        builder = CategoricalTableBuilder("var1")
        builder.add_row("value2", 0.9)
        TestPruning.system.get_state().add_to_state(builder.build())

        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have var1=value2", 0.3)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have localvar=value1", 0.2)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "o", "and we have localvar=value3", 0.31)

        TestPruning.system.get_state().reset(initial_state)

    def test_pruning7(self):
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u2", "[Greet, HowAreYou]", 0.7)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u2", "none", 0.1)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u2", "[HowAreYou]", 0.2)

    def test_pruning8(self):
        initial_state = copy(TestPruning.system.get_state())

        created_nodes = SortedSet()
        for node_id in TestPruning.system.get_state().get_node_ids():
            if node_id.find("a_u3^") != -1:
                created_nodes.add(node_id)

        assert len(created_nodes) == 2

        values = TestPruning.system.get_state().get_node(created_nodes[0] + "").get_values()
        if ValueFactory.create("Greet") in values:
            greet_node = created_nodes[0]  # created_nodes.first()
            howareyou_node = created_nodes[-1]  # created_nodes.last()
        else:
            greet_node = created_nodes[-1]  # created_nodes.last()
            howareyou_node = created_nodes[0]  # created_nodes.first()

        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "[" + howareyou_node + "," + greet_node + "]", 0.7)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "none", 0.1)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), "a_u3", "[" + howareyou_node + "]", 0.2)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), greet_node + "", "Greet", 0.7)
        TestPruning.inference.check_prob(TestPruning.system.get_state(), howareyou_node + "", "HowAreYou", 0.9)

        TestPruning.system.get_state().reset(initial_state)
示例#23
0
class TestParameters:
    domain_file = "test/data/testwithparams.xml"
    domain_file2 = "test/data/testwithparams2.xml"
    param_file = "test/data/params.xml"

    params = XMLStateReader.extract_bayesian_network(param_file, "parameters")

    domain1 = XMLDomainReader.extract_domain(domain_file)
    domain1.set_parameters(params)

    domain2 = XMLDomainReader.extract_domain(domain_file2)
    domain2.set_parameters(params)

    inference = InferenceChecks()

    def test_param_1(self):
        InferenceChecks.exact_threshold = 0.1

        system = DialogueSystem(TestParameters.domain1)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        assert system.get_state().has_chance_node("theta_1")
        TestParameters.inference.check_cdf(system.get_state(), "theta_1", 0.5, 0.5)
        TestParameters.inference.check_cdf(system.get_state(), "theta_1", 5., 0.99)

        TestParameters.inference.check_cdf(system.get_state(), "theta_2", 1., 0.07)
        TestParameters.inference.check_cdf(system.get_state(), "theta_2", 2., 0.5)

        system.start_system()
        system.add_content("u_u", "hello there")
        utils = ((SamplingAlgorithm()).query_util(system.get_state(), "u_m'"))
        assert utils.get_util(Assignment("u_m'", "yeah yeah talk to my hand")) > 0
        assert utils.get_util(Assignment("u_m'", "so interesting!")) > 1.7
        assert utils.get_util(Assignment("u_m'", "yeah yeah talk to my hand")) < utils.get_util(Assignment("u_m'", "so interesting!"))
        assert len(system.get_state().get_node_ids()) == 11

    def test_param_2(self):
        InferenceChecks.exact_threshold = 0.1

        system = DialogueSystem(TestParameters.domain1)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False

        assert system.get_state().has_chance_node("theta_3")
        TestParameters.inference.check_cdf(system.get_state(), "theta_3", 0.6, 0.0)
        TestParameters.inference.check_cdf(system.get_state(), "theta_3", 0.8, 0.5)
        TestParameters.inference.check_cdf(system.get_state(), "theta_3", 0.95, 1.0)

        system.start_system()
        system.add_content("u_u", "brilliant")
        distrib = system.get_content("a_u")

        assert distrib.get_prob("approval") == pytest.approx(0.8, abs=0.05)

    def test_param_3(self):
        system = DialogueSystem(TestParameters.domain1)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        system.start_system()

        rules = TestParameters.domain1.get_models()[0].get_rules()
        outputs = rules[1].get_output(Assignment("u_u", "no no"))
        o = Effect(BasicEffect("a_u", "disapproval"))
        assert isinstance(outputs.get_parameter(o), SingleParameter)
        input = Assignment("theta_4", ValueFactory.create("[0.36, 0.64]"))
        assert outputs.get_parameter(o).get_value(input) == pytest.approx(0.64, abs=0.01)

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "no no")
        assert system.get_state().query_prob("theta_4").to_continuous().get_function().get_mean()[0] == pytest.approx(0.36, abs=0.1)

        assert system.get_content("a_u").get_prob("disapproval") == pytest.approx(0.64, abs=0.1)

    def test_param_4(self):
        system = DialogueSystem(TestParameters.domain1)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        system.start_system()

        rules = TestParameters.domain1.get_models()[1].get_rules()
        outputs = rules[0].get_output(Assignment("u_u", "my name is"))
        o = Effect(BasicEffect("u_u^p", "Pierre"))
        assert isinstance(outputs.get_parameter(o), SingleParameter)
        input = Assignment("theta_5", ValueFactory.create("[0.36, 0.24, 0.40]"))
        assert outputs.get_parameter(o).get_value(input) == pytest.approx(0.36, abs=0.01)

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "my name is")

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "Pierre")

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "my name is")

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "Pierre")

        assert system.get_state().query_prob("theta_5").to_continuous().get_function().get_mean()[0] == pytest.approx(0.3, abs=0.12)

    def test_param_5(self):
        system = DialogueSystem(TestParameters.domain2)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        system.start_system()

        rules = TestParameters.domain2.get_models()[0].get_rules()
        outputs = rules[0].get_output(Assignment("u_u", "brilliant"))
        o = Effect(BasicEffect("a_u", "approval"))
        assert isinstance(outputs.get_parameter(o), ComplexParameter)
        input = Assignment([Assignment("theta_6", 2.1), Assignment("theta_7", 1.3)])
        assert outputs.get_parameter(o).get_value(input) == pytest.approx(0.74, abs=0.01)

        system.get_state().remove_nodes(system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(system.get_state().get_utility_node_ids())
        system.add_content("u_u", "brilliant")

        assert system.get_state().query_prob("theta_6").to_continuous().get_function().get_mean()[0] == pytest.approx(1.0, abs=0.08)

        assert system.get_content("a_u").get_prob("approval") == pytest.approx(0.63, abs=0.08)
        assert system.get_content("a_u").get_prob("irony") == pytest.approx(0.3, abs=0.08)

    def test_param_6(self):
        system = DialogueSystem(XMLDomainReader.extract_domain("test/data/testparams3.xml"))
        system.get_settings().show_gui = False
        system.start_system()
        table = system.get_content("b").to_discrete()
        assert len(table) == 6
        assert table.get_prob("something else") == pytest.approx(0.45, abs=0.05)
        assert table.get_prob("value: first with type 1") == pytest.approx(0.175, abs=0.05)
        assert table.get_prob("value: second with type 2") == pytest.approx(0.05, abs=0.05)
示例#24
0
 def test_function(self):
     d = XMLDomainReader.extract_domain("test/data/relationaltest.xml")
     system = DialogueSystem(d)
     system.get_settings().show_gui = False
     system.start_system()
示例#25
0
    def test_demo(self):
        domain = XMLDomainReader.extract_domain(TestDemo.domain_file2)
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False

        system.start_system()
        assert len(system.get_state().get_chance_nodes()) == 5

        t = CategoricalTableBuilder("u_u")
        t.add_row("hello there", 0.7)
        t.add_row("hello", 0.2)
        updates = system.add_content(t.build())
        assert updates.issuperset({"a_u", "a_m", "u_m"})

        assert str(system.get_content("u_m").get_best()) == "Hi there"

        t2 = dict()
        t2["move forward"] = 0.06
        system.add_user_input(t2)

        assert not system.get_state().has_chance_node("u_m")

        t2 = dict()
        t2["move forward"] = 0.45
        system.add_user_input(t2)

        assert str(
            system.get_content("u_m").get_best()) == "OK, moving Forward"

        t = CategoricalTableBuilder("u_u")
        t.add_row("now do that again", 0.3)
        t.add_row("move backward", 0.22)
        t.add_row("move a bit to the left", 0.22)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "Sorry, could you repeat?"

        t = CategoricalTableBuilder("u_u")
        t.add_row("do that one more time", 0.65)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "OK, moving Forward"

        system.add_content(SingleValueDistribution("perceived", "[BlueObj]"))

        t = CategoricalTableBuilder("u_u")
        t.add_row("what do you see", 0.6)
        t.add_row("do you see it", 0.3)
        system.add_content(t.build())
        assert str(
            system.get_content("u_m").get_best()) == "I see a blue cylinder"

        t = CategoricalTableBuilder("u_u")
        t.add_row("pick up the blue object", 0.75)
        t.add_row("turn left", 0.12)
        system.add_content(t.build())

        assert str(system.get_content(
            "u_m").get_best()) == "OK, picking up the blue object"

        system.add_content(SingleValueDistribution("perceived", "[]"))
        system.add_content(SingleValueDistribution("carried", "[BlueObj]"))

        t = CategoricalTableBuilder("u_u")
        t.add_row("now please move a bit forward", 0.21)
        t.add_row("move backward a little bit", 0.13)
        system.add_content(t.build())

        assert str(system.get_content(
            "u_m").get_best()) == "Should I move a bit forward?"

        t = CategoricalTableBuilder("u_u")
        t.add_row("yes", 0.8)
        t.add_row("move backward", 0.1)
        system.add_content(t.build())

        assert str(system.get_content(
            "u_m").get_best()) == "OK, moving Forward a little bit"

        t = CategoricalTableBuilder("u_u")
        t.add_row("and now move forward", 0.21)
        t.add_row("move backward", 0.09)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "Should I move forward?"

        t = CategoricalTableBuilder("u_u")
        t.add_row("no", 0.6)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "Should I move backward?"

        t = CategoricalTableBuilder("u_u")
        t.add_row("yes", 0.5)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "OK, moving Backward"

        t = CategoricalTableBuilder("u_u")
        t.add_row("now what can you see now?", 0.7)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "I do not see anything"

        t = CategoricalTableBuilder("u_u")
        t.add_row("please release the object", 0.5)
        system.add_content(t.build())

        assert str(system.get_content(
            "u_m").get_best()) == "OK, putting down the object"

        t = CategoricalTableBuilder("u_u")
        t.add_row("something unexpected", 0.7)
        system.add_content(t.build())

        assert not system.get_state().has_chance_node("u_m")

        t = CategoricalTableBuilder("u_u")
        t.add_row("goodbye", 0.7)
        system.add_content(t.build())

        assert str(
            system.get_content("u_m").get_best()) == "Bye, see you next time"
示例#26
0
class TestRule1:
    domain = XMLDomainReader.extract_domain('test/data/domain1.xml')
    inference = InferenceChecks()

    def test_1(self):
        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "a_u", "Greeting",
                                       0.8)
        TestRule1.inference.check_prob(system.get_state(), "a_u", "None", 0.2)

        StatePruner.enable_reduction = True

    def test_2(self):
        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "i_u", "Inform",
                                       0.7 * 0.8)
        TestRule1.inference.check_prob(system.get_state(), "i_u", "None",
                                       1. - 0.7 * 0.8)

        StatePruner.enable_reduction = True

    def test_3(self):
        InferenceChecks.exact_threshold = 0.06

        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "direction",
                                       "straight", 0.79)
        TestRule1.inference.check_prob(system.get_state(), "direction", "left",
                                       0.20)
        TestRule1.inference.check_prob(system.get_state(), "direction",
                                       "right", 0.01)

        StatePruner.enable_reduction = True

    def test_4(self):
        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have var1=value2", 0.3)
        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have localvar=value1", 0.2)
        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have localvar=value3", 0.28)

        StatePruner.enable_reduction = True

    def test_5(self):
        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "o2",
                                       "here is value1", 0.35)
        TestRule1.inference.check_prob(system.get_state(), "o2",
                                       "and value2 is over there", 0.07)
        TestRule1.inference.check_prob(system.get_state(), "o2",
                                       "value3, finally", 0.28)

        StatePruner.enable_reduction = True

    def test_6(self):
        InferenceChecks.exact_threshold = 0.06

        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        system.start_system()

        builder = CategoricalTableBuilder("var1")
        builder.add_row("value2", 0.9)
        system.add_content(builder.build())

        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have var1=value2", 0.9)
        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have localvar=value1", 0.05)
        TestRule1.inference.check_prob(system.get_state(), "o",
                                       "and we have localvar=value3", 0.04)

        StatePruner.enable_reduction = True

    def test_7(self):
        system = DialogueSystem(TestRule1.domain)
        system.detach_module(ForwardPlanner)
        StatePruner.enable_reduction = False
        system.get_settings().show_gui = False
        system.start_system()

        TestRule1.inference.check_prob(system.get_state(), "a_u2",
                                       "[Greet, HowAreYou]", 0.7)
        TestRule1.inference.check_prob(system.get_state(), "a_u2", "[]", 0.1)
        TestRule1.inference.check_prob(system.get_state(), "a_u2",
                                       "[HowAreYou]", 0.2)

        StatePruner.enable_reduction = True
示例#27
0
class TestRule2:
    domain_file = "test/data/domain2.xml"
    domain_file2 = "test/data/domain3.xml"
    domain_file3 = "test/data/domain4.xml"
    domain_file4 = "test/data/thesistest2.xml"

    domain = XMLDomainReader.extract_domain(domain_file)
    inference = InferenceChecks()

    def test_1(self):
        system = DialogueSystem(TestRule2.domain)
        eq_factor = EquivalenceDistribution.none_prob
        EquivalenceDistribution.none_prob = 0.1
        old_prune_threshold = StatePruner.value_pruning_threshold
        StatePruner.value_pruning_threshold = 0.0

        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        system.start_system()

        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "Ask(A)",
                                       0.63)
        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "Ask(B)",
                                       0.27)
        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "None",
                                       0.1)

        builder = CategoricalTableBuilder("a_u")
        builder.add_row("Ask(B)", 0.8)
        builder.add_row("None", 0.2)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        system.add_content(builder.build())

        TestRule2.inference.check_prob(system.get_state(), "i_u", "Want(A)",
                                       0.090)
        TestRule2.inference.check_prob(system.get_state(), "i_u", "Want(B)",
                                       0.91)

        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "Ask(B)",
                                       0.91 * 0.9)
        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "Ask(A)",
                                       0.09 * 0.9)
        TestRule2.inference.check_prob(system.get_state(), "a_u^p", "None",
                                       0.1)

        TestRule2.inference.check_prob(system.get_state(), "a_u", "Ask(B)",
                                       0.918)
        TestRule2.inference.check_prob(system.get_state(), "a_u", "None",
                                       0.081)

        EquivalenceDistribution.none_prob = eq_factor
        StatePruner.value_pruning_threshold = old_prune_threshold

    def test_2(self):
        system = DialogueSystem(TestRule2.domain)
        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        eq_factor = EquivalenceDistribution.none_prob
        EquivalenceDistribution.none_prob = 0.1
        old_prune_threshold = StatePruner.value_pruning_threshold
        StatePruner.value_pruning_threshold = 0.0
        system.start_system()

        TestRule2.inference.check_prob(system.get_state(), "u_u2^p", "Do A",
                                       0.216)
        TestRule2.inference.check_prob(system.get_state(), "u_u2^p",
                                       "Please do C", 0.027)
        TestRule2.inference.check_prob(system.get_state(), "u_u2^p",
                                       "Could you do B?", 0.054)
        TestRule2.inference.check_prob(system.get_state(), "u_u2^p",
                                       "Could you do A?", 0.162)
        TestRule2.inference.check_prob(system.get_state(), "u_u2^p", "none",
                                       0.19)

        builder = CategoricalTableBuilder("u_u2")
        builder.add_row("Please do B", 0.4)
        builder.add_row("Do B", 0.4)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())
        system.add_content(builder.build())

        TestRule2.inference.check_prob(system.get_state(), "i_u2", "Want(B)",
                                       0.654)
        TestRule2.inference.check_prob(system.get_state(), "i_u2", "Want(A)",
                                       0.1963)
        TestRule2.inference.check_prob(system.get_state(), "i_u2", "Want(C)",
                                       0.0327)
        TestRule2.inference.check_prob(system.get_state(), "i_u2", "none",
                                       0.1168)

        EquivalenceDistribution.none_prob = eq_factor
        StatePruner.value_pruning_threshold = old_prune_threshold

    def test_3(self):
        system = DialogueSystem(TestRule2.domain)
        system.get_settings().show_gui = False
        system.detach_module(ForwardPlanner)
        eq_factor = EquivalenceDistribution.none_prob
        EquivalenceDistribution.none_prob = 0.1
        old_prune_threshold = StatePruner.value_pruning_threshold
        StatePruner.value_pruning_threshold = 0.0
        system.start_system()

        TestRule2.inference.check_util(system.get_state(), "a_m'", "Do(A)",
                                       0.6)
        TestRule2.inference.check_util(system.get_state(), "a_m'", "Do(B)",
                                       -2.6)

        builder = CategoricalTableBuilder("a_u")
        builder.add_row("Ask(B)", 0.8)
        builder.add_row("None", 0.2)
        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())
        system.add_content(builder.build())

        TestRule2.inference.check_util(system.get_state(), "a_m'", "Do(A)",
                                       -4.35)
        TestRule2.inference.check_util(system.get_state(), "a_m'", "Do(B)",
                                       2.357)

        EquivalenceDistribution.none_prob = eq_factor
        StatePruner.value_pruning_threshold = old_prune_threshold

    def test_4(self):
        domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file2)
        system2 = DialogueSystem(domain2)
        system2.get_settings().show_gui = False
        system2.detach_module(ForwardPlanner)
        system2.start_system()

        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment(
                [Assignment("a_m3'", "Do"),
                 Assignment("obj(a_m3)'", "A")]), 0.3)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment(
                [Assignment("a_m3'", "Do"),
                 Assignment("obj(a_m3)'", "B")]), -1.7)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_m3'", "obj(a_m3)'"],
            Assignment([
                Assignment("a_m3'", "SayHi"),
                Assignment("obj(a_m3)'", "None")
            ]), -0.9)

    def test_5(self):
        domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file3)
        system2 = DialogueSystem(domain2)
        system2.detach_module(ForwardPlanner)
        system2.get_settings().show_gui = False
        system2.start_system()

        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "Nod"),
                Assignment("a_md'", "None")
            ]), 2.4)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "Nod"),
                Assignment("a_md'", "DanceAround")
            ]), -0.6)
        TestRule2.inference.check_util(
            system2.get_state(), ["a_ml'", "a_mg'", "a_md'"],
            Assignment([
                Assignment("a_ml'", "SayYes"),
                Assignment("a_mg'", "None"),
                Assignment("a_md'", "None")
            ]), 1.6)

    def test_6(self):
        domain2 = XMLDomainReader.extract_domain(TestRule2.domain_file4)
        system2 = DialogueSystem(domain2)
        system2.detach_module(ForwardPlanner)
        system2.get_settings().show_gui = False
        system2.start_system()
        TestRule2.inference.check_prob(system2.get_state(), "A",
                                       ValueFactory.create("[a1,a2]"), 1.0)
        TestRule2.inference.check_prob(system2.get_state(), "a_u",
                                       "Request(ball)", 0.5)
示例#28
0
parser.add_argument('--simulator', type=str, help='simulator file path')
args = parser.parse_args()

# Set logger
logger = logging.getLogger('PyOpenDial')
logger.setLevel(logging.DEBUG)
logger.addHandler(logging.StreamHandler())

if args.domain is not None:
    system = DialogueSystem(args.domain)
else:
    system = DialogueSystem()

if args.domain:
    try:
        domain = XMLDomainReader.extract_domain(args.domain)
        system.log.info("Domain from %s successfully extracted" % args.domain)
    except Exception as e:
        system.display_comment("Cannot load domain: %s" % e)
        domain = XMLDomainReader.extract_empty_domain(args.domain)
    system.change_domain(domain)

if args.simulator:
    simulator = Simulator(system,
                          XMLDomainReader.extract_domain(args.simulator))
    system.log.info("Simulator with domain %s successfully extracted" %
                    args.simulator)
    system.attach_module(simulator)

settings = system.get_settings()
system.change_settings(settings)
示例#29
0
    def test_domain(self):
        system = DialogueSystem(
            XMLDomainReader.extract_domain(
                "test/data/example-step-by-step_params.xml"))
        system.detach_module(ForwardPlanner)
        system.get_settings().show_gui = False
        system.start_system()

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("move a little bit left", 0.4)
        o1.add_row("please move a little right", 0.3)
        system.add_content(o1.build())
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(-0.1, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("move a little bit left", 0.5)
        o1.add_row("please move a little left", 0.2)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(0.2, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("now move right please", 0.8)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Right)")) == pytest.approx(0.3, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("move left", 0.7)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(0.2, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("turn left", 0.32)
        o1.add_row("move left again", 0.3)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(0.1, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("and do that again", 0.0)
        system.add_content(o1.build())

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("turn left", 1.0)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(0.5, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("turn right", 0.4)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Right)")) == pytest.approx(-0.1, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("please turn right", 0.8)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Right)")) == pytest.approx(0.3, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())

        o1 = CategoricalTableBuilder("u_u")
        o1.add_row("and turn a bit left", 0.3)
        o1.add_row("move a bit left", 0.3)
        system.add_content(o1.build())

        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "AskRepeat")) == pytest.approx(0.0, abs=0.3)
        assert system.get_state().query_util(["a_m'"]).get_util(
            Assignment("a_m'", "Move(Left)")) == pytest.approx(0.1, abs=0.15)

        system.get_state().remove_nodes(
            system.get_state().get_action_node_ids())
        system.get_state().remove_nodes(
            system.get_state().get_utility_node_ids())
示例#30
0
class TestRefResolution:
    domain = XMLDomainReader.extract_domain("test/data/refres.xml")

    def test_nlu(self):
        system = DialogueSystem(TestRefResolution.domain)
        system.get_settings().show_gui = False
        system.start_system()
        system.add_user_input("take the red box")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=box, def=def, nb=sg, attr=red]")
        system.add_user_input("take the big yellow box")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=box, def=def, nb=sg, attr=big, attr=yellow]")
        system.add_user_input("take the big and yellow box")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=box, def=def, nb=sg, attr=big, attr=yellow]")
        system.add_user_input("take the big box on your left")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[rel=left(agent), type=box, def=def, nb=sg, attr=big]")
        system.add_user_input("take the big box on the left")
        assert system.get_content("properties(ref_main)").to_discrete().get_prob("[rel=left(agent), type=box, def=def, nb=sg, attr=big]") == pytest.approx(0.5, abs=0.01)
        assert system.get_content("properties(ref_main)").to_discrete().get_prob("[rel=left(spk), type=box, def=def, nb=sg, attr=big]") == pytest.approx(0.5, abs=0.01)
        system.add_user_input("take one box now")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[def=indef, nb=sg, type=box]")
        system.add_user_input("take the small and ugly box ")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=box, def=def, nb=sg, attr=small, attr=ugly]")
        system.add_user_input("now please pick up the book that is behind you")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=book, def=def, nb=sg, rel=behind(ref_behind)]")
        assert system.get_content("ref_behind").get_best() == ValueFactory.create("you")
        assert system.get_content("ref_main").get_best() == ValueFactory.create("the book")
        system.add_user_input("could you take the red ball on the desk")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=ball, attr=red, rel=on(ref_on), def=def, nb=sg]")
        assert system.get_content("ref_main").get_best() == ValueFactory.create("the red ball")
        assert system.get_content("ref_on").get_best() == ValueFactory.create("the desk")
        system.add_user_input("could you take the red ball next to the window")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=ball, attr=red, rel=next to(ref_next to), def=def, nb=sg]")
        assert system.get_content("ref_main").get_best() == ValueFactory.create("the red ball")
        assert system.get_content("ref_next to").get_best() == ValueFactory.create("the window")

        system.add_user_input("could you take the big red ball near the window to your left")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=ball, attr=red, attr=big, rel=near(ref_near), def=def, nb=sg]")
        assert system.get_content("ref_main").get_best() == ValueFactory.create("the big red ball")
        assert system.get_content("properties(ref_near)").get_best() == ValueFactory.create("[type=window, rel=left(agent), def=def, nb=sg]")
        assert system.get_content("ref_near").get_best() == ValueFactory.create("the window")

        system.add_user_input("could you take the big red ball near the window and to your left")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=ball, attr=red, attr=big, rel=left(agent), rel=near(ref_near), def=def, nb=sg]")
        assert system.get_content("ref_main").get_best() == ValueFactory.create("the big red ball")
        assert system.get_content("properties(ref_near)").get_best() == ValueFactory.create("[type=window, def=def, nb=sg]")
        assert system.get_content("ref_near").get_best() == ValueFactory.create("the window")

        system.add_user_input("and now pick up the books that are on top of the shelf")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=book, rel=top(ref_top), def=def, nb=pl]")
        assert system.get_content("properties(ref_top)").get_best() == ValueFactory.create("[type=shelf,def=def, nb=sg]")
        system.add_user_input("and now pick up one book which is big")
        assert system.get_content("properties(ref_main)").get_best() == ValueFactory.create("[type=book,def=indef, attr=big, nb=sg]")

        nbest = dict()
        nbest["and take the red book"] = 0.5
        nbest["and take the wred hook"] = 0.1
        system.add_user_input(nbest)
        assert system.get_content("properties(ref_main)").get_prob("[type=book,attr=red,def=def,nb=sg]") == pytest.approx(0.5, abs=0.01)
        assert system.get_content("properties(ref_main)").get_prob("[type=hook,attr=wred,def=def,nb=sg]") == pytest.approx(0.1, abs=0.01)

    def test_resolution(self):
        system = DialogueSystem(TestRefResolution.domain)
        system.get_settings().show_gui = False
        system.start_system()

        system.add_user_input("take the red ball")
        assert str(system.get_content("a_m").get_best()) == "Select(object_1)"
        assert system.get_content("matches(ref_main)").get_prob("[object_1]") == pytest.approx(0.94, abs=0.05)

        system.add_user_input("take the red object")
        assert str(system.get_content("a_m").get_best()) == "AskConfirm(object_1)"
        assert system.get_content("matches(ref_main)").get_prob("[object_1,object_3]") == pytest.approx(0.39, abs=0.05)
        assert system.get_content("matches(ref_main)").get_prob("[object_1]") == pytest.approx(0.108, abs=0.05)

        system.add_user_input("take the box")
        assert str(system.get_content("a_m").get_best()) == "Select(object_2)"
        assert system.get_content("matches(ref_main)").get_prob("[object_2]") == pytest.approx(0.7, abs=0.05)
        assert system.get_content("matches(ref_main)").get_prob("[]") == pytest.approx(0.3, abs=0.05)

        nbest = dict()
        nbest["and take the ball now"] = 0.3
        system.add_user_input(nbest)
        assert str(system.get_content("a_m").get_best()) == "AskConfirm(object_1)"
        assert system.get_content("matches(ref_main)").get_prob("[object_1]") == pytest.approx(0.27, abs=0.005)
        system.add_user_input("yes")
        assert str(system.get_content("a_m").get_best()) == "Select(object_1)"

        system.add_user_input("pick up the newspaper")
        assert str(system.get_content("a_m").get_best()) == "Failed(the newspaper)"

        system.add_user_input("pick up an object")
        assert str(system.get_content("a_m").get_best()).startswith("AskConfirm(object_")
        assert system.get_content("matches(ref_main)").get_prob("[object_1,object_2,object_3]") == pytest.approx(1.0, abs=0.005)
        system.add_user_input("no")
        assert str(system.get_content("a_m").get_best()).startswith("AskConfirm(object_")
        system.add_user_input("yes")
        assert str(system.get_content("a_m").get_best()).startswith("Select(object_")

        system.add_user_input("pick up the ball to the left of the box")
        assert str(system.get_content("a_m").get_best()) == "Select(object_1)"
        assert system.get_content("matches(ref_main)").get_prob("[object_1]") == pytest.approx(0.75, abs=0.05)

        system.add_user_input("pick up the box to the left of the ball")
        assert str(system.get_content("a_m").get_best()) == "Failed(the box)"
        assert system.get_content("matches(ref_main)").get_prob("[object_2]") == pytest.approx(0.34, abs=0.05)

    def test_underspec(self):
        domain = XMLDomainReader.extract_domain("test/data/underspectest.xml")
        system = DialogueSystem(domain)
        system.get_settings().show_gui = False
        StatePruner.enable_reduction = False
        system.start_system()
        assert system.get_content("match").get_prob("obj_1") == pytest.approx(0.66, abs=0.05)
        assert system.get_content("match").get_prob("obj_3") == pytest.approx(0.307, abs=0.05)
        assert len(system.get_state().get_chance_node_ids()) == 14
        StatePruner.enable_reduction = True