Beispiel #1
0
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    snet = define_tsnet(name=args.s_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.s_init)
    load_pretrained_model(snet, checkpoint['net'])
    logging.info('Student: %s', snet)
    logging.info('Student param size = %fMB', count_parameters_in_MB(snet))

    tnet = define_tsnet(name=args.t_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.t_model)
    load_pretrained_model(tnet, checkpoint['net'])
    tnet.eval()
    for param in tnet.parameters():
        param.requires_grad = False
    logging.info('Teacher: %s', tnet)
    logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet))
    logging.info('-----------------------------------------------')

    # initialize optimizer
    optimizer = torch.optim.SGD(snet.parameters(),
                                lr=args.lr,
                                momentum=args.momentum,
                                weight_decay=args.weight_decay,
                                nesterov=True)

    # define attacker
    attacker = BSSAttacker(step_alpha=0.3, num_steps=10, eps=1e-4)

    # define loss functions
    criterionKD = BSS(args.T)
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'snet': snet, 'tnet': tnet}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizer, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizer, criterions, attacker, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top1, test_top5 = test(test_loader, nets, criterions, epoch)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if test_top1 > best_top1:
            best_top1 = test_top1
            best_top5 = test_top5
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'snet': snet.state_dict(),
                'tnet': tnet.state_dict(),
                'prec@1': test_top1,
                'prec@5': test_top5,
            }, is_best, args.save_root)
def main():
	global args
	args = parser.parse_args()
	print(args)

	if not os.path.exists(os.path.join(args.save_root,'checkpoint')):
		os.makedirs(os.path.join(args.save_root,'checkpoint'))

	if args.cuda:
		cudnn.benchmark = True

	print('----------- Network Initialization --------------')
	snet = define_tsnet(name=args.s_name, num_class=args.num_class, cuda=args.cuda)
	checkpoint = torch.load(args.s_init)
	load_pretrained_model(snet, checkpoint['net'])

	tnet = define_tsnet(name=args.t_name, num_class=args.num_class, cuda=args.cuda)
	checkpoint = torch.load(args.t_model)
	load_pretrained_model(tnet, checkpoint['net'])
	tnet.eval()
	for param in tnet.parameters():
		param.requires_grad = False
	print('-----------------------------------------------')

	# initialize optimizer
	optimizer = torch.optim.SGD(snet.parameters(),
								lr = args.lr, 
								momentum = args.momentum, 
								weight_decay = args.weight_decay,
								nesterov = True)

	# define loss functions
	if args.cuda:
		criterionCls    = torch.nn.CrossEntropyLoss().cuda()
		criterionFitnet = torch.nn.MSELoss().cuda()
	else:
		criterionCls    = torch.nn.CrossEntropyLoss()
		criterionFitnet = torch.nn.MSELoss()

	# define transforms
	if args.data_name == 'cifar10':
		dataset = dst.CIFAR10
		mean = (0.4914, 0.4822, 0.4465)
		std  = (0.2470, 0.2435, 0.2616)
	elif args.data_name == 'cifar100':
		dataset = dst.CIFAR100
		mean = (0.5071, 0.4865, 0.4409)
		std  = (0.2673, 0.2564, 0.2762)
	else:
		raise Exception('invalid dataset name...')

	train_transform = transforms.Compose([
			transforms.Pad(4, padding_mode='reflect'),
			transforms.RandomCrop(32),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])
	test_transform = transforms.Compose([
			transforms.CenterCrop(32),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])

	# define data loader
	train_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = train_transform,
					train     = True,
					download  = True),
			batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
	test_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = test_transform,
					train     = False,
					download  = True),
			batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

	for epoch in range(1, args.epochs+1):
		epoch_start_time = time.time()

		adjust_lr(optimizer, epoch)

		# train one epoch
		nets = {'snet':snet, 'tnet':tnet}
		criterions = {'criterionCls':criterionCls, 'criterionFitnet':criterionFitnet}
		train(train_loader, nets, optimizer, criterions, epoch)
		epoch_time = time.time() - epoch_start_time
		print('one epoch time is {:02}h{:02}m{:02}s'.format(*transform_time(epoch_time)))

		# evaluate on testing set
		print('testing the models......')
		test_start_time = time.time()
		test(test_loader, nets, criterions)
		test_time = time.time() - test_start_time
		print('testing time is {:02}h{:02}m{:02}s'.format(*transform_time(test_time)))

		# save model
		print('saving models......')
		save_name = 'fitnet_r{}_r{}_{:>03}.ckp'.format(args.t_name[6:], args.s_name[6:], epoch)
		save_name = os.path.join(args.save_root, 'checkpoint', save_name)
		if epoch == 1:
			save_checkpoint({
				'epoch': epoch,
				'snet': snet.state_dict(),
				'tnet': tnet.state_dict(),
			}, save_name)
		else:
			save_checkpoint({
				'epoch': epoch,
				'snet': snet.state_dict(),
			}, save_name)
