def main(args):
    """
    Load generated model checkpoints from by default in /checkpoint/run1 and generate new text
    """
    try:
        config_path = project_path + "/" + args.config
        input_data_path = project_path + "/" + args.input
        output_data_path = project_path + "/" + args.output

        config = load_config(config_path)

        # load data
        df = read_csv(input_data_path)
        lines = list(df['raw_line'])
        random.seed(config['generate']['random_seed'])
        sample_seeds = random.choices(lines, k=config['generate']['num'])

        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess)

        pred = []
        for i in sample_seeds:
            out = gpt2.generate(sess,
                                prefix=i,
                                **config['generate']['generator'])
            pred.append(out)

        pred_df = pd.DataFrame(pred, columns=['raw_line'])
        save_csv(pred_df, output_data_path)

    except Exception as e:
        logger.error(
            "Unexpected error occurred when generating dialogues with gpt2: " +
            str(e))
Ejemplo n.º 2
0
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return JSONResponse({'text': ''}, headers=response_header)
    print('+++++++++++++++')
    print(params)
    text = gpt2.generate(sess,
                         length=100,
                         temperature=float(params.get('temperature', 0.7)),
                         prefix=params.get('prefix', '')[:500],
                         return_as_list=True)[0]

    generate_count += 1
    if generate_count == 8:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=1)
        gpt2.load_gpt2(sess)
        generate_count = 0

    gc.collect()
    text = re.split('\n', text)
    return JSONResponse({'text': text}, headers=response_header)
Ejemplo n.º 3
0
def loader(game_name):
    print(Fore.GREEN)
    l = pyfiglet.figlet_format("Loading...", font="slant")
    print(Style.BRIGHT + l)
    print(Fore.RESET)
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess)

    input1 = "I am Leo"
    stories = gpt2.generate(sess,
                            length=250,
                            temperature=0.7,
                            prefix=input1,
                            nsamples=5,
                            batch_size=5,
                            top_k=40,
                            return_as_list=True)

    print(Fore.RESET + Style.RESET_ALL)
    story = ""
    temp = stories[3].split(".")
    del temp[-1]
    for i in temp:
        story = story + i + '.'
    return str(story)
Ejemplo n.º 4
0
def generate_story():
    global sess
    # input text is request.form['input']

    try:
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess, run_name=request.form['genre'])
        print("GENRE")
        print(request.form['genre'])
        generated_text = gpt2.generate(sess,
                                       run_name=request.form['genre'],
                                       length=200,
                                       temperature=0.8,
                                       prefix=str(request.form['input']),
                                       nsamples=1,
                                       batch_size=1,
                                       return_as_list=True)[0]
        return Response(response=generated_text, status=200)

    except:
        traceback.print_exc(file=sys.stdout)
        print('aborting gen text')
        abort(404)
Ejemplo n.º 5
0
def main():
    args = parse_args()
    if not args.file:
        logger.error("No file entered. Use -f flag.")
        exit()
    filename = Path(args.file).stem

    logger.debug("Download model")
    gpt2.download_gpt2()

    logger.debug("Starting GPT-2 session")
    sess = gpt2.start_tf_sess()
    logger.debug("Finetuning model")
    gpt2.finetune(sess, args.file, steps=args.iteration)

    Path("Exports").mkdir(parent=False, exist_ok=True)

    logger.debug("Generating text")
    while True:
        generated_text = gpt2.generate(sess,
                                       return_as_list=True,
                                       temperature=args.temperature)[0]
        with open(f"Exports/{filename}_{args.temperature}_gpt2simple.txt",
                  "a") as f:
            test_hour = datetime.datetime.now().strftime("%Y/%m/%d %H:%M")
            f.write(f"{test_hour}\n")
            for i in generated_text:
                f.write(f"{i}\n")
    logger.info("Runtime : %.2f seconds" % (time.time() - temps_debut))
