예제 #1
0
def test_not_fail_find_expr(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "FindExpression")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
예제 #2
0
def test_string_parse_e2e_pos_unparse(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: False
                                            },
                                            parent_object_name="bw",
                                            required=True)
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo", "Program")
    unparser = AstUnparser(type_context)
    unparse_result = unparser.to_string(ast)
    assert unparse_result.total_string == "hello foo"
    for p in ast.depth_first_iter():
        n = p.cur_node
        if isinstance(n,
                      ObjectChoiceNode) and n.type_to_choose.name == "FooType":
            arg_node_pointer = p
            break
    assert unparse_result.pointer_to_string(arg_node_pointer) == "foo"
예제 #3
0
def test_string_parse_e2e_multiword2(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo bar"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "p1",
                                            "FooType",
                                            arg_data={
                                                POSITION: 0,
                                                MULTIWORD_POS_ARG: True
                                            },
                                            parent_object_name="sdf")
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello foo bar -a", "Program")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a foo bar"
예제 #4
0
def test_string_parse_e2e_sequence(type_context):
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw")
                          ],
                          type_data={"invoke_name": "hello"})
    parser = StringParser(type_context)
    unparser = AstUnparser(type_context)
    string = "hello -a | hello -b"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    no_space = "hello -a|hello -b"
    ast = parser.create_parse_tree(no_space, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string

    string = "hello -a | hello"
    ast = parser.create_parse_tree(string, "CommandSequence")
    to_string = unparser.to_string(ast)
    assert to_string.total_string == string
예제 #5
0
def test_word_parts_2():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("fooBarBaz",
                                   WORD_PART_TYPE_NAME,
                                   allow_partial_consume=True)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == _name_for_word_part("bar")
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "fooBar"
    pointers = list(ast.depth_first_iter())
    assert get_str_and_assert_same_part(result, pointers[1],
                                        word_part_o) == "fooBar"
    assert get_str_and_assert_same_part(result, pointers[2],
                                        mod_type_choice) == "foo"
    assert get_str_and_assert_same_part(result, pointers[3],
                                        mod_type_object) == ""
    assert get_str_and_assert_same_part(result, pointers[4],
                                        next_type_choice) == "Bar"
    assert get_str_and_assert_same_part(result, pointers[5],
                                        next_part_o) == "Bar"
예제 #6
0
def test_string_parse_e2e_multiword3(type_context):
    fooType = AInixType(type_context, "FooType")
    fo = AInixObject(
        type_context,
        "fo",
        "FooType", [],
        preferred_object_parser_name=create_object_parser_from_grammar(
            type_context, "fooname", '"foo"').name)
    twoargs = AInixObject(type_context,
                          "FooProgram",
                          "Program", [
                              AInixArgument(type_context,
                                            "a",
                                            None,
                                            arg_data={"short_name": "a"},
                                            parent_object_name="sdf"),
                              AInixArgument(type_context,
                                            "barg",
                                            None,
                                            arg_data={"short_name": "b"},
                                            parent_object_name="bw"),
                              _make_positional()
                          ],
                          type_data={"invoke_name": "hello"})
    type_context.finalize_data()
    parser = StringParser(type_context)
    ast = parser.create_parse_tree("hello -a", "CommandSequence")
    unparser = AstUnparser(type_context)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == "hello -a"
예제 #7
0
def test_word_parts_upper():
    tc = TypeContext()
    _create_root_types(tc)
    _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True),
                                ("!", False)])
    tc.finalize_data()
    parser = StringParser(tc)
    ast = parser.create_parse_tree("FOO", WORD_PART_TYPE_NAME)
    word_part_o = ast.next_node_not_copy
    assert word_part_o.implementation.name == _name_for_word_part("foo")
    mod_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_ALL_UPPER
    next_type_choice = word_part_o.get_choice_node_for_arg(
        WORD_PART_NEXT_ARG_NAME)
    next_part_o = next_type_choice.next_node_not_copy
    assert next_part_o.implementation.name == WORD_PART_TERMINAL_NAME
    ### Unparse
    unparser = AstUnparser(tc)
    result = unparser.to_string(ast)
    assert result.total_string == "FOO"
    pointers = list(ast.depth_first_iter())
    assert ast == pointers[0].cur_node
    assert result.pointer_to_string(pointers[0]) == "FOO"
    assert word_part_o == pointers[1].cur_node
    assert result.pointer_to_string(pointers[1]) == "FOO"
    assert mod_type_choice == pointers[2].cur_node
    assert result.pointer_to_string(pointers[2]) == "FOO"
    assert mod_type_object == pointers[3].cur_node
    assert result.pointer_to_string(pointers[3]) == ""
