Esempio n. 1
0
 def test_save_load(self):
     model = get_model(
         n_vocab=50257,
         n_ctx=1024,
         n_embd=768,
         n_head=12,
         n_layer=12,
     )
     model_path = os.path.join(tempfile.gettempdir(), 'test_gpt_2_%f.h5' % np.random.random())
     model.save(model_path)
     model = keras.models.load_model(model_path, custom_objects=get_custom_objects())
     model.summary()
Esempio n. 2
0
 def test_train_and_gen(self):
     token_dict = {chr(i): i for i in range(2 ** 9)}
     token_dict['Po'] = len(token_dict)
     token_dict['er'] = len(token_dict)
     model = get_model(
         n_vocab=len(token_dict),
         n_ctx=100,
         n_embd=30,
         n_head=5,
         n_layer=2,
     )
     bpe = BytePairEncoding(token_dict=token_dict, bpe_rank={('P', 'o'): 0, ('e', 'r'): 1})
     texts = [
         'Power, give me more power!',
         'From the day forth, my arm changed.',
     ]
     space_encode = bpe.encode(' ')
     inputs = [bpe.encode(text) for text in texts]
     max_len = max(map(len, inputs))
     inputs = [encode + space_encode * (max_len - len(encode)) for encode in inputs]
     outputs = [encode[1:] + space_encode for encode in inputs]
     current_path = os.path.dirname(os.path.abspath(__file__))
     model_path = os.path.join(current_path, 'gen_test.h5')
     if os.path.exists(model_path):
         model.load_weights(model_path)
         model.fit(
             x=np.array(inputs * 1000),
             y=np.expand_dims(np.array(outputs * 1000), axis=-1),
             epochs=1,
         )
     else:
         model.fit(
             x=np.array(inputs * 1000),
             y=np.expand_dims(np.array(outputs * 1000), axis=-1),
             epochs=10,
         )
         model.save_weights(model_path)
     texts = [
         'Power, give me more',
         'Power',
         'give me more ',
         'the day forth ',
         'From',
     ]
     results = generate(model, bpe, texts, length=30)
     self.assertEqual(results[0][:len('Power, give me more power!')], 'Power, give me more power!')