def compute_features(tg_model,
                     free_model,
                     tg_feature_model,
                     is_start_iteration,
                     evalloader,
                     num_samples,
                     num_features,
                     device=None):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tg_feature_model.eval()
    tg_model.eval()
    if free_model is not None:
        free_model.eval()
    features = np.zeros([num_samples, num_features])
    start_idx = 0
    with torch.no_grad():
        for inputs, targets in evalloader:
            inputs = inputs.to(device)
            if is_start_iteration:
                the_feature = tg_feature_model(inputs)
            else:
                the_feature = process_inputs_fp(tg_model,
                                                free_model,
                                                inputs,
                                                feature_mode=True)
            features[start_idx:start_idx +
                     inputs.shape[0], :] = np.squeeze(the_feature)
            start_idx = start_idx + inputs.shape[0]
    assert (start_idx == num_samples)
    return features
Esempio n. 2
0
def incremental_train_and_eval(the_args,
                               epochs,
                               fusion_vars,
                               ref_fusion_vars,
                               b1_model,
                               ref_model,
                               b2_model,
                               ref_b2_model,
                               tg_optimizer,
                               tg_lr_scheduler,
                               fusion_optimizer,
                               fusion_lr_scheduler,
                               trainloader,
                               testloader,
                               balancedloader,
                               iteration,
                               start_iteration,
                               X_protoset_cumuls,
                               Y_protoset_cumuls,
                               order_list,
                               lamda,
                               dist,
                               K,
                               lw_mr,
                               fix_bn=False,
                               weight_per_class=None,
                               device=None):

    T = 2.0
    beta = 0.25
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ref_model.eval()

    num_old_classes = ref_model.fc.out_features
    if iteration > start_iteration + 1:
        ref_b2_model.eval()

    for epoch in range(epochs):
        b1_model.train()
        b2_model.train()

        if fix_bn:
            for m in b1_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()

        train_loss = 0
        train_loss1 = 0
        train_loss2 = 0
        correct = 0
        total = 0

        tg_lr_scheduler.step()
        fusion_lr_scheduler.step()

        print('\nEpoch: %d, learning rate: ' % epoch, end='')
        print(tg_lr_scheduler.get_lr()[0])

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            inputs, targets = inputs.to(device), targets.to(device)

            tg_optimizer.zero_grad()

            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)

            if iteration == start_iteration + 1:
                ref_outputs = ref_model(inputs)
            else:
                ref_outputs, ref_features_new = process_inputs_fp(
                    the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs)
            loss1 = nn.KLDivLoss()(F.log_softmax(outputs[:,:num_old_classes]/T, dim=1), \
                F.softmax(ref_outputs.detach()/T, dim=1)) * T * T * beta * num_old_classes
            loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            loss = loss1 + loss2

            loss.backward()
            tg_optimizer.step()

            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(
            'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss: {:.4f} accuracy: {:.4f}'
            .format(len(trainloader), train_loss1 / (batch_idx + 1),
                    train_loss2 / (batch_idx + 1),
                    train_loss / (batch_idx + 1), 100. * correct / total))

        b1_model.eval()
        b2_model.eval()

        for batch_idx, (inputs, targets) in enumerate(balancedloader):
            fusion_optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)
            loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            loss.backward()
            fusion_optimizer.step()

        b1_model.eval()
        b2_model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                               b2_model, inputs)
                loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format(
            len(testloader), test_loss / (batch_idx + 1),
            100. * correct / total))

    print("Removing register forward hook")
    return b1_model, b2_model
def incremental_train_and_eval(the_args,
                               epochs,
                               fusion_vars,
                               ref_fusion_vars,
                               b1_model,
                               ref_model,
                               b2_model,
                               ref_b2_model,
                               tg_optimizer,
                               tg_lr_scheduler,
                               fusion_optimizer,
                               fusion_lr_scheduler,
                               trainloader,
                               testloader,
                               iteration,
                               start_iteration,
                               X_protoset_cumuls,
                               Y_protoset_cumuls,
                               order_list,
                               lamda,
                               dist,
                               K,
                               lw_mr,
                               balancedloader,
                               T=None,
                               beta=None,
                               fix_bn=False,
                               weight_per_class=None,
                               device=None):

    # Setting up the CUDA device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Set the 1st branch reference model to the evaluation mode
    ref_model.eval()

    # Get the number of old classes
    num_old_classes = ref_model.fc.out_features

    # If the 2nd branch reference is not None, set it to the evaluation mode
    if iteration > start_iteration + 1:
        ref_b2_model.eval()

    for epoch in range(epochs):
        # Start training for the current phase, set the two branch models to the training mode
        b1_model.train()
        b2_model.train()

        # Fix the batch norm parameters according to the config
        if fix_bn:
            for m in b1_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()

        # Set all the losses to zeros
        train_loss = 0
        train_loss1 = 0
        train_loss2 = 0
        # Set the counters to zeros
        correct = 0
        total = 0

        # Learning rate decay
        tg_lr_scheduler.step()
        fusion_lr_scheduler.step()

        # Print the information
        print('\nEpoch: %d, learning rate: ' % epoch, end='')
        print(tg_lr_scheduler.get_lr()[0])

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            # Get a batch of training samples, transfer them to the device
            inputs, targets = inputs.to(device), targets.to(device)

            # Clear the gradient of the paramaters for the tg_optimizer
            tg_optimizer.zero_grad()

            # Forward the samples in the deep networks
            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)

            if iteration == start_iteration + 1:
                ref_outputs = ref_model(inputs)
            else:
                ref_outputs, ref_features_new = process_inputs_fp(
                    the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs)
            # Loss 1: feature-level distillation loss
            loss1 = nn.KLDivLoss()(F.log_softmax(outputs[:,:num_old_classes]/T, dim=1), \
                F.softmax(ref_outputs.detach()/T, dim=1)) * T * T * beta * num_old_classes
            # Loss 2: classification loss
            loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            # Sum up all looses
            loss = loss1 + loss2

            # Backward and update the parameters
            loss.backward()
            tg_optimizer.step()

            # Record the losses and the number of samples to compute the accuracy
            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        # Print the training losses and accuracies
        print(
            'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss: {:.4f} accuracy: {:.4f}'
            .format(len(trainloader), train_loss1 / (batch_idx + 1),
                    train_loss2 / (batch_idx + 1),
                    train_loss / (batch_idx + 1), 100. * correct / total))

        # Update the aggregation weights
        b1_model.eval()
        b2_model.eval()

        for batch_idx, (inputs, targets) in enumerate(balancedloader):
            fusion_optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)
            loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            loss.backward()
            fusion_optimizer.step()

        # Running the test for this epoch
        b1_model.eval()
        b2_model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                               b2_model, inputs)
                loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format(
            len(testloader), test_loss / (batch_idx + 1),
            100. * correct / total))

    print("Removing register forward hook")
    return b1_model, b2_model
