コード例 #1
0
ファイル: sample.py プロジェクト: neuroradiology/gpt-2
    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }
コード例 #2
0
ファイル: sample.py プロジェクト: neuroradiology/gpt-2
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
コード例 #3
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0,
                    top_p=1):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]

        # myembed(globals(), locals())

        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):

        def body(past, prev, output):
            next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            logits = top_p_logits(logits, p=top_p)
            # myembed(globals(), locals())
            logits = tf.transpose(logits)
            # print(logits.shape)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            return [
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2), samples,
                tf.concat([output, samples], axis=1)
            ]

        past, prev, output = body(None, context, context)

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length - 1,
            loop_vars=[past, prev, output],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
コード例 #4
0
def get_per_subword_surprisal(*, corpus, hparams, encoder):
    start_token = encoder.encoder['<|endoftext|>']
    context = tf.fill([BATCH_SIZE, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=BATCH_SIZE))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('get_per_word_surprisal'):
        # word is a list of integers (encoded chunks)
        def body(corpus, i, past, prev, surprisals):
            # chunk should be a scalar here
            chunk = corpus[i]
            next_outputs = step(hparams, prev, past=past)
            # dimension is batch_size x vocab_size
            logits = next_outputs['logits'][:, -1, :]
            softmax = tf.nn.softmax(logits)
            surp = tf.math.scalar_mul(-1, tf.math.log(softmax[0, chunk]))
            # TODO assuming here that batch size is 1.
            # find a better solultion
            return [
                corpus,
                tf.add(i, 1),
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2),
                tf.reshape(chunk, [1, 1]),
                tf.concat([surprisals, tf.reshape(surp, [1, 1])], axis=1),
            ]

        corpus = tf.constant(corpus)
        i = tf.constant(0)
        corpus, i, past, prev, surprisals = body(corpus, i, None, context,
                                                 tf.constant([[]]))

        corpus_length = corpus.shape[0].value

        def cond(corpus, i, past, prev, surprisals):
            return tf.less(i, corpus_length)

        _, _, _, _, surprisals = tf.while_loop(
            cond=cond,
            body=body,
            loop_vars=[corpus, i, past, prev, surprisals],
            shape_invariants=[
                corpus.get_shape(),
                i.get_shape(),
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=BATCH_SIZE)),
                tf.TensorShape([BATCH_SIZE, None]),
                tf.TensorShape([BATCH_SIZE, None]),
            ],
            back_prop=False,
        )

        return surprisals
コード例 #5
0
def tf_loop_bimodel(*,
                    hparams,
                    length,
                    modeldata,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0):

    context = utility.check_start_context(start_token, context, batch_size)
    pastshape = model.past_shape(hparams=hparams, batch_size=batch_size)
    assert modeldata is None or modeldata.shape == (batch_size, length)

    with tf.name_scope('tf_loop_bimodel'):

        mdltensor = None if modeldata is None else tf.constant(modeldata,
                                                               dtype=tf.int32)
        #print("MDL tensor shape is {}".format(mdltensor.shape))
        pastshape = model.past_shape(hparams=hparams, batch_size=batch_size)

        def body(past, output, tknprbs, loggies, curstep):

            # I still don't understand why this line is necessary
            xtok = output[:, -1]

            presents, logits = model.upmodel(hparams=hparams,
                                             X=xtok[:, tf.newaxis],
                                             past=past,
                                             reuse=tf.AUTO_REUSE)
            presents.set_shape(pastshape)

            logits = logits[:, -1, :hparams.n_vocab] / tf.to_float(temperature)
            logits = utility.top_k_logits(logits, k=top_k)

            if modeldata is None:
                items = tf.multinomial(logits,
                                       num_samples=1,
                                       output_dtype=tf.int32,
                                       seed=1000)
            else:
                items = tf.gather(mdltensor, curstep, axis=1)
                print("Items shape is {}".format(items.shape))
                items = tf.reshape(items, shape=(batch_size, 1))

            assert items.shape == (batch_size, 1)

            # This is the full prob distribution
            probs = tf.reshape(logits, (batch_size, hparams.n_vocab))
            probs = tf.nn.softmax(probs, axis=1)
            justidx = tf.reshape(items, shape=(batch_size, ))
            tknprb = tf.gather(probs, justidx, axis=1)
            tknprb = tf.reshape(tf.linalg.diag_part(tknprb), (batch_size, 1))

            return [
                tf.concat([past, presents], axis=-2),
                tf.concat([output, items], axis=1),
                tf.concat([tknprbs, tknprb], axis=1),
                tf.concat([
                    loggies,
                    tf.reshape(logits, shape=(batch_size, hparams.n_vocab, 1))
                ],
                          axis=2),
                tf.add(curstep, tf.constant(1))
            ]

        def cond(*args):
            return True

        history, _ = model.upmodel(hparams=hparams,
                                   X=context[:, :-1],
                                   past=None,
                                   reuse=tf.AUTO_REUSE)
        history.set_shape(pastshape)

        loggies = tf.fill([batch_size, hparams.n_vocab, 1], -1.5655)
        tknprbs = tf.fill([batch_size, 1], -1.513)
        curstep = tf.constant(0, shape=(1, ))

        result = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[history, context, tknprbs, loggies, curstep],
            shape_invariants=[
                tf.TensorShape(pastshape),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, hparams.n_vocab, None]),
                tf.TensorShape([1])
            ],
            back_prop=False)

        return result
