def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=None)) return { 'logits': logits, 'presents': presents, }
def decode_all(self, hparams, tokens, past): with tf.variable_scope(self.scope): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=None)) return { 'logits': logits, 'presents': tf.concat([past, presents], axis=-2) }
def encode(self, input, input_len, past=None): with tf.variable_scope(self.scope, reuse=tf.AUTO_REUSE): lm_output = model.model(hparams=self.hparam, X=input, past=past, reuse=tf.AUTO_REUSE) presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=self.hparam, batch_size=None)) target_mask = tf.sequence_mask(input_len, maxlen=tf.shape(input)[1], dtype=tf.float32) target_mask = tf.expand_dims(target_mask, 2) print(presents) encode_out = tf.transpose(presents, perm=(0, 4, 2, 3, 1, 5)) ori_enc_shape = tf.shape(encode_out) encode_out = tf.reshape(encode_out, shape=(tf.shape(presents)[0], tf.shape(presents)[4], -1)) encode_out = tf.multiply(encode_out, target_mask) encode_out = tf.reshape(encode_out, shape=ori_enc_shape) encode_out = tf.transpose(encode_out, perm=(0, 4, 2, 3, 1, 5)) return encode_out
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=1): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.name_scope('sample_sequence'): def body(past, prev, output): next_outputs = step(hparams, prev, past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) logits = top_p_logits(logits, p=top_p) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ next_outputs['presents'] if past is None else tf.concat( [past, next_outputs['presents']], axis=-2), samples, tf.concat([output, samples], axis=1) ] past, prev, output = body(None, context, context) def cond(*args): return True _, _, tokens = tf.while_loop( cond=cond, body=body, maximum_iterations=length - 1, loop_vars=[past, prev, output], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size, None]), tf.TensorShape([batch_size, None]), ], back_prop=False, ) return tokens
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0): if start_token is None: assert context is not None, 'Specify exactly one of start_token and context!' else: assert context is None, 'Specify exactly one of start_token and context!' context = tf.fill([batch_size, 1], start_token) def step(hparams, tokens, past=None): lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE) logits = lm_output['logits'][:, :, :hparams.n_vocab] presents = lm_output['present'] presents.set_shape( model.past_shape(hparams=hparams, batch_size=batch_size)) return { 'logits': logits, 'presents': presents, } with tf.name_scope('sample_sequence'): # Don't feed the last context token -- leave that to the loop below # TODO: Would be slightly faster if we called step on the entire context, # rather than leaving the last token transformer calculation to the while loop. context_output = step(hparams, context[:, :-1]) def body(past, prev, output, all_logits): next_outputs = step(hparams, prev[:, tf.newaxis], past=past) logits = next_outputs['logits'][:, -1, :] / tf.to_float(temperature) logits = top_k_logits(logits, k=top_k) samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32) return [ tf.concat([past, next_outputs['presents']], axis=-2), tf.squeeze(samples, axis=[1]), tf.concat([output, samples], axis=1), tf.concat([all_logits, next_outputs['logits']], axis=1) ] def cond(*args): return True _, _, tokens, all_logits = tf.while_loop( cond=cond, body=body, maximum_iterations=length, loop_vars=[ context_output['presents'], context[:, -1], context, context_output['logits'] ], shape_invariants=[ tf.TensorShape( model.past_shape(hparams=hparams, batch_size=batch_size)), tf.TensorShape([batch_size]), tf.TensorShape([batch_size, None]), tf.TensorShape([batch_size, None, None]), ], back_prop=False, ) return tokens