示例#1
0
    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens,
                                past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(
            hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }
示例#2
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0,
                    top_p=0.0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
示例#3
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0,
                    top_p=0.0):
    '''Not worth documenting at the moment, I barely understand what I did'''
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)
    context_head = context[:, 0:1]
    context_tail = context[:, 1:]

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.compat.v1.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context_head[:, :-1])

        def body(past, prev, context_head, context_tail, all_logits):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(
                temperature, tf.float32)
            only_logits = logits
            return [
                tf.concat([past, next_outputs['presents']],
                          axis=-2), context_tail[:, 0],
                tf.concat([context_head, context_tail[:, 0:1]], axis=1),
                context_tail[:, 1:],
                tf.concat(
                    [all_logits, tf.expand_dims(only_logits, 1)], axis=1)
            ]

        def cond(*args):
            return True

        past, prev, tokens, context_tail, all_logits = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'], context[:, -1], context_head,
                context_tail,
                tf.ones([batch_size, 0, 50257])
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams,
                                     batch_size=batch_size)),  #past?
                tf.TensorShape([batch_size]),  #prev?
                tf.TensorShape([batch_size, None]),  #context head
                tf.TensorShape([batch_size, None]),  #context tail
                tf.TensorShape([batch_size, None, 50257]),  #all logits
            ],
            back_prop=False,
        )

        return past, prev, tokens, all_logits
示例#4
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0,
                    top_p=0.0,
                    return_attention=False):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.compat.v1.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(
                temperature, tf.float32)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.random.categorical(logits,
                                            num_samples=1,
                                            dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        past_n_present, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        if return_attention:
            # past_n_present should be in the dimension of
            # [batch, layers, 2, heads, sequence, features]
            past = past_n_present[:, :, :1, :, :, :]
            present = past_n_present[:, :, 1:, :, :, :]
            # compute the past and present attetntion
            attention = tf.matmul(
                past / tf.cast(temperature, tf.float32),
                tf.transpose(present, perm=[0, 1, 2, 3, 5, 4]))
            attention = tf.nn.softmax(attention, axis=-1)
            return tokens, attention
        else:
            return tokens