def step(hparams, tokens, past=None): lm_output = gpt2_model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( gpt2_model.past_shape(hparams=hparams, batch_size=batch_size)) return {'logits': logits, 'presents': presents}
def sample_sequence( hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0, ): if start_token is None: assert (context is not None), 'Specify exactly one of start_token and context!' else: assert (context is None), 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = gpt2_model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( gpt2_model.past_shape(hparams=hparams, batch_size=batch_size)) return {'logits': logits, 'presents': presents} with tf.name_scope('sample_sequence'): context_output = step(hparams, context[:, :-1]) def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.cast( temperature, tf.float32) if top_p > 0.0: logits = top_p_logits(logits, p=top_p) else: logits = top_k_logits(logits, k=top_k) samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[context_output['presents'], context[:, -1], context], shape_invariants=[ tf.TensorShape( gpt2_model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens