def main_worker(gpu, ngpus_per_node, opt): global best_acc, total_time opt.gpu = int(gpu) opt.gpu_id = int(gpu) if opt.gpu is not None: print("Use GPU: {} for training".format(opt.gpu)) if opt.multiprocessing_distributed: # Only one node now. opt.rank = gpu dist_backend = 'nccl' dist.init_process_group(backend=dist_backend, init_method=opt.dist_url, world_size=opt.world_size, rank=opt.rank) opt.batch_size = int(opt.batch_size / ngpus_per_node) opt.num_workers = int( (opt.num_workers + ngpus_per_node - 1) / ngpus_per_node) if opt.deterministic: torch.manual_seed(12345) cudnn.deterministic = True cudnn.benchmark = False numpy.random.seed(12345) class_num_map = { 'cifar100': 100, 'imagenet': 1000, 'imagenette': 10, } if opt.dataset not in class_num_map: raise NotImplementedError(opt.dataset) n_cls = class_num_map[opt.dataset] # model model_t = load_teacher(opt.path_t, n_cls, opt.gpu, opt) module_args = {'num_classes': n_cls} model_s = model_dict[opt.model_s](**module_args) if opt.dataset == 'cifar100': data = torch.randn(2, 3, 32, 32) elif opt.dataset == 'imagenet': data = torch.randn(2, 3, 224, 224) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'semckd': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = SemCKDLoss() self_attention = SelfA( len(feat_s) - 2, len(feat_t) - 2, opt.batch_size, s_n, t_n) module_list.append(self_attention) trainable_list.append(self_attention) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = 50000 criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'irg': criterion_kd = IRGLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'hkd': criterion_kd = HKDLoss(init_weight=opt.hkd_initial_weight, decay=opt.hkd_decay) elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss module_list.append(model_t) if torch.cuda.is_available(): # For multiprocessing distributed, DistributedDataParallel constructor # should always set the single device scope, otherwise, # DistributedDataParallel will use all available devices. if opt.multiprocessing_distributed: if opt.gpu is not None: torch.cuda.set_device(opt.gpu) module_list.cuda(opt.gpu) distributed_modules = [] for module in module_list: DDP = torch.nn.parallel.DistributedDataParallel distributed_modules.append( DDP(module, device_ids=[opt.gpu])) module_list = distributed_modules criterion_list.cuda(opt.gpu) else: print( 'multiprocessing_distributed must be with a specifiec gpu id' ) else: criterion_list.cuda() module_list.cuda() if not opt.deterministic: cudnn.benchmark = True optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample( batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader = get_cifar100_dataloaders( batch_size=opt.batch_size, num_workers=opt.num_workers) elif opt.dataset in imagenet_list: if opt.dali is None: train_loader, val_loader, train_sampler = get_imagenet_dataloader( dataset=opt.dataset, batch_size=opt.batch_size, num_workers=opt.num_workers, multiprocessing_distributed=opt.multiprocessing_distributed) else: train_loader, val_loader = get_dali_data_loader(opt) else: raise NotImplementedError(opt.dataset) if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) if not opt.skip_validation: # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) if opt.dali is not None: val_loader.reset() if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: print('teacher accuracy: ', teacher_acc) else: print('Skipping teacher validation.') # routine for epoch in range(1, opt.epochs + 1): torch.cuda.empty_cache() if opt.multiprocessing_distributed: if opt.dali is None: train_sampler.set_epoch(epoch) adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_acc_top5, train_loss, data_time = train( epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() if opt.multiprocessing_distributed: metrics = torch.tensor( [train_acc, train_acc_top5, train_loss, data_time]).cuda(opt.gpu, non_blocking=True) reduced = reduce_tensor( metrics, opt.world_size if 'world_size' in opt else 1) train_acc, train_acc_top5, train_loss, data_time = reduced.tolist() if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: print( ' * Epoch {}, GPU {}, Acc@1 {:.3f}, Acc@5 {:.3f}, Time {:.2f}, Data {:.2f}' .format(epoch, opt.gpu, train_acc, train_acc_top5, time2 - time1, data_time)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) print('GPU %d validating' % (opt.gpu)) test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) if opt.dali is not None: train_loader.reset() val_loader.reset() if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: print(' ** Acc@1 {:.3f}, Acc@5 {:.3f}'.format( test_acc, test_acc_top5)) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', test_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } if opt.distill == 'semckd': state['attention'] = trainable_list[-1].state_dict() save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) test_merics = { 'test_loss': test_loss, 'test_acc': test_acc, 'test_acc_top5': test_acc_top5, 'epoch': epoch } save_dict_to_json( test_merics, os.path.join(opt.save_folder, "test_best_metrics.json")) print('saving the best model!') torch.save(state, save_file) if not opt.multiprocessing_distributed or opt.rank % ngpus_per_node == 0: # This best accuracy is only for printing purpose. print('best accuracy:', best_acc) # save parameters save_state = {k: v for k, v in opt._get_kwargs()} # No. parameters(M) num_params = (sum(p.numel() for p in model_s.parameters()) / 1000000.0) save_state['Total params'] = num_params save_state['Total time'] = (time.time() - total_time) / 3600.0 params_json_path = os.path.join(opt.save_folder, "parameters.json") save_dict_to_json(save_state, params_json_path)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode, data_path= opt.datapath ) else: train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True, data_path=opt.datapath ) n_cls = 100 else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
print('==> Load pretrained model form', args.pretrained, '...') pretrained_model = torch.load(args.pretrained) best_acc = pretrained_model['best_acc'] print(pretrained_model['best_acc']) model.load_state_dict(pretrained_model['state_dict']) print(model) module_list = nn.ModuleList([]) module_list.append(model) # define solver and criterion base_lr = float(args.lr) if args.distill == 'kd': criterion_kd = DistillKL(args.kd_T) elif args.distill == 'attention': criterion_kd = Attention() elif args.distill == 'nst': criterion_kd = NSTLoss() elif args.distill == 'similarity': criterion_kd = Similarity() elif args.distill == 'rkd': criterion_kd = RKDLoss() elif args.distill == 'pkt': criterion_kd = PKT() elif args.distill == 'kdsvd': criterion_kd = KDSVD() else: raise NotImplementedError(args.distill)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, n_test=100, is_instance=True) n_cls = 100 else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # embed the watermark into teacher model wm_loader = None if opt.watermark == "usenix": # paper: https://www.usenix.org/system/files/conference/usenixsecurity18/sec18-adi.pdf # Define an optimizer for the teacher trainable_list_t = nn.ModuleList([model_t]) optimizer_t = optim.SGD(trainable_list_t.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) wm_loader = get_usenixwm_dataloader() print("## Train data + Watermark Val Acc") teacher_acc, _, _ = validate(val_loader, model_t, nn.CrossEntropyLoss(), opt) print("==> embedding USENIX watermark...") max_epochs = 250 # Cutoff val for epoch in range(1, max_epochs + 1): set_learning_rate(2e-4, optimizer_t) top1, top5 = train_vanilla(epoch, wm_loader, model_t, nn.CrossEntropyLoss(), optimizer_t, opt) if top1 >= 97: break print("## Watermark val") teacher_acc, _, _ = validate(wm_loader, model_t, nn.CrossEntropyLoss(), opt) print("==> done") # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # If teacher and student models match, copy over weights for initialization if (opt.init_strat == "noise") and (type(model_t) == type(model_s)): print("==> Copying teachers weights to student with a weight of {}".format(opt.init_inv_corr)) model_s.load_state_dict(model_t.state_dict()) with torch.no_grad(): for param in model_s.parameters(): if torch.cuda.is_available(): noise = (torch.randn(param.size()) * opt.init_inv_corr).cuda() param.add_(noise) else: param.add_(torch.randn(param.size()) * opt.init_inv_corr) student_acc, _, _ = validate(val_loader, model_s, criterion_cls, opt) print('student accuracy: ', student_acc) print("==> done") # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) if wm_loader is not None: print("==> wm retention") wm_top1, wm_top5, wm_loss = validate(wm_loader, model_s, criterion_cls, opt) logger.log_value('wm_ret', wm_top1, epoch) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
def main(): opt = parse_option() seed = opt.seed if opt.distribution: seed += opt.local_rank torch.cuda.set_device(opt.local_rank) set_seed(seed) best_acc = 0 best_acc_top5 = 0 best_epoch = 0 # tensorboard logger if opt.distribution: dist.barrier() if not opt.distribution or opt.local_rank == 0: writer = SummaryWriter(log_dir=opt.tb_folder, flush_secs=2) # dataloader if opt.dataset == 'cifar100': train_loader, val_loader = get_cifar100_dataloaders( batch_size=opt.batch_size, num_workers=opt.num_workers, distribution=opt.distribution) n_cls = 100 elif opt.dataset == 'cifar10': train_loader, val_loader = get_cifar10_dataloaders( batch_size=opt.batch_size, num_workers=opt.num_workers, distribution=opt.distribution) n_cls = 10 else: raise NotImplementedError(opt.dataset) # model set_seed(seed) model_t_list = load_teacher_list(n_cls, opt.teacher_model_name, opt.model_t_name) model_s = model_dict[opt.model_s](num_classes=n_cls) # To get feature map's shape data = torch.randn(2, 3, 32, 32) feat_t_list = [] model_s.eval() for model_t in model_t_list: model_t.eval() for model_t in model_t_list: feat_t, _ = model_t(data, is_feat=True) feat_t_list.append(feat_t) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_KLdiv = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() for i, feat_t in enumerate(feat_t_list): regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss # KL divergence loss, original knowledge distillation criterion_list.append(criterion_KLdiv) criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay, nesterov=opt.nesterov) # append teacher after optimizer to avoid weight_decay module_list.extend(model_t_list) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = False if opt.distribution: module_list = [ DDP(model, device_ids=[opt.local_rank], output_device=opt.local_rank) for model in module_list ] # validate teacher accuracy if opt.show_teacher_acc: if opt.teacher_num > 1: teacher_acc, teacher_acc_top5, _, teacher_acc_list = validate_multi( val_loader, model_t_list, criterion_cls, opt) else: model_t = model_t_list[0] teacher_acc, teacher_acc_top5, _ = validate( val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc, teacher_acc_top5) # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) test_acc, test_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) # tensorboard if not opt.distribution or opt.local_rank == 0: writer.add_scalar('train_acc', train_acc, epoch) writer.add_scalar('train_loss', train_loss, epoch) writer.add_scalar('test_acc', test_acc, epoch) writer.add_scalar('test_loss', test_loss, epoch) writer.add_scalar('test_acc_top5', test_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc best_acc_top5 = test_acc_top5 best_epoch = epoch if not opt.distribution or opt.local_rank == 0: state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, 'best_acc_top5': best_acc_top5, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') torch.save(state, save_file) # regular saving if not opt.distribution or opt.local_rank == 0: if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, 'accuracy_top5': test_acc_top5, } save_file = os.path.join( opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) print('best accuracy:', best_acc, best_acc_top5) f = open(os.path.join(opt.save_folder, 'log.txt'), 'a+') f.write(f'best accuracy: {best_acc}, {best_acc_top5}') f.close() if not opt.distribution or opt.local_rank == 0: if opt.show_teacher_acc: print('teacher accuracy: ', teacher_acc, teacher_acc_top5) print('best accuracy:', best_acc, best_acc_top5, "at epoch", best_epoch) f = open(os.path.join(opt.save_folder, 'log.txt'), 'a+') f.write( ' * best Acc@1 {top1:.3f} Acc@5 {top5:.3f} at epoch {best_epoch}\n' .format(top1=best_acc, top5=best_acc_top5, best_epoch=best_epoch)) f.close() # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
def __init__(self): super(BaseModel, self).__init__() self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') self.distill_loss = DistillKL(4)
def distill(opt): # refine the opt arguments opt.model_path = './save/student_model' iterations = opt.lr_decay_epochs.split(',') opt.lr_decay_epochs = list([]) for it in iterations: opt.lr_decay_epochs.append(int(it)) opt.model_t = get_teacher_name(opt.path_t) opt.print_freq = int(50000 / opt.batch_size / opt.print_freq) opt.model_name = 'S:{}_T:{}_{}_{}/r:{}_a:{}_b:{}_{}_{}_{}_{}_lam:{}_alp:{}_augsize:{}_T:{}'.format( opt.model_s, opt.model_t, opt.dataset, opt.distill, opt.gamma, opt.alpha, opt.beta, opt.trial, opt.device, opt.seed, opt.aug_type, opt.aug_lambda, opt.aug_alpha, opt.aug_size, opt.kd_T) opt.save_folder = os.path.join(opt.model_path, opt.model_name) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) opt.learning_rate = 0.1 * opt.batch_size / 128 # set different learning rate from these 4 models if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']: opt.learning_rate = opt.learning_rate / 5 print("learning rate is set to:", opt.learning_rate) best_acc = 0 np.random.seed(opt.seed) torch.manual_seed(opt.seed) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample( batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_cifar100_dataloaders( opt, is_instance=True) n_cls = 100 else: raise NotImplementedError(opt.dataset) # set the interval for testing opt.test_freq = int(50000 / opt.batch_size) # compute number of epochs using the original cifar100 dataset size opt.lr_decay_epochs = list( int(i * 50000 / opt.aug_size) for i in opt.lr_decay_epochs) opt.epochs = int(opt.epochs * 50000 / opt.aug_size) print('Decay epochs: ', opt.lr_decay_epochs) print('Max epochs: ', opt.epochs) # set the device if torch.cuda.is_available(): device = torch.device(opt.device) else: device = torch.device('cpu') # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) # print(model_s) print("Size of the teacher:", count_parameters(model_t)) print("Size of the student:", count_parameters(model_s)) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.to(device) criterion_list.to(device) cudnn.benchmark = True # setup warmup warmup_scheduler = WarmUpLR(optimizer, len(train_loader) * opt.epochs_warmup) # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: %.2f \n' % (teacher_acc)) # creat logger logger = Logger(dir=opt.save_folder, var_names=[ 'Epoch', 'l_xent', 'l_kd', 'l_other', 'acc_train', 'acc_test', 'acc_test_best', 'lr' ], format=[ '%02d', '%.4f', '%.4f', '%.4f', '%.2f', '%.2f', '%.2f', '%.6f' ], args=opt) total_t = 0 # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) time1 = time.time() best_acc, total_t = train(epoch, train_loader, val_loader, module_list, criterion_list, optimizer, opt, best_acc, logger, device, warmup_scheduler, total_t) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print('Best accuracy: %.2f \n' % (best_acc))
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) if os.path.exists('./save/log/'): pass else: os.mkdir('./save/log/') log = open(os.path.join('./save/log/', 'log_{}.txt'.format(opt.path_config[-11:-7])), 'w') print_log('save path : {}'.format("./save/"), log) # dataloader if opt.dataset == 'imagenet': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_dataloader_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_imagenet_dataloader(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True) '''import torchvision.datasets as dset import torchvision.transforms as transforms data_folder = '/gdata/ImageNet2012' traindir = os.path.join(data_folder, 'train') valdir = os.path.join(data_folder, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_data1 = dset.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) valid_data = dset.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_data1, batch_size=opt.batch_size, shuffle=True, pin_memory=True, num_workers=opt.workers) val_loader = torch.utils.data.DataLoader( valid_data, batch_size=opt.batch_size, shuffle=True, pin_memory=True, num_workers=opt.workers) n_data = len(train_data1)''' n_cls = 1000 else: raise NotImplementedError(opt.dataset) print("##") # model model_t = EfficientNet.from_pretrained('efficientnet-b0',weights_path='./pretrain_efficientNet/pretrain_efficientNet.pth') #torch.save(model_t.state_dict(), './pretrain_efficientNet/pretrain_efficientNet.pt') from proxyless_nas.jj import get_proxyless_model model_s = get_proxyless_model(net_config_path=opt.path_config) gpus = [0, 1, 2, 3] torch.cuda.set_device('cuda:{}'.format(gpus[0])) data = torch.randn(2, 3, 224, 224) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt,log) print('teacher accuracy: ', teacher_acc) print_log("teacher accuracy:{}".format(teacher_acc), log) # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") print_log("==> training...", log) time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt,log) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print_log('epoch {}, total time {:.2f}'.format(epoch, time2 - time1), log) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt,log) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') print_log('saving the best model!', log) torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0 or epoch<10: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) print_log('best accuracy:{}'.format(best_acc), log) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file) log.close()
def main(): best_acc = 0 opt = parse_option() torch.manual_seed(2021) torch.cuda.manual_seed(2021) torch.backends.cudnn.deterministic = True # dataloader if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample( opt.data_path, batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode, use_fake_data=opt.use_fake_data, fake_data_folder=opt.fake_data_path, nfake=opt.nfake) else: train_loader, val_loader, n_data = get_cifar100_dataloaders( opt.data_path, batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True, use_fake_data=opt.use_fake_data, fake_data_folder=opt.fake_data_path, nfake=opt.nfake) n_cls = 100 # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) ## student model name, how to initialize student model, etc. student_model_filename = 'S_{}_T_{}_{}_r_{}_a_{}_b_{}_epoch_{}'.format( opt.model_s, opt.model_t, opt.distill, opt.gamma, opt.alpha, opt.beta, opt.epochs) if opt.finetune: ckpt_cnn_filename = os.path.join( opt.save_folder, student_model_filename + '_finetune_True_last.pth') ## load pre-trained model checkpoint = torch.load(opt.init_student_path) model_s.load_state_dict(checkpoint['model']) else: ckpt_cnn_filename = os.path.join(opt.save_folder, student_model_filename + '_last.pth') print('\n ' + ckpt_cnn_filename) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) if not os.path.isfile(ckpt_cnn_filename): print("\n Start training the {} >>>".format(opt.model_s)) ## resume training if opt.resume_epoch > 0: save_file = opt.save_intrain_folder + "/ckpt_{}_epoch_{}.pth".format( opt.model_s, opt.resume_epoch) checkpoint = torch.load(save_file) model_s.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) module_list.load_state_dict(checkpoint['module_list']) trainable_list.load_state_dict(checkpoint['trainable_list']) criterion_list.load_state_dict(checkpoint['criterion_list']) # module_list = checkpoint['module_list'] # criterion_list = checkpoint['criterion_list'] # # trainable_list = checkpoint['trainable_list'] # ckpt_test_accuracy = checkpoint['accuracy'] # ckpt_epoch = checkpoint['epoch'] # print('\n Resume training: epoch {}, test_acc {}...'.format(ckpt_epoch, ckpt_test_accuracy)) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() #end if for epoch in range(opt.resume_epoch, opt.epochs): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) test_acc, tect_acc_top5, test_loss = validate( val_loader, model_s, criterion_cls, opt) # regular saving if (epoch + 1) % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'optimizer': optimizer.state_dict(), 'module_list': module_list.state_dict(), 'criterion_list': criterion_list.state_dict(), 'trainable_list': trainable_list.state_dict(), 'accuracy': test_acc, } save_file = os.path.join( opt.save_intrain_folder, 'ckpt_{}_epoch_{}.pth'.format(opt.model_s, epoch + 1)) torch.save(state, save_file) ##end for epoch # store model torch.save({ 'opt': opt, 'model': model_s.state_dict(), }, ckpt_cnn_filename) print("\n End training CNN.") else: print("\n Loading pre-trained {}.".format(opt.model_s)) checkpoint = torch.load(ckpt_cnn_filename) model_s.load_state_dict(checkpoint['model']) test_acc, test_acc_top5, _ = validate(val_loader, model_s, criterion_cls, opt) print("\n {}, test_acc:{:.3f}, test_acc_top5:{:.3f}.".format( opt.model_s, test_acc, test_acc_top5)) eval_results_fullpath = opt.save_folder + "/test_result_" + opt.model_name + ".txt" if not os.path.isfile(eval_results_fullpath): eval_results_logging_file = open(eval_results_fullpath, "w") eval_results_logging_file.close() with open(eval_results_fullpath, 'a') as eval_results_logging_file: eval_results_logging_file.write( "\n===================================================================================================" ) eval_results_logging_file.write("\n Test results for {} \n".format( opt.model_name)) print(opt, file=eval_results_logging_file) eval_results_logging_file.write( "\n Test accuracy: Top1 {:.3f}, Top5 {:.3f}.".format( test_acc, test_acc_top5)) eval_results_logging_file.write( "\n Test error rate: Top1 {:.3f}, Top5 {:.3f}.".format( 100 - test_acc, 100 - test_acc_top5))