Ejemplo n.º 6
0
def train(input_file):

    if os.path.exists('models/temp'):
        shutil.rmtree('models/temp')
    
    if os.path.exists('models/124M'):
        pass
    else:
        download()

    sess = gpt2.start_tf_sess()
   
    model_name = '124M'
    model_dir = 'models/'
    training_dir = 'src/training_data/'
    file_name = input_file.split('.')[0]

    gpt2.finetune(sess,
        training_dir+input_file,
        model_name=model_name,
        checkpoint_dir=model_dir+'temp/',
        run_name='',
        steps=1)
    
    gpt2.reset_session(sess)
    
    if os.path.exists('models/latest'):
        shutil.rmtree('models/latest')
    shutil.copytree('models/temp','models/latest')
    # shutil.rmtree('models/temp')
def train_data (inputFile, outputDir):
    sess = gpt2.start_tf_sess()
    # train for the input file 

#    tf.variable_scope("sess", reuse=True)
     
    '''
    gpt2.finetune(sess,
        "resource/"+inputFile+".txt",
        model_name=model_name,
        #run_name=inputFile,
        overwrite=True,
        steps=2)
        #_traceback = tf_stack.extract_stack())   # steps is max number of training steps    ''' 
        
        
    # generate 50 examples
    for x in range(0,10):
      #  tf.get_variable_scope().reuse_variables()
        gpt2.load_gpt2(sess)
      #  output = gpt2.generate(sess, return_as_list=True)[0]
        gpt2.generate_to_file(sess, destination_path="newoutputs/" +outputDir+str(uuid.uuid4())+".txt")

#        datalist = gpt2.generate(sess, return_as_list=True)[0]
#        print (datalist)
#        gen_to_file(outputDir, inputFile)
 #       tf.get_variable_scope(reuse=True)
        #tf.get_variable_scope().reuse_variables()

