Пример #1
0
 def leave_one_out(self):
     n_subjects = 32
     results = []
     for i in tqdm(range(1, n_subjects)):
         datasets, n_classes = prepare_cost(
             test_subjects=[i], folder="../data/data_cost/files/")
         params = self.arm_parameters()
         net = copy.copy(self.net)
         AccMetric = AccuracyMetric(self.epochs, "Accuracy Test",
                                    self.pruning, datasets, n_classes,
                                    self.net, self.quant_scheme,
                                    self.quant_params, self.collate_fn,
                                    self.splitter, self.models_path)
         AccMetric.parametrization = self.arm_parameters()
         AccMetric.reload = True
         collate_fn = copy.copy(self.collate_fn)
         if self.splitter:
             collate_fn = partial(collate_fn, max_len=params.get("max_len"))
         AccMetric.trainer.load_dataloaders(params.get("batch_size", 4),
                                            collate_fn=collate_fn)
         input_shape = get_shape_from_dataloader(
             AccMetric.trainer.dataloader["train"], params)
         old_net = self.reload_net(params, n_classes, input_shape, net)
         AccMetric.old_net = old_net
         acc = AccMetric.train_evaluate(str(self.arm) + "_test_result")
         results.append(acc)
     save(str(self.arm) + "_test_result.npy", asarray(results))
Пример #2
0
def load_experiment(data_folder="../data/data_cost/files"):

    manual_seed(42)
    np.random.seed(42)

    filterwarnings(action="ignore", category=DeprecationWarning, module=r".*")
    filterwarnings(action="ignore", module=r"torch.quantization")
    filterwarnings(action="ignore", category=UserWarning)

    datasets, n_classes = prepare_cost(folder=data_folder, image=True)
    sspace = search_space()
    quant_params = None
    collate_fn = split_arrange_pad_n_pack

    args = configuration("TRAIN")
    if not path.exists(args["ROOT"]):
        mkdir(args["ROOT"])
    time_init = time()
    sparse_exp = SparseExperiment(name=str(args["NAME"]),
                                  root=args["ROOT"],
                                  objectives=int(args["OBJECTIVES"]),
                                  pruning=bool_converter(args["PRUNING"]),
                                  epochs=args["epochs1"],
                                  datasets=datasets,
                                  classes=n_classes,
                                  search_space=sspace,
                                  net=Net,
                                  flops=int(args["FLOPS"]),
                                  quant_scheme=str(args["QUANT_SCHEME"]),
                                  quant_params=quant_params,
                                  collate_fn=collate_fn,
                                  splitter=bool_converter(args["SPLITTER"]),
                                  models_path=path.join(
                                      args["ROOT"], "models"))
    exp, data = sparse_exp.create_load_experiment()
    return exp, data
Пример #3
0
from sparsemod.utils_data import configuration, bool_converter, str_to_list
from sparsemod.sparse import Sparse
from rnn import search_space, Net, operations, split_pad_n_pack
from load_data import prepare_cost
from test import ModelTester

if __name__ == "__main__":

    manual_seed(42)
    np.random.seed(42)

    filterwarnings(action="ignore", category=DeprecationWarning, module=r".*")
    filterwarnings(action="ignore", module=r"torch.quantization")
    filterwarnings(action="ignore", category=UserWarning)

    datasets, n_classes = prepare_cost(folder="../data/data_cost/files",
                                       image=False)
    search_space = search_space()
    quant_params = {nn.LSTM, nn.Linear, nn.GRU}
    collate_fn = split_pad_n_pack

    if bool_converter(configuration("DEFAULT")["TRAIN"]):
        args = configuration("TRAIN")
        if not path.exists(args["ROOT"]):
            mkdir(args["ROOT"])
        time_init = time()
        sparse_instance = Sparse(r1=int(args["R1"]),
                                 r2=int(args["R2"]),
                                 r3=int(args["R3"]),
                                 epochs1=int(args["EPOCHS1"]),
                                 epochs2=int(args["EPOCHS2"]),
                                 epochs3=int(args["EPOCHS3"]),