def get_model(args): if args.seed is not None: set_seed(args) if args.dataset == "cifar10": depth, widen_factor = 28, 2 elif args.dataset == 'cifar100': depth, widen_factor = 28, 8 student_model = WideResNet(num_classes=args.num_classes, depth=depth, widen_factor=widen_factor, dropout=0, dense_dropout=args.dense_dropout) if os.path.isfile(args.resume): print(f"=> loading checkpoint '{args.resume}'") loc = f'cpu' checkpoint = torch.load(args.resume, map_location=loc) if checkpoint['avg_state_dict'] is not None: model_load_state_dict(student_model, checkpoint['avg_state_dict']) else: model_load_state_dict(student_model, checkpoint['student_state_dict']) print( f"=> loaded checkpoint '{args.resume}' (step {checkpoint['step']})" ) else: print(f"=> no checkpoint found at '{args.resume}'") exit(1) if args.device != 'cpu': student_model.cuda() return student_model
def main(args): # writer = SummaryWriter('./runs/CIFAR_100_exp') train_transform = transforms.Compose([transforms.Pad(4, padding_mode='reflect'), transforms.RandomRotation(15), transforms.RandomHorizontalFlip(), transforms.RandomCrop(32), transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))]) test_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5071, 0.4867, 0.4408),(0.2675,0.2565,0.2761))]) train_dataset = datasets.CIFAR100('./dataset',train = True, transform = train_transform, download=True) test_dataset = datasets.CIFAR100('./dataset',train = False, transform = test_transform, download=True) train_loader = DataLoader(train_dataset, batch_size = args.batch_size, shuffle=True, num_workers=args.num_workers) test_loader = DataLoader(test_dataset, batch_size = args.batch_size, shuffle=False, num_workers=args.num_workers) Teacher = WideResNet(depth=args.teacher_depth, num_classes=100, widen_factor=args.teacher_width_factor, drop_rate=0.3) Teacher.cuda() Teacher.eval() teacher_weight_path = path.join(args.teacher_root_path, 'model_best.pth.tar') t_load = torch.load(teacher_weight_path)['state_dict'] Teacher.load_state_dict(t_load) Student = WideResNet(depth = args.student_depth, num_classes=100, widen_factor=args.student_width_factor, drop_rate=0.0) Student.cuda() cudnn.benchmark = True optimizer = torch.optim.SGD(Student.parameters(), lr = args.lr, momentum=0.9, weight_decay=5e-4, nesterov=True) opt_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones = [60, 120, 160], gamma=2e-1) criterion = nn.CrossEntropyLoss() best_acc = 0 best_acc5 = 0 best_flag = False for epoch in range(args.total_epochs): for iter_, data in enumerate(train_loader): images, labels = data images, labels = images.cuda(), labels.cuda() t_outs, *t_acts = Teacher(images) s_outs, *s_acts = Student(images) cls_loss = criterion(s_outs, labels) """ statistical matching and AdaIN losses """ if args.aux_flag==0: aux_loss_1 = SM_Loss(t_acts[2], s_acts[2]) # group conv2 else: aux_loss_1 = 0 for i in range(3): aux_loss_1 += SM_Loss(t_acts[i], s_acts[i]) F_hat = AdaIN(t_acts[2], s_acts[2]) interim_out_q = Teacher.bn1(F_hat) interim_out_q = Teacher.relu(interim_out_q) interim_out_q = F.avg_pool2d(interim_out_q, 8) interim_out_q = interim_out_q.view(-1, Teacher.last_ch) q = Teacher.fc(interim_out_q) aux_loss_2 = torch.mean(torch.pow(t_outs-q, 2)) total_loss = cls_loss + aux_loss_1 + aux_loss_2 optimizer.zero_grad() total_loss.backward() optimizer.step() top1, top5 = evaluator(test_loader, Student) if top1 > best_acc: best_acc = top1 best_acc5 = top5 best_flag = True if best_flag: state = {'epoch':epoch+1, 'state_dict':Student.state_dict(), 'optimizer': optimizer.state_dict()} save_ckpt(state, is_best=best_flag, root_path = args.student_weight_path) best_flag = False opt_scheduler.step() # writer.add_scalar('acc/top1', top1, epoch) # writer.add_scalar('acc/top5', top5, epoch) # writer.close() print("Best top 1 acc: {}".format(best_acc)) print("Best top 5 acc: {}".format(best_acc5))