コード例 #6
0
def sample_sequence(config,
                    length,
                    start_token=None,
                    context=None,
                    temprature=1,
                    top_k=0,
                    top_p=0.0):
    if start_token is None:
        assert context is not None, 'Specify either `start_token` or `context`'
    else:
        assert context is None, 'Specify either `start_token` or `context`'
        context = tf.fill([config.batch_size, 1], start_token)

    def step(config, tokens, past=None):
        lm_output = model(config, tokens, past, reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :config.vocab_size]
        present = lm_output['present']
        present.set_shape(past_shape(config))
        return {'logits': logits, 'present': present}

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        '''TODO: Would be slightly faster if we called step on the entire context
        rather than leaving the last token transformer calculation to the while loop'''
        context_output = step(config, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(config, prev[:, tf.newaxis], past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(
                temprature, tf.float32)
            if top_p > 0.0:
                logits = top_p_logits(logits, top_p)
            else:
                logits = top_k_logits(logits, top_k)

            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['present']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1)
            ]

        def cond(t1, t2, t3):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[context_output['present'], context[:, -1], context],
            shape_invariants=[
                tf.TensorShape(past_shape(config)),
                tf.TensorShape([config.batch_size]),
                tf.TensorShape([config.batch_size, None])
            ],
            back_prop=False)

        return tokens
コード例 #7
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0):
    if context is not None:
        context_input = context
        encoder_length = tf.shape(context)[1]
    else:
        context_input = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):

        def body(past, prev, output):
            next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            return [
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2), samples,
                tf.concat([output, samples], axis=1)
            ]

        # past: the past hidden states of the transformer
        # prev: the prediction token
        # output: the cumulative tokens
        past, prev, output = body(None, context_input, context_input)
        prev = tf.reshape(prev, [batch_size, 1])
        output = tf.reshape(output, [batch_size, -1])

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length - 1,
            loop_vars=[past, prev, output],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