예제 #8
0
def test_cp(all_the_stuff_context, string):
    tc = all_the_stuff_context
    parser = StringParser(tc)
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast)
    assert result.total_string == string
    pointers = list(ast.depth_first_iter())
    assert result.pointer_to_string(pointers[0]) == string
예제 #9
0
    def __init__(self, file_name):
        #self.type_context, self.model, self.example_store = restore(file_name)

        # hacks
        from ainix_kernel.training import fullret_try
        model, index, replacers, type_context, loader = fullret_try.train_the_thing()
        self.type_context, self.model, self.example_store = type_context, model, index

        self.unparser = AstUnparser(self.type_context, self.model.get_string_tokenizer())
예제 #10
0
def make_copy_version_of_tree(
        ast: ObjectChoiceNode, unparser: AstUnparser,
        token_metadata: StringTokensMetadata) -> ObjectChoiceNode:
    """Goes through and replaces anywhere can copy with a copy at earliest possible
    oppertunity. Makes a copy and does not mutate the original ast"""
    unparse = unparser.to_string(ast)
    cur_pointer = AstIterPointer(ast, None, None)
    last_pointer = None
    while cur_pointer:
        if isinstance(cur_pointer.cur_node, ObjectChoiceNode):
            this_node_str = unparse.pointer_to_string(cur_pointer)
            if this_node_str:
                copy_pos = string_in_tok_list(this_node_str, token_metadata)
                # Questionable hacky skip?
                # Need to avoid being overly generous with copying stuff
                other_no_copy_reason = this_node_str in STOP_WORDS or \
                                       this_node_str in ('"', ".", ",", "?")
                if copy_pos and not other_no_copy_reason:
                    copy_node = CopyNode(cur_pointer.cur_node.type_to_choose,
                                         copy_pos[0], copy_pos[1])
                    cur_pointer = cur_pointer.dfs_get_next().change_here(
                        copy_node, always_clone=True)
        last_pointer = cur_pointer
        cur_pointer = cur_pointer.dfs_get_next()
    return last_pointer.get_root().cur_node
예제 #11
0
class Interface():
    def __init__(self, file_name):
        #self.type_context, self.model, self.example_store = restore(file_name)

        # hacks
        from ainix_kernel.training import fullret_try
        model, index, replacers, type_context, loader = fullret_try.train_the_thing()
        self.type_context, self.model, self.example_store = type_context, model, index

        self.unparser = AstUnparser(self.type_context, self.model.get_string_tokenizer())

    def predict(self, utterance: str, ytype: str) -> PredictReturn:
        try:
            result, metad = self.model.predict(utterance, ytype, False)
            assert result.is_frozen
            unparse = self.unparser.to_string(result, utterance)
            return PredictReturn(
                success=True,
                ast=result,
                unparse=unparse,
                metad=metad,
                error_message=None
            )
        except ModelException as e:
            return PredictReturn(
                False,
                None,
                None,
                None, str(e)
            )
