예제 #1
0
def load_experiment(data_folder="../data/data_cost/files"):

    manual_seed(42)
    np.random.seed(42)

    filterwarnings(action="ignore", category=DeprecationWarning, module=r".*")
    filterwarnings(action="ignore", module=r"torch.quantization")
    filterwarnings(action="ignore", category=UserWarning)

    datasets, n_classes = prepare_cifar2(folder=data_folder)
    sspace = search_space()
    quant_params = None
    collate_fn = None

    args = configuration("TRAIN")
    if not path.exists(args["ROOT"]):
        mkdir(args["ROOT"])
    time_init = time()
    sparse_exp = SparseExperiment(
        name=str(args["NAME"]),
        root=args["ROOT"],
        objectives=int(args["OBJECTIVES"]),
        pruning=bool_converter(args["PRUNING"]),
        epochs=args["epochs1"],
        datasets=datasets,
        classes=n_classes,
        search_space=sspace,
        net=Net,
        flops=int(args["FLOPS"]),
        quant_scheme=str(args["QUANT_SCHEME"]),
        quant_params=quant_params,
        collate_fn=collate_fn,
        splitter=bool_converter(args["SPLITTER"]),
        models_path=path.join(args["ROOT"], "models"),
    )
    exp, data = sparse_exp.create_load_experiment()
    return exp, data
예제 #2
0
파일: main.py 프로젝트: BCJuan/SpArSeMod
if __name__ == "__main__":

    manual_seed(42)
    np.random.seed(42)

    filterwarnings(action="ignore", category=DeprecationWarning, module=r".*")
    filterwarnings(action="ignore", module=r"torch.quantization")
    filterwarnings(action="ignore", category=UserWarning)

    datasets, n_classes = prepare_cost(folder="../data/data_cost/files",
                                       image=False)
    search_space = search_space()
    quant_params = {nn.LSTM, nn.Linear, nn.GRU}
    collate_fn = split_pad_n_pack

    if bool_converter(configuration("DEFAULT")["TRAIN"]):
        args = configuration("TRAIN")
        if not path.exists(args["ROOT"]):
            mkdir(args["ROOT"])
        time_init = time()
        sparse_instance = Sparse(r1=int(args["R1"]),
                                 r2=int(args["R2"]),
                                 r3=int(args["R3"]),
                                 epochs1=int(args["EPOCHS1"]),
                                 epochs2=int(args["EPOCHS2"]),
                                 epochs3=int(args["EPOCHS3"]),
                                 name=str(args["NAME"]),
                                 root=args["ROOT"],
                                 objectives=str_to_list(args["OBJECTIVES"]),
                                 batch_size=int(args["BATCH_SIZE"]),
                                 morphisms=bool_converter(args["MORPHISMS"]),