コード例 #1
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_full_number(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    type = "Number"
    x_y = [("one", "1"), ("two", "2"), ("three", "3"), ("ten", "10"),
           ("negative one", "-1")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("-1", type)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("negative one", type, False)
    print("Expected")
    print(expected.dump_str())
    print("predicted")
    print(prediction.dump_str())
    assert expected.dump_str() == prediction.dump_str()
    assert expected == prediction
    #
    expected = parser.create_parse_tree("2", type)
    prediction = model.predict("two", type, False)
    assert expected == prediction
    #
    with pytest.raises(ModelCantPredictException):
        model.predict("sdsdfas sdfasdf asdf", type, False)
コード例 #2
0
def test_word_parts_2():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("fooBarBaz",
                                   WORD_PART_TYPE_NAME,
                                   allow_partial_consume=True)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == _name_for_word_part("bar")
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "fooBar"
    pointers = list(ast.depth_first_iter())
    assert get_str_and_assert_same_part(result, pointers[1],
                                        word_part_o) == "fooBar"
    assert get_str_and_assert_same_part(result, pointers[2],
                                        mod_type_choice) == "foo"
    assert get_str_and_assert_same_part(result, pointers[3],
                                        mod_type_object) == ""
    assert get_str_and_assert_same_part(result, pointers[4],
                                        next_type_choice) == "Bar"
    assert get_str_and_assert_same_part(result, pointers[5],
                                        next_part_o) == "Bar"
コード例 #3
0
def test_string_parse_e2e_pos_unparse(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: False
                                            },
                                            parent_object_name="bw",
                                            required=True)
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo", "Program")
    unparser = AstUnparser(type_context)
    unparse_result = unparser.to_string(ast)
    assert unparse_result.total_string == "hello foo"
    for p in ast.depth_first_iter():
        n = p.cur_node
        if isinstance(n,
                      ObjectChoiceNode) and n.type_to_choose.name == "FooType":
            arg_node_pointer = p
            break
    assert unparse_result.pointer_to_string(arg_node_pointer) == "foo"
コード例 #4
0
def test_string_parse_e2e_sequence(type_context):
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw")
                          ],
                          type_data={"invoke_name": "hello"})
    parser = StringParser(type_context)
    unparser = AstUnparser(type_context)
    string = "hello -a | hello -b"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    no_space = "hello -a|hello -b"
    ast = parser.create_parse_tree(no_space, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    string = "hello -a | hello"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string
コード例 #5
0
def test_string_parse_e2e_multiword3(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw"),
                              _make_positional()
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello -a", "CommandSequence")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a"
コード例 #6
0
def test_string_parse_e2e_multiword2(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo bar"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: True
                                            },
                                            parent_object_name="sdf")
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo bar -a", "Program")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a foo bar"
コード例 #7
0
def test_max_munch():
    tc = TypeContext()
    loader.load_path("builtin_types/generic_parsers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    foo_type = "MMTestType"
    AInixType(tc, foo_type, default_type_parser_name="max_munch_type_parser")

    def make_mock_with_parse_rep(representation: str):
        loader._load_object(
            {
                "name": representation,
                "type": foo_type,
                "preferred_object_parser": {
                    "grammar": f'"{representation}"'
                }
            }, tc, "foopathsdf")
        assert tc.get_object_by_name(
            representation).preferred_object_parser_name is not None

    objects = [
        make_mock_with_parse_rep(rep)
        for rep in ("fo", "bar", "f", "foo", "foot", 'baz')
    ]

    parser = StringParser(tc)
    ast = parser.create_parse_tree("foobar",
                                   foo_type,
                                   allow_partial_consume=True)
    assert ast.next_node_not_copy.implementation.name == "foo"
コード例 #8
0
def test_word_parts_upper():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("FOO", WORD_PART_TYPE_NAME)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_ALL_UPPER
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == WORD_PART_TERMINAL_NAME
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "FOO"
    pointers = list(ast.depth_first_iter())
    assert ast == pointers[0].cur_node
    assert result.pointer_to_string(pointers[0]) == "FOO"
    assert word_part_o == pointers[1].cur_node
    assert result.pointer_to_string(pointers[1]) == "FOO"
    assert mod_type_choice == pointers[2].cur_node
    assert result.pointer_to_string(pointers[2]) == "FOO"
    assert mod_type_object == pointers[3].cur_node
    assert result.pointer_to_string(pointers[3]) == ""
コード例 #9
0
def test_not_fail_find_expr(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "FindExpression")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
コード例 #10
0
ファイル: test_copy_tools.py プロジェクト: AInixProject/AInix
def test_make_copy_optional_arg():
    tc = TypeContext()
    ft = AInixType(tc, "ft")
    bt = AInixType(tc, "bt")
    arg1 = AInixArgument(tc,
                         "arg1",
                         "bt",
                         required=False,
                         parent_object_name="fo")
    fo = AInixObject(
        tc,
        "fo",
        "ft", [arg1],
        preferred_object_parser_name=create_object_parser_from_grammar(
            tc, "masfoo_parser", '"foo" arg1?').name)
    bo = AInixObject(
        tc,
        "bo",
        "bt",
        None,
        preferred_object_parser_name=create_object_parser_from_grammar(
            tc, "masdfo_parser", '"bar"').name)
    tc.finalize_data()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("foobar", "ft")
    tokenizer = SpaceTokenizer()
    in_str = "Hello bar sdf cow"
    tokens, metadata = tokenizer.tokenize(in_str)
    unpar_res = unparser.to_string(ast)
    assert unpar_res.total_string == "foobar"
    result = make_copy_version_of_tree(ast, unparser, metadata)
    assert result.next_node_not_copy.get_choice_node_for_arg(
        "arg1").copy_was_chosen
コード例 #11
0
def test_path_parse_and_unparse_with_error(tc, in_str):
    parser = StringParser(tc)
    try:
        ast = parser.create_parse_tree(in_str, "Path")
        pytest.fail(f"{in_str} unexpectedly worked")
    except AInixParseError as e:
        pass
コード例 #12
0
def make_vocab_from_example_store(
    exampe_store: ExamplesStore,
    x_tokenizer: Tokenizer,
    y_tokenizer: Tokenizer,
    x_vocab_builder: VocabBuilder = None,
    y_vocab_builder: VocabBuilder = None
) -> typing.Tuple[Vocab, Vocab]:
    """Creates an x and y vocab based off all examples in an example store.

    Args:
        exampe_store: The example store to build from
        x_tokenizer: The tokenizer to use for x queries
        y_tokenizer: The tokenizer to use for y queries
        x_vocab_builder: The builder for the kind of x vocab we want to produce.
            If None, it just picks a reasonable default.
        y_vocab_builder: The builder for the kind of y vocab we want to produce.
            If None, it just picks a reasonable default.
    Returns:
        Tuple of the new (x vocab, y vocab).
    """
    if x_vocab_builder is None:
        x_vocab_builder = CounterVocabBuilder(min_freq=1)
    if y_vocab_builder is None:
        y_vocab_builder = CounterVocabBuilder()

    already_done_ys = set()
    parser = StringParser(exampe_store.type_context)
    for example in exampe_store.get_all_x_values():
        x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0])
        if example.ytext not in already_done_ys:
            ast = parser.create_parse_tree(example.ytext, example.ytype)
            y_tokens, _ = y_tokenizer.tokenize(ast)
            y_vocab_builder.add_sequence(y_tokens)
    return x_vocab_builder.produce_vocab(), y_vocab_builder.produce_vocab()
