예제 #1
0
 def test_get_batch(self):
     sentences = [
         ['All', 'work', 'and', 'no', 'play'],
         ['makes', 'Jack', 'a', 'dull', 'boy', '.'],
     ]
     token_dict = {
         'all': 3,
         'work': 4,
         'and': 5,
         'no': 6,
         'play': 7,
         'makes': 8,
         'a': 9,
         'dull': 10,
         'boy': 11,
         '.': 12,
     }
     inputs, outputs = BiLM.get_batch(sentences, token_dict, ignore_case=False)
     expect = [
         [1, 4, 5, 6, 7, 0],
         [8, 1, 9, 10, 11, 12],
     ]
     self.assertEqual(expect, inputs.tolist())
     expect = [
         [[4], [5], [6], [7], [2], [0]],
         [[1], [9], [10], [11], [12], [2]],
     ]
     self.assertEqual(expect, outputs[0].tolist())
     expect = [
         [[2], [1], [4], [5], [6], [0]],
         [[2], [8], [1], [9], [10], [11]],
     ]
     self.assertEqual(expect, outputs[1].tolist())
     inputs, outputs = BiLM.get_batch(sentences, token_dict, ignore_case=True)
     expect = [
         [3, 4, 5, 6, 7, 0],
         [8, 1, 9, 10, 11, 12],
     ]
     self.assertEqual(expect, inputs.tolist())
     expect = [
         [[4], [5], [6], [7], [2], [0]],
         [[1], [9], [10], [11], [12], [2]],
     ]
     self.assertEqual(expect, outputs[0].tolist())
     expect = [
         [[2], [3], [4], [5], [6], [0]],
         [[2], [8], [1], [9], [10], [11]],
     ]
     self.assertEqual(expect, outputs[1].tolist())
예제 #2
0
def train_batch_generator(batch_size=32, training=True):
    batch_size //= 2
    while True:
        sentences = []
        if training:
            batch_pos = random.sample(train_pos_files, batch_size)
            batch_neg = random.sample(train_neg_files, batch_size)
        else:
            batch_pos = random.sample(val_pos_files, batch_size)
            batch_neg = random.sample(val_neg_files, batch_size)
        for file_name in batch_pos:
            with codecs.open(os.path.join(TRAIN_ROOT, 'pos', file_name), 'r',
                             'utf8') as reader:
                text = reader.read().strip()
                sentences.append(get_word_list_eng(text))
        for file_name in batch_neg:
            with codecs.open(os.path.join(TRAIN_ROOT, 'neg', file_name), 'r',
                             'utf8') as reader:
                text = reader.read().strip()
            sentences.append(get_word_list_eng(text))
        word_input, _ = BiLM.get_batch(
            sentences=sentences,
            token_dict=word_dict,
            ignore_case=True,
        )
        yield word_input, keras.utils.to_categorical([1] * batch_size +
                                                     [0] * batch_size)
예제 #3
0
 def test_bidirectional_overfitting(self):
     sentences = [
         ['All', 'work', 'and', 'no', 'play'],
         ['makes', 'Jack', 'a', 'dull', 'boy', '.'],
     ]
     token_dict = {
         '': 0,
         '<UNK>': 1,
         '<EOS>': 2,
         'all': 3,
         'work': 4,
         'and': 5,
         'no': 6,
         'play': 7,
         'makes': 8,
         'a': 9,
         'dull': 10,
         'boy': 11,
         '.': 12,
     }
     token_dict_rev = {v: k for k, v in token_dict.items()}
     inputs, outputs = BiLM.get_batch(sentences,
                                      token_dict,
                                      ignore_case=True,
                                      unk_index=token_dict['<UNK>'],
                                      eos_index=token_dict['<EOS>'])
     bi_lm = BiLM(token_num=len(token_dict),
                  embedding_dim=10,
                  rnn_units=10,
                  use_bidirectional=True)
     bi_lm.model.summary()
     bi_lm.fit(
         np.repeat(inputs, 2**12, axis=0),
         [
             np.repeat(outputs[0], 2**12, axis=0),
             np.repeat(outputs[1], 2**12, axis=0),
         ],
         epochs=5,
     )
     predict = bi_lm.predict(inputs)
     forward = predict[0].argmax(axis=-1)
     backward = predict[1].argmax(axis=-1)
     self.assertEqual(
         'work and no play <EOS>',
         ' '.join(map(lambda x: token_dict_rev[x],
                      forward[0].tolist()[:-1])).strip())
     self.assertEqual(
         '<UNK> a dull boy . <EOS>',
         ' '.join(map(lambda x: token_dict_rev[x],
                      forward[1].tolist())).strip())
     self.assertEqual(
         '<EOS> all work and no', ' '.join(
             map(lambda x: token_dict_rev[x],
                 backward[0].tolist()[:-1])).strip())
     self.assertEqual(
         '<EOS> makes <UNK> a dull boy',
         ' '.join(map(lambda x: token_dict_rev[x],
                      backward[1].tolist())).strip())
예제 #4
0
def train_lm_generator(batch_size=32):
    while True:
        index = 0
        while index * batch_size < len(sentences):
            batch_sentences = sentences[index * batch_size:(index + 1) *
                                        batch_size]
            inputs, outputs = BiLM.get_batch(batch_sentences,
                                             token_dict=word_dict,
                                             ignore_case=True)
            yield inputs, outputs
예제 #5
0
def lm_batch_generator(sentences, steps):
    global word_dict, char_dict, max_word_len
    while True:
        for i in range(steps):
            batch_sentences = sentences[BATCH_SIZE *
                                        i:min(BATCH_SIZE *
                                              (i + 1), len(sentences))]
            inputs, outputs = BiLM.get_batch(
                sentences=batch_sentences,
                token_dict=word_dict,
                ignore_case=True,
                unk_index=word_dict['<UNK>'],
                eos_index=word_dict['<EOS>'],
            )
            yield inputs, outputs
예제 #6
0
def test_batch_generator(batch_size=32):
    batch_size //= 2
    index = 0
    while index < test_num:
        sentences = []
        batch_pos = test_pos_files[index:min(index + batch_size, test_num)]
        batch_neg = test_neg_files[index:min(index + batch_size, test_num)]
        index += batch_size
        for file_name in batch_pos:
            with codecs.open(os.path.join(TEST_ROOT, 'pos', file_name), 'r',
                             'utf8') as reader:
                text = reader.read().strip()
                sentences.append(get_word_list_eng(text))
        for file_name in batch_neg:
            with codecs.open(os.path.join(TEST_ROOT, 'neg', file_name), 'r',
                             'utf8') as reader:
                text = reader.read().strip()
            sentences.append(get_word_list_eng(text))
        word_input, _ = BiLM.get_batch(
            sentences=sentences,
            token_dict=word_dict,
            ignore_case=True,
        )
        yield word_input