Exemplo n.º 1
0
def evaluate_population(
    inputs, targets, pool, model, criterion, clipvalue=[-np.inf, np.inf]
):
    """Optimisation function of the platform """
    outputs_pool = torch.zeros((len(pool),) + (len(inputs), 1),
                               dtype=TorchUtils.get_data_type(),
                               device=TorchUtils.get_accelerator_type()
                               )
    criterion_pool = torch.zeros(len(pool),
                                 dtype=TorchUtils.get_data_type(),
                                 device=TorchUtils.get_accelerator_type())
    for j in range(len(pool)):

        # control_voltage_genes = self.get_control_voltages(gene_pool[j], len(inputs_wfm))  # , gene_pool[j, self.gene_trafo_index]
        # inputs_without_offset_and_scale = self._input_trafo(inputs_wfm, gene_pool[j, self.gene_trafo_index])
        # assert False, 'Check the case for inputing voltages with plateaus to check if it works when merging control voltages and inputs'
        model.set_control_voltages(pool[j])
        outputs_pool[j] = model(inputs)

        if torch.any(outputs_pool[j] <= model.get_clipping_value()[0]) or torch.any(
            outputs_pool[j] >= model.get_clipping_value()[1]
        ) or (outputs_pool[j] - outputs_pool[j].mean() == 0.).all():
            criterion_pool[j] = criterion(None, None, default_value=True)
        else:
            criterion_pool[j] = criterion(outputs_pool[j], targets)

        # output_popul[j] = self.processor.get_output(merge_inputs_and_control_voltages_in_numpy(inputs_without_offset_and_scale, control_voltage_genes, self.input_indices, self.control_voltage_indices))
    return outputs_pool, criterion_pool
Exemplo n.º 2
0
def postprocess(configs,
                dataset,
                model,
                criterion,
                logger,
                node=None,
                waveform_transforms=None,
                save_dir=None,
                name="train"):
    results = {}
    with torch.no_grad():
        model.eval()
        inputs, targets = dataset[:]
        indices = torch.argsort(targets[:, 0], dim=0)
        inputs, targets = inputs[indices], targets[indices]
        if waveform_transforms is not None:
            inputs, targets = waveform_transforms([inputs, targets])
        if inputs.device != TorchUtils.get_accelerator_type():
            inputs = inputs.to(device=TorchUtils.get_accelerator_type())
        if targets.device != TorchUtils.get_accelerator_type():
            targets = targets.to(device=TorchUtils.get_accelerator_type())
        predictions = model(inputs)
        results["performance"] = criterion(predictions, targets)

    # results['gap'] = dataset.gap
    results["inputs"] = inputs
    results["targets"] = targets
    results["best_output"] = predictions
    results["accuracy"] = get_accuracy(
        predictions, targets, configs, node=node
    )  # accuracy(predictions.squeeze(), targets.squeeze(), plot=None, return_node=True)
    results["correlation"] = pearsons_correlation(predictions, targets)
    # results['accuracy_fig'] = plot_perceptron(results['accuracy'], save_dir, name=name)

    return results
Exemplo n.º 3
0
    def __init__(self, processor, inputs_list):
        super(DNPU_Base, self).__init__()
        if isinstance(processor, Processor):
            self.processor = processor
        else:
            self.processor = Processor(
                processor
            )  # It accepts initialising a processor as a dictionary
        ######### Set up node #########
        # Freeze parameters of node
        for params in self.processor.parameters():
            params.requires_grad = False

        self.indices_node = np.arange(
            len(self.processor.data_input_indices) +
            len(self.processor.control_indices))
        ######### set learnable parameters #########
        self.control_list = TorchUtils.get_tensor_from_list(
            self.set_controls(inputs_list), data_type=torch.int64)

        ######### Initialise data input ranges #########
        self.data_input_low = torch.stack([
            self.processor.processor.voltage_ranges[indx_cv, 0]
            for indx_cv in inputs_list
        ])
        self.data_input_high = torch.stack([
            self.processor.processor.voltage_ranges[indx_cv, 1]
            for indx_cv in inputs_list
        ])

        ###### Set everything as torch Tensors and send to DEVICE ######
        self.inputs_list = TorchUtils.get_tensor_from_list(
            inputs_list, data_type=torch.int64)
