Ejemplo n.º 1
0
def test_full_number(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    type = "Number"
    x_y = [("one", "1"), ("two", "2"), ("three", "3"), ("ten", "10"),
           ("negative one", "-1")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("-1", type)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("negative one", type, False)
    print("Expected")
    print(expected.dump_str())
    print("predicted")
    print(prediction.dump_str())
    assert expected.dump_str() == prediction.dump_str()
    assert expected == prediction
    #
    expected = parser.create_parse_tree("2", type)
    prediction = model.predict("two", type, False)
    assert expected == prediction
    #
    with pytest.raises(ModelCantPredictException):
        model.predict("sdsdfas sdfasdf asdf", type, False)
Ejemplo n.º 2
0
def test_string_parse_e2e_sequence(type_context):
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw")
                          ],
                          type_data={"invoke_name": "hello"})
    parser = StringParser(type_context)
    unparser = AstUnparser(type_context)
    string = "hello -a | hello -b"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    no_space = "hello -a|hello -b"
    ast = parser.create_parse_tree(no_space, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    string = "hello -a | hello"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string
Ejemplo n.º 3
0
def test_not_fail_find_expr(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "FindExpression")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
Ejemplo n.º 4
0
def update_latent_store_from_examples(
        model: 'StringTypeTranslateCF', latent_store: LatentStore,
        examples: ExamplesStore, replacer: Replacer, parser: StringParser,
        splits: Optional[Tuple[DataSplits]], unparser: AstUnparser,
        tokenizer: StringTokenizer):
    model.set_in_eval_mode()
    for example in examples.get_all_x_values(splits):
        # TODO multi sampling and average replacers
        x_replaced, y_replaced = replacer.strings_replace(
            example.xquery, example.ytext, seed_from_x_val(example.xquery))
        ast = parser.create_parse_tree(y_replaced, example.ytype)
        _, token_metadata = tokenizer.tokenize(x_replaced)
        copy_ast = copy_tools.make_copy_version_of_tree(
            ast, unparser, token_metadata)
        # TODO Think about whether feeding in the raw x is good idea.
        # will change once have replacer sampling
        latents = model.get_latent_select_states(example.xquery, copy_ast)
        nodes = list(copy_ast.depth_first_iter())
        #print("LATENTS", latents)
        for i, l in enumerate(latents):
            dfs_depth = i * 2
            n = nodes[dfs_depth].cur_node
            assert isinstance(n, ObjectChoiceNode)
            c = l.detach()
            assert not c.requires_grad
            latent_store.set_latent_for_example(c, n.type_to_choose.ind,
                                                example.id, dfs_depth)
    model.set_in_train_mode()
Ejemplo n.º 5
0
def test_path_parse_and_unparse_with_error(tc, in_str):
    parser = StringParser(tc)
    try:
        ast = parser.create_parse_tree(in_str, "Path")
        pytest.fail(f"{in_str} unexpectedly worked")
    except AInixParseError as e:
        pass
Ejemplo n.º 6
0
def test_string_parse_e2e_multiword3(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw"),
                              _make_positional()
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello -a", "CommandSequence")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a"
Ejemplo n.º 7
0
def test_word_parts_2():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("fooBarBaz",
                                   WORD_PART_TYPE_NAME,
                                   allow_partial_consume=True)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == _name_for_word_part("bar")
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "fooBar"
    pointers = list(ast.depth_first_iter())
    assert get_str_and_assert_same_part(result, pointers[1],
                                        word_part_o) == "fooBar"
    assert get_str_and_assert_same_part(result, pointers[2],
                                        mod_type_choice) == "foo"
    assert get_str_and_assert_same_part(result, pointers[3],
                                        mod_type_object) == ""
    assert get_str_and_assert_same_part(result, pointers[4],
                                        next_type_choice) == "Bar"
    assert get_str_and_assert_same_part(result, pointers[5],
                                        next_part_o) == "Bar"
Ejemplo n.º 8
0
def test_word_parts_upper():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("FOO", WORD_PART_TYPE_NAME)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_ALL_UPPER
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == WORD_PART_TERMINAL_NAME
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "FOO"
    pointers = list(ast.depth_first_iter())
    assert ast == pointers[0].cur_node
    assert result.pointer_to_string(pointers[0]) == "FOO"
    assert word_part_o == pointers[1].cur_node
    assert result.pointer_to_string(pointers[1]) == "FOO"
    assert mod_type_choice == pointers[2].cur_node
    assert result.pointer_to_string(pointers[2]) == "FOO"
    assert mod_type_object == pointers[3].cur_node
    assert result.pointer_to_string(pointers[3]) == ""
Ejemplo n.º 9
0
def make_vocab_from_example_store(
    exampe_store: ExamplesStore,
    x_tokenizer: Tokenizer,
    y_tokenizer: Tokenizer,
    x_vocab_builder: VocabBuilder = None,
    y_vocab_builder: VocabBuilder = None
) -> typing.Tuple[Vocab, Vocab]:
    """Creates an x and y vocab based off all examples in an example store.

    Args:
        exampe_store: The example store to build from
        x_tokenizer: The tokenizer to use for x queries
        y_tokenizer: The tokenizer to use for y queries
        x_vocab_builder: The builder for the kind of x vocab we want to produce.
            If None, it just picks a reasonable default.
        y_vocab_builder: The builder for the kind of y vocab we want to produce.
            If None, it just picks a reasonable default.
    Returns:
        Tuple of the new (x vocab, y vocab).
    """
    if x_vocab_builder is None:
        x_vocab_builder = CounterVocabBuilder(min_freq=1)
    if y_vocab_builder is None:
        y_vocab_builder = CounterVocabBuilder()

    already_done_ys = set()
    parser = StringParser(exampe_store.type_context)
    for example in exampe_store.get_all_x_values():
        x_vocab_builder.add_sequence(x_tokenizer.tokenize(example.xquery)[0])
        if example.ytext not in already_done_ys:
            ast = parser.create_parse_tree(example.ytext, example.ytype)
            y_tokens, _ = y_tokenizer.tokenize(ast)
            y_vocab_builder.add_sequence(y_tokens)
    return x_vocab_builder.produce_vocab(), y_vocab_builder.produce_vocab()
Ejemplo n.º 10
0
def test_type_pred_gt_result(numbers_type_context):
    parser = StringParser(numbers_type_context)
    ast = parser.create_parse_tree("9", "BaseTen")
    type = numbers_type_context.get_type_by_name("BaseTen")
    valid_set = AstObjectChoiceSet(type, None)
    valid_set.add(ast, True, 1, 1)
    choose = ObjectChoiceNode(type)
    gt_res = models.SeaCR.comparer._create_gt_compare_result(
        ast, choose, valid_set)
    assert gt_res.prob_valid_in_example == 1
    assert gt_res.impl_scores == ((1, "nine"), )
    ast = parser.create_parse_tree("6", "BaseTen")
    gt_res = models.SeaCR.comparer._create_gt_compare_result(
        ast, choose, valid_set)
    assert gt_res.prob_valid_in_example == 0
    assert gt_res.impl_scores is None
Ejemplo n.º 11
0
def test_max_munch():
    tc = TypeContext()
    loader.load_path("builtin_types/generic_parsers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    foo_type = "MMTestType"
    AInixType(tc, foo_type, default_type_parser_name="max_munch_type_parser")

    def make_mock_with_parse_rep(representation: str):
        loader._load_object(
            {
                "name": representation,
                "type": foo_type,
                "preferred_object_parser": {
                    "grammar": f'"{representation}"'
                }
            }, tc, "foopathsdf")
        assert tc.get_object_by_name(
            representation).preferred_object_parser_name is not None

    objects = [
        make_mock_with_parse_rep(rep)
        for rep in ("fo", "bar", "f", "foo", "foot", 'baz')
    ]

    parser = StringParser(tc)
    ast = parser.create_parse_tree("foobar",
                                   foo_type,
                                   allow_partial_consume=True)
    assert ast.next_node_not_copy.implementation.name == "foo"
Ejemplo n.º 12
0
def test_string_parse_e2e_pos_unparse(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: False
                                            },
                                            parent_object_name="bw",
                                            required=True)
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo", "Program")
    unparser = AstUnparser(type_context)
    unparse_result = unparser.to_string(ast)
    assert unparse_result.total_string == "hello foo"
    for p in ast.depth_first_iter():
        n = p.cur_node
        if isinstance(n,
                      ObjectChoiceNode) and n.type_to_choose.name == "FooType":
            arg_node_pointer = p
            break
    assert unparse_result.pointer_to_string(arg_node_pointer) == "foo"
Ejemplo n.º 13
0
def test_string_parse_e2e_multiword2(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo bar"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: True
                                            },
                                            parent_object_name="sdf")
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo bar -a", "Program")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a foo bar"
Ejemplo n.º 14
0
def make_y_ast_set(y_type: AInixType, all_y_examples: List[YValue],
                   replacement_sample: ReplacementSampling,
                   string_parser: StringParser,
                   this_x_metadata: StringTokensMetadata,
                   unparser: AstUnparser):
    y_ast_set = AstObjectChoiceSet(y_type, None)
    y_texts = set()
    individual_asts = []
    individual_asts_preferences = []
    for y_example in all_y_examples:
        replaced_y = replacement_sample.replace_x(y_example.y_text)
        if replaced_y not in y_texts:
            parsed_ast = string_parser.create_parse_tree(
                replaced_y, y_type.name)
            individual_asts.append(parsed_ast)
            individual_asts_preferences.append(y_example.y_preference)
            y_ast_set.add(parsed_ast, True, y_example.y_preference, 1.0)
            y_texts.add(replaced_y)
            # handle copies
            # TODO figure how to weight the copy node??
            add_copies_to_ast_set(parsed_ast,
                                  y_ast_set,
                                  unparser,
                                  this_x_metadata,
                                  copy_node_weight=1)
    y_ast_set.freeze()
    teacher_force_path_ast = WeightedRandomChooser(
        individual_asts, individual_asts_preferences).sample()
    return y_ast_set, y_texts, teacher_force_path_ast
Ejemplo n.º 15
0
def test_make_copy_optional_arg():
    tc = TypeContext()
    ft = AInixType(tc, "ft")
    bt = AInixType(tc, "bt")
    arg1 = AInixArgument(tc,
                         "arg1",
                         "bt",
                         required=False,
                         parent_object_name="fo")
    fo = AInixObject(
        tc,
        "fo",
        "ft", [arg1],
        preferred_object_parser_name=create_object_parser_from_grammar(
            tc, "masfoo_parser", '"foo" arg1?').name)
    bo = AInixObject(
        tc,
        "bo",
        "bt",
        None,
        preferred_object_parser_name=create_object_parser_from_grammar(
            tc, "masdfo_parser", '"bar"').name)
    tc.finalize_data()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("foobar", "ft")
    tokenizer = SpaceTokenizer()
    in_str = "Hello bar sdf cow"
    tokens, metadata = tokenizer.tokenize(in_str)
    unpar_res = unparser.to_string(ast)
    assert unpar_res.total_string == "foobar"
    result = make_copy_version_of_tree(ast, unparser, metadata)
    assert result.next_node_not_copy.get_choice_node_for_arg(
        "arg1").copy_was_chosen
Ejemplo n.º 16
0
def test_touch_set(all_the_stuff_context):
    x_str = 'set the last mod time of out.txt to now'
    tc = all_the_stuff_context
    parser = StringParser(tc)
    string = "touch out.txt"
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast, x_str)
    assert result.total_string == string

    cset = AstObjectChoiceSet(tc.get_type_by_name("Program"))
    cset.add(ast, True, 1, 1)
    new_ast = parser.create_parse_tree(string, "Program")
    assert cset.is_node_known_valid(new_ast)

    tokenizer = NonLetterTokenizer()
    _, tok_metadata = tokenizer.tokenize(x_str)
    ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata)
    add_copies_to_ast_set(ast, cset, unparser, tok_metadata)
    assert cset.is_node_known_valid(ast_copies)
    assert cset.is_node_known_valid(ast)

    # Scary complicated reconstruction of something that broke it.
    # could be made into a simpler unit test in copy_tools
    touch_o = tc.get_object_by_name("touch")
    file_list = tc.get_type_by_name("PathList")
    r_arg = touch_o.get_arg_by_name("r")
    m_arg = touch_o.get_arg_by_name("m")
    other_copy = ObjectChoiceNode(
        tc.get_type_by_name("Program"),
        ObjectNode(
            touch_o,
            pmap({
                "r":
                ObjectChoiceNode(r_arg.present_choice_type,
                                 ObjectNode(r_arg.not_present_object, pmap())),
                "m":
                ObjectChoiceNode(m_arg.present_choice_type,
                                 ObjectNode(m_arg.not_present_object, pmap())),
                "file_list":
                ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14))
            })))
    other_result = unparser.to_string(other_copy, x_str)
    assert other_result.total_string == string
    assert cset.is_node_known_valid(other_copy)
