예제 #1
0
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return JSONResponse({'text': ''}, headers=response_header)
    print('+++++++++++++++')
    print(params)
    text = gpt2.generate(sess,
                         length=100,
                         temperature=float(params.get('temperature', 0.7)),
                         prefix=params.get('prefix', '')[:500],
                         return_as_list=True)[0]

    generate_count += 1
    if generate_count == 8:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=1)
        gpt2.load_gpt2(sess)
        generate_count = 0

    gc.collect()
    text = re.split('\n', text)
    return JSONResponse({'text': text}, headers=response_header)
예제 #2
0
    def on_status(self, tweet):
        print("Received status.")
        if self.is_mentioned(tweet):
            username = tweet.user.screen_name
            text = str(tweet.text)
            # do I need to remove @sarcastic_trump before generating a new prediction?
            text_without_self_username = text.replace("@sarcastic_trump", "")
            try:
                generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix(
                    prefix=text_without_self_username)
            except:

                tf.reset_default_graph()
                self.sess = gpt2.start_tf_sess()
                gpt2.load_gpt2(self.sess, run_name='trump_clean_small')
                generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix(
                    prefix=text_without_self_username)

            # remove the extra lines without punctuation. Might also remove hashtags # ?
            tweet_without_extra_lines = self.remove_extra_lines(
                generated_tweet_from_text_as_prefix[0])

            print(f"{username}: {text}")

            if len(generated_tweet_from_text_as_prefix[0]) + len(
                    username) + 5 > 240:
                self.api.update_status(
                    f"Hey @{username}, {tweet_without_extra_lines[0:240-len(username)-5]}",
                    in_reply_to_status_id=tweet.id)
            else:
                self.api.update_status(
                    f"Hey @{username}, {tweet_without_extra_lines}",
                    in_reply_to_status_id=tweet.id)
예제 #3
0
def start():
    print("Starting")
    start_time = datetime.datetime.now()

    sess = gpt2.start_tf_sess()

    gpt2.load_gpt2(sess, model_name=model_name)

    text = gpt2.generate(
        sess,
        model_name=model_name,
        prefix=
        "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
        "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
        "researchers was the fact that the unicorns spoke perfect English.",
        length=100,
        temperature=0.7,
        top_p=0.9,
        return_as_list=True)

    total_time = datetime.datetime.now() - start_time

    print("Total time required is = ", total_time)

    print(text)

    return " ".join(text)
def main(args):
    """
    Load generated model checkpoints from by default in /checkpoint/run1 and generate new text
    """
    try:
        config_path = project_path + "/" + args.config
        input_data_path = project_path + "/" + args.input
        output_data_path = project_path + "/" + args.output

        config = load_config(config_path)

        # load data
        df = read_csv(input_data_path)
        lines = list(df['raw_line'])
        random.seed(config['generate']['random_seed'])
        sample_seeds = random.choices(lines, k=config['generate']['num'])

        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess)

        pred = []
        for i in sample_seeds:
            out = gpt2.generate(sess,
                                prefix=i,
                                **config['generate']['generator'])
            pred.append(out)

        pred_df = pd.DataFrame(pred, columns=['raw_line'])
        save_csv(pred_df, output_data_path)

    except Exception as e:
        logger.error(
            "Unexpected error occurred when generating dialogues with gpt2: " +
            str(e))
def train_data (inputFile, outputDir):
    sess = gpt2.start_tf_sess()
    # train for the input file 

#    tf.variable_scope("sess", reuse=True)
     
    '''
    gpt2.finetune(sess,
        "resource/"+inputFile+".txt",
        model_name=model_name,
        #run_name=inputFile,
        overwrite=True,
        steps=2)
        #_traceback = tf_stack.extract_stack())   # steps is max number of training steps    ''' 
        
        
    # generate 50 examples
    for x in range(0,10):
      #  tf.get_variable_scope().reuse_variables()
        gpt2.load_gpt2(sess)
      #  output = gpt2.generate(sess, return_as_list=True)[0]
        gpt2.generate_to_file(sess, destination_path="newoutputs/" +outputDir+str(uuid.uuid4())+".txt")