Exemplo n.º 4
0
def default_validate_gate(gate_base_dir, validation_processor_configs):
    model = torch.load(
        os.path.join(gate_base_dir, 'reproducibility', 'model.pt'),
        map_location=torch.device(TorchUtils.get_accelerator_type()))
    results = torch.load(
        os.path.join(gate_base_dir, 'reproducibility', "results.pickle"),
        map_location=torch.device(TorchUtils.get_accelerator_type()))
    experiment_configs = load_configs(
        os.path.join(gate_base_dir, 'reproducibility', 'configs.yaml'))

    results_dir = init_dirs(gate_base_dir, is_main=True)

    criterion = manager.get_criterion(experiment_configs["algorithm"])

    waveform_transforms = transforms.Compose([
        PlateausToPoints(
            experiment_configs['processor']["data"]['waveform']
        ),  # Required to remove plateaus from training because the perceptron cannot accept less than 10 values for each gate
        PointsToPlateaus(validation_processor_configs["data"]["waveform"])
    ])

    validate_gate(model,
                  results,
                  validation_processor_configs,
                  criterion,
                  results_dir=results_dir,
                  transforms=waveform_transforms)
Exemplo n.º 5
0
    def crossover_blxab(self, parent1, parent2):
        """
        Crossover method: Blend alpha beta crossover returns a new genome (voltage combination)
        from two parents. Here, parent 1 has a higher fitness than parent 2
        """

        # check this in pytorch
        maximum = torch.max(parent1, parent2)
        minimum = torch.min(parent1, parent2)
        diff_maxmin = maximum - minimum
        offspring = torch.zeros((parent1.shape),
                                dtype=TorchUtils.get_data_type(),
                                device=TorchUtils.get_accelerator_type())
        for i in range(len(parent1)):
            if parent1[i] > parent2[i]:
                offspring[i] = uniform(
                    minimum[i] - diff_maxmin[i] * self.beta,
                    maximum[i] + diff_maxmin[i] * self.alpha,
                ).sample()
            else:
                offspring[i] = uniform(
                    minimum[i] - diff_maxmin[i] * self.alpha,
                    maximum[i] + diff_maxmin[i] * self.beta,
                ).sample()
        for i in range(0, len(self.gene_range)):
            if offspring[i] < self.gene_range[i][0]:
                offspring[i] = self.gene_range[i][0]
            if offspring[i] > self.gene_range[i][1]:
                offspring[i] = self.gene_range[i][1]
        return offspring
Exemplo n.º 6
0
def plot_search_results(label, results, save_dir, extension="png", show_plots=False):
    accuracy_per_run = TorchUtils.get_numpy_from_tensor(results["accuracy_per_run"])
    performance_per_run = TorchUtils.get_numpy_from_tensor(
        results["performance_per_run"]
    )

    plt.figure()
    plt.plot(accuracy_per_run, performance_per_run, "o")
    plt.title("Accuracy vs Fisher (" + label + ")")
    plt.xlabel("Accuracy")
    plt.ylabel("Fisher value")
    plt.savefig(os.path.join(save_dir, "accuracy_vs_fisher_" + label + "." + extension))

    plt.figure()
    plt.hist(performance_per_run, 100)
    plt.title("Histogram of Fisher values (" + label + ")")
    plt.xlabel("Fisher values")
    plt.ylabel("Counts")
    plt.savefig(
        os.path.join(save_dir, "fisher_values_histogram_" + label + "." + extension)
    )

    plt.figure()
    plt.hist(accuracy_per_run, 100)
    plt.title("Histogram of Accuracy values")
    plt.xlabel("Accuracy values")
    plt.ylabel("Counts")
    plt.savefig(os.path.join(save_dir, "accuracy_histogram_" + label + "." + extension))

    if show_plots:
        plt.show()
Exemplo n.º 7
0
def plot_perceptron(results, save_dir=None, show_plot=False, name="train"):
    fig = plt.figure()
    plt.title(f"Accuracy: {results['accuracy_value']:.2f} %")
    plt.plot(
        TorchUtils.get_numpy_from_tensor(results["norm_inputs"]), label="Norm. Waveform"
    )
    plt.plot(
        TorchUtils.get_numpy_from_tensor(results["predicted_labels"]),
        ".",
        label="Predicted labels",
    )
    plt.plot(TorchUtils.get_numpy_from_tensor(results["targets"]), "g", label="Targets")
    plt.plot(
        np.arange(len(results["predicted_labels"])),
        TorchUtils.get_numpy_from_tensor(
            torch.ones_like(results["predicted_labels"]) * results["norm_threshold"]
        ),
        "k:",
        label="Norm. Threshold",
    )
    plt.legend()
    if show_plot:
        plt.show()
    if save_dir is not None:
        plt.savefig(os.path.join(save_dir, name + "_accuracy.jpg"))
    plt.close()
    return fig