예제 #12
0
    def create_from_save_state_dict(
        cls,
        state_dict: dict,
        new_type_context: TypeContext,
        new_example_store: ExamplesStore,
    ) -> 'EncDecModel':
        # TODO (DNGros): acutally handle the new type context.

        # TODO check the name of the query encoder
        if state_dict['name'] == "EncoderDecoder":
            # I don't think this is right??
            query_encoder = encoders.StringQueryEncoder.create_from_save_state_dict(
                state_dict['query_encoder'])
        else:
            query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict(
                state_dict['query_encoder'])
        parser = StringParser(new_type_context)
        unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer())
        replacers = get_all_replacers()
        decoder = TreeRNNDecoder.create_from_save_state_dict(
            state_dict['decoder'], new_type_context, new_example_store,
            replacers, parser, unparser)
        model = cls(type_context=new_type_context,
                    query_encoder=query_encoder,
                    tree_decoder=decoder)
        if state_dict['need_latent_train']:
            update_latent_store_from_examples(
                model, decoder.action_selector.latent_store,
                new_example_store, replacers, parser, None, unparser,
                query_encoder.get_tokenizer())
        return model
예제 #13
0
def get_default_encdec_model(examples: ExamplesStore,
                             standard_size=16,
                             replacer: Replacer = None,
                             use_retrieval_decoder: bool = False,
                             pretrain_checkpoint: str = None,
                             learn_on_top_pretrained: bool = False):
    (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers(
        use_word_piece=pretrain_checkpoint is not None)
    if x_vocab is None or pretrain_checkpoint is None:
        # If we don't have a pretrained_checkpoint, then we might not know all
        # words in the vocab, so should filter the vocab to only tokens in dataset
        x_vocab = vocab.make_x_vocab_from_examples(examples,
                                                   x_tokenizer,
                                                   replacer,
                                                   min_freq=5,
                                                   replace_samples=3)
    hidden_size = standard_size
    tc = examples.type_context
    encoder = make_default_query_encoder(x_tokenizer, x_vocab, hidden_size,
                                         pretrain_checkpoint)
    if not use_retrieval_decoder:
        decoder = decoders.get_default_nonretrieval_decoder(tc, hidden_size)
    else:
        parser = StringParser(tc)
        unparser = AstUnparser(tc, x_tokenizer)
        decoder = decoders.get_default_retrieval_decoder(
            tc, hidden_size, examples, replacer, parser, unparser)
    model = EncDecModel(examples.type_context, encoder, decoder)
    if use_retrieval_decoder:
        # TODO lolz, this is such a crappy interface
        model.plz_train_this_latent_store_thanks = lambda: decoder.action_selector.latent_store
    return model
예제 #14
0
 def __init__(self,
              prediction: ObjectChoiceNode,
              ground_truth: AstObjectChoiceSet,
              y_texts: Set[str],
              x_text: str,
              exception,
              unparser: AstUnparser,
              know_pred_str: str = None):
     self.data = {}
     self.prediction = prediction
     self.ground_truth = ground_truth
     self.y_texts = y_texts
     self.x_text = x_text
     self.p_exception = exception
     if self.prediction is not None:
         try:
             self.predicted_y = unparser.to_string(self.prediction,
                                                   self.x_text).total_string
         except RecursionError as e:
             self.predicted_y = f"UNPARSE RECURSION LIMIT HIT"
     else:
         self.predicted_y = f"EXCEPTION {str(self.p_exception)}"
     self.in_ast_set = self.ground_truth.is_node_known_valid(
         self.prediction)
     self.correct = self.in_ast_set or self.predicted_y in self.y_texts
     self.known_pred_str = know_pred_str
     if self.correct and not self.in_ast_set:
         warnings.warn(
             f"The prediction is not in ground truth but value "
             f"matches a y string. "
             f"Prediction text {self.predicted_y} actuals {self.y_texts}")
     self._fill_stats()
예제 #15
0
def test_touch_set(all_the_stuff_context):
    x_str = 'set the last mod time of out.txt to now'
    tc = all_the_stuff_context
    parser = StringParser(tc)
    string = "touch out.txt"
    ast = parser.create_parse_tree(string, "Program")
    unparser = AstUnparser(tc, NonLetterTokenizer())
    result = unparser.to_string(ast, x_str)
    assert result.total_string == string

    cset = AstObjectChoiceSet(tc.get_type_by_name("Program"))
    cset.add(ast, True, 1, 1)
    new_ast = parser.create_parse_tree(string, "Program")
    assert cset.is_node_known_valid(new_ast)

    tokenizer = NonLetterTokenizer()
    _, tok_metadata = tokenizer.tokenize(x_str)
    ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata)
    add_copies_to_ast_set(ast, cset, unparser, tok_metadata)
    assert cset.is_node_known_valid(ast_copies)
    assert cset.is_node_known_valid(ast)

    # Scary complicated reconstruction of something that broke it.
    # could be made into a simpler unit test in copy_tools
    touch_o = tc.get_object_by_name("touch")
    file_list = tc.get_type_by_name("PathList")
    r_arg = touch_o.get_arg_by_name("r")
    m_arg = touch_o.get_arg_by_name("m")
    other_copy = ObjectChoiceNode(
        tc.get_type_by_name("Program"),
        ObjectNode(
            touch_o,
            pmap({
                "r":
                ObjectChoiceNode(r_arg.present_choice_type,
                                 ObjectNode(r_arg.not_present_object, pmap())),
                "m":
                ObjectChoiceNode(m_arg.present_choice_type,
                                 ObjectNode(m_arg.not_present_object, pmap())),
                "file_list":
                ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14))
            })))
    other_result = unparser.to_string(other_copy, x_str)
    assert other_result.total_string == string
    assert cset.is_node_known_valid(other_copy)
