Ejemplo n.º 1
0
def get_results(file, logger):
    """Grab all the results according to the hyperparameter file."""
    results = []
    params = []
    num_epochs_retraining = []
    # Loop through all experiments
    for param in util.file.get_parameters(file, 1, 0):
        # get number of retraining epochs
        n_e = param["generated"]["retraining"]["numEpochs"]
        n_e -= param["generated"]["retraining"]["startEpoch"]

        # initialize logger and setup parameters
        logger.initialize_from_param(param)
        # run the experiment (only if necessary)
        try:
            state = logger.get_global_state()
        except ValueError:
            experiment.Evaluator(logger).run()
            state = logger.get_global_state()
        # extract the results
        results.append(copy.deepcopy(state))
        params.append(copy.deepcopy(param))
        num_epochs_retraining.append(n_e)
    return (
        OrderedDict(zip(num_epochs_retraining, results)),
        OrderedDict(zip(num_epochs_retraining, params)),
    )
Ejemplo n.º 2
0
def get_results(file, logger):
    """Grab all the results according to the hyperparameter file."""
    results = []
    params = []
    labels = []
    # Loop through all experiments
    for param in get_parameters(file, 1, 0):
        # initialize logger and setup parameters
        logger.initialize_from_param(param)
        # run the experiment (only if necessary)
        try:
            state = logger.get_global_state()
        except ValueError:
            experiment.Evaluator(logger).run()
            state = logger.get_global_state()
        # extract the results
        results.append(copy.deepcopy(state))
        params.append(copy.deepcopy(param))
        # extract the legend (based on heuristic)
        label = param["generated"]["datasetTest"].split("_")
        if len(label) > 2:
            label = label[2:]
        labels.append("_".join(label))
        # extract the plots
        graphers = logger.generate_plots(store_figs=False)

        if "imagenet" in file:
            graphers[0].graph(
                percentage_x=True,
                percentage_y=True,
                store=False,
                remove_outlier=False,
            )

        if "_rand" in file or "retraininit" in file:
            for i, grapher in enumerate(graphers[:6]):
                percentage_y = bool((i + 1) % 3)
                grapher.graph(
                    percentage_x=True,
                    percentage_y=percentage_y,
                    store=False,
                    show_ref=False,
                    show_delta=False,
                    remove_outlier=False,
                )
                if percentage_y:
                    # grapher._figure.gca().set_xlim([50, 99])
                    grapher._figure.gca().set_ylim([80, 95])

    return results, params, labels, graphers
Ejemplo n.º 3
0
def get_results(file, logger):
    """Grab all the results according to the parameter file."""
    # Loop through all experiments
    for param in get_parameters(file, 1, 0):
        # initialize logger and setup parameters
        logger.initialize_from_param(param)
        # get evaluator
        evaluator = experiment.Evaluator(logger)
        # run the experiment (only if necessary)
        try:
            state = logger.get_global_state()
        except ValueError:
            evaluator.run()
            state = logger.get_global_state()

        # extract the results
        return evaluator, state
Ejemplo n.º 4
0
def get_results(file, logger, matches_exact, matches_partial):
    """Grab all the results according to the hyperparameter file."""
    results = []
    params = []
    labels = []
    # Loop through all experiments
    for param in util.file.get_parameters(file, 1, 0):
        # extract the legend (based on heuristic)
        label = param["generated"]["datasetTest"].split("_")
        if len(label) > 2:
            label = label[2:]
        label = "_".join(label)

        # check if label conforms to any of the provided filters
        if not (
            np.any([str(mtc) == label for mtc in matches_exact]).item()
            or np.any([str(mtc) in label for mtc in matches_partial]).item()
        ):
            continue

        # remove severity from label once identified
        label = label if "CIFAR10" in label else label.split("_")[0]

        # initialize logger and setup parameters
        logger.initialize_from_param(param)
        # run the experiment (only if necessary)
        try:
            state = logger.get_global_state()
        except ValueError:
            experiment.Evaluator(logger).run()
            state = logger.get_global_state()
        # extract the results
        results.append(copy.deepcopy(state))
        params.append(copy.deepcopy(param))
        labels.append(label)
    return OrderedDict(zip(labels, results)), OrderedDict(zip(labels, params))
Ejemplo n.º 5
0
    MEAN_C = [0.485, 0.456, 0.406]
    STD_C = [0.229, 0.224, 0.225]
elif "CIFAR" in PARAM["network"]["dataset"]:
    MEAN_C = [0.4914, 0.4822, 0.4465]
    STD_C = [0.2023, 0.1994, 0.2010]
else:
    raise ValueError("Please adapt script to provide normalization of dset!")
MEAN_C = np.asarray(MEAN_C)[:, np.newaxis, np.newaxis]
STD_C = np.asarray(STD_C)[:, np.newaxis, np.newaxis]

# Now initialize the logger
LOGGER = experiment.Logger()
LOGGER.initialize_from_param(PARAM)

# Initialize the evaluator and run it (will return if nothing to compute)
COMPRESSOR = experiment.Evaluator(LOGGER)
COMPRESSOR.run()

