def step(params, tokens, past=None): if params["precision"] == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) lm_output["pred"] = tf.cast(lm_output["pred"], tf.float32) else: lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) pred = lm_output['pred'] presents = lm_output['present'] presents.set_shape( gpt2.past_shape(params=params, batch_size=batch_size)) return { 'pred': pred, 'presents': presents, }
def sample_sequence(*, params, length, start_token=None, batch_size=None, context=None, temperature=1, gaussian_noise=0, top_k=0): n_bins = params['multibin_nbins'] overlap = params['multibin_overlap'] s_min = params['multibin_min'] s_max = params['multibin_max'] sections, centers = make_sect_and_center(n_bins=n_bins, overlap=overlap, span=(s_min, s_max)) if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) # length = length - params["text_len"] def step(params, tokens, past=None): if params["precision"] == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) lm_output["pred"] = tf.cast(lm_output["pred"], tf.float32) else: lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) pred = lm_output['pred'] presents = lm_output['present'] presents.set_shape( gpt2.past_shape(params=params, batch_size=batch_size)) return { 'pred': pred, 'presents': presents, } with tf.name_scope('sample_sequence'): context_output = step(params, context[:, :-1]) def body(past, prev, output): next_outputs = step(params, prev[:, tf.newaxis], past=past) pred = next_outputs['pred'][:, -1, :] logits = pred[:, :, :, 0] resids = pred[:, :, :, 1] logits = top_k_logits(logits, k=top_k) # batch_size x 85 x 21 regressed = resids + tf.constant(centers)[ None, None] # batch_size x 85 x 21 sampled_bin = tf.multinomial( tf.reshape(logits, (-1, n_bins)), # (batch_size x 85) x 21 num_samples=1, output_dtype=tf.int32) # (batch_size x 85) x 1 sampled_bin = tf.reshape(sampled_bin, (-1, 85)) # batch_size x 85 idxs = tf.reshape(sampled_bin, (-1, )) idxs = tf.stack([tf.range(tf.size(idxs)), idxs], axis=-1) samples = tf.gather_nd( tf.reshape(regressed, (-1, n_bins)), # (batch_size x 85) x 21 idxs) # (batch_size x 85) samples = tf.reshape(samples, (-1, 85)) samples = samples + tf.random.normal(tf.shape(samples), stddev=gaussian_noise) past = tf.concat([past, next_outputs['presents']], axis=-2) prev = samples output = tf.concat([output, samples[:, None]], axis=1) return [ past, samples, output, ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, ], shape_invariants=[ tf.TensorShape( gpt2.past_shape(params=params, batch_size=batch_size)), tf.TensorShape([None, 85]), tf.TensorShape([None, None, 85]), ], back_prop=False, ) return tokens
def sample_sequence(*, params, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) length = length - params["text_len"] def step(params, tokens, past=None): if params["precision"] == 'bfloat16': with tf.contrib.tpu.bfloat16_scope(): lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) lm_output["logits"] = tf.cast(lm_output["logits"], tf.float32) else: lm_output = lm_output = gpt2.model(params=params, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :params["n_vocab"]] presents = lm_output['present'] presents.set_shape( gpt2.past_shape(params=params, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.name_scope('sample_sequence'): context_output = step(params, context[:, :-1]) def body(past, prev, output): next_outputs = step(params, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), ] def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, ], shape_invariants=[ tf.TensorShape( gpt2.past_shape(params=params, batch_size=batch_size)), tf.TensorShape([None]), tf.TensorShape([None, None]), ], back_prop=False, ) return tokens