예제 #16
0
def full_ret_from_example_store(example_store: ExamplesStore,
                                replacers: Replacer,
                                pretrained_checkpoint: str,
                                replace_samples: int = REPLACEMENT_SAMPLES,
                                encoder_name="CM") -> FullRetModel:
    output_size = 200

    # for glove the parseable vocab is huge. Narrow it down
    #query_vocab = make_x_vocab_from_examples(example_store, x_tokenizer, replacers,
    #                                         min_freq=2, replace_samples=2)
    #print(f"vocab len {len(query_vocab)}")

    if encoder_name == "CM":
        (x_tokenizer, query_vocab), y_tokenizer = get_default_tokenizers(
            use_word_piece=True)
        embedder = make_default_query_encoder(x_tokenizer, query_vocab,
                                              output_size,
                                              pretrained_checkpoint)
        embedder.eval()
    elif encoder_name == "BERT":
        from ainix_kernel.models.EncoderDecoder.bertencoders import BertEncoder
        embedder = BertEncoder()
    else:
        raise ValueError(f"Unsupported encoder_name {encoder_name}")
    print(
        f"Number of embedder params: {get_number_of_model_parameters(embedder)}"
    )
    parser = StringParser(example_store.type_context)
    unparser = AstUnparser(example_store.type_context,
                           embedder.get_tokenizer())
    summaries, example_refs, example_splits = [], [], []
    nb_update_fn, finalize_nb_fn = get_nb_learner()
    with torch.no_grad():
        for xval in tqdm(list(example_store.get_all_x_values())):
            # Get the most prefered y text
            all_y_examples = example_store.get_y_values_for_y_set(
                xval.y_set_id)
            most_preferable_y = all_y_examples[0]
            new_summary, new_example_ref = _preproc_example(
                xval,
                most_preferable_y,
                replacers,
                embedder,
                parser,
                unparser,
                nb_update_fn if xval.split == DataSplits.TRAIN else None,
                replace_samples=replace_samples)
            summaries.append(new_summary)
            example_refs.append(new_example_ref)
            example_splits.append(xval.split)
    return FullRetModel(
        embedder=embedder,
        summaries=torch.stack(summaries),
        dataset_splits=torch.tensor(example_splits),
        example_refs=np.array(example_refs),
        choice_models=None  #finalize_nb_fn()
    )
