Example #1
0
    def __init__(self, intents=None, testing=False):
        self.td = TopicProcessor()
        self.cp = ConversationProcessor()
        self.convX, self.convy = self.cp.get_dataset()

        self.topX, self.topy = self.td.get_dataset()
        self.topy = np_utils.to_categorical(self.topy)
        self.top_classes = self.td.get_num_categories()
        self.top_sentence_length = self.td.get_sentence_length()
        self.graph = None
        self.seq2seqmodel = seq2seqmodel()
        if not testing:
            # Try to load an existing model from file
            try:
                clear_session()
                self.model, self.graph, self.sess = self._load_model()
            except (FileNotFoundError, OSError):
                # If the model doesn't exist, we'll have to build one
                model = self._build_model()
                self.model = self._train_model(model, self.convX, self.convy,
                                               self.topX, self.topy)
                self.model, self.graph, self.sess = self._load_model()
            # If the user hasn't supplied a location, use the default
            if not intents:
                intents = 'chatbot/corpus/output.json'

            with open(intents, 'r') as json_file:
                self.intents = json.load(json_file)
        else:
            # self._test_model(self.convX, self.convy, self.topX, self.topy)
            self._k_fold_test(self.convX, self.convy, self.topX, self.topy)
Example #2
0
 def test_get_sentence_length(self):
     tp = TopicProcessor()
     test_length = 4
     tp.max_sentence_length = test_length
     max_sentence_length = tp.get_sentence_length()
     self.assertEqual(test_length, max_sentence_length)
Example #3
0
 def test__load_files(self):
     tp = TopicProcessor()
     statements, topics = tp._load_files()
     if not len(statements):
         self.fail()
Example #4
0
 def test_get_dataset(self):
     tp = TopicProcessor()
     X, y = tp.get_dataset()
     if not X[0].size > 1:
         self.fail()
Example #5
0
 def test_get_num_categories(self):
     tp = TopicProcessor()
     test_num = 4
     tp.num_topics = test_num
     num_topics = tp.get_num_categories()
     self.assertEqual(test_num, num_topics)
Example #6
0
 def test_init(self):
     tp = TopicProcessor()
Example #7
0
 def test_decode_topic(self):
     tp = TopicProcessor()
     original_topic = 15
     decoded_topic = tp.decode_topic(original_topic)
     self.assertNotEqual(original_topic, decoded_topic)
     self.assertEqual(decoded_topic, 'abuse')
Example #8
0
 def test__build_dictionary(self):
     tp = TopicProcessor()
     _, topics = tp._load_files()
     dictionary, reverse_dictionary, num_topics = tp._build_dictionarys(topics)
     if not len(dictionary) > 1:
         self.fail()
Example #9
0
 def test_encode_topic(self):
     tp = TopicProcessor()
     original_topic = 'abuse'
     encoded_topic = tp.encode_topic(original_topic)
     self.assertNotEqual(original_topic, encoded_topic)
     self.assertEqual(encoded_topic, 15)
Example #10
0
 def test_encode_statement(self):
     tp = TopicProcessor()
     original_statement = 'My parents hit me'
     new_statement = tp.encode_statement(original_statement)
     self.assertNotEqual(original_statement, new_statement)
Example #11
0
 def test__pad_lines(self):
     tp = TopicProcessor()
     padded_lines = tp._pad_lines([[40]])
     if not len(padded_lines[0]) > 39 or not 40 in padded_lines[0]:
         self.fail()
Example #12
0
 def test__vectorize_text(self):
     tp = TopicProcessor()
     statements, _ = tp._load_files()
     vectorized_text = tp._vectorize_text(statements)
     self.assertNotEqual(statements, vectorized_text)
Example #13
0
 def test__create_tokenizer(self):
     tp = TopicProcessor()
     statements, _ = tp._load_files()
     tokenizer = tp._create_tokenizer(statements)
     if not isinstance(tokenizer, keras.preprocessing.text.Tokenizer):
         self.fail()
Example #14
0
 def test__clean_text(self):
     tp = TopicProcessor()
     statements_1, topics_1 = tp._load_files()
     statements_2, topics_2 = tp._clean_text(statements_1, topics_1)
     self.assertNotEqual(statements_1, statements_2)