コード例 #1
0
 def test_beam_search(self):
     """Tests beam_search
     """
     decoder = TransformerDecoder(embedding=self._embedding)
     outputs = decoder(
         memory=self._memory,
         memory_sequence_length=self._memory_sequence_length,
         memory_attention_bias=None,
         inputs=None,
         beam_width=5,
         start_tokens=self._start_tokens,
         end_token=2,
         max_decoding_length=self._max_decode_len,
         mode=tf.estimator.ModeKeys.PREDICT)
     with self.test_session() as sess:
         sess.run(tf.global_variables_initializer())
         outputs_ = sess.run(outputs)
         self.assertEqual(outputs_['log_prob'].shape,
                          (self._batch_size, 5))
         self.assertEqual(outputs_['sample_id'].shape,
                          (self._batch_size, self._max_decode_len, 5))
コード例 #2
0
    def test_train(self):
        """Tests train_greedy
        """
        decoder = TransformerDecoder(embedding=self._embedding)
        # 6 blocks
        # -self multihead_attention: 4 dense without bias + 2 layer norm vars
        # -encdec multihead_attention: 4 dense without bias + 2 layer norm vars
        # -poswise_network: Dense with bias, Dense with bias + 2 layer norm vars
        # 2 layer norm vars
        outputs = decoder(memory=self._memory,
                          memory_sequence_length=self._memory_sequence_length,
                          memory_attention_bias=None,
                          inputs=self._inputs,
                          decoding_strategy='train_greedy',
                          mode=tf.estimator.ModeKeys.TRAIN)
        self.assertEqual(len(decoder.trainable_variables), 110)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            outputs_ = sess.run(outputs)

            self.assertIsInstance(outputs_, TransformerDecoderOutput)
コード例 #3
0
    def test_beam_search(self):
        """Tests beam_search
        """
        decoder = TransformerDecoder(
            vocab_size=self._vocab_size,
            output_layer=self._output_layer)
        decoder.eval()
        outputs = decoder(
            memory=self._memory,
            memory_sequence_length=self._memory_sequence_length,
            memory_attention_bias=None,
            inputs=None,
            beam_width=5,
            start_tokens=self._start_tokens,
            end_token=self._end_token,
            max_decoding_length=self._max_decode_len)

        self.assertEqual(outputs['log_prob'].shape,
                         (self._batch_size, 5))
        self.assertEqual(outputs['sample_id'].shape,
                         (self._batch_size, self._max_decode_len, 5))
コード例 #4
0
    def test_decode_infer_sample(self):
        """Tests infer_sample
        """
        decoder = TransformerDecoder(
            vocab_size=self._vocab_size,
            output_layer=self._output_layer
        )
        helper = tx_helper.SampleEmbeddingHelper(
            self._embedding_fn, self._start_tokens, self._end_token)

        outputs, length = decoder(
            memory=self._memory,
            memory_sequence_length=self._memory_sequence_length,
            memory_attention_bias=None,
            inputs=None,
            helper=helper,
            max_decoding_length=self._max_decode_len,
            mode=tf.estimator.ModeKeys.PREDICT)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            outputs_ = sess.run(outputs)
            self.assertIsInstance(outputs_, TransformerDecoderOutput)
コード例 #5
0
 def test_decode_train(self):
     """Tests train_greedy
     """
     decoder = TransformerDecoder(
         vocab_size=self._vocab_size,
         output_layer=self._output_layer)
     decoder.train()
     # 6 blocks
     # -self multihead_attention: 4 dense without bias + 2 layer norm vars
     # -encdec multihead_attention: 4 dense without bias + 2 layer norm vars
     # -poswise_network: Dense with bias, Dense with bias + 2 layer norm vars
     # 2 layer norm vars
     outputs = decoder(memory=self._memory,
                       memory_sequence_length=self._memory_sequence_length,
                       memory_attention_bias=None,
                       inputs=self._inputs,
                       decoding_strategy='train_greedy')
     # print(decoder)
     # for name, _ in decoder.named_parameters():
     #     print(name)
     self.assertEqual(len(decoder.trainable_variables), 110)
     self.assertIsInstance(outputs, TransformerDecoderOutput)
コード例 #6
0
    def __init__(self,
                 pretrained_model_name: Optional[str] = None,
                 cache_dir: Optional[str] = None,
                 hparams=None):

        super().__init__(pretrained_model_name=pretrained_model_name,
                         cache_dir=cache_dir,
                         hparams=hparams)

        if self.pretrained_model_dir:
            self._hparams = HParams(self.pretrained_model_hparams,
                                    self._hparams.todict())

        # Word embedding
        self.word_embedder = WordEmbedder(vocab_size=self._hparams.vocab_size,
                                          hparams=self._hparams.embed)

        # Position embedding
        self.position_embedder = PositionEmbedder(
            position_size=self._hparams.position_size,
            hparams=self._hparams.position_embed)

        # The GPT2 decoder (a TransformerDecoder)
        self.decoder = TransformerDecoder(
            vocab_size=self._hparams.vocab_size,
            output_layer=self.word_embedder.embedding,
            hparams=self._hparams.decoder)

        if self.pretrained_model_dir:
            gpt2_utils.init_gpt2_checkpoint(self, self.pretrained_model_dir)
        elif self._hparams.initializer:
            initialize = layers.get_initializer(self._hparams.initializer)
            assert initialize is not None
            # Do not re-initialize LayerNorm modules.
            for name, param in self.named_parameters():
                if name.split(
                        '.')[-1] == 'weight' and 'layer_norm' not in name:
                    initialize(param)