Ejemplo n.º 17
0
def test_fails(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    try:
        ast = parser.create_parse_tree(string, "CommandSequence")
        print(ast.dump_str())
        pytest.fail("Expected fail")
    except AInixParseError:
        pass
Ejemplo n.º 18
0
def test_cp(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
    pointers = list(ast.depth_first_iter())
    assert result.pointer_to_string(pointers[0]) == string
Ejemplo n.º 19
0
def test_numbers_copys(numbers_type_context):
    tc = numbers_type_context
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("0", "Number")
    tokenizer = NonLetterTokenizer()
    in_str = "nil"
    tokens, metadata = tokenizer.tokenize(in_str)
    ast_set = AstObjectChoiceSet(tc.get_type_by_name("Number"))
    ast_set.add(ast, True, 1, 1)
    add_copies_to_ast_set(ast, ast_set, unparser, metadata)
    assert not ast_set.copy_is_known_choice()
Ejemplo n.º 20
0
def test_get_copy_paths():
    tc = get_toy_strings_context()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("TWO foo bar", "ToySimpleStrs")
    unpar_res = unparser.to_string(ast)
    assert unpar_res.total_string == "TWO foo bar"
    tokenizer = SpaceTokenizer()
    in_str = "Hello there foo cow"
    tokens, metadata = tokenizer.tokenize(in_str)
    result = make_copy_version_of_tree(ast, unparser, metadata)
    assert get_paths_to_all_copies(result) == ((0, 0, 0), )
Ejemplo n.º 21
0
def test_full_number_2(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    ainix_kernel.indexing.exampleloader.load_path(
        f"{BUILTIN_TYPES_PATH}/numbers_examples.ainix.yaml", index)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("9", "Number")
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("nineth", "Number", False)
    assert expected == prediction
Ejemplo n.º 22
0
def test_partial_copy_numbers():
    tc = TypeContext()
    loader.load_path(f"builtin_types/generic_parsers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    loader.load_path(f"builtin_types/numbers.ainix.yaml",
                     tc,
                     up_search_limit=3)
    tc.finalize_data()
    parser = StringParser(tc)
    tokenizer = NonLetterTokenizer()
    unparser = AstUnparser(tc, tokenizer)
    ast = parser.create_parse_tree("1000", "Number")
Ejemplo n.º 23
0
def get_parsable_commands(data: Tuple[str, str]) -> List[Tuple[str, str]]:
    tc = get_a_tc()
    parser = StringParser(tc)
    parsable_data = []
    for nl, cm in data:
        try:
            if cm.strip().startswith("tar"):
                continue
            ast = parser.create_parse_tree(cm, "CommandSequence")
            parsable_data.append((nl, cm))
            print(f"PASS {cm}")
        except AInixParseError as e:
            pass
    return parsable_data
Ejemplo n.º 24
0
def test_predict_digit(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    x_y = [("one", "1"), ("two", "2"), ("three", "3")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE,
                                      "BaseTen")
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("2", "BaseTen")
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("two", "BaseTen", False)
    assert expected == prediction
Ejemplo n.º 25
0
class OracleComparer(Comparer):
    """A comparer which peeks at the index to try and always returns the right
    results even if example not in training set. (Useful for testing)"""
    def __init__(self, index: ExamplesIndex):
        self.index = index
        self.parser = StringParser(self.index.type_context)

    def _get_actual_example_from_index(self, gen_query: str,
                                       gen_ast_current_leaf: ObjectChoiceNode):
        lookup_results = self.index.get_nearest_examples(
            gen_query, gen_ast_current_leaf.get_type_to_choose_name(), max_results=25)
        for result in lookup_results:
            if result.xquery == gen_query:
                return result
        raise ValueError(f"Oracle unable to find result for {gen_query}")

    def compare(
        self,
        gen_query: str,
        gen_ast_current_root: ObjectChoiceNode,
        gen_ast_current_leaf: ObjectChoiceNode,
        current_gen_depth: int,
        example_query: str,
        example_ast_root: ObjectChoiceNode,
    ) -> ComparerResult:
        if current_gen_depth > 1:
            raise NotImplemented("Oracle does not currently support greater depths")
        oracle_gt_example = self._get_actual_example_from_index(gen_query, gen_ast_current_leaf)
        oracle_ast = self.parser.create_parse_tree(
            oracle_gt_example.ytext, oracle_gt_example.ytype)
        oracle_ast_set = AstObjectChoiceSet(oracle_ast.type_to_choose, None)
        oracle_ast_set.add(oracle_ast, True, 1, 1)
        return _create_gt_compare_result(
            example_ast_root, gen_ast_current_leaf, oracle_ast_set)

    def train(
        self,
        gen_query: str,
        gen_ast_current_root: ObjectChoiceNode,
        gen_ast_current_leaf: ObjectChoiceNode,
        current_gen_depth: int,
        example_query: str,
        example_ast_root: AstNode,
        expected_result: ComparerResult
    ):
        # Phhhsh, I'm an oracle. I don't need your "training"...
        pass
Ejemplo n.º 26
0
def test_only_one_option(base_type_context):
    # Create some types
    foo_type = AInixType(base_type_context, "FooType")
    AInixObject(base_type_context, "FooObj", "FooType")
    base_type_context.finalize_data()
    # Create an index
    index = ExamplesIndex(base_type_context,
                          ExamplesIndex.get_default_ram_backend())
    index.add_yset_default_weight(["example"], ["y"], index.DEFAULT_X_TYPE,
                                  "FooType")
    # Create a expected value
    parser = StringParser(base_type_context)
    expected = parser.create_parse_tree("string", foo_type.name)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("example", "FooType", False)
    assert expected == prediction
Ejemplo n.º 27
0
def test_digit_list_2(numbers_type_context):
    # Create an index
    index = ExamplesIndex(numbers_type_context,
                          ExamplesIndex.get_default_ram_backend())
    type = "IntBase"
    x_y = [("ten", "10"), ("twenty", "20"), ("thirty", "30")]
    for x, y in x_y:
        index.add_yset_default_weight([x], [y], index.DEFAULT_X_TYPE, type)
    # Create a expected value
    parser = StringParser(numbers_type_context)
    expected = parser.create_parse_tree("20", type)
    # Predict
    model = make_rulebased_seacr(index)
    prediction = model.predict("twenty", type, False)
    print(expected.dump_str())
    print(prediction.dump_str())
    assert expected == prediction
Ejemplo n.º 28
0
def test_get_latents():
    out_v = torch.Tensor([[1, 2, 3, 4]])
    mock_cell = MagicMock(return_value=(out_v, torch.Tensor(1, 4)))
    mock_selector = MagicMock()
    mock_vectorizer = MagicMock()
    mock_vocab = MagicMock()
    decoder = TreeRNNDecoder(mock_cell, mock_selector, mock_vectorizer,
                             mock_vocab)
    tc = get_toy_strings_context()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("TWO foo bar", "ToySimpleStrs")

    latents = decoder.get_latent_select_states(torch.Tensor(1, 4),
                                               torch.Tensor(1, 3, 4),
                                               MagicMock(), ast)

    assert len(latents) == 3
    assert latents == [out_v for _ in range(3)]
Ejemplo n.º 29
0
def test_file_replacer():
    replacements = _load_replacer_relative(
        "../../../training/augmenting/data/FILENAME.tsv")
    tc = TypeContext()
    loader = TypeContextDataLoader(tc, up_search_limit=4)
    loader.load_path("builtin_types/generic_parsers.ainix.yaml")
    loader.load_path("builtin_types/command.ainix.yaml")
    loader.load_path("builtin_types/paths.ainix.yaml")
    allspecials.load_all_special_types(tc)
    tc.finalize_data()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)

    for repl in replacements:
        x, y = repl.get_replacement()
        assert x == y
        ast = parser.create_parse_tree(x, "Path")
        result = unparser.to_string(ast)
        assert result.total_string == x
Ejemplo n.º 30
0
def test_add_copies_to_ast_set_other_arg():
    tc = get_toy_strings_context()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)
    ast = parser.create_parse_tree("TWO foo bar", "ToySimpleStrs")
    unpar_res = unparser.to_string(ast)
    assert unpar_res.total_string == "TWO foo bar"
    tokenizer = SpaceTokenizer()
    in_str = "Hello bar sf cow"
    tokens, metadata = tokenizer.tokenize(in_str)
    ast_set = AstObjectChoiceSet(tc.get_type_by_name("ToySimpleStrs"))
    ast_set.add(ast, True, 1, 1)
    n: ObjectNode = ast.next_node_not_copy
    arg1set = ast_set.get_next_node_for_choice("two_string").next_node. \
        get_arg_set_data(n.as_childless_node()).get_next_node_for_arg("arg2")
    assert arg1set.is_known_choice("bar")
    assert not arg1set.copy_is_known_choice()
    add_copies_to_ast_set(ast, ast_set, unparser, metadata)
    assert n.implementation.name == "two_string"
    assert arg1set.copy_is_known_choice()
    assert arg1set.is_known_choice("bar")