コード例 #8
0
def sample_sequence(
    *,
    hparams,
    length,
    start_token=None,
    target_tokens=None,
    target_bias=None,
    end_tokens=[],
    eval_tokens=None,
    batch_size=None,
    context=None,
    temperature=1,
    top_k=0,
    top_p=0.0,
    gpu_layers=20,
):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    if target_tokens is None:
        target_tokens = tf.constant([], dtype=tf.dtypes.int64)
    if target_bias is None:
        target_bias = tf.constant([], dtype=tf.dtypes.float32)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE,
                                gpu_layers=gpu_layers)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])
        initial_target_weights = tf.sparse.to_dense(
            tf.SparseTensor(
                tf.reshape(target_tokens, [-1, 1]),
                target_bias,
                [hparams.n_vocab],
            ),
            default_value=1,
        )

        def body(past, prev, output, target_weights, evaluate, evaluation):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)

            # If specified, increase the probability of each target word by its corresponding target_bias
            logits *= tf.reshape(target_weights, [1, -1])

            # Restrict the logits to either the top_p or top_k tokens
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)

            if eval_tokens is None:
                next_word = tf.multinomial(logits,
                                           num_samples=1,
                                           output_dtype=tf.int32)
                next_evaluation = [[]]
            else:
                next_word = [[evaluate[0]]]
                # I feel like this is wrong, and will likely break when a batch_size > 1 is used
                next_evaluation = [tf.gather(logits, evaluate[0], axis=1)]

            new_target_weights = tf.where(
                tf.sparse.to_dense(
                    tf.sparse.reorder(
                        tf.SparseTensor(
                            # This will do the wrong thing for batch sizes > 1
                            # Instead of removing a target weight only for that batch, it will remove it for all batches.
                            tf.reshape(tf.cast(next_word, tf.dtypes.int64),
                                       [-1, 1]),
                            [False],
                            [hparams.n_vocab],
                        ), ),
                    default_value=True,
                ),
                target_weights,
                tf.ones_like(target_weights),
            )

            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(next_word, axis=[1]),
                tf.concat([output, next_word], axis=1),
                new_target_weights,
                evaluate[1:],
                tf.concat([evaluation, next_evaluation], axis=1),
            ]

        def cond(past, prev, output, target_weights, evaluate, evaluation):
            prev_shaped = tf.reshape(prev, shape=[batch_size, 1])
            end_tokens_shaped = tf.constant(end_tokens,
                                            dtype=tf.int32,
                                            shape=[1, len(end_tokens)])
            end_token_not_seen = tf.math.logical_not(
                tf.math.reduce_any(
                    tf.math.equal(prev_shaped, end_tokens_shaped)))
            not_end_of_evaluation = True if eval_tokens is None else tf.greater(
                tf.shape(evaluate)[0], 0)
            return tf.logical_and(not_end_of_evaluation, end_token_not_seen)

        _, _, output, _, _, evaluation = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
                initial_target_weights,
                tf.constant([]) if eval_tokens is None else eval_tokens,
                tf.constant([], shape=[batch_size, 0]),
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([hparams.n_vocab]),
                tf.TensorShape([None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        if eval_tokens is not None:
            return evaluation
        else:
            return output
コード例 #9
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0):

    context = utility.check_start_context(start_token, context, batch_size)

    with tf.name_scope('sample_sequence'):

        pastshape = model.past_shape(hparams=hparams, batch_size=batch_size)

        def body(past, output, tknprbs, loggies):

            # I still don't understand why this line is necessary
            xtok = output[:, -1]

            presents, logits = model.upmodel(hparams=hparams,
                                             X=xtok[:, tf.newaxis],
                                             past=past,
                                             reuse=tf.AUTO_REUSE)
            presents.set_shape(pastshape)

            logits = logits[:, -1, :hparams.n_vocab] / tf.to_float(temperature)
            logits = utility.top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32,
                                     seed=1000)

            # This is the full prob distribution
            probs = tf.reshape(logits, (batch_size, hparams.n_vocab))
            probs = tf.nn.softmax(probs, axis=1)
            justidx = tf.reshape(samples, shape=(batch_size, ))
            tknprb = tf.gather(probs, justidx, axis=1)
            tknprb = tf.reshape(tf.linalg.diag_part(tknprb), (batch_size, 1))

            return [
                tf.concat([past, presents], axis=-2),
                tf.concat([output, samples], axis=1),
                tf.concat([tknprbs, tknprb], axis=1),
                tf.concat([
                    loggies,
                    tf.reshape(logits, shape=(batch_size, hparams.n_vocab, 1))
                ],
                          axis=2)
            ]

        def cond(*args):
            return True

        history, _ = model.upmodel(hparams=hparams,
                                   X=context[:, :-1],
                                   past=None,
                                   reuse=tf.AUTO_REUSE)
        history.set_shape(pastshape)

        loggies = tf.fill([batch_size, hparams.n_vocab, 1], -1.5655)
        tknprbs = tf.fill([batch_size, 1], -1.513)

        result = tf.while_loop(cond=cond,
                               body=body,
                               maximum_iterations=length,
                               loop_vars=[history, context, tknprbs, loggies],
                               shape_invariants=[
                                   tf.TensorShape(pastshape),
                                   tf.TensorShape([batch_size, None]),
                                   tf.TensorShape([batch_size, None]),
                                   tf.TensorShape(
                                       [batch_size, hparams.n_vocab, None])
                               ],
                               back_prop=False)

        return result
