def test_beam_search(self): """Tests beam_search """ decoder = TransformerDecoder(embedding=self._embedding) outputs = decoder( memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, beam_width=5, start_tokens=self._start_tokens, end_token=2, max_decoding_length=self._max_decode_len, mode=tf.estimator.ModeKeys.PREDICT) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_ = sess.run(outputs) self.assertEqual(outputs_['log_prob'].shape, (self._batch_size, 5)) self.assertEqual(outputs_['sample_id'].shape, (self._batch_size, self._max_decode_len, 5))
def test_train(self): """Tests train_greedy """ decoder = TransformerDecoder(embedding=self._embedding) # 6 blocks # -self multihead_attention: 4 dense without bias + 2 layer norm vars # -encdec multihead_attention: 4 dense without bias + 2 layer norm vars # -poswise_network: Dense with bias, Dense with bias + 2 layer norm vars # 2 layer norm vars outputs = decoder(memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=self._inputs, decoding_strategy='train_greedy', mode=tf.estimator.ModeKeys.TRAIN) self.assertEqual(len(decoder.trainable_variables), 110) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_ = sess.run(outputs) self.assertIsInstance(outputs_, TransformerDecoderOutput)
def test_beam_search(self): """Tests beam_search """ decoder = TransformerDecoder( vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() outputs = decoder( memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, beam_width=5, start_tokens=self._start_tokens, end_token=self._end_token, max_decoding_length=self._max_decode_len) self.assertEqual(outputs['log_prob'].shape, (self._batch_size, 5)) self.assertEqual(outputs['sample_id'].shape, (self._batch_size, self._max_decode_len, 5))
def test_decode_infer_sample(self): """Tests infer_sample """ decoder = TransformerDecoder( vocab_size=self._vocab_size, output_layer=self._output_layer ) helper = tx_helper.SampleEmbeddingHelper( self._embedding_fn, self._start_tokens, self._end_token) outputs, length = decoder( memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, helper=helper, max_decoding_length=self._max_decode_len, mode=tf.estimator.ModeKeys.PREDICT) with self.test_session() as sess: sess.run(tf.global_variables_initializer()) outputs_ = sess.run(outputs) self.assertIsInstance(outputs_, TransformerDecoderOutput)
def test_decode_train(self): """Tests train_greedy """ decoder = TransformerDecoder( vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.train() # 6 blocks # -self multihead_attention: 4 dense without bias + 2 layer norm vars # -encdec multihead_attention: 4 dense without bias + 2 layer norm vars # -poswise_network: Dense with bias, Dense with bias + 2 layer norm vars # 2 layer norm vars outputs = decoder(memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=self._inputs, decoding_strategy='train_greedy') # print(decoder) # for name, _ in decoder.named_parameters(): # print(name) self.assertEqual(len(decoder.trainable_variables), 110) self.assertIsInstance(outputs, TransformerDecoderOutput)
def __init__(self, pretrained_model_name: Optional[str] = None, cache_dir: Optional[str] = None, hparams=None): super().__init__(pretrained_model_name=pretrained_model_name, cache_dir=cache_dir, hparams=hparams) if self.pretrained_model_dir: self._hparams = HParams(self.pretrained_model_hparams, self._hparams.todict()) # Word embedding self.word_embedder = WordEmbedder(vocab_size=self._hparams.vocab_size, hparams=self._hparams.embed) # Position embedding self.position_embedder = PositionEmbedder( position_size=self._hparams.position_size, hparams=self._hparams.position_embed) # The GPT2 decoder (a TransformerDecoder) self.decoder = TransformerDecoder( vocab_size=self._hparams.vocab_size, output_layer=self.word_embedder.embedding, hparams=self._hparams.decoder) if self.pretrained_model_dir: gpt2_utils.init_gpt2_checkpoint(self, self.pretrained_model_dir) elif self._hparams.initializer: initialize = layers.get_initializer(self._hparams.initializer) assert initialize is not None # Do not re-initialize LayerNorm modules. for name, param in self.named_parameters(): if name.split( '.')[-1] == 'weight' and 'layer_norm' not in name: initialize(param)
def test_beam_search(self): """Tests beam_search """ decoder = TransformerDecoder(token_pos_embedder=self._embedding_fn, vocab_size=self._vocab_size, output_layer=self._output_layer) decoder.eval() beam_width = 5 outputs = decoder(memory=self._memory, memory_sequence_length=self._memory_sequence_length, memory_attention_bias=None, inputs=None, beam_width=beam_width, start_tokens=self._start_tokens, end_token=self._end_token, max_decoding_length=self._max_decode_len) self.assertEqual(outputs['log_prob'].size(), (self._batch_size, beam_width)) self.assertEqual(outputs['sample_id'].size(0), self._batch_size) self.assertLessEqual(outputs['sample_id'].size(2), self._max_decode_len) self.assertEqual(outputs['sample_id'].size(2), beam_width)
def main(_): """ Builds the model and runs """ np.random.seed(FLAGS.seed) tf.set_random_seed(FLAGS.seed) nsamples = FLAGS.nsamples batch_size = FLAGS.batch_size max_decoding_length = FLAGS.max_decoding_length ckpt_path = FLAGS.checkpoint # Load GPT-2 model configuration if FLAGS.config_type == "json": gpt2_config = model_utils.transform_gpt2_to_texar_config( FLAGS.config_model) elif FLAGS.config_type == 'texar': gpt2_config = importlib.import_module(FLAGS.config_model) else: raise ValueError('Unknown config_type.') assert max_decoding_length <= gpt2_config.decoder["position_size"], ( "max_decoding_length should be smaller than position size") assert nsamples % batch_size == 0, ( "nsamples must be dividable by batch_size") # Create a data pre-processor for, e.g., BPE encoding proc = processor.get_encoder("gpt2_pretrained_models/model_117M") context = tf.placeholder(tf.int32, [batch_size, None]) context_length = tf.placeholder(tf.int32, [batch_size]) end_token = proc.encoder['<|endoftext|>'] if FLAGS.is_interactive: start_tokens = context[:, 0] else: start_tokens = tf.fill([batch_size], end_token) # Build the GPT-2 modle embedder = tx.modules.WordEmbedder(vocab_size=gpt2_config.vocab_size, hparams=gpt2_config.embed) helper = tx.modules.TopKSampleEmbeddingHelper( embedding=embedder, start_tokens=start_tokens, end_token=end_token, top_k=FLAGS.top_k, softmax_temperature=FLAGS.temperature) decoder = TransformerDecoder(embedding=embedder.embedding, hparams=gpt2_config.decoder) with tf.Session() as sess: if FLAGS.is_interactive: # Generate continuations of context lm_output, _ = decoder(context=context, context_sequence_length=context_length, max_decoding_length=max_decoding_length, helper=helper, mode=tf.estimator.ModeKeys.PREDICT) # Load model checkpoint model_utils.init_gpt2_checkpoint(sess, ckpt_path) print("\nFinished loading\n") # Enter interactive mode while True: raw_text = input("Model input >>> ") while not raw_text: print('Input should not be empty!') raw_text = input("Model input >>> ") context_tokens = proc.encode(raw_text) feed_dict = { context: [context_tokens for _ in range(batch_size)], context_length: [len(context_tokens) for _ in range(batch_size)], tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 for _ in range(nsamples // batch_size): output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): generated += 1 print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) si = sample_id[i][len(context_tokens):] print(proc.decode(si)) print("=" * 80) else: # Generate samples from scratch lm_output, _ = decoder(max_decoding_length=max_decoding_length, helper=helper, mode=tf.estimator.ModeKeys.PREDICT) # Load model checkpoint model_utils.init_gpt2_checkpoint(sess, ckpt_path) print("\nFinished loading\n") feed_dict = { tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT } generated = 0 while nsamples == 0 or generated < nsamples: output = sess.run(lm_output, feed_dict=feed_dict) sample_id = output.sample_id for i in range(batch_size): generated += batch_size text = proc.decode(sample_id[i]) print("=" * 40 + " SAMPLE " + str(generated) + " " + "=" * 40) print(text)