Esempio n. 4
0
def incremental_train_and_eval(the_args,
                               epochs,
                               fusion_vars,
                               ref_fusion_vars,
                               b1_model,
                               ref_model,
                               b2_model,
                               ref_b2_model,
                               tg_optimizer,
                               tg_lr_scheduler,
                               fusion_optimizer,
                               fusion_lr_scheduler,
                               trainloader,
                               testloader,
                               iteration,
                               start_iteration,
                               X_protoset_cumuls,
                               Y_protoset_cumuls,
                               order_list,
                               lamda,
                               dist,
                               K,
                               lw_mr,
                               fix_bn=False,
                               weight_per_class=None,
                               device=None):

    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    ref_model.eval()

    num_old_classes = ref_model.fc.out_features
    handle_ref_features = ref_model.fc.register_forward_hook(get_ref_features)
    handle_cur_features = b1_model.fc.register_forward_hook(get_cur_features)
    handle_old_scores_bs = b1_model.fc.fc1.register_forward_hook(
        get_old_scores_before_scale)
    handle_new_scores_bs = b1_model.fc.fc2.register_forward_hook(
        get_new_scores_before_scale)
    if iteration > start_iteration + 1:
        ref_b2_model.eval()

    for epoch in range(epochs):
        b1_model.train()
        b2_model.train()

        if fix_bn:
            for m in b1_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()
        train_loss = 0
        train_loss1 = 0
        train_loss2 = 0
        train_loss3 = 0
        correct = 0
        total = 0

        tg_lr_scheduler.step()
        fusion_lr_scheduler.step()

        print('\nEpoch: %d, learning rate: ' % epoch, end='')
        print(tg_lr_scheduler.get_lr()[0])

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            inputs, targets = inputs.to(device), targets.to(device)

            tg_optimizer.zero_grad()

            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)

            if iteration == start_iteration + 1:
                ref_outputs = ref_model(inputs)
                loss1 = nn.CosineEmbeddingLoss()(
                    cur_features, ref_features.detach(),
                    torch.ones(inputs.shape[0]).to(device)) * lamda
            else:
                ref_outputs, ref_features_new = process_inputs_fp(
                    the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs)
                loss1 = nn.CosineEmbeddingLoss()(
                    cur_features, ref_features_new.detach(),
                    torch.ones(inputs.shape[0]).to(device)) * lamda
            loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            outputs_bs = torch.cat((old_scores, new_scores), dim=1)
            assert (outputs_bs.size() == outputs.size())
            gt_index = torch.zeros(outputs_bs.size()).to(device)
            gt_index = gt_index.scatter(1, targets.view(-1, 1), 1).ge(0.5)
            gt_scores = outputs_bs.masked_select(gt_index)
            max_novel_scores = outputs_bs[:, num_old_classes:].topk(K,
                                                                    dim=1)[0]
            hard_index = targets.lt(num_old_classes)
            hard_num = torch.nonzero(hard_index).size(0)
            if hard_num > 0:
                gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, K)
                max_novel_scores = max_novel_scores[hard_index]
                assert (gt_scores.size() == max_novel_scores.size())
                assert (gt_scores.size(0) == hard_num)
                loss3 = nn.MarginRankingLoss(margin=dist)(
                    gt_scores.view(-1, 1), max_novel_scores.view(-1, 1),
                    torch.ones(hard_num * K).to(device)) * lw_mr
            else:
                loss3 = torch.zeros(1).to(device)
            loss = loss1 + loss2 + loss3

            loss.backward()
            tg_optimizer.step()

            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss3 += loss3.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(
            'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss3: {:.4f}, train loss: {:.4f} accuracy: {:.4f}'
            .format(len(trainloader), train_loss1 / (batch_idx + 1),
                    train_loss2 / (batch_idx + 1),
                    train_loss3 / (batch_idx + 1),
                    train_loss / (batch_idx + 1), 100. * correct / total))

        b1_model.eval()
        b2_model.eval()

        for batch_idx, (inputs, targets) in enumerate(testloader):
            fusion_optimizer.zero_grad()
            inputs, targets = inputs.to(device), targets.to(device)
            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)
            loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
            loss.backward()
            fusion_optimizer.step()
        print('Fusion index:    mtl1, free1, mtl2, free2, mtl3, free3')
        print(
            'Fusion vars:     {:.2f}, {:.2f},  {:.2f}, {:.2f},  {:.2f}, {:.2f}'
            .format(float(fusion_vars[0]), 1.0 - float(fusion_vars[0]),
                    float(fusion_vars[1]), 1.0 - float(fusion_vars[1]),
                    float(fusion_vars[2]), 1.0 - float(fusion_vars[2])))
        print(
            'Ref fusion vars: {:.2f}, {:.2f},  {:.2f}, {:.2f},  {:.2f}, {:.2f}'
            .format(float(ref_fusion_vars[0]), 1.0 - float(ref_fusion_vars[0]),
                    float(ref_fusion_vars[1]), 1.0 - float(ref_fusion_vars[1]),
                    float(ref_fusion_vars[2]),
                    1.0 - float(ref_fusion_vars[2])))

        b1_model.eval()
        b2_model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                               b2_model, inputs)
                loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format(
            len(testloader), test_loss / (batch_idx + 1),
            100. * correct / total))

    print("Removing register forward hook")
    handle_ref_features.remove()
    handle_cur_features.remove()
    handle_old_scores_bs.remove()
    handle_new_scores_bs.remove()
    return b1_model, b2_model
