def test_fn_transition(self):

        transitions = ['not good', 'okay', 'good']

        def transition_fn(input):
            if input < 80:
                return 0
            elif input < 90:
                return 1
            elif input >= 90:
                return 2
            else:
                raise ValueError

        n = Node(
            name='check grade',
            content='How did you do?',
            options='is how I did',
            message_type='direct input',
            transitions=transitions,
            transition_fn=transition_fn,
        )
        for i in range(80):
            self.assertEqual(transitions[0], n.get_transition(i))
        for i in range(80, 90):
            self.assertEqual(transitions[1], n.get_transition(i))
        for i in range(90, 110):
            self.assertEqual(transitions[2], n.get_transition(i))
 def test_single_choice_transitions(self):
     options = ['Good', 'Okay', 'Bad']
     transition = 'oh'
     n = Node(name='name',
              content='How are you?',
              options=options,
              message_type='multiple choice',
              transitions=transition)
     for i in range(len(options)):
         self.assertEqual(transition, n.get_transition(options[i]))
 def test_multiple_choice_transition(self):
     options = ['Good', 'Okay', 'Bad']
     transitions = ['great', 'good', 'too bad']
     n = Node(name='name',
              content='How are you?',
              options=options,
              message_type='multiple choice',
              transitions=transitions)
     self.assertEqual(len(transitions), len(options))
     for i in range(len(options)):
         self.assertEqual(transitions[i], n.get_transition(options[i]))
    def test_valid_and_invalid_setups(self):
        one_option = 'o1'
        two_options = ['o1', 'o2']
        three_options = ['o1', 'o2', 'o3']

        multi = Message.Type.MULTIPLE_CHOICE
        direct = Message.Type.DIRECT_INPUT

        one_transition = 't1'
        two_transitions = ['t1', 't2']
        three_transitions = ['t1', 't2', 't3']

        is_fn = print
        none = None

        valid_sets = [
            [one_option, multi, one_transition, none],
            [two_options, multi, two_transitions, none],
            [three_options, multi, three_transitions, none],
            [three_options, multi, one_transition, none],
            [one_option, direct, one_transition, none],
            [one_option, direct, one_transition, is_fn],
            [one_option, direct, two_transitions, is_fn],
            [one_option, direct, three_transitions, is_fn],
        ]
        invalid_sets = [
            [one_option, multi, one_transition, is_fn],
            [two_options, multi, one_transition, is_fn],
            [three_options, multi, one_transition, is_fn],
            [one_transition, direct, two_transitions, none],
            [one_transition, direct, three_transitions, none],
            [two_options, direct, one_transition, is_fn],
            [three_options, direct, one_transition, is_fn],
        ]

        for set_ in valid_sets:
            options, message_type, transitions, transition_fn = set_
            node = Node('name',
                        content='Here is a question, or is it?',
                        options=options,
                        message_type=message_type,
                        transitions=transitions,
                        transition_fn=transition_fn)
            self.assertEqual(Node, type(node))

        for set_ in invalid_sets:
            options, message_type, transitions, transition_fn = set_
            self.assertRaises(ValueError,
                              Node,
                              'name',
                              content='Here is a question, or is it?',
                              options=options,
                              message_type=message_type,
                              transitions=transitions,
                              transition_fn=transition_fn)
    def test_bad_node_transitions(self):

        # transition to a nonexistent node
        self.assertRaises(
            ValueError, lambda: DirectedGraph(
                name="Foo",
                start_node="node1",
                nodes=[
                    Node(name='node1',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='not a node'),
                    Node(name='node2',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='exit'),
                ]))

        # No exit
        self.assertRaises(
            ValueError, lambda: DirectedGraph(
                name="Foo",
                start_node="node1",
                nodes=[
                    Node(name='node1',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='node2'),
                    Node(name='node2',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='node1'),
                ]))

        # Bad start node
        self.assertRaises(
            KeyError, lambda: DirectedGraph(
                name="Foo",
                start_node="node that doesn't exist",
                nodes=[
                    Node(name='node1',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='node2'),
                    Node(name='node2',
                         content='foo',
                         options='bar',
                         message_type=Message.Type.MULTIPLE_CHOICE,
                         transitions='exit'),
                ]))
