def __init__(self,
                 start_num=0,
                 end_num=10,
                 rehearsal_size=2000,
                 net=None,
                 save_path="",
                 data_name="cifar100",
                 epoch=50):
        self.lr = 0.001
        self.epoch = epoch
        self.warm = 1
        self.batch_size = 256
        self.start_num = start_num
        self.end_num = end_num
        self.class_num = end_num - start_num
        self.use_cuda = True
        self.task_num = 1
        self.save_path = save_path
        self.main_net_path = save_path + "/icarl_" + str(start_num) + ".ptn"
        self.rehearsal_size = rehearsal_size
        transform_train = transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.RandomRotation(10),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        if data_name is "cifar100":
            self.trainset = Cifar100Split(start_num=start_num,
                                          end_num=end_num,
                                          train=True,
                                          transform=transform_train)
            self.testset = Cifar100Split(start_num=start_num,
                                         end_num=end_num,
                                         train=False,
                                         transform=transform_test)
        elif data_name is "cifar10":
            self.trainset = Cifar10_SVNH_Split(isCifar10=True,
                                               start_num=start_num,
                                               end_num=end_num,
                                               train=True,
                                               transform=transform_train)
            self.testset = Cifar10_SVNH_Split(isCifar10=True,
                                              start_num=start_num,
                                              end_num=end_num,
                                              train=False,
                                              transform=transform_test)
        else:
            self.trainset = Cifar10_SVNH_Split(isCifar10=False,
                                               start_num=start_num,
                                               end_num=end_num,
                                               train=True,
                                               transform=transform_train)
            self.testset = Cifar10_SVNH_Split(isCifar10=False,
                                              start_num=start_num,
                                              end_num=end_num,
                                              train=False,
                                              transform=transform_test)
        self.trainloader = data.DataLoader(self.trainset,
                                           batch_size=self.batch_size,
                                           shuffle=True,
                                           num_workers=0)
        self.testloader = data.DataLoader(self.testset,
                                          batch_size=self.batch_size,
                                          shuffle=False,
                                          num_workers=0)
        self.rehearsal_data = None
        self.rehearsal_loader = None
        if start_num > 0:
            self.rehearsal_data = Cifar100Rehearsal(end_num=start_num,
                                                    rehearsal_size=2000,
                                                    transform=transform_train)
            if data_name is "cifar100":
                self.rehearsal_data = Cifar100Rehearsal(
                    end_num=start_num,
                    rehearsal_size=2000,
                    transform=transform_train)
            elif data_name is "cifar10":
                self.rehearsal_data = Cifar10_SVNH_Rehearsal(
                    isCifar10=True,
                    end_num=start_num,
                    rehearsal_size=2000,
                    transform=transform_train)
            else:
                self.rehearsal_data = Cifar10_SVNH_Rehearsal(
                    isCifar10=False,
                    end_num=start_num,
                    rehearsal_size=2000,
                    transform=transform_train)
            self.rehearsal_loader = data.DataLoader(self.rehearsal_data,
                                                    batch_size=self.batch_size,
                                                    shuffle=True,
                                                    num_workers=0)

        self.net = net
        self.optimizer = torch.optim.Adam(self.net.parameters(),
                                          lr=self.lr,
                                          betas=(0.9, 0.999),
                                          eps=1e-08,
                                          weight_decay=0,
                                          amsgrad=False)
        milestones = [10, 20, 30, 40]
        self.train_scheduler = torch.optim.lr_scheduler.MultiStepLR(
            self.optimizer, milestones=milestones, gamma=0.5)
        self.optmizer_De_list = []
        self.scheduler_De_list = []
        self.decoder_list = []
        self.pixelwise_loss = torch.nn.L1Loss()
        self.best_loss_pix = {}

        for i in range(self.class_num):
            self.best_loss_pix[i] = 1000000
            de = DecoderNet()
            if self.use_cuda:
                de = de.cuda()
            self.decoder_list.append(de)
            opt = torch.optim.Adam(de.parameters(),
                                   lr=self.lr,
                                   betas=(0.9, 0.999))
            self.optmizer_De_list.append(opt)
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                opt, milestones=milestones, gamma=0.5)
            self.scheduler_De_list.append(scheduler)
        self.old_model = None
        if start_num > 0:
            self.distillation = True
            # self.old_model = copy.deepcopy(net)
        else:
            self.distillation = False