Beispiel #3
0
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    net1 = define_tsnet(name=args.net1_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.net1_init)
    load_pretrained_model(net1, checkpoint['net'])
    logging.info('Net1: %s', net1)
    logging.info('Net1 param size = %fMB', count_parameters_in_MB(net1))

    net2 = define_tsnet(name=args.net2_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.net2_init)
    load_pretrained_model(net2, checkpoint['net'])
    logging.info('Net2: %s', net2)
    logging.info('Net2 param size = %fMB', count_parameters_in_MB(net2))
    logging.info('-----------------------------------------------')

    # initialize optimizer
    optimizer1 = torch.optim.SGD(net1.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)
    optimizer2 = torch.optim.SGD(net2.parameters(),
                                 lr=args.lr,
                                 momentum=args.momentum,
                                 weight_decay=args.weight_decay,
                                 nesterov=True)

    # define loss functions
    criterionKD = DML()
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'net1': net1, 'net2': net2}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}
    optimizers = {'optimizer1': optimizer1, 'optimizer2': optimizer2}

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizers, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizers, criterions, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top11, test_top15, test_top21, test_top25 = test(
            test_loader, nets, criterions)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if max(test_top11, test_top21) > best_top1:
            best_top1 = max(test_top11, test_top21)
            best_top5 = max(test_top15, test_top25)
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'net1': net1.state_dict(),
                'net2': net2.state_dict(),
                'prec1@1': test_top11,
                'prec1@5': test_top15,
                'prec2@1': test_top21,
                'prec2@5': test_top25,
            }, is_best, args.save_root)
Beispiel #4
0
def main():
	np.random.seed(args.seed)
	torch.manual_seed(args.seed)
	if args.cuda:
		torch.cuda.manual_seed(args.seed)
		cudnn.enabled = True
		cudnn.benchmark = True
	logging.info("args = %s", args)
	logging.info("unparsed_args = %s", unparsed)

	logging.info('----------- Network Initialization --------------')
	net = define_tsnet(name=args.net_name, num_class=args.num_class, cuda=args.cuda)
	logging.info('%s', net)
	logging.info("param size = %fMB", count_parameters_in_MB(net))
	logging.info('-----------------------------------------------')

	# save initial parameters
	logging.info('Saving initial parameters......') 
	save_path = os.path.join(args.save_root, 'initial_r{}.pth.tar'.format(args.net_name[6:]))
	torch.save({
		'epoch': 0,
		'net': net.state_dict(),
		'prec@1': 0.0,
		'prec@5': 0.0,
	}, save_path)

	# initialize optimizer
	optimizer = torch.optim.SGD(net.parameters(),
								lr = args.lr, 
								momentum = args.momentum, 
								weight_decay = args.weight_decay,
								nesterov = True)

	# define loss functions
	if args.cuda:
		criterion = torch.nn.CrossEntropyLoss().cuda()
	else:
		criterion = torch.nn.CrossEntropyLoss()

	# define transforms
	if args.data_name == 'cifar10':
		dataset = dst.CIFAR10
		mean = (0.4914, 0.4822, 0.4465)
		std  = (0.2470, 0.2435, 0.2616)
	elif args.data_name == 'cifar100':
		dataset = dst.CIFAR100
		mean = (0.5071, 0.4865, 0.4409)
		std  = (0.2673, 0.2564, 0.2762)
	else:
		raise Exception('Invalid dataset name...')

	train_transform = transforms.Compose([
			transforms.Pad(4, padding_mode='reflect'),
			transforms.RandomCrop(32),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])
	test_transform = transforms.Compose([
			transforms.CenterCrop(32),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])

	# define data loader
	train_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = train_transform,
					train     = True,
					download  = True),
			batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
	test_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = test_transform,
					train     = False,
					download  = True),
			batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

	best_top1 = 0
	best_top5 = 0
	for epoch in range(1, args.epochs+1):
		adjust_lr(optimizer, epoch)

		# train one epoch
		epoch_start_time = time.time()
		train(train_loader, net, optimizer, criterion, epoch)

		# evaluate on testing set
		logging.info('Testing the models......')
		test_top1, test_top5 = test(test_loader, net, criterion)

		epoch_duration = time.time() - epoch_start_time
		logging.info('Epoch time: {}s'.format(int(epoch_duration)))

		# save model
		is_best = False
		if test_top1 > best_top1:
			best_top1 = test_top1
			best_top5 = test_top5
			is_best = True
		logging.info('Saving models......')
		save_checkpoint({
			'epoch': epoch,
			'net': net.state_dict(),
			'prec@1': test_top1,
			'prec@5': test_top5,
		}, is_best, args.save_root)
