示例#1
0
def generate(model, inputs):
    text = inputs['prompt']
    encoded = _tokenize_article_pieces(encoder, text)
    context_formatted = []
    context_formatted.extend(encoded[:-1])
    ignore_ids_np = np.array(encoder.special_tokens_onehot)
    ignore_ids_np[encoder.endoftext] = 0

    gens = []
    gens_raw = []
    gen_probs = []

    for chunk_i in range(num_chunks):
        tokens_out, probs_out = sess.run(
            [model['tokens'], model['probs']],
            feed_dict={
                model['initial_context']:
                [context_formatted] * batch_size_per_chunk,
                model['eos_token']: 60000,
                model['p_for_topp']: top_p[chunk_i]
            })
        for t_i, p_i in zip(tokens_out, probs_out):
            extraction = extract_generated_target(output_tokens=t_i,
                                                  encoder=encoder,
                                                  target=args.target)
            gens.append(extraction['extraction'])

    return gens[0]
def generate_article_attribute(sess, encoder, tokens, probs, article, target='article'):

    """
    Given attributes about an article (title, author, etc), use that context to generate
    a replacement for one of those attributes using the Grover model.

    This function is based on the Grover examples distributed with the Grover code.
    """

    # Tokenize the raw article text
    article_pieces = _tokenize_article_pieces(encoder, article)

    # Grab the article elements the model careas about - domain, date, title, etc.
    context_formatted = []
    for key in ['domain', 'date', 'authors', 'title', 'article']:
        if key != target:
            context_formatted.extend(article_pieces.pop(key, []))

    # Start formatting the tokens in the way the model expects them, starting with
    # which article attribute we want to generate.
    context_formatted.append(encoder.__dict__['begin_{}'.format(target)])
    # Tell the model which special tokens (such as the end token) aren't part of the text
    ignore_ids_np = np.array(encoder.special_tokens_onehot)
    ignore_ids_np[encoder.__dict__['end_{}'.format(target)]] = 0

    # We are only going to generate one article attribute with a fixed
    # top_ps cut-off of 95%. This simple example isn't processing in batches.
    gens = []
    article['top_ps'] = [0.95]

    # Run the input through the TensorFlow model and grab the generated output
    tokens_out, probs_out = sess.run(
        [tokens, probs],
        feed_dict={
            # Pass real values for the inputs that the
            # model needs to be able to run.
            initial_context: [context_formatted],
            eos_token: encoder.__dict__['end_{}'.format(target)],
            ignore_ids: ignore_ids_np,
            p_for_topp: np.array([0.95]),
        }
    )

    # The model is done! Grab the results it generated and format the results into normal text.
    for t_i, p_i in zip(tokens_out, probs_out):
        extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=target)
        gens.append(extraction['extraction'])

    # Return the generated text.
    return gens[-1]
示例#3
0
            # Format context end

            # Indices we definitely DONT WANT TO PREDICT
            ignore_ids_np = np.array(encoder.special_tokens_onehot)
            ignore_ids_np[encoder.endoftext] = 0

            gens = []
            gens_raw = []
            gen_probs = []

            # article['top_ps'] = top_p.reshape(-1).tolist()
            for chunk_i in range(num_chunks):
                tokens_out, probs_out = sess.run([tokens, probs],
                                                 feed_dict={initial_context: [context_formatted] * batch_size_per_chunk,
                                                            eos_token: 60000,
                                                            p_for_topp: top_p[chunk_i]})

                for t_i, p_i in zip(tokens_out, probs_out):
                    extraction = extract_generated_target(output_tokens=t_i, encoder=encoder, target=args.target)
                    gens.append(extraction['extraction'])

            # article['gens_{}'.format(args.target)] = gens
            # article['gensraw_{}'.format(args.target)] = gens_raw
            # article['probs_{}'.format(args.target)] = gen_probs

            # these were in there for whatever reason...
            # article.pop('input_ids_conditional', None)
            # article.pop('input_ids_unconditional', None)
            # f_out.write(json.dumps(article) + '\n')
            print(gens[0])
        text = input()