コード例 #13
0
def test_cp(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
    pointers = list(ast.depth_first_iter())
    assert result.pointer_to_string(pointers[0]) == string
コード例 #14
0
def test_fails(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    try:
        ast = parser.create_parse_tree(string, "CommandSequence")
        print(ast.dump_str())
        pytest.fail("Expected fail")
    except AInixParseError:
        pass
コード例 #15
0
ファイル: type_predictor.py プロジェクト: AInixProject/AInix
 def __init__(self, index: ExamplesIndex, comparer: 'Comparer'):
     self.index = index
     self.type_context = index.type_context
     self.comparer = comparer
     self.parser = StringParser(self.type_context)
     self.prepared_trainers = False
     self.present_pred_criterion = torch.nn.BCEWithLogitsLoss()
     self.train_sample_count = 20
     self.train_search_sample_dropout = 0.6
     self.max_examples_to_compare = 10
     self.optimizer = None
コード例 #16
0
ファイル: test_copy_tools.py プロジェクト: AInixProject/AInix
def test_get_copy_paths():
    tc = get_toy_strings_context()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("TWO foo bar", "ToySimpleStrs")
    unpar_res = unparser.to_string(ast)
    assert unpar_res.total_string == "TWO foo bar"
    tokenizer = SpaceTokenizer()
    in_str = "Hello there foo cow"
    tokens, metadata = tokenizer.tokenize(in_str)
    result = make_copy_version_of_tree(ast, unparser, metadata)
    assert get_paths_to_all_copies(result) == ((0, 0, 0), )
コード例 #17
0
def test_word_parts_3():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('f', True), ('ooo', True), ("bar", True),
                                ("!", False)])
    tc.finalize_data()
    word_part_type = tc.get_type_by_name(WORD_PART_TYPE_NAME)
    parser = StringParser(tc)
    node, data = parser._parse_object_choice_node(
        "fooo.bar", word_part_type.default_type_parser, word_part_type)
    assert data.parse_success
    assert data.remaining_string == ".bar"
