def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape(model.past_shape( hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, }
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. context_output = step(hparams, context[:, :-1]) def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) if top_p > 0.0: logits = top_p_logits(logits, p=top_p) else: logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, ], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0): '''Not worth documenting at the moment, I barely understand what I did''' if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) context_head = context[:, 0:1] context_tail = context[:, 1:] def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.compat.v1.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. context_output = step(hparams, context_head[:, :-1]) def body(past, prev, context_head, context_tail, all_logits): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.cast( temperature, tf.float32) only_logits = logits return [ tf.concat([past, next_outputs['presents']], axis=-2), context_tail[:, 0], tf.concat([context_head, context_tail[:, 0:1]], axis=1), context_tail[:, 1:], tf.concat( [all_logits, tf.expand_dims(only_logits, 1)], axis=1) ] def cond(*args): return True past, prev, tokens, context_tail, all_logits = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context_head, context_tail, tf.ones([batch_size, 0, 50257]) ], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), #past? tf.TensorShape([batch_size]), #prev? tf.TensorShape([batch_size, None]), #context head tf.TensorShape([batch_size, None]), #context tail tf.TensorShape([batch_size, None, 50257]), #all logits ], back_prop=False, ) return past, prev, tokens, all_logits
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0.0, return_attention=False): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.compat.v1.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. context_output = step(hparams, context[:, :-1]) def body(past, prev, output): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.cast( temperature, tf.float32) if top_p > 0.0: logits = top_p_logits(logits, p=top_p) else: logits = top_k_logits(logits, k=top_k) samples = tf.random.categorical(logits, num_samples=1, dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True past_n_present, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, ], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) if return_attention: # past_n_present should be in the dimension of # [batch, layers, 2, heads, sequence, features] past = past_n_present[:, :, :1, :, :, :] present = past_n_present[:, :, 1:, :, :, :] # compute the past and present attetntion attention = tf.matmul( past / tf.cast(temperature, tf.float32), tf.transpose(present, perm=[0, 1, 2, 3, 5, 4])) attention = tf.nn.softmax(attention, axis=-1) return tokens, attention else: return tokens