Exemplo n.º 8
0
 def __call__(self, data):
     inputs, targets = data[0], data[1]
     if inputs.device != TorchUtils.get_accelerator_type():
         inputs = inputs.to(device=TorchUtils.get_accelerator_type())
     if targets.device != TorchUtils.get_accelerator_type():
         targets = targets.to(device=TorchUtils.get_accelerator_type())
     return (inputs, targets)
Exemplo n.º 9
0
def load_reproducibility_results(base_dir, model_name="model.pt"):
    base_dir = os.path.join(base_dir, "reproducibility")
    # configs = load_configs(os.path.join(gate_base_dir, 'configs.yaml'))
    model = torch.load(os.path.join(base_dir, model_name),
                       map_location=TorchUtils.get_accelerator_type())
    results = torch.load(os.path.join(base_dir, "results.pickle"),
                         map_location=TorchUtils.get_accelerator_type())
    return model, results  # , configs
Exemplo n.º 10
0
 def __init__(self, configs):
     self.inputs, targets, self.info_dict = self.load_data(configs["data"])
     self.targets = targets / self.info_dict["processor"]["driver"][
         "amplification"]
     self.inputs = TorchUtils.get_tensor_from_numpy(self.inputs).cpu()
     self.targets = TorchUtils.get_tensor_from_numpy(self.targets).cpu()
     assert len(self.inputs) == len(
         self.targets), "Inputs and Outpus have NOT the same length"
Exemplo n.º 11
0
 def _init_voltage_ranges(self):
     offset = TorchUtils.get_tensor_from_list(
         self.info["data_info"]["input_data"]["offset"])
     amplitude = TorchUtils.get_tensor_from_list(
         self.info["data_info"]["input_data"]["amplitude"])
     min_voltage = (offset - amplitude).unsqueeze(dim=1)
     max_voltage = (offset + amplitude).unsqueeze(dim=1)
     self.voltage_ranges = torch.cat((min_voltage, max_voltage), dim=1)
Exemplo n.º 12
0
 def __call__(self, data):
     inputs, targets = data[0], data[1]
     inputs = torch.tensor(inputs,
                           device=self.device,
                           dtype=TorchUtils.get_data_type())
     targets = torch.tensor(targets,
                            device=self.device,
                            dtype=TorchUtils.get_data_type())
     return (inputs, targets)
Exemplo n.º 13
0
 def _init_electrode_info(self, configs):
     # self.input_no = len(configs['data_input_indices'])
     self.data_input_indices = TorchUtils.get_tensor_from_list(
         configs["data"]["input_indices"], data_type=torch.int64)
     self.control_indices = np.delete(np.arange(self.electrode_no),
                                      configs["data"]["input_indices"])
     self.control_indices = TorchUtils.get_tensor_from_list(
         self.control_indices, data_type=torch.int64
     )  # IndexError: tensors used as indices must be long, byte or bool tensors
Exemplo n.º 14
0
 def __init__(self, configs):
     super().__init__()
     self._load(configs)
     self._init_voltage_ranges()
     self.amplification = TorchUtils.get_tensor_from_list(
         self.info["data_info"]["processor"]['driver']["amplification"])
     self.clipping_value = TorchUtils.get_tensor_from_list(
         self.info["data_info"]["clipping_value"])
     self.noise = get_noise(configs)
Exemplo n.º 15
0
    def _init_dnpu(self, alpha):
        self.alpha = torch.tensor(alpha,
                                  device=TorchUtils.get_accelerator_type(),
                                  dtype=TorchUtils.get_data_type())

        for (params) in (self.parameters(
        )):  # Freeze parameters of the neural network of the surrogate model
            params.requires_grad = False
        self._init_bias()