Esempio n. 5
0
    def train(self):
        self.train_writer = SummaryWriter(logdir=self.save_path)
        dictionary_size = self.dictionary_size
        top1_acc_list_cumul = np.zeros(
            (int(self.args.num_classes / self.args.nb_cl), 4,
             self.args.nb_runs))
        top1_acc_list_ori = np.zeros(
            (int(self.args.num_classes / self.args.nb_cl), 4,
             self.args.nb_runs))
        X_train_total = np.array(self.trainset.train_data)
        Y_train_total = np.array(self.trainset.train_labels)
        X_valid_total = np.array(self.testset.test_data)
        Y_valid_total = np.array(self.testset.test_labels)
        np.random.seed(1993)
        for iteration_total in range(self.args.nb_runs):
            order_name = osp.join(
                self.save_path,
                "seed_{}_{}_order_run_{}.pkl".format(1993, self.args.dataset,
                                                     iteration_total))
            print("Order name:{}".format(order_name))
            if osp.exists(order_name):
                print("Loading orders")
                order = utils.misc.unpickle(order_name)
            else:
                print("Generating orders")
                order = np.arange(self.args.num_classes)
                np.random.shuffle(order)
                utils.misc.savepickle(order, order_name)
            order_list = list(order)
            print(order_list)
        np.random.seed(self.args.random_seed)
        X_valid_cumuls = []
        X_protoset_cumuls = []
        X_train_cumuls = []
        Y_valid_cumuls = []
        Y_protoset_cumuls = []
        Y_train_cumuls = []
        alpha_dr_herding = np.zeros(
            (int(self.args.num_classes / self.args.nb_cl), dictionary_size,
             self.args.nb_cl), np.float32)
        prototypes = np.zeros(
            (self.args.num_classes, dictionary_size, X_train_total.shape[1],
             X_train_total.shape[2], X_train_total.shape[3]))
        for orde in range(self.args.num_classes):
            prototypes[orde, :, :, :, :] = X_train_total[np.where(
                Y_train_total == order[orde])]
        start_iter = int(self.args.nb_cl_fg / self.args.nb_cl) - 1
        for iteration in range(start_iter,
                               int(self.args.num_classes / self.args.nb_cl)):
            if iteration == start_iter:
                last_iter = 0
                tg_model = self.network(num_classes=self.args.nb_cl_fg)
                in_features = tg_model.fc.in_features
                out_features = tg_model.fc.out_features
                print("Out_features:", out_features)
                ref_model = None
                free_model = None
                ref_free_model = None
            elif iteration == start_iter + 1:
                last_iter = iteration
                ref_model = copy.deepcopy(tg_model)
                print("Fusion Mode: " + self.args.fusion_mode)
                tg_model = self.network_mtl(num_classes=self.args.nb_cl_fg)
                ref_dict = ref_model.state_dict()
                tg_dict = tg_model.state_dict()
                tg_dict.update(ref_dict)
                tg_model.load_state_dict(tg_dict)
                tg_model.to(self.device)
                in_features = tg_model.fc.in_features
                out_features = tg_model.fc.out_features
                print("Out_features:", out_features)
                new_fc = modified_linear.SplitCosineLinear(
                    in_features, out_features, self.args.nb_cl)
                new_fc.fc1.weight.data = tg_model.fc.weight.data
                new_fc.sigma.data = tg_model.fc.sigma.data
                tg_model.fc = new_fc
                lamda_mult = out_features * 1.0 / self.args.nb_cl
            else:
                last_iter = iteration
                ref_model = copy.deepcopy(tg_model)
                in_features = tg_model.fc.in_features
                out_features1 = tg_model.fc.fc1.out_features
                out_features2 = tg_model.fc.fc2.out_features
                print("Out_features:", out_features1 + out_features2)
                new_fc = modified_linear.SplitCosineLinear(
                    in_features, out_features1 + out_features2,
                    self.args.nb_cl)
                new_fc.fc1.weight.data[:
                                       out_features1] = tg_model.fc.fc1.weight.data
                new_fc.fc1.weight.data[
                    out_features1:] = tg_model.fc.fc2.weight.data
                new_fc.sigma.data = tg_model.fc.sigma.data
                tg_model.fc = new_fc
                lamda_mult = (out_features1 +
                              out_features2) * 1.0 / (self.args.nb_cl)
            if iteration > start_iter:
                cur_lamda = self.args.lamda * math.sqrt(lamda_mult)
            else:
                cur_lamda = self.args.lamda
            actual_cl = order[range(last_iter * self.args.nb_cl,
                                    (iteration + 1) * self.args.nb_cl)]
            indices_train_10 = np.array([
                i in order[range(last_iter * self.args.nb_cl,
                                 (iteration + 1) * self.args.nb_cl)]
                for i in Y_train_total
            ])
            indices_test_10 = np.array([
                i in order[range(last_iter * self.args.nb_cl,
                                 (iteration + 1) * self.args.nb_cl)]
                for i in Y_valid_total
            ])
            X_train = X_train_total[indices_train_10]
            X_valid = X_valid_total[indices_test_10]
            X_valid_cumuls.append(X_valid)
            X_train_cumuls.append(X_train)
            X_valid_cumul = np.concatenate(X_valid_cumuls)
            X_train_cumul = np.concatenate(X_train_cumuls)
            Y_train = Y_train_total[indices_train_10]
            Y_valid = Y_valid_total[indices_test_10]
            Y_valid_cumuls.append(Y_valid)
            Y_train_cumuls.append(Y_train)
            Y_valid_cumul = np.concatenate(Y_valid_cumuls)
            Y_train_cumul = np.concatenate(Y_train_cumuls)
            if iteration == start_iter:
                X_valid_ori = X_valid
                Y_valid_ori = Y_valid
            else:
                X_protoset = np.concatenate(X_protoset_cumuls)
                Y_protoset = np.concatenate(Y_protoset_cumuls)
                if self.args.rs_ratio > 0:
                    scale_factor = (len(X_train) * self.args.rs_ratio) / (
                        len(X_protoset) * (1 - self.args.rs_ratio))
                    rs_sample_weights = np.concatenate(
                        (np.ones(len(X_train)),
                         np.ones(len(X_protoset)) * scale_factor))
                    rs_num_samples = int(
                        len(X_train) / (1 - self.args.rs_ratio))
                    print(
                        "X_train:{}, X_protoset:{}, rs_num_samples:{}".format(
                            len(X_train), len(X_protoset), rs_num_samples))
                X_train = np.concatenate((X_train, X_protoset), axis=0)
                Y_train = np.concatenate((Y_train, Y_protoset))
            print('Batch of classes number {0} arrives'.format(iteration + 1))
            map_Y_train = np.array([order_list.index(i) for i in Y_train])
            map_Y_valid_cumul = np.array(
                [order_list.index(i) for i in Y_valid_cumul])
            is_start_iteration = (iteration == start_iter)
            if iteration > start_iter:
                old_embedding_norm = tg_model.fc.fc1.weight.data.norm(
                    dim=1, keepdim=True)
                average_old_embedding_norm = torch.mean(old_embedding_norm,
                                                        dim=0).to('cpu').type(
                                                            torch.DoubleTensor)
                tg_feature_model = nn.Sequential(
                    *list(tg_model.children())[:-1])
                num_features = tg_model.fc.in_features
                novel_embedding = torch.zeros((self.args.nb_cl, num_features))
                for cls_idx in range(iteration * self.args.nb_cl,
                                     (iteration + 1) * self.args.nb_cl):
                    cls_indices = np.array([i == cls_idx for i in map_Y_train])
                    assert (len(
                        np.where(cls_indices == 1)[0]) == dictionary_size)
                    self.evalset.test_data = X_train[cls_indices].astype(
                        'uint8')
                    self.evalset.test_labels = np.zeros(
                        self.evalset.test_data.shape[0])
                    evalloader = torch.utils.data.DataLoader(
                        self.evalset,
                        batch_size=self.args.eval_batch_size,
                        shuffle=False,
                        num_workers=self.args.num_workers)
                    num_samples = self.evalset.test_data.shape[0]
                    cls_features = compute_features(tg_model, free_model,
                                                    tg_feature_model,
                                                    is_start_iteration,
                                                    evalloader, num_samples,
                                                    num_features)
                    norm_features = F.normalize(torch.from_numpy(cls_features),
                                                p=2,
                                                dim=1)
                    cls_embedding = torch.mean(norm_features, dim=0)
                    novel_embedding[cls_idx -
                                    iteration * self.args.nb_cl] = F.normalize(
                                        cls_embedding, p=2,
                                        dim=0) * average_old_embedding_norm
                tg_model.to(self.device)
                tg_model.fc.fc2.weight.data = novel_embedding.to(self.device)
            self.trainset.train_data = X_train.astype('uint8')
            self.trainset.train_labels = map_Y_train
            if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1:
                print("Weights from sampling:", rs_sample_weights)
                index1 = np.where(rs_sample_weights > 1)[0]
                index2 = np.where(map_Y_train < iteration * self.args.nb_cl)[0]
                assert ((index1 == index2).all())
                train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
                    rs_sample_weights, rs_num_samples)
                trainloader = torch.utils.data.DataLoader(
                    self.trainset,
                    batch_size=self.args.train_batch_size,
                    shuffle=False,
                    sampler=train_sampler,
                    num_workers=self.args.num_workers)
            else:
                trainloader = torch.utils.data.DataLoader(
                    self.trainset,
                    batch_size=self.args.train_batch_size,
                    shuffle=True,
                    num_workers=self.args.num_workers)
            self.testset.test_data = X_valid_cumul.astype('uint8')
            self.testset.test_labels = map_Y_valid_cumul
            testloader = torch.utils.data.DataLoader(
                self.testset,
                batch_size=self.args.test_batch_size,
                shuffle=False,
                num_workers=self.args.num_workers)
            print('Max and min of train labels: {}, {}'.format(
                min(map_Y_train), max(map_Y_train)))
            print('Max and min of valid labels: {}, {}'.format(
                min(map_Y_valid_cumul), max(map_Y_valid_cumul)))
            ckp_name = osp.join(
                self.save_path,
                'run_{}_iteration_{}_model.pth'.format(iteration_total,
                                                       iteration))
            ckp_name_free = osp.join(
                self.save_path, 'run_{}_iteration_{}_free_model.pth'.format(
                    iteration_total, iteration))
            print('Checkpoint name:', ckp_name)
            if iteration == start_iter and self.args.resume_fg:
                print("Loading first group models from checkpoint")
                tg_model = torch.load(self.args.ckpt_dir_fg)
            elif self.args.resume and os.path.exists(ckp_name):
                print("Loading models from checkpoint")
                tg_model = torch.load(ckp_name)
            else:
                if iteration > start_iter:
                    ref_model = ref_model.to(self.device)
                    ignored_params = list(map(id,
                                              tg_model.fc.fc1.parameters()))
                    base_params = filter(lambda p: id(p) not in ignored_params,
                                         tg_model.parameters())
                    base_params = filter(lambda p: p.requires_grad,
                                         base_params)
                    base_params = filter(lambda p: p.requires_grad,
                                         base_params)
                    tg_params_new = [{
                        'params':
                        base_params,
                        'lr':
                        self.args.base_lr2,
                        'weight_decay':
                        self.args.custom_weight_decay
                    }, {
                        'params': tg_model.fc.fc1.parameters(),
                        'lr': 0,
                        'weight_decay': 0
                    }]
                    tg_model = tg_model.to(self.device)
                    tg_optimizer = optim.SGD(
                        tg_params_new,
                        lr=self.args.base_lr2,
                        momentum=self.args.custom_momentum,
                        weight_decay=self.args.custom_weight_decay)
                else:
                    tg_params = tg_model.parameters()
                    tg_model = tg_model.to(self.device)
                    tg_optimizer = optim.SGD(
                        tg_params,
                        lr=self.args.base_lr1,
                        momentum=self.args.custom_momentum,
                        weight_decay=self.args.custom_weight_decay)
                if iteration > start_iter:
                    tg_lr_scheduler = lr_scheduler.MultiStepLR(
                        tg_optimizer,
                        milestones=self.lr_strat,
                        gamma=self.args.lr_factor)
                else:
                    tg_lr_scheduler = lr_scheduler.MultiStepLR(
                        tg_optimizer,
                        milestones=self.lr_strat_first_phase,
                        gamma=self.args.lr_factor)
                print("Incremental train")
                if iteration > start_iter:
                    tg_model = incremental_train_and_eval(
                        self.args.epochs, tg_model, ref_model, free_model,
                        ref_free_model, tg_optimizer, tg_lr_scheduler,
                        trainloader, testloader, iteration, start_iter,
                        cur_lamda, self.args.dist, self.args.K,
                        self.args.lw_mr)
                else:
                    tg_model = incremental_train_and_eval(
                        self.args.epochs, tg_model, ref_model, tg_optimizer,
                        tg_lr_scheduler, trainloader, testloader, iteration,
                        start_iter, cur_lamda, self.args.dist, self.args.K,
                        self.args.lw_mr)
                torch.save(tg_model, ckp_name)
            if self.args.fix_budget:
                nb_protos_cl = int(
                    np.ceil(self.args.nb_protos * 100. / self.args.nb_cl /
                            (iteration + 1)))
            else:
                nb_protos_cl = self.args.nb_protos
            tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1])
            num_features = tg_model.fc.in_features
            for iter_dico in range(last_iter * self.args.nb_cl,
                                   (iteration + 1) * self.args.nb_cl):
                self.evalset.test_data = prototypes[iter_dico].astype('uint8')
                self.evalset.test_labels = np.zeros(
                    self.evalset.test_data.shape[0])
                evalloader = torch.utils.data.DataLoader(
                    self.evalset,
                    batch_size=self.args.eval_batch_size,
                    shuffle=False,
                    num_workers=self.args.num_workers)
                num_samples = self.evalset.test_data.shape[0]
                mapped_prototypes = compute_features(tg_model, free_model,
                                                     tg_feature_model,
                                                     is_start_iteration,
                                                     evalloader, num_samples,
                                                     num_features)
                D = mapped_prototypes.T
                D = D / np.linalg.norm(D, axis=0)
                mu = np.mean(D, axis=1)
                index1 = int(iter_dico / self.args.nb_cl)
                index2 = iter_dico % self.args.nb_cl
                alpha_dr_herding[index1, :,
                                 index2] = alpha_dr_herding[index1, :,
                                                            index2] * 0
                w_t = mu
                iter_herding = 0
                iter_herding_eff = 0
                while not (np.sum(alpha_dr_herding[index1, :, index2] != 0)
                           == min(nb_protos_cl,
                                  500)) and iter_herding_eff < 1000:
                    tmp_t = np.dot(w_t, D)
                    ind_max = np.argmax(tmp_t)

                    iter_herding_eff += 1
                    if alpha_dr_herding[index1, ind_max, index2] == 0:
                        alpha_dr_herding[index1, ind_max,
                                         index2] = 1 + iter_herding
                        iter_herding += 1
                    w_t = w_t + mu - D[:, ind_max]
            X_protoset_cumuls = []
            Y_protoset_cumuls = []
            class_means = np.zeros((64, 100, 2))
            for iteration2 in range(iteration + 1):
                for iter_dico in range(self.args.nb_cl):
                    current_cl = order[range(iteration2 * self.args.nb_cl,
                                             (iteration2 + 1) *
                                             self.args.nb_cl)]
                    self.evalset.test_data = prototypes[
                        iteration2 * self.args.nb_cl +
                        iter_dico].astype('uint8')
                    self.evalset.test_labels = np.zeros(
                        self.evalset.test_data.shape[0])  #zero labels
                    evalloader = torch.utils.data.DataLoader(
                        self.evalset,
                        batch_size=self.args.eval_batch_size,
                        shuffle=False,
                        num_workers=self.args.num_workers)
                    num_samples = self.evalset.test_data.shape[0]
                    mapped_prototypes = compute_features(
                        tg_model, free_model, tg_feature_model,
                        is_start_iteration, evalloader, num_samples,
                        num_features)
                    D = mapped_prototypes.T
                    D = D / np.linalg.norm(D, axis=0)
                    self.evalset.test_data = prototypes[
                        iteration2 * self.args.nb_cl +
                        iter_dico][:, :, :, ::-1].astype('uint8')
                    evalloader = torch.utils.data.DataLoader(
                        self.evalset,
                        batch_size=self.args.eval_batch_size,
                        shuffle=False,
                        num_workers=self.args.num_workers)
                    mapped_prototypes2 = compute_features(
                        tg_model, free_model, tg_feature_model,
                        is_start_iteration, evalloader, num_samples,
                        num_features)
                    D2 = mapped_prototypes2.T
                    D2 = D2 / np.linalg.norm(D2, axis=0)
                    alph = alpha_dr_herding[iteration2, :, iter_dico]
                    alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1.
                    X_protoset_cumuls.append(
                        prototypes[iteration2 * self.args.nb_cl + iter_dico,
                                   np.where(alph == 1)[0]])
                    Y_protoset_cumuls.append(
                        order[iteration2 * self.args.nb_cl + iter_dico] *
                        np.ones(len(np.where(alph == 1)[0])))
                    alph = alph / np.sum(alph)
                    class_means[:, current_cl[iter_dico],
                                0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2
                    class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm(
                        class_means[:, current_cl[iter_dico], 0])
                    alph = np.ones(dictionary_size) / dictionary_size
                    class_means[:, current_cl[iter_dico],
                                1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2
                    class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm(
                        class_means[:, current_cl[iter_dico], 1])
            current_means = class_means[:, order[range(0, (iteration + 1) *
                                                       self.args.nb_cl)]]
            X_protoset_array_old = np.array(X_protoset_cumuls)
            self.T = self.args.mnemonics_steps * self.args.mnemonics_epochs
            self.img_size = 32
            self.mnemonics_lrs = self.args.mnemonics_lr
            num_classes_incremental = self.args.nb_cl
            num_classes = self.args.nb_cl_fg
            nb_cl = self.args.nb_cl
            transform_proto = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.5071, 0.4866, 0.4409),
                                     (0.2009, 0.1984, 0.2023)),
            ])
            self.mnemonics_label = []
            if iteration == start_iter:
                the_X_protoset_array = np.array(X_protoset_cumuls).astype(
                    'uint8')
                the_Y_protoset_cumuls = np.array(Y_protoset_cumuls)
            else:
                the_X_protoset_array = np.array(
                    X_protoset_cumuls[-num_classes_incremental:]).astype(
                        'uint8')
                the_Y_protoset_cumuls = np.array(
                    Y_protoset_cumuls[-num_classes_incremental:])
            self.mnemonics_data = torch.zeros(the_X_protoset_array.shape[0],
                                              the_X_protoset_array.shape[1], 3,
                                              self.img_size, self.img_size)
            for idx1 in range(the_X_protoset_array.shape[0]):
                for idx2 in range(the_X_protoset_array.shape[1]):
                    the_img = the_X_protoset_array[idx1][idx2]
                    the_PIL_image = Image.fromarray(the_img)
                    the_PIL_image = transform_proto(the_PIL_image)
                    self.mnemonics_data[idx1][idx2] = the_PIL_image
                map_Y_label = self.map_labels(order_list,
                                              the_Y_protoset_cumuls[idx1])
                self.mnemonics_label.append(map_Y_label)
            self.mnemonics = nn.ParameterList()
            self.mnemonics.append(nn.Parameter(self.mnemonics_data))
            start_iteration = start_iter
            device = self.device
            self.mnemonics.to(device)
            tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1])
            tg_feature_model.eval()
            tg_model.eval()
            if free_model is not None:
                free_model.eval()
            self.mnemonics_optimizer = optim.SGD(
                self.mnemonics,
                lr=self.args.mnemonics_outer_lr,
                momentum=0.9,
                weight_decay=5e-4)
            self.mnemonics_lr_scheduler = optim.lr_scheduler.StepLR(
                self.mnemonics_optimizer,
                step_size=self.args.mnemonics_decay_epochs,
                gamma=self.args.mnemonics_decay_factor)
            current_means_new = current_means[:, :, 0].T
            for epoch in range(self.args.mnemonics_total_epochs):
                train_loss = 0
                self.mnemonics_lr_scheduler.step()
                for batch_idx, (q_inputs, q_targets) in enumerate(trainloader):
                    q_inputs, q_targets = q_inputs.to(device), q_targets.to(
                        device)
                    if iteration == start_iteration:
                        q_feature = tg_feature_model(q_inputs)
                    else:
                        q_feature = process_inputs_fp(tg_model,
                                                      free_model,
                                                      q_inputs,
                                                      feature_mode=True)
                    self.mnemonics_optimizer.zero_grad()
                    total_tr_loss = 0
                    if iteration == start_iteration:
                        mnemonics_outputs = tg_feature_model(
                            self.mnemonics[0][0])
                    else:
                        mnemonics_outputs = process_inputs_fp(
                            tg_model,
                            free_model,
                            self.mnemonics[0][0],
                            feature_mode=True)
                    this_class_mean_mnemonics = torch.mean(mnemonics_outputs,
                                                           dim=0)
                    this_class_mean_mnemonics = torch.squeeze(
                        this_class_mean_mnemonics)
                    total_class_mean_mnemonics = this_class_mean_mnemonics.unsqueeze(
                        dim=0)
                    for mnemonics_idx in range(len(self.mnemonics[0]) - 1):
                        if iteration == start_iteration:
                            mnemonics_outputs = tg_feature_model(
                                self.mnemonics[0][mnemonics_idx + 1])
                        else:
                            mnemonics_outputs = process_inputs_fp(
                                tg_model,
                                free_model,
                                self.mnemonics[0][mnemonics_idx + 1],
                                feature_mode=True)
                        this_class_mean_mnemonics = torch.mean(
                            mnemonics_outputs, dim=0)
                        this_class_mean_mnemonics = torch.squeeze(
                            this_class_mean_mnemonics)
                        total_class_mean_mnemonics = torch.cat(
                            (total_class_mean_mnemonics,
                             this_class_mean_mnemonics.unsqueeze(dim=0)),
                            dim=0)
                    if iteration == start_iteration:
                        all_cls_means = total_class_mean_mnemonics
                    else:
                        all_cls_means = torch.tensor(
                            current_means_new).float().to(device)
                        all_cls_means[-nb_cl:] = total_class_mean_mnemonics
                    the_logits = F.linear(
                        F.normalize(torch.squeeze(q_feature), p=2, dim=1),
                        F.normalize(all_cls_means, p=2, dim=1))
                    loss = F.cross_entropy(the_logits, q_targets)
                    loss.backward()
                    train_loss += loss.item()
            X_protoset_cumuls = process_mnemonics(
                X_protoset_cumuls, Y_protoset_cumuls, self.mnemonics,
                self.mnemonics_label, order_list, self.args.nb_cl_fg,
                self.args.nb_cl, iteration, start_iter)
            X_protoset_array = np.array(X_protoset_cumuls)
            X_protoset_cumuls_idx = 0
            for iteration2 in range(iteration + 1):
                for iter_dico in range(self.args.nb_cl):
                    alph = alpha_dr_herding[iteration2, :, iter_dico]
                    alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1.
                    this_X_protoset_array = X_protoset_array[
                        X_protoset_cumuls_idx]
                    X_protoset_cumuls_idx += 1
                    this_X_protoset_array = this_X_protoset_array.astype(
                        np.float64)
                    prototypes[iteration2 * self.args.nb_cl + iter_dico,
                               np.where(alph == 1)[0]] = this_X_protoset_array
            class_means = np.zeros((64, 100, 2))
            for iteration2 in range(iteration + 1):
                for iter_dico in range(self.args.nb_cl):
                    current_cl = order[range(iteration2 * self.args.nb_cl,
                                             (iteration2 + 1) *
                                             self.args.nb_cl)]
                    self.evalset.test_data = prototypes[
                        iteration2 * self.args.nb_cl +
                        iter_dico].astype('uint8')
                    self.evalset.test_labels = np.zeros(
                        self.evalset.test_data.shape[0])  #zero labels
                    evalloader = torch.utils.data.DataLoader(
                        self.evalset,
                        batch_size=self.args.eval_batch_size,
                        shuffle=False,
                        num_workers=self.args.num_workers)
                    num_samples = self.evalset.test_data.shape[0]
                    mapped_prototypes = compute_features(
                        tg_model, free_model, tg_feature_model,
                        is_start_iteration, evalloader, num_samples,
                        num_features)
                    D = mapped_prototypes.T
                    D = D / np.linalg.norm(D, axis=0)
                    self.evalset.test_data = prototypes[
                        iteration2 * self.args.nb_cl +
                        iter_dico][:, :, :, ::-1].astype('uint8')
                    evalloader = torch.utils.data.DataLoader(
                        self.evalset,
                        batch_size=self.args.eval_batch_size,
                        shuffle=False,
                        num_workers=self.args.num_workers)
                    mapped_prototypes2 = compute_features(
                        tg_model, free_model, tg_feature_model,
                        is_start_iteration, evalloader, num_samples,
                        num_features)
                    D2 = mapped_prototypes2.T
                    D2 = D2 / np.linalg.norm(D2, axis=0)
                    alph = alpha_dr_herding[iteration2, :, iter_dico]
                    alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1.
                    alph = alph / np.sum(alph)
                    class_means[:, current_cl[iter_dico],
                                0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2
                    class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm(
                        class_means[:, current_cl[iter_dico], 0])
                    alph = np.ones(dictionary_size) / dictionary_size
                    class_means[:, current_cl[iter_dico],
                                1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2
                    class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm(
                        class_means[:, current_cl[iter_dico], 1])
            torch.save(
                class_means,
                osp.join(
                    self.save_path,
                    'run_{}_iteration_{}_class_means.pth'.format(
                        iteration_total, iteration)))
            current_means = class_means[:, order[range(0, (iteration + 1) *
                                                       self.args.nb_cl)]]
            is_start_iteration = (iteration == start_iter)
            map_Y_valid_ori = np.array(
                [order_list.index(i) for i in Y_valid_ori])
            print('Computing accuracy on the original batch of classes')
            self.evalset.test_data = X_valid_ori.astype('uint8')
            self.evalset.test_labels = map_Y_valid_ori
            evalloader = torch.utils.data.DataLoader(
                self.evalset,
                batch_size=self.args.eval_batch_size,
                shuffle=False,
                num_workers=self.args.num_workers)
            ori_acc, fast_fc = compute_accuracy(
                tg_model,
                free_model,
                tg_feature_model,
                current_means,
                X_protoset_cumuls,
                Y_protoset_cumuls,
                evalloader,
                order_list,
                is_start_iteration=is_start_iteration,
                maml_lr=self.args.maml_lr,
                maml_epoch=self.args.maml_epoch)
            top1_acc_list_ori[iteration, :,
                              iteration_total] = np.array(ori_acc).T
            self.train_writer.add_scalar('ori_acc/LwF', float(ori_acc[0]),
                                         iteration)
            self.train_writer.add_scalar('ori_acc/iCaRL', float(ori_acc[1]),
                                         iteration)
            map_Y_valid_cumul = np.array(
                [order_list.index(i) for i in Y_valid_cumul])
            print('Computing cumulative accuracy')
            self.evalset.test_data = X_valid_cumul.astype('uint8')
            self.evalset.test_labels = map_Y_valid_cumul
            evalloader = torch.utils.data.DataLoader(
                self.evalset,
                batch_size=self.args.eval_batch_size,
                shuffle=False,
                num_workers=self.args.num_workers)
            cumul_acc, _ = compute_accuracy(
                tg_model,
                free_model,
                tg_feature_model,
                current_means,
                X_protoset_cumuls,
                Y_protoset_cumuls,
                evalloader,
                order_list,
                is_start_iteration=is_start_iteration,
                fast_fc=fast_fc,
                maml_lr=self.args.maml_lr,
                maml_epoch=self.args.maml_epoch)
            top1_acc_list_cumul[iteration, :,
                                iteration_total] = np.array(cumul_acc).T
            self.train_writer.add_scalar('cumul_acc/LwF', float(cumul_acc[0]),
                                         iteration)
            self.train_writer.add_scalar('cumul_acc/iCaRL',
                                         float(cumul_acc[1]), iteration)
        torch.save(
            top1_acc_list_ori,
            osp.join(self.save_path,
                     'run_{}_top1_acc_list_ori.pth'.format(iteration_total)))
        torch.save(
            top1_acc_list_cumul,
            osp.join(self.save_path,
                     'run_{}_top1_acc_list_cumul.pth'.format(iteration_total)))
        self.train_writer.close
def incremental_train_and_eval(the_args, epochs, fusion_vars, ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, \
    tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, iteration, \
    start_iteration, X_protoset_cumuls, Y_protoset_cumuls, order_list, the_lambda, dist, \
    K, lw_mr, balancedloader, fix_bn=False, weight_per_class=None, device=None):

    # Setting up the CUDA device
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # Set the 1st branch reference model to the evaluation mode
    ref_model.eval()

    # Get the number of old classes
    num_old_classes = ref_model.fc.out_features

    # Get the features from the current and the reference model
    handle_ref_features = ref_model.fc.register_forward_hook(get_ref_features)
    handle_cur_features = b1_model.fc.register_forward_hook(get_cur_features)
    handle_old_scores_bs = b1_model.fc.fc1.register_forward_hook(
        get_old_scores_before_scale)
    handle_new_scores_bs = b1_model.fc.fc2.register_forward_hook(
        get_new_scores_before_scale)

    # If the 2nd branch reference is not None, set it to the evaluation mode
    if iteration > start_iteration + 1:
        ref_b2_model.eval()

    for epoch in range(epochs):
        # Start training for the current phase, set the two branch models to the training mode
        b1_model.train()
        b2_model.train()

        # Fix the batch norm parameters according to the config
        if fix_bn:
            for m in b1_model.modules():
                if isinstance(m, nn.BatchNorm2d):
                    m.eval()

        # Set all the losses to zeros
        train_loss = 0
        train_loss1 = 0
        train_loss2 = 0
        train_loss3 = 0
        # Set the counters to zeros
        correct = 0
        total = 0

        # Learning rate decay
        tg_lr_scheduler.step()
        fusion_lr_scheduler.step()

        # Print the information
        print('\nEpoch: %d, learning rate: ' % epoch, end='')
        print(tg_lr_scheduler.get_lr()[0])

        for batch_idx, (inputs, targets) in enumerate(trainloader):

            # Get a batch of training samples, transfer them to the device
            inputs, targets = inputs.to(device), targets.to(device)

            # Clear the gradient of the paramaters for the tg_optimizer
            tg_optimizer.zero_grad()

            # Forward the samples in the deep networks
            outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                           b2_model, inputs)

            # Loss 1: feature-level distillation loss
            if iteration == start_iteration + 1:
                ref_outputs = ref_model(inputs)
                loss1 = nn.CosineEmbeddingLoss()(
                    cur_features, ref_features.detach(),
                    torch.ones(inputs.shape[0]).to(device)) * the_lambda
            else:
                ref_outputs, ref_features_new = process_inputs_fp(
                    the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs)
                loss1 = nn.CosineEmbeddingLoss()(
                    cur_features, ref_features_new.detach(),
                    torch.ones(inputs.shape[0]).to(device)) * the_lambda

            # Loss 2: classification loss
            loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)

            # Loss 3: margin ranking loss
            outputs_bs = torch.cat((old_scores, new_scores), dim=1)
            assert (outputs_bs.size() == outputs.size())
            gt_index = torch.zeros(outputs_bs.size()).to(device)
            gt_index = gt_index.scatter(1, targets.view(-1, 1), 1).ge(0.5)
            gt_scores = outputs_bs.masked_select(gt_index)
            max_novel_scores = outputs_bs[:, num_old_classes:].topk(K,
                                                                    dim=1)[0]
            hard_index = targets.lt(num_old_classes)
            hard_num = torch.nonzero(hard_index).size(0)
            if hard_num > 0:
                gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, K)
                max_novel_scores = max_novel_scores[hard_index]
                assert (gt_scores.size() == max_novel_scores.size())
                assert (gt_scores.size(0) == hard_num)
                loss3 = nn.MarginRankingLoss(margin=dist)(
                    gt_scores.view(-1, 1), max_novel_scores.view(-1, 1),
                    torch.ones(hard_num * K).to(device)) * lw_mr
            else:
                loss3 = torch.zeros(1).to(device)

            # Sum up all looses
            loss = loss1 + loss2 + loss3

            # Backward and update the parameters
            loss.backward()
            tg_optimizer.step()

            # Record the losses and the number of samples to compute the accuracy
            train_loss += loss.item()
            train_loss1 += loss1.item()
            train_loss2 += loss2.item()
            train_loss3 += loss3.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        # Print the training losses and accuracies
        print(
            'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss3: {:.4f}, train loss: {:.4f} accuracy: {:.4f}'
            .format(len(trainloader), train_loss1 / (batch_idx + 1),
                    train_loss2 / (batch_idx + 1),
                    train_loss3 / (batch_idx + 1),
                    train_loss / (batch_idx + 1), 100. * correct / total))

        # Update the aggregation weights
        b1_model.eval()
        b2_model.eval()

        for batch_idx, (inputs, targets) in enumerate(balancedloader):
            if batch_idx <= 500:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                               b2_model, inputs)
                loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
                loss.backward()
                fusion_optimizer.step()

        # Running the test for this epoch
        b1_model.eval()
        b2_model.eval()
        test_loss = 0
        correct = 0
        total = 0
        with torch.no_grad():
            for batch_idx, (inputs, targets) in enumerate(testloader):
                inputs, targets = inputs.to(device), targets.to(device)
                outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model,
                                               b2_model, inputs)
                loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets)
                test_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format(
            len(testloader), test_loss / (batch_idx + 1),
            100. * correct / total))

    print("Removing register forward hook")
    handle_ref_features.remove()
    handle_cur_features.remove()
    handle_old_scores_bs.remove()
    handle_new_scores_bs.remove()
    return b1_model, b2_model
