def get_a_tc() -> TypeContext: type_context = TypeContext() loader = TypeContextDataLoader(type_context, up_search_limit=4) loader.load_path("builtin_types/generic_parsers.ainix.yaml") loader.load_path("builtin_types/command.ainix.yaml") loader.load_path("builtin_types/paths.ainix.yaml") allspecials.load_all_special_types(type_context) for f in ALL_EXAMPLE_NAMES: loader.load_path(f"builtin_types/{f}.ainix.yaml") type_context.finalize_data() return type_context
def serialize( model: StringTypeTranslateCF, loader: TypeContextDataLoader, save_path: str, eval_results: EvaluateLogger = None, trained_epochs = None ): ser = { "version": 0, "model": model.get_save_state_dict(), "type_loader": loader.get_save_state_dict(), "eval_results": eval_results, "trained_epochs": trained_epochs } torch.save(ser, save_path)
def get_examples(split_proportions: SPLIT_PROPORTIONS_TYPE = DEFAULT_SPLITS, randomize_seed: bool = False): type_context = TypeContext() loader = TypeContextDataLoader(type_context, up_search_limit=4) loader.load_path("builtin_types/generic_parsers.ainix.yaml") loader.load_path("builtin_types/command.ainix.yaml") loader.load_path("builtin_types/paths.ainix.yaml") allspecials.load_all_special_types(type_context) for f in ALL_EXAMPLE_NAMES: loader.load_path(f"builtin_types/{f}.ainix.yaml") type_context.finalize_data() split_seed = None if not randomize_seed else random.randint(1, 1e8) index = load_all_examples(type_context, split_proportions, split_seed) #index = load_tellina_examples(type_context) #index = load_all_and_tellina(type_context) #print("num docs", index.get_num_x_values()) #print("num train", len(list(index.get_all_x_values((DataSplits.TRAIN, ))))) replacers = get_all_replacers() return type_context, index, replacers, loader
def test_file_replacer(): replacements = _load_replacer_relative( "../../../training/augmenting/data/FILENAME.tsv") tc = TypeContext() loader = TypeContextDataLoader(tc, up_search_limit=4) loader.load_path("builtin_types/generic_parsers.ainix.yaml") loader.load_path("builtin_types/command.ainix.yaml") loader.load_path("builtin_types/paths.ainix.yaml") allspecials.load_all_special_types(tc) tc.finalize_data() parser = StringParser(tc) unparser = AstUnparser(tc) for repl in replacements: x, y = repl.get_replacement() assert x == y ast = parser.create_parse_tree(x, "Path") result = unparser.to_string(ast) assert result.total_string == x
def restore(file_name) -> Tuple[TypeContext, StringTypeTranslateCF, ExamplesStore]: save_dict = torch.load(file_name) type_context, loader = TypeContextDataLoader.restore_from_save_dict(save_dict['type_loader']) allspecials.load_all_special_types(type_context) type_context.finalize_data() need_example_store = save_dict['model'].get('need_example_store', False) if need_example_store: # TODO (DNGros) smart restoring. example_store = load_all_examples(type_context) else: example_store = None if save_dict.get('name', None) == "fullret": from ainix_kernel.models.Fullretrieval.fullretmodel import FullRetModel model = FullRetModel.create_from_save_state_dict( save_dict['model'], type_context, example_store) elif save_dict['model']['name'] == 'EncoderDecoder': from ainix_kernel.models.EncoderDecoder.encdecmodel import EncDecModel model = EncDecModel.create_from_save_state_dict( save_dict['model'], type_context, example_store) else: raise ValueError(f"Unrecognized model name {save_dict['name']}") model.end_train_session() return type_context, model, example_store
if __name__ == "__main__": pretrained_checkpoint_path = "../../checkpoints/" \ "lmchkp_iter152k_200_2rnn_total3.29_ns0.47_lm2.82.pt" output_size = 200 (x_tokenizer, query_vocab), y_tokenizer = _get_default_tokenizers() base_enc = make_default_cookie_monster_base(query_vocab, output_size) model = PretrainPoweredQueryEncoder.create_with_pretrained_checkpoint( pretrained_checkpoint_path, x_tokenizer, query_vocab, output_size, freeze_base=True) model.eval() type_context = TypeContext() loader = TypeContextDataLoader(type_context, up_search_limit=4) loader.load_path("builtin_types/generic_parsers.ainix.yaml") loader.load_path("builtin_types/command.ainix.yaml") loader.load_path("builtin_types/paths.ainix.yaml") allspecials.load_all_special_types(type_context) for f in ALL_EXAMPLE_NAMES: loader.load_path(f"builtin_types/{f}.ainix.yaml") type_context.finalize_data() index = load_all_examples(type_context) #index = load_tellina_examples(type_context) print("num docs", index.backend.index.doc_count()) replacers = get_all_replacers()