Esempio n. 1
0
    def _diag_fisher(self, cls_id, iteration, side_fc=False):
        print("Training MAS model: Classifier %d" % (cls_id))
        precision_matrices = {}
        mse_criterion = nn.MSELoss()
        mse_criterion = mse_criterion.to(device)
        for n, p in copy.deepcopy(self.params).items():
            p.data.zero_()
            precision_matrices[n] = variable(p.data)

        self.model.eval()
        for batch_idx, (inputs, targets) in enumerate(self.dataset):
            inputs, targets = inputs.to(device), targets.to(device)
            #num_old_classes = args.nb_cl * iteration
            #targets = targets - num_old_classes
            if side_fc is True:
                start_index = args.side_classifier * args.nb_cl * iteration
            else:
                start_index = args.nb_cl * iteration
            self.model.zero_grad()
            outputs = self.model(inputs, side_fc=side_fc)
            i = cls_id - 1
            Target_zeros = torch.zeros_like(
                outputs[:, (start_index +
                            args.nb_cl * i):(start_index + args.nb_cl *
                                             (i + 1))]).to(device)
            loss_cls = mse_criterion(
                outputs[:,
                        (start_index + args.nb_cl * i):(start_index +
                                                        args.nb_cl * (i + 1))],
                Target_zeros)
            loss_cls.backward()
            for n, p in self.model.named_parameters():
                if 'fc' not in n:
                    precision_matrices[n].data += torch.abs(p.grad.data) / len(
                        self.dataset)
            if (batch_idx + 1) % 200 == 0:
                print(batch_idx + 1)

        precision_matrices = {n: p for n, p in precision_matrices.items()}
        save_name_side = os.path.join(
            ckp_prefix +
            'WI/Weigtht_Importance_step_{}_K_{}_classifier_{}.pkl').format(
                iteration, args.side_classifier, cls_id)
        utils_pytorch.savepickle(precision_matrices, save_name_side)

        return precision_matrices
Esempio n. 2
0
svhn_data_copy = svhn_data.data
svhn_labels_copy = svhn_data.labels

# Launch the different runs
for n_run in range(args.nb_runs):
    # Select the order for the class learning
    order_name = "./checkpoint/{}_order_run_{}.pkl".format(args.dataset, n_run)
    print("Order name:{}".format(order_name))
    if os.path.exists(order_name):
        print("Loading orders")
        order = utils_pytorch.unpickle(order_name)
    else:
        print("Generating orders")
        order = np.arange(args.num_classes)
        np.random.shuffle(order)
        utils_pytorch.savepickle(order, order_name)
    order_list = list(order)
    print(order_list)

    start_iter = 0
    for iteration in range(start_iter, int(args.num_classes / args.nb_cl)):
        # Prepare the training data for the current batch of classes
        actual_cl = order[range(iteration * args.nb_cl,
                                (iteration + 1) * args.nb_cl)]
        indices_train_subset = np.array([
            i in order[range(iteration * args.nb_cl,
                             (iteration + 1) * args.nb_cl)]
            for i in Y_train_total
        ])
        indices_test_subset = np.array([
            i in order[range(0, (iteration + 1) * args.nb_cl)]
svhn_data_copy = svhn_data.data
svhn_labels_copy = svhn_data.labels

# Launch the different runs
for n_run in range(args.nb_runs):
    # Select the order for the class learning
    order_name = "./checkpoint/{}_order_run_{}.pkl".format(args.dataset, n_run)
    print("Order name:{}".format(order_name))
    if os.path.exists(order_name):
        print("Loading orders")
        order = utils_pytorch.unpickle(order_name)
    else:
        print("Generating orders")
        order = np.arange(args.num_classes)
        np.random.shuffle(order)
        utils_pytorch.savepickle(order, order_name)
    order_list = list(order)
    print(order_list)

    start_iter = 0
    for iteration in range(start_iter, int(args.num_classes / args.nb_cl)):
        # Prepare the training data for the current batch of classes(total class(100)/group class(20))
        actual_cl = order[range(iteration * args.nb_cl,
                                (iteration + 1) * args.nb_cl)]
        indices_train_subset = np.array([
            i in order[range(iteration * args.nb_cl,
                             (iteration + 1) * args.nb_cl)]
            for i in Y_train_total
        ])
        indices_test_subset = np.array([
            i in order[range(0, (iteration + 1) * args.nb_cl)]