def test_infer_ast_spec(self): observations = ast_spec_inference.ASTObservations( example_count=1, node_types={ "root": ast_spec_inference.NodeObservations( count=10, count_root=10, fields={ "nonempty_sequence": ast_spec_inference.FieldObservations( count_one=2, count_many=8), "one_child": ast_spec_inference.FieldObservations(count_one=10), }), "foo": ast_spec_inference.NodeObservations( count=20, count_root=0, fields={ "optional_child": ast_spec_inference.FieldObservations(count_one=15), "sequence": ast_spec_inference.FieldObservations( count_many=4, count_one=4), "no_children": ast_spec_inference.FieldObservations(), }), }) spec = ast_spec_inference.infer_ast_spec(observations) expected = { "root": generic_ast_graphs.ASTNodeSpec( fields={ "nonempty_sequence": generic_ast_graphs.FieldType.NONEMPTY_SEQUENCE, "one_child": generic_ast_graphs.FieldType.ONE_CHILD }, sequence_item_types={ "nonempty_sequence": "root_nonempty_sequence" }, has_parent=False), "foo": generic_ast_graphs.ASTNodeSpec( fields={ "optional_child": generic_ast_graphs.FieldType.OPTIONAL_CHILD, "sequence": generic_ast_graphs.FieldType.SEQUENCE, "no_children": generic_ast_graphs.FieldType.NO_CHILDREN }, sequence_item_types={"sequence": "foo_sequence"}, has_parent=True) } self.assertEqual(spec, expected)
def test_observe_types_and_fields(self): tree = generic_ast_graphs.GenericASTNode( 0, "root", { "children": [ generic_ast_graphs.GenericASTNode( 1, "foo", { "a": [generic_ast_graphs.GenericASTNode(12, "bar", {})], "b": [generic_ast_graphs.GenericASTNode(13, "bar", {})], "c": [], "d": [ generic_ast_graphs.GenericASTNode(14, "bar", {}), generic_ast_graphs.GenericASTNode(15, "bar", {}) ], }), generic_ast_graphs.GenericASTNode( 2, "foo", { "a": [generic_ast_graphs.GenericASTNode(22, "bar", {})], "b": [], "d": [generic_ast_graphs.GenericASTNode(24, "bar", {})], }) ] }) observations = ast_spec_inference.observe_types_and_fields(tree) expected = ast_spec_inference.ASTObservations( example_count=1, node_types={ "root": ast_spec_inference.NodeObservations( count=1, count_root=1, fields={ "children": ast_spec_inference.FieldObservations(count_many=1) }), "foo": ast_spec_inference.NodeObservations( count=2, count_root=0, fields={ "a": ast_spec_inference.FieldObservations(count_one=2), "b": ast_spec_inference.FieldObservations(count_one=1), "c": ast_spec_inference.FieldObservations(), "d": ast_spec_inference.FieldObservations( count_one=1, count_many=1), }), "bar": ast_spec_inference.NodeObservations( count=6, count_root=0, fields={}), }) self.assertEqual(observations, expected)
def test_merge_observations(self): observations_a = ast_spec_inference.ASTObservations( example_count=3, node_types={ "appears_in_a": ast_spec_inference.NodeObservations( count=12, count_root=3, fields={ "foo": ast_spec_inference.FieldObservations(count_one=4, count_many=5) }), "appears_in_both": ast_spec_inference.NodeObservations( count=67, count_root=8, fields={ "field_in_a": ast_spec_inference.FieldObservations(count_one=1, count_many=2), "field_in_both": ast_spec_inference.FieldObservations(count_one=3, count_many=4), }), }) observations_b = ast_spec_inference.ASTObservations( example_count=5, node_types={ "appears_in_b": ast_spec_inference.NodeObservations( count=7, count_root=5, fields={ "bar": ast_spec_inference.FieldObservations(count_one=3, count_many=2) }), "appears_in_both": ast_spec_inference.NodeObservations( count=13, count_root=10, fields={ "field_in_b": ast_spec_inference.FieldObservations(count_one=2, count_many=1), "field_in_both": ast_spec_inference.FieldObservations(count_one=10, count_many=20), }), }) observations_merged = observations_a + observations_b expected = ast_spec_inference.ASTObservations( example_count=8, node_types={ "appears_in_a": ast_spec_inference.NodeObservations( count=12, count_root=3, fields={ "foo": ast_spec_inference.FieldObservations(count_one=4, count_many=5) }), "appears_in_b": ast_spec_inference.NodeObservations( count=7, count_root=5, fields={ "bar": ast_spec_inference.FieldObservations(count_one=3, count_many=2) }), "appears_in_both": ast_spec_inference.NodeObservations( count=80, count_root=18, fields={ "field_in_a": ast_spec_inference.FieldObservations(count_one=1, count_many=2), "field_in_b": ast_spec_inference.FieldObservations(count_one=2, count_many=1), "field_in_both": ast_spec_inference.FieldObservations(count_one=13, count_many=24), }), }) self.assertEqual(observations_merged, expected)