Beispiel #5
0
def main():
	global args
	args = parser.parse_args()
	print(args)

	if not os.path.exists(os.path.join(args.save_root,'checkpoint')):
		os.makedirs(os.path.join(args.save_root,'checkpoint'))

	if args.cuda:
		cudnn.benchmark = True

	print('----------- Network Initialization --------------')
	net = define_tsnet(name=args.net_name, num_class=args.num_class, cuda=args.cuda)
	print('-----------------------------------------------')

	# save initial parameters
	print('saving initial parameters......')

	if args.net_name[:6] == 'resnet':
		save_name = 'baseline_r{}_{:>03}.ckp'.format(args.net_name[6:], 0)
	elif args.net_name[:6] == 'resnex':
		save_name = 'baseline_rx{}_{:>03}.ckp'.format(args.net_name[7:], 0)
	elif args.net_name[:6] == 'densen':
		save_name = 'baseline_dBC{}_{:>03}.ckp'.format(args.net_name[10:], 0)

	save_name = os.path.join(args.save_root, 'checkpoint', save_name)
	save_checkpoint({
		'epoch': 0,
		'net': net.state_dict(),
	}, save_name)

	# initialize optimizer
	optimizer = torch.optim.SGD(net.parameters(),
								lr = args.lr,
								momentum = args.momentum,
								weight_decay = args.weight_decay,
								nesterov = True)

	# define loss functions
	if args.cuda:
		criterion = torch.nn.CrossEntropyLoss().cuda()
	else:
		criterion = torch.nn.CrossEntropyLoss()

	# define transforms
	if args.data_name == 'cifar10':
		dataset = dst.CIFAR10
		mean = (0.4914, 0.4822, 0.4465)
		std  = (0.2470, 0.2435, 0.2616)
	elif args.data_name == 'cifar100':
		dataset = dst.CIFAR100
		mean = (0.5071, 0.4865, 0.4409)
		std  = (0.2673, 0.2564, 0.2762)
	else:
		raise Exception('invalid dataset name...')

	train_transform = transforms.Compose([
			transforms.Pad(4, padding_mode='reflect'),
			transforms.RandomCrop(32),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])
	test_transform = transforms.Compose([
			transforms.CenterCrop(32),
			transforms.ToTensor(),
			transforms.Normalize(mean=mean,std=std)
		])

	# define data loader
	train_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = train_transform,
					train     = True,
					download  = True),
			batch_size=args.batch_size, shuffle=True, num_workers=4, pin_memory=True)
	test_loader = torch.utils.data.DataLoader(
			dataset(root      = args.img_root,
					transform = test_transform,
					train     = False,
					download  = True),
			batch_size=args.batch_size, shuffle=False, num_workers=4, pin_memory=True)

	trainF = open(os.path.join(args.save_log, 'train.csv'), 'w')
	testF = open(os.path.join(args.save_log, 'test.csv'), 'w')

	train_start_time = time.time()

	max_test_prec_1, max_test_prec_5 = 0, 0

	for epoch in range(1, args.epochs+1):
		epoch_start_time = time.time()

		adjust_lr(optimizer, epoch)

		# train one epoch
		train(train_loader, net, optimizer, criterion, epoch, trainF)
		epoch_time = time.time() - epoch_start_time
		print('one epoch time is {:02}h{:02}m{:02}s'.format(*transform_time(epoch_time)))

		# evaluate on testing set
		print('testing the models......')
		test_start_time = time.time()
		test_prec_1, test_prec_5 = test(test_loader, net, criterion, testF)
		test_time = time.time() - test_start_time
		print('testing time is {:02}h{:02}m{:02}s'.format(*transform_time(test_time)))

		max_test_prec_1 = max(max_test_prec_1, test_prec_1)
		max_test_prec_5 = max(max_test_prec_5, test_prec_5)

		# save model
		print('saving models......')
		save_name = 'baseline_r{}_{:>03}.ckp'.format(args.net_name[6:], epoch)
		save_name = os.path.join(args.save_root, 'checkpoint', save_name)
		save_checkpoint({
			'epoch': epoch,
			'net': net.state_dict(),
		}, save_name)

	train_finish_time = time.time()
	train_total_time = train_start_time - train_finish_time

	trainF.close()
	testF.close()

	print('the total train time is:{:02}h{:02}m{:02}s'.format(*transform_time(train_total_time)))
	print('the max test prec@1:{:.2f} , prec@5:{:.2f}'.format(max_test_prec_1, max_test_prec_5))