Exemplo n.º 16
0
def plot_results(results, plots_dir=None, show_plots=False, extension="png"):
    plot_output(results["train_results"],
                "Train",
                plots_dir=plots_dir,
                extension=extension)
    plot_perceptron(results["train_results"]["accuracy"],
                    plots_dir,
                    name="train")
    if "dev_results" in results:
        plot_output(results["dev_results"],
                    "Dev",
                    plots_dir=plots_dir,
                    extension=extension)
        plot_perceptron(results["dev_results"]["accuracy"],
                        plots_dir,
                        name="dev")
    if "test_results" in results:
        plot_output(results["test_results"],
                    "Test",
                    plots_dir=plots_dir,
                    extension=extension)
        plot_perceptron(results["test_results"]["accuracy"],
                        plots_dir,
                        name="test")
    plt.figure()
    plt.title(f"Learning profile", fontsize=12)
    plt.plot(
        TorchUtils.get_numpy_from_tensor(
            results["train_results"]["performance_history"]),
        label="Train",
    )
    if "dev_results" in results:
        plt.plot(
            TorchUtils.get_numpy_from_tensor(
                results["dev_results"]["performance_history"]),
            label="Dev",
        )
    plt.legend()
    if plots_dir is not None:
        plt.savefig(os.path.join(plots_dir, f"training_profile." + extension))

    plt.figure()
    plt.title(f"Inputs (V) \n {results['gap']} gap (-1 to 1 scale)",
              fontsize=12)
    plot_inputs(results["train_results"], "Train", ["blue", "cornflowerblue"])
    if "dev_results" in results:
        plot_inputs(results["dev_results"], "Dev", ["orange", "bisque"])
    if "test_results" in results:
        plot_inputs(results["test_results"], "Test", ["green", "springgreen"])
    plt.legend()
    # if type(results['dev_inputs']) is torch.Tensor:
    if plots_dir is not None:
        plt.savefig(os.path.join(plots_dir, f"input." + extension))

    if show_plots:
        plt.show()
    plt.close("all")
Exemplo n.º 17
0
 def _init_pool(self):
     pool = torch.zeros((self.genome_no, len(self.gene_range)),
                        device=TorchUtils.get_accelerator_type(),
                        dtype=TorchUtils.get_data_type()
                        )  # Dimensions (Genome number, gene number)
     for i in range(0, len(self.gene_range)):
         pool[:, i] = uniform(self.gene_range[i][0],
                              self.gene_range[i][1]).sample(
                                  (self.genome_no, ))
     return pool
Exemplo n.º 18
0
def postprocess(dataloader, model, amplification, results_dir, label):
    print(f'Postprocessing {label} data ... ')
    predictions = TorchUtils.format_tensor(
        torch.zeros(len(dataloader), dataloader.batch_size))
    targets_log = TorchUtils.format_tensor(
        torch.zeros(len(dataloader), dataloader.batch_size))
    i = 0
    with torch.no_grad():
        model.eval()
        for inputs, targets in dataloader:
            if inputs.device != TorchUtils.get_accelerator_type():
                inputs = inputs.to(device=TorchUtils.get_accelerator_type())
            if targets.device != TorchUtils.get_accelerator_type():
                targets = targets.to(device=TorchUtils.get_accelerator_type())
            targets_log[i] = targets.squeeze()
            predictions[i] = model(inputs).squeeze()
            i += 1
        #inputs, targets = dataset[:]
        # inputs = inputs.to(device=TorchUtils.get_accelerator_type())
        # targets = targets.to(device=TorchUtils.get_accelerator_type())
        # predictions = model(inputs)

    # train_targets = amplification * TorchUtils.get_numpy_from_tensor(targets_log)
    targets_log = targets_log.view(targets_log.shape[0] * targets_log.shape[1])
    predictions = predictions.view(predictions.shape[0] * predictions.shape[1])
    train_targets = amplification * TorchUtils.get_numpy_from_tensor(
        targets_log)
    train_output = amplification * TorchUtils.get_numpy_from_tensor(
        predictions)
    plot_all(train_targets, train_output, results_dir, name=label)
Exemplo n.º 19
0
 def __init__(self, current_range, voltage_range, cut=True):
     assert len(current_range) == len(
         voltage_range), "Mapping ranges are different in length"
     self.map_variables = TorchUtils.get_tensor_from_list(
         [
             get_map_to_voltage_vars(
                 voltage_range[i][0],
                 voltage_range[i][1],
                 current_range[i][0],
                 current_range[i][1],
             ) for i in range(len(current_range))
         ],
         device=TorchUtils.get_accelerator_type())
     self.current_range = current_range
     self.cut = cut
