Exemplo n.º 1
0
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
        model_type, model_name = ("--model_type=openai-gpt", "--model_name_or_path=openai-gpt")
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
            result = run_generation.main()
Exemplo n.º 2
0
    def test_generation(self):
        stream_handler = logging.StreamHandler(sys.stdout)
        logger.addHandler(stream_handler)

        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]
        model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=sshleifer/tiny-gpt2")
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
            result = run_generation.main()
            self.assertGreaterEqual(len(result[0]), 10)
Exemplo n.º 3
0
    def test_generation(self):
        testargs = ["run_generation.py", "--prompt=Hello", "--length=10", "--seed=42"]

        if is_cuda_and_apex_available():
            testargs.append("--fp16")

        model_type, model_name = (
            "--model_type=gpt2",
            "--model_name_or_path=sshleifer/tiny-gpt2",
        )
        with patch.object(sys, "argv", testargs + [model_type, model_name]):
            result = run_generation.main()
            self.assertGreaterEqual(len(result[0]), 10)
Exemplo n.º 4
0
def generate_text(params):
    """Generate text using transformers."""
    prompt = ''
    if not params['genre'] and not params['title'] and not params['prefix']:
        prompt += EOS_TOKEN
    if params['genre']:
        prompt += params['genre'] + EOG_TOKEN
    if params['title']:
        prompt += params['title'] + EOT_TOKEN
    if params['prefix']:
        prompt += params['prefix']
    text = run_generation.main([
        '--model_type=gpt2', '--model_name_or_path=app/output',
        f"--prompt={prompt}" if prompt else '--prompt=""',
        f'--temperature={float(params["temp"]) if params["temp"] else uniform(0.7, 1)}',
        f'--top_p={float(params["top_p"]) if params["top_p"] else 0}',
        '--num_samples=1', '--length=256', f'--stop_token={EOS_TOKEN}'
    ])
    return prompt + text
Exemplo n.º 5
0
def generate_text(params):
    """Generate text using transformers."""
    if params['genre'] and params['prefix']:
        prompt = BOS_TOKEN + params['genre'] + EOG_TOKEN + params['prefix']
    elif params['genre']:
        prompt = BOS_TOKEN + params['genre'] + EOG_TOKEN
    elif params['prefix']:
        # If user specifies prefix, the model cannot generate the genre anymore
        prompt = params['prefix']
    else:
        prompt = BOS_TOKEN
    text = run_generation.main([
        '--model_type=gpt2', '--model_name_or_path=app/output',
        f"--prompt={prompt}" if prompt else '--prompt=""',
        f'--temperature={float(params["temp"]) if params["temp"] else uniform(0.7, 1)}',
        f'--top_p={float(params["top_p"]) if params["top_p"] else 0}',
        '--num_samples=1', '--length=256', f'--stop_token={EOS_TOKEN}'
    ])
    return prompt + text
Exemplo n.º 6
0
def inference():
    SEED = random.randint(0, 9999)
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_return_sequences',
                        help='num_return_sequences=1',
                        type=int,
                        default=3)
    parser.add_argument('--batch_size', type=int, default=1)
    parser.add_argument('--no_fp16', help='', type=bool, default=False)
    args = parser.parse_args()
    num_return_sequences = args.num_return_sequences
    batch_size = args.batch_size
    if (args.no_fp16):
        fp16 = False
    else:
        fp16 = True
    prompt = input("Model prompt >>> ")
    args = [
        "--num_return_sequences=" + str(num_return_sequences),
        "--prompt=" + prompt, "--model_type=" + MODEL_TYPE,
        "--model_name_or_path=" + MODEL_PATH, "--seed=" + str(SEED),
        "--length=350"
    ]
    if (fp16):
        args.append("--fp16")
    start_time = time.time()
    sequences = run_generation.main(args, batch_size)
    print("text generation took --- %s seconds ---" %
          (time.time() - start_time))
    ts = time.time()
    filename = DATA_PATH + "/inference_results_" + datetime.datetime.fromtimestamp(
        ts).strftime('%Y-%m-%d_%H:%M') + ".csv"
    print("saving results to: ", filename)
    with open(filename, 'w', newline='') as myfile:
        wr = csv.writer(myfile, quoting=csv.QUOTE_ALL)
        wr.writerow(sequences)
Exemplo n.º 7
0
import argparse
import logging
import sys
import unittest
from unittest.mock import patch

import run_generation

logging.basicConfig(level=logging.DEBUG)

logger = logging.getLogger()

testargs = ["run_generation.py", "--prompt=Hello", "--length=240", "--seed=42"]

model_type, model_name = ("--model_type=gpt2", "--model_name_or_path=gpt2")
with patch.object(sys, "argv", testargs + [model_type, model_name]):
    result = run_generation.main()
    print(result)