def test_learn_hello_world_tree_larger(self): learner = TemplateLatticeLearner(minimal_variables=True, words_per_leaf_slot=2) dataset = list( self.hello_world_and_world_adjective.generate_all_string()) template_tree = learner.learn(dataset) print(template_tree_visualiser.render_tree_string(template_tree)) pruned_template_tree = template_tree.prune_redundant_abstractions() print( "pruned\n", template_tree_visualiser.render_tree_string(pruned_template_tree), ) # Only two templates in the top top_templates = { tt.get_template() for tt in pruned_template_tree.get_children() } self.assertEqual( { Template.from_string("The [SLOT] is [SLOT]"), Template.from_string("[SLOT], [SLOT]!"), }, top_templates, ) self.assertEqual( set(dataset), set({ t.get_template().to_flat_string() for t in pruned_template_tree.get_descendant_leaves() }), )
def test_3_line_learner(self): learner = TemplateLatticeLearner(minimal_variables=True) dataset = ["hello world", "hi world", "hello universe"] template_tree = learner.learn(dataset) expected = TemplateTree( Template.from_string("[SLOT]"), [ TemplateTree( Template.from_string("[SLOT] world"), [ TemplateTree(Template.from_string(s)) for s in ["hello world", "hi world"] ], ), TemplateTree( Template.from_string("hello [SLOT]"), [ TemplateTree(Template.from_string(s)) for s in ["hello world", "hello universe"] ], ), ], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected, template_tree)
def check_same_tree_learned(self, learner: TemplateTreeLearner, dataset: Collection[str], trials: int = 20): first_tree = learner.learn(dataset) # print(template_tree_visualiser.render_tree_string(first_tree)) for i in range(trials): random.shuffle(dataset) other_tree = learner.learn(dataset) self.assertEqual( first_tree, other_tree, "Non-equal trees " + str(i) + ":\n" + render_tree_string(first_tree) + "\n" + render_tree_string(other_tree), )
def test_learn_hello_world_tree(self): learner = TemplateLatticeLearner(minimal_variables=True) dataset = list(self.hello_world_small.generate_all_string()) template_tree = learner.learn(dataset) print( template_tree_visualiser.render_tree_string( template_tree.collapse()))
def test_disallow_empty_string_simple_2(self): dataset = [ "He likes cute cats", "He likes nice cats", "He likes cats", "This is another sentence", ] learner = TemplateLatticeLearner(minimal_variables=True, allow_empty_string=False) template_tree = learner.learn(dataset) expected = TemplateTree( Template.from_string("[SLOT]"), [ TemplateTree( Template.from_string("He likes [SLOT]"), [ TemplateTree( Template.from_string("He likes [SLOT] cats"), [ TemplateTree(Template.from_string(s)) for s in ["He likes cute cats", "He likes nice cats"] ], ), TemplateTree(Template.from_string("He likes cats")), ], ), TemplateTree(Template.from_string("This is another sentence")), ], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected, template_tree)
def test_2_line_learner(self): learner = TemplateLatticeLearner(minimal_variables=True) dataset = ["hello world", "hi world"] template_tree = learner.learn(dataset) expected_top_template = Template.from_string("[SLOT] world") expected = TemplateTree( expected_top_template, [TemplateTree(Template.from_string(s)) for s in dataset], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected_top_template, template_tree.get_template()) self.assertEqual(expected, template_tree)
def test_disallow_empty_string_hard(self): dataset = [ "I saw her on the quiet hill", "I saw her on the tall hill", "I saw her on the hill", "He likes cute cats", "He likes nice cats", "He likes cats", ] learner = TemplateLatticeLearner(minimal_variables=True, allow_empty_string=False) template_tree = learner.learn(dataset) expected = TemplateTree( Template.from_string("[SLOT]"), [ TemplateTree( Template.from_string("He likes [SLOT]"), [ TemplateTree( Template.from_string("He likes [SLOT] cats"), [ TemplateTree(Template.from_string(s)) for s in ["He likes cute cats", "He likes nice cats"] ], ), TemplateTree(Template.from_string("He likes cats")), ], ), TemplateTree( Template.from_string("I saw her on the [SLOT]"), [ TemplateTree( Template.from_string( "I saw her on the [SLOT] hill"), [ TemplateTree(Template.from_string(s)) for s in [ "I saw her on the tall hill", "I saw her on the quiet hill", ] ], ), TemplateTree( Template.from_string("I saw her on the hill")), ], ), ], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected, template_tree)
def test_4_line_learner_longer_second(self): learner = TemplateLatticeLearner(minimal_variables=True, words_per_leaf_slot=2) dataset = [ "hello world", "hi world", "hello solar system", "hi solar system" ] template_tree = learner.learn(dataset) expected = TemplateTree( Template.from_string("[SLOT]"), [ TemplateTree( Template.from_string("[SLOT] world"), [ TemplateTree(Template.from_string(s)) for s in ["hello world", "hi world"] ], ), TemplateTree( Template.from_string("[SLOT] solar system"), [ TemplateTree(Template.from_string(s)) for s in ["hello solar system", "hi solar system"] ], ), TemplateTree( Template.from_string("hello [SLOT]"), [ TemplateTree(Template.from_string(s)) for s in ["hello world", "hello solar system"] ], ), TemplateTree( Template.from_string("hi [SLOT]"), [ TemplateTree(Template.from_string(s)) for s in ["hi world", "hi solar system"] ], ), ], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected, template_tree)
def test_disallow_empty_string_simple(self): """ Checks whether disallowing empty string in learning works""" learner = TemplateLatticeLearner(minimal_variables=True, allow_empty_string=False) dataset = ["I am a human", "I am a nice human", "I am a bad human"] template_tree = learner.learn(dataset) expected = TemplateTree( Template.from_string("I am a [SLOT]"), [ TemplateTree( Template.from_string("I am a [SLOT] human"), [ TemplateTree(Template.from_string(s)) for s in ["I am a nice human", "I am a bad human"] ], ), TemplateTree(Template.from_string("I am a human"), ), ], ) print(template_tree_visualiser.render_tree_string(template_tree)) self.assertEqual(expected, template_tree)
def log_tree(text, template_tree): print(text, render_tree_string(template_tree), sep="\n")