Пример #1
0
    def input_pruning(self, results_path, min_n_input_dims=20, minimum_neurons=20):
        """
        :param net:
        :param gt:
        :param min_n_input_dims:
        :param minimum_neurons:
        :return:
        """
        self.eval()
        with torch.no_grad():
            hebb_input = self.hebb_input_values.data.copy_(self.hebb_input_values.data).cpu().numpy()
            if len(hebb_input) >= min_n_input_dims:
                to_keep = hebb_input > float(self.gt_input)
                notTooUsed = hebb_input < float(self.lt_input)
                print("min_hebb_value:", self.gt_input)
                valid_indices = indices_h(to_keep)
                valid_indices_down = indices_h(notTooUsed)
                total_valid = np.intersect1d(valid_indices, valid_indices_down)
                if len(valid_indices) < minimum_neurons:
                    # TODO Replace neurons that could not be removed?
                    valid_indices = indices_h(torch.sort(hebb_input)[1] < minimum_neurons)
                    print("Minimum neurons on layer 1", sep="\t", file=self.hebb_log)

                print("previous_valid_len", self.previous_valid_len)
                self.valid_bool = [1. if x in valid_indices else 0. for x in range(self.input_size)]
                self.valid_bool_down = [1. if x in valid_indices_down else 0. for x in range(self.input_size)]
                self.valid_bool_total = [1. if x in total_valid else 0. for x in range(self.input_size)]
                self.alive_inputs = [x for x in range(len(hebb_input)) if x in valid_indices]
                self.alive_inputs_down = [x for x in range(len(hebb_input)) if x in valid_indices_down]
                self.alive_inputs_total = [x for x in range(len(hebb_input)) if x in total_valid]
                alive_inputs = np.array(self.alive_inputs)
                #if len(self.alive_inputs) < self.previous_valid_len:
                masks_path = results_path + "/images/masks/" + str(self.dataset_name) + "/"
                create_missing_folders(masks_path)

                img_path = "_".join(["alive_inputs", str(len(valid_indices_down)), str(self.epoch), "down.png"])
                print("self.n_channels", self.n_channels)
                if len(self.input_shape) == 3:
                    print("SAVING MASK at", results_path)
                    mask = np.reshape(self.valid_bool_down, newshape=(28, 28))  # TODO change hard coding
                    plt.imsave(masks_path + img_path, mask)
                img_path = "_".join(["alive_inputs", str(len(total_valid)), str(self.epoch), "total.png"])
                print("self.n_channels", self.n_channels)
                if len(self.input_shape) == 3:
                    print("SAVING MASK at", results_path)
                    mask = np.reshape(self.valid_bool_total, newshape=(28, 28))  # TODO change hard coding
                    plt.imsave(masks_path + img_path, mask)
                img_path = "_".join(["alive_inputs", str(len(valid_indices)), str(self.epoch), "up.png"])
                print("self.n_channels", self.n_channels)
                if len(self.input_shape) == 3:
                    print("SAVING MASK at", results_path)
                    mask = np.reshape(self.valid_bool, newshape=(28, 28))  # TODO change hard coding
                    plt.imsave(masks_path + img_path, mask)

                self.previous_valid_len = len(valid_indices)
                self.valid_bool_tensor = self.valid_bool_tensor * torch.Tensor(self.valid_bool).cuda()
                return self.valid_bool, self.alive_inputs