def model_test(model_num=5, save_path="", per_num=10,model_name="",alpha=0.02):

    alpha = np.exp(-alpha* model_num * per_num)
    torch.cuda.empty_cache()
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ])

    testset = Cifar100Split(start_num=0, end_num=model_num * per_num, train=False, transform=transform)
    testloader = data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=0)
    decoder_list = []
    net = resnet18(num_classes=100).cuda()
    model_path = torch.load(save_path + "/"+model_name+"_" + str((model_num - 1) * per_num) + ".ptn")
    net.load_state_dict(model_path)
    net.eval()
    feature_extract_list = []
    for j in range(model_num):
        feature_extract = resnet18(num_classes=100).cuda()
        model_path = torch.load(save_path + "/"+model_name+"_" + str(j * per_num) + ".ptn")
        feature_extract.load_state_dict(model_path)
        feature_extract.eval()
        feature_extract_list.append(feature_extract)
    for j in range(model_num * per_num):
        de = DecoderNet().cuda()
        model_path = torch.load(save_path + "/decoder_" + str(j) + ".ptn")
        de.load_state_dict(model_path)
        de.eval()
        decoder_list.append(de)

    correct_top1 = 0
    correct_rec = 0
    correct_combination = 0
    over_num = 0
    over_num2 = 0
    for n, (images, labels) in enumerate(testloader):
        images = images.cuda()
        labels = labels.cuda()
        with torch.no_grad():
            output = net(images)
            preds = output[:, 0:model_num * per_num]
            _, pre_label_top1 = preds.max(dim=1)
            correct_top1 += pre_label_top1.eq(labels).sum()
            k = model_num + 1
            pre_score, pre_label = preds.topk(k=k, dim=1)
            for i in range(images.size(0)):
                score = preds[i, labels[i]]
                over_confident = preds[i] > score
                over_num += over_confident.sum()

                # print(correct)
                pix_loss = torch.zeros([k]).cuda()
                for j in range(len(pre_label[i])):
                    path = pre_label[i][j]
                    task = path // per_num
                    feature, _ = feature_extract_list[task].forward_feature(images[i:i + 1])
                    rec_img = decoder_list[path](feature)
                    a = images[i].view(-1)
                    b = rec_img.view(-1)
                    loss = torch.mean(torch.abs(a - b))
                    # pix_loss[j]= 0.1*(mean_pix_loss[path]-loss)
                    pix_loss[j] = loss

                # print(pix_loss)
                if pre_label[i][pix_loss.argmin()] == labels[i]:
                    correct_rec += 1
                    # print(correct_rec)
                a = (pre_score[i] - min(pre_score[i])) / (max(pre_score[i]) - min(pre_score[i]))
                a = torch.nn.functional.softmax(a, dim=0)
                pix_loss = (pix_loss - min(pix_loss)) / (max(pix_loss) - min(pix_loss))
                b = torch.nn.functional.softmax(-pix_loss, dim=0)
                c = alpha * a + (1 - alpha) * b
                if pre_label[i][c.argmax()] == labels[i]:
                    correct_combination += 1

                if (pre_label[i].eq(labels[i])).sum() > 0:
                    index = torch.where(pre_label[i] == labels[i])
                    score = c[index]
                    a = c > score
                    over_num2 += a.sum()
                else:
                    over_num2 += k

    num_data = len(testloader.dataset)
    correct_combination = correct_combination / num_data
    correct_rec = correct_rec / num_data
    correct_top1 = correct_top1.float() / num_data
    print("Test model:", str(model_num), "  Average  Accuracy:", correct_top1, "rec acc", correct_rec,
          " Accuracy predict+decoder:", correct_combination, " alpha:", alpha)

    utils.save_acc_csv(save_file="/over_confident100-" + str(per_num) + ".csv", class_num=model_num * per_num,
                       acc=over_num.cpu().numpy(), model_name=model_name+"_100-" + str(per_num))
    utils.save_acc_csv(save_file="/over_confident100-" + str(per_num) + ".csv", class_num=model_num * per_num,
                       acc=over_num2.cpu().numpy(), model_name=model_name+"_Decoder100-" + str(per_num))

    utils.save_acc_csv(save_file="/decoder_acc100-" + str(per_num) + ".csv", class_num=model_num * per_num,
                       acc=correct_top1.cpu().numpy(), model_name=model_name+"_100-" + str(per_num))
    utils.save_acc_csv(save_file="/decoder_acc100-" + str(per_num) + ".csv", class_num=model_num * per_num,
                       acc=correct_combination, model_name=model_name+"_Decoder100-" + str(per_num))