Exemple #6
0
    def build_graph_from_dict(self,
                              interactions_dict,
                              graph_name,
                              text_populator=None,
                              speaking_rate=None):
        graph_dict = interactions_dict[graph_name]
        start_node_name = graph_dict["start_node_name"]
        nodes = []

        valid_speaking_rates = ["x-slow", "slow", "medium", "fast", "x-fast"]
        if speaking_rate is not None and speaking_rate not in valid_speaking_rates:
            raise ValueError("Not a valid speaking rate")

        all_nodes_in_graph = graph_dict["nodes"]
        for node_name in all_nodes_in_graph.keys():
            node_info = all_nodes_in_graph[node_name]
            optional_values = {
                "args": None,
                "result_convert_from_str_fn": str,
                "result_db_key": None,
                "is_append_result": False,
                "tests": None,
                "is_confirm": False,
                "error_message": "Please enter a valid input",
                "error_options": ("Okay", "Oops")
            }
            for value in optional_values.keys():
                if value not in node_info:
                    node_info[value] = optional_values[value]

            if speaking_rate is not None:
                node_info["content"] = "<prosody rate=\"{speaking_rate}\">".format(speaking_rate=speaking_rate) + \
                                       node_info["content"] + "</prosody> "

            if node_info[
                    "result_convert_from_str_fn"] == "save_tomorrow_checkin_datetime":
                node_info[
                    "result_convert_from_str_fn"] = self.next_day_checkin_datetime_from_str
            if node_info[
                    "result_convert_from_str_fn"] == "later_today_checkin_datetime":
                node_info[
                    "result_convert_from_str_fn"] = self.later_today_checkin_datetime

            if node_info["tests"] == "check reading id":
                node_info["tests"] = self.check_reading_id

            node = Node(name=node_name,
                        transitions=node_info["transitions"],
                        content=node_info["content"],
                        options=node_info["options"],
                        message_type=node_info["message_type"],
                        args=node_info["args"],
                        result_convert_from_str_fn=node_info[
                            "result_convert_from_str_fn"],
                        result_db_key=node_info["result_db_key"],
                        is_append_result=node_info["is_append_result"],
                        tests=node_info["tests"],
                        is_confirm=node_info["is_confirm"],
                        error_message=node_info["error_message"],
                        error_options=node_info["error_options"],
                        text_populator=text_populator)
            nodes.append(node)

        return DirectedGraph(name=graph_name,
                             nodes=nodes,
                             start_node=start_node_name)
        json.dump(variation_dict, f)

    import atexit
    atexit.register(lambda: os.remove(db_file))
    atexit.register(lambda: os.remove(variation_file))

    variety_populator_ = VarietyPopulator(variation_file)
    database_populator_ = DatabasePopulator(db_file)
    text_populator = TextPopulator(variety_populator_, database_populator_)

    ask_name = Node(name='ask name',
                    content="What's your name?",
                    options='Okay',
                    message_type='text entry',
                    result_db_key='user_name',
                    result_convert_from_str_fn=str,
                    tests=lambda x: len(x) > 1,
                    error_message='Enter something with at least two letters',
                    is_confirm=True,
                    text_populator=text_populator,
                    transitions='ask age')
    ask_age = Node(name='ask age',
                   content="Alright, {'db': 'user_name'}, how old are you?",
                   options='years_old',
                   message_type='text entry',
                   result_convert_from_str_fn=float,
                   result_db_key='user_age',
                   tests=[
                       lambda x: x >= 0,
                       lambda x: x <= 200,
                   ],
    'options': options,
    'message_type': Message.Type.MULTIPLE_CHOICE,
}


def left_option():
    return options[0]


def right_option():
    return options[1]


node1 = Node(
    name='node1',
    transitions=['node2', 'node1'],
    **message_kwargs,
)
node2 = Node(
    name='node2',
    transitions=['node1', 'node3'],
    **message_kwargs,
)
node3 = Node(
    name='node3',
    transitions=['exit', 'exit'],
    **message_kwargs,
)


class TestDirectedGraph(unittest.TestCase):