def main():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.cuda:
        torch.cuda.manual_seed(args.seed)
        cudnn.enabled = True
        cudnn.benchmark = True
    logging.info("args = %s", args)
    logging.info("unparsed_args = %s", unparsed)

    logging.info('----------- Network Initialization --------------')
    snet = define_tsnet(name=args.s_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.s_init)
    load_pretrained_model(snet, checkpoint['net'])
    logging.info('Student: %s', snet)
    logging.info('Student param size = %fMB', count_parameters_in_MB(snet))

    tnet = define_tsnet(name=args.t_name,
                        num_class=args.num_class,
                        cuda=args.cuda)
    checkpoint = torch.load(args.t_model)
    load_pretrained_model(tnet, checkpoint['net'])
    tnet.eval()
    for param in tnet.parameters():
        param.requires_grad = False
    logging.info('Teacher: %s', tnet)
    logging.info('Teacher param size = %fMB', count_parameters_in_MB(tnet))
    logging.info('-----------------------------------------------')

    # define loss functions
    if args.kd_mode == 'logits':
        criterionKD = Logits()
    elif args.kd_mode == 'st':
        criterionKD = SoftTarget(args.T)
    elif args.kd_mode == 'at':
        criterionKD = AT(args.p)
    elif args.kd_mode == 'fitnet':
        criterionKD = Hint()
    elif args.kd_mode == 'nst':
        criterionKD = NST()
    elif args.kd_mode == 'pkt':
        criterionKD = PKTCosSim()
    elif args.kd_mode == 'fsp':
        criterionKD = FSP()
    elif args.kd_mode == 'rkd':
        criterionKD = RKD(args.w_dist, args.w_angle)
    elif args.kd_mode == 'ab':
        criterionKD = AB(args.m)
    elif args.kd_mode == 'sp':
        criterionKD = SP()
    elif args.kd_mode == 'sobolev':
        criterionKD = Sobolev()
    elif args.kd_mode == 'cc':
        criterionKD = CC(args.gamma, args.P_order)
    elif args.kd_mode == 'lwm':
        criterionKD = LwM()
    elif args.kd_mode == 'irg':
        criterionKD = IRG(args.w_irg_vert, args.w_irg_edge, args.w_irg_tran)
    elif args.kd_mode == 'vid':
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for s_c, t_c in zip(s_channels, t_channels):
            criterionKD.append(VID(s_c, int(args.sf * t_c), t_c,
                                   args.init_var))
        criterionKD = [c.cuda()
                       for c in criterionKD] if args.cuda else criterionKD
        criterionKD = [None] + criterionKD  # None is a placeholder
    elif args.kd_mode == 'ofd':
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for s_c, t_c in zip(s_channels, t_channels):
            criterionKD.append(
                OFD(s_c, t_c).cuda() if args.cuda else OFD(s_c, t_c))
        criterionKD = [None] + criterionKD  # None is a placeholder
    elif args.kd_mode == 'afd':
        # t_channels is same with s_channels
        s_channels = snet.module.get_channel_num()[1:4]
        t_channels = tnet.module.get_channel_num()[1:4]
        criterionKD = []
        for t_c in t_channels:
            criterionKD.append(
                AFD(t_c, args.att_f).cuda() if args.
                cuda else AFD(t_c, args.att_f))
        criterionKD = [None] + criterionKD  # None is a placeholder
        # # t_chws is same with s_chws
        # s_chws = snet.module.get_chw_num()[1:4]
        # t_chws = tnet.module.get_chw_num()[1:4]
        # criterionKD = []
        # for t_chw in t_chws:
        # 	criterionKD.append(AFD(t_chw).cuda() if args.cuda else AFD(t_chw))
        # criterionKD = [None] + criterionKD # None is a placeholder
    else:
        raise Exception('Invalid kd mode...')
    if args.cuda:
        criterionCls = torch.nn.CrossEntropyLoss().cuda()
    else:
        criterionCls = torch.nn.CrossEntropyLoss()

    # initialize optimizer
    if args.kd_mode in ['vid', 'ofd', 'afd']:
        optimizer = torch.optim.SGD(chain(
            snet.parameters(), *[c.parameters() for c in criterionKD[1:]]),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)
    else:
        optimizer = torch.optim.SGD(snet.parameters(),
                                    lr=args.lr,
                                    momentum=args.momentum,
                                    weight_decay=args.weight_decay,
                                    nesterov=True)

    # define transforms
    if args.data_name == 'cifar10':
        dataset = dst.CIFAR10
        mean = (0.4914, 0.4822, 0.4465)
        std = (0.2470, 0.2435, 0.2616)
    elif args.data_name == 'cifar100':
        dataset = dst.CIFAR100
        mean = (0.5071, 0.4865, 0.4409)
        std = (0.2673, 0.2564, 0.2762)
    else:
        raise Exception('Invalid dataset name...')

    train_transform = transforms.Compose([
        transforms.Pad(4, padding_mode='reflect'),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])
    test_transform = transforms.Compose([
        transforms.CenterCrop(32),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

    # define data loader
    train_loader = torch.utils.data.DataLoader(dataset(
        root=args.img_root,
        transform=train_transform,
        train=True,
        download=True),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=4,
                                               pin_memory=True)
    test_loader = torch.utils.data.DataLoader(dataset(root=args.img_root,
                                                      transform=test_transform,
                                                      train=False,
                                                      download=True),
                                              batch_size=args.batch_size,
                                              shuffle=False,
                                              num_workers=4,
                                              pin_memory=True)

    # warp nets and criterions for train and test
    nets = {'snet': snet, 'tnet': tnet}
    criterions = {'criterionCls': criterionCls, 'criterionKD': criterionKD}

    # first initilizing the student nets
    if args.kd_mode in ['fsp', 'ab']:
        logging.info('The first stage, student initialization......')
        train_init(train_loader, nets, optimizer, criterions, 50)
        args.lambda_kd = 0.0
        logging.info('The second stage, softmax training......')

    best_top1 = 0
    best_top5 = 0
    for epoch in range(1, args.epochs + 1):
        adjust_lr(optimizer, epoch)

        # train one epoch
        epoch_start_time = time.time()
        train(train_loader, nets, optimizer, criterions, epoch)

        # evaluate on testing set
        logging.info('Testing the models......')
        test_top1, test_top5 = test(test_loader, nets, criterions, epoch)

        epoch_duration = time.time() - epoch_start_time
        logging.info('Epoch time: {}s'.format(int(epoch_duration)))

        # save model
        is_best = False
        if test_top1 > best_top1:
            best_top1 = test_top1
            best_top5 = test_top5
            is_best = True
        logging.info('Saving models......')
        save_checkpoint(
            {
                'epoch': epoch,
                'snet': snet.state_dict(),
                'tnet': tnet.state_dict(),
                'prec@1': test_top1,
                'prec@5': test_top5,
            }, is_best, args.save_root)