def compute_features(tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features, device=None): if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tg_feature_model.eval() tg_model.eval() if free_model is not None: free_model.eval() features = np.zeros([num_samples, num_features]) start_idx = 0 with torch.no_grad(): for inputs, targets in evalloader: inputs = inputs.to(device) if is_start_iteration: the_feature = tg_feature_model(inputs) else: the_feature = process_inputs_fp(tg_model, free_model, inputs, feature_mode=True) features[start_idx:start_idx + inputs.shape[0], :] = np.squeeze(the_feature) start_idx = start_idx + inputs.shape[0] assert (start_idx == num_samples) return features
def incremental_train_and_eval(the_args, epochs, fusion_vars, ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, balancedloader, iteration, start_iteration, X_protoset_cumuls, Y_protoset_cumuls, order_list, lamda, dist, K, lw_mr, fix_bn=False, weight_per_class=None, device=None): T = 2.0 beta = 0.25 if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ref_model.eval() num_old_classes = ref_model.fc.out_features if iteration > start_iteration + 1: ref_b2_model.eval() for epoch in range(epochs): b1_model.train() b2_model.train() if fix_bn: for m in b1_model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() train_loss = 0 train_loss1 = 0 train_loss2 = 0 correct = 0 total = 0 tg_lr_scheduler.step() fusion_lr_scheduler.step() print('\nEpoch: %d, learning rate: ' % epoch, end='') print(tg_lr_scheduler.get_lr()[0]) for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) tg_optimizer.zero_grad() outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) if iteration == start_iteration + 1: ref_outputs = ref_model(inputs) else: ref_outputs, ref_features_new = process_inputs_fp( the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs) loss1 = nn.KLDivLoss()(F.log_softmax(outputs[:,:num_old_classes]/T, dim=1), \ F.softmax(ref_outputs.detach()/T, dim=1)) * T * T * beta * num_old_classes loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) loss = loss1 + loss2 loss.backward() tg_optimizer.step() train_loss += loss.item() train_loss1 += loss1.item() train_loss2 += loss2.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print( 'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss: {:.4f} accuracy: {:.4f}' .format(len(trainloader), train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1), train_loss / (batch_idx + 1), 100. * correct / total)) b1_model.eval() b2_model.eval() for batch_idx, (inputs, targets) in enumerate(balancedloader): fusion_optimizer.zero_grad() inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) loss.backward() fusion_optimizer.step() b1_model.eval() b2_model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format( len(testloader), test_loss / (batch_idx + 1), 100. * correct / total)) print("Removing register forward hook") return b1_model, b2_model
def incremental_train_and_eval(the_args, epochs, fusion_vars, ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, iteration, start_iteration, X_protoset_cumuls, Y_protoset_cumuls, order_list, lamda, dist, K, lw_mr, balancedloader, T=None, beta=None, fix_bn=False, weight_per_class=None, device=None): # Setting up the CUDA device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Set the 1st branch reference model to the evaluation mode ref_model.eval() # Get the number of old classes num_old_classes = ref_model.fc.out_features # If the 2nd branch reference is not None, set it to the evaluation mode if iteration > start_iteration + 1: ref_b2_model.eval() for epoch in range(epochs): # Start training for the current phase, set the two branch models to the training mode b1_model.train() b2_model.train() # Fix the batch norm parameters according to the config if fix_bn: for m in b1_model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() # Set all the losses to zeros train_loss = 0 train_loss1 = 0 train_loss2 = 0 # Set the counters to zeros correct = 0 total = 0 # Learning rate decay tg_lr_scheduler.step() fusion_lr_scheduler.step() # Print the information print('\nEpoch: %d, learning rate: ' % epoch, end='') print(tg_lr_scheduler.get_lr()[0]) for batch_idx, (inputs, targets) in enumerate(trainloader): # Get a batch of training samples, transfer them to the device inputs, targets = inputs.to(device), targets.to(device) # Clear the gradient of the paramaters for the tg_optimizer tg_optimizer.zero_grad() # Forward the samples in the deep networks outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) if iteration == start_iteration + 1: ref_outputs = ref_model(inputs) else: ref_outputs, ref_features_new = process_inputs_fp( the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs) # Loss 1: feature-level distillation loss loss1 = nn.KLDivLoss()(F.log_softmax(outputs[:,:num_old_classes]/T, dim=1), \ F.softmax(ref_outputs.detach()/T, dim=1)) * T * T * beta * num_old_classes # Loss 2: classification loss loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) # Sum up all looses loss = loss1 + loss2 # Backward and update the parameters loss.backward() tg_optimizer.step() # Record the losses and the number of samples to compute the accuracy train_loss += loss.item() train_loss1 += loss1.item() train_loss2 += loss2.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() # Print the training losses and accuracies print( 'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss: {:.4f} accuracy: {:.4f}' .format(len(trainloader), train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1), train_loss / (batch_idx + 1), 100. * correct / total)) # Update the aggregation weights b1_model.eval() b2_model.eval() for batch_idx, (inputs, targets) in enumerate(balancedloader): fusion_optimizer.zero_grad() inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) loss.backward() fusion_optimizer.step() # Running the test for this epoch b1_model.eval() b2_model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format( len(testloader), test_loss / (batch_idx + 1), 100. * correct / total)) print("Removing register forward hook") return b1_model, b2_model
def incremental_train_and_eval(the_args, epochs, fusion_vars, ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, iteration, start_iteration, X_protoset_cumuls, Y_protoset_cumuls, order_list, lamda, dist, K, lw_mr, fix_bn=False, weight_per_class=None, device=None): if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ref_model.eval() num_old_classes = ref_model.fc.out_features handle_ref_features = ref_model.fc.register_forward_hook(get_ref_features) handle_cur_features = b1_model.fc.register_forward_hook(get_cur_features) handle_old_scores_bs = b1_model.fc.fc1.register_forward_hook( get_old_scores_before_scale) handle_new_scores_bs = b1_model.fc.fc2.register_forward_hook( get_new_scores_before_scale) if iteration > start_iteration + 1: ref_b2_model.eval() for epoch in range(epochs): b1_model.train() b2_model.train() if fix_bn: for m in b1_model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() train_loss = 0 train_loss1 = 0 train_loss2 = 0 train_loss3 = 0 correct = 0 total = 0 tg_lr_scheduler.step() fusion_lr_scheduler.step() print('\nEpoch: %d, learning rate: ' % epoch, end='') print(tg_lr_scheduler.get_lr()[0]) for batch_idx, (inputs, targets) in enumerate(trainloader): inputs, targets = inputs.to(device), targets.to(device) tg_optimizer.zero_grad() outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) if iteration == start_iteration + 1: ref_outputs = ref_model(inputs) loss1 = nn.CosineEmbeddingLoss()( cur_features, ref_features.detach(), torch.ones(inputs.shape[0]).to(device)) * lamda else: ref_outputs, ref_features_new = process_inputs_fp( the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs) loss1 = nn.CosineEmbeddingLoss()( cur_features, ref_features_new.detach(), torch.ones(inputs.shape[0]).to(device)) * lamda loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) outputs_bs = torch.cat((old_scores, new_scores), dim=1) assert (outputs_bs.size() == outputs.size()) gt_index = torch.zeros(outputs_bs.size()).to(device) gt_index = gt_index.scatter(1, targets.view(-1, 1), 1).ge(0.5) gt_scores = outputs_bs.masked_select(gt_index) max_novel_scores = outputs_bs[:, num_old_classes:].topk(K, dim=1)[0] hard_index = targets.lt(num_old_classes) hard_num = torch.nonzero(hard_index).size(0) if hard_num > 0: gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, K) max_novel_scores = max_novel_scores[hard_index] assert (gt_scores.size() == max_novel_scores.size()) assert (gt_scores.size(0) == hard_num) loss3 = nn.MarginRankingLoss(margin=dist)( gt_scores.view(-1, 1), max_novel_scores.view(-1, 1), torch.ones(hard_num * K).to(device)) * lw_mr else: loss3 = torch.zeros(1).to(device) loss = loss1 + loss2 + loss3 loss.backward() tg_optimizer.step() train_loss += loss.item() train_loss1 += loss1.item() train_loss2 += loss2.item() train_loss3 += loss3.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print( 'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss3: {:.4f}, train loss: {:.4f} accuracy: {:.4f}' .format(len(trainloader), train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1), train_loss3 / (batch_idx + 1), train_loss / (batch_idx + 1), 100. * correct / total)) b1_model.eval() b2_model.eval() for batch_idx, (inputs, targets) in enumerate(testloader): fusion_optimizer.zero_grad() inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) loss.backward() fusion_optimizer.step() print('Fusion index: mtl1, free1, mtl2, free2, mtl3, free3') print( 'Fusion vars: {:.2f}, {:.2f}, {:.2f}, {:.2f}, {:.2f}, {:.2f}' .format(float(fusion_vars[0]), 1.0 - float(fusion_vars[0]), float(fusion_vars[1]), 1.0 - float(fusion_vars[1]), float(fusion_vars[2]), 1.0 - float(fusion_vars[2]))) print( 'Ref fusion vars: {:.2f}, {:.2f}, {:.2f}, {:.2f}, {:.2f}, {:.2f}' .format(float(ref_fusion_vars[0]), 1.0 - float(ref_fusion_vars[0]), float(ref_fusion_vars[1]), 1.0 - float(ref_fusion_vars[1]), float(ref_fusion_vars[2]), 1.0 - float(ref_fusion_vars[2]))) b1_model.eval() b2_model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format( len(testloader), test_loss / (batch_idx + 1), 100. * correct / total)) print("Removing register forward hook") handle_ref_features.remove() handle_cur_features.remove() handle_old_scores_bs.remove() handle_new_scores_bs.remove() return b1_model, b2_model
def train(self): self.train_writer = SummaryWriter(logdir=self.save_path) dictionary_size = self.dictionary_size top1_acc_list_cumul = np.zeros( (int(self.args.num_classes / self.args.nb_cl), 4, self.args.nb_runs)) top1_acc_list_ori = np.zeros( (int(self.args.num_classes / self.args.nb_cl), 4, self.args.nb_runs)) X_train_total = np.array(self.trainset.train_data) Y_train_total = np.array(self.trainset.train_labels) X_valid_total = np.array(self.testset.test_data) Y_valid_total = np.array(self.testset.test_labels) np.random.seed(1993) for iteration_total in range(self.args.nb_runs): order_name = osp.join( self.save_path, "seed_{}_{}_order_run_{}.pkl".format(1993, self.args.dataset, iteration_total)) print("Order name:{}".format(order_name)) if osp.exists(order_name): print("Loading orders") order = utils.misc.unpickle(order_name) else: print("Generating orders") order = np.arange(self.args.num_classes) np.random.shuffle(order) utils.misc.savepickle(order, order_name) order_list = list(order) print(order_list) np.random.seed(self.args.random_seed) X_valid_cumuls = [] X_protoset_cumuls = [] X_train_cumuls = [] Y_valid_cumuls = [] Y_protoset_cumuls = [] Y_train_cumuls = [] alpha_dr_herding = np.zeros( (int(self.args.num_classes / self.args.nb_cl), dictionary_size, self.args.nb_cl), np.float32) prototypes = np.zeros( (self.args.num_classes, dictionary_size, X_train_total.shape[1], X_train_total.shape[2], X_train_total.shape[3])) for orde in range(self.args.num_classes): prototypes[orde, :, :, :, :] = X_train_total[np.where( Y_train_total == order[orde])] start_iter = int(self.args.nb_cl_fg / self.args.nb_cl) - 1 for iteration in range(start_iter, int(self.args.num_classes / self.args.nb_cl)): if iteration == start_iter: last_iter = 0 tg_model = self.network(num_classes=self.args.nb_cl_fg) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("Out_features:", out_features) ref_model = None free_model = None ref_free_model = None elif iteration == start_iter + 1: last_iter = iteration ref_model = copy.deepcopy(tg_model) print("Fusion Mode: " + self.args.fusion_mode) tg_model = self.network_mtl(num_classes=self.args.nb_cl_fg) ref_dict = ref_model.state_dict() tg_dict = tg_model.state_dict() tg_dict.update(ref_dict) tg_model.load_state_dict(tg_dict) tg_model.to(self.device) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("Out_features:", out_features) new_fc = modified_linear.SplitCosineLinear( in_features, out_features, self.args.nb_cl) new_fc.fc1.weight.data = tg_model.fc.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = out_features * 1.0 / self.args.nb_cl else: last_iter = iteration ref_model = copy.deepcopy(tg_model) in_features = tg_model.fc.in_features out_features1 = tg_model.fc.fc1.out_features out_features2 = tg_model.fc.fc2.out_features print("Out_features:", out_features1 + out_features2) new_fc = modified_linear.SplitCosineLinear( in_features, out_features1 + out_features2, self.args.nb_cl) new_fc.fc1.weight.data[: out_features1] = tg_model.fc.fc1.weight.data new_fc.fc1.weight.data[ out_features1:] = tg_model.fc.fc2.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = (out_features1 + out_features2) * 1.0 / (self.args.nb_cl) if iteration > start_iter: cur_lamda = self.args.lamda * math.sqrt(lamda_mult) else: cur_lamda = self.args.lamda actual_cl = order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] indices_train_10 = np.array([ i in order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] for i in Y_train_total ]) indices_test_10 = np.array([ i in order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] for i in Y_valid_total ]) X_train = X_train_total[indices_train_10] X_valid = X_valid_total[indices_test_10] X_valid_cumuls.append(X_valid) X_train_cumuls.append(X_train) X_valid_cumul = np.concatenate(X_valid_cumuls) X_train_cumul = np.concatenate(X_train_cumuls) Y_train = Y_train_total[indices_train_10] Y_valid = Y_valid_total[indices_test_10] Y_valid_cumuls.append(Y_valid) Y_train_cumuls.append(Y_train) Y_valid_cumul = np.concatenate(Y_valid_cumuls) Y_train_cumul = np.concatenate(Y_train_cumuls) if iteration == start_iter: X_valid_ori = X_valid Y_valid_ori = Y_valid else: X_protoset = np.concatenate(X_protoset_cumuls) Y_protoset = np.concatenate(Y_protoset_cumuls) if self.args.rs_ratio > 0: scale_factor = (len(X_train) * self.args.rs_ratio) / ( len(X_protoset) * (1 - self.args.rs_ratio)) rs_sample_weights = np.concatenate( (np.ones(len(X_train)), np.ones(len(X_protoset)) * scale_factor)) rs_num_samples = int( len(X_train) / (1 - self.args.rs_ratio)) print( "X_train:{}, X_protoset:{}, rs_num_samples:{}".format( len(X_train), len(X_protoset), rs_num_samples)) X_train = np.concatenate((X_train, X_protoset), axis=0) Y_train = np.concatenate((Y_train, Y_protoset)) print('Batch of classes number {0} arrives'.format(iteration + 1)) map_Y_train = np.array([order_list.index(i) for i in Y_train]) map_Y_valid_cumul = np.array( [order_list.index(i) for i in Y_valid_cumul]) is_start_iteration = (iteration == start_iter) if iteration > start_iter: old_embedding_norm = tg_model.fc.fc1.weight.data.norm( dim=1, keepdim=True) average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type( torch.DoubleTensor) tg_feature_model = nn.Sequential( *list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features novel_embedding = torch.zeros((self.args.nb_cl, num_features)) for cls_idx in range(iteration * self.args.nb_cl, (iteration + 1) * self.args.nb_cl): cls_indices = np.array([i == cls_idx for i in map_Y_train]) assert (len( np.where(cls_indices == 1)[0]) == dictionary_size) self.evalset.test_data = X_train[cls_indices].astype( 'uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] cls_features = compute_features(tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx - iteration * self.args.nb_cl] = F.normalize( cls_embedding, p=2, dim=0) * average_old_embedding_norm tg_model.to(self.device) tg_model.fc.fc2.weight.data = novel_embedding.to(self.device) self.trainset.train_data = X_train.astype('uint8') self.trainset.train_labels = map_Y_train if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1: print("Weights from sampling:", rs_sample_weights) index1 = np.where(rs_sample_weights > 1)[0] index2 = np.where(map_Y_train < iteration * self.args.nb_cl)[0] assert ((index1 == index2).all()) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( rs_sample_weights, rs_num_samples) trainloader = torch.utils.data.DataLoader( self.trainset, batch_size=self.args.train_batch_size, shuffle=False, sampler=train_sampler, num_workers=self.args.num_workers) else: trainloader = torch.utils.data.DataLoader( self.trainset, batch_size=self.args.train_batch_size, shuffle=True, num_workers=self.args.num_workers) self.testset.test_data = X_valid_cumul.astype('uint8') self.testset.test_labels = map_Y_valid_cumul testloader = torch.utils.data.DataLoader( self.testset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) print('Max and min of train labels: {}, {}'.format( min(map_Y_train), max(map_Y_train))) print('Max and min of valid labels: {}, {}'.format( min(map_Y_valid_cumul), max(map_Y_valid_cumul))) ckp_name = osp.join( self.save_path, 'run_{}_iteration_{}_model.pth'.format(iteration_total, iteration)) ckp_name_free = osp.join( self.save_path, 'run_{}_iteration_{}_free_model.pth'.format( iteration_total, iteration)) print('Checkpoint name:', ckp_name) if iteration == start_iter and self.args.resume_fg: print("Loading first group models from checkpoint") tg_model = torch.load(self.args.ckpt_dir_fg) elif self.args.resume and os.path.exists(ckp_name): print("Loading models from checkpoint") tg_model = torch.load(ckp_name) else: if iteration > start_iter: ref_model = ref_model.to(self.device) ignored_params = list(map(id, tg_model.fc.fc1.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, tg_model.parameters()) base_params = filter(lambda p: p.requires_grad, base_params) base_params = filter(lambda p: p.requires_grad, base_params) tg_params_new = [{ 'params': base_params, 'lr': self.args.base_lr2, 'weight_decay': self.args.custom_weight_decay }, { 'params': tg_model.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0 }] tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD( tg_params_new, lr=self.args.base_lr2, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) else: tg_params = tg_model.parameters() tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD( tg_params, lr=self.args.base_lr1, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) if iteration > start_iter: tg_lr_scheduler = lr_scheduler.MultiStepLR( tg_optimizer, milestones=self.lr_strat, gamma=self.args.lr_factor) else: tg_lr_scheduler = lr_scheduler.MultiStepLR( tg_optimizer, milestones=self.lr_strat_first_phase, gamma=self.args.lr_factor) print("Incremental train") if iteration > start_iter: tg_model = incremental_train_and_eval( self.args.epochs, tg_model, ref_model, free_model, ref_free_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) else: tg_model = incremental_train_and_eval( self.args.epochs, tg_model, ref_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) torch.save(tg_model, ckp_name) if self.args.fix_budget: nb_protos_cl = int( np.ceil(self.args.nb_protos * 100. / self.args.nb_cl / (iteration + 1))) else: nb_protos_cl = self.args.nb_protos tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features for iter_dico in range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl): self.evalset.test_data = prototypes[iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features(tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) mu = np.mean(D, axis=1) index1 = int(iter_dico / self.args.nb_cl) index2 = iter_dico % self.args.nb_cl alpha_dr_herding[index1, :, index2] = alpha_dr_herding[index1, :, index2] * 0 w_t = mu iter_herding = 0 iter_herding_eff = 0 while not (np.sum(alpha_dr_herding[index1, :, index2] != 0) == min(nb_protos_cl, 500)) and iter_herding_eff < 1000: tmp_t = np.dot(w_t, D) ind_max = np.argmax(tmp_t) iter_herding_eff += 1 if alpha_dr_herding[index1, ind_max, index2] == 0: alpha_dr_herding[index1, ind_max, index2] = 1 + iter_herding iter_herding += 1 w_t = w_t + mu - D[:, ind_max] X_protoset_cumuls = [] Y_protoset_cumuls = [] class_means = np.zeros((64, 100, 2)) for iteration2 in range(iteration + 1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2 * self.args.nb_cl, (iteration2 + 1) * self.args.nb_cl)] self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico][:, :, :, ::-1].astype('uint8') evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2 / np.linalg.norm(D2, axis=0) alph = alpha_dr_herding[iteration2, :, iter_dico] alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1. X_protoset_cumuls.append( prototypes[iteration2 * self.args.nb_cl + iter_dico, np.where(alph == 1)[0]]) Y_protoset_cumuls.append( order[iteration2 * self.args.nb_cl + iter_dico] * np.ones(len(np.where(alph == 1)[0]))) alph = alph / np.sum(alph) class_means[:, current_cl[iter_dico], 0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 0]) alph = np.ones(dictionary_size) / dictionary_size class_means[:, current_cl[iter_dico], 1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 1]) current_means = class_means[:, order[range(0, (iteration + 1) * self.args.nb_cl)]] X_protoset_array_old = np.array(X_protoset_cumuls) self.T = self.args.mnemonics_steps * self.args.mnemonics_epochs self.img_size = 32 self.mnemonics_lrs = self.args.mnemonics_lr num_classes_incremental = self.args.nb_cl num_classes = self.args.nb_cl_fg nb_cl = self.args.nb_cl transform_proto = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)), ]) self.mnemonics_label = [] if iteration == start_iter: the_X_protoset_array = np.array(X_protoset_cumuls).astype( 'uint8') the_Y_protoset_cumuls = np.array(Y_protoset_cumuls) else: the_X_protoset_array = np.array( X_protoset_cumuls[-num_classes_incremental:]).astype( 'uint8') the_Y_protoset_cumuls = np.array( Y_protoset_cumuls[-num_classes_incremental:]) self.mnemonics_data = torch.zeros(the_X_protoset_array.shape[0], the_X_protoset_array.shape[1], 3, self.img_size, self.img_size) for idx1 in range(the_X_protoset_array.shape[0]): for idx2 in range(the_X_protoset_array.shape[1]): the_img = the_X_protoset_array[idx1][idx2] the_PIL_image = Image.fromarray(the_img) the_PIL_image = transform_proto(the_PIL_image) self.mnemonics_data[idx1][idx2] = the_PIL_image map_Y_label = self.map_labels(order_list, the_Y_protoset_cumuls[idx1]) self.mnemonics_label.append(map_Y_label) self.mnemonics = nn.ParameterList() self.mnemonics.append(nn.Parameter(self.mnemonics_data)) start_iteration = start_iter device = self.device self.mnemonics.to(device) tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) tg_feature_model.eval() tg_model.eval() if free_model is not None: free_model.eval() self.mnemonics_optimizer = optim.SGD( self.mnemonics, lr=self.args.mnemonics_outer_lr, momentum=0.9, weight_decay=5e-4) self.mnemonics_lr_scheduler = optim.lr_scheduler.StepLR( self.mnemonics_optimizer, step_size=self.args.mnemonics_decay_epochs, gamma=self.args.mnemonics_decay_factor) current_means_new = current_means[:, :, 0].T for epoch in range(self.args.mnemonics_total_epochs): train_loss = 0 self.mnemonics_lr_scheduler.step() for batch_idx, (q_inputs, q_targets) in enumerate(trainloader): q_inputs, q_targets = q_inputs.to(device), q_targets.to( device) if iteration == start_iteration: q_feature = tg_feature_model(q_inputs) else: q_feature = process_inputs_fp(tg_model, free_model, q_inputs, feature_mode=True) self.mnemonics_optimizer.zero_grad() total_tr_loss = 0 if iteration == start_iteration: mnemonics_outputs = tg_feature_model( self.mnemonics[0][0]) else: mnemonics_outputs = process_inputs_fp( tg_model, free_model, self.mnemonics[0][0], feature_mode=True) this_class_mean_mnemonics = torch.mean(mnemonics_outputs, dim=0) this_class_mean_mnemonics = torch.squeeze( this_class_mean_mnemonics) total_class_mean_mnemonics = this_class_mean_mnemonics.unsqueeze( dim=0) for mnemonics_idx in range(len(self.mnemonics[0]) - 1): if iteration == start_iteration: mnemonics_outputs = tg_feature_model( self.mnemonics[0][mnemonics_idx + 1]) else: mnemonics_outputs = process_inputs_fp( tg_model, free_model, self.mnemonics[0][mnemonics_idx + 1], feature_mode=True) this_class_mean_mnemonics = torch.mean( mnemonics_outputs, dim=0) this_class_mean_mnemonics = torch.squeeze( this_class_mean_mnemonics) total_class_mean_mnemonics = torch.cat( (total_class_mean_mnemonics, this_class_mean_mnemonics.unsqueeze(dim=0)), dim=0) if iteration == start_iteration: all_cls_means = total_class_mean_mnemonics else: all_cls_means = torch.tensor( current_means_new).float().to(device) all_cls_means[-nb_cl:] = total_class_mean_mnemonics the_logits = F.linear( F.normalize(torch.squeeze(q_feature), p=2, dim=1), F.normalize(all_cls_means, p=2, dim=1)) loss = F.cross_entropy(the_logits, q_targets) loss.backward() train_loss += loss.item() X_protoset_cumuls = process_mnemonics( X_protoset_cumuls, Y_protoset_cumuls, self.mnemonics, self.mnemonics_label, order_list, self.args.nb_cl_fg, self.args.nb_cl, iteration, start_iter) X_protoset_array = np.array(X_protoset_cumuls) X_protoset_cumuls_idx = 0 for iteration2 in range(iteration + 1): for iter_dico in range(self.args.nb_cl): alph = alpha_dr_herding[iteration2, :, iter_dico] alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1. this_X_protoset_array = X_protoset_array[ X_protoset_cumuls_idx] X_protoset_cumuls_idx += 1 this_X_protoset_array = this_X_protoset_array.astype( np.float64) prototypes[iteration2 * self.args.nb_cl + iter_dico, np.where(alph == 1)[0]] = this_X_protoset_array class_means = np.zeros((64, 100, 2)) for iteration2 in range(iteration + 1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2 * self.args.nb_cl, (iteration2 + 1) * self.args.nb_cl)] self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico][:, :, :, ::-1].astype('uint8') evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2 / np.linalg.norm(D2, axis=0) alph = alpha_dr_herding[iteration2, :, iter_dico] alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1. alph = alph / np.sum(alph) class_means[:, current_cl[iter_dico], 0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 0]) alph = np.ones(dictionary_size) / dictionary_size class_means[:, current_cl[iter_dico], 1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 1]) torch.save( class_means, osp.join( self.save_path, 'run_{}_iteration_{}_class_means.pth'.format( iteration_total, iteration))) current_means = class_means[:, order[range(0, (iteration + 1) * self.args.nb_cl)]] is_start_iteration = (iteration == start_iter) map_Y_valid_ori = np.array( [order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes') self.evalset.test_data = X_valid_ori.astype('uint8') self.evalset.test_labels = map_Y_valid_ori evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) ori_acc, fast_fc = compute_accuracy( tg_model, free_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T self.train_writer.add_scalar('ori_acc/LwF', float(ori_acc[0]), iteration) self.train_writer.add_scalar('ori_acc/iCaRL', float(ori_acc[1]), iteration) map_Y_valid_cumul = np.array( [order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy') self.evalset.test_data = X_valid_cumul.astype('uint8') self.evalset.test_labels = map_Y_valid_cumul evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) cumul_acc, _ = compute_accuracy( tg_model, free_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, fast_fc=fast_fc, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T self.train_writer.add_scalar('cumul_acc/LwF', float(cumul_acc[0]), iteration) self.train_writer.add_scalar('cumul_acc/iCaRL', float(cumul_acc[1]), iteration) torch.save( top1_acc_list_ori, osp.join(self.save_path, 'run_{}_top1_acc_list_ori.pth'.format(iteration_total))) torch.save( top1_acc_list_cumul, osp.join(self.save_path, 'run_{}_top1_acc_list_cumul.pth'.format(iteration_total))) self.train_writer.close
def incremental_train_and_eval(the_args, epochs, fusion_vars, ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, \ tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, iteration, \ start_iteration, X_protoset_cumuls, Y_protoset_cumuls, order_list, the_lambda, dist, \ K, lw_mr, balancedloader, fix_bn=False, weight_per_class=None, device=None): # Setting up the CUDA device if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # Set the 1st branch reference model to the evaluation mode ref_model.eval() # Get the number of old classes num_old_classes = ref_model.fc.out_features # Get the features from the current and the reference model handle_ref_features = ref_model.fc.register_forward_hook(get_ref_features) handle_cur_features = b1_model.fc.register_forward_hook(get_cur_features) handle_old_scores_bs = b1_model.fc.fc1.register_forward_hook( get_old_scores_before_scale) handle_new_scores_bs = b1_model.fc.fc2.register_forward_hook( get_new_scores_before_scale) # If the 2nd branch reference is not None, set it to the evaluation mode if iteration > start_iteration + 1: ref_b2_model.eval() for epoch in range(epochs): # Start training for the current phase, set the two branch models to the training mode b1_model.train() b2_model.train() # Fix the batch norm parameters according to the config if fix_bn: for m in b1_model.modules(): if isinstance(m, nn.BatchNorm2d): m.eval() # Set all the losses to zeros train_loss = 0 train_loss1 = 0 train_loss2 = 0 train_loss3 = 0 # Set the counters to zeros correct = 0 total = 0 # Learning rate decay tg_lr_scheduler.step() fusion_lr_scheduler.step() # Print the information print('\nEpoch: %d, learning rate: ' % epoch, end='') print(tg_lr_scheduler.get_lr()[0]) for batch_idx, (inputs, targets) in enumerate(trainloader): # Get a batch of training samples, transfer them to the device inputs, targets = inputs.to(device), targets.to(device) # Clear the gradient of the paramaters for the tg_optimizer tg_optimizer.zero_grad() # Forward the samples in the deep networks outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) # Loss 1: feature-level distillation loss if iteration == start_iteration + 1: ref_outputs = ref_model(inputs) loss1 = nn.CosineEmbeddingLoss()( cur_features, ref_features.detach(), torch.ones(inputs.shape[0]).to(device)) * the_lambda else: ref_outputs, ref_features_new = process_inputs_fp( the_args, ref_fusion_vars, ref_model, ref_b2_model, inputs) loss1 = nn.CosineEmbeddingLoss()( cur_features, ref_features_new.detach(), torch.ones(inputs.shape[0]).to(device)) * the_lambda # Loss 2: classification loss loss2 = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) # Loss 3: margin ranking loss outputs_bs = torch.cat((old_scores, new_scores), dim=1) assert (outputs_bs.size() == outputs.size()) gt_index = torch.zeros(outputs_bs.size()).to(device) gt_index = gt_index.scatter(1, targets.view(-1, 1), 1).ge(0.5) gt_scores = outputs_bs.masked_select(gt_index) max_novel_scores = outputs_bs[:, num_old_classes:].topk(K, dim=1)[0] hard_index = targets.lt(num_old_classes) hard_num = torch.nonzero(hard_index).size(0) if hard_num > 0: gt_scores = gt_scores[hard_index].view(-1, 1).repeat(1, K) max_novel_scores = max_novel_scores[hard_index] assert (gt_scores.size() == max_novel_scores.size()) assert (gt_scores.size(0) == hard_num) loss3 = nn.MarginRankingLoss(margin=dist)( gt_scores.view(-1, 1), max_novel_scores.view(-1, 1), torch.ones(hard_num * K).to(device)) * lw_mr else: loss3 = torch.zeros(1).to(device) # Sum up all looses loss = loss1 + loss2 + loss3 # Backward and update the parameters loss.backward() tg_optimizer.step() # Record the losses and the number of samples to compute the accuracy train_loss += loss.item() train_loss1 += loss1.item() train_loss2 += loss2.item() train_loss3 += loss3.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() # Print the training losses and accuracies print( 'Train set: {}, train loss1: {:.4f}, train loss2: {:.4f}, train loss3: {:.4f}, train loss: {:.4f} accuracy: {:.4f}' .format(len(trainloader), train_loss1 / (batch_idx + 1), train_loss2 / (batch_idx + 1), train_loss3 / (batch_idx + 1), train_loss / (batch_idx + 1), 100. * correct / total)) # Update the aggregation weights b1_model.eval() b2_model.eval() for batch_idx, (inputs, targets) in enumerate(balancedloader): if batch_idx <= 500: inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) loss.backward() fusion_optimizer.step() # Running the test for this epoch b1_model.eval() b2_model.eval() test_loss = 0 correct = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(testloader): inputs, targets = inputs.to(device), targets.to(device) outputs, _ = process_inputs_fp(the_args, fusion_vars, b1_model, b2_model, inputs) loss = nn.CrossEntropyLoss(weight_per_class)(outputs, targets) test_loss += loss.item() _, predicted = outputs.max(1) total += targets.size(0) correct += predicted.eq(targets).sum().item() print('Test set: {} test loss: {:.4f} accuracy: {:.4f}'.format( len(testloader), test_loss / (batch_idx + 1), 100. * correct / total)) print("Removing register forward hook") handle_ref_features.remove() handle_cur_features.remove() handle_old_scores_bs.remove() handle_new_scores_bs.remove() return b1_model, b2_model
def compute_accuracy(the_args, fusion_vars, b1_model, b2_model, tg_feature_model, class_means, \ X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=False, \ fast_fc=None, scale=None, print_info=True, device=None, cifar=True, imagenet=False, \ valdir=None, maml_lr=0.1, maml_epoch=50): if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") b1_model.eval() tg_feature_model.eval() b1_model.eval() if b2_model is not None: b2_model.eval() fast_fc = 0.0 correct = 0 correct_icarl = 0 correct_icarl_cosine = 0 correct_icarl_cosine2 = 0 correct_ncm = 0 correct_maml = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(evalloader): inputs, targets = inputs.to(device), targets.to(device) total += targets.size(0) if is_start_iteration: outputs = b1_model(inputs) else: outputs, outputs_feature = process_inputs_fp( the_args, fusion_vars, b1_model, b2_model, inputs) outputs = F.softmax(outputs, dim=1) if scale is not None: assert (scale.shape[0] == 1) assert (outputs.shape[1] == scale.shape[1]) outputs = outputs / scale.repeat(outputs.shape[0], 1).type( torch.FloatTensor).to(device) _, predicted = outputs.max(1) correct += predicted.eq(targets).sum().item() if is_start_iteration: outputs_feature = np.squeeze(tg_feature_model(inputs)) sqd_icarl = cdist(class_means[:, :, 0].T, outputs_feature.cpu(), 'sqeuclidean') score_icarl = torch.from_numpy((-sqd_icarl).T).to(device) _, predicted_icarl = score_icarl.max(1) correct_icarl += predicted_icarl.eq(targets).sum().item() sqd_icarl_cosine = cdist(class_means[:, :, 0].T, outputs_feature.cpu(), 'cosine') score_icarl_cosine = torch.from_numpy( (-sqd_icarl_cosine).T).to(device) _, predicted_icarl_cosine = score_icarl_cosine.max(1) correct_icarl_cosine += predicted_icarl_cosine.eq( targets).sum().item() fast_weights = torch.from_numpy(np.float32( class_means[:, :, 0].T)).to(device) sqd_icarl_cosine2 = F.linear( F.normalize(torch.squeeze(outputs_feature), p=2, dim=1), F.normalize(fast_weights, p=2, dim=1)) score_icarl_cosine2 = sqd_icarl_cosine2 _, predicted_icarl_cosine2 = score_icarl_cosine2.max(1) correct_icarl_cosine2 += predicted_icarl_cosine2.eq( targets).sum().item() sqd_ncm = cdist(class_means[:, :, 1].T, outputs_feature.cpu(), 'sqeuclidean') score_ncm = torch.from_numpy((-sqd_ncm).T).to(device) _, predicted_ncm = score_ncm.max(1) correct_ncm += predicted_ncm.eq(targets).sum().item() if print_info: print(" Top 1 accuracy CNN :\t\t{:.2f} %".format( 100. * correct / total)) print(" Top 1 accuracy iCaRL :\t\t{:.2f} %".format( 100. * correct_icarl / total)) print(" Top 1 accuracy iCaRL-UB :\t\t{:.2f} %".format( 100. * correct_ncm / total)) print(" The above results are the accuracy for the current phase.") print( " For the average accuracy, you need to record the results for all phases and calculate the average value." ) cnn_acc = 100. * correct / total icarl_acc = 100. * correct_icarl / total ncm_acc = 100. * correct_ncm / total maml_acc = 0.0 return [cnn_acc, icarl_acc, ncm_acc, maml_acc], fast_fc
def compute_accuracy(tg_model, free_model, tg_feature_model, class_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=False, fast_fc=None, scale=None, print_info=True, device=None, maml_lr=0.1, maml_epoch=50): if device is None: device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") tg_feature_model.eval() tg_model.eval() if free_model is not None: free_model.eval() if fast_fc is None: transform_proto = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.5071, 0.4866, 0.4409), (0.2009, 0.1984, 0.2023)), ]) protoset = torchvision.datasets.CIFAR100(root='./data', train=False, download=False, transform=transform_proto) X_protoset_array = np.array(X_protoset_cumuls).astype('uint8') protoset.test_data = X_protoset_array.reshape( -1, X_protoset_array.shape[2], X_protoset_array.shape[3], X_protoset_array.shape[4]) Y_protoset_cumuls = np.array(Y_protoset_cumuls).reshape(-1) map_Y_protoset_cumuls = map_labels(order_list, Y_protoset_cumuls) protoset.test_labels = map_Y_protoset_cumuls protoloader = torch.utils.data.DataLoader(protoset, batch_size=128, shuffle=True, num_workers=2) fast_fc = torch.from_numpy(np.float32(class_means[:, :, 0].T)).to(device) fast_fc.requires_grad = True epoch_num = maml_epoch for epoch_idx in range(epoch_num): for the_inputs, the_targets in protoloader: the_inputs, the_targets = the_inputs.to( device), the_targets.to(device) the_features = tg_feature_model(the_inputs) the_logits = F.linear( F.normalize(torch.squeeze(the_features), p=2, dim=1), F.normalize(fast_fc, p=2, dim=1)) the_loss = F.cross_entropy(the_logits, the_targets) the_grad = torch.autograd.grad(the_loss, fast_fc) fast_fc = fast_fc - maml_lr * the_grad[0] correct = 0 correct_icarl = 0 correct_ncm = 0 correct_maml = 0 total = 0 with torch.no_grad(): for batch_idx, (inputs, targets) in enumerate(evalloader): inputs, targets = inputs.to(device), targets.to(device) total += targets.size(0) if is_start_iteration: outputs = tg_model(inputs) else: outputs, outputs_feature = process_inputs_fp( tg_model, free_model, inputs) outputs = F.softmax(outputs, dim=1) if scale is not None: assert (scale.shape[0] == 1) assert (outputs.shape[1] == scale.shape[1]) outputs = outputs / scale.repeat(outputs.shape[0], 1).type( torch.FloatTensor).to(device) _, predicted = outputs.max(1) correct += predicted.eq(targets).sum().item() if is_start_iteration: outputs_feature = np.squeeze(tg_feature_model(inputs)) sqd_icarl = cdist(class_means[:, :, 0].T, outputs_feature, 'sqeuclidean') score_icarl = torch.from_numpy((-sqd_icarl).T).to(device) _, predicted_icarl = score_icarl.max(1) correct_icarl += predicted_icarl.eq(targets).sum().item() sqd_ncm = cdist(class_means[:, :, 1].T, outputs_feature, 'sqeuclidean') score_ncm = torch.from_numpy((-sqd_ncm).T).to(device) _, predicted_ncm = score_ncm.max(1) correct_ncm += predicted_ncm.eq(targets).sum().item() the_logits = F.linear( F.normalize(torch.squeeze(outputs_feature), p=2, dim=1), F.normalize(fast_fc, p=2, dim=1)) _, predicted_maml = the_logits.max(1) correct_maml += predicted_maml.eq(targets).sum().item() cnn_acc = 100. * correct / total icarl_acc = 100. * correct_icarl / total ncm_acc = 100. * correct_ncm / total maml_acc = 100. * correct_maml / total if print_info: print(" Accuracy for LwF :\t\t{:.2f} %".format(cnn_acc)) print(" Accuracy for iCaRL :\t\t{:.2f} %".format(icarl_acc)) print(" The above results are the accuracy for the current phase.") print( " For the average accuracy, you need to record the results for all phases and calculate the average value." ) return [cnn_acc, icarl_acc, ncm_acc, maml_acc], fast_fc