Exemplo n.º 20
0
    def full_check(self, point_no):
        points = torch.rand(point_no, device=TorchUtils.get_accelerator_type(), dtype=TorchUtils.get_data_type())  # .unsqueeze(dim=1)
        waveform = self.waveform_mgr.points_to_waveform(points)
        assert (
            (waveform[0, :] == 0.0).all() and (waveform[-1, :] == 0.0).all()
        ), "Waveforms do not start and end with zero"
        assert len(waveform) == (
            (self.waveform_mgr.plateau_length * len(points))
            + (self.waveform_mgr.slope_length * (len(points) + 1))
        ), "Waveform has an incorrect shape"

        mask = self.waveform_mgr.generate_mask(len(waveform))
        assert len(mask) == len(waveform)

        waveform_to_points = self.waveform_mgr.waveform_to_points(waveform)
        plateaus_to_points = self.waveform_mgr.plateaus_to_points(waveform[mask])
        assert (
            (points.half().float() == waveform_to_points.half().float()).all()
            == (points.half().float() == plateaus_to_points.half().float()).all()
            == True
        ), "Inconsistent to_point conversion"

        points_to_plateau = self.waveform_mgr.points_to_plateaus(points)
        waveform_to_plateau = self.waveform_mgr.waveform_to_plateaus(waveform)
        assert (waveform[mask] == points_to_plateau).all() == (
            waveform[mask] == waveform_to_plateau
        ).all(), "Inconsistent plateau conversion"

        plateaus_to_waveform, _ = self.waveform_mgr.plateaus_to_waveform(
            waveform[mask]
        )
        assert (
            waveform == plateaus_to_waveform
        ).all(), "Inconsistent waveform conversion"
Exemplo n.º 21
0
 def __init__(self, configs, verbose=False):
     super().__init__()
     self.configs = configs
     self.verbose = verbose
     self.load(configs["torch_model_dict"])
     if TorchUtils.get_accelerator_type() == torch.device("cuda"):
         self.raw_model.cuda()
Exemplo n.º 22
0
 def forward(self, x):
     with torch.no_grad():
         x, mask = self.waveform_mgr.plateaus_to_waveform(x, return_pytorch=False)
         output = self.forward_numpy(x)
         if self.logger is not None:
             self.logger.log_output(x)
     return TorchUtils.get_tensor_from_numpy(output[mask])
Exemplo n.º 23
0
 def set_control_voltages(self, bias):
     with torch.no_grad():
         bias = bias.unsqueeze(dim=0)
         assert (
             self.bias.shape == bias.shape
         ), "Control voltages could not be set due to a shape missmatch with regard to the ones already in the model."
         self.bias = torch.nn.Parameter(TorchUtils.format_tensor(bias))
Exemplo n.º 24
0
def train_perceptron(results, configs, node=None):
    # Initialise key elements of the trainer
    dataloaders = get_data(results, configs, shuffle=True)
    loss = torch.nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(node.parameters(),
                                 lr=configs["learning_rate"],
                                 betas=configs["betas"])
    best_accuracy = -1
    best_labels = None
    looper = trange(configs["epochs"], desc="Calculating accuracy")
    node = node.to(device=TorchUtils.get_accelerator_type(), dtype=DTYPE)
    # validation_index = get_index(dataloaders)

    for epoch in looper:
        for inputs, targets in dataloaders[0]:
            if inputs.device != TorchUtils.get_accelerator_type():
                inputs = inputs.to(TorchUtils.get_accelerator_type())
            if targets.device != TorchUtils.get_accelerator_type():
                targets = targets.to(TorchUtils.get_accelerator_type())
            optimizer.zero_grad()
            predictions = node(inputs)
            cost = loss(predictions, targets)
            cost.backward()
            optimizer.step()
        with torch.no_grad():
            node.eval()
            accuracy, labels = evaluate_accuracy(results['norm_inputs'],
                                                 results['targets'], node)
            if accuracy > best_accuracy:
                best_accuracy = accuracy
                best_labels = labels
                w, b = [p for p in node.parameters()]
                # TODO: Add a more efficient stopping mechanism ?
                if best_accuracy >= 100.0:
                    looper.set_description(
                        f"Reached 100/% accuracy. Stopping at Epoch: {epoch+1}  Accuracy {best_accuracy}, loss: {cost}"
                    )
                    looper.close()
                    break
            node.train()
        looper.set_description(
            f"Epoch: {epoch+1}  Accuracy {accuracy}, loss: {cost}"
            #f"Epoch: {epoch+1} loss: {cost}"
        )
    node.weight = w
    node.bias = b
    return best_accuracy, best_labels, node
