def test_default_min_max_length(): happy_tt = HappyTextToText() args = TTSettings(min_length=5, max_length=5) output = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=args) tokens = happy_tt.tokenizer.encode(output.text, return_tensors="pt") length = len(tokens[0]) assert length == 5
def example_7_1(): # --------------------------------------# happy_tt = HappyTextToText() # default uses t5-small top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, min_length=20, max_length=20, early_stopping=True) result = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence", args=top_p_sampling_settings) print(result) # nlp est un domaine de l’intelligence artificielle. n print(result.text) # nlp est un domaine de l’intelligence artificielle. n
def test_tt_save(): happy_tt = HappyTextToText() happy_tt.save("model/") result_before = happy_tt.generate_text( "translate English to French: Hello my name is Eric") happy = HappyTextToText(load_path="model/") result_after = happy.generate_text( "translate English to French: Hello my name is Eric") assert result_before.text == result_after.text
def example_7_2(): happy_tt = HappyTextToText("T5", "t5-small") greedy_settings = TTSettings(no_repeat_ngram_size=2, max_length=20) output_greedy = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence ", args=greedy_settings) beam_settings = TTSettings(num_beams=5, max_length=20) output_beam_search = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence ", args=beam_settings) generic_sampling_settings = TTSettings(do_sample=True, top_k=0, temperature=0.7, max_length=20) output_generic_sampling = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence ", args=generic_sampling_settings) top_k_sampling_settings = TTSettings(do_sample=True, top_k=50, temperature=0.7, max_length=20) output_top_k_sampling = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence ", args=top_k_sampling_settings) top_p_sampling_settings = TTSettings(do_sample=True, top_k=0, top_p=0.8, temperature=0.7, max_length=20) output_top_p_sampling = happy_tt.generate_text( "translate English to French: nlp is a field of artificial intelligence ", args=top_p_sampling_settings) print("Greedy:", output_greedy.text ) # Greedy: nlp est un domaine de l'intelligence artificielle print("Beam:", output_beam_search.text ) # Beam: nlp est un domaine de l'intelligence artificielle print("Generic Sampling:", output_generic_sampling.text ) # Generic Sampling: nlp est un champ d'intelligence artificielle print( "Top-k Sampling:", output_top_k_sampling.text ) # Top-k Sampling: nlp est un domaine de l’intelligence artificielle print( "Top-p Sampling:", output_top_p_sampling.text ) # Top-p Sampling: nlp est un domaine de l'intelligence artificielle
def example_7_0(): # --------------------------------------# happy_tt = HappyTextToText("T5", "t5-small") # default
def test_default_simple(): happy_tt = HappyTextToText() output = happy_tt.generate_text( "translate English to French: Hello my name is Eric") assert type(output.text) == str
def test_all_methods(): happy_tt = HappyTextToText() greedy_settings = TTSettings(min_length=5, max_length=5, no_repeat_ngram_size=2) output_greedy = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=greedy_settings) beam_settings = TTSettings(min_length=5, max_length=5, early_stopping=True, num_beams=5) output_beam_search = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=beam_settings) generic_sampling_settings = TTSettings(min_length=5, max_length=5, do_sample=True, early_stopping=False, top_k=0, temperature=0.7) output_generic_sampling = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=generic_sampling_settings) top_k_sampling_settings = TTSettings(min_length=5, max_length=5, do_sample=True, early_stopping=False, top_k=50, temperature=0.7) output_top_k_sampling = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=top_k_sampling_settings) top_p_sampling_settings = TTSettings(min_length=5, max_length=5, do_sample=True, early_stopping=False, top_k=0, top_p=0.8, temperature=0.7) output_top_p_sampling = happy_tt.generate_text( "translate English to French: Hello my name is Eric", args=top_p_sampling_settings) assert type(output_greedy.text) == str assert type(output_beam_search.text) == str assert type(output_generic_sampling.text) == str assert type(output_top_k_sampling.text) == str assert type(output_top_p_sampling.text) == str print("greedy: ", output_greedy.text, end="\n\n") print("beam-search: ", output_beam_search.text, end="\n\n") print("generic-sampling: ", output_generic_sampling.text, end="\n\n") print("top-k-sampling: ", output_top_k_sampling.text, end="\n\n") print("top-p-sampling: ", output_top_p_sampling.text, end="\n\n")