Esempio n. 7
0
def compute_accuracy(the_args, fusion_vars, b1_model, b2_model, tg_feature_model, class_means, \
    X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=False, \
    fast_fc=None, scale=None, print_info=True, device=None, cifar=True, imagenet=False, \
    valdir=None, maml_lr=0.1, maml_epoch=50):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    b1_model.eval()
    tg_feature_model.eval()
    b1_model.eval()
    if b2_model is not None:
        b2_model.eval()
    fast_fc = 0.0
    correct = 0
    correct_icarl = 0
    correct_icarl_cosine = 0
    correct_icarl_cosine2 = 0
    correct_ncm = 0
    correct_maml = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(evalloader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)

            if is_start_iteration:
                outputs = b1_model(inputs)
            else:
                outputs, outputs_feature = process_inputs_fp(
                    the_args, fusion_vars, b1_model, b2_model, inputs)

            outputs = F.softmax(outputs, dim=1)
            if scale is not None:
                assert (scale.shape[0] == 1)
                assert (outputs.shape[1] == scale.shape[1])
                outputs = outputs / scale.repeat(outputs.shape[0], 1).type(
                    torch.FloatTensor).to(device)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

            if is_start_iteration:
                outputs_feature = np.squeeze(tg_feature_model(inputs))
            sqd_icarl = cdist(class_means[:, :, 0].T, outputs_feature.cpu(),
                              'sqeuclidean')
            score_icarl = torch.from_numpy((-sqd_icarl).T).to(device)
            _, predicted_icarl = score_icarl.max(1)
            correct_icarl += predicted_icarl.eq(targets).sum().item()
            sqd_icarl_cosine = cdist(class_means[:, :, 0].T,
                                     outputs_feature.cpu(), 'cosine')
            score_icarl_cosine = torch.from_numpy(
                (-sqd_icarl_cosine).T).to(device)
            _, predicted_icarl_cosine = score_icarl_cosine.max(1)
            correct_icarl_cosine += predicted_icarl_cosine.eq(
                targets).sum().item()
            fast_weights = torch.from_numpy(np.float32(
                class_means[:, :, 0].T)).to(device)
            sqd_icarl_cosine2 = F.linear(
                F.normalize(torch.squeeze(outputs_feature), p=2, dim=1),
                F.normalize(fast_weights, p=2, dim=1))
            score_icarl_cosine2 = sqd_icarl_cosine2
            _, predicted_icarl_cosine2 = score_icarl_cosine2.max(1)
            correct_icarl_cosine2 += predicted_icarl_cosine2.eq(
                targets).sum().item()
            sqd_ncm = cdist(class_means[:, :, 1].T, outputs_feature.cpu(),
                            'sqeuclidean')
            score_ncm = torch.from_numpy((-sqd_ncm).T).to(device)
            _, predicted_ncm = score_ncm.max(1)
            correct_ncm += predicted_ncm.eq(targets).sum().item()
    if print_info:
        print("  Top 1 accuracy CNN            :\t\t{:.2f} %".format(
            100. * correct / total))
        print("  Top 1 accuracy iCaRL          :\t\t{:.2f} %".format(
            100. * correct_icarl / total))
        print("  Top 1 accuracy iCaRL-UB       :\t\t{:.2f} %".format(
            100. * correct_ncm / total))
        print("  The above results are the accuracy for the current phase.")
        print(
            "  For the average accuracy, you need to record the results for all phases and calculate the average value."
        )
    cnn_acc = 100. * correct / total
    icarl_acc = 100. * correct_icarl / total
    ncm_acc = 100. * correct_ncm / total
    maml_acc = 0.0
    return [cnn_acc, icarl_acc, ncm_acc, maml_acc], fast_fc