# device settings
torch.cuda.set_device("cuda:0")
DEVICE = torch.device("cuda:0")
DEVICE_STORAGE = torch.device("cpu")

# Generate all the models we like.
# get a list of models
MODELS = [COMPRESSOR.get_all(**kw) for kw in MODELS_ALL_DESCRIPTION]
IDX_RANGES = []
NUM_TOTAL = 0
for model_array in MODELS:
    IDX_RANGES.append(np.arange(NUM_TOTAL, NUM_TOTAL + len(model_array)))
    NUM_TOTAL += len(model_array)
Ejemplo n.º 6
0
def get_results(file, logger, legend_on):
    """Grab all the results according to the hyperparameter file."""
    results = []
    params = []
    labels = []
    graphers_all = []
    # Loop through all experiments
    for param in get_parameters(file, 1, 0):
        # initialize logger and setup parameters
        logger.initialize_from_param(param)
        # run the experiment (only if necessary)
        try:
            state = logger.get_global_state()
        except ValueError:
            experiment.Evaluator(logger).run()
            state = logger.get_global_state()
        # extract the results
        results.append(copy.deepcopy(state))
        params.append(copy.deepcopy(param))
        # extract the legend (based on heuristic)
        label = param["generated"]["datasetTest"].split("_")
        if len(label) > 2:
            label = label[2:]
        labels.append("_".join(label))
        # extract the plots
        graphers = logger.generate_plots(store_figs=False)

        # modify label of x-axis
        graphers[0]._figure.gca().set_xlabel("Compression Ratio (Parameters)")

        if "cifar/retraininit" in file:
            for i, grapher in enumerate(graphers[:6]):
                percentage_y = bool((i + 1) % 3)
                grapher.graph(
                    percentage_x=True,
                    percentage_y=percentage_y,
                    store=False,
                    show_ref=False,
                    show_delta=False,
                    remove_outlier=False,
                )
                if percentage_y:
                    grapher._figure.gca().set_ylim([86, 98])
        elif "cifar/prune" in file and "_plus" in file:
            graphers[0]._figure.gca().set_xlim([20, 65])
            graphers[0]._figure.gca().set_ylim([-61, 2])
        elif "cifar/prune" in file:
            graphers[0]._figure.gca().set_ylim([-87, 5])
        elif "imagenet/prune" in file and "_plus" in file:
            graphers[0]._figure.gca().set_xlim([39, 81])
            graphers[0]._figure.gca().set_ylim([-61, 2])
        elif "imagenet/prune" in file:
            graphers[0]._figure.gca().set_xlim([0, 87])
            graphers[0]._figure.gca().set_ylim([-87, 5])
        elif "imagenet/retrain/" in file:
            graphers[0]._figure.gca().set_ylim([-3.5, 1.5])
        elif "imagenet/retraincascade" in file:
            # graphers[0]._figure.gca().set_xlim([-11, 2])
            graphers[0]._figure.gca().set_ylim([-2.5, 1])
        elif "imagenet/retrain" in file:
            graphers[0]._figure.gca().set_ylim([-11, 2])
        elif "cifar/retrainablation/" in file:
            graphers[0]._figure.gca().set_ylim([-3.2, 1.2])
        elif "cifar/retrain/densenet" in file:
            graphers[0]._figure.gca().set_xlim([14, 78])
            graphers[0]._figure.gca().set_ylim([-3.5, 1.5])
        elif "cifar/retrain/vgg" in file:
            graphers[0]._figure.gca().set_xlim([70, 97.5])
            graphers[0]._figure.gca().set_ylim([-3.5, 1.5])
        elif "cifar/retrain/wrn" in file:
            graphers[0]._figure.gca().set_xlim([85, 97.5])
            graphers[0]._figure.gca().set_ylim([-2.5, 0.0])
        elif "cifar/retrain/" in file:
            # graphers[0]._figure.gca().set_ylim([-11, 3])
            graphers[0]._figure.gca().set_ylim([-3.5, 1.5])
        elif "cifar/retrainlittle/" in file:
            graphers[0]._figure.gca().set_ylim([-3.5, 1.5])
        elif "cifar/retrain" in file:
            graphers[0]._figure.gca().set_ylim([-11, 2])
        elif "voc/prune" in file:
            graphers[0]._figure.gca().set_xlim([0, 90])
            graphers[0]._figure.gca().set_ylim([-87, 5])
        elif "voc/retrain" in file:
            graphers[0]._figure.gca().set_ylim([-3, 2])

        for grapher in graphers:
            legend = grapher._figure.gca().get_legend()
            if legend is not None:
                grapher._figure.gca().get_legend().remove()
                legend.set_bbox_to_anchor((1.1, 0.7))

        if legend_on:
            graphers[0].graph(
                percentage_x=True,
                percentage_y=True,
                store=False,
                kwargs_legend={
                    "loc": "upper left",
                    "ncol": 1,
                    "bbox_to_anchor": (1.1, 0.9),
                },
            )

        graphers_all.append(graphers)

    return results, params, labels, graphers_all