コード例 #10
0
ファイル: sample.py プロジェクト: jgoodrich77/MC-tailor
def sample_sequence_ISMC_threshold(*,
                                   Dis,
                                   layer,
                                   hparams,
                                   length,
                                   start_token=None,
                                   batch_size=None,
                                   context=None,
                                   temperature=1,
                                   top_k=0,
                                   top_p=0.0):
    #if start_token is None:
    #    assert context is not None, 'Specify exactly one of start_token and context!'
    #else:
    #    assert context is None, 'Specify exactly one of start_token and context!'

    batch_size = 1000
    context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])
        shape1 = model.past_shape(hparams=hparams, batch_size=batch_size)
        shape1[-2] = -1

        def body(past, prev, output, stop_before):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)

            past = tf.concat([past, next_outputs['presents']], axis=-2)
            prev = tf.squeeze(samples, axis=[1])
            output = tf.concat([output, samples], axis=1)

            log_prob = Dis.log_prob_step(tf.concat([
                output,
                tf.ones(
                    shape=[tf.shape(output)[0], length - tf.shape(output)[1]],
                    dtype=tf.int32) * start_token
            ],
                                                   axis=1),
                                         layer=layer)
            log_prob_cut = tf.reduce_min(log_prob - Dis.dis[:layer], axis=1)
            ids = tf.range(tf.shape(log_prob_cut)[0])
            #already_end_ids=ids[:stop_before]
            #end_ids=tf.cast(tf.where(tf.equal(prev[stop_before:], start_token))[:,0], dtype=tf.int32)+stop_before
            #non_end_ids=tf.cast(tf.where(tf.not_equal(prev[stop_before:], start_token))[:,0], dtype=tf.int32)+stop_before
            non_end_ids = ids

            log_prob_non_end = tf.gather(log_prob_cut, non_end_ids)
            #r=tf.random_uniform(shape=tf.shape(prob_non_end)[0:1], dtype=tf.float32)
            #selected_ids_pre=tf.where(tf.less(r, prob_non_end))[:,0]
            selected_ids_pre = tf.where(tf.less(0.0, log_prob_non_end))[:, 0]

            def true_fn():
                return tf.gather(non_end_ids, selected_ids_pre)

            def false_fn():
                return non_end_ids

            sample_ids = tf.cond(tf.shape(non_end_ids)[0] > 0,
                                 true_fn=true_fn,
                                 false_fn=false_fn)
            #sample_ids=non_end_ids

            #combine_ids=tf.concat([already_end_ids, end_ids, sample_ids], axis=0)
            combine_ids = sample_ids
            end_ids = sample_ids
            return [
                tf.gather(past, combine_ids),
                tf.gather(prev, combine_ids),
                tf.gather(output, combine_ids),
                tf.shape(end_ids)[0] + stop_before,
            ]

        def cond(*args):
            return True

        _, _, tokens, _ = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length - 1,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
                0,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=None)),
                tf.TensorShape([None]),
                tf.TensorShape([None, None]),
                tf.TensorShape([])
            ],
            back_prop=False,
        )

        return tokens