예제 #17
0
def test_file_replacer():
    replacements = _load_replacer_relative(
        "../../../training/augmenting/data/FILENAME.tsv")
    tc = TypeContext()
    loader = TypeContextDataLoader(tc, up_search_limit=4)
    loader.load_path("builtin_types/generic_parsers.ainix.yaml")
    loader.load_path("builtin_types/command.ainix.yaml")
    loader.load_path("builtin_types/paths.ainix.yaml")
    allspecials.load_all_special_types(tc)
    tc.finalize_data()
    parser = StringParser(tc)
    unparser = AstUnparser(tc)

    for repl in replacements:
        x, y = repl.get_replacement()
        assert x == y
        ast = parser.create_parse_tree(x, "Path")
        result = unparser.to_string(ast)
        assert result.total_string == x
예제 #18
0
def test_generic_word():
    context = TypeContext()
    loader.load_path("builtin_types/generic_parsers.ainix.yaml",
                     context,
                     up_search_limit=4)
    generic_strings.create_generic_strings(context)
    context.finalize_data()
    parser = StringParser(context)
    ast = parser.create_parse_tree("a", WORD_TYPE_NAME)
    generic_word_ob = ast.next_node_not_copy
    assert generic_word_ob.implementation.name == WORD_OBJ_NAME
    parts_arg = generic_word_ob.get_choice_node_for_arg("parts")
    parts_v = parts_arg.next_node_not_copy
    assert parts_v.implementation.name == "word_part_a"
    mod_type_choice = parts_v.get_choice_node_for_arg(
        WORD_PART_MODIFIER_ARG_NAME)
    mod_type_object = mod_type_choice.next_node_not_copy
    assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME
    unparser = AstUnparser(context)
    result = unparser.to_string(ast)
    assert result.total_string == "a"
예제 #19
0
 def create_from_save_state_dict(cls, state_dict: dict,
                                 new_type_context: TypeContext,
                                 new_example_store: ExamplesStore):
     query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict(
         state_dict['embedder'])
     parser = StringParser(new_type_context)
     unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer())
     replacers = get_all_replacers()
     return cls(query_encoder,
                state_dict['summaries'],
                state_dict['dataset_splits'],
                example_refs=state_dict['example_refs'],
                choice_models=state_dict['choice_models'])
예제 #20
0
 def __init__(self,
              model: StringTypeTranslateCF,
              example_store: ExamplesStore,
              batch_size: int = 1,
              replacer: Replacer = None,
              loader: TypeContextDataLoader = None):
     self.model = model
     self.example_store = example_store
     self.type_context = example_store.type_context
     self.string_parser = StringParser(self.type_context)
     self.batch_size = batch_size
     self.str_tokenizer = self.model.get_string_tokenizer()
     self.unparser = AstUnparser(self.type_context, self.str_tokenizer)
     self.replacer = replacer
     self.loader = loader
     if self.replacer is None:
         self.replacer = Replacer([])
