def test_not_fail_find_expr(all_the_stuff_context, string): tc = all_the_stuff_context parser = StringParser(tc) ast = parser.create_parse_tree(string, "FindExpression") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast) assert result.total_string == string
def test_string_parse_e2e_pos_unparse(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "p1", "FooType", arg_data={ POSITION: 0, MULTIWORD_POS_ARG: False }, parent_object_name="bw", required=True) ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello foo", "Program") unparser = AstUnparser(type_context) unparse_result = unparser.to_string(ast) assert unparse_result.total_string == "hello foo" for p in ast.depth_first_iter(): n = p.cur_node if isinstance(n, ObjectChoiceNode) and n.type_to_choose.name == "FooType": arg_node_pointer = p break assert unparse_result.pointer_to_string(arg_node_pointer) == "foo"
def test_string_parse_e2e_multiword2(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo bar"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "p1", "FooType", arg_data={ POSITION: 0, MULTIWORD_POS_ARG: True }, parent_object_name="sdf") ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello foo bar -a", "Program") unparser = AstUnparser(type_context) to_string = unparser.to_string(ast) assert to_string.total_string == "hello -a foo bar"
def test_string_parse_e2e_sequence(type_context): twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "barg", None, arg_data={"short_name": "b"}, parent_object_name="bw") ], type_data={"invoke_name": "hello"}) parser = StringParser(type_context) unparser = AstUnparser(type_context) string = "hello -a | hello -b" ast = parser.create_parse_tree(string, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string no_space = "hello -a|hello -b" ast = parser.create_parse_tree(no_space, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string string = "hello -a | hello" ast = parser.create_parse_tree(string, "CommandSequence") to_string = unparser.to_string(ast) assert to_string.total_string == string
def test_word_parts_2(): tc = TypeContext() _create_root_types(tc) _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True), ("!", False)]) tc.finalize_data() parser = StringParser(tc) ast = parser.create_parse_tree("fooBarBaz", WORD_PART_TYPE_NAME, allow_partial_consume=True) word_part_o = ast.next_node_not_copy assert word_part_o.implementation.name == _name_for_word_part("foo") mod_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_MODIFIER_ARG_NAME) mod_type_object = mod_type_choice.next_node_not_copy assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME next_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_NEXT_ARG_NAME) next_part_o = next_type_choice.next_node_not_copy assert next_part_o.implementation.name == _name_for_word_part("bar") ### Unparse unparser = AstUnparser(tc) result = unparser.to_string(ast) assert result.total_string == "fooBar" pointers = list(ast.depth_first_iter()) assert get_str_and_assert_same_part(result, pointers[1], word_part_o) == "fooBar" assert get_str_and_assert_same_part(result, pointers[2], mod_type_choice) == "foo" assert get_str_and_assert_same_part(result, pointers[3], mod_type_object) == "" assert get_str_and_assert_same_part(result, pointers[4], next_type_choice) == "Bar" assert get_str_and_assert_same_part(result, pointers[5], next_part_o) == "Bar"
def test_string_parse_e2e_multiword3(type_context): fooType = AInixType(type_context, "FooType") fo = AInixObject( type_context, "fo", "FooType", [], preferred_object_parser_name=create_object_parser_from_grammar( type_context, "fooname", '"foo"').name) twoargs = AInixObject(type_context, "FooProgram", "Program", [ AInixArgument(type_context, "a", None, arg_data={"short_name": "a"}, parent_object_name="sdf"), AInixArgument(type_context, "barg", None, arg_data={"short_name": "b"}, parent_object_name="bw"), _make_positional() ], type_data={"invoke_name": "hello"}) type_context.finalize_data() parser = StringParser(type_context) ast = parser.create_parse_tree("hello -a", "CommandSequence") unparser = AstUnparser(type_context) to_string = unparser.to_string(ast) assert to_string.total_string == "hello -a"
def test_word_parts_upper(): tc = TypeContext() _create_root_types(tc) _create_all_word_parts(tc, [('foo', True), ("bar", True), ("fo", True), ("!", False)]) tc.finalize_data() parser = StringParser(tc) ast = parser.create_parse_tree("FOO", WORD_PART_TYPE_NAME) word_part_o = ast.next_node_not_copy assert word_part_o.implementation.name == _name_for_word_part("foo") mod_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_MODIFIER_ARG_NAME) mod_type_object = mod_type_choice.next_node_not_copy assert mod_type_object.implementation.name == MODIFIER_ALL_UPPER next_type_choice = word_part_o.get_choice_node_for_arg( WORD_PART_NEXT_ARG_NAME) next_part_o = next_type_choice.next_node_not_copy assert next_part_o.implementation.name == WORD_PART_TERMINAL_NAME ### Unparse unparser = AstUnparser(tc) result = unparser.to_string(ast) assert result.total_string == "FOO" pointers = list(ast.depth_first_iter()) assert ast == pointers[0].cur_node assert result.pointer_to_string(pointers[0]) == "FOO" assert word_part_o == pointers[1].cur_node assert result.pointer_to_string(pointers[1]) == "FOO" assert mod_type_choice == pointers[2].cur_node assert result.pointer_to_string(pointers[2]) == "FOO" assert mod_type_object == pointers[3].cur_node assert result.pointer_to_string(pointers[3]) == ""
def test_cp(all_the_stuff_context, string): tc = all_the_stuff_context parser = StringParser(tc) ast = parser.create_parse_tree(string, "Program") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast) assert result.total_string == string pointers = list(ast.depth_first_iter()) assert result.pointer_to_string(pointers[0]) == string
def __init__(self, file_name): #self.type_context, self.model, self.example_store = restore(file_name) # hacks from ainix_kernel.training import fullret_try model, index, replacers, type_context, loader = fullret_try.train_the_thing() self.type_context, self.model, self.example_store = type_context, model, index self.unparser = AstUnparser(self.type_context, self.model.get_string_tokenizer())
def make_copy_version_of_tree( ast: ObjectChoiceNode, unparser: AstUnparser, token_metadata: StringTokensMetadata) -> ObjectChoiceNode: """Goes through and replaces anywhere can copy with a copy at earliest possible oppertunity. Makes a copy and does not mutate the original ast""" unparse = unparser.to_string(ast) cur_pointer = AstIterPointer(ast, None, None) last_pointer = None while cur_pointer: if isinstance(cur_pointer.cur_node, ObjectChoiceNode): this_node_str = unparse.pointer_to_string(cur_pointer) if this_node_str: copy_pos = string_in_tok_list(this_node_str, token_metadata) # Questionable hacky skip? # Need to avoid being overly generous with copying stuff other_no_copy_reason = this_node_str in STOP_WORDS or \ this_node_str in ('"', ".", ",", "?") if copy_pos and not other_no_copy_reason: copy_node = CopyNode(cur_pointer.cur_node.type_to_choose, copy_pos[0], copy_pos[1]) cur_pointer = cur_pointer.dfs_get_next().change_here( copy_node, always_clone=True) last_pointer = cur_pointer cur_pointer = cur_pointer.dfs_get_next() return last_pointer.get_root().cur_node
class Interface(): def __init__(self, file_name): #self.type_context, self.model, self.example_store = restore(file_name) # hacks from ainix_kernel.training import fullret_try model, index, replacers, type_context, loader = fullret_try.train_the_thing() self.type_context, self.model, self.example_store = type_context, model, index self.unparser = AstUnparser(self.type_context, self.model.get_string_tokenizer()) def predict(self, utterance: str, ytype: str) -> PredictReturn: try: result, metad = self.model.predict(utterance, ytype, False) assert result.is_frozen unparse = self.unparser.to_string(result, utterance) return PredictReturn( success=True, ast=result, unparse=unparse, metad=metad, error_message=None ) except ModelException as e: return PredictReturn( False, None, None, None, str(e) )
def create_from_save_state_dict( cls, state_dict: dict, new_type_context: TypeContext, new_example_store: ExamplesStore, ) -> 'EncDecModel': # TODO (DNGros): acutally handle the new type context. # TODO check the name of the query encoder if state_dict['name'] == "EncoderDecoder": # I don't think this is right?? query_encoder = encoders.StringQueryEncoder.create_from_save_state_dict( state_dict['query_encoder']) else: query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict( state_dict['query_encoder']) parser = StringParser(new_type_context) unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer()) replacers = get_all_replacers() decoder = TreeRNNDecoder.create_from_save_state_dict( state_dict['decoder'], new_type_context, new_example_store, replacers, parser, unparser) model = cls(type_context=new_type_context, query_encoder=query_encoder, tree_decoder=decoder) if state_dict['need_latent_train']: update_latent_store_from_examples( model, decoder.action_selector.latent_store, new_example_store, replacers, parser, None, unparser, query_encoder.get_tokenizer()) return model
def get_default_encdec_model(examples: ExamplesStore, standard_size=16, replacer: Replacer = None, use_retrieval_decoder: bool = False, pretrain_checkpoint: str = None, learn_on_top_pretrained: bool = False): (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers( use_word_piece=pretrain_checkpoint is not None) if x_vocab is None or pretrain_checkpoint is None: # If we don't have a pretrained_checkpoint, then we might not know all # words in the vocab, so should filter the vocab to only tokens in dataset x_vocab = vocab.make_x_vocab_from_examples(examples, x_tokenizer, replacer, min_freq=5, replace_samples=3) hidden_size = standard_size tc = examples.type_context encoder = make_default_query_encoder(x_tokenizer, x_vocab, hidden_size, pretrain_checkpoint) if not use_retrieval_decoder: decoder = decoders.get_default_nonretrieval_decoder(tc, hidden_size) else: parser = StringParser(tc) unparser = AstUnparser(tc, x_tokenizer) decoder = decoders.get_default_retrieval_decoder( tc, hidden_size, examples, replacer, parser, unparser) model = EncDecModel(examples.type_context, encoder, decoder) if use_retrieval_decoder: # TODO lolz, this is such a crappy interface model.plz_train_this_latent_store_thanks = lambda: decoder.action_selector.latent_store return model
def __init__(self, prediction: ObjectChoiceNode, ground_truth: AstObjectChoiceSet, y_texts: Set[str], x_text: str, exception, unparser: AstUnparser, know_pred_str: str = None): self.data = {} self.prediction = prediction self.ground_truth = ground_truth self.y_texts = y_texts self.x_text = x_text self.p_exception = exception if self.prediction is not None: try: self.predicted_y = unparser.to_string(self.prediction, self.x_text).total_string except RecursionError as e: self.predicted_y = f"UNPARSE RECURSION LIMIT HIT" else: self.predicted_y = f"EXCEPTION {str(self.p_exception)}" self.in_ast_set = self.ground_truth.is_node_known_valid( self.prediction) self.correct = self.in_ast_set or self.predicted_y in self.y_texts self.known_pred_str = know_pred_str if self.correct and not self.in_ast_set: warnings.warn( f"The prediction is not in ground truth but value " f"matches a y string. " f"Prediction text {self.predicted_y} actuals {self.y_texts}") self._fill_stats()
def test_touch_set(all_the_stuff_context): x_str = 'set the last mod time of out.txt to now' tc = all_the_stuff_context parser = StringParser(tc) string = "touch out.txt" ast = parser.create_parse_tree(string, "Program") unparser = AstUnparser(tc, NonLetterTokenizer()) result = unparser.to_string(ast, x_str) assert result.total_string == string cset = AstObjectChoiceSet(tc.get_type_by_name("Program")) cset.add(ast, True, 1, 1) new_ast = parser.create_parse_tree(string, "Program") assert cset.is_node_known_valid(new_ast) tokenizer = NonLetterTokenizer() _, tok_metadata = tokenizer.tokenize(x_str) ast_copies = make_copy_version_of_tree(ast, unparser, tok_metadata) add_copies_to_ast_set(ast, cset, unparser, tok_metadata) assert cset.is_node_known_valid(ast_copies) assert cset.is_node_known_valid(ast) # Scary complicated reconstruction of something that broke it. # could be made into a simpler unit test in copy_tools touch_o = tc.get_object_by_name("touch") file_list = tc.get_type_by_name("PathList") r_arg = touch_o.get_arg_by_name("r") m_arg = touch_o.get_arg_by_name("m") other_copy = ObjectChoiceNode( tc.get_type_by_name("Program"), ObjectNode( touch_o, pmap({ "r": ObjectChoiceNode(r_arg.present_choice_type, ObjectNode(r_arg.not_present_object, pmap())), "m": ObjectChoiceNode(m_arg.present_choice_type, ObjectNode(m_arg.not_present_object, pmap())), "file_list": ObjectChoiceNode(file_list, CopyNode(file_list, 12, 14)) }))) other_result = unparser.to_string(other_copy, x_str) assert other_result.total_string == string assert cset.is_node_known_valid(other_copy)
def full_ret_from_example_store(example_store: ExamplesStore, replacers: Replacer, pretrained_checkpoint: str, replace_samples: int = REPLACEMENT_SAMPLES, encoder_name="CM") -> FullRetModel: output_size = 200 # for glove the parseable vocab is huge. Narrow it down #query_vocab = make_x_vocab_from_examples(example_store, x_tokenizer, replacers, # min_freq=2, replace_samples=2) #print(f"vocab len {len(query_vocab)}") if encoder_name == "CM": (x_tokenizer, query_vocab), y_tokenizer = get_default_tokenizers( use_word_piece=True) embedder = make_default_query_encoder(x_tokenizer, query_vocab, output_size, pretrained_checkpoint) embedder.eval() elif encoder_name == "BERT": from ainix_kernel.models.EncoderDecoder.bertencoders import BertEncoder embedder = BertEncoder() else: raise ValueError(f"Unsupported encoder_name {encoder_name}") print( f"Number of embedder params: {get_number_of_model_parameters(embedder)}" ) parser = StringParser(example_store.type_context) unparser = AstUnparser(example_store.type_context, embedder.get_tokenizer()) summaries, example_refs, example_splits = [], [], [] nb_update_fn, finalize_nb_fn = get_nb_learner() with torch.no_grad(): for xval in tqdm(list(example_store.get_all_x_values())): # Get the most prefered y text all_y_examples = example_store.get_y_values_for_y_set( xval.y_set_id) most_preferable_y = all_y_examples[0] new_summary, new_example_ref = _preproc_example( xval, most_preferable_y, replacers, embedder, parser, unparser, nb_update_fn if xval.split == DataSplits.TRAIN else None, replace_samples=replace_samples) summaries.append(new_summary) example_refs.append(new_example_ref) example_splits.append(xval.split) return FullRetModel( embedder=embedder, summaries=torch.stack(summaries), dataset_splits=torch.tensor(example_splits), example_refs=np.array(example_refs), choice_models=None #finalize_nb_fn() )
def test_file_replacer(): replacements = _load_replacer_relative( "../../../training/augmenting/data/FILENAME.tsv") tc = TypeContext() loader = TypeContextDataLoader(tc, up_search_limit=4) loader.load_path("builtin_types/generic_parsers.ainix.yaml") loader.load_path("builtin_types/command.ainix.yaml") loader.load_path("builtin_types/paths.ainix.yaml") allspecials.load_all_special_types(tc) tc.finalize_data() parser = StringParser(tc) unparser = AstUnparser(tc) for repl in replacements: x, y = repl.get_replacement() assert x == y ast = parser.create_parse_tree(x, "Path") result = unparser.to_string(ast) assert result.total_string == x
def test_generic_word(): context = TypeContext() loader.load_path("builtin_types/generic_parsers.ainix.yaml", context, up_search_limit=4) generic_strings.create_generic_strings(context) context.finalize_data() parser = StringParser(context) ast = parser.create_parse_tree("a", WORD_TYPE_NAME) generic_word_ob = ast.next_node_not_copy assert generic_word_ob.implementation.name == WORD_OBJ_NAME parts_arg = generic_word_ob.get_choice_node_for_arg("parts") parts_v = parts_arg.next_node_not_copy assert parts_v.implementation.name == "word_part_a" mod_type_choice = parts_v.get_choice_node_for_arg( WORD_PART_MODIFIER_ARG_NAME) mod_type_object = mod_type_choice.next_node_not_copy assert mod_type_object.implementation.name == MODIFIER_LOWER_NAME unparser = AstUnparser(context) result = unparser.to_string(ast) assert result.total_string == "a"
def create_from_save_state_dict(cls, state_dict: dict, new_type_context: TypeContext, new_example_store: ExamplesStore): query_encoder = PretrainPoweredQueryEncoder.create_from_save_state_dict( state_dict['embedder']) parser = StringParser(new_type_context) unparser = AstUnparser(new_type_context, query_encoder.get_tokenizer()) replacers = get_all_replacers() return cls(query_encoder, state_dict['summaries'], state_dict['dataset_splits'], example_refs=state_dict['example_refs'], choice_models=state_dict['choice_models'])
def __init__(self, model: StringTypeTranslateCF, example_store: ExamplesStore, batch_size: int = 1, replacer: Replacer = None, loader: TypeContextDataLoader = None): self.model = model self.example_store = example_store self.type_context = example_store.type_context self.string_parser = StringParser(self.type_context) self.batch_size = batch_size self.str_tokenizer = self.model.get_string_tokenizer() self.unparser = AstUnparser(self.type_context, self.str_tokenizer) self.replacer = replacer self.loader = loader if self.replacer is None: self.replacer = Replacer([])
def do_train(num_replace_samples=DEFAULT_REPLACE_SAMPLES): type_context, index, replacers, loader = get_examples() (x_tokenizer, x_vocab), y_tokenizer = get_default_tokenizers() string_parser = StringParser(type_context) unparser = AstUnparser(type_context, x_tokenizer) rsampled_examples = tqdm(itertools.chain.from_iterable( (iterate_data_pairs(index, replacers, string_parser, x_tokenizer, unparser, None) for _ in range(num_replace_samples))), total=index.get_num_x_values() * num_replace_samples) # TODO this should probably be model dependent encoder = make_default_query_encoder(x_tokenizer, x_vocab, 200, pretrain_checkpoint) encoder_cache = EncoderCache(index) cache_examples_from_iterable(encoder_cache, encoder, rsampled_examples, DEFAULT_ENCODER_BATCH_SIZE)
def get_y_ast_sets( xs: List[str], yids: List[int], yhashes: List[str], rsamples: List[ReplacementSampling], string_tokenizer: StringTokenizer, index: ExamplesStore ) -> Generator[Tuple[AstSet, Set[str]], None, None]: type_context = index.type_context string_parser = StringParser(type_context) unparser = AstUnparser(type_context, string_tokenizer) verify_same_hashes(index, yids, yhashes) assert len(xs) == len(yids) == len(rsamples) for x, yid, rsample in zip(xs, yids, rsamples): yvalues = index.get_y_values_for_y_set(yid) this_x_toks, this_x_metadata = string_tokenizer.tokenize(x) y_ast_set, y_texts, teacher_force_path_ast = make_y_ast_set( y_type=type_context.get_type_by_name(yvalues[0].y_type), all_y_examples=yvalues, replacement_sample=rsample, string_parser=string_parser, this_x_metadata=this_x_metadata, unparser=unparser) yield y_ast_set, y_texts
def add_copies_to_ast_set(ast: ObjectChoiceNode, ast_set: AstObjectChoiceSet, unparser: AstUnparser, token_metadata: StringTokensMetadata, copy_node_weight: float = 1) -> None: """Takes in an AST that has been parsed and adds copynodes where appropriate to an AstSet that contains that AST""" unparse = unparser.to_string(ast) df_ast_pointers = list(ast.depth_first_iter()) df_ast_nodes = [pointer.cur_node for pointer in ast.depth_first_iter()] df_ast_set = list( depth_first_iterate_ast_set_along_path(ast_set, df_ast_nodes)) assert len(df_ast_nodes) == len(df_ast_set) for pointer, cur_set in zip(df_ast_pointers, df_ast_set): if isinstance(pointer.cur_node, ObjectChoiceNode): # TODO (DNGros): Figure out if we are handling weight and probability right # I think works fine now if known valid _try_add_copy_node_at_object_choice(pointer, cur_set, True, copy_node_weight, 1, unparse, token_metadata) elif isinstance(pointer.cur_node, ObjectNode): pass else: raise ValueError("Unrecognized node?")
def test_dot_separated_words(tc, in_str): parser = StringParser(tc) ast = parser.create_parse_tree(in_str, "DotSeparatedWords") unparser = AstUnparser(tc) to_string = unparser.to_string(ast) assert to_string.total_string == in_str
predictions = f.readlines() with open(args.tgt_yids, 'r') as f: yinfo = f.readlines() if args.json_preds: import json predictions = [ " ".join(json.loads(p)['predicted_tokens'][0]) for p in predictions ] # Parse all the strings xs = list(map(nonascii_untokenize, xs)) predictions = list(map(nonascii_untokenize, predictions)) yids = [] yhashes = [] rsamples = [] for info in yinfo: split_info = info.split() yids.append(int(split_info[0])) yhashes.append(split_info[1]) rsamples.append( ReplacementSampling.from_serialized_string("".join( split_info[2:]))) # Sort the items so that multiple samples of same id end up together yids, xs, predictions, yhashes, rsamples = \ zip(*sorted(zip(yids, xs, predictions, yhashes, rsamples))) # Set up for actually predicting tokenizer = get_tokenizer_by_name(args.tokenizer_name) type_context, index, replacers, loader = get_examples() string_parser = StringParser(type_context) unparser = AstUnparser(type_context, tokenizer) y_asts = get_y_ast_sets(xs, yids, yhashes, rsamples, tokenizer, index) eval_stuff(xs, ys, predictions, y_asts, string_parser, unparser)
for f in ALL_EXAMPLE_NAMES: loader.load_path(f"builtin_types/{f}.ainix.yaml") type_context.finalize_data() index = load_all_examples(type_context) #index = load_tellina_examples(type_context) print("num docs", index.get_num_x_values()) print("num train", len(list(index.get_all_x_values((DataSplits.TRAIN, ))))) replacers = get_all_replacers() model = full_ret_from_example_store(index, replacers, pretrained_checkpoint_path) unparser = AstUnparser(type_context, model.get_string_tokenizer()) nb_models = model.nb_models program_nb = nb_models['Program'] print(program_nb._model.sigma_) print(program_nb._model.theta_) #program_nb._model.sigma_ *= 100 while True: q = input("Query: ") summary, mem, tokens = model.embedder([q]) summary_and_depth = torch.cat((summary, torch.tensor([[2.0]])), dim=1) s1 = summary_and_depth[0].data.numpy() std_devs = np.sqrt(program_nb._model.sigma_) diffs_from_mean = program_nb._model.theta_ - s1
def test_path_parse_extension(tc, in_str): parser = StringParser(tc) ast = parser.create_parse_tree(in_str, "FileExtension") unparser = AstUnparser(tc) to_string = unparser.to_string(ast) assert to_string.total_string == in_str
def test_path_list_parse_and_unparse_without_error(tc, in_str): parser = StringParser(tc) ast = parser.create_parse_tree(in_str, "PathList") unparser = AstUnparser(tc) to_string = unparser.to_string(ast) assert to_string.total_string == in_str
argparer.add_argument("--replace_samples", type=int, default=1) default_split_train = DEFAULT_SPLITS[0] assert default_split_train[1] == DataSplits.TRAIN argparer.add_argument("--train_percent", type=float, default=default_split_train[0]*100) argparer.add_argument("--randomize_seed", action='store_true') args = argparer.parse_args() train_frac = args.train_percent / 100.0 split_proportions = ((train_frac, DataSplits.TRAIN), (1-train_frac, DataSplits.VALIDATION)) type_context, index, replacers, loader = get_examples( split_proportions, randomize_seed=args.randomize_seed) index.get_all_x_values(()) string_parser = StringParser(type_context) tokenizer, vocab = get_default_pieced_tokenizer_word_list() unparser = AstUnparser(type_context, tokenizer) non_ascii_tokenizer = NonLetterTokenizer() def non_asci_do(string): toks, metad = non_ascii_tokenizer.tokenize(string) return " ".join(toks) split_to_sentences = defaultdict(list) num_to_do = args.replace_samples*index.get_num_x_values() data_iterator = itertools.chain.from_iterable( (iterate_data_pairs( index, replacers, string_parser, tokenizer, unparser, None) for epoch in range(args.replace_samples)) ) data_iterator = itertools.islice(data_iterator, num_to_do) for (example, this_example_replaced_x, y_ast_set, teacher_force_path_ast, y_texts, rsample) in tqdm(data_iterator, total=num_to_do):
argparer.add_argument("--randomize_seed", action='store_true') argparer.add_argument("--nointeractive", action='store_true') argparer.add_argument("--eval_replace_samples", type=int, default=5) argparer.add_argument("--replace_samples", type=int, default=REPLACEMENT_SAMPLES) argparer.add_argument("--encoder_name", type=str, default="CM") args = argparer.parse_args() train_frac = args.train_percent / 100.0 split_proportions = ((train_frac, DataSplits.TRAIN), (1 - train_frac, DataSplits.VALIDATION)) model, index, replacers, type_context, loader = train_the_thing( split_proportions, args.randomize_seed, args.replace_samples, args.encoder_name) unparser = AstUnparser(type_context, model.get_string_tokenizer()) tran_trainer = TypeTranslateCFTrainer(model, index, replacer=replacers, loader=loader) logger = EvaluateLogger() tran_trainer.evaluate(logger, dump_each=True, num_replace_samples=args.eval_replace_samples) print_ast_eval_log(logger) if not args.nointeractive: while True: q = input("Query: ") ast, metad = model.predict(q, "CommandSequence", True)