def train(self): self.train_writer = SummaryWriter(logdir=self.save_path) dictionary_size = self.dictionary_size top1_acc_list_cumul = np.zeros( (int(self.args.num_classes / self.args.nb_cl), 4, self.args.nb_runs)) top1_acc_list_ori = np.zeros( (int(self.args.num_classes / self.args.nb_cl), 4, self.args.nb_runs)) X_train_total = np.array(self.trainset.train_data) Y_train_total = np.array(self.trainset.train_labels) X_valid_total = np.array(self.testset.test_data) Y_valid_total = np.array(self.testset.test_labels) np.random.seed(1993) for iteration_total in range(self.args.nb_runs): order_name = osp.join( self.save_path, "seed_{}_{}_order_run_{}.pkl".format(1993, self.args.dataset, iteration_total)) print("Order name:{}".format(order_name)) if osp.exists(order_name): print("Loading orders") order = utils.misc.unpickle(order_name) else: print("Generating orders") order = np.arange(self.args.num_classes) np.random.shuffle(order) utils.misc.savepickle(order, order_name) order_list = list(order) print(order_list) np.random.seed(self.args.random_seed) X_valid_cumuls = [] X_protoset_cumuls = [] X_train_cumuls = [] Y_valid_cumuls = [] Y_protoset_cumuls = [] Y_train_cumuls = [] alpha_dr_herding = np.zeros( (int(self.args.num_classes / self.args.nb_cl), dictionary_size, self.args.nb_cl), np.float32) prototypes = np.zeros( (self.args.num_classes, dictionary_size, X_train_total.shape[1], X_train_total.shape[2], X_train_total.shape[3])) for orde in range(self.args.num_classes): prototypes[orde, :, :, :, :] = X_train_total[np.where( Y_train_total == order[orde])] start_iter = int(self.args.nb_cl_fg / self.args.nb_cl) - 1 for iteration in range(start_iter, int(self.args.num_classes / self.args.nb_cl)): if iteration == start_iter: last_iter = 0 tg_model = self.network(num_classes=self.args.nb_cl_fg) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("Out_features:", out_features) ref_model = None free_model = None ref_free_model = None elif iteration == start_iter + 1: last_iter = iteration ref_model = copy.deepcopy(tg_model) print("Fusion Mode: " + self.args.fusion_mode) tg_model = self.network(num_classes=self.args.nb_cl_fg) ref_dict = ref_model.state_dict() tg_dict = tg_model.state_dict() tg_dict.update(ref_dict) tg_model.load_state_dict(tg_dict) tg_model.to(self.device) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("Out_features:", out_features) new_fc = modified_linear.SplitCosineLinear( in_features, out_features, self.args.nb_cl) new_fc.fc1.weight.data = tg_model.fc.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = out_features * 1.0 / self.args.nb_cl else: last_iter = iteration ref_model = copy.deepcopy(tg_model) in_features = tg_model.fc.in_features out_features1 = tg_model.fc.fc1.out_features out_features2 = tg_model.fc.fc2.out_features print("Out_features:", out_features1 + out_features2) new_fc = modified_linear.SplitCosineLinear( in_features, out_features1 + out_features2, self.args.nb_cl) new_fc.fc1.weight.data[: out_features1] = tg_model.fc.fc1.weight.data new_fc.fc1.weight.data[ out_features1:] = tg_model.fc.fc2.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = (out_features1 + out_features2) * 1.0 / (self.args.nb_cl) if iteration > start_iter: cur_lamda = self.args.lamda * math.sqrt(lamda_mult) else: cur_lamda = self.args.lamda actual_cl = order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] indices_train_10 = np.array([ i in order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] for i in Y_train_total ]) indices_test_10 = np.array([ i in order[range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl)] for i in Y_valid_total ]) X_train = X_train_total[indices_train_10] X_valid = X_valid_total[indices_test_10] X_valid_cumuls.append(X_valid) X_train_cumuls.append(X_train) X_valid_cumul = np.concatenate(X_valid_cumuls) X_train_cumul = np.concatenate(X_train_cumuls) Y_train = Y_train_total[indices_train_10] Y_valid = Y_valid_total[indices_test_10] Y_valid_cumuls.append(Y_valid) Y_train_cumuls.append(Y_train) Y_valid_cumul = np.concatenate(Y_valid_cumuls) Y_train_cumul = np.concatenate(Y_train_cumuls) if iteration == start_iter: X_valid_ori = X_valid Y_valid_ori = Y_valid else: X_protoset = np.concatenate(X_protoset_cumuls) Y_protoset = np.concatenate(Y_protoset_cumuls) if self.args.rs_ratio > 0: scale_factor = (len(X_train) * self.args.rs_ratio) / ( len(X_protoset) * (1 - self.args.rs_ratio)) rs_sample_weights = np.concatenate( (np.ones(len(X_train)), np.ones(len(X_protoset)) * scale_factor)) rs_num_samples = int( len(X_train) / (1 - self.args.rs_ratio)) print( "X_train:{}, X_protoset:{}, rs_num_samples:{}".format( len(X_train), len(X_protoset), rs_num_samples)) X_train = np.concatenate((X_train, X_protoset), axis=0) Y_train = np.concatenate((Y_train, Y_protoset)) print('Batch of classes number {0} arrives'.format(iteration + 1)) map_Y_train = np.array([order_list.index(i) for i in Y_train]) map_Y_valid_cumul = np.array( [order_list.index(i) for i in Y_valid_cumul]) is_start_iteration = (iteration == start_iter) if iteration > start_iter: old_embedding_norm = tg_model.fc.fc1.weight.data.norm( dim=1, keepdim=True) average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type( torch.DoubleTensor) tg_feature_model = nn.Sequential( *list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features novel_embedding = torch.zeros((self.args.nb_cl, num_features)) for cls_idx in range(iteration * self.args.nb_cl, (iteration + 1) * self.args.nb_cl): cls_indices = np.array([i == cls_idx for i in map_Y_train]) assert (len( np.where(cls_indices == 1)[0]) == dictionary_size) self.evalset.test_data = X_train[cls_indices].astype( 'uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] cls_features = compute_features(tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx - iteration * self.args.nb_cl] = F.normalize( cls_embedding, p=2, dim=0) * average_old_embedding_norm tg_model.to(self.device) tg_model.fc.fc2.weight.data = novel_embedding.to(self.device) self.trainset.train_data = X_train.astype('uint8') self.trainset.train_labels = map_Y_train if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1: print("Weights from sampling:", rs_sample_weights) index1 = np.where(rs_sample_weights > 1)[0] index2 = np.where(map_Y_train < iteration * self.args.nb_cl)[0] assert ((index1 == index2).all()) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( rs_sample_weights, rs_num_samples) trainloader = torch.utils.data.DataLoader( self.trainset, batch_size=self.args.train_batch_size, shuffle=False, sampler=train_sampler, num_workers=self.args.num_workers) else: trainloader = torch.utils.data.DataLoader( self.trainset, batch_size=self.args.train_batch_size, shuffle=True, num_workers=self.args.num_workers) self.testset.test_data = X_valid_cumul.astype('uint8') self.testset.test_labels = map_Y_valid_cumul testloader = torch.utils.data.DataLoader( self.testset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) print('Max and min of train labels: {}, {}'.format( min(map_Y_train), max(map_Y_train))) print('Max and min of valid labels: {}, {}'.format( min(map_Y_valid_cumul), max(map_Y_valid_cumul))) ckp_name = osp.join( self.save_path, 'run_{}_iteration_{}_model.pth'.format(iteration_total, iteration)) ckp_name_free = osp.join( self.save_path, 'run_{}_iteration_{}_free_model.pth'.format( iteration_total, iteration)) print('Checkpoint name:', ckp_name) if iteration == start_iter and self.args.resume_fg: print("Loading first group models from checkpoint") tg_model = torch.load(self.args.ckpt_dir_fg) elif self.args.resume and os.path.exists(ckp_name): print("Loading models from checkpoint") tg_model = torch.load(ckp_name) else: if iteration > start_iter: ref_model = ref_model.to(self.device) ignored_params = list(map(id, tg_model.fc.fc1.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, tg_model.parameters()) base_params = filter(lambda p: p.requires_grad, base_params) base_params = filter(lambda p: p.requires_grad, base_params) tg_params_new = [{ 'params': base_params, 'lr': self.args.base_lr2, 'weight_decay': self.args.custom_weight_decay }, { 'params': tg_model.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0 }] tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD( tg_params_new, lr=self.args.base_lr2, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) else: tg_params = tg_model.parameters() tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD( tg_params, lr=self.args.base_lr1, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) if iteration > start_iter: tg_lr_scheduler = lr_scheduler.MultiStepLR( tg_optimizer, milestones=self.lr_strat, gamma=self.args.lr_factor) else: tg_lr_scheduler = lr_scheduler.MultiStepLR( tg_optimizer, milestones=self.lr_strat_first_phase, gamma=self.args.lr_factor) print("Incremental train") if iteration > start_iter: tg_model = incremental_train_and_eval( self.args.epochs, tg_model, ref_model, free_model, ref_free_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) else: tg_model = incremental_train_and_eval( self.args.epochs, tg_model, ref_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) torch.save(tg_model, ckp_name) if self.args.fix_budget: nb_protos_cl = int( np.ceil(self.args.nb_protos * 100. / self.args.nb_cl / (iteration + 1))) else: nb_protos_cl = self.args.nb_protos tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features for iter_dico in range(last_iter * self.args.nb_cl, (iteration + 1) * self.args.nb_cl): self.evalset.test_data = prototypes[iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features(tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) mu = np.mean(D, axis=1) index1 = int(iter_dico / self.args.nb_cl) index2 = iter_dico % self.args.nb_cl alpha_dr_herding[index1, :, index2] = alpha_dr_herding[index1, :, index2] * 0 w_t = mu iter_herding = 0 iter_herding_eff = 0 while not (np.sum(alpha_dr_herding[index1, :, index2] != 0) == min(nb_protos_cl, 500)) and iter_herding_eff < 1000: tmp_t = np.dot(w_t, D) ind_max = np.argmax(tmp_t) iter_herding_eff += 1 if alpha_dr_herding[index1, ind_max, index2] == 0: alpha_dr_herding[index1, ind_max, index2] = 1 + iter_herding iter_herding += 1 w_t = w_t + mu - D[:, ind_max] X_protoset_cumuls = [] Y_protoset_cumuls = [] class_means = np.zeros((64, 100, 2)) for iteration2 in range(iteration + 1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2 * self.args.nb_cl, (iteration2 + 1) * self.args.nb_cl)] self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico][:, :, :, ::-1].astype('uint8') evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2 / np.linalg.norm(D2, axis=0) alph = alpha_dr_herding[iteration2, :, iter_dico] alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1. X_protoset_cumuls.append( prototypes[iteration2 * self.args.nb_cl + iter_dico, np.where(alph == 1)[0]]) Y_protoset_cumuls.append( order[iteration2 * self.args.nb_cl + iter_dico] * np.ones(len(np.where(alph == 1)[0]))) alph = alph / np.sum(alph) class_means[:, current_cl[iter_dico], 0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 0]) alph = np.ones(dictionary_size) / dictionary_size class_means[:, current_cl[iter_dico], 1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 1]) current_means = class_means[:, order[range(0, (iteration + 1) * self.args.nb_cl)]] class_means = np.zeros((64, 100, 2)) for iteration2 in range(iteration + 1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2 * self.args.nb_cl, (iteration2 + 1) * self.args.nb_cl)] self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico].astype('uint8') self.evalset.test_labels = np.zeros( self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D / np.linalg.norm(D, axis=0) self.evalset.test_data = prototypes[ iteration2 * self.args.nb_cl + iter_dico][:, :, :, ::-1].astype('uint8') evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features( tg_model, free_model, tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2 / np.linalg.norm(D2, axis=0) alph = alpha_dr_herding[iteration2, :, iter_dico] alph = (alph > 0) * (alph < nb_protos_cl + 1) * 1. alph = alph / np.sum(alph) class_means[:, current_cl[iter_dico], 0] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 0] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 0]) alph = np.ones(dictionary_size) / dictionary_size class_means[:, current_cl[iter_dico], 1] = (np.dot(D, alph) + np.dot(D2, alph)) / 2 class_means[:, current_cl[iter_dico], 1] /= np.linalg.norm( class_means[:, current_cl[iter_dico], 1]) torch.save( class_means, osp.join( self.save_path, 'run_{}_iteration_{}_class_means.pth'.format( iteration_total, iteration))) current_means = class_means[:, order[range(0, (iteration + 1) * self.args.nb_cl)]] is_start_iteration = (iteration == start_iter) map_Y_valid_ori = np.array( [order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes') self.evalset.test_data = X_valid_ori.astype('uint8') self.evalset.test_labels = map_Y_valid_ori evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) ori_acc, fast_fc = compute_accuracy( tg_model, free_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T self.train_writer.add_scalar('ori_acc/LwF', float(ori_acc[0]), iteration) self.train_writer.add_scalar('ori_acc/iCaRL', float(ori_acc[1]), iteration) map_Y_valid_cumul = np.array( [order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy') self.evalset.test_data = X_valid_cumul.astype('uint8') self.evalset.test_labels = map_Y_valid_cumul evalloader = torch.utils.data.DataLoader( self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) cumul_acc, _ = compute_accuracy( tg_model, free_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, fast_fc=fast_fc, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T self.train_writer.add_scalar('cumul_acc/LwF', float(cumul_acc[0]), iteration) self.train_writer.add_scalar('cumul_acc/iCaRL', float(cumul_acc[1]), iteration) torch.save( top1_acc_list_ori, osp.join(self.save_path, 'run_{}_top1_acc_list_ori.pth'.format(iteration_total))) torch.save( top1_acc_list_cumul, osp.join(self.save_path, 'run_{}_top1_acc_list_cumul.pth'.format(iteration_total))) self.train_writer.close
def train(self): self.train_writer = SummaryWriter(comment=self.save_path) dictionary_size = self.dictionary_size top1_acc_list_cumul = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) top1_acc_list_ori = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) if self.args.dataset == 'cifar100': X_train_total = np.array(self.trainset.data) Y_train_total = np.array(self.trainset.targets) X_valid_total = np.array(self.testset.data) Y_valid_total = np.array(self.testset.targets) self.fusion_vars = nn.ParameterList() if self.args.branch_mode == 'dual': for idx in range(3): self.fusion_vars.append(nn.Parameter(torch.FloatTensor([0.5]))) elif self.args.branch_mode == 'single': for idx in range(3): self.fusion_vars.append(nn.Parameter(torch.FloatTensor([1.0]))) else: raise ValueError('Please set correct mode.') self.fusion_vars.to(self.device) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': X_train_total, Y_train_total = split_images_labels(self.trainset.imgs) X_valid_total, Y_valid_total = split_images_labels(self.testset.imgs) self.fusion_vars = nn.ParameterList() if self.args.branch_mode == 'dual': for idx in range(4): self.fusion_vars.append(nn.Parameter(torch.FloatTensor([0.5]))) elif self.args.branch_mode == 'single': for idx in range(4): self.fusion_vars.append(nn.Parameter(torch.FloatTensor([1.0]))) else: raise ValueError('Please set correct mode.') self.fusion_vars.to(self.device) else: raise ValueError('Please set correct dataset.') np.random.seed(1993) for iteration_total in range(self.args.nb_runs): order_name = osp.join(self.save_path, "seed_{}_{}_order_run_{}.pkl".format(1993, self.args.dataset, iteration_total)) print("Order name:{}".format(order_name)) if osp.exists(order_name): print("Loading orders") order = utils.misc.unpickle(order_name) else: print("Generating orders") order = np.arange(self.args.num_classes) np.random.shuffle(order) utils.misc.savepickle(order, order_name) order_list = list(order) print(order_list) np.random.seed(None) X_valid_cumuls = [] X_protoset_cumuls = [] X_train_cumuls = [] Y_valid_cumuls = [] Y_protoset_cumuls = [] Y_train_cumuls = [] alpha_dr_herding = np.zeros((int(self.args.num_classes/self.args.nb_cl),dictionary_size,self.args.nb_cl),np.float32) if self.args.dataset == 'cifar100': prototypes = np.zeros((self.args.num_classes,dictionary_size,X_train_total.shape[1],X_train_total.shape[2],X_train_total.shape[3])) for orde in range(self.args.num_classes): prototypes[orde,:,:,:,:] = X_train_total[np.where(Y_train_total==order[orde])] elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': prototypes = [[] for i in range(self.args.num_classes)] for orde in range(self.args.num_classes): prototypes[orde] = X_train_total[np.where(Y_train_total==order[orde])] prototypes = np.array(prototypes) else: raise ValueError('Please set correct dataset.') start_iter = int(self.args.nb_cl_fg/self.args.nb_cl)-1 for iteration in range(start_iter, int(self.args.num_classes/self.args.nb_cl)): if iteration == start_iter: last_iter = 0 b1_model = self.network(num_classes=self.args.nb_cl_fg) in_features = b1_model.fc.in_features out_features = b1_model.fc.out_features print("Feature:", in_features, "Class:", out_features) ref_model = None b2_model = None ref_b2_model = None elif iteration == start_iter+1: last_iter = iteration ref_model = copy.deepcopy(b1_model) self.ref_fusion_vars = copy.deepcopy(self.fusion_vars) if self.args.branch_1 == 'ss': b1_model = self.network_mtl(num_classes=self.args.nb_cl_fg) else: b1_model = self.network(num_classes=self.args.nb_cl_fg) ref_dict = ref_model.state_dict() tg_dict = b1_model.state_dict() tg_dict.update(ref_dict) b1_model.load_state_dict(tg_dict) b1_model.to(self.device) if self.args.branch_2 == 'ss': b2_model = self.network_mtl(num_classes=self.args.nb_cl_fg) else: b2_model = self.network(num_classes=self.args.nb_cl_fg) b2_dict = b2_model.state_dict() b2_dict.update(ref_dict) b2_model.load_state_dict(b2_dict) b2_model.to(self.device) in_features = b1_model.fc.in_features out_features = b1_model.fc.out_features print("Feature:", in_features, "Class:", out_features) new_fc = modified_linear.SplitCosineLinear(in_features, out_features, self.args.nb_cl) new_fc.fc1.weight.data = b1_model.fc.weight.data new_fc.sigma.data = b1_model.fc.sigma.data b1_model.fc = new_fc lamda_mult = out_features*1.0 / self.args.nb_cl else: last_iter = iteration ref_model = copy.deepcopy(b1_model) self.ref_fusion_vars = copy.deepcopy(self.fusion_vars) ref_b2_model = copy.deepcopy(b2_model) in_features = b1_model.fc.in_features out_features1 = b1_model.fc.fc1.out_features out_features2 = b1_model.fc.fc2.out_features print("Feature:", in_features, "Class:", out_features1+out_features2) new_fc = modified_linear.SplitCosineLinear(in_features, out_features1+out_features2, self.args.nb_cl) new_fc.fc1.weight.data[:out_features1] = b1_model.fc.fc1.weight.data new_fc.fc1.weight.data[out_features1:] = b1_model.fc.fc2.weight.data new_fc.sigma.data = b1_model.fc.sigma.data b1_model.fc = new_fc lamda_mult = (out_features1+out_features2)*1.0 / (self.args.nb_cl) if iteration > start_iter: cur_lamda = self.args.lamda * math.sqrt(lamda_mult) else: cur_lamda = self.args.lamda actual_cl = order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] indices_train_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_train_total]) indices_test_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_valid_total]) X_train = X_train_total[indices_train_10] X_valid = X_valid_total[indices_test_10] X_valid_cumuls.append(X_valid) X_train_cumuls.append(X_train) X_valid_cumul = np.concatenate(X_valid_cumuls) X_train_cumul = np.concatenate(X_train_cumuls) Y_train = Y_train_total[indices_train_10] Y_valid = Y_valid_total[indices_test_10] Y_valid_cumuls.append(Y_valid) Y_train_cumuls.append(Y_train) Y_valid_cumul = np.concatenate(Y_valid_cumuls) Y_train_cumul = np.concatenate(Y_train_cumuls) if iteration == start_iter: X_valid_ori = X_valid Y_valid_ori = Y_valid else: X_protoset = np.concatenate(X_protoset_cumuls) Y_protoset = np.concatenate(Y_protoset_cumuls) if self.args.rs_ratio > 0: scale_factor = (len(X_train) * self.args.rs_ratio) / (len(X_protoset) * (1 - self.args.rs_ratio)) rs_sample_weights = np.concatenate((np.ones(len(X_train)), np.ones(len(X_protoset))*scale_factor)) rs_num_samples = int(len(X_train) / (1 - self.args.rs_ratio)) print("X_train:{}, X_protoset:{}, rs_num_samples:{}".format(len(X_train), len(X_protoset), rs_num_samples)) X_train = np.concatenate((X_train,X_protoset),axis=0) Y_train = np.concatenate((Y_train,Y_protoset)) print('Batch of classes number {0} arrives ...'.format(iteration+1)) map_Y_train = np.array([order_list.index(i) for i in Y_train]) map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) is_start_iteration = (iteration == start_iter) if iteration > start_iter: if self.args.dataset == 'cifar100': old_embedding_norm = b1_model.fc.fc1.weight.data.norm(dim=1, keepdim=True) average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) tg_feature_model = nn.Sequential(*list(b1_model.children())[:-1]) num_features = b1_model.fc.in_features novel_embedding = torch.zeros((self.args.nb_cl, num_features)) for cls_idx in range(iteration*self.args.nb_cl, (iteration+1)*self.args.nb_cl): cls_indices = np.array([i == cls_idx for i in map_Y_train]) assert(len(np.where(cls_indices==1)[0])==dictionary_size) self.evalset.data = X_train[cls_indices].astype('uint8') self.evalset.targets = np.zeros(self.evalset.data.shape[0]) evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.data.shape[0] #cls_features = compute_features(tg_feature_model, evalloader, num_samples, num_features) cls_features = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx-iteration*self.args.nb_cl] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm b1_model.to(self.device) b1_model.fc.fc2.weight.data = novel_embedding.to(self.device) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': old_embedding_norm = b1_model.fc.fc1.weight.data.norm(dim=1, keepdim=True) average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) tg_feature_model = nn.Sequential(*list(b1_model.children())[:-1]) num_features = b1_model.fc.in_features novel_embedding = torch.zeros((self.args.nb_cl, num_features)) for cls_idx in range(iteration*self.args.nb_cl, (iteration+1)*self.args.nb_cl): cls_indices = np.array([i == cls_idx for i in map_Y_train]) assert(len(np.where(cls_indices==1)[0])<=dictionary_size) current_eval_set = merge_images_labels(X_train[cls_indices], np.zeros(len(X_train[cls_indices]))) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=2) num_samples = len(X_train[cls_indices]) #cls_features = compute_features(tg_feature_model, evalloader, num_samples, num_features) cls_features = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx-iteration*self.args.nb_cl] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm b1_model.to(self.device) b1_model.fc.fc2.weight.data = novel_embedding.to(self.device) else: raise ValueError('Please set correct dataset.') if self.args.dataset == 'cifar100': self.trainset.data = X_train.astype('uint8') self.trainset.targets = map_Y_train if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1: print("Weights from sampling:", rs_sample_weights) index1 = np.where(rs_sample_weights>1)[0] index2 = np.where(map_Y_train<iteration*self.args.nb_cl)[0] assert((index1==index2).all()) train_sampler = torch.utils.data.sampler.WeightedRandomSampler(rs_sample_weights, rs_num_samples) trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=False, sampler=train_sampler, num_workers=self.args.num_workers) else: trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=True, num_workers=self.args.num_workers) self.testset.data = X_valid_cumul.astype('uint8') self.testset.targets = map_Y_valid_cumul testloader = torch.utils.data.DataLoader(self.testset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) print('Max and Min of train labels: {}, {}'.format(min(map_Y_train), max(map_Y_train))) print('Max and Min of valid labels: {}, {}'.format(min(map_Y_valid_cumul), max(map_Y_valid_cumul))) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': current_train_imgs = merge_images_labels(X_train, map_Y_train) self.trainset.imgs = self.trainset.samples = current_train_imgs if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1: print("Weights from sampling:", rs_sample_weights) index1 = np.where(rs_sample_weights>1)[0] index2 = np.where(map_Y_train<iteration*self.args.nb_cl)[0] assert((index1==index2).all()) train_sampler = torch.utils.data.sampler.WeightedRandomSampler(rs_sample_weights, rs_num_samples) trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=False, sampler=train_sampler, num_workers=self.args.num_workers, pin_memory=True) else: trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=True, num_workers=self.args.num_workers, pin_memory=True) current_test_imgs = merge_images_labels(X_valid_cumul, map_Y_valid_cumul) self.testset.imgs = self.testset.samples = current_test_imgs testloader = torch.utils.data.DataLoader(self.testset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) print('Max and Min of train labels: {}, {}'.format(min(map_Y_train), max(map_Y_train))) print('Max and Min of valid labels: {}, {}'.format(min(map_Y_valid_cumul), max(map_Y_valid_cumul))) else: raise ValueError('Please set correct dataset.') ckp_name = osp.join(self.save_path, 'run_{}_iteration_{}_model.pth'.format(iteration_total, iteration)) ckp_name_free = osp.join(self.save_path, 'run_{}_iteration_{}_b2_model.pth'.format(iteration_total, iteration)) print('ckp_name', ckp_name) if iteration==start_iter and self.args.resume_fg: b1_model = torch.load(self.args.ckpt_dir_fg) elif self.args.resume and os.path.exists(ckp_name): b1_model = torch.load(ckp_name) else: if iteration > start_iter: ref_model = ref_model.to(self.device) ignored_params = list(map(id, b1_model.fc.fc1.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, b1_model.parameters()) base_params = filter(lambda p: p.requires_grad,base_params) b2_params = b2_model.parameters() if self.args.branch_1 == 'fixed': branch1_lr = 0.0 branch1_weight_decay = 0 else: branch1_lr = self.args.base_lr2 branch1_weight_decay = self.args.custom_weight_decay if self.args.branch_2 == 'fixed': branch2_lr = 0.0 branch2_weight_decay = 0 else: branch2_lr = self.args.base_lr2 branch2_weight_decay = self.args.custom_weight_decay tg_params_new =[{'params': base_params, 'lr': branch1_lr, 'weight_decay': branch1_weight_decay}, \ {'params': b2_params, 'lr': branch2_lr, 'weight_decay': branch2_weight_decay}, \ {'params': b1_model.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] b1_model = b1_model.to(self.device) tg_optimizer = optim.SGD(tg_params_new, lr=self.args.base_lr2, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) if self.args.branch_mode == 'dual': fusion_optimizer = optim.SGD(self.fusion_vars, lr=self.args.fusion_lr, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) elif self.args.branch_mode == 'single': fusion_optimizer = optim.SGD(self.fusion_vars, lr=0.0, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) else: raise ValueError('Please set correct mode.') else: tg_params = b1_model.parameters() b1_model = b1_model.to(self.device) tg_optimizer = optim.SGD(tg_params, lr=self.args.base_lr1, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) if self.args.branch_mode == 'dual': fusion_optimizer = optim.SGD(self.fusion_vars, lr=self.args.fusion_lr, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) elif self.args.branch_mode == 'single': fusion_optimizer = optim.SGD(self.fusion_vars, lr=0.0, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) else: raise ValueError('Please set correct mode.') if iteration > start_iter: tg_lr_scheduler = lr_scheduler.MultiStepLR(tg_optimizer, milestones=self.lr_strat, \ gamma=self.args.lr_factor) fusion_lr_scheduler = lr_scheduler.MultiStepLR(fusion_optimizer, milestones=self.lr_strat, \ gamma=self.args.lr_factor) else: tg_lr_scheduler = lr_scheduler.MultiStepLR(tg_optimizer, milestones=self.lr_strat, \ gamma=self.args.lr_factor) fusion_lr_scheduler = lr_scheduler.MultiStepLR(fusion_optimizer, \ milestones=self.lr_strat_first_phase, gamma=self.args.lr_factor) if iteration > start_iter: X_train_this_step = X_train_total[indices_train_10] Y_train_this_step = Y_train_total[indices_train_10] the_idx = np.random.randint(0,len(X_train_this_step),size=self.args.nb_cl*self.args.nb_protos) X_balanced_this_step = np.concatenate((X_train_this_step[the_idx],X_protoset),axis=0) Y_balanced_this_step = np.concatenate((Y_train_this_step[the_idx],Y_protoset),axis=0) map_Y_train_this_step = np.array([order_list.index(i) for i in Y_balanced_this_step]) self.balancedset.data = X_balanced_this_step.astype('uint8') self.balancedset.targets = map_Y_train_this_step balancedloader = torch.utils.data.DataLoader(self.balancedset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) if self.args.baseline == 'lucir': if iteration > start_iter: b1_model, b2_model = incremental_train_and_eval_lucir(self.args, self.args.epochs, self.fusion_vars, self.ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, balancedloader, iteration, start_iter, X_protoset_cumuls, Y_protoset_cumuls, order_list, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) else: b1_model = incremental_train_and_eval_first_phase_lucir(self.args, self.args.epochs, b1_model, ref_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) elif self.args.baseline == 'icarl': if iteration > start_iter: b1_model, b2_model = incremental_train_and_eval_icarl(self.args, self.args.epochs, self.fusion_vars, self.ref_fusion_vars, b1_model, ref_model, b2_model, ref_b2_model, tg_optimizer, tg_lr_scheduler, fusion_optimizer, fusion_lr_scheduler, trainloader, testloader, balancedloader, iteration, start_iter, X_protoset_cumuls, Y_protoset_cumuls, order_list, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) else: b1_model = incremental_train_and_eval_first_phase_icarl(self.args, self.args.epochs, b1_model, ref_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) else: raise ValueError('Please set correct baseline.') torch.save(b1_model, ckp_name) torch.save(b2_model, ckp_name_free) if self.args.dynamic_budget: nb_protos_cl = self.args.nb_protos else: nb_protos_cl = int(np.ceil(self.args.nb_protos*100./self.args.nb_cl/(iteration+1))) tg_feature_model = nn.Sequential(*list(b1_model.children())[:-1]) num_features = b1_model.fc.in_features if self.args.dataset == 'cifar100': for iter_dico in range(last_iter*self.args.nb_cl, (iteration+1)*self.args.nb_cl): self.evalset.data = prototypes[iter_dico].astype('uint8') self.evalset.targets = np.zeros(self.evalset.data.shape[0]) evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.data.shape[0] #mapped_prototypes = compute_features(tg_feature_model, evalloader, num_samples, num_features) mapped_prototypes = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) mu = np.mean(D,axis=1) index1 = int(iter_dico/self.args.nb_cl) index2 = iter_dico % self.args.nb_cl alpha_dr_herding[index1,:,index2] = alpha_dr_herding[index1,:,index2]*0 w_t = mu iter_herding = 0 iter_herding_eff = 0 while not(np.sum(alpha_dr_herding[index1,:,index2]!=0)==min(nb_protos_cl,500)) and iter_herding_eff<1000: tmp_t = np.dot(w_t,D) ind_max = np.argmax(tmp_t) iter_herding_eff += 1 if alpha_dr_herding[index1,ind_max,index2] == 0: alpha_dr_herding[index1,ind_max,index2] = 1+iter_herding iter_herding += 1 w_t = w_t+mu-D[:,ind_max] elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': for iter_dico in range(last_iter*self.args.nb_cl, (iteration+1)*self.args.nb_cl): current_eval_set = merge_images_labels(prototypes[iter_dico], np.zeros(len(prototypes[iter_dico]))) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) num_samples = len(prototypes[iter_dico]) mapped_prototypes = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) mu = np.mean(D,axis=1) index1 = int(iter_dico/self.args.nb_cl) index2 = iter_dico % self.args.nb_cl alpha_dr_herding[index1,:,index2] = alpha_dr_herding[index1,:,index2]*0 w_t = mu iter_herding = 0 iter_herding_eff = 0 while not(np.sum(alpha_dr_herding[index1,:,index2]!=0)==min(nb_protos_cl,500)) and iter_herding_eff<1000: tmp_t = np.dot(w_t,D) ind_max = np.argmax(tmp_t) iter_herding_eff += 1 if alpha_dr_herding[index1,ind_max,index2] == 0: alpha_dr_herding[index1,ind_max,index2] = 1+iter_herding iter_herding += 1 w_t = w_t+mu-D[:,ind_max] else: raise ValueError('Please set correct dataset.') X_protoset_cumuls = [] Y_protoset_cumuls = [] if self.args.dataset == 'cifar100': class_means = np.zeros((64,100,2)) for iteration2 in range(iteration+1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2*self.args.nb_cl,(iteration2+1)*self.args.nb_cl)] self.evalset.data = prototypes[iteration2*self.args.nb_cl+iter_dico].astype('uint8') self.evalset.targets = np.zeros(self.evalset.data.shape[0]) evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.data.shape[0] mapped_prototypes = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) self.evalset.data = prototypes[iteration2*self.args.nb_cl+iter_dico][:,:,:,::-1].astype('uint8') evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2/np.linalg.norm(D2,axis=0) alph = alpha_dr_herding[iteration2,:,iter_dico] alph = (alph>0)*(alph<nb_protos_cl+1)*1. X_protoset_cumuls.append(prototypes[iteration2*self.args.nb_cl+iter_dico,np.where(alph==1)[0]]) Y_protoset_cumuls.append(order[iteration2*self.args.nb_cl+iter_dico]*np.ones(len(np.where(alph==1)[0]))) alph = alph/np.sum(alph) class_means[:,current_cl[iter_dico],0] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],0] /= np.linalg.norm(class_means[:,current_cl[iter_dico],0]) alph = np.ones(dictionary_size)/dictionary_size class_means[:,current_cl[iter_dico],1] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],1] /= np.linalg.norm(class_means[:,current_cl[iter_dico],1]) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': class_means = np.zeros((num_features, self.args.num_classes, 2)) for iteration2 in range(iteration+1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2*self.args.nb_cl,(iteration2+1)*self.args.nb_cl)] current_eval_set = merge_images_labels(prototypes[iteration2*self.args.nb_cl+iter_dico], np.zeros(len(prototypes[iteration2*self.args.nb_cl+iter_dico]))) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) num_samples = len(prototypes[iteration2*self.args.nb_cl+iter_dico]) mapped_prototypes = compute_features(self.args, self.fusion_vars, b1_model, b2_model, \ tg_feature_model, is_start_iteration, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) D2 = D alph = alpha_dr_herding[iteration2,:,iter_dico] assert((alph[num_samples:]==0).all()) alph = alph[:num_samples] alph = (alph>0)*(alph<nb_protos_cl+1)*1. X_protoset_cumuls.append(prototypes[iteration2*self.args.nb_cl+iter_dico][np.where(alph==1)[0]]) Y_protoset_cumuls.append(order[iteration2*self.args.nb_cl+iter_dico]*np.ones(len(np.where(alph==1)[0]))) alph = alph/np.sum(alph) class_means[:,current_cl[iter_dico],0] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],0] /= np.linalg.norm(class_means[:,current_cl[iter_dico],0]) alph = np.ones(num_samples)/num_samples class_means[:,current_cl[iter_dico],1] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],1] /= np.linalg.norm(class_means[:,current_cl[iter_dico],1]) else: raise ValueError('Please set correct dataset.') torch.save(class_means, osp.join(self.save_path, 'run_{}_iteration_{}_class_means.pth'.format(iteration_total, iteration))) current_means = class_means[:, order[range(0,(iteration+1)*self.args.nb_cl)]] if iteration == start_iter: is_start_iteration = True else: is_start_iteration = False if self.args.dataset == 'cifar100': map_Y_valid_ori = np.array([order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes...') self.evalset.data = X_valid_ori.astype('uint8') self.evalset.targets = map_Y_valid_ori evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) ori_acc, fast_fc = compute_accuracy(self.args, self.fusion_vars, b1_model, b2_model, tg_feature_model, \ current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, \ order_list, is_start_iteration=is_start_iteration, \ maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy...') self.evalset.data = X_valid_cumul.astype('uint8') self.evalset.targets = map_Y_valid_cumul evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) cumul_acc, _ = compute_accuracy(self.args, self.fusion_vars, b1_model, b2_model, tg_feature_model, \ current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, \ is_start_iteration=is_start_iteration, fast_fc=fast_fc, \ maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T print('Computing confusion matrix...') elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': map_Y_valid_ori = np.array([order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes...') current_eval_set = merge_images_labels(X_valid_ori, map_Y_valid_ori) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) ori_acc, fast_fc = compute_accuracy(self.args, self.fusion_vars, b1_model, b2_model, tg_feature_model, \ current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, \ is_start_iteration=is_start_iteration, cifar=False, imagenet=True, \ valdir=os.path.join(self.args.data_dir, 'val'), \ maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy...') current_eval_set = merge_images_labels(X_valid_cumul, map_Y_valid_cumul) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) cumul_acc, _ = compute_accuracy(self.args, self.fusion_vars, b1_model, b2_model, tg_feature_model, \ current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, \ is_start_iteration=is_start_iteration, fast_fc=fast_fc, cifar=False, imagenet=True, \ valdir=os.path.join(self.args.data_dir, 'val'), \ maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T else: raise ValueError('Please set correct dataset.') torch.save(top1_acc_list_ori, osp.join(self.save_path, 'run_{}_top1_acc_list_ori.pth'.format(iteration_total))) torch.save(top1_acc_list_cumul, osp.join(self.save_path, 'run_{}_top1_acc_list_cumul.pth'.format(iteration_total))) self.train_writer.close()
def eval(self): self.train_writer = SummaryWriter(comment=self.save_path) dictionary_size = self.dictionary_size top1_acc_list_cumul = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) top1_acc_list_ori = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) if self.args.dataset == 'cifar100': X_train_total = np.array(self.trainset.train_data) Y_train_total = np.array(self.trainset.train_labels) X_valid_total = np.array(self.testset.test_data) Y_valid_total = np.array(self.testset.test_labels) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': X_train_total, Y_train_total = split_images_labels(self.trainset.imgs) X_valid_total, Y_valid_total = split_images_labels(self.testset.imgs) else: raise ValueError('Please set correct dataset.') for iteration_total in range(self.args.nb_runs): order_name = osp.join(self.save_path, \ "seed_{}_{}_order_run_{}.pkl".format(self.args.random_seed, self.args.dataset, iteration_total)) print("Order name:{}".format(order_name)) if osp.exists(order_name): print("Loading orders") order = utils.misc.unpickle(order_name) else: print("Generating orders") order = np.arange(self.args.num_classes) np.random.shuffle(order) utils.misc.savepickle(order, order_name) order_list = list(order) print(order_list) X_valid_cumuls = [] X_protoset_cumuls = [] Y_valid_cumuls = [] Y_protoset_cumuls = [] start_iter = int(self.args.nb_cl_fg/self.args.nb_cl)-1 for iteration in range(start_iter, int(self.args.num_classes/self.args.nb_cl)): if iteration == start_iter: last_iter = 0 tg_model = self.network(num_classes=self.args.nb_cl_fg) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("in_features:", in_features, "out_features:", out_features) ref_model = None elif iteration == start_iter+1: last_iter = iteration ref_model = copy.deepcopy(tg_model) if self.args.use_mtl: tg_model = self.network_mtl(num_classes=self.args.nb_cl_fg) else: tg_model = self.network(num_classes=self.args.nb_cl_fg) ref_dict = ref_model.state_dict() tg_dict = tg_model.state_dict() tg_dict.update(ref_dict) tg_model.load_state_dict(tg_dict) tg_model.to(self.device) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("in_features:", in_features, "out_features:", out_features) new_fc = modified_linear.SplitCosineLinear(in_features, out_features, self.args.nb_cl) new_fc.fc1.weight.data = tg_model.fc.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = out_features*1.0 / self.args.nb_cl else: last_iter = iteration ref_model = copy.deepcopy(tg_model) in_features = tg_model.fc.in_features out_features1 = tg_model.fc.fc1.out_features out_features2 = tg_model.fc.fc2.out_features print("in_features:", in_features, "out_features1:", out_features1, "out_features2:", out_features2) new_fc = modified_linear.SplitCosineLinear(in_features, out_features1+out_features2, self.args.nb_cl) new_fc.fc1.weight.data[:out_features1] = tg_model.fc.fc1.weight.data new_fc.fc1.weight.data[out_features1:] = tg_model.fc.fc2.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = (out_features1+out_features2)*1.0 / (self.args.nb_cl) actual_cl = order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] indices_train_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_train_total]) indices_test_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_valid_total]) X_valid = X_valid_total[indices_test_10] X_valid_cumuls.append(X_valid) X_valid_cumul = np.concatenate(X_valid_cumuls) Y_valid = Y_valid_total[indices_test_10] Y_valid_cumuls.append(Y_valid) Y_valid_cumul = np.concatenate(Y_valid_cumuls) if iteration == start_iter: X_valid_ori = X_valid Y_valid_ori = Y_valid ckp_name = osp.join(self.save_path, 'run_{}_iteration_{}_model.pth'.format(iteration_total, iteration)) print('ckp_name', ckp_name) print("[*] Loading models from checkpoint") tg_model = torch.load(ckp_name) tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) if self.args.dataset == 'cifar100': map_Y_valid_ori = np.array([order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes...') self.evalset.test_data = X_valid_ori.astype('uint8') self.evalset.test_labels = map_Y_valid_ori evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) ori_acc = compute_accuracy(tg_model, tg_feature_model, evalloader) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T self.train_writer.add_scalar('ori_acc/cnn', float(ori_acc), iteration) map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy...') self.evalset.test_data = X_valid_cumul.astype('uint8') self.evalset.test_labels = map_Y_valid_cumul evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) cumul_acc = compute_accuracy(tg_model, tg_feature_model, evalloader) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T self.train_writer.add_scalar('cumul_acc/cnn', float(cumul_acc), iteration) elif self.args.dataset == 'imagenet_sub' or self.args.dataset == 'imagenet': map_Y_valid_ori = np.array([order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes...') current_eval_set = merge_images_labels(X_valid_ori, map_Y_valid_ori) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) ori_acc = compute_accuracy(tg_model, tg_feature_model, evalloader) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T self.train_writer.add_scalar('ori_acc/cnn', float(ori_acc), iteration) map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy...') current_eval_set = merge_images_labels(X_valid_cumul, map_Y_valid_cumul) self.evalset.imgs = self.evalset.samples = current_eval_set evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers, pin_memory=True) cumul_acc = compute_accuracy(tg_model, tg_feature_model, evalloader) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T self.train_writer.add_scalar('cumul_acc/cnn', float(cumul_acc), iteration) else: raise ValueError('Please set correct dataset.') self.train_writer.close()
def train(self): """The function for the train phase""" # Set tensorboard self.train_writer = SummaryWriter(logdir=self.save_path) # Load dictionary size dictionary_size = self.dictionary_size # Set array to record the accuracies top1_acc_list_cumul = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) top1_acc_list_ori = np.zeros((int(self.args.num_classes/self.args.nb_cl), 4, self.args.nb_runs)) # Load the samples for CIFAR-100 X_train_total = np.array(self.trainset.train_data) Y_train_total = np.array(self.trainset.train_labels) X_valid_total = np.array(self.testset.test_data) Y_valid_total = np.array(self.testset.test_labels) # Generate order list using random seed 1993. This operation follows the related papers e.g., iCaRL for iteration_total in range(self.args.nb_runs): order_name = osp.join(self.save_path, "seed_{}_{}_order_run_{}.pkl".format(self.args.random_seed, self.args.dataset, iteration_total)) print("Order name:{}".format(order_name)) if osp.exists(order_name): print("Loading orders") order = utils.misc.unpickle(order_name) else: print("Generating orders") order = np.arange(self.args.num_classes) np.random.shuffle(order) utils.misc.savepickle(order, order_name) order_list = list(order) print(order_list) # Set empty lists and arrays X_valid_cumuls = [] X_protoset_cumuls = [] X_train_cumuls = [] Y_valid_cumuls = [] Y_protoset_cumuls = [] Y_train_cumuls = [] alpha_dr_herding = np.zeros((int(self.args.num_classes/self.args.nb_cl),dictionary_size,self.args.nb_cl),np.float32) # Set the initial exemplars prototypes = np.zeros((self.args.num_classes,dictionary_size,X_train_total.shape[1],X_train_total.shape[2],X_train_total.shape[3])) for orde in range(self.args.num_classes): prototypes[orde,:,:,:,:] = X_train_total[np.where(Y_train_total==order[orde])] # Set the start iteration start_iter = int(self.args.nb_cl_fg/self.args.nb_cl)-1 # Begin training for iteration in range(start_iter, int(self.args.num_classes/self.args.nb_cl)): # Initial model if iteration == start_iter: last_iter = 0 tg_model = self.network(num_classes=self.args.nb_cl_fg) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("In_features:", in_features, "out_features:", out_features) ref_model = None elif iteration == start_iter+1: last_iter = iteration ref_model = copy.deepcopy(tg_model) tg_model = self.network_mtl(num_classes=self.args.nb_cl_fg) ref_dict = ref_model.state_dict() tg_dict = tg_model.state_dict() tg_dict.update(ref_dict) tg_model.load_state_dict(tg_dict) tg_model.to(self.device) in_features = tg_model.fc.in_features out_features = tg_model.fc.out_features print("in_features:", in_features, "out_features:", out_features) new_fc = modified_linear.SplitCosineLinear(in_features, out_features, self.args.nb_cl) new_fc.fc1.weight.data = tg_model.fc.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = out_features*1.0 / self.args.nb_cl else: last_iter = iteration ref_model = copy.deepcopy(tg_model) in_features = tg_model.fc.in_features out_features1 = tg_model.fc.fc1.out_features out_features2 = tg_model.fc.fc2.out_features print("in_features:", in_features, "out_features1:", out_features1, "out_features2:", out_features2) new_fc = modified_linear.SplitCosineLinear(in_features, out_features1+out_features2, self.args.nb_cl) new_fc.fc1.weight.data[:out_features1] = tg_model.fc.fc1.weight.data new_fc.fc1.weight.data[out_features1:] = tg_model.fc.fc2.weight.data new_fc.sigma.data = tg_model.fc.sigma.data tg_model.fc = new_fc lamda_mult = (out_features1+out_features2)*1.0 / (self.args.nb_cl) # Set lamda if iteration > start_iter: cur_lamda = self.args.lamda * math.sqrt(lamda_mult) else: cur_lamda = self.args.lamda if iteration > start_iter: print("Lamda for less forget is set to ", cur_lamda) # Add the current exemplars to the training set actual_cl = order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] indices_train_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_train_total]) indices_test_10 = np.array([i in order[range(last_iter*self.args.nb_cl,(iteration+1)*self.args.nb_cl)] for i in Y_valid_total]) X_train = X_train_total[indices_train_10] X_valid = X_valid_total[indices_test_10] X_valid_cumuls.append(X_valid) X_train_cumuls.append(X_train) X_valid_cumul = np.concatenate(X_valid_cumuls) X_train_cumul = np.concatenate(X_train_cumuls) Y_train = Y_train_total[indices_train_10] Y_valid = Y_valid_total[indices_test_10] Y_valid_cumuls.append(Y_valid) Y_train_cumuls.append(Y_train) Y_valid_cumul = np.concatenate(Y_valid_cumuls) Y_train_cumul = np.concatenate(Y_train_cumuls) if iteration == start_iter: X_valid_ori = X_valid Y_valid_ori = Y_valid else: X_protoset = np.concatenate(X_protoset_cumuls) Y_protoset = np.concatenate(Y_protoset_cumuls) if self.args.rs_ratio > 0: scale_factor = (len(X_train) * self.args.rs_ratio) / (len(X_protoset) * (1 - self.args.rs_ratio)) rs_sample_weights = np.concatenate((np.ones(len(X_train)), np.ones(len(X_protoset))*scale_factor)) rs_num_samples = int(len(X_train) / (1 - self.args.rs_ratio)) print("X_train:{}, X_protoset:{}, rs_num_samples:{}".format(len(X_train), len(X_protoset), rs_num_samples)) X_train = np.concatenate((X_train,X_protoset),axis=0) Y_train = np.concatenate((Y_train,Y_protoset)) # Launch the training loop print('Batch of classes number {0} arrives'.format(iteration+1)) map_Y_train = np.array([order_list.index(i) for i in Y_train]) map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) # Imprint weights if iteration > start_iter and self.args.imprint_weights: print("Imprint weights") # Compute the average norm of old embdding old_embedding_norm = tg_model.fc.fc1.weight.data.norm(dim=1, keepdim=True) average_old_embedding_norm = torch.mean(old_embedding_norm, dim=0).to('cpu').type(torch.DoubleTensor) tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features novel_embedding = torch.zeros((self.args.nb_cl, num_features)) for cls_idx in range(iteration*self.args.nb_cl, (iteration+1)*self.args.nb_cl): cls_indices = np.array([i == cls_idx for i in map_Y_train]) assert(len(np.where(cls_indices==1)[0])==dictionary_size) self.evalset.test_data = X_train[cls_indices].astype('uint8') self.evalset.test_labels = np.zeros(self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] cls_features = compute_features(tg_feature_model, evalloader, num_samples, num_features) norm_features = F.normalize(torch.from_numpy(cls_features), p=2, dim=1) cls_embedding = torch.mean(norm_features, dim=0) novel_embedding[cls_idx-iteration*self.args.nb_cl] = F.normalize(cls_embedding, p=2, dim=0) * average_old_embedding_norm tg_model.to(self.device) tg_model.fc.fc2.weight.data = novel_embedding.to(self.device) self.trainset.train_data = X_train.astype('uint8') self.trainset.train_labels = map_Y_train if iteration > start_iter and self.args.rs_ratio > 0 and scale_factor > 1: print("Weights from sampling:", rs_sample_weights) index1 = np.where(rs_sample_weights>1)[0] index2 = np.where(map_Y_train<iteration*self.args.nb_cl)[0] assert((index1==index2).all()) train_sampler = torch.utils.data.sampler.WeightedRandomSampler(rs_sample_weights, rs_num_samples) trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=False, sampler=train_sampler, num_workers=self.args.num_workers) else: trainloader = torch.utils.data.DataLoader(self.trainset, batch_size=self.args.train_batch_size, shuffle=True, num_workers=self.args.num_workers) self.testset.test_data = X_valid_cumul.astype('uint8') self.testset.test_labels = map_Y_valid_cumul testloader = torch.utils.data.DataLoader(self.testset, batch_size=self.args.test_batch_size, shuffle=False, num_workers=self.args.num_workers) print('Max and min of train labels: {}, {}'.format(min(map_Y_train), max(map_Y_train))) print('Max and min of valid labels: {}, {}'.format(min(map_Y_valid_cumul), max(map_Y_valid_cumul))) # Set checkpoint name ckp_name = osp.join(self.save_path, 'run_{}_iteration_{}_model.pth'.format(iteration_total, iteration)) print('Checkpoint name:', ckp_name) # Resume the saved models or set the new models if iteration==start_iter and self.args.resume_fg: print("Loading first group models from checkpoint") tg_model = torch.load(self.args.ckpt_dir_fg) elif self.args.resume and os.path.exists(ckp_name): print("Loading models from checkpoint") tg_model = torch.load(ckp_name) else: if iteration > start_iter: ref_model = ref_model.to(self.device) ignored_params = list(map(id, tg_model.fc.fc1.parameters())) base_params = filter(lambda p: id(p) not in ignored_params, tg_model.parameters()) base_params = filter(lambda p: p.requires_grad,base_params) tg_params_new =[{'params': base_params, 'lr': self.args.base_lr2, 'weight_decay': self.args.custom_weight_decay}, {'params': tg_model.fc.fc1.parameters(), 'lr': 0, 'weight_decay': 0}] tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD(tg_params_new, lr=self.args.base_lr2, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) else: tg_params = tg_model.parameters() tg_model = tg_model.to(self.device) tg_optimizer = optim.SGD(tg_params, lr=self.args.base_lr1, momentum=self.args.custom_momentum, weight_decay=self.args.custom_weight_decay) if iteration > start_iter: tg_lr_scheduler = lr_scheduler.MultiStepLR(tg_optimizer, milestones=self.lr_strat, gamma=self.args.lr_factor) else: tg_lr_scheduler = lr_scheduler.MultiStepLR(tg_optimizer, milestones=self.lr_strat, gamma=self.args.lr_factor) print("Incremental train") tg_model = incremental_train_and_eval(self.args.epochs, tg_model, ref_model, tg_optimizer, tg_lr_scheduler, trainloader, testloader, iteration, start_iter, cur_lamda, self.args.dist, self.args.K, self.args.lw_mr) torch.save(tg_model, ckp_name) # Process the exemplars if self.args.fix_budget: nb_protos_cl = int(np.ceil(self.args.nb_protos*100./self.args.nb_cl/(iteration+1))) else: nb_protos_cl = self.args.nb_protos tg_feature_model = nn.Sequential(*list(tg_model.children())[:-1]) num_features = tg_model.fc.in_features # Using herding startegy print('Updating exemplars') for iter_dico in range(last_iter*self.args.nb_cl, (iteration+1)*self.args.nb_cl): # Possible exemplars in the feature space and projected on the L2 sphere self.evalset.test_data = prototypes[iter_dico].astype('uint8') self.evalset.test_labels = np.zeros(self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features(tg_feature_model, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) # Ranking of the potential exemplars mu = np.mean(D,axis=1) index1 = int(iter_dico/self.args.nb_cl) index2 = iter_dico % self.args.nb_cl alpha_dr_herding[index1,:,index2] = alpha_dr_herding[index1,:,index2]*0 w_t = mu iter_herding = 0 iter_herding_eff = 0 while not(np.sum(alpha_dr_herding[index1,:,index2]!=0)==min(nb_protos_cl,500)) and iter_herding_eff<1000: tmp_t = np.dot(w_t,D) ind_max = np.argmax(tmp_t) iter_herding_eff += 1 if alpha_dr_herding[index1,ind_max,index2] == 0: alpha_dr_herding[index1,ind_max,index2] = 1+iter_herding iter_herding += 1 w_t = w_t+mu-D[:,ind_max] # Prepare the protoset X_protoset_cumuls = [] Y_protoset_cumuls = [] # Calculate the mean of exemplars print('Computing the mean of exemplars') class_means = np.zeros((64,100,2)) for iteration2 in range(iteration+1): for iter_dico in range(self.args.nb_cl): current_cl = order[range(iteration2*self.args.nb_cl,(iteration2+1)*self.args.nb_cl)] # Collect data in the feature space for each class self.evalset.test_data = prototypes[iteration2*self.args.nb_cl+iter_dico].astype('uint8') self.evalset.test_labels = np.zeros(self.evalset.test_data.shape[0]) #zero labels evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) num_samples = self.evalset.test_data.shape[0] mapped_prototypes = compute_features(tg_feature_model, evalloader, num_samples, num_features) D = mapped_prototypes.T D = D/np.linalg.norm(D,axis=0) self.evalset.test_data = prototypes[iteration2*self.args.nb_cl+iter_dico][:,:,:,::-1].astype('uint8') evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) mapped_prototypes2 = compute_features(tg_feature_model, evalloader, num_samples, num_features) D2 = mapped_prototypes2.T D2 = D2/np.linalg.norm(D2,axis=0) # iCaRL alph = alpha_dr_herding[iteration2,:,iter_dico] alph = (alph>0)*(alph<nb_protos_cl+1)*1. X_protoset_cumuls.append(prototypes[iteration2*self.args.nb_cl+iter_dico,np.where(alph==1)[0]]) Y_protoset_cumuls.append(order[iteration2*self.args.nb_cl+iter_dico]*np.ones(len(np.where(alph==1)[0]))) alph = alph/np.sum(alph) class_means[:,current_cl[iter_dico],0] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],0] /= np.linalg.norm(class_means[:,current_cl[iter_dico],0]) # Nearest neighbor upper bound alph = np.ones(dictionary_size)/dictionary_size class_means[:,current_cl[iter_dico],1] = (np.dot(D,alph)+np.dot(D2,alph))/2 class_means[:,current_cl[iter_dico],1] /= np.linalg.norm(class_means[:,current_cl[iter_dico],1]) torch.save(class_means, osp.join(self.save_path, 'run_{}_iteration_{}_class_means.pth'.format(iteration_total, iteration))) # Mnemonics print("Mnemonics training") # Initialize the mnemonics self.mnemonics, self.mnemonics_lrs, self.mnemonics_label = self.MnemonicsTrainer.mnemonics_init_with_images_cifar(iteration, start_iter, self.args.nb_cl_fg, self.args.nb_cl, X_protoset_cumuls, Y_protoset_cumuls, order_list, self.device) # Train the mnemonics self.mnemonics, self.mnemonics_lrs, self.mnemonics_label = self.MnemonicsTrainer.mnemonics_train(tg_model, trainloader, testloader, iteration, start_iter, self.device) # Process the mnemonics and set them as the exemplars X_protoset_cumuls, Y_protoset_cumuls = process_mnemonics(X_protoset_cumuls, Y_protoset_cumuls, self.mnemonics, self.mnemonics_label, order_list) # Get current class means current_means = class_means[:, order[range(0,(iteration+1)*self.args.nb_cl)]] # Set iteration labels is_start_iteration = (iteration == start_iter) # Calculate validation error of model on the first nb_cl classes map_Y_valid_ori = np.array([order_list.index(i) for i in Y_valid_ori]) print('Computing accuracy on the original batch of classes') self.evalset.test_data = X_valid_ori.astype('uint8') self.evalset.test_labels = map_Y_valid_ori evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) ori_acc, fast_fc = compute_accuracy(tg_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_ori[iteration, :, iteration_total] = np.array(ori_acc).T self.train_writer.add_scalar('ori_acc/cosine', float(ori_acc[0]), iteration) self.train_writer.add_scalar('ori_acc/nearest_neighbor', float(ori_acc[1]), iteration) self.train_writer.add_scalar('ori_acc/fc_finetune', float(ori_acc[3]), iteration) # Calculate validation error of model on the cumul of classes map_Y_valid_cumul = np.array([order_list.index(i) for i in Y_valid_cumul]) print('Computing cumulative accuracy') self.evalset.test_data = X_valid_cumul.astype('uint8') self.evalset.test_labels = map_Y_valid_cumul evalloader = torch.utils.data.DataLoader(self.evalset, batch_size=self.args.eval_batch_size, shuffle=False, num_workers=self.args.num_workers) cumul_acc, _ = compute_accuracy(tg_model, tg_feature_model, current_means, X_protoset_cumuls, Y_protoset_cumuls, evalloader, order_list, is_start_iteration=is_start_iteration, fast_fc=fast_fc, maml_lr=self.args.maml_lr, maml_epoch=self.args.maml_epoch) top1_acc_list_cumul[iteration, :, iteration_total] = np.array(cumul_acc).T self.train_writer.add_scalar('cumul_acc/cosine', float(cumul_acc[0]), iteration) self.train_writer.add_scalar('cumul_acc/nearest_neighbor', float(cumul_acc[1]), iteration) self.train_writer.add_scalar('cumul_acc/fc_finetune', float(cumul_acc[3]), iteration) # Save data and close tensorboard torch.save(top1_acc_list_ori, osp.join(self.save_path, 'run_{}_top1_acc_list_ori.pth'.format(iteration_total))) torch.save(top1_acc_list_cumul, osp.join(self.save_path, 'run_{}_top1_acc_list_cumul.pth'.format(iteration_total))) self.train_writer.close