コード例 #7
0
    def test_beam_search(self):
        """Tests beam_search
        """
        decoder = TransformerDecoder(token_pos_embedder=self._embedding_fn,
                                     vocab_size=self._vocab_size,
                                     output_layer=self._output_layer)
        decoder.eval()
        beam_width = 5
        outputs = decoder(memory=self._memory,
                          memory_sequence_length=self._memory_sequence_length,
                          memory_attention_bias=None,
                          inputs=None,
                          beam_width=beam_width,
                          start_tokens=self._start_tokens,
                          end_token=self._end_token,
                          max_decoding_length=self._max_decode_len)

        self.assertEqual(outputs['log_prob'].size(),
                         (self._batch_size, beam_width))
        self.assertEqual(outputs['sample_id'].size(0), self._batch_size)
        self.assertLessEqual(outputs['sample_id'].size(2),
                             self._max_decode_len)
        self.assertEqual(outputs['sample_id'].size(2), beam_width)
コード例 #8
0
def main(_):
    """
    Builds the model and runs
    """
    np.random.seed(FLAGS.seed)
    tf.set_random_seed(FLAGS.seed)

    nsamples = FLAGS.nsamples
    batch_size = FLAGS.batch_size
    max_decoding_length = FLAGS.max_decoding_length

    ckpt_path = FLAGS.checkpoint
    # Load GPT-2 model configuration
    if FLAGS.config_type == "json":
        gpt2_config = model_utils.transform_gpt2_to_texar_config(
            FLAGS.config_model)
    elif FLAGS.config_type == 'texar':
        gpt2_config = importlib.import_module(FLAGS.config_model)
    else:
        raise ValueError('Unknown config_type.')

    assert max_decoding_length <= gpt2_config.decoder["position_size"], (
        "max_decoding_length should be smaller than position size")
    assert nsamples % batch_size == 0, (
        "nsamples must be dividable by batch_size")

    # Create a data pre-processor for, e.g., BPE encoding
    proc = processor.get_encoder("gpt2_pretrained_models/model_117M")

    context = tf.placeholder(tf.int32, [batch_size, None])
    context_length = tf.placeholder(tf.int32, [batch_size])

    end_token = proc.encoder['<|endoftext|>']
    if FLAGS.is_interactive:
        start_tokens = context[:, 0]
    else:
        start_tokens = tf.fill([batch_size], end_token)

    # Build the GPT-2 modle
    embedder = tx.modules.WordEmbedder(vocab_size=gpt2_config.vocab_size,
                                       hparams=gpt2_config.embed)

    helper = tx.modules.TopKSampleEmbeddingHelper(
        embedding=embedder,
        start_tokens=start_tokens,
        end_token=end_token,
        top_k=FLAGS.top_k,
        softmax_temperature=FLAGS.temperature)

    decoder = TransformerDecoder(embedding=embedder.embedding,
                                 hparams=gpt2_config.decoder)

    with tf.Session() as sess:

        if FLAGS.is_interactive:
            # Generate continuations of context
            lm_output, _ = decoder(context=context,
                                   context_sequence_length=context_length,
                                   max_decoding_length=max_decoding_length,
                                   helper=helper,
                                   mode=tf.estimator.ModeKeys.PREDICT)

            # Load model checkpoint
            model_utils.init_gpt2_checkpoint(sess, ckpt_path)
            print("\nFinished loading\n")

            # Enter interactive mode
            while True:

                raw_text = input("Model input >>> ")

                while not raw_text:
                    print('Input should not be empty!')
                    raw_text = input("Model input >>> ")

                context_tokens = proc.encode(raw_text)

                feed_dict = {
                    context: [context_tokens for _ in range(batch_size)],
                    context_length:
                    [len(context_tokens) for _ in range(batch_size)],
                    tx.context.global_mode():
                    tf.estimator.ModeKeys.PREDICT
                }
                generated = 0
                for _ in range(nsamples // batch_size):

                    output = sess.run(lm_output, feed_dict=feed_dict)

                    sample_id = output.sample_id
                    for i in range(batch_size):

                        generated += 1
                        print("=" * 40 + " SAMPLE " + str(generated) + " " +
                              "=" * 40)
                        si = sample_id[i][len(context_tokens):]
                        print(proc.decode(si))
                print("=" * 80)
        else:
            # Generate samples from scratch
            lm_output, _ = decoder(max_decoding_length=max_decoding_length,
                                   helper=helper,
                                   mode=tf.estimator.ModeKeys.PREDICT)

            # Load model checkpoint
            model_utils.init_gpt2_checkpoint(sess, ckpt_path)
            print("\nFinished loading\n")

            feed_dict = {
                tx.context.global_mode(): tf.estimator.ModeKeys.PREDICT
            }
            generated = 0
            while nsamples == 0 or generated < nsamples:

                output = sess.run(lm_output, feed_dict=feed_dict)

                sample_id = output.sample_id
                for i in range(batch_size):

                    generated += batch_size
                    text = proc.decode(sample_id[i])
                    print("=" * 40 + " SAMPLE " + str(generated) + " " +
                          "=" * 40)
                    print(text)