#        datalist = gpt2.generate(sess, return_as_list=True)[0]
#        print (datalist)
#        gen_to_file(outputDir, inputFile)
 #       tf.get_variable_scope(reuse=True)
        #tf.get_variable_scope().reuse_variables()

#    tf.get_variable_scope(reuse=True)
#    tf.AUTO_REUSE = True
#    sess.reuse_variables()
   # tf.reset_default_graph()
    gpt2.reset_session(sess,threads=-1,server=None)
    '''
예제 #6
0
파일: app.py 프로젝트: zeta1999/gpt2-french
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return UJSONResponse({'text': ''},
                             headers=response_header)
    
    gpt2.load_gpt2(sess, run_name=params.get('run_name', ''))

    text = gpt2.generate(sess,
                         run_name=params.get('run_name', ''),
                         length=int(params.get('length', 1023)),
                         temperature=float(params.get('temperature', 0.7)),
                         top_k=int(params.get('top_k', 0)),
                         top_p=float(params.get('top_p', 0)),
                         prefix=params.get('prefix', '')[:500],
                         truncate=params.get('truncate', None),
                         include_prefix=str(params.get('include_prefix', True)).lower() == 'true',
                         return_as_list=True)[0]

    sess = gpt2.reset_session(sess)

    gc.collect()
    return UJSONResponse({'text': text},
                         headers=response_header)
예제 #7
0
    def get(self, context=''):
        run_name = 'run3'

        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess, run_name=run_name)

        results = gpt2.generate(sess,
                                run_name=run_name,
                                prefix=context,
                                nsamples=10,
                                length=200,
                                batch_size=10,
                                temperature=1,
                                top_k=40,
                                include_prefix=True,
                                return_as_list=True)

        all_tweets = []

        for result in results:
            subtweets = result.splitlines()
            all_tweets = list(set(all_tweets + subtweets))

        with io.open('tweets_unseparated.txt', 'r',
                     encoding="utf-8") as tweet_file:
            original_tweets = tweet_file.readlines()

        original_tweets = [x.strip() for x in original_tweets]

        all_tweets = list(set(all_tweets) - set(original_tweets))

        result = {'predicted_text': all_tweets}
        return jsonify(result)
예제 #8
0
def generate(prefix, input_file, similarity_threshold, nsamples, length,
             temperature, k):
    # load the quotes used for fine-tuning
    with open(input_file, 'r') as f:
        originals = f.readlines()
        original_quotes = [
            originals[i].strip() for i in range(1, len(originals), 3)
        ]

    # generate a batch of quotes
    sess = gpt2s.start_tf_sess()
    gpt2s.load_gpt2(sess)
    samples = gpt2s.generate(sess,
                             nsamples=nsamples,
                             length=length,
                             temperature=temperature,
                             top_k=k,
                             prefix=prefix + '\n',
                             return_as_list=True)

    # filter the samples
    quotes = []
    for s in samples:
        title, body = s.split('\n')[:2]
        is_long = len(body.split(' ')) > 3
        is_novel = all(
            similar(body, x) < similarity_threshold for x in original_quotes)

        if is_long and is_novel:
            quotes.append(body)

    return quotes
예제 #9
0
def main():
    helix = twitch.Helix('', use_cache=True)
    global lastmsg, msg
    lastmsg = datetime.datetime.now()
    msg = queue.Queue(100)
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name=run_name)

    while not msg.full():
        newmsg = genmsg(sess)
        print(newmsg)
        msg.put(newmsg)

    for channel in textchannels:
        chat = twitch.Chat(channel="#" + channel,
                           nickname='WoodenLongboard',
                           oauth="",
                           helix=helix)
        chats[channel] = chat
        chats[channel].subscribe(handle_message)

    print("Finished init")

    while True:
        if not msg.full():
            msg.put(genmsg(sess))
예제 #10
0
def generate():
    first_line = request.args['firstLine']
    first_line = '<|startoftext|> ' + first_line.lower()

    sess = gpt2.start_tf_sess(threads=1)
    gpt2.load_gpt2(sess, run_name="run1", checkpoint_dir="checkpoint")

    output = ['']
    while (len(output[0]) <= len(first_line) + 30):
        output = gpt2.generate(sess,
                               run_name='run1',
                               checkpoint_dir='checkpoint',
                               model_dir='models',
                               sample_dir='samples',
                               return_as_list=True,
                               length=120,
                               temperature=0.7,
                               prefix=first_line,
                               truncate="<|endoftext|>",
                               include_prefix=True)

    tf.reset_default_graph()
    sess.close()
    gc.collect()

    data = output[0].replace('<|startoftext|> ', '')
    return json.dumps({"data": data})
예제 #11
0
async def generate(input: str = "", auth: str = ""):
    global sess, generate_count

    if auth != AUTH_KEY:
        return "Invalid auth token provided"

    result = gpt2.generate(
        sess,
        run_name="run1",
        length=300,
        temperature=0.9,
        prefix=input,
        top_p=100,
        nsamples=1,
        batch_size=1,
        include_prefix=False,
        return_as_list=True,
    )[0]

    generate_count += 1

    if generate_count == 12:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=8)
        gpt2.load_gpt2(sess, run_name="run1")
        generate_count = 0

    return result
예제 #12
0
 async def _generate_samples(self, model: typing.Optional[str] = None, max_size=31):
     if model == None:
         return
     cog_data_path = data_manager.cog_data_path(self)
     model_path: pathlib.Path = cog_data_path / "models" / model
     if not model_path.exists():
         log.error(f"Model {model} not found in {str(cog_data_path)}")
         return
     tf_session = gpt_2_simple.start_tf_sess()
     gpt_2_simple.load_gpt2(
         tf_session,
         checkpoint_dir=str(cog_data_path / "checkpoints"),
         model_name=model,
         model_dir=str(model_path.parent),
     )
     while True:
         new_sample = gpt_2_simple.generate(
             tf_session,
             return_as_list=True,
             truncate="<|endoftext|>",
             temperature=1.0,
         )[0]
         async with self.full:
             if len(self.samples) >= max_size:
                 log.info("Cache full, waiting for next command")
                 await self.full.wait()
             self.samples.append(new_sample)
         async with self.empty:
             if len(self.samples) == 1:
                 self.empty.notify()
예제 #13
0
def generate_story():
    global sess
    # input text is request.form['input']

    try:
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess, run_name=request.form['genre'])
        print("GENRE")
        print(request.form['genre'])
        generated_text = gpt2.generate(sess,
                                       run_name=request.form['genre'],
                                       length=200,
                                       temperature=0.8,
                                       prefix=str(request.form['input']),
                                       nsamples=1,
                                       batch_size=1,
                                       return_as_list=True)[0]
        return Response(response=generated_text, status=200)

    except:
        traceback.print_exc(file=sys.stdout)
        print('aborting gen text')
        abort(404)
예제 #14
0
def main():
    if len(sys.argv) < 4:
        print(
            'Usage: python run_generator.py RUN_NAME SUBREDDIT NO_SAMPLES (TEMPERATURE)'
        )
        return

    run_name = sys.argv[1]
    subreddit = sys.argv[2]
    try:
        no_samples = int(sys.argv[3])
    except Exception as e:
        print(e)
        print('Third argument should be an integer')
        return

    temperature = 1
    if len(sys.argv) >= 5:
        try:
            temperature = float(sys.argv[4])
        except Exception as e:
            print(e)
            print('Fourth argument should be a float')
            return

    update_checkpoint(run_name)

    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name=run_name, checkpoint_dir='generator_models')

    generate_to_file(sess, run_name, subreddit, n=no_samples, temp=temperature)

    print('Done.')
예제 #15
0
파일: app.py 프로젝트: Visya/ghostwriter
def generate_ideas():
    prefix = request.args.get("prefix")
    length = int(request.args.get("length", 50))
    samples = int(request.args.get("samples", 1))

    if samples <= 0 or samples > 5:
        abort(
            jsonify(
                {
                    "message":
                    "Samples value is invalid, min 1 and max 5 allowed."
                },
                400,
            ))

    session = gpt2.start_tf_sess()
    gpt2.load_gpt2(session, model_name=model_name)

    ideas = gpt2.generate(
        session,
        model_name=model_name,
        prefix=prefix,
        length=length,
        nsamples=samples,
        batch_size=samples,
    )

    session.close()

    return jsonify(ideas=ideas)
예제 #16
0
파일: app.py 프로젝트: basedrhys/quote-gen
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return UJSONResponse({'text': ''}, headers=response_header)

    text = gpt2.generate(sess,
                         length=int(params.get('length', 1023)),
                         temperature=float(params.get('temperature', 0.7)),
                         top_k=int(params.get('top_k', 0)),
                         top_p=float(params.get('top_p', 0)),
                         prefix=params.get('prefix', '')[:500],
                         truncate=params.get('truncate', None),
                         include_prefix=str(params.get(
                             'include_prefix', True)).lower() == 'true',
                         return_as_list=True)[0]

    generate_count += 1
    if generate_count == 8:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=1)
        gpt2.load_gpt2(sess)
        generate_count = 0

    gc.collect()
    return UJSONResponse({'text': text}, headers=response_header)
예제 #17
0
def loader(game_name):
    print(Fore.GREEN)
    l = pyfiglet.figlet_format("Loading...", font="slant")
    print(Style.BRIGHT + l)
    print(Fore.RESET)
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess)

    input1 = "I am Leo"
    stories = gpt2.generate(sess,
                            length=250,
                            temperature=0.7,
                            prefix=input1,
                            nsamples=5,
                            batch_size=5,
                            top_k=40,
                            return_as_list=True)

    print(Fore.RESET + Style.RESET_ALL)
    story = ""
    temp = stories[3].split(".")
    del temp[-1]
    for i in temp:
        story = story + i + '.'
    return str(story)
예제 #18
0
 def __init__(self, source, num_words, prompt="DEFAULT", temp=0.7):
     #separating blocks into sentence tokens
     nltk.download('punkt')
     self.tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
     #where the training data is stored
     self.source = source
     #deviation from original dataset
     self.temperature = temp
     #length of output
     self.num_words = num_words
     #user input
     self.prompt = prompt
     if self.prompt == "DEFAULT":
         self.prompt = "The quick brown fox jumped over the lazy dog."
     files = os.listdir(os.getcwd())
     if "models" not in files:
         #first-time runthrough
         self.setupModel()
         print('Setup Complete')
     if "checkpoint" not in files:
         #story generator, give parameters if necessary
         self.trainGenerator()
         print('Training Complete')
     tf.reset_default_graph()
     self.session = gpt2.start_tf_sess()
     gpt2.load_gpt2(self.session, run_name='run1')
     print('Done')
예제 #19
0
def main():

    sess = gpt2.start_tf_sess()

    gpt2.load_gpt2(sess)

    single_text = gpt2.generate(sess, return_as_list=True)[0]
    print(single_text)
예제 #20
0
def _load_model(name):
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(
        sess,
        run_name=name,
        checkpoint_dir=cfg.MODEL_DIR,
    )
    return sess
예제 #21
0
def loader():
    # if 'sess' not in cache:
    #     cache['sess'] = gpt2.start_tf_sess()
    #     gpt2.load_gpt2(cache['sess'], checkpoint_dir='./gpt_2/checkpoint', run_name='run1')
    # return cache['sess']
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, checkpoint_dir='../assets/gpt_2/checkpoint', run_name='run1')
    return sess
예제 #22
0
def load_model(run):
    # Get our pathfor the checkpoint setup
    checkpoint_dir = Path("checkpoint").absolute()

    # Start tensorflow session & load our model
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, checkpoint_dir=checkpoint_dir, run_name=run)
    return sess
예제 #23
0
    def __init__(self):
        model_name = "774M"
        run_name = 'run1'

        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess, model_name=model_name)
        @bot.event
        async def on_ready():
            print('Logged in as')
            print(bot.user.name)
            print(bot.user.id)
            print('------')

        @bot.command()
        async def hello(ctx):
            """-Tonto says hello!"""
            await ctx.send("""Hello! I am Jaden Bot, a clone of my original host Jaden. I was designed to act as the Jaden replacment in the event people can not locate or access their local Jaden.""")

        @bot.command()
        async def story(ctx,*,input):
            """-Makes Jaden tell a story."""
            eTxt = discord.Embed(name="LOADING", description="Loading Story ...")
            await ctx.send(embed=eTxt)
            result = gpt2.generate(
            sess, 
            model_name=model_name,
            top_k=40, 
            top_p = 0.9,
            prefix=input, 
            truncate='<|endoftext|>', 
            length=100, 
            temperature=0.7,
            nsamples=1,
            batch_size=1,
            return_as_list=True
            )[0]
            await ctx.send(result)


        @bot.command()
        async def say(ctx,*,input):
            """-Tonto will repeate what you write."""
            await ctx.send(input)

        @bot.command()
        async def cmds(ctx):
            """-Will list commands if you can't use .help."""
            with open('help.txt', 'r') as hfile:
                data = hfile.read()
            await ctx.send(data)

        @bot.command()
        async def repeat(ctx, times: int, content='repeating...'):
            """Repeats a message multiple times."""
            for i in range(times):
                await ctx.send(content)  

        bot.run(TOKEN)
예제 #24
0
 def __init__(self, checkpoint_dir, run_name):
     self.checkpoint_dir = checkpoint_dir
     self.run_name = run_name
     self.sess = gpt2.start_tf_sess()
     gpt2.load_gpt2(self.sess,
                    run_name=self.run_name,
                    checkpoint_dir=self.checkpoint_dir,
                    model_name=None,
                    model_dir='models')
예제 #25
0
def start():
    model_name = "124M"
    if not os.path.isdir(os.path.join("models", model_name)):
        print(f"Downloading {model_name} model...")
        gpt2.download_gpt2(model_name=model_name)

    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, model_name=model_name)
    return sess
예제 #26
0
def generate_text():

    gpt2.download_gpt2()

    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess)

    recipe = gpt2.generate(sess)

    return render_template("index.html", recipe=recipe)
예제 #27
0
def setup(opts):
    global run_name, sess
    run_name = opts['run_name']
    print(f'Run name: {run_name}')

    shutil.copy(opts['checkpoint_file'],
                f'checkpoints/{run_name}/model-10000.data-00000-of-00001')

    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name=run_name)
    return None
예제 #28
0
 def loadRun(self):
     #load an existing trained run and generate from it
     session = gpt2.start_tf_sess()
     gpt2.load_gpt2(session, run_name='run1')
     gpt2.generate_to_file(session,
                           include_prefix=False,
                           truncate=".",
                           destination_path='bot_says.txt',
                           length=self.num_words,
                           temperature=self.temperature,
                           prefix=self.prompt)
예제 #29
0
def gpt(sentence):
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name='run1')
    sent = gpt2.generate(sess,
                         run_name='run1',
                         return_as_list=True,
                         include_prefix=False,
                         prefix=sentence,
                         truncate='<|endoftext|>')
    sess.close()

    return sentence
예제 #30
0
def gen_story(mode, input_text, out_length=1023):
    sess = gpt2.start_tf_sess()

    check_dir = 'tf_model/355M_' + mode
    gpt2.load_gpt2(sess, checkpoint_dir=check_dir)
    text = gpt2.generate(sess, return_as_list=True, 
        checkpoint_dir=check_dir, prefix=input_text, length=out_length, include_prefix=True)

    # tf.reset_default_graph()
    # sess.close()

    return text[0]