def get_criterion(dataset, teacher, num_classes, gpu, mixup, label_smoothing): criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_kd = None if mixup > 0.0: criterion_kd = NLLMultiLabelSmooth(label_smoothing) elif label_smoothing > 0.0: criterion_kd = LabelSmoothing(label_smoothing) if criterion_kd: criterion_kd = criterion_kd.cuda() return criterion_kd, criterion, None # criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) # criterion_smooth = criterion_smooth.cuda() if teacher == 'none': return criterion, criterion, None if dataset == 'imagenet': # load model model_teacher = models.__dict__[teacher](pretrained=True) model_teacher = model_teacher.cuda() model_teacher = DDP(model_teacher, device_ids=[gpu]) for p in model_teacher.parameters(): p.requires_grad = False model_teacher.eval() criterion_kd = KD_loss.DistributionLoss() return criterion_kd, criterion, model_teacher else: model_teacher = get_model(teacher, num_classes=num_classes, num_channels=96) model_teacher = model_teacher.cuda() if args.distributed: model_teacher = DDP(model_teacher, device_ids=[gpu]) checkpoint_tar = 'teacher.pth.tar' print('loading checkpoint {} ..........'.format(checkpoint_tar)) checkpoint = torch.load( checkpoint_tar, map_location=lambda storage, loc: storage.cuda(args.gpu)) model_teacher.load_state_dict(checkpoint['state_dict'], strict=False) print("loaded checkpoint {}".format(checkpoint_tar)) for p in model_teacher.parameters(): p.requires_grad = False model_teacher.eval() criterion_kd = KD_loss.DistributionLoss() return criterion_kd, criterion, model_teacher
def main(): if not torch.cuda.is_available(): sys.exit(1) start_t = time.time() cudnn.benchmark = True cudnn.enabled = True logging.info("args = %s", args) # load model model_teacher = models.__dict__[args.teacher](pretrained=True) model_teacher = nn.DataParallel(model_teacher).cuda() for p in model_teacher.parameters(): p.requires_grad = False model_teacher.eval() model_student = birealnet18() logging.info('student:') logging.info(model_student) model_student = nn.DataParallel(model_student).cuda() criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) criterion_smooth = criterion_smooth.cuda() criterion_kd = KD_loss.DistributionLoss() all_parameters = model_student.parameters() weight_parameters = [] for pname, p in model_student.named_parameters(): if p.ndimension() == 4 or 'conv' in pname: weight_parameters.append(p) weight_parameters_id = list(map(id, weight_parameters)) other_parameters = list( filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) optimizer = torch.optim.Adam( [{ 'params': other_parameters }, { 'params': weight_parameters, 'weight_decay': args.weight_decay }], lr=args.learning_rate, ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1.0 - step / args.epochs), last_epoch=-1) start_epoch = 0 best_top1_acc = 0 # checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar') # checkpoint = torch.load(checkpoint_tar) # model_student.load_state_dict(checkpoint['state_dict'], strict=False) checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') if os.path.exists(checkpoint_tar): logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) checkpoint = torch.load(checkpoint_tar) start_epoch = checkpoint['epoch'] best_top1_acc = checkpoint['best_top1_acc'] model_student.load_state_dict(checkpoint['state_dict'], strict=False) logging.info("loaded checkpoint {} epoch = {}".format( checkpoint_tar, checkpoint['epoch'])) # adjust the learning rate according to the checkpoint for epoch in range(start_epoch): scheduler.step() # load training data traindir = os.path.join(args.data, 'train') valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # data augmentation crop_scale = 0.08 lighting_param = 0.1 train_transforms = transforms.Compose([ transforms.RandomResizedCrop(224, scale=(crop_scale, 1.0)), Lighting(lighting_param), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) train_dataset = datasets.ImageFolder(traindir, transform=train_transforms) train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.workers, pin_memory=True) # load validation data val_loader = torch.utils.data.DataLoader(datasets.ImageFolder( valdir, transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), normalize, ])), batch_size=args.batch_size, shuffle=False, num_workers=args.workers, pin_memory=True) # train the model valid_obj, valid_top1_acc, valid_top5_acc = validate( epoch, val_loader, model_student, criterion, args)
def main(): if not torch.cuda.is_available(): sys.exit(1) start_t = time.time() cudnn.benchmark = True cudnn.enabled = True logging.info("args = %s", args) # load model #model_teacher = models.__dict__[args.teacher](pretrained=True) model_teacher = cifar_resnet56(pretrained="cifar100") model_teacher = nn.DataParallel(model_teacher).cuda() for p in model_teacher.parameters(): p.requires_grad = False model_teacher.eval() model_student = reactnet(100) logging.info('student:') logging.info(model_student) model_student = nn.DataParallel(model_student).cuda() criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) criterion_smooth = criterion_smooth.cuda() criterion_kd = KD_loss.DistributionLoss() all_parameters = model_student.parameters() weight_parameters = [] for pname, p in model_student.named_parameters(): if p.ndimension() == 4 or 'conv' in pname: weight_parameters.append(p) weight_parameters_id = list(map(id, weight_parameters)) other_parameters = list( filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) optimizer = torch.optim.Adam( [{ 'params': other_parameters }, { 'params': weight_parameters, 'weight_decay': args.weight_decay }], lr=args.learning_rate, ) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step: (1.0 - step / args.epochs), last_epoch=-1) start_epoch = 0 best_top1_acc = 0 #checkpoint_tar = os.path.join(args.save, 'checkpoint_ba.pth.tar') #checkpoint = torch.load(checkpoint_tar) #model_student.load_state_dict(checkpoint['state_dict'], strict=False) checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') if os.path.exists(checkpoint_tar): logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) checkpoint = torch.load(checkpoint_tar) start_epoch = checkpoint['epoch'] + 1 best_top1_acc = checkpoint['best_top1_acc'] model_student.load_state_dict(checkpoint['state_dict'], strict=False) logging.info("loaded checkpoint {} epoch = {}".format( checkpoint_tar, checkpoint['epoch'])) # adjust the learning rate according to the checkpoint for epoch in range(start_epoch): scheduler.step() # load training data # traindir = os.path.join(args.data, 'train') #valdir = os.path.join(args.data, 'val') normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # data augmentation crop_scale = 0.08 lighting_param = 0.1 train_transforms = transforms.Compose([ transforms.RandomResizedCrop(32, scale=(crop_scale, 1.0)), Lighting(lighting_param), transforms.RandomHorizontalFlip(), transforms.ToTensor(), normalize ]) cifar100_train = datasets.CIFAR100("../data", train=True, download=True, transform=transforms.Compose([ transforms.RandomCrop(size=32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize( mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023], ), ])) cifar100_eval = datasets.CIFAR100("../data", train=False, download=False, transform=transforms.Compose([ transforms.ToTensor(), transforms.Normalize( mean=[0.5071, 0.4865, 0.4409], std=[0.2009, 0.1984, 0.2023], ), ])) train_loader = DataLoader( cifar100_train, batch_size=64, shuffle=True, ) val_loader = DataLoader( cifar100_eval, batch_size=64, shuffle=True, ) # train the model epoch = start_epoch while epoch < args.epochs: train_obj, train_top1_acc, train_top5_acc = train( epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler) valid_obj, valid_top1_acc, valid_top5_acc = validate( epoch, val_loader, model_student, criterion, args) is_best = False if valid_top1_acc > best_top1_acc: best_top1_acc = valid_top1_acc is_best = True save_checkpoint( { 'epoch': epoch, 'state_dict': model_student.state_dict(), 'best_top1_acc': best_top1_acc, 'optimizer': optimizer.state_dict(), }, is_best, args.save) epoch += 1 training_time = (time.time() - start_t) / 36000 valid_obj, valid_top1_acc, valid_top5_acc = validate( epoch, val_loader, model_student, criterion, args) print('total training time = {} hours'.format(training_time))
def main(): if not torch.cuda.is_available(): sys.exit(1) start_t = time.time() cudnn.benchmark = True cudnn.enabled=True logging.info("args = %s", args) # load model for imagenet #model_teacher = models.__dict__[args.teacher](pretrained=True) #model_teacher = nn.DataParallel(model_teacher).cuda() #load teacher model for cifar-10 model_teacher = resnet.resnet56() #print(model_teacher) checkpoint = torch.load("../../models/resnet56-4bfd9763.th") state_dict = checkpoint['state_dict'] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v model_teacher.load_state_dict(new_state_dict) model_teacher = nn.DataParallel(model_teacher).cuda() for p in model_teacher.parameters(): p.requires_grad = False model_teacher.eval() model_student = reactnet(num_classes = CLASSES) #model_student = resnet.resnet32() logging.info('student:') logging.info(model_student) model_student = nn.DataParallel(model_student).cuda() criterion = nn.CrossEntropyLoss() criterion = criterion.cuda() criterion_smooth = CrossEntropyLabelSmooth(CLASSES, args.label_smooth) criterion_smooth = criterion_smooth.cuda() criterion_kd = KD_loss.DistributionLoss() all_parameters = model_student.parameters() weight_parameters = [] for pname, p in model_student.named_parameters(): if p.ndimension() == 4 or 'conv' in pname: weight_parameters.append(p) weight_parameters_id = list(map(id, weight_parameters)) other_parameters = list(filter(lambda p: id(p) not in weight_parameters_id, all_parameters)) optimizer = torch.optim.Adam( [{'params' : other_parameters}, {'params' : weight_parameters, 'weight_decay' : args.weight_decay}], lr=args.learning_rate,) #optimizer = torch.optim.SGD(model_student.parameters(), lr=args.learning_rate, # momentum=args.momentum, # weight_decay=args.weight_decay) #scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda step : (1.0-step/args.epochs), last_epoch=-1) #scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=30, gamma=0.1) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[25, 45], gamma=0.1) start_epoch = 0 best_top1_acc= 0 checkpoint_tar = os.path.join(args.save, 'checkpoint.pth.tar') if os.path.exists(checkpoint_tar): logging.info('loading checkpoint {} ..........'.format(checkpoint_tar)) checkpoint = torch.load(checkpoint_tar) start_epoch = checkpoint['epoch'] + 1 best_top1_acc = checkpoint['best_top1_acc'] model_student.load_state_dict(checkpoint['state_dict'], strict=False) logging.info("loaded checkpoint {} epoch = {}" .format(checkpoint_tar, checkpoint['epoch'])) # adjust the learning rate according to the checkpoint for epoch in range(start_epoch): scheduler.step() # load cifar-10 dataset transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ]) train_dataset = datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) val_dataset = datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) val_loader = torch.utils.data.DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=2, pin_memory=True) classe_name = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') # train the model epoch = start_epoch while epoch < args.epochs: #train_obj, train_top1_acc, train_top5_acc = train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler) #valid_obj, valid_top1_acc, valid_top5_acc = validate(epoch, val_loader, model_student, criterion, args) train_obj, train_top1_acc= train(epoch, train_loader, model_student, model_teacher, criterion, optimizer, scheduler) valid_obj, valid_top1_acc= validate(epoch, val_loader, model_student, criterion, args) is_best = False if valid_top1_acc > best_top1_acc: best_top1_acc = valid_top1_acc is_best = True save_checkpoint({ 'epoch': epoch, 'state_dict': model_student.state_dict(), 'best_top1_acc': best_top1_acc, 'optimizer' : optimizer.state_dict(), }, is_best, args.save) epoch += 1 training_time = (time.time() - start_t) / 36000 print('total training time = {} hours'.format(training_time))