def naive_ensembling(args, networks, test_loader): # simply average the weights in networks if args.width_ratio != 1: print( "Unfortunately naive ensembling can't work if models are not of same shape!" ) return -1, None weights = [(1 - args.ensemble_step), args.ensemble_step] avg_pars = get_avg_parameters(networks, weights) ensemble_network = get_model_from_name(args) # put on GPU if args.gpu_id != -1: ensemble_network = ensemble_network.cuda(args.gpu_id) # check the test performance of the method before log_dict = {} log_dict['test_losses'] = [] # log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)] routines.test(args, ensemble_network, test_loader, log_dict) # set the weights of the ensembled network for idx, (name, param) in enumerate(ensemble_network.state_dict().items()): ensemble_network.state_dict()[name].copy_(avg_pars[idx].data) # check the test performance of the method after ensembling log_dict = {} log_dict['test_losses'] = [] # log_dict['test_counter'] = [i * len(train_loader.dataset) for i in range(args.n_epochs + 1)] return routines.test(args, ensemble_network, test_loader, log_dict), ensemble_network
def distillation(args, teachers, student, train_loader, test_loader, device): # Inspiration: https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/evaluate.py for teacher in teachers: teacher.eval() optimizer = optim.SGD(student.parameters(), lr=args.learning_rate, momentum=args.momentum) log_dict = {} log_dict['train_losses'] = [] log_dict['train_counter'] = [] log_dict['test_losses'] = [] accuracies = [] accuracies.append(routines.test(args, student, test_loader, log_dict)) for epoch_idx in range(0, args.dist_epochs): student.train() for batch_idx, (data_batch, labels_batch) in enumerate(train_loader): # move to GPU if available if args.gpu_id != -1: data_batch, labels_batch = data_batch.to( device), labels_batch.to(device) # compute mean teacher output teacher_outputs = [] for teacher in teachers: teacher_outputs.append(teacher(data_batch, disable_logits=True)) teacher_outputs = torch.stack(teacher_outputs) teacher_outputs = teacher_outputs.mean(dim=0) optimizer.zero_grad() # get student output student_output = student(data_batch, disable_logits=True) # knowledge distillation loss loss = loss_fn_kd(student_output, labels_batch, teacher_outputs, args) loss.backward() # update student optimizer.step() if batch_idx % args.log_interval == 0: print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch_idx, batch_idx * len(data_batch), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item())) log_dict['train_losses'].append(loss.item()) log_dict['train_counter'].append((batch_idx * 64) + ( (epoch_idx - 1) * len(train_loader.dataset))) accuracies.append(routines.test(args, student, test_loader, log_dict)) return student, accuracies
def update_model(args, model, new_params, test=False, test_loader=None, reversed=False, idx=-1): updated_model = get_model_from_name(args, idx=idx) if args.gpu_id != -1: updated_model = updated_model.cuda(args.gpu_id) layer_idx = 0 model_state_dict = model.state_dict() print("len of model_state_dict is ", len(model_state_dict.items())) print("len of new_params is ", len(new_params)) for key, value in model_state_dict.items(): print("updated parameters for layer ", key) model_state_dict[key] = new_params[layer_idx] layer_idx += 1 if layer_idx == len(new_params): break updated_model.load_state_dict(model_state_dict) if test: log_dict = {} log_dict['test_losses'] = [] final_acc = routines.test(args, updated_model, test_loader, log_dict) print("accuracy after update is ", final_acc) else: final_acc = None return updated_model, final_acc
def recheck_accuracy(args, models, test_loader): # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar if args.recheck_cifar or args.recheck_acc: recheck_accuracies = [] for model in models: log_dict = {} log_dict['test_losses'] = [] recheck_accuracies.append( routines.test(args, model, test_loader, log_dict)) print("Rechecked accuracies are ", recheck_accuracies)
def get_network_from_param_list(args, param_list, test_loader): print("using independent method") new_network = get_model_from_name(args, idx=1) if args.gpu_id != -1: new_network = new_network.cuda(args.gpu_id) # check the test performance of the network before log_dict = {} log_dict['test_losses'] = [] routines.test(args, new_network, test_loader, log_dict) # set the weights of the new network # print("before", new_network.state_dict()) print("len of model parameters and avg aligned layers is ", len(list(new_network.parameters())), len(param_list)) assert len(list(new_network.parameters())) == len(param_list) layer_idx = 0 model_state_dict = new_network.state_dict() print("len of model_state_dict is ", len(model_state_dict.items())) print("len of param_list is ", len(param_list)) for key, value in model_state_dict.items(): model_state_dict[key] = param_list[layer_idx] layer_idx += 1 new_network.load_state_dict(model_state_dict) # check the test performance of the network after log_dict = {} log_dict['test_losses'] = [] acc = routines.test(args, new_network, test_loader, log_dict) return acc, new_network
def test_model(args, model, test_loader): log_dict = {} log_dict['test_losses'] = [] return routines.test(args, model, test_loader, log_dict)
else: model, accuracy, local_accuracy = routines.get_pretrained_model( args, os.path.join(ensemble_dir, 'model_{}/{}.checkpoint'.format(idx, args.ckpt_type)), data_separated=True, idx = idx ) models.append(model) accuracies.append(accuracy) local_accuracies.append(local_accuracy) print("Done loading all the models") # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar if args.recheck_cifar or args.recheck_acc: recheck_accuracies = [] for model in models: log_dict = {} log_dict['test_losses'] = [] recheck_accuracies.append(routines.test(args, model, test_loader, log_dict)) print("Rechecked accuracies are ", recheck_accuracies) # print('checking named modules of model0 for use in compute_activations!', list(models[0].named_modules())) else: # get dataloaders print("------- Obtain dataloaders -------") train_loader, test_loader = get_dataloader(args) if args.partition_type == 'labels': print("------- Split dataloaders by labels -------") choice = [int(x) for x in args.choice.split()] (trailo_a, teslo_a), (trailo_b, teslo_b), other = partition.split_mnist_by_labels(args, train_loader, test_loader, choice=choice) print("------- Training independent models -------") models, accuracies, local_accuracies = routines.train_data_separated_models(args, [trailo_a, trailo_b],
ensemble_dir, 'model_{}/{}.checkpoint'.format(idx, args.ckpt_type)), idx=idx) models.append(model) accuracies.append(accuracy) print("Done loading all the models") # Additional flag of recheck_acc to supplement the legacy flag recheck_cifar if args.recheck_cifar or args.recheck_acc: recheck_accuracies = [] for model in models: log_dict = {} log_dict['test_losses'] = [] recheck_accuracies.append( routines.test(args, model, test_loader, log_dict)) print("Rechecked accuracies are ", recheck_accuracies) # print('checking named modules of model0 for use in compute_activations!', list(models[0].named_modules())) # print('what about named parameters of model0 for use in compute_activations!', [tupl[0] for tupl in list(models[0].named_parameters())]) else: # get dataloaders print("------- Obtain dataloaders -------") train_loader, test_loader = get_dataloader(args) retrain_loader, _ = get_dataloader( args, no_randomness=args.no_random_trainloaders) print("------- Training independent models -------") models, accuracies = routines.train_models(args, train_loader,