Пример #2
0
    def pruning(self, fcs, minimum_neurons=2):
        self.eval()
        with torch.no_grad():
            for i in range(len(self.gt)):
                alive_neurons_out = self.hebb_values[i] > float(self.gt[i])
                indices_alive_neurons_out = indices_h(alive_neurons_out)
                self.hebb_values_neurites[i] = self.hebb_values_neurites[i][
                    indices_alive_neurons_out, :]

                w2 = fcs[i].weight.data.copy_(fcs[i].weight.data).cpu().numpy()
                b2 = fcs[i].bias.data.copy_(fcs[i].bias.data).cpu().numpy()
                wg2 = fcs[i].weight.grad.data.copy_(
                    fcs[i].weight.grad.data).cpu().numpy()
                bg2 = fcs[i].bias.grad.data.copy_(
                    fcs[i].bias.grad.data).cpu().numpy()

                bg2 = bg2[indices_alive_neurons_out]
                b2 = b2[indices_alive_neurons_out]

                wg2 = wg2[indices_alive_neurons_out, :]
                w2 = w2[indices_alive_neurons_out, :]

                if i > 0:
                    alive_neurons_in = torch.Tensor([
                        True if x > float(self.gt[i - 1]) else False
                        for x in self.hebb_values[i - 1]
                    ])
                    indices_alive_neurons_in = indices_h(alive_neurons_in)

                    self.hebb_values_neurites[i] = self.hebb_values_neurites[
                        i][:, indices_alive_neurons_in]
                    wg2 = wg2[:, indices_alive_neurons_in]
                    w2 = w2[:, indices_alive_neurons_in]
                    fcs[i].in_features = wg2.shape[1]

                self.Ns[i] = len(b2)
                fcs[i].out_features = len(b2)

                b2 = torch.from_numpy(b2)
                bg2 = torch.from_numpy(bg2)
                w2 = torch.from_numpy(w2)
                wg2 = torch.from_numpy(wg2)

                if torch.cuda.is_available():
                    w2 = Variable(w2).cuda()
                    wg2 = Variable(wg2).cuda()
                    b2 = Variable(b2).cuda()
                    bg2 = Variable(bg2).cuda()

                fcs[i].weight = nn.Parameter(w2)
                fcs[i].weight.grad = nn.Parameter(wg2)
                fcs[i].bias = nn.Parameter(b2)
                fcs[i].bias.grad = nn.Parameter(bg2)

                # alive_neurites = self.hebb_values_neurites[i] > self.gt_neurites[i]
                # alive_neurites = torch.Tensor(alive_neurites.data.cpu().numpy()).cuda()

                self.hebb_values[i] = self.hebb_values[i][
                    indices_alive_neurons_out]
                # fcs[i].weight.data = fcs[i].weight.data * alive_neurites
                # self.n_neurites[i] += [int(torch.sum(alive_neurites))]
                if len(indices_alive_neurons_out) < minimum_neurons:
                    indices_alive_neurons_out = indices_h(
                        torch.sort(self.hebb_values[i])[1] < minimum_neurons)
                    print("Minimum neurons on layer ", (i + 1),
                          sep="\t",
                          file=self.hebb_log)

            w3 = fcs[-1].weight.data.copy_(fcs[-1].weight.data).cpu().numpy()
            wg3 = fcs[-1].weight.grad.data.copy_(
                fcs[-1].weight.grad.data).cpu().numpy()

            try:
                wg3 = wg3[:, indices_alive_neurons_out]
                fcs[-1].in_features = len(indices_alive_neurons_out)
                if torch.cuda.is_available():
                    fcs[-1].weight = nn.Parameter(
                        Variable(
                            torch.from_numpy(
                                w3[:, indices_alive_neurons_out])).cuda())
                    fcs[-1].weight.grad = nn.Parameter(
                        Variable(torch.from_numpy(wg3)).cuda())
                else:
                    fcs[-1].weight = nn.Parameter(
                        Variable(
                            torch.from_numpy(w3[:,
                                                indices_alive_neurons_out])))
                    fcs[-1].weight.grad = nn.Parameter(
                        Variable(torch.from_numpy(wg3)))

            except:
                if torch.cuda.is_available():
                    fcs[-1].weight = nn.Parameter(
                        Variable(torch.from_numpy(w3)).cuda())
                    fcs[-1].weight.grad = nn.Parameter(
                        Variable(torch.from_numpy(wg3)).cuda())
                else:
                    fcs[-1].weight = nn.Parameter(
                        Variable(torch.from_numpy(w3)))
                    fcs[-1].weight.grad = nn.Parameter(
                        Variable(torch.from_numpy(wg3)))

            if torch.cuda.is_available():
                fcs = fcs.cuda()
            return fcs