コード例 #11
0
ファイル: sample.py プロジェクト: jgoodrich77/MC-tailor
def sample_sequence_SMC(*,
                        Dis,
                        layer,
                        hparams,
                        length,
                        start_token=None,
                        batch_size=None,
                        context=None,
                        temperature=1,
                        top_k=0,
                        top_p=0.0):
    #if start_token is None:
    #    assert context is not None, 'Specify exactly one of start_token and context!'
    #else:
    #    assert context is None, 'Specify exactly one of start_token and context!'
    context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])
        shape1 = model.past_shape(hparams=hparams, batch_size=batch_size)
        shape1[-2] = -1

        def body(past, prev, output, stop_before):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)

            past = tf.concat([past, next_outputs['presents']], axis=-2)
            prev = tf.squeeze(samples, axis=[1])
            output = tf.concat([output, samples], axis=1)
            output_full = tf.concat([
                output,
                tf.zeros([
                    hparams.batch_size,
                    hparams.seq_len + 1 - tf.shape(output)[1]
                ],
                         dtype=tf.int32) + start_token
            ],
                                    axis=1)[:, :hparams.seq_len]

            def transform(x, lift=2.0):
                x = x - 0.5
                x = x + tf.abs(x)
                return lift * x**2

            prob = transform(Dis.prob(output_full, layer=layer))
            ids = tf.range(hparams.batch_size)
            already_end_ids = ids[:stop_before]
            end_ids = tf.cast(tf.where(
                tf.equal(prev[stop_before:], start_token))[:, 0],
                              dtype=tf.int32) + stop_before
            non_end_ids = tf.cast(tf.where(
                tf.not_equal(prev[stop_before:], start_token))[:, 0],
                                  dtype=tf.int32) + stop_before

            prob_non_end = tf.gather(prob, non_end_ids)

            def true_fn():
                return tf.gather(
                    non_end_ids,
                    tf.random.multinomial(
                        logits=tf.log(prob_non_end + 1e-7)[tf.newaxis, :],
                        num_samples=tf.shape(non_end_ids)[0])[0])

            def false_fn():
                return non_end_ids

            sample_ids = tf.cond(tf.shape(non_end_ids)[0] > 0,
                                 true_fn=true_fn,
                                 false_fn=false_fn)

            combine_ids = tf.concat([already_end_ids, end_ids, sample_ids],
                                    axis=0)
            return [
                tf.reshape(tf.gather(past, combine_ids), shape1),
                tf.reshape(tf.gather(prev, combine_ids), [batch_size]),
                tf.reshape(tf.gather(output, combine_ids), [batch_size, -1]),
                tf.shape(end_ids)[0] + stop_before,
            ]

        def cond(*args):
            return True

        _, _, tokens, _ = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
                0,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([])
            ],
            back_prop=False,
        )

        return tokens
コード例 #12
0
ファイル: sample.py プロジェクト: jgoodrich77/MC-tailor
def sample_sequence_for_GAN(*,
                            hparams,
                            length,
                            start_token=None,
                            batch_size=None,
                            context=None,
                            temperature=1,
                            top_k=0,
                            top_p=0.0,
                            ST=False,
                            Gumbel_temperature=3.0):
    context = tf.one_hot(tf.fill([batch_size, 1], start_token),
                         hparams.n_vocab,
                         dtype=tf.float32)

    def Gumbel_variable(shape, dtype):
        r = tf.random_uniform(shape=shape, dtype=dtype)
        return -tf.log(-tf.log(r))

    def step(hparams, tokens, past=None):
        print('tokens:{}'.format(tokens.shape))
        lm_output = model.model_for_GAN(hparams=hparams,
                                        X=tokens,
                                        past=past,
                                        reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1, :])
        b = tf.get_variable(name='s', initializer=0.1, dtype=tf.float32)

        def body(past, prev, output, output2):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            logits += Gumbel_variable(tf.shape(logits), logits.dtype)
            sample_gumbel = tf.nn.softmax(logits / Gumbel_temperature)
            if ST:
                sample = tf.one_hot(tf.argmax(sample_gumbel, axis=-1),
                                    tf.shape(sample_gumbel)[-1],
                                    dtype=tf.float32)
            else:
                sample = sample_gumbel
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                sample,
                tf.concat([output, tf.expand_dims(sample, axis=1)], axis=1),
                tf.concat(
                    [output2, tf.expand_dims(sample_gumbel, axis=1)], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens_out, tokens_gumbel = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1, :],
                context,
                context,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None, None]),
                tf.TensorShape([batch_size, None, None]),
            ],
            back_prop=True,
            swap_memory=True,
        )

        return tokens_out, tokens_gumbel
