def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode, data_path= opt.datapath ) else: train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True, data_path=opt.datapath ) n_cls = 100 else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_cifar100_dataloaders(batch_size=opt.batch_size, num_workers=opt.num_workers, n_test=100, is_instance=True) n_cls = 100 else: raise NotImplementedError(opt.dataset) # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # embed the watermark into teacher model wm_loader = None if opt.watermark == "usenix": # paper: https://www.usenix.org/system/files/conference/usenixsecurity18/sec18-adi.pdf # Define an optimizer for the teacher trainable_list_t = nn.ModuleList([model_t]) optimizer_t = optim.SGD(trainable_list_t.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) wm_loader = get_usenixwm_dataloader() print("## Train data + Watermark Val Acc") teacher_acc, _, _ = validate(val_loader, model_t, nn.CrossEntropyLoss(), opt) print("==> embedding USENIX watermark...") max_epochs = 250 # Cutoff val for epoch in range(1, max_epochs + 1): set_learning_rate(2e-4, optimizer_t) top1, top5 = train_vanilla(epoch, wm_loader, model_t, nn.CrossEntropyLoss(), optimizer_t, opt) if top1 >= 97: break print("## Watermark val") teacher_acc, _, _ = validate(wm_loader, model_t, nn.CrossEntropyLoss(), opt) print("==> done") # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) # If teacher and student models match, copy over weights for initialization if (opt.init_strat == "noise") and (type(model_t) == type(model_s)): print("==> Copying teachers weights to student with a weight of {}".format(opt.init_inv_corr)) model_s.load_state_dict(model_t.state_dict()) with torch.no_grad(): for param in model_s.parameters(): if torch.cuda.is_available(): noise = (torch.randn(param.size()) * opt.init_inv_corr).cuda() param.add_(noise) else: param.add_(torch.randn(param.size()) * opt.init_inv_corr) student_acc, _, _ = validate(val_loader, model_s, criterion_cls, opt) print('student accuracy: ', student_acc) print("==> done") # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt) if wm_loader is not None: print("==> wm retention") wm_top1, wm_top5, wm_loss = validate(wm_loader, model_s, criterion_cls, opt) logger.log_value('wm_ret', wm_top1, epoch) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file)
def main(): best_acc = 0 opt = parse_option() # tensorboard logger logger = tb_logger.Logger(logdir=opt.tb_folder, flush_secs=2) if os.path.exists('./save/log/'): pass else: os.mkdir('./save/log/') log = open(os.path.join('./save/log/', 'log_{}.txt'.format(opt.path_config[-11:-7])), 'w') print_log('save path : {}'.format("./save/"), log) # dataloader if opt.dataset == 'imagenet': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_dataloader_sample(batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_imagenet_dataloader(batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True) '''import torchvision.datasets as dset import torchvision.transforms as transforms data_folder = '/gdata/ImageNet2012' traindir = os.path.join(data_folder, 'train') valdir = os.path.join(data_folder, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) train_data1 = dset.ImageFolder(traindir, transforms.Compose([ transforms.RandomResizedCrop(224), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize, ])) valid_data = dset.ImageFolder(valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])) train_loader = torch.utils.data.DataLoader( train_data1, batch_size=opt.batch_size, shuffle=True, pin_memory=True, num_workers=opt.workers) val_loader = torch.utils.data.DataLoader( valid_data, batch_size=opt.batch_size, shuffle=True, pin_memory=True, num_workers=opt.workers) n_data = len(train_data1)''' n_cls = 1000 else: raise NotImplementedError(opt.dataset) print("##") # model model_t = EfficientNet.from_pretrained('efficientnet-b0',weights_path='./pretrain_efficientNet/pretrain_efficientNet.pth') #torch.save(model_t.state_dict(), './pretrain_efficientNet/pretrain_efficientNet.pt') from proxyless_nas.jj import get_proxyless_model model_s = get_proxyless_model(net_config_path=opt.path_config) gpus = [0, 1, 2, 3] torch.cuda.set_device('cuda:{}'.format(gpus[0])) data = torch.randn(2, 3, 224, 224) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)] ) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, logger, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, logger, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append(criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt,log) print('teacher accuracy: ', teacher_acc) print_log("teacher accuracy:{}".format(teacher_acc), log) # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") print_log("==> training...", log) time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt,log) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print_log('epoch {}, total time {:.2f}'.format(epoch, time2 - time1), log) logger.log_value('train_acc', train_acc, epoch) logger.log_value('train_loss', train_loss, epoch) test_acc, tect_acc_top5, test_loss = validate(val_loader, model_s, criterion_cls, opt,log) logger.log_value('test_acc', test_acc, epoch) logger.log_value('test_loss', test_loss, epoch) logger.log_value('test_acc_top5', tect_acc_top5, epoch) # save the best model if test_acc > best_acc: best_acc = test_acc state = { 'epoch': epoch, 'model': model_s.state_dict(), 'best_acc': best_acc, } save_file = os.path.join(opt.save_folder, '{}_best.pth'.format(opt.model_s)) print('saving the best model!') print_log('saving the best model!', log) torch.save(state, save_file) # regular saving if epoch % opt.save_freq == 0 or epoch<10: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'accuracy': test_acc, } save_file = os.path.join(opt.save_folder, 'ckpt_epoch_{epoch}.pth'.format(epoch=epoch)) torch.save(state, save_file) # This best accuracy is only for printing purpose. # The results reported in the paper/README is from the last epoch. print('best accuracy:', best_acc) print_log('best accuracy:{}'.format(best_acc), log) # save model state = { 'opt': opt, 'model': model_s.state_dict(), } save_file = os.path.join(opt.save_folder, '{}_last.pth'.format(opt.model_s)) torch.save(state, save_file) log.close()
def distill(opt): # refine the opt arguments opt.model_path = './save/student_model' iterations = opt.lr_decay_epochs.split(',') opt.lr_decay_epochs = list([]) for it in iterations: opt.lr_decay_epochs.append(int(it)) opt.model_t = get_teacher_name(opt.path_t) opt.print_freq = int(50000 / opt.batch_size / opt.print_freq) opt.model_name = 'S:{}_T:{}_{}_{}/r:{}_a:{}_b:{}_{}_{}_{}_{}_lam:{}_alp:{}_augsize:{}_T:{}'.format( opt.model_s, opt.model_t, opt.dataset, opt.distill, opt.gamma, opt.alpha, opt.beta, opt.trial, opt.device, opt.seed, opt.aug_type, opt.aug_lambda, opt.aug_alpha, opt.aug_size, opt.kd_T) opt.save_folder = os.path.join(opt.model_path, opt.model_name) if not os.path.isdir(opt.save_folder): os.makedirs(opt.save_folder) opt.learning_rate = 0.1 * opt.batch_size / 128 # set different learning rate from these 4 models if opt.model_s in ['MobileNetV2', 'ShuffleV1', 'ShuffleV2']: opt.learning_rate = opt.learning_rate / 5 print("learning rate is set to:", opt.learning_rate) best_acc = 0 np.random.seed(opt.seed) torch.manual_seed(opt.seed) # dataloader if opt.dataset == 'cifar100': if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample( batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode) else: train_loader, val_loader, n_data = get_cifar100_dataloaders( opt, is_instance=True) n_cls = 100 else: raise NotImplementedError(opt.dataset) # set the interval for testing opt.test_freq = int(50000 / opt.batch_size) # compute number of epochs using the original cifar100 dataset size opt.lr_decay_epochs = list( int(i * 50000 / opt.aug_size) for i in opt.lr_decay_epochs) opt.epochs = int(opt.epochs * 50000 / opt.aug_size) print('Decay epochs: ', opt.lr_decay_epochs) print('Max epochs: ', opt.epochs) # set the device if torch.cuda.is_available(): device = torch.device(opt.device) else: device = torch.device('cpu') # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) # print(model_s) print("Size of the teacher:", count_parameters(model_t)) print("Size of the student:", count_parameters(model_s)) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.to(device) criterion_list.to(device) cudnn.benchmark = True # setup warmup warmup_scheduler = WarmUpLR(optimizer, len(train_loader) * opt.epochs_warmup) # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: %.2f \n' % (teacher_acc)) # creat logger logger = Logger(dir=opt.save_folder, var_names=[ 'Epoch', 'l_xent', 'l_kd', 'l_other', 'acc_train', 'acc_test', 'acc_test_best', 'lr' ], format=[ '%02d', '%.4f', '%.4f', '%.4f', '%.2f', '%.2f', '%.2f', '%.6f' ], args=opt) total_t = 0 # routine for epoch in range(1, opt.epochs + 1): adjust_learning_rate(epoch, opt, optimizer) time1 = time.time() best_acc, total_t = train(epoch, train_loader, val_loader, module_list, criterion_list, optimizer, opt, best_acc, logger, device, warmup_scheduler, total_t) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) print('Best accuracy: %.2f \n' % (best_acc))
def main(): best_acc = 0 opt = parse_option() torch.manual_seed(2021) torch.cuda.manual_seed(2021) torch.backends.cudnn.deterministic = True # dataloader if opt.distill in ['crd']: train_loader, val_loader, n_data = get_cifar100_dataloaders_sample( opt.data_path, batch_size=opt.batch_size, num_workers=opt.num_workers, k=opt.nce_k, mode=opt.mode, use_fake_data=opt.use_fake_data, fake_data_folder=opt.fake_data_path, nfake=opt.nfake) else: train_loader, val_loader, n_data = get_cifar100_dataloaders( opt.data_path, batch_size=opt.batch_size, num_workers=opt.num_workers, is_instance=True, use_fake_data=opt.use_fake_data, fake_data_folder=opt.fake_data_path, nfake=opt.nfake) n_cls = 100 # model model_t = load_teacher(opt.path_t, n_cls) model_s = model_dict[opt.model_s](num_classes=n_cls) ## student model name, how to initialize student model, etc. student_model_filename = 'S_{}_T_{}_{}_r_{}_a_{}_b_{}_epoch_{}'.format( opt.model_s, opt.model_t, opt.distill, opt.gamma, opt.alpha, opt.beta, opt.epochs) if opt.finetune: ckpt_cnn_filename = os.path.join( opt.save_folder, student_model_filename + '_finetune_True_last.pth') ## load pre-trained model checkpoint = torch.load(opt.init_student_path) model_s.load_state_dict(checkpoint['model']) else: ckpt_cnn_filename = os.path.join(opt.save_folder, student_model_filename + '_last.pth') print('\n ' + ckpt_cnn_filename) data = torch.randn(2, 3, 32, 32) model_t.eval() model_s.eval() feat_t, _ = model_t(data, is_feat=True) feat_s, _ = model_s(data, is_feat=True) module_list = nn.ModuleList([]) module_list.append(model_s) trainable_list = nn.ModuleList([]) trainable_list.append(model_s) criterion_cls = nn.CrossEntropyLoss() criterion_div = DistillKL(opt.kd_T) if opt.distill == 'kd': criterion_kd = DistillKL(opt.kd_T) elif opt.distill == 'hint': criterion_kd = HintLoss() regress_s = ConvReg(feat_s[opt.hint_layer].shape, feat_t[opt.hint_layer].shape) module_list.append(regress_s) trainable_list.append(regress_s) elif opt.distill == 'crd': opt.s_dim = feat_s[-1].shape[1] opt.t_dim = feat_t[-1].shape[1] opt.n_data = n_data criterion_kd = CRDLoss(opt) module_list.append(criterion_kd.embed_s) module_list.append(criterion_kd.embed_t) trainable_list.append(criterion_kd.embed_s) trainable_list.append(criterion_kd.embed_t) elif opt.distill == 'attention': criterion_kd = Attention() elif opt.distill == 'nst': criterion_kd = NSTLoss() elif opt.distill == 'similarity': criterion_kd = Similarity() elif opt.distill == 'rkd': criterion_kd = RKDLoss() elif opt.distill == 'pkt': criterion_kd = PKT() elif opt.distill == 'kdsvd': criterion_kd = KDSVD() elif opt.distill == 'correlation': criterion_kd = Correlation() embed_s = LinearEmbed(feat_s[-1].shape[1], opt.feat_dim) embed_t = LinearEmbed(feat_t[-1].shape[1], opt.feat_dim) module_list.append(embed_s) module_list.append(embed_t) trainable_list.append(embed_s) trainable_list.append(embed_t) elif opt.distill == 'vid': s_n = [f.shape[1] for f in feat_s[1:-1]] t_n = [f.shape[1] for f in feat_t[1:-1]] criterion_kd = nn.ModuleList( [VIDLoss(s, t, t) for s, t in zip(s_n, t_n)]) # add this as some parameters in VIDLoss need to be updated trainable_list.append(criterion_kd) elif opt.distill == 'abound': s_shapes = [f.shape for f in feat_s[1:-1]] t_shapes = [f.shape for f in feat_t[1:-1]] connector = Connector(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(connector) init_trainable_list.append(model_s.get_feat_modules()) criterion_kd = ABLoss(len(feat_s[1:-1])) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification module_list.append(connector) elif opt.distill == 'factor': s_shape = feat_s[-2].shape t_shape = feat_t[-2].shape paraphraser = Paraphraser(t_shape) translator = Translator(s_shape, t_shape) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(paraphraser) criterion_init = nn.MSELoss() init(model_s, model_t, init_trainable_list, criterion_init, train_loader, opt) # classification criterion_kd = FactorTransfer() module_list.append(translator) module_list.append(paraphraser) trainable_list.append(translator) elif opt.distill == 'fsp': s_shapes = [s.shape for s in feat_s[:-1]] t_shapes = [t.shape for t in feat_t[:-1]] criterion_kd = FSP(s_shapes, t_shapes) # init stage training init_trainable_list = nn.ModuleList([]) init_trainable_list.append(model_s.get_feat_modules()) init(model_s, model_t, init_trainable_list, criterion_kd, train_loader, opt) # classification training pass else: raise NotImplementedError(opt.distill) criterion_list = nn.ModuleList([]) criterion_list.append(criterion_cls) # classification loss criterion_list.append( criterion_div) # KL divergence loss, original knowledge distillation criterion_list.append(criterion_kd) # other knowledge distillation loss # optimizer optimizer = optim.SGD(trainable_list.parameters(), lr=opt.learning_rate, momentum=opt.momentum, weight_decay=opt.weight_decay) # append teacher after optimizer to avoid weight_decay module_list.append(model_t) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() cudnn.benchmark = True # validate teacher accuracy teacher_acc, _, _ = validate(val_loader, model_t, criterion_cls, opt) print('teacher accuracy: ', teacher_acc) if not os.path.isfile(ckpt_cnn_filename): print("\n Start training the {} >>>".format(opt.model_s)) ## resume training if opt.resume_epoch > 0: save_file = opt.save_intrain_folder + "/ckpt_{}_epoch_{}.pth".format( opt.model_s, opt.resume_epoch) checkpoint = torch.load(save_file) model_s.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) module_list.load_state_dict(checkpoint['module_list']) trainable_list.load_state_dict(checkpoint['trainable_list']) criterion_list.load_state_dict(checkpoint['criterion_list']) # module_list = checkpoint['module_list'] # criterion_list = checkpoint['criterion_list'] # # trainable_list = checkpoint['trainable_list'] # ckpt_test_accuracy = checkpoint['accuracy'] # ckpt_epoch = checkpoint['epoch'] # print('\n Resume training: epoch {}, test_acc {}...'.format(ckpt_epoch, ckpt_test_accuracy)) if torch.cuda.is_available(): module_list.cuda() criterion_list.cuda() #end if for epoch in range(opt.resume_epoch, opt.epochs): adjust_learning_rate(epoch, opt, optimizer) print("==> training...") time1 = time.time() train_acc, train_loss = train(epoch, train_loader, module_list, criterion_list, optimizer, opt) time2 = time.time() print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) test_acc, tect_acc_top5, test_loss = validate( val_loader, model_s, criterion_cls, opt) # regular saving if (epoch + 1) % opt.save_freq == 0: print('==> Saving...') state = { 'epoch': epoch, 'model': model_s.state_dict(), 'optimizer': optimizer.state_dict(), 'module_list': module_list.state_dict(), 'criterion_list': criterion_list.state_dict(), 'trainable_list': trainable_list.state_dict(), 'accuracy': test_acc, } save_file = os.path.join( opt.save_intrain_folder, 'ckpt_{}_epoch_{}.pth'.format(opt.model_s, epoch + 1)) torch.save(state, save_file) ##end for epoch # store model torch.save({ 'opt': opt, 'model': model_s.state_dict(), }, ckpt_cnn_filename) print("\n End training CNN.") else: print("\n Loading pre-trained {}.".format(opt.model_s)) checkpoint = torch.load(ckpt_cnn_filename) model_s.load_state_dict(checkpoint['model']) test_acc, test_acc_top5, _ = validate(val_loader, model_s, criterion_cls, opt) print("\n {}, test_acc:{:.3f}, test_acc_top5:{:.3f}.".format( opt.model_s, test_acc, test_acc_top5)) eval_results_fullpath = opt.save_folder + "/test_result_" + opt.model_name + ".txt" if not os.path.isfile(eval_results_fullpath): eval_results_logging_file = open(eval_results_fullpath, "w") eval_results_logging_file.close() with open(eval_results_fullpath, 'a') as eval_results_logging_file: eval_results_logging_file.write( "\n===================================================================================================" ) eval_results_logging_file.write("\n Test results for {} \n".format( opt.model_name)) print(opt, file=eval_results_logging_file) eval_results_logging_file.write( "\n Test accuracy: Top1 {:.3f}, Top5 {:.3f}.".format( test_acc, test_acc_top5)) eval_results_logging_file.write( "\n Test error rate: Top1 {:.3f}, Top5 {:.3f}.".format( 100 - test_acc, 100 - test_acc_top5))