예제 #21
0
def do_train(num_replace_samples=DEFAULT_REPLACE_SAMPLES):
    type_context, index, replacers, loader = get_examples()
    (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers()

    string_parser = StringParser(type_context)
    unparser = AstUnparser(type_context, x_tokenizer)

    rsampled_examples = tqdm(itertools.chain.from_iterable(
        (iterate_data_pairs(index, replacers, string_parser, x_tokenizer,
                            unparser, None)
         for _ in range(num_replace_samples))),
                             total=index.get_num_x_values() *
                             num_replace_samples)

    # TODO this should probably be model dependent
    encoder = make_default_query_encoder(x_tokenizer, x_vocab, 200,
                                         pretrain_checkpoint)

    encoder_cache = EncoderCache(index)
    cache_examples_from_iterable(encoder_cache, encoder, rsampled_examples,
                                 DEFAULT_ENCODER_BATCH_SIZE)
예제 #22
0
def get_y_ast_sets(
        xs: List[str], yids: List[int], yhashes: List[str],
        rsamples: List[ReplacementSampling], string_tokenizer: StringTokenizer,
        index: ExamplesStore
) -> Generator[Tuple[AstSet, Set[str]], None, None]:
    type_context = index.type_context
    string_parser = StringParser(type_context)
    unparser = AstUnparser(type_context, string_tokenizer)
    verify_same_hashes(index, yids, yhashes)
    assert len(xs) == len(yids) == len(rsamples)
    for x, yid, rsample in zip(xs, yids, rsamples):
        yvalues = index.get_y_values_for_y_set(yid)
        this_x_toks, this_x_metadata = string_tokenizer.tokenize(x)
        y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set(
            y_type=type_context.get_type_by_name(yvalues[0].y_type),
            all_y_examples=yvalues,
            replacement_sample=rsample,
            string_parser=string_parser,
            this_x_metadata=this_x_metadata,
            unparser=unparser)
        yield y_ast_set, y_texts
예제 #23
0
def add_copies_to_ast_set(ast: ObjectChoiceNode,
                          ast_set: AstObjectChoiceSet,
                          unparser: AstUnparser,
                          token_metadata: StringTokensMetadata,
                          copy_node_weight: float = 1) -> None:
    """Takes in an AST that has been parsed and adds copynodes where appropriate
    to an AstSet that contains that AST"""
    unparse = unparser.to_string(ast)
    df_ast_pointers = list(ast.depth_first_iter())
    df_ast_nodes = [pointer.cur_node for pointer in ast.depth_first_iter()]
    df_ast_set = list(
        depth_first_iterate_ast_set_along_path(ast_set, df_ast_nodes))
    assert len(df_ast_nodes) == len(df_ast_set)
    for pointer, cur_set in zip(df_ast_pointers, df_ast_set):
        if isinstance(pointer.cur_node, ObjectChoiceNode):
            # TODO (DNGros): Figure out if we are handling weight and probability right
            # I think works fine now if known valid
            _try_add_copy_node_at_object_choice(pointer, cur_set, True,
                                                copy_node_weight, 1, unparse,
                                                token_metadata)
        elif isinstance(pointer.cur_node, ObjectNode):
            pass
        else:
            raise ValueError("Unrecognized node?")
예제 #24
0
def test_dot_separated_words(tc, in_str):
    parser = StringParser(tc)
    ast = parser.create_parse_tree(in_str, "DotSeparatedWords")
    unparser = AstUnparser(tc)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == in_str
예제 #25
0
     predictions = f.readlines()
 with open(args.tgt_yids, 'r') as f:
     yinfo = f.readlines()
 if args.json_preds:
     import json
     predictions = [
         " ".join(json.loads(p)['predicted_tokens'][0]) for p in predictions
     ]
 # Parse all the strings
 xs = list(map(nonascii_untokenize, xs))
 predictions = list(map(nonascii_untokenize, predictions))
 yids = []
 yhashes = []
 rsamples = []
 for info in yinfo:
     split_info = info.split()
     yids.append(int(split_info[0]))
     yhashes.append(split_info[1])
     rsamples.append(
         ReplacementSampling.from_serialized_string("".join(
             split_info[2:])))
 # Sort the items so that multiple samples of same id end up together
 yids, xs, predictions, yhashes, rsamples = \
     zip(*sorted(zip(yids, xs, predictions, yhashes, rsamples)))
 # Set up for actually predicting
 tokenizer = get_tokenizer_by_name(args.tokenizer_name)
 type_context, index, replacers, loader = get_examples()
 string_parser = StringParser(type_context)
 unparser = AstUnparser(type_context, tokenizer)
 y_asts = get_y_ast_sets(xs, yids, yhashes, rsamples, tokenizer, index)
 eval_stuff(xs, ys, predictions, y_asts, string_parser, unparser)
예제 #26
0
    for f in ALL_EXAMPLE_NAMES:
        loader.load_path(f"builtin_types/{f}.ainix.yaml")
    type_context.finalize_data()

    index = load_all_examples(type_context)
    #index = load_tellina_examples(type_context)

    print("num docs", index.get_num_x_values())
    print("num train", len(list(index.get_all_x_values((DataSplits.TRAIN, )))))

    replacers = get_all_replacers()

    model = full_ret_from_example_store(index, replacers,
                                        pretrained_checkpoint_path)
    unparser = AstUnparser(type_context, model.get_string_tokenizer())
    nb_models = model.nb_models
    program_nb = nb_models['Program']

    print(program_nb._model.sigma_)
    print(program_nb._model.theta_)

    #program_nb._model.sigma_ *= 100

    while True:
        q = input("Query: ")
        summary, mem, tokens = model.embedder([q])
        summary_and_depth = torch.cat((summary, torch.tensor([[2.0]])), dim=1)
        s1 = summary_and_depth[0].data.numpy()
        std_devs = np.sqrt(program_nb._model.sigma_)
        diffs_from_mean = program_nb._model.theta_ - s1
예제 #27
0
def test_path_parse_extension(tc, in_str):
    parser = StringParser(tc)
    ast = parser.create_parse_tree(in_str, "FileExtension")
    unparser = AstUnparser(tc)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == in_str
예제 #28
0
def test_path_list_parse_and_unparse_without_error(tc, in_str):
    parser = StringParser(tc)
    ast = parser.create_parse_tree(in_str, "PathList")
    unparser = AstUnparser(tc)
    to_string = unparser.to_string(ast)
    assert to_string.total_string == in_str
예제 #29
0
    argparer.add_argument("--replace_samples", type=int, default=1)
    default_split_train = DEFAULT_SPLITS[0]
    assert default_split_train[1] == DataSplits.TRAIN
    argparer.add_argument("--train_percent", type=float, default=default_split_train[0]*100)
    argparer.add_argument("--randomize_seed", action='store_true')
    args = argparer.parse_args()

    train_frac = args.train_percent / 100.0
    split_proportions = ((train_frac, DataSplits.TRAIN), (1-train_frac, DataSplits.VALIDATION))

    type_context, index, replacers, loader = get_examples(
        split_proportions, randomize_seed=args.randomize_seed)
    index.get_all_x_values(())
    string_parser = StringParser(type_context)
    tokenizer, vocab = get_default_pieced_tokenizer_word_list()
    unparser = AstUnparser(type_context, tokenizer)
    non_ascii_tokenizer = NonLetterTokenizer()
    def non_asci_do(string):
        toks, metad = non_ascii_tokenizer.tokenize(string)
        return " ".join(toks)

    split_to_sentences = defaultdict(list)
    num_to_do = args.replace_samples*index.get_num_x_values()
    data_iterator = itertools.chain.from_iterable(
        (iterate_data_pairs(
            index, replacers, string_parser, tokenizer, unparser, None)
         for epoch in range(args.replace_samples))
    )
    data_iterator = itertools.islice(data_iterator, num_to_do)
    for (example, this_example_replaced_x, y_ast_set,
         teacher_force_path_ast, y_texts, rsample) in tqdm(data_iterator, total=num_to_do):
예제 #30
0
    argparer.add_argument("--randomize_seed", action='store_true')
    argparer.add_argument("--nointeractive", action='store_true')
    argparer.add_argument("--eval_replace_samples", type=int, default=5)
    argparer.add_argument("--replace_samples",
                          type=int,
                          default=REPLACEMENT_SAMPLES)
    argparer.add_argument("--encoder_name", type=str, default="CM")
    args = argparer.parse_args()
    train_frac = args.train_percent / 100.0
    split_proportions = ((train_frac, DataSplits.TRAIN),
                         (1 - train_frac, DataSplits.VALIDATION))

    model, index, replacers, type_context, loader = train_the_thing(
        split_proportions, args.randomize_seed, args.replace_samples,
        args.encoder_name)
    unparser = AstUnparser(type_context, model.get_string_tokenizer())

    tran_trainer = TypeTranslateCFTrainer(model,
                                          index,
                                          replacer=replacers,
                                          loader=loader)
    logger = EvaluateLogger()
    tran_trainer.evaluate(logger,
                          dump_each=True,
                          num_replace_samples=args.eval_replace_samples)
    print_ast_eval_log(logger)

    if not args.nointeractive:
        while True:
            q = input("Query: ")
            ast, metad = model.predict(q, "CommandSequence", True)