コード例 #13
0
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, top_p=0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        def body(past, prev, output):
            next_outputs = step(hparams, prev, past=past)
            # print(tf.unstack(next_outputs['logits'][:, -1, :] ))

            if temperature == 0:
                logits = tf.map_fn(fn=lambda logit_tensor: logit_tensor / tf.random.uniform((1,), minval=.69, maxval=.91, dtype=tf.dtypes.float32),
                    elems=next_outputs['logits'][:, -1, :],
                    back_prop=False,
                    dtype=tf.float32)
            else: 
                logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)

            # logits = top_p_logits(logits, p=top_p)
            if top_p:
                logits = top_p_logits(logits, p=top_p)
            else: 
                logits = top_k_logits(logits, k=top_k)
            
            samples = tf.compat.v1.multinomial(logits, num_samples=1, output_dtype=tf.int32)
            return [
                next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
                samples,
                tf.concat([output, samples], axis=1)
            ]

        past, prev, output = body(None, context, context)

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length - 1,
            loop_vars=[
                past,
                prev,
                output
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
コード例 #14
0
def run_sequence(*, hparams, start_token=None, batch_size=None, context=None, length=None):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present'] 
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        def body(past, prev, probs, count):
            next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:, -1, :]

            samples = context[:,tf.cast(count,dtype=tf.int32)]
            samples = tf.reshape(samples, [batch_size,1])
            print("Samples", samples)
            prob = tf.reduce_max(tf.nn.softmax(logits) * tf.one_hot(samples[:,0], logits.shape[1]),axis=1)
            
            return [
                next_outputs['presents'] if past is None else tf.concat([past, next_outputs['presents']], axis=-2),
                samples,
                tf.concat([probs, prob[:,tf.newaxis]], axis=1),
                count+1
            ]

        past, prev, probs, count = body(None,
                                        tf.zeros((batch_size,1),dtype=tf.int32) + context[0,0],
                                        tf.zeros((batch_size,0),dtype=tf.float32),
                                        tf.zeros([], dtype=tf.int32))

        def cond(*args):
            return True
        
        _, _, probs, _ = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length - 1,
            loop_vars=[
                past,
                prev,
                probs,
                count,
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([]),
            ],
            back_prop=False,
        )

        return probs
コード例 #15
0
def sample_sequence_glove_all_top_five_gpu(*,
                                           hparams,
                                           length,
                                           start_token=None,
                                           batch_size=None,
                                           context=None,
                                           temperature=1,
                                           top_k=0,
                                           top_p=1,
                                           glove=None,
                                           weight=None):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    converter_table = np.load(
        str(os.path.dirname(os.path.abspath(__file__))) +
        '/look_ups_gpt-2/converter_table.npy')

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):

        def body(past, prev, output, probabilities):
            next_outputs = step(hparams, prev, past=past)
            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            logits = top_k_logits(logits, k=top_k)
            logits = top_p_logits(logits, p=top_p)

            glove_one, glove_two, glove_three, glove_four, glove_five = glove

            similar_five = tf.convert_to_tensor(cosine_similarity(
                np.reshape(glove_five, (1, -1)), converter_table),
                                                dtype=tf.float32)
            similar_four = tf.convert_to_tensor(cosine_similarity(
                np.reshape(glove_four, (1, -1)), converter_table),
                                                dtype=tf.float32)
            similar_three = tf.convert_to_tensor(cosine_similarity(
                np.reshape(glove_three, (1, -1)), converter_table),
                                                 dtype=tf.float32)
            similar_two = tf.convert_to_tensor(cosine_similarity(
                np.reshape(glove_two, (1, -1)), converter_table),
                                               dtype=tf.float32)
            similar_one = tf.convert_to_tensor(cosine_similarity(
                np.reshape(glove_one, (1, -1)), converter_table),
                                               dtype=tf.float32)

            value = weight  #7.0 #6.0  #8.0
            fact = tf.constant(value, tf.float32)

            prob = tf.nn.softmax(logits)

            logits = tf.add(logits, similar_one * fact)
            logits = tf.add(logits, similar_two * fact)
            logits = tf.add(logits, similar_three * fact)
            logits = tf.add(logits, similar_four * fact)
            logits = tf.add(logits, similar_five * fact)

            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            sample = samples[0, 0]

            probability_old = tf.cast(tf.gather_nd(prob, [[0, sample]]),
                                      tf.float64)
            probability = tf.multiply(probability_old, probabilities)
            return [
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2), samples,
                tf.concat([output, samples], axis=1), probability
            ]

        probabilities = tf.ones([1], tf.float64)
        past, prev, output, probabilities = body(None, context, context,
                                                 probabilities)

        def cond(*args):
            return True

        _, _, tokens, probabilities = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length - 1,
            loop_vars=[past, prev, output, probabilities],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([
                    batch_size,
                ]),
            ],
            back_prop=False,
        )

        return tokens, probabilities
