def main(): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True cudnn.benchmark = True logging.info("args = %s", args) logging.info("unparsed_args = %s", unparsed) logging.info('----------- Network Initialization --------------') snet = define_tsnet(name=args.s_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.s_init) load_pretrained_model(snet, checkpoint['net']) logging.info('Student: %s', snet) logging.info('Student param size = %fMB', count_parameters_in_MB(snet)) tnet = define_tsnet(name=args.t_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.t_model) load_pretrained_model(tnet, checkpoint['net']) tnet.eval() for param in tnet.parameters(): param.requires_grad = False logging.info('Teacher: %s', tnet) logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet)) logging.info('-----------------------------------------------') # initialize optimizer optimizer = torch.optim.SGD(snet.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # define attacker attacker = BSSAttacker(step_alpha=0.3, num_steps=10, eps=1e-4) # define loss functions criterionKD = BSS(args.T) if args.cuda: criterionCls = torch.nn.CrossEntropyLoss().cuda() else: criterionCls = torch.nn.CrossEntropyLoss() # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('Invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader(dataset( root=args.img_root, transform=train_transform, train=True, download=True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root, transform=test_transform, train=False, download=True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # warp nets and criterions for train and test nets = {'snet': snet, 'tnet': tnet} criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD} best_top1 = 0 best_top5 = 0 for epoch in range(1, args.epochs + 1): adjust_lr(optimizer, epoch) # train one epoch epoch_start_time = time.time() train(train_loader, nets, optimizer, criterions, attacker, epoch) # evaluate on testing set logging.info('Testing the models......') test_top1, test_top5 = test(test_loader, nets, criterions, epoch) epoch_duration = time.time() - epoch_start_time logging.info('Epoch time: {}s'.format(int(epoch_duration))) # save model is_best = False if test_top1 > best_top1: best_top1 = test_top1 best_top5 = test_top5 is_best = True logging.info('Saving models......') save_checkpoint( { 'epoch': epoch, 'snet': snet.state_dict(), 'tnet': tnet.state_dict(), 'prec@1': test_top1, 'prec@5': test_top5, }, is_best, args.save_root)
def main(): global args args = parser.parse_args() print(args) if not os.path.exists(os.path.join(args.save_root,'checkpoint')): os.makedirs(os.path.join(args.save_root,'checkpoint')) if args.cuda: cudnn.benchmark = True print('----------- Network Initialization --------------') snet = define_tsnet(name=args.s_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.s_init) load_pretrained_model(snet, checkpoint['net']) tnet = define_tsnet(name=args.t_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.t_model) load_pretrained_model(tnet, checkpoint['net']) tnet.eval() for param in tnet.parameters(): param.requires_grad = False print('-----------------------------------------------') # initialize optimizer optimizer = torch.optim.SGD(snet.parameters(), lr = args.lr, momentum = args.momentum, weight_decay = args.weight_decay, nesterov = True) # define loss functions if args.cuda: criterionCls = torch.nn.CrossEntropyLoss().cuda() criterionFitnet = torch.nn.MSELoss().cuda() else: criterionCls = torch.nn.CrossEntropyLoss() criterionFitnet = torch.nn.MSELoss() # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = train_transform, train = True, download = True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = test_transform, train = False, download = True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) for epoch in range(1, args.epochs+1): epoch_start_time = time.time() adjust_lr(optimizer, epoch) # train one epoch nets = {'snet':snet, 'tnet':tnet} criterions = {'criterionCls':criterionCls, 'criterionFitnet':criterionFitnet} train(train_loader, nets, optimizer, criterions, epoch) epoch_time = time.time() - epoch_start_time print('one epoch time is {:02}h{:02}m{:02}s'.format(*transform_time(epoch_time))) # evaluate on testing set print('testing the models......') test_start_time = time.time() test(test_loader, nets, criterions) test_time = time.time() - test_start_time print('testing time is {:02}h{:02}m{:02}s'.format(*transform_time(test_time))) # save model print('saving models......') save_name = 'fitnet_r{}_r{}_{:>03}.ckp'.format(args.t_name[6:], args.s_name[6:], epoch) save_name = os.path.join(args.save_root, 'checkpoint', save_name) if epoch == 1: save_checkpoint({ 'epoch': epoch, 'snet': snet.state_dict(), 'tnet': tnet.state_dict(), }, save_name) else: save_checkpoint({ 'epoch': epoch, 'snet': snet.state_dict(), }, save_name)
def main(): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True cudnn.benchmark = True logging.info("args = %s", args) logging.info("unparsed_args = %s", unparsed) logging.info('----------- Network Initialization --------------') net1 = define_tsnet(name=args.net1_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.net1_init) load_pretrained_model(net1, checkpoint['net']) logging.info('Net1: %s', net1) logging.info('Net1 param size = %fMB', count_parameters_in_MB(net1)) net2 = define_tsnet(name=args.net2_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.net2_init) load_pretrained_model(net2, checkpoint['net']) logging.info('Net2: %s', net2) logging.info('Net2 param size = %fMB', count_parameters_in_MB(net2)) logging.info('-----------------------------------------------') # initialize optimizer optimizer1 = torch.optim.SGD(net1.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) optimizer2 = torch.optim.SGD(net2.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # define loss functions criterionKD = DML() if args.cuda: criterionCls = torch.nn.CrossEntropyLoss().cuda() else: criterionCls = torch.nn.CrossEntropyLoss() # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('Invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader(dataset( root=args.img_root, transform=train_transform, train=True, download=True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root, transform=test_transform, train=False, download=True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # warp nets and criterions for train and test nets = {'net1': net1, 'net2': net2} criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD} optimizers = {'optimizer1': optimizer1, 'optimizer2': optimizer2} best_top1 = 0 best_top5 = 0 for epoch in range(1, args.epochs + 1): adjust_lr(optimizers, epoch) # train one epoch epoch_start_time = time.time() train(train_loader, nets, optimizers, criterions, epoch) # evaluate on testing set logging.info('Testing the models......') test_top11, test_top15, test_top21, test_top25 = test( test_loader, nets, criterions) epoch_duration = time.time() - epoch_start_time logging.info('Epoch time: {}s'.format(int(epoch_duration))) # save model is_best = False if max(test_top11, test_top21) > best_top1: best_top1 = max(test_top11, test_top21) best_top5 = max(test_top15, test_top25) is_best = True logging.info('Saving models......') save_checkpoint( { 'epoch': epoch, 'net1': net1.state_dict(), 'net2': net2.state_dict(), 'prec1@1': test_top11, 'prec1@5': test_top15, 'prec2@1': test_top21, 'prec2@5': test_top25, }, is_best, args.save_root)
def main(): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True cudnn.benchmark = True logging.info("args = %s", args) logging.info("unparsed_args = %s", unparsed) logging.info('----------- Network Initialization --------------') net = define_tsnet(name=args.net_name, num_class=args.num_class, cuda=args.cuda) logging.info('%s', net) logging.info("param size = %fMB", count_parameters_in_MB(net)) logging.info('-----------------------------------------------') # save initial parameters logging.info('Saving initial parameters......') save_path = os.path.join(args.save_root, 'initial_r{}.pth.tar'.format(args.net_name[6:])) torch.save({ 'epoch': 0, 'net': net.state_dict(), 'prec@1': 0.0, 'prec@5': 0.0, }, save_path) # initialize optimizer optimizer = torch.optim.SGD(net.parameters(), lr = args.lr, momentum = args.momentum, weight_decay = args.weight_decay, nesterov = True) # define loss functions if args.cuda: criterion = torch.nn.CrossEntropyLoss().cuda() else: criterion = torch.nn.CrossEntropyLoss() # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('Invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = train_transform, train = True, download = True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = test_transform, train = False, download = True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) best_top1 = 0 best_top5 = 0 for epoch in range(1, args.epochs+1): adjust_lr(optimizer, epoch) # train one epoch epoch_start_time = time.time() train(train_loader, net, optimizer, criterion, epoch) # evaluate on testing set logging.info('Testing the models......') test_top1, test_top5 = test(test_loader, net, criterion) epoch_duration = time.time() - epoch_start_time logging.info('Epoch time: {}s'.format(int(epoch_duration))) # save model is_best = False if test_top1 > best_top1: best_top1 = test_top1 best_top5 = test_top5 is_best = True logging.info('Saving models......') save_checkpoint({ 'epoch': epoch, 'net': net.state_dict(), 'prec@1': test_top1, 'prec@5': test_top5, }, is_best, args.save_root)
def main(): global args args = parser.parse_args() print(args) if not os.path.exists(os.path.join(args.save_root,'checkpoint')): os.makedirs(os.path.join(args.save_root,'checkpoint')) if args.cuda: cudnn.benchmark = True print('----------- Network Initialization --------------') net = define_tsnet(name=args.net_name, num_class=args.num_class, cuda=args.cuda) print('-----------------------------------------------') # save initial parameters print('saving initial parameters......') if args.net_name[:6] == 'resnet': save_name = 'baseline_r{}_{:>03}.ckp'.format(args.net_name[6:], 0) elif args.net_name[:6] == 'resnex': save_name = 'baseline_rx{}_{:>03}.ckp'.format(args.net_name[7:], 0) elif args.net_name[:6] == 'densen': save_name = 'baseline_dBC{}_{:>03}.ckp'.format(args.net_name[10:], 0) save_name = os.path.join(args.save_root, 'checkpoint', save_name) save_checkpoint({ 'epoch': 0, 'net': net.state_dict(), }, save_name) # initialize optimizer optimizer = torch.optim.SGD(net.parameters(), lr = args.lr, momentum = args.momentum, weight_decay = args.weight_decay, nesterov = True) # define loss functions if args.cuda: criterion = torch.nn.CrossEntropyLoss().cuda() else: criterion = torch.nn.CrossEntropyLoss() # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean,std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = train_transform, train = True, download = True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader( dataset(root = args.img_root, transform = test_transform, train = False, download = True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) trainF = open(os.path.join(args.save_log, 'train.csv'), 'w') testF = open(os.path.join(args.save_log, 'test.csv'), 'w') train_start_time = time.time() max_test_prec_1, max_test_prec_5 = 0, 0 for epoch in range(1, args.epochs+1): epoch_start_time = time.time() adjust_lr(optimizer, epoch) # train one epoch train(train_loader, net, optimizer, criterion, epoch, trainF) epoch_time = time.time() - epoch_start_time print('one epoch time is {:02}h{:02}m{:02}s'.format(*transform_time(epoch_time))) # evaluate on testing set print('testing the models......') test_start_time = time.time() test_prec_1, test_prec_5 = test(test_loader, net, criterion, testF) test_time = time.time() - test_start_time print('testing time is {:02}h{:02}m{:02}s'.format(*transform_time(test_time))) max_test_prec_1 = max(max_test_prec_1, test_prec_1) max_test_prec_5 = max(max_test_prec_5, test_prec_5) # save model print('saving models......') save_name = 'baseline_r{}_{:>03}.ckp'.format(args.net_name[6:], epoch) save_name = os.path.join(args.save_root, 'checkpoint', save_name) save_checkpoint({ 'epoch': epoch, 'net': net.state_dict(), }, save_name) train_finish_time = time.time() train_total_time = train_start_time - train_finish_time trainF.close() testF.close() print('the total train time is:{:02}h{:02}m{:02}s'.format(*transform_time(train_total_time))) print('the max test prec@1:{:.2f} , prec@5:{:.2f}'.format(max_test_prec_1, max_test_prec_5))
def main(): np.random.seed(args.seed) torch.manual_seed(args.seed) if args.cuda: torch.cuda.manual_seed(args.seed) cudnn.enabled = True cudnn.benchmark = True logging.info("args = %s", args) logging.info("unparsed_args = %s", unparsed) logging.info('----------- Network Initialization --------------') snet = define_tsnet(name=args.s_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.s_init) load_pretrained_model(snet, checkpoint['net']) logging.info('Student: %s', snet) logging.info('Student param size = %fMB', count_parameters_in_MB(snet)) tnet = define_tsnet(name=args.t_name, num_class=args.num_class, cuda=args.cuda) checkpoint = torch.load(args.t_model) load_pretrained_model(tnet, checkpoint['net']) tnet.eval() for param in tnet.parameters(): param.requires_grad = False logging.info('Teacher: %s', tnet) logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet)) logging.info('-----------------------------------------------') # define loss functions if args.kd_mode == 'logits': criterionKD = Logits() elif args.kd_mode == 'st': criterionKD = SoftTarget(args.T) elif args.kd_mode == 'at': criterionKD = AT(args.p) elif args.kd_mode == 'fitnet': criterionKD = Hint() elif args.kd_mode == 'nst': criterionKD = NST() elif args.kd_mode == 'pkt': criterionKD = PKTCosSim() elif args.kd_mode == 'fsp': criterionKD = FSP() elif args.kd_mode == 'rkd': criterionKD = RKD(args.w_dist, args.w_angle) elif args.kd_mode == 'ab': criterionKD = AB(args.m) elif args.kd_mode == 'sp': criterionKD = SP() elif args.kd_mode == 'sobolev': criterionKD = Sobolev() elif args.kd_mode == 'cc': criterionKD = CC(args.gamma, args.P_order) elif args.kd_mode == 'lwm': criterionKD = LwM() elif args.kd_mode == 'irg': criterionKD = IRG(args.w_irg_vert, args.w_irg_edge, args.w_irg_tran) elif args.kd_mode == 'vid': s_channels = snet.module.get_channel_num()[1:4] t_channels = tnet.module.get_channel_num()[1:4] criterionKD = [] for s_c, t_c in zip(s_channels, t_channels): criterionKD.append(VID(s_c, int(args.sf * t_c), t_c, args.init_var)) criterionKD = [c.cuda() for c in criterionKD] if args.cuda else criterionKD criterionKD = [None] + criterionKD # None is a placeholder elif args.kd_mode == 'ofd': s_channels = snet.module.get_channel_num()[1:4] t_channels = tnet.module.get_channel_num()[1:4] criterionKD = [] for s_c, t_c in zip(s_channels, t_channels): criterionKD.append( OFD(s_c, t_c).cuda() if args.cuda else OFD(s_c, t_c)) criterionKD = [None] + criterionKD # None is a placeholder elif args.kd_mode == 'afd': # t_channels is same with s_channels s_channels = snet.module.get_channel_num()[1:4] t_channels = tnet.module.get_channel_num()[1:4] criterionKD = [] for t_c in t_channels: criterionKD.append( AFD(t_c, args.att_f).cuda() if args. cuda else AFD(t_c, args.att_f)) criterionKD = [None] + criterionKD # None is a placeholder # # t_chws is same with s_chws # s_chws = snet.module.get_chw_num()[1:4] # t_chws = tnet.module.get_chw_num()[1:4] # criterionKD = [] # for t_chw in t_chws: # criterionKD.append(AFD(t_chw).cuda() if args.cuda else AFD(t_chw)) # criterionKD = [None] + criterionKD # None is a placeholder else: raise Exception('Invalid kd mode...') if args.cuda: criterionCls = torch.nn.CrossEntropyLoss().cuda() else: criterionCls = torch.nn.CrossEntropyLoss() # initialize optimizer if args.kd_mode in ['vid', 'ofd', 'afd']: optimizer = torch.optim.SGD(chain( snet.parameters(), *[c.parameters() for c in criterionKD[1:]]), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) else: optimizer = torch.optim.SGD(snet.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True) # define transforms if args.data_name == 'cifar10': dataset = dst.CIFAR10 mean = (0.4914, 0.4822, 0.4465) std = (0.2470, 0.2435, 0.2616) elif args.data_name == 'cifar100': dataset = dst.CIFAR100 mean = (0.5071, 0.4865, 0.4409) std = (0.2673, 0.2564, 0.2762) else: raise Exception('Invalid dataset name...') train_transform = transforms.Compose([ transforms.Pad(4, padding_mode='reflect'), transforms.RandomCrop(32), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) test_transform = transforms.Compose([ transforms.CenterCrop(32), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ]) # define data loader train_loader = torch.utils.data.DataLoader(dataset( root=args.img_root, transform=train_transform, train=True, download=True), batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root, transform=test_transform, train=False, download=True), batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True) # warp nets and criterions for train and test nets = {'snet': snet, 'tnet': tnet} criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD} # first initilizing the student nets if args.kd_mode in ['fsp', 'ab']: logging.info('The first stage, student initialization......') train_init(train_loader, nets, optimizer, criterions, 50) args.lambda_kd = 0.0 logging.info('The second stage, softmax training......') best_top1 = 0 best_top5 = 0 for epoch in range(1, args.epochs + 1): adjust_lr(optimizer, epoch) # train one epoch epoch_start_time = time.time() train(train_loader, nets, optimizer, criterions, epoch) # evaluate on testing set logging.info('Testing the models......') test_top1, test_top5 = test(test_loader, nets, criterions, epoch) epoch_duration = time.time() - epoch_start_time logging.info('Epoch time: {}s'.format(int(epoch_duration))) # save model is_best = False if test_top1 > best_top1: best_top1 = test_top1 best_top5 = test_top5 is_best = True logging.info('Saving models......') save_checkpoint( { 'epoch': epoch, 'snet': snet.state_dict(), 'tnet': tnet.state_dict(), 'prec@1': test_top1, 'prec@5': test_top5, }, is_best, args.save_root)