#    tf.get_variable_scope(reuse=True)
#    tf.AUTO_REUSE = True
#    sess.reuse_variables()
   # tf.reset_default_graph()
    gpt2.reset_session(sess,threads=-1,server=None)
    '''
Ejemplo n.º 8
0
    def get(self, context=''):
        run_name = 'run3'

        sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(sess, run_name=run_name)

        results = gpt2.generate(sess,
                                run_name=run_name,
                                prefix=context,
                                nsamples=10,
                                length=200,
                                batch_size=10,
                                temperature=1,
                                top_k=40,
                                include_prefix=True,
                                return_as_list=True)

        all_tweets = []

        for result in results:
            subtweets = result.splitlines()
            all_tweets = list(set(all_tweets + subtweets))

        with io.open('tweets_unseparated.txt', 'r',
                     encoding="utf-8") as tweet_file:
            original_tweets = tweet_file.readlines()

        original_tweets = [x.strip() for x in original_tweets]

        all_tweets = list(set(all_tweets) - set(original_tweets))

        result = {'predicted_text': all_tweets}
        return jsonify(result)
Ejemplo n.º 9
0
def generator(data_1, data_2, data_3):

    game_name = data_1
    epoch = data_2
    model = data_3
    db = Database.readBlobData(game_name)
    file_name = "/content/app/data/data.txt"

    g = pyfiglet.figlet_format("Generating world...", font="slant")
    print(Fore.BLACK + Style.DIM)

    print(Fore.GREEN)
    print(Style.BRIGHT + g)
    print(Fore.BLACK + Style.DIM)
    sess = gpt2.start_tf_sess()

    sample = gpt2.finetune(sess,
                           dataset=file_name,
                           model_name=data_3,
                           steps=epoch,
                           restore_from='fresh',
                           run_name="run1",
                           print_every=1,
                           sample_every=epoch,
                           save_every=epoch)
    return sample
  def fit(self,
          input_path,
          reset = True,
          overwrite = False,
          num_steps = 1000,
          batch_size = 1,
          print_every = 10,
          sample_every = 200,
          save_every = 300,
          restore_from = 'fresh',
          run_name = 'reddit_comment_generator'):
    if reset:
      tf.reset_default_graph()
      self.tf_sess = gpt2.start_tf_sess()

    if overwrite and restore_from != 'latest':
      restore_from = 'latest'

    # Finetuning the model on new data
    gpt2.finetune(self.tf_sess,
                  dataset = input_path,
                  batch_size = batch_size,
                  model_name = self.model_type,
                  steps = num_steps,
                  restore_from = restore_from,
                  run_name = run_name,
                  print_every = print_every,
                  sample_every = sample_every,
                  save_every = save_every)
  def generate_comments(self,
                        user_input,
                        bert_model_prediction,
                        length = 200,
                        temperature = 0.7,
                        num_samples = 2,
                        batch_size = 1,
                        top_k = 0,
                        top_p = 0,
                        run_name = 'reddit_comment_generator',
                        checkpoint_dir = './GPT2/checkpoint',
                        truncate_string = None):
    if not self.tf_sess:
      self.tf_sess = gpt2.start_tf_sess()

    # Generate samples
    subreddit_id = self.SubredditMapping[bert_model_prediction]
    prefix = '****S ' + subreddit_id + '\n' + user_input + '\n' + '****ES'

    comments = gpt2.generate(self.tf_sess,
                             length = length,
                             temperature = temperature,
                             prefix = prefix,
                             nsamples = num_samples,
                             batch_size = batch_size,
                             run_name = run_name,
                             top_k = top_k,
                             top_p = top_p,
                             return_as_list = True,
                             checkpoint_dir = checkpoint_dir,
                             truncate = truncate_string)

    index = 0
    shuffle(self.Names)
    ans = ''
    for text in comments:
        text = text.split('\n')
        L = len(text)

        i = 0
        while ('****TC' not in text[i]):
            text[i] = ''
            i += 1

        start = i
        while(i < L and '****S' not in text[i]):
            if '****TC' in text[i]:
              text[i] = '<strong>' + str(self.Names[index]) + '</strong>'
              index += 1
            elif '****ETC' in text[i]:
              text[i] = ''
            i += 1

        text = text[start:i]
        text = '\n'.join(text)
        if not ans:
            ans = text
        else:
            ans = ans + '\n\n' + text
    return ans
Ejemplo n.º 12
0
def main():
    helix = twitch.Helix('', use_cache=True)
    global lastmsg, msg
    lastmsg = datetime.datetime.now()
    msg = queue.Queue(100)
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name=run_name)

    while not msg.full():
        newmsg = genmsg(sess)
        print(newmsg)
        msg.put(newmsg)

    for channel in textchannels:
        chat = twitch.Chat(channel="#" + channel,
                           nickname='WoodenLongboard',
                           oauth="",
                           helix=helix)
        chats[channel] = chat
        chats[channel].subscribe(handle_message)

    print("Finished init")

    while True:
        if not msg.full():
            msg.put(genmsg(sess))
Ejemplo n.º 13
0
    def prepare_fine_tuning(self, file_name: str):
        """
        prepare_fine_tuning : Personnalise et regle le modèle pour l'entrainer sur notre dataset.
        
        Args:
            file_name (str): Nom du fichier d'entrée.
        """
        if not os.path.isdir(os.path.join("models", self.model_name)):
            print(f"Downloading {self.model_name} model...")
            gpt2.download_gpt2(
                model_name=self.model_name
            )  # model is saved into current directory under /models/124M/

        sess = gpt2.start_tf_sess()

        gpt2.finetune(
            sess,
            dataset=file_name,
            model_name=self.model_name,
            steps=1000,
            restore_from="fresh",
            run_name=self.run_name,
            print_every=10,
            sample_every=200,
            save_every=500,
        )
Ejemplo n.º 14
0
def generate():
    first_line = request.args['firstLine']
    first_line = '<|startoftext|> ' + first_line.lower()

    sess = gpt2.start_tf_sess(threads=1)
    gpt2.load_gpt2(sess, run_name="run1", checkpoint_dir="checkpoint")

    output = ['']
    while (len(output[0]) <= len(first_line) + 30):
        output = gpt2.generate(sess,
                               run_name='run1',
                               checkpoint_dir='checkpoint',
                               model_dir='models',
                               sample_dir='samples',
                               return_as_list=True,
                               length=120,
                               temperature=0.7,
                               prefix=first_line,
                               truncate="<|endoftext|>",
                               include_prefix=True)

    tf.reset_default_graph()
    sess.close()
    gc.collect()

    data = output[0].replace('<|startoftext|> ', '')
    return json.dumps({"data": data})
Ejemplo n.º 15
0
 def __init__(self, source, num_words, prompt="DEFAULT", temp=0.7):
     #separating blocks into sentence tokens
     nltk.download('punkt')
     self.tokenizer = nltk.data.load('tokenizers/punkt/english.pickle')
     #where the training data is stored
     self.source = source
     #deviation from original dataset
     self.temperature = temp
     #length of output
     self.num_words = num_words
     #user input
     self.prompt = prompt
     if self.prompt == "DEFAULT":
         self.prompt = "The quick brown fox jumped over the lazy dog."
     files = os.listdir(os.getcwd())
     if "models" not in files:
         #first-time runthrough
         self.setupModel()
         print('Setup Complete')
     if "checkpoint" not in files:
         #story generator, give parameters if necessary
         self.trainGenerator()
         print('Training Complete')
     tf.reset_default_graph()
     self.session = gpt2.start_tf_sess()
     gpt2.load_gpt2(self.session, run_name='run1')
     print('Done')
Ejemplo n.º 16
0
 async def _generate_samples(self, model: typing.Optional[str] = None, max_size=31):
     if model == None:
         return
     cog_data_path = data_manager.cog_data_path(self)
     model_path: pathlib.Path = cog_data_path / "models" / model
     if not model_path.exists():
         log.error(f"Model {model} not found in {str(cog_data_path)}")
         return
     tf_session = gpt_2_simple.start_tf_sess()
     gpt_2_simple.load_gpt2(
         tf_session,
         checkpoint_dir=str(cog_data_path / "checkpoints"),
         model_name=model,
         model_dir=str(model_path.parent),
     )
     while True:
         new_sample = gpt_2_simple.generate(
             tf_session,
             return_as_list=True,
             truncate="<|endoftext|>",
             temperature=1.0,
         )[0]
         async with self.full:
             if len(self.samples) >= max_size:
                 log.info("Cache full, waiting for next command")
                 await self.full.wait()
             self.samples.append(new_sample)
         async with self.empty:
             if len(self.samples) == 1:
                 self.empty.notify()
Ejemplo n.º 17
0
    def on_status(self, tweet):
        print("Received status.")
        if self.is_mentioned(tweet):
            username = tweet.user.screen_name
            text = str(tweet.text)
            # do I need to remove @sarcastic_trump before generating a new prediction?
            text_without_self_username = text.replace("@sarcastic_trump", "")
            try:
                generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix(
                    prefix=text_without_self_username)
            except:

                tf.reset_default_graph()
                self.sess = gpt2.start_tf_sess()
                gpt2.load_gpt2(self.sess, run_name='trump_clean_small')
                generated_tweet_from_text_as_prefix = self.generate_gpt2_tweet_using_prefix(
                    prefix=text_without_self_username)

            # remove the extra lines without punctuation. Might also remove hashtags # ?
            tweet_without_extra_lines = self.remove_extra_lines(
                generated_tweet_from_text_as_prefix[0])

            print(f"{username}: {text}")

            if len(generated_tweet_from_text_as_prefix[0]) + len(
                    username) + 5 > 240:
                self.api.update_status(
                    f"Hey @{username}, {tweet_without_extra_lines[0:240-len(username)-5]}",
                    in_reply_to_status_id=tweet.id)
            else:
                self.api.update_status(
                    f"Hey @{username}, {tweet_without_extra_lines}",
                    in_reply_to_status_id=tweet.id)
Ejemplo n.º 18
0
    def finetune(self, corpus, return_text=True):
        """ Returns generated text sample

        Parameters
        ----------
        arg: corpus (object)
            - desc: Custom dataset text file

        arg: return_text (bool)
            - default: True
            - desc: Toggles whether to return custom-generated text in an array after fine-tuning

        Returns:
            Generated string in an array
        """
        sess = gpt2.start_tf_sess()
        gpt2.finetune(sess,
                corpus,
                model_name=self.model_name,
                steps=1000)     # steps is max number of training steps

        if return_text:
            text = gpt2.generate(sess, return_as_list=True)
            return text
        else:
            gpt2.generate(sess)	
Ejemplo n.º 19
0
async def generate(input: str = "", auth: str = ""):
    global sess, generate_count

    if auth != AUTH_KEY:
        return "Invalid auth token provided"

    result = gpt2.generate(
        sess,
        run_name="run1",
        length=300,
        temperature=0.9,
        prefix=input,
        top_p=100,
        nsamples=1,
        batch_size=1,
        include_prefix=False,
        return_as_list=True,
    )[0]

    generate_count += 1

    if generate_count == 12:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=8)
        gpt2.load_gpt2(sess, run_name="run1")
        generate_count = 0

    return result
Ejemplo n.º 20
0
def generate(prefix, input_file, similarity_threshold, nsamples, length,
             temperature, k):
    # load the quotes used for fine-tuning
    with open(input_file, 'r') as f:
        originals = f.readlines()
        original_quotes = [
            originals[i].strip() for i in range(1, len(originals), 3)
        ]

    # generate a batch of quotes
    sess = gpt2s.start_tf_sess()
    gpt2s.load_gpt2(sess)
    samples = gpt2s.generate(sess,
                             nsamples=nsamples,
                             length=length,
                             temperature=temperature,
                             top_k=k,
                             prefix=prefix + '\n',
                             return_as_list=True)

    # filter the samples
    quotes = []
    for s in samples:
        title, body = s.split('\n')[:2]
        is_long = len(body.split(' ')) > 3
        is_novel = all(
            similar(body, x) < similarity_threshold for x in original_quotes)

        if is_long and is_novel:
            quotes.append(body)

    return quotes
Ejemplo n.º 21
0
def gpt2_finetune(hparams):
    info_print("Model finetuning, please wait. (Press Ctrl+C to exit early)")
    sess = gpt2.start_tf_sess()

    # input check
    if not os.path.exists(
            os.path.join(hparams.gpt2_model_dir, hparams.gpt2_model_name)):
        raise FileNotFoundError(
            "The specified gpt2 pretrained model doesn't exist, please restore the default params."
        )

    # clear checkpoint dir
    model_path = os.path.join(hparams.finetuned_model_dir,
                              hparams.finetuned_model_name)
    if os.path.exists(model_path):
        shutil.rmtree(model_path)

    gpt2.finetune(sess=sess,
                  dataset=hparams.data_path,
                  model_dir=hparams.gpt2_model_dir,
                  model_name=hparams.gpt2_model_name,
                  checkpoint_dir=hparams.finetuned_model_dir,
                  run_name=hparams.finetuned_model_name,
                  multi_gpu=hparams.multi_gpu,
                  steps=hparams.steps)
Ejemplo n.º 22
0
def start():
    print("Starting")
    start_time = datetime.datetime.now()

    sess = gpt2.start_tf_sess()

    gpt2.load_gpt2(sess, model_name=model_name)

    text = gpt2.generate(
        sess,
        model_name=model_name,
        prefix=
        "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
        "previously unexplored valley, in the Andes Mountains. Even more surprising to the "
        "researchers was the fact that the unicorns spoke perfect English.",
        length=100,
        temperature=0.7,
        top_p=0.9,
        return_as_list=True)

    total_time = datetime.datetime.now() - start_time

    print("Total time required is = ", total_time)

    print(text)

    return " ".join(text)
Ejemplo n.º 23
0
def finetune(
    model_name: str,
    text_path: str,
    num_steps: int,
    sample_length: int,
    save_every: Optional[int],
) -> None:

    # Download the model if it is not present
    if not os.path.isdir(os.path.join("models", model_name)):
        print(f"Downloading {model_name} model...")
        gpt2.download_gpt2(model_name=model_name)

    sess = gpt2.start_tf_sess()

    if save_every is None:
        save_every = int(num_steps / 4)

    gpt2.finetune(
        sess,
        text_path,
        model_name=model_name,
        steps=num_steps,
        sample_length=sample_length,
        save_every=save_every,
    )  # steps is max number of training steps

    gpt2.generate(sess)
Ejemplo n.º 24
0
def main():
    if len(sys.argv) < 4:
        print(
            'Usage: python run_generator.py RUN_NAME SUBREDDIT NO_SAMPLES (TEMPERATURE)'
        )
        return

    run_name = sys.argv[1]
    subreddit = sys.argv[2]
    try:
        no_samples = int(sys.argv[3])
    except Exception as e:
        print(e)
        print('Third argument should be an integer')
        return

    temperature = 1
    if len(sys.argv) >= 5:
        try:
            temperature = float(sys.argv[4])
        except Exception as e:
            print(e)
            print('Fourth argument should be a float')
            return

    update_checkpoint(run_name)

    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, run_name=run_name, checkpoint_dir='generator_models')

    generate_to_file(sess, run_name, subreddit, n=no_samples, temp=temperature)

    print('Done.')
Ejemplo n.º 25
0
async def homepage(request):
    global generate_count
    global sess

    if request.method == 'GET':
        params = request.query_params
    elif request.method == 'POST':
        params = await request.json()
    elif request.method == 'HEAD':
        return UJSONResponse({'text': ''}, headers=response_header)

    text = gpt2.generate(sess,
                         length=int(params.get('length', 1023)),
                         temperature=float(params.get('temperature', 0.7)),
                         top_k=int(params.get('top_k', 0)),
                         top_p=float(params.get('top_p', 0)),
                         prefix=params.get('prefix', '')[:500],
                         truncate=params.get('truncate', None),
                         include_prefix=str(params.get(
                             'include_prefix', True)).lower() == 'true',
                         return_as_list=True)[0]

    generate_count += 1
    if generate_count == 8:
        # Reload model to prevent Graph/Session from going OOM
        tf.reset_default_graph()
        sess.close()
        sess = gpt2.start_tf_sess(threads=1)
        gpt2.load_gpt2(sess)
        generate_count = 0

    gc.collect()
    return UJSONResponse({'text': text}, headers=response_header)
Ejemplo n.º 26
0
def generate_ideas():
    prefix = request.args.get("prefix")
    length = int(request.args.get("length", 50))
    samples = int(request.args.get("samples", 1))

    if samples <= 0 or samples > 5:
        abort(
            jsonify(
                {
                    "message":
                    "Samples value is invalid, min 1 and max 5 allowed."
                },
                400,
            ))

    session = gpt2.start_tf_sess()
    gpt2.load_gpt2(session, model_name=model_name)

    ideas = gpt2.generate(
        session,
        model_name=model_name,
        prefix=prefix,
        length=length,
        nsamples=samples,
        batch_size=samples,
    )

    session.close()

    return jsonify(ideas=ideas)
Ejemplo n.º 27
0
 def __init__(self, group=None, target=None, name=None,
              args=(), kwargs=None, verbose=None):
     super(GeneratorThread,self).__init__()
     self.target = target
     self.name = name
     self.last_model = ''
     self.sess = gpt2.start_tf_sess()
     return
Ejemplo n.º 28
0
def load_model(run):
    # Get our pathfor the checkpoint setup
    checkpoint_dir = Path("checkpoint").absolute()

    # Start tensorflow session & load our model
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, checkpoint_dir=checkpoint_dir, run_name=run)
    return sess
Ejemplo n.º 29
0
def loader():
    # if 'sess' not in cache:
    #     cache['sess'] = gpt2.start_tf_sess()
    #     gpt2.load_gpt2(cache['sess'], checkpoint_dir='./gpt_2/checkpoint', run_name='run1')
    # return cache['sess']
    sess = gpt2.start_tf_sess()
    gpt2.load_gpt2(sess, checkpoint_dir='../assets/gpt_2/checkpoint', run_name='run1')
    return sess
Ejemplo n.º 30
0
def main():

    sess = gpt2.start_tf_sess()

    gpt2.load_gpt2(sess)

    single_text = gpt2.generate(sess, return_as_list=True)[0]
    print(single_text)