コード例 #16
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    sampler='k',
                    temperature=1,
                    top_k=0,
                    alpha=0.05,
                    nuc_prob=0.25,
                    flat_prob=0.02,
                    k_window_size=None,
                    window_weights=None):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'  # this is where the whole context is already given into the model.
        # it is the primer that I write for it!
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1],
                          start_token)  # this is not used in my case!

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):

        # this will store all of the logits for a specific sampling run.
        # ultimately I want this to be a:  samples * prompts (batch size) * words * length  sized matrix.

        #all_logits = tf.Variable(tf.zeros([batch_size, 50257, length]), dtype=tf.float32)
        #all_logits = tf.Variable(name='all_logits', shape=[batch_size, 50257, length], initializer= ,dtype=tf.float32, trainable=False)

        def body(past, prev, output, all_logits):
            next_outputs = step(hparams, prev, past=past)

            logits = next_outputs['logits'][:,
                                            -1, :] / tf.to_float(temperature)
            if sampler == 'k':
                print('using top k')
                logits = top_k_logits(logits, k=top_k)
            elif sampler == 'n':
                print('using nucleus')
                logits = nucleus(logits, p=nuc_prob)
            elif sampler == 'tfs':
                print('using tail free sampling')
                logits = tail_free(logits,
                                   alpha)  #, k_window_size, window_weights)
            elif sampler == 'flat':
                print('using flat percentage sampling')
                logits = flat_perc(logits, flat_prob)
            else:
                print('defauling to top k sampling')
                logits = top_k_logits(logits, k=top_k)
            #print('the logits shape post processing is: ', logits.shape)
            samples = tf.multinomial(logits,
                                     num_samples=1,
                                     output_dtype=tf.int32)
            #print('the samples shape is: ', samples.shape)
            return [
                next_outputs['presents'] if past is None else tf.concat(
                    [past, next_outputs['presents']], axis=-2),
                tf.reshape(samples, [batch_size, 1]),
                tf.concat([output, samples], axis=1),
                tf.expand_dims(next_outputs['logits'][:, -1, :], axis=2)
                if all_logits is None else tf.concat([
                    all_logits,
                    tf.expand_dims(next_outputs['logits'][:, -1, :], axis=2)
                ],
                                                     axis=2)
                #tf.concat([all_logits, tf.expand_dims(next_outputs['logits'][:, -1, :], axis=2)], axis=2)
                #all_logits[:,:,tf.shape(output)[1]+1].assign(tf.expand_dims(next_outputs['logits'][:, -1, :], axis=2) )
            ]

        past, prev, output, all_logits = body(
            None, context, context, None
        )  # for the first run the output and previous are both the context.

        def cond(*args):
            return True

        _, _, tokens, all_logits_out = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length - 1,
            loop_vars=[past, prev, output, all_logits],
            #changed the 2nd shape invariant so that it can handle the ? shape (which is actually batch size) for the TFS sampling.
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([batch_size, 50257, None])  #batch size
            ],
            back_prop=False,
        )

        return (tokens, all_logits_out)