Exemplo n.º 25
0
    def _init_bias(self):
        self.control_low = self.processor.get_control_ranges()[:, 0]
        self.control_high = self.processor.get_control_ranges()[:, 1]
        assert any(
            self.control_low < 0
        ), "Min. Voltage is assumed to be negative, but value is positive!"
        assert any(
            self.control_high > 0
        ), "Max. Voltage is assumed to be positive, but value is negative!"
        bias = self.control_low + (
            self.control_high - self.control_low) * torch.rand(
                1,
                len(self.processor.control_indices),
                dtype=TorchUtils.get_data_type(),
                device=TorchUtils.get_accelerator_type())

        self.bias = nn.Parameter(bias)
Exemplo n.º 26
0
def init_seed(configs):
    if "seed" in configs:
        seed = configs["seed"]
    else:
        seed = None

    seed = TorchUtils.init_seed(seed, deterministic=True)
    configs["seed"] = seed
Exemplo n.º 27
0
def evaluate_model(model, dataset, criterion, results={}, transforms=None):
    with torch.no_grad():
        model.eval()
        if transforms is None:
            inputs, targets = dataset[:]
        else:
            inputs, targets = transforms(dataset[:])
        inputs = inputs.to(device=TorchUtils.get_accelerator_type())
        targets = targets.to(device=TorchUtils.get_accelerator_type())

        predictions = model(inputs)

    results["inputs"] = inputs
    results["targets"] = targets
    results["predictions"] = predictions
    results["performance"] = criterion(predictions, targets)
    return results
Exemplo n.º 28
0
def plot_results(results, base_dir=None, show_plots=False):
    fig = plt.figure()
    correlations = TorchUtils.get_numpy_from_tensor(torch.abs(results["correlations"]))
    threshold = TorchUtils.get_numpy_from_tensor(
        results["threshold"] * 100. * torch.ones(correlations.shape)
    )
    accuracies = TorchUtils.get_numpy_from_tensor(results["accuracies"])
    plt.plot(correlations, threshold, "k")
    plt.scatter(correlations, accuracies)
    plt.xlabel("Fitness / Performance")
    plt.ylabel("Accuracy")

    # create_directory(path)
    plt.savefig(os.path.join(base_dir, "fitness_vs_accuracy.png"))
    if show_plots:
        plt.show()
    plt.close()
    return fig
Exemplo n.º 29
0
def search_solution(
    configs,
    custom_model,
    criterion,
    algorithm,
    transforms=None,
    logger=None,
    is_main=True,
):
    main_dir, search_stats_dir, results_dir, reproducibility_dir = init_dirs(
        configs["data"]["gap"], configs["results_base_dir"], is_main=is_main)
    configs["results_base_dir"] = main_dir
    dataloaders = get_ring_data(configs, transforms)
    all_results = init_all_results(dataloaders, configs["runs"])
    best_run = None

    for run in range(configs["runs"]):
        print(f"########### RUN {run} ################")
        all_results["seeds"][run] = TorchUtils.init_seed(None,
                                                         deterministic=True)

        results, model = ring_task(
            configs,
            dataloaders,
            custom_model,
            criterion,
            algorithm,
            logger=logger,
            is_main=False,
            save_data=False,
        )
        all_results = update_all_search_stats(all_results, results, run)
        if is_best_run(results, best_run):
            results["best_index"] = run
            best_run = results
            plot_results(results, plots_dir=results_dir)
            torch.save(model, os.path.join(reproducibility_dir, "model.pt"))
            torch.save(
                results,
                os.path.join(reproducibility_dir, "results.pickle"),
                pickle_protocol=p.HIGHEST_PROTOCOL,
            )
            save(
                "configs",
                os.path.join(reproducibility_dir, "configs.yaml"),
                data=configs,
            )
            torch.save(results,
                       os.path.join(search_stats_dir, "best_result.pickle"))

    close_search(
        all_results,
        search_stats_dir,
        "all_results_" + str(configs["data"]["gap"]) + "_gap_" +
        str(configs["runs"]) + "_runs",
    )
Exemplo n.º 30
0
 def init_batch_norm(self, batch_norm, affine, track_running_stats):
     if batch_norm:
         self.init_output_node_no()
         self.bn = nn.BatchNorm1d(
             self.bn_outputs,
             affine=affine,
             track_running_stats=track_running_stats).to(
                 device=TorchUtils.get_accelerator_type())
     else:
         self.bn = batch_norm