def load_experiment(data_folder="../data/data_cost/files"): manual_seed(42) np.random.seed(42) filterwarnings(action="ignore", category=DeprecationWarning, module=r".*") filterwarnings(action="ignore", module=r"torch.quantization") filterwarnings(action="ignore", category=UserWarning) datasets, n_classes = prepare_cifar2(folder=data_folder) sspace = search_space() quant_params = None collate_fn = None args = configuration("TRAIN") if not path.exists(args["ROOT"]): mkdir(args["ROOT"]) time_init = time() sparse_exp = SparseExperiment( name=str(args["NAME"]), root=args["ROOT"], objectives=int(args["OBJECTIVES"]), pruning=bool_converter(args["PRUNING"]), epochs=args["epochs1"], datasets=datasets, classes=n_classes, search_space=sspace, net=Net, flops=int(args["FLOPS"]), quant_scheme=str(args["QUANT_SCHEME"]), quant_params=quant_params, collate_fn=collate_fn, splitter=bool_converter(args["SPLITTER"]), models_path=path.join(args["ROOT"], "models"), ) exp, data = sparse_exp.create_load_experiment() return exp, data
if __name__ == "__main__": manual_seed(42) np.random.seed(42) filterwarnings(action="ignore", category=DeprecationWarning, module=r".*") filterwarnings(action="ignore", module=r"torch.quantization") filterwarnings(action="ignore", category=UserWarning) datasets, n_classes = prepare_cost(folder="../data/data_cost/files", image=False) search_space = search_space() quant_params = {nn.LSTM, nn.Linear, nn.GRU} collate_fn = split_pad_n_pack if bool_converter(configuration("DEFAULT")["TRAIN"]): args = configuration("TRAIN") if not path.exists(args["ROOT"]): mkdir(args["ROOT"]) time_init = time() sparse_instance = Sparse(r1=int(args["R1"]), r2=int(args["R2"]), r3=int(args["R3"]), epochs1=int(args["EPOCHS1"]), epochs2=int(args["EPOCHS2"]), epochs3=int(args["EPOCHS3"]), name=str(args["NAME"]), root=args["ROOT"], objectives=str_to_list(args["OBJECTIVES"]), batch_size=int(args["BATCH_SIZE"]), morphisms=bool_converter(args["MORPHISMS"]),