コード例 #17
0
def sample_sequence(*,
                    hparams,
                    length,
                    start_token=None,
                    batch_size=None,
                    context=None,
                    temperature=1,
                    top_k=0,
                    top_p=0.0):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None):
        lm_output = model.model(hparams=hparams,
                                X=tokens,
                                past=past,
                                reuse=tf.compat.v1.AUTO_REUSE)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(
            model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1])

        def body(past, prev, output):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :] / tf.cast(
                temperature, dtype=tf.float32)
            if top_p > 0.0:
                logits = top_p_logits(logits, p=top_p)
            else:
                logits = top_k_logits(logits, k=top_k)
            samples = tf.random.categorical(logits,
                                            num_samples=1,
                                            dtype=tf.int32)
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
            ]

        def cond(*args):
            return True

        _, _, tokens = tf.while_loop(
            cond=cond,
            body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
            ],
            shape_invariants=[
                tf.TensorShape(
                    model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
            ],
            back_prop=False,
        )

        return tokens
コード例 #18
0
def sample_sequence(*, hparams, length, start_token=None, batch_size=None, context=None, temperature=1, top_k=0, softmax_length=10):
    if start_token is None:
        assert context is not None, 'Specify exactly one of start_token and context!'
    else:
        assert context is None, 'Specify exactly one of start_token and context!'
        context = tf.fill([batch_size, 1], start_token)

    def step(hparams, tokens, past=None, mix_prompt=False):
        lm_output = model.model(hparams=hparams, X=tokens, past=past, reuse=tf.AUTO_REUSE, mix_prompt=mix_prompt)

        logits = lm_output['logits'][:, :, :hparams.n_vocab]
        presents = lm_output['present']
        presents.set_shape(model.past_shape(hparams=hparams, batch_size=batch_size))
        return {
            'logits': logits,
            'presents': presents,
        }

    with tf.name_scope('sample_sequence'):
        # Don't feed the last context token -- leave that to the loop below
        # TODO: Would be slightly faster if we called step on the entire context,
        # rather than leaving the last token transformer calculation to the while loop.
        context_output = step(hparams, context[:, :-1],mix_prompt=True)

        def body(past, prev, output, index, softmax):
            next_outputs = step(hparams, prev[:, tf.newaxis], past=past)
            logits = next_outputs['logits'][:, -1, :]  / tf.to_float(temperature)
            #Calculate Softmax of top N, before we cancel them to -1e10
            softmax_length_logits, index_loop = tf.nn.top_k(logits, k=softmax_length)
            softmax_loop = tf.nn.softmax(softmax_length_logits)

            logits = top_k_logits(logits, k=top_k)
            samples = tf.multinomial(logits, num_samples=1, output_dtype=tf.int32)

            #
            # #Index of Logits
            # indexOfLogits = tf.where(logits>-1e9)
            # indexAsList = indexOfLogits[:,1]
            # indexAsList = tf.expand_dims(indexAsList, axis=0)
            # #Softmax of top k logits
            # kLogits = tf.gather_nd(logits, indexOfLogits)
            # softmaxOfLogits = tf.nn.softmax(kLogits)
            # softmaxOfLogits = tf.expand_dims(softmaxOfLogits, axis=0)
            #We want the word distribution of the last one.
            #I think multinomial is converting logits to an index. Maybe convert num_samples = 10? Then Softmax?
            return [
                tf.concat([past, next_outputs['presents']], axis=-2),
                tf.squeeze(samples, axis=[1]),
                tf.concat([output, samples], axis=1),
                tf.concat([index, index_loop], axis=0),
                tf.concat([softmax, softmax_loop], axis=0)
            ]

        def cond(*args):
            return True

        #Filler Index and Softmax Tensors for first loop
        #Therefor, skip first row in output
        fillerIndex = tf.zeros(shape=[1,softmax_length], dtype=tf.int32)
        fillerSoftmax = tf.zeros(shape=[1,softmax_length], dtype=tf.float32)


        _, _, tokens, index, softmax = tf.while_loop(
            cond=cond, body=body,
            maximum_iterations=length,
            loop_vars=[
                context_output['presents'],
                context[:, -1],
                context,
                fillerIndex,
                fillerSoftmax,
            ],
            shape_invariants=[
                tf.TensorShape(model.past_shape(hparams=hparams, batch_size=batch_size)),
                tf.TensorShape([batch_size]),
                tf.TensorShape([batch_size, None]),
                tf.TensorShape([None,None]),
                tf.TensorShape([None,None]),
            ],
            back_prop=False,
        )

        return tokens, index, softmax