コード例 #18
0
ファイル: test_copy_tools.py プロジェクト: AInixProject/AInix
def test_numbers_copys(numbers_type_context):
    tc = numbers_type_context
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("0", "Number")
    tokenizer = NonLetterTokenizer()
    in_str = "nil"
    tokens, metadata = tokenizer.tokenize(in_str)
    ast_set = AstObjectChoiceSet(tc.get_type_by_name("Number"))
    ast_set.add(ast, True, 1, 1)
    add_copies_to_ast_set(ast, ast_set, unparser, metadata)
    assert not ast_set.copy_is_known_choice()
コード例 #19
0
ファイル: test_copy_tools.py プロジェクト: AInixProject/AInix
def test_partial_copy_numbers():
    tc = TypeContext()
    loader.load_path(f"builtin_types/generic_parsers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    loader.load_path(f"builtin_types/numbers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    tc.finalize_data()
    parser = StringParser(tc)
    tokenizer = NonLetterTokenizer()
    unparser = AstUnparser(tc, tokenizer)
    ast = parser.create_parse_tree("1000", "Number")
コード例 #20
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_full_number_2(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    ainix_kernel.indexing.exampleloader.load_path(
        f"{BUILTIN_TYPES_PATH}/numbers_examples.ainix.yaml", index)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("9", "Number")
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("nineth", "Number", False)
    assert expected == prediction
コード例 #21
0
def get_parsable_commands(data: Tuple[str, str]) -> List[Tuple[str, str]]:
    tc = get_a_tc()
    parser = StringParser(tc)
    parsable_data = []
    for nl, cm in data:
        try:
            if cm.strip().startswith("tar"):
                continue
            ast = parser.create_parse_tree(cm, "CommandSequence")
            parsable_data.append((nl, cm))
            print(f"PASS {cm}")
        except AInixParseError as e:
            pass
    return parsable_data
コード例 #22
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_predict_digit(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    x_y = [("one", "1"), ("two", "2"), ("three", "3")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE,
                                      "BaseTen")
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("2", "BaseTen")
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("two", "BaseTen", False)
    assert expected == prediction
コード例 #23
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_type_pred_gt_result(numbers_type_context):
    parser = StringParser(numbers_type_context)
    ast = parser.create_parse_tree("9", "BaseTen")
    type = numbers_type_context.get_type_by_name("BaseTen")
    valid_set = AstObjectChoiceSet(type, None)
    valid_set.add(ast, True, 1, 1)
    choose = ObjectChoiceNode(type)
    gt_res = models.SeaCR.comparer._create_gt_compare_result(
        ast, choose, valid_set)
    assert gt_res.prob_valid_in_example == 1
    assert gt_res.impl_scores == ((1, "nine"), )
    ast = parser.create_parse_tree("6", "BaseTen")
    gt_res = models.SeaCR.comparer._create_gt_compare_result(
        ast, choose, valid_set)
    assert gt_res.prob_valid_in_example == 0
    assert gt_res.impl_scores is None
コード例 #24
0
def update_latent_store_from_examples(
        model: 'StringTypeTranslateCF', latent_store: LatentStore,
        examples: ExamplesStore, replacer: Replacer, parser: StringParser,
        splits: Optional[Tuple[DataSplits]], unparser: AstUnparser,
        tokenizer: StringTokenizer):
    model.set_in_eval_mode()
    for example in examples.get_all_x_values(splits):
        # TODO multi sampling and average replacers
        x_replaced, y_replaced = replacer.strings_replace(
            example.xquery, example.ytext, seed_from_x_val(example.xquery))
        ast = parser.create_parse_tree(y_replaced, example.ytype)
        _, token_metadata = tokenizer.tokenize(x_replaced)
        copy_ast = copy_tools.make_copy_version_of_tree(
            ast, unparser, token_metadata)
        # TODO Think about whether feeding in the raw x is good idea.
        # will change once have replacer sampling
        latents = model.get_latent_select_states(example.xquery, copy_ast)
        nodes = list(copy_ast.depth_first_iter())
        #print("LATENTS", latents)
        for i, l in enumerate(latents):
            dfs_depth = i * 2
            n = nodes[dfs_depth].cur_node
            assert isinstance(n, ObjectChoiceNode)
            c = l.detach()
            assert not c.requires_grad
            latent_store.set_latent_for_example(c, n.type_to_choose.ind,
                                                example.id, dfs_depth)
    model.set_in_train_mode()
コード例 #25
0
def make_y_ast_set(y_type: AInixType, all_y_examples: List[YValue],
                   replacement_sample: ReplacementSampling,
                   string_parser: StringParser,
                   this_x_metadata: StringTokensMetadata,
                   unparser: AstUnparser):
    y_ast_set = AstObjectChoiceSet(y_type, None)
    y_texts = set()
    individual_asts = []
    individual_asts_preferences = []
    for y_example in all_y_examples:
        replaced_y = replacement_sample.replace_x(y_example.y_text)
        if replaced_y not in y_texts:
            parsed_ast = string_parser.create_parse_tree(
                replaced_y, y_type.name)
            individual_asts.append(parsed_ast)
            individual_asts_preferences.append(y_example.y_preference)
            y_ast_set.add(parsed_ast, True, y_example.y_preference, 1.0)
            y_texts.add(replaced_y)
            # handle copies
            # TODO figure how to weight the copy node??
            add_copies_to_ast_set(parsed_ast,
                                  y_ast_set,
                                  unparser,
                                  this_x_metadata,
                                  copy_node_weight=1)
    y_ast_set.freeze()
    teacher_force_path_ast = WeightedRandomChooser(
        individual_asts, individual_asts_preferences).sample()
    return y_ast_set, y_texts, teacher_force_path_ast
コード例 #26
0
ファイル: encdecmodel.py プロジェクト: AInixProject/AInix
def get_default_encdec_model(examples: ExamplesStore,
                             standard_size=16,
                             replacer: Replacer = None,
                             use_retrieval_decoder: bool = False,
                             pretrain_checkpoint: str = None,
                             learn_on_top_pretrained: bool = False):
    (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers(
        use_word_piece=pretrain_checkpoint is not None)
    if x_vocab is None or pretrain_checkpoint is None:
        # If we don't have a pretrained_checkpoint, then we might not know all
        # words in the vocab, so should filter the vocab to only tokens in dataset
        x_vocab = vocab.make_x_vocab_from_examples(examples,
                                                   x_tokenizer,
                                                   replacer,
                                                   min_freq=5,
                                                   replace_samples=3)
    hidden_size = standard_size
    tc = examples.type_context
    encoder = make_default_query_encoder(x_tokenizer, x_vocab, hidden_size,
                                         pretrain_checkpoint)
    if not use_retrieval_decoder:
        decoder = decoders.get_default_nonretrieval_decoder(tc, hidden_size)
    else:
        parser = StringParser(tc)
        unparser = AstUnparser(tc, x_tokenizer)
        decoder = decoders.get_default_retrieval_decoder(
            tc, hidden_size, examples, replacer, parser, unparser)
    model = EncDecModel(examples.type_context, encoder, decoder)
    if use_retrieval_decoder:
        # TODO lolz, this is such a crappy interface
        model.plz_train_this_latent_store_thanks = lambda: decoder.action_selector.latent_store
    return model
コード例 #27
0
ファイル: encdecmodel.py プロジェクト: AInixProject/AInix
    def create_from_save_state_dict(
        cls,
        state_dict: dict,
        new_type_context: TypeContext,
        new_example_store: ExamplesStore,
    ) -> 'EncDecModel':
        # TODO (DNGros): acutally handle the new type context.

        # TODO check the name of the query encoder
        if state_dict['name'] == "EncoderDecoder":
            # I don't think this is right??
            query_encoder = encoders.StringQueryEncoder.create_from_save_state_dict(
                state_dict['query_encoder'])
        else:
            query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict(
                state_dict['query_encoder'])
        parser = StringParser(new_type_context)
        unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer())
        replacers = get_all_replacers()
        decoder = TreeRNNDecoder.create_from_save_state_dict(
            state_dict['decoder'], new_type_context, new_example_store,
            replacers, parser, unparser)
        model = cls(type_context=new_type_context,
                    query_encoder=query_encoder,
                    tree_decoder=decoder)
        if state_dict['need_latent_train']:
            update_latent_store_from_examples(
                model, decoder.action_selector.latent_store,
                new_example_store, replacers, parser, None, unparser,
                query_encoder.get_tokenizer())
        return model
コード例 #28
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_digit_list_2(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    type = "IntBase"
    x_y = [("ten", "10"), ("twenty", "20"), ("thirty", "30")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("20", type)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("twenty", type, False)
    print(expected.dump_str())
    print(prediction.dump_str())
    assert expected == prediction
コード例 #29
0
ファイル: test_seacr.py プロジェクト: AInixProject/AInix
def test_only_one_option(base_type_context):
    # Create some types
    foo_type = AInixType(base_type_context, "FooType")
    AInixObject(base_type_context, "FooObj", "FooType")
    base_type_context.finalize_data()
    # Create an index
    index = ExamplesIndex(base_type_context,
                          ExamplesIndex.get_default_ram_backend())
    index.add_yset_default_weight(["example"], ["y"], index.DEFAULT_X_TYPE,
                                  "FooType")
    # Create a expected value
    parser = StringParser(base_type_context)
    expected = parser.create_parse_tree("string", foo_type.name)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("example", "FooType", False)
    assert expected == prediction
コード例 #30
0
def test_touch_set(all_the_stuff_context):
    x_str = 'set the last mod time of out.txt to now'
    tc = all_the_stuff_context
    parser = StringParser(tc)
    string = "touch out.txt"
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast, x_str)
    assert result.total_string == string

    cset = AstObjectChoiceSet(tc.get_type_by_name("Program"))
    cset.add(ast, True, 1, 1)
    new_ast = parser.create_parse_tree(string, "Program")
    assert cset.is_node_known_valid(new_ast)

    tokenizer = NonLetterTokenizer()
    _, tok_metadata = tokenizer.tokenize(x_str)
    ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata)
    add_copies_to_ast_set(ast, cset, unparser, tok_metadata)
    assert cset.is_node_known_valid(ast_copies)
    assert cset.is_node_known_valid(ast)

    # Scary complicated reconstruction of something that broke it.
    # could be made into a simpler unit test in copy_tools
    touch_o = tc.get_object_by_name("touch")
    file_list = tc.get_type_by_name("PathList")
    r_arg = touch_o.get_arg_by_name("r")
    m_arg = touch_o.get_arg_by_name("m")
    other_copy = ObjectChoiceNode(
        tc.get_type_by_name("Program"),
        ObjectNode(
            touch_o,
            pmap({
                "r":
                ObjectChoiceNode(r_arg.present_choice_type,
                                 ObjectNode(r_arg.not_present_object, pmap())),
                "m":
                ObjectChoiceNode(m_arg.present_choice_type,
                                 ObjectNode(m_arg.not_present_object, pmap())),
                "file_list":
                ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14))
            })))
    other_result = unparser.to_string(other_copy, x_str)
    assert other_result.total_string == string
    assert cset.is_node_known_valid(other_copy)