Esempio n. 8
0
def compute_accuracy(tg_model,
                     free_model,
                     tg_feature_model,
                     class_means,
                     X_protoset_cumuls,
                     Y_protoset_cumuls,
                     evalloader,
                     order_list,
                     is_start_iteration=False,
                     fast_fc=None,
                     scale=None,
                     print_info=True,
                     device=None,
                     maml_lr=0.1,
                     maml_epoch=50):
    if device is None:
        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tg_feature_model.eval()
    tg_model.eval()
    if free_model is not None:
        free_model.eval()
    if fast_fc is None:
        transform_proto = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.5071, 0.4866, 0.4409),
                                 (0.2009, 0.1984, 0.2023)),
        ])
        protoset = torchvision.datasets.CIFAR100(root='./data',
                                                 train=False,
                                                 download=False,
                                                 transform=transform_proto)
        X_protoset_array = np.array(X_protoset_cumuls).astype('uint8')
        protoset.test_data = X_protoset_array.reshape(
            -1, X_protoset_array.shape[2], X_protoset_array.shape[3],
            X_protoset_array.shape[4])
        Y_protoset_cumuls = np.array(Y_protoset_cumuls).reshape(-1)
        map_Y_protoset_cumuls = map_labels(order_list, Y_protoset_cumuls)
        protoset.test_labels = map_Y_protoset_cumuls
        protoloader = torch.utils.data.DataLoader(protoset,
                                                  batch_size=128,
                                                  shuffle=True,
                                                  num_workers=2)

        fast_fc = torch.from_numpy(np.float32(class_means[:, :,
                                                          0].T)).to(device)
        fast_fc.requires_grad = True

        epoch_num = maml_epoch
        for epoch_idx in range(epoch_num):
            for the_inputs, the_targets in protoloader:
                the_inputs, the_targets = the_inputs.to(
                    device), the_targets.to(device)
                the_features = tg_feature_model(the_inputs)
                the_logits = F.linear(
                    F.normalize(torch.squeeze(the_features), p=2, dim=1),
                    F.normalize(fast_fc, p=2, dim=1))
                the_loss = F.cross_entropy(the_logits, the_targets)
                the_grad = torch.autograd.grad(the_loss, fast_fc)
                fast_fc = fast_fc - maml_lr * the_grad[0]
    correct = 0
    correct_icarl = 0
    correct_ncm = 0
    correct_maml = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(evalloader):
            inputs, targets = inputs.to(device), targets.to(device)
            total += targets.size(0)
            if is_start_iteration:
                outputs = tg_model(inputs)
            else:
                outputs, outputs_feature = process_inputs_fp(
                    tg_model, free_model, inputs)
            outputs = F.softmax(outputs, dim=1)
            if scale is not None:
                assert (scale.shape[0] == 1)
                assert (outputs.shape[1] == scale.shape[1])
                outputs = outputs / scale.repeat(outputs.shape[0], 1).type(
                    torch.FloatTensor).to(device)
            _, predicted = outputs.max(1)
            correct += predicted.eq(targets).sum().item()

            if is_start_iteration:
                outputs_feature = np.squeeze(tg_feature_model(inputs))
            sqd_icarl = cdist(class_means[:, :, 0].T, outputs_feature,
                              'sqeuclidean')
            score_icarl = torch.from_numpy((-sqd_icarl).T).to(device)
            _, predicted_icarl = score_icarl.max(1)
            correct_icarl += predicted_icarl.eq(targets).sum().item()
            sqd_ncm = cdist(class_means[:, :, 1].T, outputs_feature,
                            'sqeuclidean')
            score_ncm = torch.from_numpy((-sqd_ncm).T).to(device)
            _, predicted_ncm = score_ncm.max(1)
            correct_ncm += predicted_ncm.eq(targets).sum().item()
            the_logits = F.linear(
                F.normalize(torch.squeeze(outputs_feature), p=2, dim=1),
                F.normalize(fast_fc, p=2, dim=1))
            _, predicted_maml = the_logits.max(1)
            correct_maml += predicted_maml.eq(targets).sum().item()
    cnn_acc = 100. * correct / total
    icarl_acc = 100. * correct_icarl / total
    ncm_acc = 100. * correct_ncm / total
    maml_acc = 100. * correct_maml / total
    if print_info:
        print("  Accuracy for LwF    :\t\t{:.2f} %".format(cnn_acc))
        print("  Accuracy for iCaRL  :\t\t{:.2f} %".format(icarl_acc))
        print("  The above results are the accuracy for the current phase.")
        print(
            "  For the average accuracy, you need to record the results for all phases and calculate the average value."
        )
    return [cnn_acc, icarl_acc, ncm_acc, maml_acc], fast_fc