예제 #1
0
    def generate_text(self, start_string: str, num_words: int,
                      temperature: float) -> str:
        """Generate string starting with start_string of length num_characters using rebuilt model.

        Resetting of the graph must be done with every POST request otherwise the model won't run.
        :param start_string: str user wishes to start generation with. Can be a single letter.
        :param num_words: number of words you wish to be generated. Note time to generate increases.
        :param temperature: parameter that determines how 'surprising' the predictions are. value of 1 is neutral,
        lower is more predictable, higher is more surprising.
        :return: string of generated text
        """
        load_start_time = datetime.now()

        tf.compat.v1.reset_default_graph()
        self.sess = gpt2.start_tf_sess()
        gpt2.load_gpt2(self.sess,
                       run_name=self.model_folder,
                       checkpoint_dir=self.checkpoint_directory)

        time_to_load = datetime.now() - load_start_time
        logging.info(f'Time taken to load model {time_to_load}')

        gen_start_time = datetime.now()

        txt = \
            gpt2.generate(self.sess, run_name=self.model_folder, checkpoint_dir=self.checkpoint_directory,
                          length=num_words, temperature=temperature, prefix=start_string, return_as_list=True)[0]

        time_to_generate = datetime.now() - gen_start_time
        logging.info(f'Time taken to generate lyrics {time_to_generate}')

        txt = CleanOutput.remove_lastline(text_in=txt)
        txt = CleanOutput.capitalise_first_character(text_in=txt)
        txt = CleanOutput.clean_line(text_in=txt)
        txt = CleanOutput.sanitise_string(text_in=txt,
                                          custom_badwords=custom_badwords)

        self.sess.close()

        return txt
예제 #2
0
파일: server.py 프로젝트: mxdillon/kanyai
def get_text(text_input: str, num_words: int, generator: GenerateLyrics) -> str:
    """Generate the lyrics for the text input from the model.

    :param text_input: starting lyric from the input form
    :param num_words: # words to generate
    :param generator: model generator
    :return: sanitised lyrics for rendering (str)get
    """
    if text_input is None:
        return ' '
    else:
        logging.info(f'Generating lyrics for {text_input}')

        logging.debug(f'ensuring space for {text_input}')
        start_phrase = CleanOutput.ensure_newline(text_input)

        logging.debug('generating text')
        generated_text = generator.generate_text(start_string=start_phrase,
                                                 num_words=num_words,
                                                 temperature=1.00)

        return generated_text
예제 #3
0
def test_remove_lastline(text_in, expected):
    """Check that custom profanities are being redacted."""
    assert CleanOutput.remove_lastline(text_in=text_in) == expected
예제 #4
0
def test_clean_sentence(text_in, expected):
    """Check that custom profanities are being redacted."""
    assert CleanOutput.clean_line(text_in=text_in) == expected
예제 #5
0
def test_sanitise_string(text_in, expected):
    """Check that custom profanities are being redacted."""
    assert CleanOutput.sanitise_string(
        text_in=text_in, custom_badwords=custom_badwords) == expected
예제 #6
0
def test_ensure_newline(text_in, expected):
    """Test space trimming also for a nice output."""
    assert CleanOutput.ensure_newline(text_in=text_in) == expected
예제 #7
0
def test_capitalise_first_character(text_in, expected):
    """Test capitalising the first character for a nice output."""
    assert CleanOutput.capitalise_first_character(text_in=text_in) == expected