def test_full_number(numbers_type_context): # Create an index index = ExamplesIndex(numbers_type_context, ExamplesIndex.get_default_ram_backend()) type = "Number" x_y = [("one", "1"), ("two", "2"), ("three", "3"), ("ten", "10"), ("negative one", "-1")] for x, y in x_y: index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type) # Create a expected value parser = StringParser(numbers_type_context) expected = parser.create_parse_tree("-1", type) # Predict model = make_rulebased_seacr(index) prediction = model.predict("negative one", type, False) print("Expected") print(expected.dump_str()) print("predicted") print(prediction.dump_str()) assert expected.dump_str() == prediction.dump_str() assert expected == prediction # expected = parser.create_parse_tree("2", type) prediction = model.predict("two", type, False) assert expected == prediction # with pytest.raises(ModelCantPredictException): model.predict("sdsdfas sdfasdf asdf", type, False)
def test_word_parts_2(): tc = TypeContext() _create_root_types(tc) _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True), ("!", False)]) tc.finalize_data() parser = StringParser(tc) ast = parser.create_parse_tree("fooBarBaz", WORD_PART_TYPE_NAME, allow_partial_consume=True) word_part_o = ast.next_node_not_copy assert word_part_o.implementation.name == _name_for_word_part("foo") mod_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_MODIFIER_ARG_NAME) mod_type_object = mod_type_choice.next_node_not_copy assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME next_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_NEXT_ARG_NAME) next_part_o = next_type_choice.next_node_not_copy assert next_part_o.implementation.name == _name_for_word_part("bar") ### Unparse unparser = AstUnparser(tc) result = unparser.to_string(ast) assert result.total_string == "fooBar" pointers = list(ast.depth_first_iter()) assert get_str_and_assert_same_part(result, pointers[1], word_part_o) == "fooBar" assert get_str_and_assert_same_part(result, pointers[2], mod_type_choice) == "foo" assert get_str_and_assert_same_part(result, pointers[3], mod_type_object) == "" assert get_str_and_assert_same_part(result, pointers[4], next_type_choice) == "Bar" assert get_str_and_assert_same_part(result, pointers[5], next_part_o) == "Bar"
def test_string_parse_e2e_pos_unparse(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "p1", "FooType", arg_data={ POSITION: 0, MULTIWORD_POS_ARG: False }, parent_object_name="bw", required=True) ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello foo", "Program") unparser = AstUnparser(type_context) unparse_result = unparser.to_string(ast) assert unparse_result.total_string == "hello foo" for p in ast.depth_first_iter(): n = p.cur_node if isinstance(n, ObjectChoiceNode) and n.type_to_choose.name == "FooType": arg_node_pointer = p break assert unparse_result.pointer_to_string(arg_node_pointer) == "foo"
def test_string_parse_e2e_sequence(type_context): twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "barg", None, arg_data={"short_name": "b"}, parent_object_name="bw") ], type_data={"invoke_name": "hello"}) parser = StringParser(type_context) unparser = AstUnparser(type_context) string = "hello -a | hello -b" ast = parser.create_parse_tree(string, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string no_space = "hello -a|hello -b" ast = parser.create_parse_tree(no_space, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string string = "hello -a | hello" ast = parser.create_parse_tree(string, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string
def test_string_parse_e2e_multiword3(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "barg", None, arg_data={"short_name": "b"}, parent_object_name="bw"), _make_positional() ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello -a", "CommandSequence") unparser = AstUnparser(type_context) to_string = unparser.to_string(ast) assert to_string.total_string == "hello -a"
def test_string_parse_e2e_multiword2(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo bar"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "p1", "FooType", arg_data={ POSITION: 0, MULTIWORD_POS_ARG: True }, parent_object_name="sdf") ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello foo bar -a", "Program") unparser = AstUnparser(type_context) to_string = unparser.to_string(ast) assert to_string.total_string == "hello -a foo bar"
def test_max_munch(): tc = TypeContext() loader.load_path("builtin_types/generic_parsers.ainix.yaml", tc, up_search_limit=3) foo_type = "MMTestType" AInixType(tc, foo_type, default_type_parser_name="max_munch_type_parser") def make_mock_with_parse_rep(representation: str): loader._load_object( { "name": representation, "type": foo_type, "preferred_object_parser": { "grammar": f'"{representation}"' } }, tc, "foopathsdf") assert tc.get_object_by_name( representation).preferred_object_parser_name is not None objects = [ make_mock_with_parse_rep(rep) for rep in ("fo", "bar", "f", "foo", "foot", 'baz') ] parser = StringParser(tc) ast = parser.create_parse_tree("foobar", foo_type, allow_partial_consume=True) assert ast.next_node_not_copy.implementation.name == "foo"
def test_word_parts_upper(): tc = TypeContext() _create_root_types(tc) _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True), ("!", False)]) tc.finalize_data() parser = StringParser(tc) ast = parser.create_parse_tree("FOO", WORD_PART_TYPE_NAME) word_part_o = ast.next_node_not_copy assert word_part_o.implementation.name == _name_for_word_part("foo") mod_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_MODIFIER_ARG_NAME) mod_type_object = mod_type_choice.next_node_not_copy assert mod_type_object.implementation.name == MODIFIER_ALL_UPPER next_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_NEXT_ARG_NAME) next_part_o = next_type_choice.next_node_not_copy assert next_part_o.implementation.name == WORD_PART_TERMINAL_NAME ### Unparse unparser = AstUnparser(tc) result = unparser.to_string(ast) assert result.total_string == "FOO" pointers = list(ast.depth_first_iter()) assert ast == pointers[0].cur_node assert result.pointer_to_string(pointers[0]) == "FOO" assert word_part_o == pointers[1].cur_node assert result.pointer_to_string(pointers[1]) == "FOO" assert mod_type_choice == pointers[2].cur_node assert result.pointer_to_string(pointers[2]) == "FOO" assert mod_type_object == pointers[3].cur_node assert result.pointer_to_string(pointers[3]) == ""
def test_not_fail_find_expr(all_the_stuff_context, string): tc = all_the_stuff_context parser = StringParser(tc) ast = parser.create_parse_tree(string, "FindExpression") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast) assert result.total_string == string
def test_make_copy_optional_arg(): tc = TypeContext() ft = AInixType(tc, "ft") bt = AInixType(tc, "bt") arg1 = AInixArgument(tc, "arg1", "bt", required=False, parent_object_name="fo") fo = AInixObject( tc, "fo", "ft", [arg1], preferred_object_parser_name=create_object_parser_from_grammar( tc, "masfoo_parser", '"foo" arg1?').name) bo = AInixObject( tc, "bo", "bt", None, preferred_object_parser_name=create_object_parser_from_grammar( tc, "masdfo_parser", '"bar"').name) tc.finalize_data() parser = StringParser(tc) unparser = AstUnparser(tc) ast = parser.create_parse_tree("foobar", "ft") tokenizer = SpaceTokenizer() in_str = "Hello bar sdf cow" tokens, metadata = tokenizer.tokenize(in_str) unpar_res = unparser.to_string(ast) assert unpar_res.total_string == "foobar" result = make_copy_version_of_tree(ast, unparser, metadata) assert result.next_node_not_copy.get_choice_node_for_arg( "arg1").copy_was_chosen
def test_path_parse_and_unparse_with_error(tc, in_str): parser = StringParser(tc) try: ast = parser.create_parse_tree(in_str, "Path") pytest.fail(f"{in_str} unexpectedly worked") except AInixParseError as e: pass
def make_vocab_from_example_store( exampe_store: ExamplesStore, x_tokenizer: Tokenizer, y_tokenizer: Tokenizer, x_vocab_builder: VocabBuilder = None, y_vocab_builder: VocabBuilder = None ) -> typing.Tuple[Vocab, Vocab]: """Creates an x and y vocab based off all examples in an example store. Args: exampe_store: The example store to build from x_tokenizer: The tokenizer to use for x queries y_tokenizer: The tokenizer to use for y queries x_vocab_builder: The builder for the kind of x vocab we want to produce. If None, it just picks a reasonable default. y_vocab_builder: The builder for the kind of y vocab we want to produce. If None, it just picks a reasonable default. Returns: Tuple of the new (x vocab, y vocab). """ if x_vocab_builder is None: x_vocab_builder = CounterVocabBuilder(min_freq=1) if y_vocab_builder is None: y_vocab_builder = CounterVocabBuilder() already_done_ys = set() parser = StringParser(exampe_store.type_context) for example in exampe_store.get_all_x_values(): x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0]) if example.ytext not in already_done_ys: ast = parser.create_parse_tree(example.ytext, example.ytype) y_tokens, _ = y_tokenizer.tokenize(ast) y_vocab_builder.add_sequence(y_tokens) return x_vocab_builder.produce_vocab(), y_vocab_builder.produce_vocab()
def test_cp(all_the_stuff_context, string): tc = all_the_stuff_context parser = StringParser(tc) ast = parser.create_parse_tree(string, "Program") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast) assert result.total_string == string pointers = list(ast.depth_first_iter()) assert result.pointer_to_string(pointers[0]) == string
def test_fails(all_the_stuff_context, string): tc = all_the_stuff_context parser = StringParser(tc) try: ast = parser.create_parse_tree(string, "CommandSequence") print(ast.dump_str()) pytest.fail("Expected fail") except AInixParseError: pass
def __init__(self, index: ExamplesIndex, comparer: 'Comparer'): self.index = index self.type_context = index.type_context self.comparer = comparer self.parser = StringParser(self.type_context) self.prepared_trainers = False self.present_pred_criterion = torch.nn.BCEWithLogitsLoss() self.train_sample_count = 20 self.train_search_sample_dropout = 0.6 self.max_examples_to_compare = 10 self.optimizer = None
def test_get_copy_paths(): tc = get_toy_strings_context() parser = StringParser(tc) unparser = AstUnparser(tc) ast = parser.create_parse_tree("TWO foo bar", "ToySimpleStrs") unpar_res = unparser.to_string(ast) assert unpar_res.total_string == "TWO foo bar" tokenizer = SpaceTokenizer() in_str = "Hello there foo cow" tokens, metadata = tokenizer.tokenize(in_str) result = make_copy_version_of_tree(ast, unparser, metadata) assert get_paths_to_all_copies(result) == ((0, 0, 0), )
def test_word_parts_3(): tc = TypeContext() _create_root_types(tc) _create_all_word_parts(tc, [('f', True), ('ooo', True), ("bar", True), ("!", False)]) tc.finalize_data() word_part_type = tc.get_type_by_name(WORD_PART_TYPE_NAME) parser = StringParser(tc) node, data = parser._parse_object_choice_node( "fooo.bar", word_part_type.default_type_parser, word_part_type) assert data.parse_success assert data.remaining_string == ".bar"
def test_numbers_copys(numbers_type_context): tc = numbers_type_context parser = StringParser(tc) unparser = AstUnparser(tc) ast = parser.create_parse_tree("0", "Number") tokenizer = NonLetterTokenizer() in_str = "nil" tokens, metadata = tokenizer.tokenize(in_str) ast_set = AstObjectChoiceSet(tc.get_type_by_name("Number")) ast_set.add(ast, True, 1, 1) add_copies_to_ast_set(ast, ast_set, unparser, metadata) assert not ast_set.copy_is_known_choice()
def test_partial_copy_numbers(): tc = TypeContext() loader.load_path(f"builtin_types/generic_parsers.ainix.yaml", tc, up_search_limit=3) loader.load_path(f"builtin_types/numbers.ainix.yaml", tc, up_search_limit=3) tc.finalize_data() parser = StringParser(tc) tokenizer = NonLetterTokenizer() unparser = AstUnparser(tc, tokenizer) ast = parser.create_parse_tree("1000", "Number")
def test_full_number_2(numbers_type_context): # Create an index index = ExamplesIndex(numbers_type_context, ExamplesIndex.get_default_ram_backend()) ainix_kernel.indexing.exampleloader.load_path( f"{BUILTIN_TYPES_PATH}/numbers_examples.ainix.yaml", index) # Create a expected value parser = StringParser(numbers_type_context) expected = parser.create_parse_tree("9", "Number") # Predict model = make_rulebased_seacr(index) prediction = model.predict("nineth", "Number", False) assert expected == prediction
def get_parsable_commands(data: Tuple[str, str]) -> List[Tuple[str, str]]: tc = get_a_tc() parser = StringParser(tc) parsable_data = [] for nl, cm in data: try: if cm.strip().startswith("tar"): continue ast = parser.create_parse_tree(cm, "CommandSequence") parsable_data.append((nl, cm)) print(f"PASS {cm}") except AInixParseError as e: pass return parsable_data
def test_predict_digit(numbers_type_context): # Create an index index = ExamplesIndex(numbers_type_context, ExamplesIndex.get_default_ram_backend()) x_y = [("one", "1"), ("two", "2"), ("three", "3")] for x, y in x_y: index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, "BaseTen") # Create a expected value parser = StringParser(numbers_type_context) expected = parser.create_parse_tree("2", "BaseTen") # Predict model = make_rulebased_seacr(index) prediction = model.predict("two", "BaseTen", False) assert expected == prediction
def test_type_pred_gt_result(numbers_type_context): parser = StringParser(numbers_type_context) ast = parser.create_parse_tree("9", "BaseTen") type = numbers_type_context.get_type_by_name("BaseTen") valid_set = AstObjectChoiceSet(type, None) valid_set.add(ast, True, 1, 1) choose = ObjectChoiceNode(type) gt_res = models.SeaCR.comparer._create_gt_compare_result( ast, choose, valid_set) assert gt_res.prob_valid_in_example == 1 assert gt_res.impl_scores == ((1, "nine"), ) ast = parser.create_parse_tree("6", "BaseTen") gt_res = models.SeaCR.comparer._create_gt_compare_result( ast, choose, valid_set) assert gt_res.prob_valid_in_example == 0 assert gt_res.impl_scores is None
def update_latent_store_from_examples( model: 'StringTypeTranslateCF', latent_store: LatentStore, examples: ExamplesStore, replacer: Replacer, parser: StringParser, splits: Optional[Tuple[DataSplits]], unparser: AstUnparser, tokenizer: StringTokenizer): model.set_in_eval_mode() for example in examples.get_all_x_values(splits): # TODO multi sampling and average replacers x_replaced, y_replaced = replacer.strings_replace( example.xquery, example.ytext, seed_from_x_val(example.xquery)) ast = parser.create_parse_tree(y_replaced, example.ytype) _, token_metadata = tokenizer.tokenize(x_replaced) copy_ast = copy_tools.make_copy_version_of_tree( ast, unparser, token_metadata) # TODO Think about whether feeding in the raw x is good idea. # will change once have replacer sampling latents = model.get_latent_select_states(example.xquery, copy_ast) nodes = list(copy_ast.depth_first_iter()) #print("LATENTS", latents) for i, l in enumerate(latents): dfs_depth = i * 2 n = nodes[dfs_depth].cur_node assert isinstance(n, ObjectChoiceNode) c = l.detach() assert not c.requires_grad latent_store.set_latent_for_example(c, n.type_to_choose.ind, example.id, dfs_depth) model.set_in_train_mode()
def make_y_ast_set(y_type: AInixType, all_y_examples: List[YValue], replacement_sample: ReplacementSampling, string_parser: StringParser, this_x_metadata: StringTokensMetadata, unparser: AstUnparser): y_ast_set = AstObjectChoiceSet(y_type, None) y_texts = set() individual_asts = [] individual_asts_preferences = [] for y_example in all_y_examples: replaced_y = replacement_sample.replace_x(y_example.y_text) if replaced_y not in y_texts: parsed_ast = string_parser.create_parse_tree( replaced_y, y_type.name) individual_asts.append(parsed_ast) individual_asts_preferences.append(y_example.y_preference) y_ast_set.add(parsed_ast, True, y_example.y_preference, 1.0) y_texts.add(replaced_y) # handle copies # TODO figure how to weight the copy node?? add_copies_to_ast_set(parsed_ast, y_ast_set, unparser, this_x_metadata, copy_node_weight=1) y_ast_set.freeze() teacher_force_path_ast = WeightedRandomChooser( individual_asts, individual_asts_preferences).sample() return y_ast_set, y_texts, teacher_force_path_ast
def get_default_encdec_model(examples: ExamplesStore, standard_size=16, replacer: Replacer = None, use_retrieval_decoder: bool = False, pretrain_checkpoint: str = None, learn_on_top_pretrained: bool = False): (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers( use_word_piece=pretrain_checkpoint is not None) if x_vocab is None or pretrain_checkpoint is None: # If we don't have a pretrained_checkpoint, then we might not know all # words in the vocab, so should filter the vocab to only tokens in dataset x_vocab = vocab.make_x_vocab_from_examples(examples, x_tokenizer, replacer, min_freq=5, replace_samples=3) hidden_size = standard_size tc = examples.type_context encoder = make_default_query_encoder(x_tokenizer, x_vocab, hidden_size, pretrain_checkpoint) if not use_retrieval_decoder: decoder = decoders.get_default_nonretrieval_decoder(tc, hidden_size) else: parser = StringParser(tc) unparser = AstUnparser(tc, x_tokenizer) decoder = decoders.get_default_retrieval_decoder( tc, hidden_size, examples, replacer, parser, unparser) model = EncDecModel(examples.type_context, encoder, decoder) if use_retrieval_decoder: # TODO lolz, this is such a crappy interface model.plz_train_this_latent_store_thanks = lambda: decoder.action_selector.latent_store return model
def create_from_save_state_dict( cls, state_dict: dict, new_type_context: TypeContext, new_example_store: ExamplesStore, ) -> 'EncDecModel': # TODO (DNGros): acutally handle the new type context. # TODO check the name of the query encoder if state_dict['name'] == "EncoderDecoder": # I don't think this is right?? query_encoder = encoders.StringQueryEncoder.create_from_save_state_dict( state_dict['query_encoder']) else: query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict( state_dict['query_encoder']) parser = StringParser(new_type_context) unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer()) replacers = get_all_replacers() decoder = TreeRNNDecoder.create_from_save_state_dict( state_dict['decoder'], new_type_context, new_example_store, replacers, parser, unparser) model = cls(type_context=new_type_context, query_encoder=query_encoder, tree_decoder=decoder) if state_dict['need_latent_train']: update_latent_store_from_examples( model, decoder.action_selector.latent_store, new_example_store, replacers, parser, None, unparser, query_encoder.get_tokenizer()) return model
def test_digit_list_2(numbers_type_context): # Create an index index = ExamplesIndex(numbers_type_context, ExamplesIndex.get_default_ram_backend()) type = "IntBase" x_y = [("ten", "10"), ("twenty", "20"), ("thirty", "30")] for x, y in x_y: index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type) # Create a expected value parser = StringParser(numbers_type_context) expected = parser.create_parse_tree("20", type) # Predict model = make_rulebased_seacr(index) prediction = model.predict("twenty", type, False) print(expected.dump_str()) print(prediction.dump_str()) assert expected == prediction
def test_only_one_option(base_type_context): # Create some types foo_type = AInixType(base_type_context, "FooType") AInixObject(base_type_context, "FooObj", "FooType") base_type_context.finalize_data() # Create an index index = ExamplesIndex(base_type_context, ExamplesIndex.get_default_ram_backend()) index.add_yset_default_weight(["example"], ["y"], index.DEFAULT_X_TYPE, "FooType") # Create a expected value parser = StringParser(base_type_context) expected = parser.create_parse_tree("string", foo_type.name) # Predict model = make_rulebased_seacr(index) prediction = model.predict("example", "FooType", False) assert expected == prediction
def test_touch_set(all_the_stuff_context): x_str = 'set the last mod time of out.txt to now' tc = all_the_stuff_context parser = StringParser(tc) string = "touch out.txt" ast = parser.create_parse_tree(string, "Program") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast, x_str) assert result.total_string == string cset = AstObjectChoiceSet(tc.get_type_by_name("Program")) cset.add(ast, True, 1, 1) new_ast = parser.create_parse_tree(string, "Program") assert cset.is_node_known_valid(new_ast) tokenizer = NonLetterTokenizer() _, tok_metadata = tokenizer.tokenize(x_str) ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata) add_copies_to_ast_set(ast, cset, unparser, tok_metadata) assert cset.is_node_known_valid(ast_copies) assert cset.is_node_known_valid(ast) # Scary complicated reconstruction of something that broke it. # could be made into a simpler unit test in copy_tools touch_o = tc.get_object_by_name("touch") file_list = tc.get_type_by_name("PathList") r_arg = touch_o.get_arg_by_name("r") m_arg = touch_o.get_arg_by_name("m") other_copy = ObjectChoiceNode( tc.get_type_by_name("Program"), ObjectNode( touch_o, pmap({ "r": ObjectChoiceNode(r_arg.present_choice_type, ObjectNode(r_arg.not_present_object, pmap())), "m": ObjectChoiceNode(m_arg.present_choice_type, ObjectNode(m_arg.not_present_object, pmap())), "file_list": ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14)) }))) other_result = unparser.to_string(other_copy, x_str) assert other_result.total_string == string assert cset.is_node_known_valid(other_copy)